|
| 1 | +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""MNIST model training using TPUs. |
| 16 | +
|
| 17 | +This program demonstrates training of the convolutional neural network model |
| 18 | +defined in mnist.py on Google Cloud TPUs (https://cloud.google.com/tpu/). |
| 19 | +
|
| 20 | +If you are not interested in TPUs, you should ignore this file. |
| 21 | +""" |
| 22 | +from __future__ import absolute_import |
| 23 | +from __future__ import division |
| 24 | +from __future__ import print_function |
| 25 | + |
| 26 | +import tensorflow as tf |
| 27 | +import dataset |
| 28 | +import mnist |
| 29 | + |
| 30 | +tf.flags.DEFINE_string("data_dir", "", |
| 31 | + "Path to directory containing the MNIST dataset") |
| 32 | +tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir") |
| 33 | +tf.flags.DEFINE_integer("batch_size", 1024, |
| 34 | + "Mini-batch size for the training. Note that this " |
| 35 | + "is the global batch size and not the per-shard batch.") |
| 36 | +tf.flags.DEFINE_integer("train_steps", 1000, "Total number of training steps.") |
| 37 | +tf.flags.DEFINE_integer("eval_steps", 0, |
| 38 | + "Total number of evaluation steps. If `0`, evaluation " |
| 39 | + "after training is skipped.") |
| 40 | +tf.flags.DEFINE_float("learning_rate", 0.05, "Learning rate.") |
| 41 | + |
| 42 | +tf.flags.DEFINE_bool("use_tpu", True, "Use TPUs rather than plain CPUs") |
| 43 | +tf.flags.DEFINE_string("master", "local", "GRPC URL of the Cloud TPU instance.") |
| 44 | +tf.flags.DEFINE_integer("iterations", 50, |
| 45 | + "Number of iterations per TPU training loop.") |
| 46 | +tf.flags.DEFINE_integer("num_shards", 8, "Number of shards (TPU chips).") |
| 47 | + |
| 48 | +FLAGS = tf.flags.FLAGS |
| 49 | + |
| 50 | + |
| 51 | +def metric_fn(labels, logits): |
| 52 | + accuracy = tf.metrics.accuracy( |
| 53 | + labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1)) |
| 54 | + return {"accuracy": accuracy} |
| 55 | + |
| 56 | + |
| 57 | +def model_fn(features, labels, mode, params): |
| 58 | + del params |
| 59 | + if mode == tf.estimator.ModeKeys.PREDICT: |
| 60 | + raise RuntimeError("mode {} is not supported yet".format(mode)) |
| 61 | + image = features |
| 62 | + if isinstance(image, dict): |
| 63 | + image = features["image"] |
| 64 | + |
| 65 | + model = mnist.Model("channels_last") |
| 66 | + logits = model(image, training=(mode == tf.estimator.ModeKeys.TRAIN)) |
| 67 | + loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) |
| 68 | + |
| 69 | + if mode == tf.estimator.ModeKeys.TRAIN: |
| 70 | + learning_rate = tf.train.exponential_decay( |
| 71 | + FLAGS.learning_rate, |
| 72 | + tf.train.get_global_step(), |
| 73 | + decay_steps=100000, |
| 74 | + decay_rate=0.96) |
| 75 | + optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) |
| 76 | + if FLAGS.use_tpu: |
| 77 | + optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) |
| 78 | + return tf.contrib.tpu.TPUEstimatorSpec( |
| 79 | + mode=mode, |
| 80 | + loss=loss, |
| 81 | + train_op=optimizer.minimize(loss, tf.train.get_global_step())) |
| 82 | + |
| 83 | + if mode == tf.estimator.ModeKeys.EVAL: |
| 84 | + return tf.contrib.tpu.TPUEstimatorSpec( |
| 85 | + mode=mode, loss=loss, eval_metrics=(metric_fn, [labels, logits])) |
| 86 | + |
| 87 | + |
| 88 | +def train_input_fn(params): |
| 89 | + batch_size = params["batch_size"] |
| 90 | + data_dir = params["data_dir"] |
| 91 | + # Retrieves the batch size for the current shard. The # of shards is |
| 92 | + # computed according to the input pipeline deployment. See |
| 93 | + # `tf.contrib.tpu.RunConfig` for details. |
| 94 | + ds = dataset.train(data_dir).cache().repeat().shuffle( |
| 95 | + buffer_size=50000).apply( |
| 96 | + tf.contrib.data.batch_and_drop_remainder(batch_size)) |
| 97 | + images, labels = ds.make_one_shot_iterator().get_next() |
| 98 | + return images, labels |
| 99 | + |
| 100 | + |
| 101 | +def eval_input_fn(params): |
| 102 | + batch_size = params["batch_size"] |
| 103 | + data_dir = params["data_dir"] |
| 104 | + ds = dataset.test(data_dir).apply( |
| 105 | + tf.contrib.data.batch_and_drop_remainder(batch_size)) |
| 106 | + images, labels = ds.make_one_shot_iterator().get_next() |
| 107 | + return images, labels |
| 108 | + |
| 109 | + |
| 110 | +def main(argv): |
| 111 | + del argv # Unused. |
| 112 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 113 | + |
| 114 | + run_config = tf.contrib.tpu.RunConfig( |
| 115 | + master=FLAGS.master, |
| 116 | + evaluation_master=FLAGS.master, |
| 117 | + model_dir=FLAGS.model_dir, |
| 118 | + session_config=tf.ConfigProto( |
| 119 | + allow_soft_placement=True, log_device_placement=True), |
| 120 | + tpu_config=tf.contrib.tpu.TPUConfig(FLAGS.iterations, FLAGS.num_shards), |
| 121 | + ) |
| 122 | + |
| 123 | + estimator = tf.contrib.tpu.TPUEstimator( |
| 124 | + model_fn=model_fn, |
| 125 | + use_tpu=FLAGS.use_tpu, |
| 126 | + train_batch_size=FLAGS.batch_size, |
| 127 | + eval_batch_size=FLAGS.batch_size, |
| 128 | + params={"data_dir": FLAGS.data_dir}, |
| 129 | + config=run_config) |
| 130 | + # TPUEstimator.train *requires* a max_steps argument. |
| 131 | + estimator.train(input_fn=train_input_fn, max_steps=FLAGS.train_steps) |
| 132 | + # TPUEstimator.evaluate *requires* a steps argument. |
| 133 | + # Note that the number of examples used during evaluation is |
| 134 | + # --eval_steps * --batch_size. |
| 135 | + # So if you change --batch_size then change --eval_steps too. |
| 136 | + estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_steps) |
| 137 | + |
| 138 | + |
| 139 | +if __name__ == "__main__": |
| 140 | + tf.app.run() |
0 commit comments