1
1
from __future__ import division
2
+ from __future__ import print_function
2
3
import os
3
4
import time
4
5
import math
13
14
def conv_out_size_same (size , stride ):
14
15
return int (math .ceil (float (size ) / float (stride )))
15
16
17
+ def gen_random (mode , size ):
18
+ if mode == 'normal01' : return np .random .normal (0 ,1 ,size = size )
19
+ if mode == 'uniform_signed' : return np .random .uniform (- 1 ,1 ,size = size )
20
+ if mode == 'uniform_unsigned' : return np .random .uniform (0 ,1 ,size = size )
21
+
22
+
16
23
class DCGAN (object ):
17
24
def __init__ (self , sess , input_height = 108 , input_width = 108 , crop = True ,
18
25
batch_size = 64 , sample_num = 64 , output_height = 64 , output_width = 64 ,
19
26
y_dim = None , z_dim = 100 , gf_dim = 64 , df_dim = 64 ,
20
27
gfc_dim = 1024 , dfc_dim = 1024 , c_dim = 3 , dataset_name = 'default' ,
21
- input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None , data_dir = './data' ):
28
+ max_to_keep = 1 ,
29
+ input_fname_pattern = '*.jpg' , checkpoint_dir = 'ckpts' , sample_dir = 'samples' , out_dir = './out' , data_dir = './data' ):
22
30
"""
23
31
24
32
Args:
@@ -70,6 +78,8 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
70
78
self .input_fname_pattern = input_fname_pattern
71
79
self .checkpoint_dir = checkpoint_dir
72
80
self .data_dir = data_dir
81
+ self .out_dir = out_dir
82
+ self .max_to_keep = max_to_keep
73
83
74
84
if self .dataset_name == 'mnist' :
75
85
self .data_X , self .data_y = self .load_mnist ()
@@ -148,7 +158,7 @@ def sigmoid_cross_entropy_with_logits(x, y):
148
158
self .d_vars = [var for var in t_vars if 'd_' in var .name ]
149
159
self .g_vars = [var for var in t_vars if 'g_' in var .name ]
150
160
151
- self .saver = tf .train .Saver ()
161
+ self .saver = tf .train .Saver (max_to_keep = self . max_to_keep )
152
162
153
163
def train (self , config ):
154
164
d_optim = tf .train .AdamOptimizer (config .learning_rate , beta1 = config .beta1 ) \
@@ -160,13 +170,15 @@ def train(self, config):
160
170
except :
161
171
tf .initialize_all_variables ().run ()
162
172
163
- self .g_sum = merge_summary ([self .z_sum , self .d__sum ,
164
- self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
173
+ if config .G_img_sum :
174
+ self .g_sum = merge_summary ([self .z_sum , self .d__sum , self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
175
+ else :
176
+ self .g_sum = merge_summary ([self .z_sum , self .d__sum , self .d_loss_fake_sum , self .g_loss_sum ])
165
177
self .d_sum = merge_summary (
166
178
[self .z_sum , self .d_sum , self .d_loss_real_sum , self .d_loss_sum ])
167
- self .writer = SummaryWriter ("./ logs" , self .sess .graph )
179
+ self .writer = SummaryWriter (os . path . join ( self . out_dir , " logs") , self .sess .graph )
168
180
169
- sample_z = np . random . uniform ( - 1 , 1 , size = (self .sample_num , self .z_dim ))
181
+ sample_z = gen_random ( config . z_dist , size = (self .sample_num , self .z_dim ))
170
182
171
183
if config .dataset == 'mnist' :
172
184
sample_inputs = self .data_X [0 :self .sample_num ]
@@ -223,7 +235,7 @@ def train(self, config):
223
235
else :
224
236
batch_images = np .array (batch ).astype (np .float32 )
225
237
226
- batch_z = np . random . uniform ( - 1 , 1 , [config .batch_size , self .z_dim ]) \
238
+ batch_z = gen_random ( config . z_dist , size = [config .batch_size , self .z_dim ]) \
227
239
.astype (np .float32 )
228
240
229
241
if config .dataset == 'mnist' :
@@ -281,12 +293,11 @@ def train(self, config):
281
293
errD_real = self .d_loss_real .eval ({ self .inputs : batch_images })
282
294
errG = self .g_loss .eval ({self .z : batch_z })
283
295
284
- counter += 1
285
- print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
286
- % (epoch , config .epoch , idx , batch_idxs ,
296
+ print ("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
297
+ % (counter , epoch , config .epoch , idx , batch_idxs ,
287
298
time .time () - start_time , errD_fake + errD_real , errG ))
288
299
289
- if np .mod (counter , 100 ) == 1 :
300
+ if np .mod (counter , config . sample_freq ) == 0 :
290
301
if config .dataset == 'mnist' :
291
302
samples , d_loss , g_loss = self .sess .run (
292
303
[self .sampler , self .d_loss , self .g_loss ],
@@ -297,7 +308,7 @@ def train(self, config):
297
308
}
298
309
)
299
310
save_images (samples , image_manifold_size (samples .shape [0 ]),
300
- './{}/train_{:02d}_{:04d} .png' .format (config .sample_dir , epoch , idx ))
311
+ './{}/train_{:08d} .png' .format (config .sample_dir , counter ))
301
312
print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss , g_loss ))
302
313
else :
303
314
try :
@@ -309,14 +320,16 @@ def train(self, config):
309
320
},
310
321
)
311
322
save_images (samples , image_manifold_size (samples .shape [0 ]),
312
- './{}/train_{:02d}_{:04d} .png' .format (config .sample_dir , epoch , idx ))
323
+ './{}/train_{:08d} .png' .format (config .sample_dir , counter ))
313
324
print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss , g_loss ))
314
325
except :
315
326
print ("one pic error!..." )
316
327
317
- if np .mod (counter , 500 ) == 2 :
328
+ if np .mod (counter , config . ckpt_freq ) == 0 :
318
329
self .save (config .checkpoint_dir , counter )
319
-
330
+
331
+ counter += 1
332
+
320
333
def discriminator (self , image , y = None , reuse = False ):
321
334
with tf .variable_scope ("discriminator" ) as scope :
322
335
if reuse :
@@ -501,28 +514,39 @@ def model_dir(self):
501
514
return "{}_{}_{}_{}" .format (
502
515
self .dataset_name , self .batch_size ,
503
516
self .output_height , self .output_width )
504
-
505
- def save (self , checkpoint_dir , step ):
506
- model_name = "DCGAN.model"
507
- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
508
517
518
+ def save (self , checkpoint_dir , step , filename = 'model' , ckpt = True , frozen = False ):
519
+ # model_name = "DCGAN.model"
520
+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
521
+
522
+ filename += '.b' + str (self .batch_size )
509
523
if not os .path .exists (checkpoint_dir ):
510
524
os .makedirs (checkpoint_dir )
511
525
512
- self .saver .save (self .sess ,
513
- os .path .join (checkpoint_dir , model_name ),
514
- global_step = step )
526
+ if ckpt :
527
+ self .saver .save (self .sess ,
528
+ os .path .join (checkpoint_dir , filename ),
529
+ global_step = step )
530
+
531
+ if frozen :
532
+ tf .train .write_graph (
533
+ tf .graph_util .convert_variables_to_constants (self .sess , self .sess .graph_def , ["generator_1/Tanh" ]),
534
+ checkpoint_dir ,
535
+ '{}-{:06d}_frz.pb' .format (filename , step ),
536
+ as_text = False )
515
537
516
538
def load (self , checkpoint_dir ):
517
- import re
518
- print (" [*] Reading checkpoints..." )
519
- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
539
+ #import re
540
+ print (" [*] Reading checkpoints..." , checkpoint_dir )
541
+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
542
+ # print(" ->", checkpoint_dir)
520
543
521
544
ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
522
545
if ckpt and ckpt .model_checkpoint_path :
523
546
ckpt_name = os .path .basename (ckpt .model_checkpoint_path )
524
547
self .saver .restore (self .sess , os .path .join (checkpoint_dir , ckpt_name ))
525
- counter = int (next (re .finditer ("(\d+)(?!.*\d)" ,ckpt_name )).group (0 ))
548
+ #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
549
+ counter = int (ckpt_name .split ('-' )[- 1 ])
526
550
print (" [*] Success to read {}" .format (ckpt_name ))
527
551
return True , counter
528
552
else :
0 commit comments