Skip to content

Commit 364f96d

Browse files
authored
Merge pull request tensorflow#3113 from asimshankar/mnist
[mnist]: Training and evaluation with TPUs.
2 parents aeb8cfc + 6daaec1 commit 364f96d

File tree

2 files changed

+142
-2
lines changed

2 files changed

+142
-2
lines changed

official/mnist/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def read32(bytestream):
3333

3434
def check_image_file_header(filename):
3535
"""Validate that filename corresponds to images for the MNIST dataset."""
36-
with open(filename) as f:
36+
with tf.gfile.Open(filename) as f:
3737
magic = read32(f)
3838
num_images = read32(f)
3939
rows = read32(f)
@@ -49,7 +49,7 @@ def check_image_file_header(filename):
4949

5050
def check_labels_file_header(filename):
5151
"""Validate that filename corresponds to labels for the MNIST dataset."""
52-
with open(filename) as f:
52+
with tf.gfile.Open(filename) as f:
5353
magic = read32(f)
5454
num_items = read32(f)
5555
if magic != 2049:

official/mnist/mnist_tpu.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""MNIST model training using TPUs.
16+
17+
This program demonstrates training of the convolutional neural network model
18+
defined in mnist.py on Google Cloud TPUs (https://cloud.google.com/tpu/).
19+
20+
If you are not interested in TPUs, you should ignore this file.
21+
"""
22+
from __future__ import absolute_import
23+
from __future__ import division
24+
from __future__ import print_function
25+
26+
import tensorflow as tf
27+
import dataset
28+
import mnist
29+
30+
tf.flags.DEFINE_string("data_dir", "",
31+
"Path to directory containing the MNIST dataset")
32+
tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir")
33+
tf.flags.DEFINE_integer("batch_size", 1024,
34+
"Mini-batch size for the training. Note that this "
35+
"is the global batch size and not the per-shard batch.")
36+
tf.flags.DEFINE_integer("train_steps", 1000, "Total number of training steps.")
37+
tf.flags.DEFINE_integer("eval_steps", 0,
38+
"Total number of evaluation steps. If `0`, evaluation "
39+
"after training is skipped.")
40+
tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.")
41+
42+
tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs")
43+
tf.flags.DEFINE_string("master", "local", "GRPC URL of the Cloud TPU instance.")
44+
tf.flags.DEFINE_integer("iterations", 50,
45+
"Number of iterations per TPU training loop.")
46+
tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).")
47+
48+
FLAGS = tf.flags.FLAGS
49+
50+
51+
def metric_fn(labels, logits):
52+
accuracy = tf.metrics.accuracy(
53+
labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
54+
return {"accuracy": accuracy}
55+
56+
57+
def model_fn(features, labels, mode, params):
58+
del params
59+
if mode == tf.estimator.ModeKeys.PREDICT:
60+
raise RuntimeError("mode {} is not supported yet".format(mode))
61+
image = features
62+
if isinstance(image, dict):
63+
image = features["image"]
64+
65+
model = mnist.Model("channels_last")
66+
logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN))
67+
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
68+
69+
if mode == tf.estimator.ModeKeys.TRAIN:
70+
learning_rate = tf.train.exponential_decay(
71+
FLAGS.learning_rate,
72+
tf.train.get_global_step(),
73+
decay_steps=100000,
74+
decay_rate=0.96)
75+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
76+
if FLAGS.use_tpu:
77+
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
78+
return tf.contrib.tpu.TPUEstimatorSpec(
79+
mode=mode,
80+
loss=loss,
81+
train_op=optimizer.minimize(loss, tf.train.get_global_step()))
82+
83+
if mode == tf.estimator.ModeKeys.EVAL:
84+
return tf.contrib.tpu.TPUEstimatorSpec(
85+
mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits]))
86+
87+
88+
def train_input_fn(params):
89+
batch_size = params["batch_size"]
90+
data_dir = params["data_dir"]
91+
# Retrieves the batch size for the current shard. The # of shards is
92+
# computed according to the input pipeline deployment. See
93+
# `tf.contrib.tpu.RunConfig` for details.
94+
ds = dataset.train(data_dir).cache().repeat().shuffle(
95+
buffer_size=50000).apply(
96+
tf.contrib.data.batch_and_drop_remainder(batch_size))
97+
images, labels = ds.make_one_shot_iterator().get_next()
98+
return images, labels
99+
100+
101+
def eval_input_fn(params):
102+
batch_size = params["batch_size"]
103+
data_dir = params["data_dir"]
104+
ds = dataset.test(data_dir).apply(
105+
tf.contrib.data.batch_and_drop_remainder(batch_size))
106+
images, labels = ds.make_one_shot_iterator().get_next()
107+
return images, labels
108+
109+
110+
def main(argv):
111+
del argv # Unused.
112+
tf.logging.set_verbosity(tf.logging.INFO)
113+
114+
run_config = tf.contrib.tpu.RunConfig(
115+
master=FLAGS.master,
116+
evaluation_master=FLAGS.master,
117+
model_dir=FLAGS.model_dir,
118+
session_config=tf.ConfigProto(
119+
allow_soft_placement=True, log_device_placement=True),
120+
tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards),
121+
)
122+
123+
estimator = tf.contrib.tpu.TPUEstimator(
124+
model_fn=model_fn,
125+
use_tpu=FLAGS.use_tpu,
126+
train_batch_size=FLAGS.batch_size,
127+
eval_batch_size=FLAGS.batch_size,
128+
params={"data_dir": FLAGS.data_dir},
129+
config=run_config)
130+
# TPUEstimator.train *requires* a max_steps argument.
131+
estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps)
132+
# TPUEstimator.evaluate *requires* a steps argument.
133+
# Note that the number of examples used during evaluation is
134+
# --eval_steps * --batch_size.
135+
# So if you change --batch_size then change --eval_steps too.
136+
estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps)
137+
138+
139+
if __name__ == "__main__":
140+
tf.app.run()

0 commit comments

Comments
 (0)