@@ -73,31 +73,37 @@ def build_model(self):
73
73
if self .y_dim :
74
74
self .y = tf .placeholder (tf .float32 , [self .batch_size , self .y_dim ], name = 'y' )
75
75
76
- image_dims = [self . output_height , self . output_width , self .c_dim ]
76
+ image_dims = [None , None , self .c_dim ]
77
77
78
78
self .inputs = tf .placeholder (
79
- tf .float32 , [self .batch_size ] + image_dims name = 'real_images' )
79
+ tf .float32 , [self .batch_size ] + image_dims , name = 'real_images' )
80
80
self .sample_inputs = tf .placeholder (
81
81
tf .float32 , [self .sample_num ] + image_dims , name = 'sample_inputs' )
82
82
83
- inputs = tf .image .resize_images (
84
- self .inputs , [self .output_height , self .output_width ])
85
- sample_inputs = tf .image .resize_images (
86
- self .sample_inputs , [self .output_height , self .output_width ])
83
+ if not self .is_crop :
84
+ inputs = tf .image .resize_images (
85
+ self .inputs , [self .output_height , self .output_width ])
86
+ sample_inputs = tf .image .resize_images (
87
+ self .sample_inputs , [self .output_height , self .output_width ])
88
+ else :
89
+ inputs = self .inputs
90
+ sample_inputs = self .sample_inputs
87
91
88
92
self .z = tf .placeholder (
89
93
tf .float32 , [None , self .z_dim ], name = 'z' )
90
94
self .z_sum = histogram_summary ("z" , self .z )
91
95
92
96
if self .y_dim :
93
97
self .G = self .generator (self .z , self .y )
94
- self .D , self .D_logits = self .discriminator (self .inputs , self .y , reuse = False )
98
+ self .D , self .D_logits = \
99
+ self .discriminator (inputs , self .y , reuse = False )
95
100
96
101
self .sampler = self .sampler (self .z , self .y )
97
- self .D_ , self .D_logits_ = self .discriminator (self .G , self .y , reuse = True )
102
+ self .D_ , self .D_logits_ = \
103
+ self .discriminator (self .G , self .y , reuse = True )
98
104
else :
99
105
self .G = self .generator (self .z )
100
- self .D , self .D_logits = self .discriminator (self . inputs )
106
+ self .D , self .D_logits = self .discriminator (inputs )
101
107
102
108
self .sampler = self .sampler (self .z )
103
109
self .D_ , self .D_logits_ = self .discriminator (self .G , reuse = True )
@@ -144,10 +150,9 @@ def train(self, config):
144
150
g_optim = tf .train .AdamOptimizer (config .learning_rate , beta1 = config .beta1 ) \
145
151
.minimize (self .g_loss , var_list = self .g_vars )
146
152
try :
147
- tf .initialize_all_variables ().run ()
153
+ tf .global_variables_initializer ().run ()
148
154
except :
149
- init_op = tf .global_variables_initializer ()
150
- self .sess .run (init_op )
155
+ tf .initialize_all_variables ().run ()
151
156
152
157
self .g_sum = merge_summary ([self .z_sum , self .d__sum ,
153
158
self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
@@ -198,8 +203,8 @@ def train(self, config):
198
203
batch_files = data [idx * config .batch_size :(idx + 1 )* config .batch_size ]
199
204
batch = [
200
205
get_image (batch_file ,
201
- self .image_height ,
202
- self .image_width ,
206
+ image_height = self .image_height ,
207
+ image_width = self .image_width ,
203
208
resize_height = self .output_height ,
204
209
resize_width = self .output_width ,
205
210
is_crop = self .is_crop ,
@@ -263,8 +268,8 @@ def train(self, config):
263
268
feed_dict = { self .z : batch_z })
264
269
self .writer .add_summary (summary_str , counter )
265
270
266
- errD_fake = self .d_loss_fake .eval ({self .z : batch_z })
267
- errD_real = self .d_loss_real .eval ({self .inputs : batch_images })
271
+ errD_fake = self .d_loss_fake .eval ({ self .z : batch_z })
272
+ errD_real = self .d_loss_real .eval ({ self .inputs : batch_images })
268
273
errG = self .g_loss .eval ({self .z : batch_z })
269
274
270
275
counter += 1
@@ -325,10 +330,10 @@ def discriminator(self, image, y=None, reuse=False):
325
330
326
331
h1 = lrelu (self .d_bn1 (conv2d (h0 , self .df_dim + self .y_dim , name = 'd_h1_conv' )))
327
332
h1 = tf .reshape (h1 , [self .batch_size , - 1 ])
328
- h1 = tf .concat ( 1 , [h1 , y ])
333
+ h1 = tf .concat_v2 ( [h1 , y ], 1 )
329
334
330
335
h2 = lrelu (self .d_bn2 (linear (h1 , self .dfc_dim , 'd_h2_lin' )))
331
- h2 = tf .concat ( 1 , [h2 , y ])
336
+ h2 = tf .concat_v2 ( [h2 , y ], 1 )
332
337
333
338
h3 = linear (h2 , 1 , 'd_h3_lin' )
334
339
@@ -337,100 +342,114 @@ def discriminator(self, image, y=None, reuse=False):
337
342
def generator (self , z , y = None ):
338
343
with tf .variable_scope ("generator" ) as scope :
339
344
if not self .y_dim :
340
- s = self .output_size
341
- s2 , s4 , s8 , s16 = int (s / 2 ), int (s / 4 ), int (s / 8 ), int (s / 16 )
345
+ s_h , s_w = self .output_height , self .output_width
346
+ s_h2 , s_h4 , s_h8 , s_h16 = \
347
+ int (s_h / 2 ), int (s_h / 4 ), int (s_h / 8 ), int (s_h / 16 )
348
+ s_w2 , s_w4 , s_w8 , s_w16 = \
349
+ int (s_w / 2 ), int (s_w / 4 ), int (s_w / 8 ), int (s_w / 16 )
342
350
343
351
# project `z` and reshape
344
- self .z_ , self .h0_w , self .h0_b = linear (z , self .gf_dim * 8 * s16 * s16 , 'g_h0_lin' , with_w = True )
352
+ self .z_ , self .h0_w , self .h0_b = linear (
353
+ z , self .gf_dim * 8 * s_h16 * s_w16 , 'g_h0_lin' , with_w = True )
345
354
346
- self .h0 = tf .reshape (self .z_ , [- 1 , s16 , s16 , self .gf_dim * 8 ])
355
+ self .h0 = tf .reshape (
356
+ self .z_ , [- 1 , s_h16 , s_w16 , self .gf_dim * 8 ])
347
357
h0 = tf .nn .relu (self .g_bn0 (self .h0 ))
348
358
349
- self .h1 , self .h1_w , self .h1_b = deconv2d (h0 ,
350
- [self .batch_size , s8 , s8 , self .gf_dim * 4 ], name = 'g_h1' , with_w = True )
359
+ self .h1 , self .h1_w , self .h1_b = deconv2d (
360
+ h0 , [self .batch_size , s_h8 , s_w8 , self .gf_dim * 4 ], name = 'g_h1' , with_w = True )
351
361
h1 = tf .nn .relu (self .g_bn1 (self .h1 ))
352
362
353
- h2 , self .h2_w , self .h2_b = deconv2d (h1 ,
354
- [self .batch_size , s4 , s4 , self .gf_dim * 2 ], name = 'g_h2' , with_w = True )
363
+ h2 , self .h2_w , self .h2_b = deconv2d (
364
+ h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ], name = 'g_h2' , with_w = True )
355
365
h2 = tf .nn .relu (self .g_bn2 (h2 ))
356
366
357
- h3 , self .h3_w , self .h3_b = deconv2d (h2 ,
358
- [self .batch_size , s2 , s2 , self .gf_dim * 1 ], name = 'g_h3' , with_w = True )
367
+ h3 , self .h3_w , self .h3_b = deconv2d (
368
+ h2 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 1 ], name = 'g_h3' , with_w = True )
359
369
h3 = tf .nn .relu (self .g_bn3 (h3 ))
360
370
361
- h4 , self .h4_w , self .h4_b = deconv2d (h3 ,
362
- [self .batch_size , s , s , self .c_dim ], name = 'g_h4' , with_w = True )
371
+ h4 , self .h4_w , self .h4_b = deconv2d (
372
+ h3 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h4' , with_w = True )
363
373
364
374
return tf .nn .tanh (h4 )
365
375
else :
366
- s = self .output_size
367
- s2 , s4 = int (s / 2 ), int (s / 4 )
376
+ s_h , s_w = self .output_height , self .output_width
377
+ s_h2 , s_h4 = int (s_h / 2 ), int (s_h / 4 )
378
+ s_w2 , s_w4 = int (s_w / 2 ), int (s_w / 4 )
368
379
369
380
# yb = tf.expand_dims(tf.expand_dims(y, 1),2)
370
381
yb = tf .reshape (y , [self .batch_size , 1 , 1 , self .y_dim ])
371
- z = tf .concat ( 1 , [z , y ])
382
+ z = tf .concat_v2 ( [z , y ], 1 )
372
383
373
- h0 = tf .nn .relu (self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
374
- h0 = tf .concat (1 , [h0 , y ])
384
+ h0 = tf .nn .relu (
385
+ self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
386
+ h0 = tf .concat_v2 ([h0 , y ], 1 )
375
387
376
- h1 = tf .nn .relu (self .g_bn1 (linear (h0 , self .gf_dim * 2 * s4 * s4 , 'g_h1_lin' )))
377
- h1 = tf .reshape (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ])
388
+ h1 = tf .nn .relu (self .g_bn1 (
389
+ linear (h0 , self .gf_dim * 2 * s_h4 * s_w4 , 'g_h1_lin' )))
390
+ h1 = tf .reshape (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ])
378
391
379
392
h1 = conv_cond_concat (h1 , yb )
380
393
381
394
h2 = tf .nn .relu (self .g_bn2 (deconv2d (h1 ,
382
- [self .batch_size , s2 , s2 , self .gf_dim * 2 ], name = 'g_h2' )))
395
+ [self .batch_size , s_h2 , s_w2 , self .gf_dim * 2 ], name = 'g_h2' )))
383
396
h2 = conv_cond_concat (h2 , yb )
384
397
385
398
return tf .nn .sigmoid (
386
- deconv2d (h2 , [self .batch_size , s , s , self .c_dim ], name = 'g_h3' ))
399
+ deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
387
400
388
401
def sampler (self , z , y = None ):
389
402
with tf .variable_scope ("generator" ) as scope :
390
403
scope .reuse_variables ()
391
404
392
405
if not self .y_dim :
393
406
394
- s = self .output_size
395
- s2 , s4 , s8 , s16 = int (s / 2 ), int (s / 4 ), int (s / 8 ), int (s / 16 )
407
+ s_h , s_w = self .output_height , self .output_width
408
+ s_h2 , s_h4 , s_h8 , s_h16 = \
409
+ int (s_h / 2 ), int (s_h / 4 ), int (s_h / 8 ), int (s_h / 16 )
410
+ s_w2 , s_w4 , s_w8 , s_w16 = \
411
+ int (s_w / 2 ), int (s_w / 4 ), int (s_w / 8 ), int (s_w / 16 )
396
412
397
413
# project `z` and reshape
398
- h0 = tf .reshape (linear (z , self .gf_dim * 8 * s16 * s16 , 'g_h0_lin' ),
399
- [- 1 , s16 , s16 , self .gf_dim * 8 ])
414
+ h0 = tf .reshape (
415
+ linear (z , self .gf_dim * 8 * s_h16 * s_w16 , 'g_h0_lin' ),
416
+ [- 1 , s_h16 , s_w16 , self .gf_dim * 8 ])
400
417
h0 = tf .nn .relu (self .g_bn0 (h0 , train = False ))
401
418
402
- h1 = deconv2d (h0 , [self .batch_size , s8 , s8 , self .gf_dim * 4 ], name = 'g_h1' )
419
+ h1 = deconv2d (h0 , [self .batch_size , s_h8 , s_w8 , self .gf_dim * 4 ], name = 'g_h1' )
403
420
h1 = tf .nn .relu (self .g_bn1 (h1 , train = False ))
404
421
405
- h2 = deconv2d (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ], name = 'g_h2' )
422
+ h2 = deconv2d (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ], name = 'g_h2' )
406
423
h2 = tf .nn .relu (self .g_bn2 (h2 , train = False ))
407
424
408
- h3 = deconv2d (h2 , [self .batch_size , s2 , s2 , self .gf_dim * 1 ], name = 'g_h3' )
425
+ h3 = deconv2d (h2 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 1 ], name = 'g_h3' )
409
426
h3 = tf .nn .relu (self .g_bn3 (h3 , train = False ))
410
427
411
- h4 = deconv2d (h3 , [self .batch_size , s , s , self .c_dim ], name = 'g_h4' )
428
+ h4 = deconv2d (h3 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h4' )
412
429
413
430
return tf .nn .tanh (h4 )
414
431
else :
415
- s = self .output_size
416
- s2 , s4 = int (s / 2 ), int (s / 4 )
432
+ s_h , s_w = self .output_height , self .output_width
433
+ s_h2 , s_h4 = int (s_h / 2 ), int (s_h / 4 )
434
+ s_w2 , s_w4 = int (s_w / 2 ), int (s_w / 4 )
417
435
418
436
# yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
419
437
yb = tf .reshape (y , [self .batch_size , 1 , 1 , self .y_dim ])
420
- z = tf .concat ( 1 , [z , y ])
438
+ z = tf .concat_v2 ( [z , y ], 1 )
421
439
422
440
h0 = tf .nn .relu (self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
423
- h0 = tf .concat ( 1 , [h0 , y ])
441
+ h0 = tf .concat_v2 ( [h0 , y ], 1 )
424
442
425
- h1 = tf .nn .relu (self .g_bn1 (linear (h0 , self .gf_dim * 2 * s4 * s4 , 'g_h1_lin' ), train = False ))
426
- h1 = tf .reshape (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ])
443
+ h1 = tf .nn .relu (self .g_bn1 (
444
+ linear (h0 , self .gf_dim * 2 * s_h4 * s_w4 , 'g_h1_lin' ), train = False ))
445
+ h1 = tf .reshape (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ])
427
446
h1 = conv_cond_concat (h1 , yb )
428
447
429
448
h2 = tf .nn .relu (self .g_bn2 (
430
- deconv2d (h1 , [self .batch_size , s2 , s2 , self .gf_dim * 2 ], name = 'g_h2' ), train = False ))
449
+ deconv2d (h1 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 2 ], name = 'g_h2' ), train = False ))
431
450
h2 = conv_cond_concat (h2 , yb )
432
451
433
- return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s , s , self .c_dim ], name = 'g_h3' ))
452
+ return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
434
453
435
454
def load_mnist (self ):
436
455
data_dir = os .path .join ("./data" , self .dataset_name )
@@ -468,11 +487,16 @@ def load_mnist(self):
468
487
y_vec [i ,y [i ]] = 1.0
469
488
470
489
return X / 255. ,y_vec
490
+
491
+ @property
492
+ def model_dir (self ):
493
+ return "{}_{}_{}_{}" .format (
494
+ self .dataset_name , self .batch_size ,
495
+ self .output_height , self .output_width )
471
496
472
497
def save (self , checkpoint_dir , step ):
473
498
model_name = "DCGAN.model"
474
- model_dir = "%s_%s_%s" % (self .dataset_name , self .batch_size , self .output_size )
475
- checkpoint_dir = os .path .join (checkpoint_dir , model_dir )
499
+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
476
500
477
501
if not os .path .exists (checkpoint_dir ):
478
502
os .makedirs (checkpoint_dir )
@@ -483,9 +507,7 @@ def save(self, checkpoint_dir, step):
483
507
484
508
def load (self , checkpoint_dir ):
485
509
print (" [*] Reading checkpoints..." )
486
-
487
- model_dir = "%s_%s_%s" % (self .dataset_name , self .batch_size , self .output_size )
488
- checkpoint_dir = os .path .join (checkpoint_dir , model_dir )
510
+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
489
511
490
512
ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
491
513
if ckpt and ckpt .model_checkpoint_path :
0 commit comments