import numpy as np
from sklearn.datasets import load_digits, fetch_openml
from nn_layers import FullyConnect, Activation
import matplotlib.pyplot as plt


class VAE(object):

    def __init__(self, dim_in, dim_hidden, dim_z):
        self.n_epochs, self.batch_size = 10, 32
        self.C = 1  # trade off of reconstruction and KL divergence

        # architecture is hard-coded
        self.encoder_hidden = FullyConnect([dim_in], [dim_hidden], lr=1e-2)
        self.encoder_act = Activation(act_type='ReLU')
        self.encoder_mu = FullyConnect([dim_hidden], [dim_z], lr=1e-2)
        self.encoder_log_sigma = FullyConnect([dim_hidden], [dim_z], lr=1e-2)

        self.decoder_hidden = FullyConnect([dim_z], [dim_hidden], lr=1e-2)
        self.decoder_act_hidden = Activation(act_type='ReLU')
        self.decoder_out = FullyConnect([dim_hidden], [dim_in], lr=1e-2)
        self.decoder_act_out = Activation(act_type='Sigmoid')

    def fit(self, x):
        for epoch in range(self.n_epochs):
            permut = np.random.permutation(
                x.shape[0] // self.batch_size * self.batch_size
            ).reshape([-1, self.batch_size])
            for b_idx in range(permut.shape[0]):
                x_batch = x[permut[b_idx, :]]
                mu, log_sigma = self.encoder_forward(x_batch)
                z = self.sampling(mu, log_sigma)
                out = self.decoder_forward(z)

                recon_grad = self.C * (out - x_batch)
                grad_d_act_out = self.decoder_act_out.gradient(recon_grad)
                grad_d_out = self.decoder_out.gradient(grad_d_act_out)
                grad_d_act_hidden = self.decoder_act_hidden.gradient(
                    grad_d_out)
                grad_z = self.decoder_hidden.gradient(grad_d_act_hidden)

                kl_mu_grad = mu
                kl_sigma_grad = np.exp(2 * log_sigma) - 1
                grad_mu = self.encoder_mu.gradient(grad_z + kl_mu_grad)
                grad_log_sigma = self.encoder_log_sigma.gradient(
                    grad_z + kl_sigma_grad)
                grad_e_act = self.encoder_act.gradient(
                    grad_mu + grad_log_sigma)
                grad_e_hidden = self.encoder_hidden.gradient(grad_e_act)

                self.backward()
            print('epoch: {}, log loss: {}, kl loss: {}'.format(
                epoch, self.log_loss(out, x_batch), self.kl_loss(mu, log_sigma)
            ))

    def encoder_forward(self, x):
        hidden = self.encoder_hidden.forward(x)
        hidden = self.encoder_act.forward(hidden)
        mu = self.encoder_mu.forward(hidden)
        log_sigma = self.encoder_log_sigma.forward(hidden)
        return mu, log_sigma

    def sampling(self, mu, log_sigma):
        noise = np.random.randn(mu.shape[0], mu.shape[1])
        return mu + noise * np.exp(log_sigma)

    def decoder_forward(self, z):
        hidden = self.decoder_hidden.forward(z)
        hidden = self.decoder_act_hidden.forward(hidden)
        out = self.decoder_out.forward(hidden)
        out = self.decoder_act_out.forward(out)
        return out

    def backward(self):
        self.decoder_act_out.backward()
        self.decoder_out.backward()
        self.decoder_act_hidden.backward()
        self.decoder_hidden.backward()
        self.encoder_mu.backward()
        self.encoder_log_sigma.backward()
        self.encoder_act.backward()
        self.encoder_hidden.backward()

    def log_loss(self, pred, x):
        return 0.5 * self.C * np.square(pred - x).mean()

    def kl_loss(self, mu, log_sigma):
        return 0.5 * (-2 * log_sigma + np.exp(2 * log_sigma) + np.square(mu) - 1).mean()


def main():
    #data = load_digits()
    #x, y = data.data, data.target
    x, _ = fetch_openml('mnist_784', return_X_y=True, data_home="data", as_frame=False)
    vae = VAE(x.shape[1], 64, 2)
    vae.fit(x / x.max())

    n_rows = 11
    for i in range(n_rows):
        for j in range(n_rows):
            plt.subplot(n_rows, n_rows, i * n_rows + j + 1)
            plt.imshow(
                vae.decoder_forward(
                    np.array([[(i - n_rows // 2) / 2, (j - n_rows // 2) / 2]])).reshape(28, 28),
                cmap='gray', vmin=0, vmax=1
            )
    plt.show()


if __name__ == "__main__":
    main()