@@ -18,7 +18,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
18
18
batch_size = 64 , sample_num = 64 , output_height = 64 , output_width = 64 ,
19
19
y_dim = None , z_dim = 100 , gf_dim = 64 , df_dim = 64 ,
20
20
gfc_dim = 1024 , dfc_dim = 1024 , c_dim = 3 , dataset_name = 'default' ,
21
- input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None ):
21
+ input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None , data_dir = './data' ):
22
22
"""
23
23
24
24
Args:
@@ -69,12 +69,13 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
69
69
self .dataset_name = dataset_name
70
70
self .input_fname_pattern = input_fname_pattern
71
71
self .checkpoint_dir = checkpoint_dir
72
+ self .data_dir = data_dir
72
73
73
74
if self .dataset_name == 'mnist' :
74
75
self .data_X , self .data_y = self .load_mnist ()
75
76
self .c_dim = self .data_X [0 ].shape [- 1 ]
76
77
else :
77
- self .data = glob (os .path .join ("./data" , self .dataset_name , self .input_fname_pattern ))
78
+ self .data = glob (os .path .join (self . data_dir , self .dataset_name , self .input_fname_pattern ))
78
79
imreadImg = imread (self .data [0 ])
79
80
if len (imreadImg .shape ) >= 3 : #check if image is a non-grayscale image by checking channel number
80
81
self .c_dim = imread (self .data [0 ]).shape [- 1 ]
@@ -192,7 +193,7 @@ def train(self, config):
192
193
batch_idxs = min (len (self .data_X ), config .train_size ) // config .batch_size
193
194
else :
194
195
self .data = glob (os .path .join (
195
- "./data" , config .dataset , self .input_fname_pattern ))
196
+ config . data_dir , config .dataset , self .input_fname_pattern ))
196
197
batch_idxs = min (len (self .data ), config .train_size ) // config .batch_size
197
198
198
199
for idx in xrange (0 , batch_idxs ):
@@ -451,7 +452,7 @@ def sampler(self, z, y=None):
451
452
return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
452
453
453
454
def load_mnist (self ):
454
- data_dir = os .path .join ("./data" , self .dataset_name )
455
+ data_dir = os .path .join (self . data_dir , self .dataset_name )
455
456
456
457
fd = open (os .path .join (data_dir ,'train-images-idx3-ubyte' ))
457
458
loaded = np .fromfile (file = fd ,dtype = np .uint8 )
0 commit comments