Skip to content

Commit 98d2810

Browse files
authored
Merge pull request #302 from memo/master
A number of extra options
2 parents 5a02f11 + 407ae28 commit 98d2810

File tree

3 files changed

+134
-54
lines changed

3 files changed

+134
-54
lines changed

Diff for: main.py

+66-22
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import scipy.misc
33
import numpy as np
4+
import json
45

56
from model import DCGAN
6-
from utils import pp, visualize, to_json, show_all_variables
7+
from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp
78

89
import tensorflow as tf
910

@@ -19,27 +20,56 @@
1920
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
2021
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
2122
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
22-
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
23-
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
24-
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
23+
flags.DEFINE_string("data_dir", "./data", "path to datasets [e.g. $HOME/data]")
24+
flags.DEFINE_string("out_dir", "./out", "Root directory for outputs [e.g. $HOME/out]")
25+
flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []")
26+
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]")
27+
flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]")
2528
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
2629
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
2730
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
28-
flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
31+
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
32+
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
33+
flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep")
34+
flags.DEFINE_integer("sample_freq", 200, "sample every this many iterations")
35+
flags.DEFINE_integer("ckpt_freq", 200, "save checkpoint every this many iterations")
36+
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
37+
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
38+
flags.DEFINE_boolean("G_img_sum", False, "Save generator image summaries in log")
39+
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
2940
FLAGS = flags.FLAGS
3041

3142
def main(_):
3243
pp.pprint(flags.FLAGS.__flags)
44+
45+
# expand user name and environment variables
46+
FLAGS.data_dir = expand_path(FLAGS.data_dir)
47+
FLAGS.out_dir = expand_path(FLAGS.out_dir)
48+
FLAGS.out_name = expand_path(FLAGS.out_name)
49+
FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
50+
FLAGS.sample_dir = expand_path(FLAGS.sample_dir)
3351

34-
if FLAGS.input_width is None:
35-
FLAGS.input_width = FLAGS.input_height
36-
if FLAGS.output_width is None:
37-
FLAGS.output_width = FLAGS.output_height
52+
if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
53+
if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
54+
if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height
3855

39-
if not os.path.exists(FLAGS.checkpoint_dir):
40-
os.makedirs(FLAGS.checkpoint_dir)
41-
if not os.path.exists(FLAGS.sample_dir):
42-
os.makedirs(FLAGS.sample_dir)
56+
# output folders
57+
if FLAGS.out_name == "":
58+
FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
59+
if FLAGS.train:
60+
FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)
61+
62+
FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
63+
FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
64+
FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)
65+
66+
if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
67+
if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)
68+
69+
with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
70+
flags_dict = {k:FLAGS[k].value for k in FLAGS}
71+
json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
72+
4373

4474
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
4575
run_config = tf.ConfigProto()
@@ -56,13 +86,15 @@ def main(_):
5686
batch_size=FLAGS.batch_size,
5787
sample_num=FLAGS.batch_size,
5888
y_dim=10,
59-
z_dim=FLAGS.generate_test_images,
89+
z_dim=FLAGS.z_dim,
6090
dataset_name=FLAGS.dataset,
6191
input_fname_pattern=FLAGS.input_fname_pattern,
6292
crop=FLAGS.crop,
6393
checkpoint_dir=FLAGS.checkpoint_dir,
6494
sample_dir=FLAGS.sample_dir,
65-
data_dir=FLAGS.data_dir)
95+
data_dir=FLAGS.data_dir,
96+
out_dir=FLAGS.out_dir,
97+
max_to_keep=FLAGS.max_to_keep)
6698
else:
6799
dcgan = DCGAN(
68100
sess,
@@ -72,22 +104,25 @@ def main(_):
72104
output_height=FLAGS.output_height,
73105
batch_size=FLAGS.batch_size,
74106
sample_num=FLAGS.batch_size,
75-
z_dim=FLAGS.generate_test_images,
107+
z_dim=FLAGS.z_dim,
76108
dataset_name=FLAGS.dataset,
77109
input_fname_pattern=FLAGS.input_fname_pattern,
78110
crop=FLAGS.crop,
79111
checkpoint_dir=FLAGS.checkpoint_dir,
80112
sample_dir=FLAGS.sample_dir,
81-
data_dir=FLAGS.data_dir)
113+
data_dir=FLAGS.data_dir,
114+
out_dir=FLAGS.out_dir,
115+
max_to_keep=FLAGS.max_to_keep)
82116

83117
show_all_variables()
84118

85119
if FLAGS.train:
86120
dcgan.train(FLAGS)
87121
else:
88-
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
89-
raise Exception("[!] Train a model first, then run test mode")
90-
122+
load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
123+
if not load_success:
124+
raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir)
125+
91126

92127
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
93128
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
@@ -96,8 +131,17 @@ def main(_):
96131
# [dcgan.h4_w, dcgan.h4_b, None])
97132

98133
# Below is codes for visualization
99-
OPTION = 1
100-
visualize(sess, dcgan, FLAGS, OPTION)
134+
if FLAGS.export:
135+
export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size))
136+
dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)
137+
138+
if FLAGS.freeze:
139+
export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size))
140+
dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)
141+
142+
if FLAGS.visualize:
143+
OPTION = 1
144+
visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)
101145

102146
if __name__ == '__main__':
103147
tf.app.run()

Diff for: model.py

+50-26
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import division
2+
from __future__ import print_function
23
import os
34
import time
45
import math
@@ -13,12 +14,19 @@
1314
def conv_out_size_same(size, stride):
1415
return int(math.ceil(float(size) / float(stride)))
1516

17+
def gen_random(mode, size):
18+
if mode=='normal01': return np.random.normal(0,1,size=size)
19+
if mode=='uniform_signed': return np.random.uniform(-1,1,size=size)
20+
if mode=='uniform_unsigned': return np.random.uniform(0,1,size=size)
21+
22+
1623
class DCGAN(object):
1724
def __init__(self, sess, input_height=108, input_width=108, crop=True,
1825
batch_size=64, sample_num = 64, output_height=64, output_width=64,
1926
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
2027
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
21-
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
28+
max_to_keep=1,
29+
input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
2230
"""
2331
2432
Args:
@@ -70,6 +78,8 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
7078
self.input_fname_pattern = input_fname_pattern
7179
self.checkpoint_dir = checkpoint_dir
7280
self.data_dir = data_dir
81+
self.out_dir = out_dir
82+
self.max_to_keep = max_to_keep
7383

7484
if self.dataset_name == 'mnist':
7585
self.data_X, self.data_y = self.load_mnist()
@@ -148,7 +158,7 @@ def sigmoid_cross_entropy_with_logits(x, y):
148158
self.d_vars = [var for var in t_vars if 'd_' in var.name]
149159
self.g_vars = [var for var in t_vars if 'g_' in var.name]
150160

151-
self.saver = tf.train.Saver()
161+
self.saver = tf.train.Saver(max_to_keep=self.max_to_keep)
152162

153163
def train(self, config):
154164
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
@@ -160,13 +170,15 @@ def train(self, config):
160170
except:
161171
tf.initialize_all_variables().run()
162172

163-
self.g_sum = merge_summary([self.z_sum, self.d__sum,
164-
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
173+
if config.G_img_sum:
174+
self.g_sum = merge_summary([self.z_sum, self.d__sum, self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
175+
else:
176+
self.g_sum = merge_summary([self.z_sum, self.d__sum, self.d_loss_fake_sum, self.g_loss_sum])
165177
self.d_sum = merge_summary(
166178
[self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
167-
self.writer = SummaryWriter("./logs", self.sess.graph)
179+
self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph)
168180

169-
sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim))
181+
sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim))
170182

171183
if config.dataset == 'mnist':
172184
sample_inputs = self.data_X[0:self.sample_num]
@@ -223,7 +235,7 @@ def train(self, config):
223235
else:
224236
batch_images = np.array(batch).astype(np.float32)
225237

226-
batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \
238+
batch_z = gen_random(config.z_dist, size=[config.batch_size, self.z_dim]) \
227239
.astype(np.float32)
228240

229241
if config.dataset == 'mnist':
@@ -281,12 +293,11 @@ def train(self, config):
281293
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
282294
errG = self.g_loss.eval({self.z: batch_z})
283295

284-
counter += 1
285-
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
286-
% (epoch, config.epoch, idx, batch_idxs,
296+
print("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
297+
% (counter, epoch, config.epoch, idx, batch_idxs,
287298
time.time() - start_time, errD_fake+errD_real, errG))
288299

289-
if np.mod(counter, 100) == 1:
300+
if np.mod(counter, config.sample_freq) == 0:
290301
if config.dataset == 'mnist':
291302
samples, d_loss, g_loss = self.sess.run(
292303
[self.sampler, self.d_loss, self.g_loss],
@@ -297,7 +308,7 @@ def train(self, config):
297308
}
298309
)
299310
save_images(samples, image_manifold_size(samples.shape[0]),
300-
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
311+
'./{}/train_{:08d}.png'.format(config.sample_dir, counter))
301312
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
302313
else:
303314
try:
@@ -309,14 +320,16 @@ def train(self, config):
309320
},
310321
)
311322
save_images(samples, image_manifold_size(samples.shape[0]),
312-
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
323+
'./{}/train_{:08d}.png'.format(config.sample_dir, counter))
313324
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
314325
except:
315326
print("one pic error!...")
316327

317-
if np.mod(counter, 500) == 2:
328+
if np.mod(counter, config.ckpt_freq) == 0:
318329
self.save(config.checkpoint_dir, counter)
319-
330+
331+
counter += 1
332+
320333
def discriminator(self, image, y=None, reuse=False):
321334
with tf.variable_scope("discriminator") as scope:
322335
if reuse:
@@ -501,28 +514,39 @@ def model_dir(self):
501514
return "{}_{}_{}_{}".format(
502515
self.dataset_name, self.batch_size,
503516
self.output_height, self.output_width)
504-
505-
def save(self, checkpoint_dir, step):
506-
model_name = "DCGAN.model"
507-
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
508517

518+
def save(self, checkpoint_dir, step, filename='model', ckpt=True, frozen=False):
519+
# model_name = "DCGAN.model"
520+
# checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
521+
522+
filename += '.b' + str(self.batch_size)
509523
if not os.path.exists(checkpoint_dir):
510524
os.makedirs(checkpoint_dir)
511525

512-
self.saver.save(self.sess,
513-
os.path.join(checkpoint_dir, model_name),
514-
global_step=step)
526+
if ckpt:
527+
self.saver.save(self.sess,
528+
os.path.join(checkpoint_dir, filename),
529+
global_step=step)
530+
531+
if frozen:
532+
tf.train.write_graph(
533+
tf.graph_util.convert_variables_to_constants(self.sess, self.sess.graph_def, ["generator_1/Tanh"]),
534+
checkpoint_dir,
535+
'{}-{:06d}_frz.pb'.format(filename, step),
536+
as_text=False)
515537

516538
def load(self, checkpoint_dir):
517-
import re
518-
print(" [*] Reading checkpoints...")
519-
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
539+
#import re
540+
print(" [*] Reading checkpoints...", checkpoint_dir)
541+
# checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
542+
# print(" ->", checkpoint_dir)
520543

521544
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
522545
if ckpt and ckpt.model_checkpoint_path:
523546
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
524547
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
525-
counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
548+
#counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
549+
counter = int(ckpt_name.split('-')[-1])
526550
print(" [*] Success to read {}".format(ckpt_name))
527551
return True, counter
528552
else:

Diff for: utils.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import scipy.misc
1010
import cv2
1111
import numpy as np
12+
import os
13+
import time
14+
import datetime
1215
from time import gmtime, strftime
1316
from six.moves import xrange
1417

@@ -19,6 +22,15 @@
1922

2023
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
2124

25+
26+
def expand_path(path):
27+
return os.path.expanduser(os.path.expandvars(path))
28+
29+
def timestamp(s='%Y%m%d.%H%M%S', ts=None):
30+
if not ts: ts = time.time()
31+
st = datetime.datetime.fromtimestamp(ts).strftime(s)
32+
return st
33+
2234
def show_all_variables():
2335
model_vars = tf.trainable_variables()
2436
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
@@ -174,12 +186,12 @@ def make_frame(t):
174186
clip = mpy.VideoClip(make_frame, duration=duration)
175187
clip.write_gif(fname, fps = len(images) / duration)
176188

177-
def visualize(sess, dcgan, config, option):
189+
def visualize(sess, dcgan, config, option, sample_dir='samples'):
178190
image_frame_dim = int(math.ceil(config.batch_size**.5))
179191
if option == 0:
180192
z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
181193
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
182-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
194+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
183195
elif option == 1:
184196
values = np.arange(0, 1, 1./config.batch_size)
185197
for idx in xrange(dcgan.z_dim):
@@ -197,7 +209,7 @@ def visualize(sess, dcgan, config, option):
197209
else:
198210
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
199211

200-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))
212+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx)))
201213
elif option == 2:
202214
values = np.arange(0, 1, 1./config.batch_size)
203215
for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
@@ -220,7 +232,7 @@ def visualize(sess, dcgan, config, option):
220232
try:
221233
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
222234
except:
223-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
235+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
224236
elif option == 3:
225237
values = np.arange(0, 1, 1./config.batch_size)
226238
for idx in xrange(dcgan.z_dim):
@@ -230,7 +242,7 @@ def visualize(sess, dcgan, config, option):
230242
z[idx] = values[kdx]
231243

232244
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
233-
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
245+
make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
234246
elif option == 4:
235247
image_set = []
236248
values = np.arange(0, 1, 1./config.batch_size)
@@ -241,7 +253,7 @@ def visualize(sess, dcgan, config, option):
241253
for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
242254

243255
image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
244-
make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
256+
make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
245257

246258
new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
247259
for idx in range(64) + range(63, -1, -1)]

0 commit comments

Comments
 (0)