-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathvariational_autoencoder.py
111 lines (91 loc) · 4.17 KB
/
variational_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
from sklearn.datasets import load_digits, fetch_openml
from nn_layers import FullyConnect, Activation
import matplotlib.pyplot as plt
class VAE(object):
def __init__(self, dim_in, dim_hidden, dim_z):
self.n_epochs, self.batch_size = 10, 32
self.C = 1 # trade off of reconstruction and KL divergence
# architecture is hard-coded
self.encoder_hidden = FullyConnect([dim_in], [dim_hidden], lr=1e-2)
self.encoder_act = Activation(act_type='ReLU')
self.encoder_mu = FullyConnect([dim_hidden], [dim_z], lr=1e-2)
self.encoder_log_sigma = FullyConnect([dim_hidden], [dim_z], lr=1e-2)
self.decoder_hidden = FullyConnect([dim_z], [dim_hidden], lr=1e-2)
self.decoder_act_hidden = Activation(act_type='ReLU')
self.decoder_out = FullyConnect([dim_hidden], [dim_in], lr=1e-2)
self.decoder_act_out = Activation(act_type='Sigmoid')
def fit(self, x):
for epoch in range(self.n_epochs):
permut = np.random.permutation(
x.shape[0] // self.batch_size * self.batch_size
).reshape([-1, self.batch_size])
for b_idx in range(permut.shape[0]):
x_batch = x[permut[b_idx, :]]
mu, log_sigma = self.encoder_forward(x_batch)
z = self.sampling(mu, log_sigma)
out = self.decoder_forward(z)
recon_grad = self.C * (out - x_batch)
grad_d_act_out = self.decoder_act_out.gradient(recon_grad)
grad_d_out = self.decoder_out.gradient(grad_d_act_out)
grad_d_act_hidden = self.decoder_act_hidden.gradient(
grad_d_out)
grad_z = self.decoder_hidden.gradient(grad_d_act_hidden)
kl_mu_grad = mu
kl_sigma_grad = np.exp(2 * log_sigma) - 1
grad_mu = self.encoder_mu.gradient(grad_z + kl_mu_grad)
grad_log_sigma = self.encoder_log_sigma.gradient(
grad_z + kl_sigma_grad)
grad_e_act = self.encoder_act.gradient(
grad_mu + grad_log_sigma)
grad_e_hidden = self.encoder_hidden.gradient(grad_e_act)
self.backward()
print('epoch: {}, log loss: {}, kl loss: {}'.format(
epoch, self.log_loss(out, x_batch), self.kl_loss(mu, log_sigma)
))
def encoder_forward(self, x):
hidden = self.encoder_hidden.forward(x)
hidden = self.encoder_act.forward(hidden)
mu = self.encoder_mu.forward(hidden)
log_sigma = self.encoder_log_sigma.forward(hidden)
return mu, log_sigma
def sampling(self, mu, log_sigma):
noise = np.random.randn(mu.shape[0], mu.shape[1])
return mu + noise * np.exp(log_sigma)
def decoder_forward(self, z):
hidden = self.decoder_hidden.forward(z)
hidden = self.decoder_act_hidden.forward(hidden)
out = self.decoder_out.forward(hidden)
out = self.decoder_act_out.forward(out)
return out
def backward(self):
self.decoder_act_out.backward()
self.decoder_out.backward()
self.decoder_act_hidden.backward()
self.decoder_hidden.backward()
self.encoder_mu.backward()
self.encoder_log_sigma.backward()
self.encoder_act.backward()
self.encoder_hidden.backward()
def log_loss(self, pred, x):
return 0.5 * self.C * np.square(pred - x).mean()
def kl_loss(self, mu, log_sigma):
return 0.5 * (-2 * log_sigma + np.exp(2 * log_sigma) + np.square(mu) - 1).mean()
def main():
#data = load_digits()
#x, y = data.data, data.target
x, _ = fetch_openml('mnist_784', return_X_y=True, data_home="data", as_frame=False)
vae = VAE(x.shape[1], 64, 2)
vae.fit(x / x.max())
n_rows = 11
for i in range(n_rows):
for j in range(n_rows):
plt.subplot(n_rows, n_rows, i * n_rows + j + 1)
plt.imshow(
vae.decoder_forward(
np.array([[(i - n_rows // 2) / 2, (j - n_rows // 2) / 2]])).reshape(28, 28),
cmap='gray', vmin=0, vmax=1
)
plt.show()
if __name__ == "__main__":
main()