/*
 * Decompiled with CFR 0.152.
 */
package org.generallib.deeplearning.neuralnetwork;

import org.generallib.deeplearning.neuralnetwork.ActivationFunction;
import org.generallib.deeplearning.neuralnetwork.InvalidLayerCountException;
import org.generallib.deeplearning.neuralnetwork.NeuralNetworkInitializeException;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public class NeuralNetwork {
    private final int[] layerCounts;
    private final DoubleMatrix[] theta;
    private final int outputRange;
    private ActivationFunction act;

    public static void main(String[] ar) throws Exception {
        NeuralNetwork net = new NeuralNetwork(new int[]{5, 6, 3}, new ActivationFunction(){

            @Override
            public DoubleMatrix activate(DoubleMatrix matrix) {
                return MatrixFunctions.exp(matrix.mul(-1.0)).add(1.0).rdiv(1.0);
            }
        });
        System.out.println(net);
        DoubleMatrix dataset = new DoubleMatrix(new double[][]{{72.0, 69.0, 76.0, 76.0, 79.0}, {72.0, 69.0, 76.0, 76.0, 111.0}, {72.0, 69.0, 76.0, 108.0, 79.0}, {72.0, 69.0, 76.0, 108.0, 111.0}, {72.0, 69.0, 108.0, 76.0, 79.0}, {72.0, 69.0, 108.0, 108.0, 79.0}, {72.0, 69.0, 108.0, 76.0, 111.0}, {104.0, 101.0, 108.0, 108.0, 111.0}, {104.0, 105.0, 32.0, 32.0, 32.0}, {104.0, 73.0, 32.0, 32.0, 32.0}, {72.0, 105.0, 32.0, 32.0, 32.0}, {72.0, 73.0, 32.0, 32.0, 32.0}, {32.0, 104.0, 105.0, 32.0, 32.0}, {32.0, 72.0, 105.0, 32.0, 32.0}, {32.0, 104.0, 73.0, 32.0, 32.0}, {32.0, 72.0, 73.0, 32.0, 32.0}, {104.0, 32.0, 73.0, 32.0, 32.0}});
        dataset = dataset.div(dataset.max() - dataset.min());
        double cost = 0.0;
        int count = 0;
        while (count < 10) {
            net.resetLayers();
            double lambda = 0.001 * (double)count * (double)count;
            int i = 0;
            while (i < 300) {
                cost = net.trainNetwork(dataset, new DoubleMatrix(new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}), lambda);
                ++i;
            }
            System.out.println("\nlambda [" + lambda + "] >>> " + cost + "\n");
            System.out.println("hello: " + net.predict(new DoubleMatrix(new double[][]{{104.0, 101.0, 108.0, 108.0, 111.0}})));
            System.out.println("hi: " + net.predict(new DoubleMatrix(new double[][]{{104.0, 105.0, 32.0, 32.0, 32.0}})));
            System.out.println("happy: " + net.predict(new DoubleMatrix(new double[][]{{104.0, 97.0, 112.0, 112.0, 121.0}})));
            ++count;
        }
    }

    public NeuralNetwork(int[] layerCounts, ActivationFunction act) throws NeuralNetworkInitializeException {
        this.act = act;
        if (layerCounts.length < 3) {
            throw new InvalidLayerCountException();
        }
        this.layerCounts = layerCounts;
        this.theta = new DoubleMatrix[layerCounts.length - 1];
        this.outputRange = layerCounts[layerCounts.length - 1];
        this.resetLayers();
    }

    public void resetLayers() {
        int i = 0;
        while (i < this.layerCounts.length - 1) {
            this.theta[i] = DoubleMatrix.rand(this.layerCounts[i + 1], this.layerCounts[i]);
            this.theta[i] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(this.layerCounts[i + 1]), this.theta[i]);
            ++i;
        }
    }

    public double trainNetwork(DoubleMatrix X, DoubleMatrix y) {
        return this.trainNetwork(X, y, 0.0);
    }

    public double trainNetwork(DoubleMatrix X, DoubleMatrix y, double lambda) {
        int m = X.rows;
        DoubleMatrix[] forward = new DoubleMatrix[this.layerCounts.length];
        forward[0] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(X.rows), X);
        int i = 1;
        while (i < this.layerCounts.length - 1) {
            DoubleMatrix z = forward[i - 1].mmul(this.theta[i - 1].transpose());
            forward[i] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(z.rows), this.act.activate(z));
            ++i;
        }
        DoubleMatrix z = forward[this.layerCounts.length - 2].mmul(this.theta[this.layerCounts.length - 2].transpose());
        forward[this.layerCounts.length - 1] = this.act.activate(z);
        double cost = this.cost(X, y, this.theta, m, lambda);
        DoubleMatrix y_mat = DoubleMatrix.eye(this.outputRange).getRows(y.toIntArray());
        DoubleMatrix[] deltas = new DoubleMatrix[this.layerCounts.length];
        deltas[this.layerCounts.length - 1] = forward[this.layerCounts.length - 1].sub(y_mat);
        int i2 = this.layerCounts.length - 2;
        while (i2 > 0) {
            DoubleMatrix sigGrad = forward[i2].mul(forward[i2].rsub(1.0));
            sigGrad = sigGrad.getRange(0, sigGrad.rows, 1, sigGrad.columns);
            DoubleMatrix thetaTarget = this.theta[i2].getRange(0, this.theta[i2].rows, 1, this.theta[1].columns);
            deltas[i2] = deltas[i2 + 1].mmul(thetaTarget).mul(sigGrad);
            --i2;
        }
        i2 = 0;
        while (i2 < this.theta.length) {
            DoubleMatrix delta = deltas[i2 + 1];
            DoubleMatrix thetaTemp = this.theta[i2].mulColumn(0, 0.0);
            DoubleMatrix grad = delta.transpose().mmul(forward[i2]).mul(1.0 / (double)m).add(thetaTemp.mul(lambda / (double)m));
            grad = grad.mulColumn(0, 0.0);
            this.theta[i2] = this.theta[i2].sub(grad);
            ++i2;
        }
        return cost;
    }

    private double cost(DoubleMatrix X, DoubleMatrix y, DoubleMatrix[] theta, int m, double lambda) {
        DoubleMatrix output = this.predict(X, theta);
        DoubleMatrix y_mat = DoubleMatrix.eye(this.outputRange).getRows(y.toIntArray()).mul(-1.0);
        DoubleMatrix left = new DoubleMatrix();
        left.copy(y_mat);
        left = left.mul(MatrixFunctions.log(output));
        DoubleMatrix right = new DoubleMatrix();
        right.copy(y_mat);
        right = right.add(1.0);
        right = right.mul(MatrixFunctions.log(output.rsub(1.0)));
        DoubleMatrix leftMright = left.sub(right);
        double normalSum = 1.0 / (double)m * leftMright.sum();
        double regularization = 0.0;
        int i = 0;
        while (i < theta.length) {
            DoubleMatrix thetaTemp = new DoubleMatrix();
            thetaTemp.copy(theta[i]);
            thetaTemp.mulColumn(0, 0.0);
            regularization += MatrixFunctions.pow(thetaTemp, 2.0).sum();
            ++i;
        }
        return normalSum + (regularization *= lambda / (2.0 * (double)m));
    }

    public DoubleMatrix predict(DoubleMatrix X) {
        return this.predict(X, this.theta);
    }

    private DoubleMatrix predict(DoubleMatrix X, DoubleMatrix[] theta) {
        DoubleMatrix[] forward = new DoubleMatrix[this.layerCounts.length];
        forward[0] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(X.rows), X);
        int i = 1;
        while (i < this.layerCounts.length - 1) {
            DoubleMatrix z = forward[i - 1].mmul(theta[i - 1].transpose());
            forward[i] = DoubleMatrix.concatHorizontally(DoubleMatrix.ones(z.rows), this.act.activate(z));
            ++i;
        }
        DoubleMatrix z = forward[this.layerCounts.length - 2].mmul(theta[this.layerCounts.length - 2].transpose());
        return this.act.activate(z);
    }

    public String toString() {
        String layerInfo = "[ ";
        int[] nArray = this.layerCounts;
        int n = this.layerCounts.length;
        int n2 = 0;
        while (n2 < n) {
            int layer = nArray[n2];
            layerInfo = String.valueOf(layerInfo) + layer + " ";
            ++n2;
        }
        layerInfo = String.valueOf(layerInfo) + "]";
        String thetaInfo = "";
        int i = 0;
        while (i < this.theta.length) {
            thetaInfo = String.valueOf(thetaInfo) + "Layer" + (i + 1) + " -> " + (i + 2) + " \n" + this.theta[i].toString("%.5f", "", "", " ", "\n") + "\n";
            ++i;
        }
        return "NeuralNetwork -- " + layerInfo + "\n" + thetaInfo;
    }
}

