Skip to content

Commit 8664afd

Browse files
committed
clean the codes and add width and height
1 parent 00cfe3c commit 8664afd

File tree

5 files changed

+118
-86
lines changed

5 files changed

+118
-86
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ checkpoint
1212

1313
# trash
1414
.dropbox
15+
.DS_Store
1516

1617
# Created by https://www.gitignore.io/api/python,vim
1718

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ First, download dataset with:
3434

3535
To train a model with downloaded dataset:
3636

37-
$ python main.py --dataset mnist --is_train True
38-
$ python main.py --dataset celebA --is_train True --is_crop True
37+
$ python main.py --dataset mnist --is_train
38+
$ python main.py --dataset celebA --is_train --is_crop True
3939

4040
To test with an existing model:
4141

@@ -46,7 +46,7 @@ Or, you can use your own dataset (without central crop) by:
4646

4747
$ mkdir data/DATASET_NAME
4848
... add images to data/DATASET_NAME ...
49-
$ python main.py --dataset DATASET_NAME --is_train True
49+
$ python main.py --dataset DATASET_NAME --is_train
5050
$ python main.py --dataset DATASET_NAME
5151

5252

main.py

+30-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
1616
flags.DEFINE_integer("image_width", 108, "The size of image to use (will be center cropped) [108]")
1717
flags.DEFINE_integer("image_height", 108, "The size of image to use (will be center cropped) [108]")
18-
flags.DEFINE_integer("output_size", 64, "The size of the output images to produce [64]")
18+
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
19+
flags.DEFINE_integer("output_width", 64, "The size of the output images to produce [64]")
1920
flags.DEFINE_integer("c_dim", 3, "Dimension of image color. [3]")
2021
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
2122
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
@@ -35,32 +36,39 @@ def main(_):
3536

3637
with tf.Session() as sess:
3738
if FLAGS.dataset == 'mnist':
38-
dcgan = DCGAN(sess,
39-
image_width=FLAGS.image_width,
40-
image_height=FLAGS.image_height,
41-
batch_size=FLAGS.batch_size,
42-
y_dim=10,
43-
output_size=28,
44-
c_dim=1,
45-
dataset_name=FLAGS.dataset,
46-
is_crop=FLAGS.is_crop,
47-
checkpoint_dir=FLAGS.checkpoint_dir,
48-
sample_dir=FLAGS.sample_dir)
39+
dcgan = DCGAN(
40+
sess,
41+
image_width=FLAGS.image_width,
42+
image_height=FLAGS.image_height,
43+
output_width=FLAGS.output_width,
44+
output_height=FLAGS.output_height,
45+
batch_size=FLAGS.batch_size,
46+
y_dim=10,
47+
c_dim=1,
48+
dataset_name=FLAGS.dataset,
49+
is_crop=FLAGS.is_crop,
50+
checkpoint_dir=FLAGS.checkpoint_dir,
51+
sample_dir=FLAGS.sample_dir)
4952
else:
50-
dcgan = DCGAN(sess,
51-
image_size=FLAGS.image_size,
52-
batch_size=FLAGS.batch_size,
53-
output_size=FLAGS.output_size,
54-
c_dim=FLAGS.c_dim,
55-
dataset_name=FLAGS.dataset,
56-
is_crop=FLAGS.is_crop,
57-
checkpoint_dir=FLAGS.checkpoint_dir,
58-
sample_dir=FLAGS.sample_dir)
53+
dcgan = DCGAN(
54+
sess,
55+
image_width=FLAGS.image_width,
56+
image_height=FLAGS.image_height,
57+
output_width=FLAGS.output_width,
58+
output_height=FLAGS.output_height,
59+
batch_size=FLAGS.batch_size,
60+
c_dim=FLAGS.c_dim,
61+
dataset_name=FLAGS.dataset,
62+
is_crop=FLAGS.is_crop,
63+
checkpoint_dir=FLAGS.checkpoint_dir,
64+
sample_dir=FLAGS.sample_dir)
5965

6066
if FLAGS.is_train:
6167
dcgan.train(FLAGS)
6268
else:
63-
dcgan.load(FLAGS.checkpoint_dir)
69+
if not dcgan.load(FLAGS.checkpoint_dir):
70+
raise Exception("[!] Train a model first, then run test mode")
71+
6472

6573
# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
6674
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],

model.py

+82-60
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,37 @@ def build_model(self):
7373
if self.y_dim:
7474
self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')
7575

76-
image_dims = [self.output_height, self.output_width, self.c_dim]
76+
image_dims = [None, None, self.c_dim]
7777

7878
self.inputs = tf.placeholder(
79-
tf.float32, [self.batch_size] + image_dims name='real_images')
79+
tf.float32, [self.batch_size] + image_dims, name='real_images')
8080
self.sample_inputs = tf.placeholder(
8181
tf.float32, [self.sample_num] + image_dims, name='sample_inputs')
8282

83-
inputs = tf.image.resize_images(
84-
self.inputs, [self.output_height, self.output_width])
85-
sample_inputs = tf.image.resize_images(
86-
self.sample_inputs, [self.output_height, self.output_width])
83+
if not self.is_crop:
84+
inputs = tf.image.resize_images(
85+
self.inputs, [self.output_height, self.output_width])
86+
sample_inputs = tf.image.resize_images(
87+
self.sample_inputs, [self.output_height, self.output_width])
88+
else:
89+
inputs = self.inputs
90+
sample_inputs = self.sample_inputs
8791

8892
self.z = tf.placeholder(
8993
tf.float32, [None, self.z_dim], name='z')
9094
self.z_sum = histogram_summary("z", self.z)
9195

9296
if self.y_dim:
9397
self.G = self.generator(self.z, self.y)
94-
self.D, self.D_logits = self.discriminator(self.inputs, self.y, reuse=False)
98+
self.D, self.D_logits = \
99+
self.discriminator(inputs, self.y, reuse=False)
95100

96101
self.sampler = self.sampler(self.z, self.y)
97-
self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True)
102+
self.D_, self.D_logits_ = \
103+
self.discriminator(self.G, self.y, reuse=True)
98104
else:
99105
self.G = self.generator(self.z)
100-
self.D, self.D_logits = self.discriminator(self.inputs)
106+
self.D, self.D_logits = self.discriminator(inputs)
101107

102108
self.sampler = self.sampler(self.z)
103109
self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True)
@@ -144,10 +150,9 @@ def train(self, config):
144150
g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \
145151
.minimize(self.g_loss, var_list=self.g_vars)
146152
try:
147-
tf.initialize_all_variables().run()
153+
tf.global_variables_initializer().run()
148154
except:
149-
init_op = tf.global_variables_initializer()
150-
self.sess.run(init_op)
155+
tf.initialize_all_variables().run()
151156

152157
self.g_sum = merge_summary([self.z_sum, self.d__sum,
153158
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
@@ -198,8 +203,8 @@ def train(self, config):
198203
batch_files = data[idx*config.batch_size:(idx+1)*config.batch_size]
199204
batch = [
200205
get_image(batch_file,
201-
self.image_height,
202-
self.image_width,
206+
image_height=self.image_height,
207+
image_width=self.image_width,
203208
resize_height=self.output_height,
204209
resize_width=self.output_width,
205210
is_crop=self.is_crop,
@@ -263,8 +268,8 @@ def train(self, config):
263268
feed_dict={ self.z: batch_z })
264269
self.writer.add_summary(summary_str, counter)
265270

266-
errD_fake = self.d_loss_fake.eval({self.z: batch_z})
267-
errD_real = self.d_loss_real.eval({self.inputs: batch_images})
271+
errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
272+
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
268273
errG = self.g_loss.eval({self.z: batch_z})
269274

270275
counter += 1
@@ -325,10 +330,10 @@ def discriminator(self, image, y=None, reuse=False):
325330

326331
h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv')))
327332
h1 = tf.reshape(h1, [self.batch_size, -1])
328-
h1 = tf.concat(1, [h1, y])
333+
h1 = tf.concat_v2([h1, y], 1)
329334

330335
h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin')))
331-
h2 = tf.concat(1, [h2, y])
336+
h2 = tf.concat_v2([h2, y], 1)
332337

333338
h3 = linear(h2, 1, 'd_h3_lin')
334339

@@ -337,100 +342,114 @@ def discriminator(self, image, y=None, reuse=False):
337342
def generator(self, z, y=None):
338343
with tf.variable_scope("generator") as scope:
339344
if not self.y_dim:
340-
s = self.output_size
341-
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)
345+
s_h, s_w = self.output_height, self.output_width
346+
s_h2, s_h4, s_h8, s_h16 = \
347+
int(s_h/2), int(s_h/4), int(s_h/8), int(s_h/16)
348+
s_w2, s_w4, s_w8, s_w16 = \
349+
int(s_w/2), int(s_w/4), int(s_w/8), int(s_w/16)
342350

343351
# project `z` and reshape
344-
self.z_, self.h0_w, self.h0_b = linear(z, self.gf_dim*8*s16*s16, 'g_h0_lin', with_w=True)
352+
self.z_, self.h0_w, self.h0_b = linear(
353+
z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True)
345354

346-
self.h0 = tf.reshape(self.z_, [-1, s16, s16, self.gf_dim * 8])
355+
self.h0 = tf.reshape(
356+
self.z_, [-1, s_h16, s_w16, self.gf_dim * 8])
347357
h0 = tf.nn.relu(self.g_bn0(self.h0))
348358

349-
self.h1, self.h1_w, self.h1_b = deconv2d(h0,
350-
[self.batch_size, s8, s8, self.gf_dim*4], name='g_h1', with_w=True)
359+
self.h1, self.h1_w, self.h1_b = deconv2d(
360+
h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True)
351361
h1 = tf.nn.relu(self.g_bn1(self.h1))
352362

353-
h2, self.h2_w, self.h2_b = deconv2d(h1,
354-
[self.batch_size, s4, s4, self.gf_dim*2], name='g_h2', with_w=True)
363+
h2, self.h2_w, self.h2_b = deconv2d(
364+
h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True)
355365
h2 = tf.nn.relu(self.g_bn2(h2))
356366

357-
h3, self.h3_w, self.h3_b = deconv2d(h2,
358-
[self.batch_size, s2, s2, self.gf_dim*1], name='g_h3', with_w=True)
367+
h3, self.h3_w, self.h3_b = deconv2d(
368+
h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True)
359369
h3 = tf.nn.relu(self.g_bn3(h3))
360370

361-
h4, self.h4_w, self.h4_b = deconv2d(h3,
362-
[self.batch_size, s, s, self.c_dim], name='g_h4', with_w=True)
371+
h4, self.h4_w, self.h4_b = deconv2d(
372+
h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True)
363373

364374
return tf.nn.tanh(h4)
365375
else:
366-
s = self.output_size
367-
s2, s4 = int(s/2), int(s/4)
376+
s_h, s_w = self.output_height, self.output_width
377+
s_h2, s_h4 = int(s_h/2), int(s_h/4)
378+
s_w2, s_w4 = int(s_w/2), int(s_w/4)
368379

369380
# yb = tf.expand_dims(tf.expand_dims(y, 1),2)
370381
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
371-
z = tf.concat(1, [z, y])
382+
z = tf.concat_v2([z, y], 1)
372383

373-
h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
374-
h0 = tf.concat(1, [h0, y])
384+
h0 = tf.nn.relu(
385+
self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
386+
h0 = tf.concat_v2([h0, y], 1)
375387

376-
h1 = tf.nn.relu(self.g_bn1(linear(h0, self.gf_dim*2*s4*s4, 'g_h1_lin')))
377-
h1 = tf.reshape(h1, [self.batch_size, s4, s4, self.gf_dim * 2])
388+
h1 = tf.nn.relu(self.g_bn1(
389+
linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin')))
390+
h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
378391

379392
h1 = conv_cond_concat(h1, yb)
380393

381394
h2 = tf.nn.relu(self.g_bn2(deconv2d(h1,
382-
[self.batch_size, s2, s2, self.gf_dim * 2], name='g_h2')))
395+
[self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2')))
383396
h2 = conv_cond_concat(h2, yb)
384397

385398
return tf.nn.sigmoid(
386-
deconv2d(h2, [self.batch_size, s, s, self.c_dim], name='g_h3'))
399+
deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
387400

388401
def sampler(self, z, y=None):
389402
with tf.variable_scope("generator") as scope:
390403
scope.reuse_variables()
391404

392405
if not self.y_dim:
393406

394-
s = self.output_size
395-
s2, s4, s8, s16 = int(s/2), int(s/4), int(s/8), int(s/16)
407+
s_h, s_w = self.output_height, self.output_width
408+
s_h2, s_h4, s_h8, s_h16 = \
409+
int(s_h/2), int(s_h/4), int(s_h/8), int(s_h/16)
410+
s_w2, s_w4, s_w8, s_w16 = \
411+
int(s_w/2), int(s_w/4), int(s_w/8), int(s_w/16)
396412

397413
# project `z` and reshape
398-
h0 = tf.reshape(linear(z, self.gf_dim*8*s16*s16, 'g_h0_lin'),
399-
[-1, s16, s16, self.gf_dim * 8])
414+
h0 = tf.reshape(
415+
linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'),
416+
[-1, s_h16, s_w16, self.gf_dim * 8])
400417
h0 = tf.nn.relu(self.g_bn0(h0, train=False))
401418

402-
h1 = deconv2d(h0, [self.batch_size, s8, s8, self.gf_dim*4], name='g_h1')
419+
h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1')
403420
h1 = tf.nn.relu(self.g_bn1(h1, train=False))
404421

405-
h2 = deconv2d(h1, [self.batch_size, s4, s4, self.gf_dim*2], name='g_h2')
422+
h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2')
406423
h2 = tf.nn.relu(self.g_bn2(h2, train=False))
407424

408-
h3 = deconv2d(h2, [self.batch_size, s2, s2, self.gf_dim*1], name='g_h3')
425+
h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3')
409426
h3 = tf.nn.relu(self.g_bn3(h3, train=False))
410427

411-
h4 = deconv2d(h3, [self.batch_size, s, s, self.c_dim], name='g_h4')
428+
h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4')
412429

413430
return tf.nn.tanh(h4)
414431
else:
415-
s = self.output_size
416-
s2, s4 = int(s/2), int(s/4)
432+
s_h, s_w = self.output_height, self.output_width
433+
s_h2, s_h4 = int(s_h/2), int(s_h/4)
434+
s_w2, s_w4 = int(s_w/2), int(s_w/4)
417435

418436
# yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
419437
yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim])
420-
z = tf.concat(1, [z, y])
438+
z = tf.concat_v2([z, y], 1)
421439

422440
h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin')))
423-
h0 = tf.concat(1, [h0, y])
441+
h0 = tf.concat_v2([h0, y], 1)
424442

425-
h1 = tf.nn.relu(self.g_bn1(linear(h0, self.gf_dim*2*s4*s4, 'g_h1_lin'), train=False))
426-
h1 = tf.reshape(h1, [self.batch_size, s4, s4, self.gf_dim * 2])
443+
h1 = tf.nn.relu(self.g_bn1(
444+
linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False))
445+
h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2])
427446
h1 = conv_cond_concat(h1, yb)
428447

429448
h2 = tf.nn.relu(self.g_bn2(
430-
deconv2d(h1, [self.batch_size, s2, s2, self.gf_dim * 2], name='g_h2'), train=False))
449+
deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False))
431450
h2 = conv_cond_concat(h2, yb)
432451

433-
return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s, s, self.c_dim], name='g_h3'))
452+
return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
434453

435454
def load_mnist(self):
436455
data_dir = os.path.join("./data", self.dataset_name)
@@ -468,11 +487,16 @@ def load_mnist(self):
468487
y_vec[i,y[i]] = 1.0
469488

470489
return X/255.,y_vec
490+
491+
@property
492+
def model_dir(self):
493+
return "{}_{}_{}_{}".format(
494+
self.dataset_name, self.batch_size,
495+
self.output_height, self.output_width)
471496

472497
def save(self, checkpoint_dir, step):
473498
model_name = "DCGAN.model"
474-
model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
475-
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
499+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
476500

477501
if not os.path.exists(checkpoint_dir):
478502
os.makedirs(checkpoint_dir)
@@ -483,9 +507,7 @@ def save(self, checkpoint_dir, step):
483507

484508
def load(self, checkpoint_dir):
485509
print(" [*] Reading checkpoints...")
486-
487-
model_dir = "%s_%s_%s" % (self.dataset_name, self.batch_size, self.output_size)
488-
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
510+
checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
489511

490512
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
491513
if ckpt and ckpt.model_checkpoint_path:

ops.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def conv_cond_concat(x, y):
3939
"""Concatenate conditioning vector on feature map axis."""
4040
x_shapes = x.get_shape()
4141
y_shapes = y.get_shape()
42-
return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])])
42+
return tf.concat_v2([
43+
x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
4344

4445
def conv2d(input_, output_dim,
4546
k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,

0 commit comments

Comments
 (0)