@@ -90,8 +90,10 @@ def __init__(self, params):
90
90
91
91
if self .params ['model' ] == 'Gauss2D' :
92
92
self .h = Mixture_Gaussian_encoder (params )
93
+ self .q = Mixture_Gaussian_encoder (params )
93
94
elif self .params ['model' ] == 'MNIST' :
94
95
self .h = MNIST_encoder (params )
96
+ self .q = MNIST_encoder (params )
95
97
else :
96
98
raise NameError ('Unknown model ' + self .params ['model' ])
97
99
@@ -111,7 +113,7 @@ def __init__(self, params):
111
113
)
112
114
113
115
self .f = torch .nn .Sequential (
114
- torch .nn .Linear (self .g_dim + 2 * self .h_dim , H ),
116
+ torch .nn .Linear (self .g_dim + self .h_dim , H ),
115
117
torch .nn .PReLU (),
116
118
torch .nn .Linear (H , H ),
117
119
torch .nn .PReLU (),
@@ -136,8 +138,7 @@ def forward(self,data, cs, n):
136
138
assert (n == self .previous_n + 1 )
137
139
self .previous_n = self .previous_n + 1
138
140
139
- K = len (set (cs [:n ])) #num of already _assigned_clusters
140
- # K is the number of distinct classes in [0:n]
141
+ K = len (set (cs [:n ])) # num of already created clusters
141
142
142
143
if n == 1 :
143
144
@@ -159,11 +160,13 @@ def forward(self,data, cs, n):
159
160
160
161
161
162
self .hs = self .h (data ).view ([self .batch_size ,self .N , self .h_dim ])
162
- self .Q = self .hs [:,2 :,].sum (dim = 1 ) #[batch_size,h_dim]
163
-
164
163
self .Hs = torch .zeros ([self .batch_size , 1 , self .h_dim ]).to (self .device )
165
164
self .Hs [:,0 ,:] = self .hs [:,0 ,:]
166
165
166
+ self .qs = self .q (data ).view ([self .batch_size ,self .N , self .h_dim ])
167
+ self .Q = self .qs [:,2 :,].sum (dim = 1 ) #[batch_size,h_dim]
168
+
169
+
167
170
168
171
else :
169
172
if K == self .previous_K :
@@ -177,7 +180,7 @@ def forward(self,data, cs, n):
177
180
self .previous_n = 0
178
181
179
182
else :
180
- self .Q -= self .hs [:,n ,]
183
+ self .Q -= self .qs [:,n ,]
181
184
182
185
183
186
self .previous_K = K
@@ -196,7 +199,7 @@ def forward(self,data, cs, n):
196
199
gs = self .g (Hs2 ).view ([self .batch_size , K , self .g_dim ])
197
200
Gk = gs .sum (dim = 1 ) #[batch_size,g_dim]
198
201
199
- uu = torch .cat ((Gk ,self .Q , self . hs [:, n ,:] ), dim = 1 ) #prepare argument for the call to f()
202
+ uu = torch .cat ((Gk ,self .Q ), dim = 1 ) #prepare argument for the call to f()
200
203
logprobs [:,k ] = torch .squeeze (self .f (uu ))
201
204
202
205
@@ -208,7 +211,7 @@ def forward(self,data, cs, n):
208
211
209
212
Gk = gs .sum (dim = 1 )
210
213
211
- uu = torch .cat ((Gk ,self .Q , self . hs [:, n ,:] ), dim = 1 ) #prepare argument for the call to f()
214
+ uu = torch .cat ((Gk ,self .Q ), dim = 1 ) #prepare argument for the call to f()
212
215
logprobs [:,K ] = torch .squeeze (self .f (uu ))
213
216
214
217
0 commit comments