package org.deeplearning4j.examples.unsupervised.variational;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.examples.unsupervised.variational.plot.PlotUtil;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * A simple example of training a variational autoencoder on MNIST.
 * This example intentionally has a small hidden state Z (2 values) for visualization on a 2-grid.
 *
 * After training, this example plots 2 things:
 * 1. The MNIST digit reconstructions vs. the latent space
 * 2. The latent space values for the MNIST test set, as training progresses (every N minibatches)
 *
 * Note that for both plots, there is a slider at the top - change this to see how the reconstructions and latent
 * space changes over time.
 *
 * @author Alex Black
 */
public class VariationalAutoEncoderExample {
    private static final Logger log = LoggerFactory.getLogger(VariationalAutoEncoderExample.class);

    public static void main(String[] args) throws IOException {
        int minibatchSize = 128;
        int rngSeed = 12345;
        int nEpochs = 20;                   // Całkowita liczba epok treningowych.

        // Konfiguracja wykresu.
        int plotEveryNMinibatches = 100;    // Częstotliwość zbierania danych do późniejszego wyświetlenia.
        double plotMin = -5;                // Minimalna wartość na wykresie (osie x i y).
        double plotMax = 5;                 // Maksymalna wartość na wykresie (osie x i y).
        int plotNumSteps = 16;              // Liczba kroków rekonstrukcji pomiędzy wartościami plotMin i plotMax.

        // Zbiór MNIST do treningu.
        DataSetIterator trainIter = new MnistDataSetIterator(minibatchSize, true, rngSeed);

        // Konfiguracja sieci neuronowej.
        Nd4j.getRandom().setSeed(rngSeed);
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(rngSeed)
            .iterations(1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .learningRate(1e-2)
            .updater(Updater.RMSPROP).rmsDecay(0.95)
            .weightInit(WeightInit.XAVIER)
            .regularization(true).l2(1e-4)
            .list()
            .layer(0, new VariationalAutoencoder.Builder()
                .activation(Activation.LEAKYRELU)
                .encoderLayerSizes(256, 256)        // Dwie warstwy kodujące, każda o rozmiarze 256.
                .decoderLayerSizes(256, 256)        // Dwie warstwy dekodujące, każda o rozmiarze 256.
                .pzxActivationFunction("identity")  // Funkcja aktywacji p(z|data).
                .reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID.getActivationFunction()))     // Rozkład Bernoulliego funkcji p(data|z) (tylko wartości binarne, czyli 0 lub 1).
                .nIn(28 * 28)                       // Wielkość wejściowa: 28x28.
                .nOut(2)                            // Wielkość przestrzeni ukrytej: p(z|x). Dwa wymiary dla wykresu, ale może być ich więcej.
                .build())
            .pretrain(true).backprop(false).build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        // Utworzenie warstwy autokodera wariacyjnego.
        org.deeplearning4j.nn.layers.variational.VariationalAutoencoder vae
            = (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0);


        // Dane testowe dla wykresu.
        DataSet testdata = new MnistDataSetIterator(10000, false, rngSeed).next();
        INDArray testFeatures = testdata.getFeatures();
        INDArray testLabels = testdata.getLabels();
        INDArray latentSpaceGrid = getLatentSpaceGrid(plotMin, plotMax, plotNumSteps);  // Wartości X/Y siatki, między plotMin a plotMax.

        // Listy z danymi do późniejszego wykreślenia.
        List<INDArray> latentSpaceVsEpoch = new ArrayList<>(nEpochs + 1);
        INDArray latentSpaceValues = vae.activate(testFeatures, false);     // Zebranie i zarejestrowanie wartości przestrzeni ukrytej przed rozpoczęciem treningu.
        latentSpaceVsEpoch.add(latentSpaceValues);
        List<INDArray> digitsGrid = new ArrayList<>();

        // Trening.
        int iterationCount = 0;
        for (int i = 0; i < nEpochs; i++) {
            log.info("Początek epoki {} / {}",(i+1),nEpochs);
            while (trainIter.hasNext()) {
                DataSet ds = trainIter.next();
                net.fit(ds);

                // Co N=100 minipaczek:
                // (a) zbierane są wartości przestrzeni ukrytej do późniejszego wykreślenia,
                // (b) zbierane są rekonstrukcje w każdym punkcie siatki.
                if (iterationCount++ % plotEveryNMinibatches == 0) {
                    latentSpaceValues = vae.activate(testFeatures, false);
                    latentSpaceVsEpoch.add(latentSpaceValues);

                    INDArray out = vae.generateAtMeanGivenZ(latentSpaceGrid);
                    digitsGrid.add(out);
                }
            }

            trainIter.reset();
        }

        // Wykres zbioru testowego — przestrzeń ukryta vs. iteracje (domyślnie co 100 minipaczek).
        PlotUtil.plotData(latentSpaceVsEpoch, testLabels, plotMin, plotMax, plotEveryNMinibatches);

        // Wykres rekonstrukcji — przestrzeń ukryta vs. siatka.
        double imageScale = 2.0;  // Zwiększ lub zmniejsz tę wartość, aby zmienić skalę cyfr.
        PlotUtil.MNISTLatentSpaceVisualizer v = new PlotUtil.MNISTLatentSpaceVisualizer(imageScale, digitsGrid, plotEveryNMinibatches);
        v.visualize();
    }


    // Utworzenie dwuwymiarowej siatki: x od plotMin do plotMax, y od plotMin do plotMax.
    private static INDArray getLatentSpaceGrid(double plotMin, double plotMax, int plotSteps) {
        INDArray data = Nd4j.create(plotSteps * plotSteps, 2);
        INDArray linspaceRow = Nd4j.linspace(plotMin, plotMax, plotSteps);
        for (int i = 0; i < plotSteps; i++) {
            data.get(NDArrayIndex.interval(i * plotSteps, (i + 1) * plotSteps), NDArrayIndex.point(0)).assign(linspaceRow);
            int yStart = plotSteps - i - 1;
            data.get(NDArrayIndex.interval(yStart * plotSteps, (yStart + 1) * plotSteps), NDArrayIndex.point(1)).assign(linspaceRow.getDouble(i));
        }
        return data;
    }
}
