"""Moduł zawiera dodatkowe funkcje, które będą przydatne w procesie uczenia sieci, analizy wyników i ich wizualizacji"""

import torch as th
import matplotlib.pyplot as plt
from typing import List
from sklearn.metrics import classification_report
from utils import vis_utils


def plot_results(train_losses: List[float], val_losses: List[float], plot_name: str):
    """Funkcja rysująca wykresy straty na zbiorze treningowym i walidacyjnym.

    Parameters
    ----------
    train_losses : List[float]
        Lista wartości funkcji kosztu na zbiorze treningowym.
    val_losses : List[float]
        Lista wartości funkcji kosztu na zbiorze walidacyjnym.
    """
    fig, ax = plt.subplots(figsize=vis_utils.image_size_in_cm(12, 12))
    ax.plot(train_losses, label="Trening", linestyle="-", color="black", linewidth=2)
    ax.plot(val_losses, label="Walidacja", linestyle="--", color="grey", linewidth=2)
    ax.legend(fontsize=12)
    vis_utils.change_plot_comma_separator(ax)
    ax.set_xlabel("Epoka", fontsize=12)
    ax.set_ylabel("Funkcja kosztu", fontsize=12)
    ax.tick_params(axis="both", which="major", labelsize=12)
    plt.tight_layout()
    plt.savefig(plot_name, pil_kwargs={"compression": "tiff_lzw"})
    plt.show()


def val_scores_before_after_train(
    y_val: th.Tensor,
    yhat_val_before: th.Tensor,
    yhat_val_after: th.Tensor,
    model_name: str,
):
    print(f"Wynik klasyfikacji modelu {model_name} przed trenowaniem:")
    print(classification_report(y_val, yhat_val_before.argmax(dim=1).numpy()))
    print("\n\n\n")
    print(f"Wynik klasyfikacji modelu {model_name} po trenowaniu:")
    print(classification_report(y_val, yhat_val_after.argmax(dim=1).numpy()))
