|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding=utf-8 |
| 3 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +"""Evaluate BLEU score for all checkpoints in a given directory. |
| 18 | +
|
| 19 | +This script can be used in two ways. |
| 20 | +
|
| 21 | +To evaluate an already translated file: |
| 22 | +`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de` |
| 23 | +
|
| 24 | +To evaluate all checkpoints in a given directory: |
| 25 | +`t2t-bleu |
| 26 | + --model_dir=t2t_train |
| 27 | + --data_dir=t2t_data |
| 28 | + --translations_dir=my-translations |
| 29 | + --problems=translate_ende_wmt32k |
| 30 | + --hparams_set=transformer_big_single_gpu |
| 31 | + --source=wmt13_deen.en |
| 32 | + --reference=wmt13_deen.de` |
| 33 | +
|
| 34 | +In addition to the above-mentioned compulsory parameters, |
| 35 | +there are optional parameters: |
| 36 | +
|
| 37 | + * bleu_variant: cased (case-sensitive), uncased, both (default). |
| 38 | + * translations_dir: Where to store the translated files? Default="translations". |
| 39 | + * even_subdir: Where in the model_dir to store the even file? Default="", |
| 40 | + which means TensorBoard will show it as the same run as the training, but it will warn |
| 41 | + about "more than one metagraph event per run". event_subdir can be used e.g. if running |
| 42 | + this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`. |
| 43 | + * tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix |
| 44 | + can be used e.g. for different beam sizes if these should be plotted in different graphs. |
| 45 | + * min_steps: Don't evaluate checkpoints with less steps. |
| 46 | + Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps |
| 47 | + of the last successfully evaluated checkpoint. |
| 48 | + * report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True. |
| 49 | + This is useful, so TensorBoard reports correct relative time for the remaining checkpoints. |
| 50 | + This flag is set to False if min_steps is > 0. |
| 51 | + * wait_secs: Wait upto N seconds for a new checkpoint. Default=0. |
| 52 | + This is useful for continuous evaluation of a running training, |
| 53 | + in which case this should be equal to save_checkpoints_secs plus some reserve. |
| 54 | +""" |
| 55 | +from __future__ import absolute_import |
| 56 | +from __future__ import division |
| 57 | +from __future__ import print_function |
| 58 | +import os |
| 59 | +import time |
| 60 | +from collections import namedtuple |
| 61 | +from tensor2tensor.utils import decoding |
| 62 | +from tensor2tensor.utils import trainer_utils |
| 63 | +from tensor2tensor.utils import usr_dir |
| 64 | +from tensor2tensor.utils import bleu_hook |
| 65 | +import tensorflow as tf |
| 66 | + |
| 67 | +flags = tf.flags |
| 68 | +FLAGS = flags.FLAGS |
| 69 | + |
| 70 | +# t2t-bleu specific options |
| 71 | +flags.DEFINE_string("bleu_variant", "both", "Possible values: cased(case-sensitive), uncased, both(default).") |
| 72 | +flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.") |
| 73 | +flags.DEFINE_string("translation", None, "Path to the MT system translation file") |
| 74 | +flags.DEFINE_string("source", None, "Path to the source-language file to be translated") |
| 75 | +flags.DEFINE_string("reference", None, "Path to the reference translation file") |
| 76 | +flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files") |
| 77 | +flags.DEFINE_string("event_subdir", "", "Where in model_dir to store the event file") |
| 78 | +flags.DEFINE_string("tag_suffix", "", "What to add to BLEU_cased and BLEU_uncased tags. Default=''.") |
| 79 | +flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.") |
| 80 | +flags.DEFINE_integer("wait_secs", 0, "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs.") |
| 81 | +flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on flags.txt") |
| 82 | + |
| 83 | +# options derived from t2t-decode |
| 84 | +flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.") |
| 85 | +flags.DEFINE_string("t2t_usr_dir", "", |
| 86 | + "Path to a Python module that will be imported. The " |
| 87 | + "__init__.py file should include the necessary imports. " |
| 88 | + "The imported files should contain registrations, " |
| 89 | + "e.g. @registry.register_model calls, that will then be " |
| 90 | + "available to the t2t-decoder.") |
| 91 | +flags.DEFINE_string("master", "", "Address of TensorFlow master.") |
| 92 | +flags.DEFINE_string("schedule", "train_and_evaluate", |
| 93 | + "Must be train_and_evaluate for decoding.") |
| 94 | + |
| 95 | +Model = namedtuple('Model', 'filename time steps') |
| 96 | + |
| 97 | + |
| 98 | +def read_checkpoints_list(model_dir, min_steps): |
| 99 | + models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1])) |
| 100 | + for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))] |
| 101 | + return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps) |
| 102 | + |
| 103 | +def main(_): |
| 104 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 105 | + if FLAGS.translation: |
| 106 | + if FLAGS.model_dir: |
| 107 | + raise ValueError('Cannot specify both --translation and --model_dir.') |
| 108 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 109 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False) |
| 110 | + print("BLEU_uncased = %6.2f" % bleu) |
| 111 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 112 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True) |
| 113 | + print("BLEU_cased = %6.2f" % bleu) |
| 114 | + return |
| 115 | + |
| 116 | + usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) |
| 117 | + FLAGS.model = FLAGS.model or 'transformer' |
| 118 | + FLAGS.output_dir = FLAGS.model_dir |
| 119 | + trainer_utils.log_registry() |
| 120 | + trainer_utils.validate_flags() |
| 121 | + assert FLAGS.schedule == "train_and_evaluate" |
| 122 | + data_dir = os.path.expanduser(FLAGS.data_dir) |
| 123 | + model_dir = os.path.expanduser(FLAGS.model_dir) |
| 124 | + |
| 125 | + hparams = trainer_utils.create_hparams( |
| 126 | + FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) |
| 127 | + trainer_utils.add_problem_hparams(hparams, FLAGS.problems) |
| 128 | + estimator, _ = trainer_utils.create_experiment_components( |
| 129 | + data_dir=data_dir, |
| 130 | + model_name=FLAGS.model, |
| 131 | + hparams=hparams, |
| 132 | + run_config=trainer_utils.create_run_config(model_dir)) |
| 133 | + |
| 134 | + decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) |
| 135 | + decode_hp.add_hparam("shards", FLAGS.decode_shards) |
| 136 | + decode_hp.add_hparam("shard_id", FLAGS.worker_id) |
| 137 | + |
| 138 | + os.makedirs(FLAGS.translations_dir, exist_ok=True) |
| 139 | + translated_base_file = os.path.join(FLAGS.translations_dir, FLAGS.problems) |
| 140 | + event_dir = os.path.join(FLAGS.model_dir, FLAGS.event_subdir) |
| 141 | + last_step_file = os.path.join(event_dir, 'last_evaluated_step.txt') |
| 142 | + if FLAGS.min_steps == -1: |
| 143 | + try: |
| 144 | + with open(last_step_file) as ls_file: |
| 145 | + FLAGS.min_steps = int(ls_file.read()) |
| 146 | + except FileNotFoundError: |
| 147 | + FLAGS.min_steps = 0 |
| 148 | + if FLAGS.report_zero is None: |
| 149 | + FLAGS.report_zero = FLAGS.min_steps == 0 |
| 150 | + |
| 151 | + models = read_checkpoints_list(model_dir, FLAGS.min_steps) |
| 152 | + tf.logging.info("Found %d models with steps: %s" % (len(models), ", ".join(str(x.steps) for x in models))) |
| 153 | + |
| 154 | + writer = tf.summary.FileWriter(event_dir) |
| 155 | + if FLAGS.report_zero: |
| 156 | + start_time = os.path.getctime(os.path.join(model_dir, 'flags.txt')) |
| 157 | + values = [] |
| 158 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 159 | + values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0)) |
| 160 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 161 | + values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0)) |
| 162 | + writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=start_time, step=0)) |
| 163 | + |
| 164 | + exit_time = time.time() + FLAGS.wait_secs |
| 165 | + min_steps = FLAGS.min_steps |
| 166 | + while True: |
| 167 | + if not models and FLAGS.wait_secs: |
| 168 | + tf.logging.info('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time.asctime(time.localtime(exit_time))) |
| 169 | + while True: |
| 170 | + time.sleep(10) |
| 171 | + models = read_checkpoints_list(model_dir, min_steps) |
| 172 | + if models or time.time() > exit_time: |
| 173 | + break |
| 174 | + if not models: |
| 175 | + return |
| 176 | + |
| 177 | + model = models.pop(0) |
| 178 | + exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps |
| 179 | + tf.logging.info("Evaluating " + model.filename) |
| 180 | + out_file = translated_base_file + '-' + str(model.steps) |
| 181 | + tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO |
| 182 | + decoding.decode_from_file(estimator, FLAGS.source, decode_hp, out_file, checkpoint_path=model.filename) |
| 183 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 184 | + values = [] |
| 185 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 186 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False) |
| 187 | + values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu)) |
| 188 | + tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu)) |
| 189 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 190 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True) |
| 191 | + values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu)) |
| 192 | + tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu)) |
| 193 | + writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps)) |
| 194 | + writer.flush() |
| 195 | + with open(last_step_file, 'w') as ls_file: |
| 196 | + ls_file.write(str(model.steps) + '\n') |
| 197 | + |
| 198 | + |
| 199 | +if __name__ == "__main__": |
| 200 | + tf.app.run() |
0 commit comments