#
#  Plik:  cifar10_one_vs_many.py
#
#  Porównanie jednego modelu wieloklasowego z wieloma modelami binarnymi
#
#  RTK, 21.10.2019
#  Ostatnia aktualizacja:  06.11.2019
#
################################################################

import numpy as np
from keras.models import load_model

def main():
    x_test = np.load("../data/cifar10/cifar10_test_images.npy")/255.0
    y_test = np.load("../data/cifar10/cifar10_test_labels.npy")

    #  Wczytuje modele
    mm = load_model("cifar10_cnn_model.h5")
    m = []
    for i in range(10):
        m.append(load_model("cifar10_cnn_%d_model.h5" % i))

    #  Predykcje wieloklasowe
    mp = np.argmax(mm.predict(x_test), axis=1)

    #  Pojedyncze predykcje binarne
    p = np.zeros((10,10000), dtype="float32")

    for i in range(10):
        p[i,:] = m[i].predict(x_test)[:,1]

    bp = np.argmax(p, axis=0)

    #  Macierze pomyłek
    cm = np.zeros((10,10), dtype="uint16")
    cb = np.zeros((10,10), dtype="uint16")

    for i in range(10000):
        cm[y_test[i],mp[i]] += 1
        cb[y_test[i],bp[i]] += 1

    np.save("cifar10_multiclass_conf_mat.npy", cm)
    np.save("cifar10_binary_conf_mat.npy", cb)

    #  Macierze pomyłek
    print()
    print("Macierz pomyłek strategii jeden przeciw reszcie (wiersze: rzeczywiste, kolumny: przewidywane):")
    print("%s" % np.array2string(100*(cb/1000.0), precision=1))
    print()
    print("Macierz pomyłek modelu wieloklasowego:")
    print("%s"  % np.array2string(100*(cm/1000.0), precision=1))

    #  Wyświetla różnice
    db = np.diag(100*(cb/1000.0))
    dm = np.diag(100*(cm/1000.0))
    df = db - dm
    sb = np.array2string(db, precision=1)[1:-1]
    sm = np.array2string(dm, precision=1)[1:-1]
    sd = np.array2string(df, precision=1)[1:-1]
    print()
    print("Porównanie dokładności dla poszczególnych klas, model jeden przeciw reszcie a model wieloklasowy:")
    print()
    print("    jeden przeciw reszcie: %s" % sb)
    print("    wieloklasowy : %s" % sm)
    print("    różnica : %s" % sd)

    #  Wyświetla dokładność ogólną
    print()
    print("Dokładność ogólna:")
    print("    jeden przeciw reszcie: %0.1f%%" % np.diag(100*(cb/1000.0)).mean())
    print("    wieloklasowy : %0.1f%%" % np.diag(100*(cm/1000.0)).mean())
    print()

main()

