@@ -59,7 +59,7 @@ def build_model(self):
59
59
self .z = tf .placeholder (tf .float32 , [None , self .z_dim ])
60
60
61
61
self .image_ = self .generator (self .z )
62
- self .sampler = self .generator (self .z )
62
+ self .sampler = self .sampler (self .z )
63
63
64
64
self .D = self .discriminator (self .image )
65
65
self .D_ = self .discriminator (self .image_ , reuse = True )
@@ -207,36 +207,37 @@ def generator(self, z, y=None):
207
207
def sampler (self , z , y = None ):
208
208
tf .get_variable_scope ().reuse_variables ()
209
209
210
- if y :
210
+ if self . y_dim :
211
211
yb = tf .reshape (y , [None , 1 , 1 , self .y_dim ])
212
212
z = tf .concat (1 , [z , y ])
213
213
214
- h0 = tf .nn .relu (self .bn0 (linear (z , self .gfc_dim , 's_h0_lin ' )))
214
+ h0 = tf .nn .relu (self .bn0 (linear (z , self .gfc_dim , 'g_h0_lin ' )))
215
215
h0 = tf .concat (1 , [h0 , y ])
216
216
217
- h1 = tf .nn .relu (self .g_bn1 (linear (z , self .gf_dim * 2 * 7 * 7 , 's_h1_lin ' )))
217
+ h1 = tf .nn .relu (self .g_bn1 (linear (z , self .gf_dim * 2 * 7 * 7 , 'g_h1_lin ' )))
218
218
h1 = tf .reshape (h1 , [None , 7 , 7 , self .gf_dim * 2 ])
219
219
h1 = conv_cond_concat (h1 , yb )
220
220
221
- h2 = tf .nn .relu (self .bn2 (deconv2d (h1 , self .gf_dim , name = 'h2 ' )))
221
+ h2 = tf .nn .relu (self .bn2 (deconv2d (h1 , self .gf_dim , name = 'g_h2 ' )))
222
222
h2 = conv_cond_concat (h2 , yb )
223
223
224
- return tf .nn .sigmoid (deconv2d (h2 , self .c_dim , name = 'h3 ' ))
224
+ return tf .nn .sigmoid (deconv2d (h2 , self .c_dim , name = 'g_h3 ' ))
225
225
else :
226
- h0 = tf .nn .relu (self .g_bn0 (linear (z , self .gf_dim * 8 * 4 * 4 , 's_h0_lin' ),
227
- train = False ))
228
- h0 = tf .reshape (h1 , [None , 4 , 4 , self .gf_dim * 8 ])
226
+ # project `z` and reshape
227
+ h0 = tf .reshape (linear (z , self .gf_dim * 8 * 4 * 4 , 'g_h0_lin' ),
228
+ [- 1 , 4 , 4 , self .gf_dim * 8 ])
229
+ h0 = tf .nn .relu (self .g_bn0 (h0 , train = False ))
229
230
230
- h1 = deconv2d (h0 , [None , 8 , 8 , self .gf_dim * 4 ], name = 'h1 ' )
231
- h1 = tf .relu (self .g_bn1 (h1 , train = False ))
231
+ h1 = deconv2d (h0 , [self . batch_size , 8 , 8 , self .gf_dim * 4 ], name = 'g_h1 ' )
232
+ h1 = tf .nn . relu (self .g_bn1 (h1 , train = False ))
232
233
233
- h2 = deconv2d (h1 , [None , 16 , 16 , self .gf_dim * 2 ], name = 'h2 ' )
234
- h2 = tf .relu (self .g_bn2 (h2 , train = False ))
234
+ h2 = deconv2d (h1 , [self . batch_size , 16 , 16 , self .gf_dim * 2 ], name = 'g_h2 ' )
235
+ h2 = tf .nn . relu (self .g_bn2 (h2 , train = False ))
235
236
236
- h3 = deconv2d (h2 , [None , 16 , 16 , self .gf_dim * 1 ], name = 'h3 ' )
237
- h3 = tf .relu (self .g_bn3 (h3 , train = False ))
237
+ h3 = deconv2d (h2 , [self . batch_size , 16 , 16 , self .gf_dim * 1 ], name = 'g_h3 ' )
238
+ h3 = tf .nn . relu (self .g_bn3 (h3 , train = False ))
238
239
239
- h4 = deconv2d (h3 , [None , 64 , 64 , 3 ], name = 'h4 ' )
240
+ h4 = deconv2d (h3 , [None , 64 , 64 , 3 ], name = 'g_h4 ' )
240
241
241
242
return tf .nn .tanh (h4 )
242
243
0 commit comments