Skip to content

Commit f40e4a0

Browse files
committed
more efficient gradient computation
1 parent c471eb7 commit f40e4a0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

main.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def main(args):
2222
model = args.model
2323
params = get_parameters(model)
2424
params['device'] = torch.device("cuda:0" if args.cuda else "cpu")
25+
26+
print(params['device'])
2527

2628

2729
dpmm = NeuralClustering(params).to(params['device'])
@@ -127,14 +129,14 @@ def main(args):
127129

128130
for perm in range(perms):
129131
arr = np.arange(N)
130-
np.random.shuffle(arr)
132+
np.random.shuffle(arr) # permute the order in which the points are queried
131133
cs = cs[arr]
132134
data= data[:,arr,:]
133135

134136
cs = relabel(cs) # this makes cluster labels appear in cs[] in increasing order
135137

136138

137-
this_loss=0
139+
this_loss=0
138140
dpmm.previous_n=0
139141

140142
for n in range(1,N):
@@ -147,23 +149,21 @@ def main(args):
147149

148150
this_loss -= logprobs[:,c].mean()
149151

150-
152+
153+
this_loss.backward() # this accumulates the gradients for each permutation
151154
loss_values[perm] = this_loss.item()/N
152155
loss += this_loss
153156

154-
155157

156158
perm_vars.append(loss_values.var())
157159
losses.append(loss.item()/N)
158160
accs.append(accuracies.mean())
159-
160-
161+
162+
optimizer.step() # the gradients used in this step are the sum of the gradients for each permutation
161163
optimizer.zero_grad()
162-
loss.backward()
163-
optimizer.step()
164164

165165

166-
print('{0:4d} N:{1:2d} K:{2} Mean NLL:{3:.3f} Mean Acc:{4:.3f} Mean Variance: {5:.7f} Mean Time/Iteration: {6:.1f}'\
166+
print('{0:4d} N:{1:2d} K:{2} Mean NLL:{3:.3f} Mean Acc:{4:.3f} Mean Permutation Variance: {5:.7f} Mean Time/Iteration: {6:.1f}'\
167167
.format(it, N, K , np.mean(losses[-50:]), np.mean(accs[-50:]), np.mean(perm_vars[-50:]), (time.time()-t_start)/(it - itt) ))
168168

169169
break

0 commit comments

Comments
 (0)