Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.

Commit 9343c65

Browse files
author
ton
committed
Working
1 parent b4c3b65 commit 9343c65

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

Diff for: classify.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
'image_file', '', 'The name of the image to run an inference.')
3232

3333
tf.app.flags.DEFINE_integer(
34-
'batch_size', 100, 'The number of samples in each batch.')
34+
'batch_size', 1, 'The number of samples in each batch.')
3535

3636
tf.app.flags.DEFINE_integer(
37-
'max_num_batches', None,
37+
'max_num_batches', 1,
3838
'Max number of batches to evaluate by default use all.')
3939

4040
tf.app.flags.DEFINE_string(
@@ -53,7 +53,7 @@
5353
'The number of threads used to create the batches.')
5454

5555
tf.app.flags.DEFINE_string(
56-
'dataset_name', 'imagenet', 'The name of the dataset to load.')
56+
'dataset_name', 'arts', 'The name of the dataset to load.')
5757

5858
tf.app.flags.DEFINE_string(
5959
'dataset_split_name', 'test', 'The name of the train/test split.')
@@ -80,7 +80,7 @@
8080
'If left as None, then moving averages are not used.')
8181

8282
tf.app.flags.DEFINE_integer(
83-
'eval_image_size', 100, 'Eval image size')
83+
'eval_image_size', 299, 'Eval image size')
8484

8585
FLAGS = tf.app.flags.FLAGS
8686

@@ -116,7 +116,7 @@ def main(_):
116116
is_training=False)
117117

118118
eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
119-
image_data_0 = tf.placeholder(tf.string, [])
119+
image_data_0 = tf.gfile.FastGFile(FLAGS.image_file, 'rb').read()
120120
image_0 = tf.image.decode_jpeg(image_data_0, channels=3)
121121
image = image_preprocessing_fn(image_0, eval_image_size, eval_image_size)
122122
label = 0
@@ -139,7 +139,7 @@ def main(_):
139139
variables_to_restore = slim.get_variables_to_restore()
140140

141141
predictions = tf.argmax(logits, 1)
142-
142+
143143
num_batches = 1
144144

145145
if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
@@ -149,14 +149,12 @@ def main(_):
149149

150150
tf.logging.info('Restoring model checkpoint %s' % checkpoint_path)
151151

152-
raw_image_data = tf.gfile.FastGFile(FLAGS.image_file, 'rb').read()
153152
answer = slim.evaluation.evaluate_once(
154153
master=FLAGS.master,
155154
checkpoint_path=checkpoint_path,
156155
logdir=FLAGS.eval_dir,
157156
num_evals=num_batches,
158157
final_op=predictions,
159-
final_op_feed_dict={image_data_0: raw_image_data},
160158
variables_to_restore=variables_to_restore)
161159

162160
label_name = dataset.labels_to_names.get(answer[0])

0 commit comments

Comments
 (0)