#
#  Plik:  mnist_cnn_fcn.py
#
#  Tworzy pełny model splotowy MNIST i zapełnia go wagami
#  z modelu wytrenowanego za pomocą cyfr MNIST.
#
#  RTK, 20.10.2019
#  Ostatnia aktualizacja:  04.03.2022
#
################################################################

from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D

#  Wczytuje wagi z modelu referencyjnego
weights = load_model('mnist_cnn_base_model.h5').get_weights()

#  Tworzy identyczną strukturę, w której warstwy gęste zostają zastąpione
#  pełnymi warstwami splotowymi
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),    # Dowolne wymiary sygnału wejściowego
                 activation='relu',         # ale ma on jedynie odcienie szarości
                 input_shape=(None,None,1)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

#  Warstwa gęsta staje się warstwą Conv2D zawierającą jądro 12x12 i 128 filtrów
model.add(Conv2D(128, (12,12), activation='relu'))
model.add(Dropout(0.5))

#  Warstwa wyjściowa to również Conv2D, ale zawiera jądro 1x1 i 10 "filtrów"
model.add(Conv2D(10, (1,1), activation='softmax'))

#  W razie potrzeby skopiuj odwzorowania wytrenowanych wag
model.layers[0].set_weights([weights[0], weights[1]])
model.layers[1].set_weights([weights[2], weights[3]])
model.layers[4].set_weights([weights[4].reshape([12,12,64,128]), weights[5]])
model.layers[6].set_weights([weights[6].reshape([1,1,128,10]), weights[7]])

#  Wynikowy pełny model splotowy
model.compile(optimizer='adam', loss='binary_crossentropy')
model.save('mnist_cnn_fcn_model.h5')

