from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    auc,
    brier_score_loss,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
)
import matplotlib.pyplot as plt
import numpy as np
import itertools


def get_confusion_matrix_plot(
    predicted_y,
    true_y,
    classes=None,
    normalize=False,
    title="Confusion matrix",
    cmap=plt.get_cmap("binary"),
    figsize=(10, 10),
):
    """
    Na podstawie przykładu sklearn
    https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    :param figsize: wielkość obrazu wejściowego
    :param predicted_y: wartości prognozowane przez model
    :param true_y: rzeczywiste wartości etykiet
    :param classes: nazwy obu klas
    :param normalize: czy wykres ma być znormalizowany?
    :param title: tytuł wykresu
    :param cmap: mapa kolorów
    :return: obraz macierzy pomyłek
    """
    if classes is None:
        classes = ["Niska jakość", "Wysoka jakość"]

    cm = confusion_matrix(true_y, predicted_y)
    if normalize:
        cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=figsize)
    ax = plt.gca()
    im = ax.imshow(cm, interpolation="nearest", cmap=cmap)

    title_obj = plt.title(title, fontsize=30)
    title_obj.set_position([0.5, 1.15])

    plt.colorbar(im)

    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, fontsize=15)
    plt.yticks(tick_marks, classes, fontsize=15)

    fmt = ".2f" if normalize else "d"
    thresh = (cm.max() - cm.min()) / 2.0 + cm.min()
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(
            j,
            i,
            format(cm[i, j], fmt),
            horizontalalignment="center",
            color="white" if cm[i, j] > thresh else "black",
            fontsize=40,
        )

    plt.tight_layout()
    plt.ylabel("Rzeczywista etykieta", fontsize=20)
    plt.xlabel("Prognozowana etykieta", fontsize=20)


def get_roc_plot(
    predicted_proba_y, true_y, tpr_bar=-1, fpr_bar=-1, figsize=(10, 10)
):
    """
    Na podstawie przykładu sklearn
    https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html
    :param fpr_bar: graniczna wartość fałszywie pozytywna
    :param tpr_bar: graniczna wartość fałszywie negatywna
    :param figsize: wielkość obrazu wyjściowego
    :param predicted_proba_y: wartości prawdopodobieństwa prognozowane przez model dla każdego przykładu
    :param true_y: trzeczywista wartość etykiety
    :return: wykres krzywej ROC
    """
    fpr, tpr, thresholds = roc_curve(true_y, predicted_proba_y)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=figsize)
    plt.plot(
        fpr,
        tpr,
        lw=1,
        alpha=1,
        color="black",
        label="Krzywa ROC (AUC = %0.2f)" % roc_auc,
    )
    plt.plot(
        [0, 1],
        [0, 1],
        linestyle="--",
        lw=2,
        color="grey",
        label="Chance",
        alpha=1,
    )

    # "Oszukanie" pozycji, aby wykres był bardziej czytelny.
    plt.plot(
        [0.01, 0.01, 1],
        [0.01, 0.99, 0.99],
        linestyle=":",
        lw=2,
        color="green",
        label="Idealny model",
        alpha=1,
    )

    if tpr_bar != -1:
        plt.plot(
            [0, 1],
            [tpr_bar, tpr_bar],
            linestyle="-",
            lw=2,
            color="red",
            label="Wymagany TPR",
            alpha=1,
        )
        plt.fill_between([0, 1], [tpr_bar, tpr_bar], [1, 1], alpha=0, hatch="\\")

    if fpr_bar != -1:
        plt.plot(
            [fpr_bar, fpr_bar],
            [0, 1],
            linestyle="-",
            lw=2,
            color="red",
            label="Wymagany FPR",
            alpha=1,
        )
        plt.fill_between([fpr_bar, 1], [1, 1], alpha=0, hatch="\\")

    plt.legend(loc="lower right")

    plt.ylabel("Odsetek wyników prawdziwie pozytywnych", fontsize=20)
    plt.xlabel("Odsetek wyników prawdziwie negatywnych", fontsize=20)
    plt.xlim(0, 1)
    plt.ylim(0, 1)


def get_calibration_plot(predicted_proba_y, true_y, figsize=(10, 10)):
    """
    Na podstawie przykładu sklearn
    https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_curve.html
    :param figsize: wielkość obrazu wyjściowego
    :param predicted_proba_y: wartości prawdopodobieństwa prognozowane przez model dla każdego przykładu
    :param true_y: rzeczywista wartość etykiety
    :return: krzywa kalibracyjna
    """

    plt.figure(figsize=figsize)
    ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
    ax2 = plt.subplot2grid((3, 1), (2, 0))

    ax1.plot([0, 1], [0, 1], "k:", label="Model idealnie skalibrowany")
    clf_score = brier_score_loss(
        true_y, predicted_proba_y, pos_label=true_y.max()
    )
    print("\tOcena Briera: %1.3f" % clf_score)

    fraction_of_positives, mean_predicted_value = calibration_curve(
        true_y, predicted_proba_y, n_bins=10
    )

    ax1.plot(
        mean_predicted_value,
        fraction_of_positives,
        "s-",
        color="black",
        label="Ocena Briera: %1.3f (0 – najlepsza, 1 – najgorsza)" % clf_score,
    )

    ax2.hist(
        predicted_proba_y,
        range=(0, 1),
        bins=10,
        histtype="step",
        lw=2,
        color="black",
    )

    ax1.set_ylabel("Odsetek pozytywnych wyników")
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])
    ax1.legend(loc="lower right")
    ax1.set_title("Krzywa kalibracyjna")

    ax2.set_title("Rozkład prawdopodobieństwa")
    ax2.set_xlabel("Średnia prognozowana wartość")
    ax2.set_ylabel("Liczba")
    ax2.legend(loc="upper center", ncol=2)

    plt.tight_layout()


def get_metrics(predicted_y, true_y):
    """
    Wyliczenie standardowych wskaźników klasyfikatora binarnego.
    :param predicted_y: wartości prognozowane przez model
    :param true_y: rzeczywiste wartości etykiet
    :return:
    """
    # Prawdziwie pozytywne / (Prawdziwie pozytywne + Fałszywie pozytywne)
    precision = precision_score(
        true_y, predicted_y, pos_label=None, average="weighted"
    )
    # Prawdziwie pozytywne / (Prawdziwie pozytywne + Fałszywie negatywne)
    recall = recall_score(
        true_y, predicted_y, pos_label=None, average="weighted"
    )

    # Średnia harmoniczna precyzji i rozrzutu.
    f1 = f1_score(true_y, predicted_y, pos_label=None, average="weighted")

    # (Prawdziwie pozytywne + Prawdziwie negatywne) / Suma
    accuracy = accuracy_score(true_y, predicted_y)
    return accuracy, precision, recall, f1


def get_feature_importance(clf, feature_names):
    """
    Uzyskanie listy ważności cech dla klasyfikatora.
    :param clf: klasyfikator scikit-learn
    :param feature_names: lista nazw cech ułożonych w odpowiedniej kolejności
    :return: lista nazw cech ułożonych w odpowiedniej kolejności
    """
    importances = clf.feature_importances_
    indices_sorted_by_importance = np.argsort(importances)[::-1]
    return list(
        zip(
            feature_names[indices_sorted_by_importance],
            importances[indices_sorted_by_importance],
        )
    )


def get_top_k(df, proba_col, true_label_col, k=5, decision_threshold=0.5):
    """
    W modelu klasyfikacyjnym funkcja zwraca k najbardziej
    poprawnych, niepoprawnych i niepewnych przykładów
    w każdej klasie.
    :param df: obiekt DataFrame zawierający prognozy i rzeczywiste etykiety
    :param proba_col: nazwa kolumny zawierającej prognozowane wartości prawdopodobieństwa
    :param true_label_col: nazwa kolumny zawierającej rzeczywiste etykiety
    :param k: liczba przykładów pokazywanych w każdej kategorii
    :param decision_threshold: granica decyzyjna klasyfikująca przykład jako pozytywny
    :return: correct_pos, correct_neg, incorrect_pos, incorrect_neg, unsure
    """
    # Uzyskanie poprawnej i niepoprawnej prognozy.
    correct = df[
        (df[proba_col] > decision_threshold) == df[true_label_col]
    ].copy()
    incorrect = df[
        (df[proba_col] > decision_threshold) != df[true_label_col]
    ].copy()

    top_correct_positive = correct[correct[true_label_col]].nlargest(
        k, proba_col
    )
    top_correct_negative = correct[~correct[true_label_col]].nsmallest(
        k, proba_col
    )

    top_incorrect_positive = incorrect[incorrect[true_label_col]].nsmallest(
        k, proba_col
    )
    top_incorrect_negative = incorrect[~incorrect[true_label_col]].nlargest(
        k, proba_col
    )

    # Uzyskanie przykładów najbliższych progowi decyzyjnemu.
    most_uncertain = df.iloc[
        (df[proba_col] - decision_threshold).abs().argsort()[:k]
    ]

    return (
        top_correct_positive,
        top_correct_negative,
        top_incorrect_positive,
        top_incorrect_negative,
        most_uncertain,
    )
