|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding=utf-8 |
| 3 | +# Copyright 2017 The Tensor2Tensor Authors. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +"""Evaluate BLEU score for all checkpoints/translations in a given directory. |
| 18 | +
|
| 19 | +This script can be used in two ways. |
| 20 | +
|
| 21 | +To evaluate one already translated file: |
| 22 | +`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de` |
| 23 | +
|
| 24 | +To evaluate all translations in a given directory (translated by t2t-translate-all): |
| 25 | +`t2t-bleu |
| 26 | + --translations_dir=my-translations |
| 27 | + --reference=wmt13_deen.de |
| 28 | + --event_dir=events` |
| 29 | +
|
| 30 | +In addition to the above-mentioned compulsory parameters, |
| 31 | +there are optional parameters: |
| 32 | +
|
| 33 | + * bleu_variant: cased (case-sensitive), uncased, both (default). |
| 34 | + * tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. tag_suffix |
| 35 | + can be used e.g. for different beam sizes if these should be plotted in different graphs. |
| 36 | + * min_steps: Don't evaluate checkpoints with less steps. |
| 37 | + Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps |
| 38 | + of the last successfully evaluated checkpoint. |
| 39 | + * report_zero: Store BLEU=0 and guess its time based on the oldest file in the translations_dir. |
| 40 | + Default=True. This is useful, so TensorBoard reports correct relative time for the remaining |
| 41 | + checkpoints. This flag is set to False if min_steps is > 0. |
| 42 | + * wait_minutes: Wait upto N minutes for a new translated file. Default=0. |
| 43 | + This is useful for continuous evaluation of a running training, |
| 44 | + in which case this should be equal to save_checkpoints_secs/60 plus time needed for translation |
| 45 | + plus some reserve. |
| 46 | +""" |
| 47 | +from __future__ import absolute_import |
| 48 | +from __future__ import division |
| 49 | +from __future__ import print_function |
| 50 | +import os |
| 51 | +from tensor2tensor.utils import bleu_hook |
| 52 | +import tensorflow as tf |
| 53 | + |
| 54 | + |
| 55 | +flags = tf.flags |
| 56 | +FLAGS = flags.FLAGS |
| 57 | + |
| 58 | +flags.DEFINE_string("source", None, "Path to the source-language file to be translated") |
| 59 | +flags.DEFINE_string("reference", None, "Path to the reference translation file") |
| 60 | +flags.DEFINE_string("translation", None, "Path to the MT system translation file") |
| 61 | +flags.DEFINE_string("translations_dir", None, "Directory with translated files to be evaulated.") |
| 62 | +flags.DEFINE_string("event_dir", None, "Where to store the event file.") |
| 63 | + |
| 64 | +flags.DEFINE_string("bleu_variant", "both", |
| 65 | + "Possible values: cased(case-sensitive), uncased, both(default).") |
| 66 | +flags.DEFINE_string("tag_suffix", "", |
| 67 | + "What to add to BLEU_cased and BLEU_uncased tags. Default=''.") |
| 68 | +flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.") |
| 69 | +flags.DEFINE_integer("wait_minutes", 0, |
| 70 | + "Wait upto N minutes for a new checkpoint, cf. save_checkpoints_secs.") |
| 71 | +flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on the oldest file.") |
| 72 | + |
| 73 | + |
| 74 | +def main(_): |
| 75 | + tf.logging.set_verbosity(tf.logging.INFO) |
| 76 | + if FLAGS.translation: |
| 77 | + if FLAGS.translations_dir: |
| 78 | + raise ValueError('Cannot specify both --translation and --translations_dir.') |
| 79 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 80 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False) |
| 81 | + print("BLEU_uncased = %6.2f" % bleu) |
| 82 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 83 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True) |
| 84 | + print("BLEU_cased = %6.2f" % bleu) |
| 85 | + return |
| 86 | + |
| 87 | + if not FLAGS.translations_dir: |
| 88 | + raise ValueError('Either --translation or --translations_dir must be specified.') |
| 89 | + transl_dir = os.path.expanduser(FLAGS.translations_dir) |
| 90 | + |
| 91 | + last_step_file = os.path.join(FLAGS.event_dir, 'last_evaluated_step.txt') |
| 92 | + if FLAGS.min_steps == -1: |
| 93 | + try: |
| 94 | + with open(last_step_file) as ls_file: |
| 95 | + FLAGS.min_steps = int(ls_file.read()) |
| 96 | + except FileNotFoundError: |
| 97 | + FLAGS.min_steps = 0 |
| 98 | + if FLAGS.report_zero is None: |
| 99 | + FLAGS.report_zero = FLAGS.min_steps == 0 |
| 100 | + |
| 101 | + writer = tf.summary.FileWriter(FLAGS.event_dir) |
| 102 | + for transl_file in bleu_hook.stepfiles_iterator(transl_dir, FLAGS.wait_minutes, |
| 103 | + FLAGS.min_steps, path_suffix=''): |
| 104 | + # report_zero handling must be inside the for-loop, |
| 105 | + # so we are sure the transl_dir is already created. |
| 106 | + if FLAGS.report_zero: |
| 107 | + all_files = (os.path.join(transl_dir, f) for f in os.listdir(transl_dir)) |
| 108 | + start_time = min(os.path.getmtime(f) for f in all_files if os.path.isfile(f)) |
| 109 | + values = [] |
| 110 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 111 | + values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0)) |
| 112 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 113 | + values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0)) |
| 114 | + writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), |
| 115 | + wall_time=start_time, step=0)) |
| 116 | + FLAGS.report_zero = False |
| 117 | + |
| 118 | + filename = transl_file.filename |
| 119 | + tf.logging.info("Evaluating " + filename) |
| 120 | + values = [] |
| 121 | + if FLAGS.bleu_variant in ('uncased', 'both'): |
| 122 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, filename, case_sensitive=False) |
| 123 | + values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu)) |
| 124 | + tf.logging.info("%s: BLEU_uncased = %6.2f" % (filename, bleu)) |
| 125 | + if FLAGS.bleu_variant in ('cased', 'both'): |
| 126 | + bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, filename, case_sensitive=True) |
| 127 | + values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu)) |
| 128 | + tf.logging.info("%s: BLEU_cased = %6.2f" % (transl_file.filename, bleu)) |
| 129 | + writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), |
| 130 | + wall_time=transl_file.mtime, step=transl_file.steps)) |
| 131 | + writer.flush() |
| 132 | + with open(last_step_file, 'w') as ls_file: |
| 133 | + ls_file.write(str(transl_file.steps) + '\n') |
| 134 | + |
| 135 | + |
| 136 | +if __name__ == "__main__": |
| 137 | + tf.app.run() |
0 commit comments