-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathDCN.py
142 lines (112 loc) · 5.3 KB
/
DCN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import numbers
import numpy as np
import torch.nn as nn
from kmeans import batch_KMeans
from autoencoder import AutoEncoder
class DCN(nn.Module):
def __init__(self, args):
super(DCN, self).__init__()
self.args = args
self.beta = args.beta # coefficient of the clustering term
self.lamda = args.lamda # coefficient of the reconstruction term
self.device = torch.device('cuda' if args.cuda else 'cpu')
# Validation check
if not self.beta > 0:
msg = 'beta should be greater than 0 but got value = {}.'
raise ValueError(msg.format(self.beta))
if not self.lamda > 0:
msg = 'lambda should be greater than 0 but got value = {}.'
raise ValueError(msg.format(self.lamda))
if len(self.args.hidden_dims) == 0:
raise ValueError('No hidden layer specified.')
self.kmeans = batch_KMeans(args)
self.autoencoder = AutoEncoder(args).to(self.device)
self.criterion = nn.MSELoss()
self.optimizer = torch.optim.Adam(self.parameters(),
lr=args.lr,
weight_decay=args.wd)
""" Compute the Equation (5) in the original paper on a data batch """
def _loss(self, X, cluster_id):
batch_size = X.size()[0]
rec_X = self.autoencoder(X)
latent_X = self.autoencoder(X, latent=True)
# Reconstruction error
rec_loss = self.lamda * self.criterion(X, rec_X)
# Regularization term on clustering
dist_loss = torch.tensor(0.).to(self.device)
clusters = torch.FloatTensor(self.kmeans.clusters).to(self.device)
for i in range(batch_size):
diff_vec = latent_X[i] - clusters[cluster_id[i]]
sample_dist_loss = torch.matmul(diff_vec.view(1, -1),
diff_vec.view(-1, 1))
dist_loss += 0.5 * self.beta * torch.squeeze(sample_dist_loss)
return (rec_loss + dist_loss,
rec_loss.detach().cpu().numpy(),
dist_loss.detach().cpu().numpy())
def pretrain(self, train_loader, epoch=100, verbose=True):
if not self.args.pretrain:
return
if not isinstance(epoch, numbers.Integral):
msg = '`epoch` should be an integer but got value = {}'
raise ValueError(msg.format(epoch))
if verbose:
print('========== Start pretraining ==========')
rec_loss_list = []
self.train()
for e in range(epoch):
for batch_idx, (data, _) in enumerate(train_loader):
batch_size = data.size()[0]
data = data.to(self.device).view(batch_size, -1)
rec_X = self.autoencoder(data)
loss = self.criterion(data, rec_X)
if verbose and batch_idx % self.args.log_interval == 0:
msg = 'Epoch: {:02d} | Batch: {:03d} | Rec-Loss: {:.3f}'
print(msg.format(e, batch_idx,
loss.detach().cpu().numpy()))
rec_loss_list.append(loss.detach().cpu().numpy())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.eval()
if verbose:
print('========== End pretraining ==========\n')
# Initialize clusters in self.kmeans after pre-training
batch_X = []
for batch_idx, (data, _) in enumerate(train_loader):
batch_size = data.size()[0]
data = data.to(self.device).view(batch_size, -1)
latent_X = self.autoencoder(data, latent=True)
batch_X.append(latent_X.detach().cpu().numpy())
batch_X = np.vstack(batch_X)
self.kmeans.init_cluster(batch_X)
return rec_loss_list
def fit(self, epoch, train_loader, verbose=True):
for batch_idx, (data, _) in enumerate(train_loader):
batch_size = data.size()[0]
data = data.view(batch_size, -1).to(self.device)
# Get the latent features
with torch.no_grad():
latent_X = self.autoencoder(data, latent=True)
latent_X = latent_X.cpu().numpy()
# [Step-1] Update the assignment results
cluster_id = self.kmeans.update_assign(latent_X)
# [Step-2] Update clusters in bath Kmeans
elem_count = np.bincount(cluster_id,
minlength=self.args.n_clusters)
for k in range(self.args.n_clusters):
# avoid empty slicing
if elem_count[k] == 0:
continue
self.kmeans.update_cluster(latent_X[cluster_id == k], k)
# [Step-3] Update the network parameters
loss, rec_loss, dist_loss = self._loss(data, cluster_id)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if verbose and batch_idx % self.args.log_interval == 0:
msg = 'Epoch: {:02d} | Batch: {:03d} | Loss: {:.3f} | Rec-' \
'Loss: {:.3f} | Dist-Loss: {:.3f}'
print(msg.format(epoch, batch_idx,
loss.detach().cpu().numpy(),
rec_loss, dist_loss))