Skip to content

Commit 2812edd

Browse files
committed
added separate encodings h and q
1 parent f40e4a0 commit 2812edd

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

ncp.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ def __init__(self, params):
9090

9191
if self.params['model'] == 'Gauss2D':
9292
self.h = Mixture_Gaussian_encoder(params)
93+
self.q = Mixture_Gaussian_encoder(params)
9394
elif self.params['model'] == 'MNIST':
9495
self.h = MNIST_encoder(params)
96+
self.q = MNIST_encoder(params)
9597
else:
9698
raise NameError('Unknown model '+ self.params['model'])
9799

@@ -111,7 +113,7 @@ def __init__(self, params):
111113
)
112114

113115
self.f = torch.nn.Sequential(
114-
torch.nn.Linear(self.g_dim +2*self.h_dim, H),
116+
torch.nn.Linear(self.g_dim +self.h_dim, H),
115117
torch.nn.PReLU(),
116118
torch.nn.Linear(H, H),
117119
torch.nn.PReLU(),
@@ -136,8 +138,7 @@ def forward(self,data, cs, n):
136138
assert(n == self.previous_n+1)
137139
self.previous_n = self.previous_n + 1
138140

139-
K = len(set(cs[:n])) #num of already _assigned_clusters
140-
# K is the number of distinct classes in [0:n]
141+
K = len(set(cs[:n])) # num of already created clusters
141142

142143
if n==1:
143144

@@ -159,11 +160,13 @@ def forward(self,data, cs, n):
159160

160161

161162
self.hs = self.h(data).view([self.batch_size,self.N, self.h_dim])
162-
self.Q = self.hs[:,2:,].sum(dim=1) #[batch_size,h_dim]
163-
164163
self.Hs = torch.zeros([self.batch_size, 1, self.h_dim]).to(self.device)
165164
self.Hs[:,0,:] = self.hs[:,0,:]
166165

166+
self.qs = self.q(data).view([self.batch_size,self.N, self.h_dim])
167+
self.Q = self.qs[:,2:,].sum(dim=1) #[batch_size,h_dim]
168+
169+
167170

168171
else:
169172
if K == self.previous_K:
@@ -177,7 +180,7 @@ def forward(self,data, cs, n):
177180
self.previous_n = 0
178181

179182
else:
180-
self.Q -= self.hs[:,n,]
183+
self.Q -= self.qs[:,n,]
181184

182185

183186
self.previous_K = K
@@ -196,7 +199,7 @@ def forward(self,data, cs, n):
196199
gs = self.g(Hs2).view([self.batch_size, K, self.g_dim])
197200
Gk = gs.sum(dim=1) #[batch_size,g_dim]
198201

199-
uu = torch.cat((Gk,self.Q,self.hs[:,n,:]), dim=1) #prepare argument for the call to f()
202+
uu = torch.cat((Gk,self.Q), dim=1) #prepare argument for the call to f()
200203
logprobs[:,k] = torch.squeeze(self.f(uu))
201204

202205

@@ -208,7 +211,7 @@ def forward(self,data, cs, n):
208211

209212
Gk = gs.sum(dim=1)
210213

211-
uu = torch.cat((Gk,self.Q,self.hs[:,n,:]), dim=1) #prepare argument for the call to f()
214+
uu = torch.cat((Gk,self.Q), dim=1) #prepare argument for the call to f()
212215
logprobs[:,K] = torch.squeeze(self.f(uu))
213216

214217

ncp_sampler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, model, data):
3030
data = data.view([self.N, 28,28])
3131

3232
self.hs = model.h(data)
33-
self.qs = self.hs
33+
self.qs = model.q(data)
3434

3535

3636
self.f = model.f
@@ -91,7 +91,7 @@ def sample(self, S):
9191

9292
logprobs = torch.zeros([S, maxK+1]).to(self.device)
9393
rQ = self.Q.repeat(S,1)
94-
rhn = self.hs[n,:].unsqueeze(0).repeat(S,1)
94+
9595

9696

9797
for k in range(maxK+1):
@@ -110,7 +110,7 @@ def sample(self, S):
110110

111111
Gk = gs.sum(dim=1)
112112

113-
uu = torch.cat((Gk,rQ,rhn), dim=1)
113+
uu = torch.cat((Gk,rQ), dim=1)
114114
logprobs[:,k] = torch.squeeze(self.f(uu))
115115

116116

0 commit comments

Comments
 (0)