Skip to content

Commit 60aa97b

Browse files
authored
Merge pull request #277 from genekogan/master
Make dataset root directory a flag
2 parents ddb7fc2 + 8fd1c70 commit 60aa97b

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

README.md

+8
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ Or, you can use your own dataset (without central crop) by:
5050
$ # example
5151
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train
5252

53+
If your dataset is located in a different root directory:
54+
55+
$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR --train
56+
$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR
57+
$ # example
58+
$ python main.py --dataset=eyes --data_dir ../datasets/ --input_fname_pattern="*_cropped.png" --train
59+
60+
5361
## Results
5462

5563
![result](assets/training.gif)

main.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
2121
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
2222
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
23+
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
2324
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
2425
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
2526
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
@@ -60,7 +61,8 @@ def main(_):
6061
input_fname_pattern=FLAGS.input_fname_pattern,
6162
crop=FLAGS.crop,
6263
checkpoint_dir=FLAGS.checkpoint_dir,
63-
sample_dir=FLAGS.sample_dir)
64+
sample_dir=FLAGS.sample_dir,
65+
data_dir=FLAGS.data_dir)
6466
else:
6567
dcgan = DCGAN(
6668
sess,
@@ -75,7 +77,8 @@ def main(_):
7577
input_fname_pattern=FLAGS.input_fname_pattern,
7678
crop=FLAGS.crop,
7779
checkpoint_dir=FLAGS.checkpoint_dir,
78-
sample_dir=FLAGS.sample_dir)
80+
sample_dir=FLAGS.sample_dir,
81+
data_dir=FLAGS.data_dir)
7982

8083
show_all_variables()
8184

model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
1818
batch_size=64, sample_num = 64, output_height=64, output_width=64,
1919
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
2020
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
21-
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None):
21+
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
2222
"""
2323
2424
Args:
@@ -69,12 +69,13 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
6969
self.dataset_name = dataset_name
7070
self.input_fname_pattern = input_fname_pattern
7171
self.checkpoint_dir = checkpoint_dir
72+
self.data_dir = data_dir
7273

7374
if self.dataset_name == 'mnist':
7475
self.data_X, self.data_y = self.load_mnist()
7576
self.c_dim = self.data_X[0].shape[-1]
7677
else:
77-
self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))
78+
self.data = glob(os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern))
7879
imreadImg = imread(self.data[0])
7980
if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number
8081
self.c_dim = imread(self.data[0]).shape[-1]
@@ -192,7 +193,7 @@ def train(self, config):
192193
batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size
193194
else:
194195
self.data = glob(os.path.join(
195-
"./data", config.dataset, self.input_fname_pattern))
196+
config.data_dir, config.dataset, self.input_fname_pattern))
196197
batch_idxs = min(len(self.data), config.train_size) // config.batch_size
197198

198199
for idx in xrange(0, batch_idxs):
@@ -451,7 +452,7 @@ def sampler(self, z, y=None):
451452
return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))
452453

453454
def load_mnist(self):
454-
data_dir = os.path.join("./data", self.dataset_name)
455+
data_dir = os.path.join(self.data_dir, self.dataset_name)
455456

456457
fd = open(os.path.join(data_dir,'train-images-idx3-ubyte'))
457458
loaded = np.fromfile(file=fd,dtype=np.uint8)

0 commit comments

Comments
 (0)