Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 584a290

Browse files
committed
scripts for proper BLEU evaluation, batch translation and averaging
`t2t-bleu` computes the "real" BLEU (giving the same result as [sacréBLEU](https://github.com/awslabs/sockeye/tree/master/contrib/sacrebleu) with `--tokenization intl` and as [mteval-v14.pl](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl) with `--international-tokenization`). It can be used in two ways: * To evaluate an already translated file: `t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de` * To evaluate all translations in a given directory. `t2t-translate-all` translates all checkpoints in a given directory. A custom command (e.g. SGE cluster wrapper) can be used instead of `t2t-decoder` for the translation. `t2t-avg-all` for each checkpoint in a given directory it averages it with the N preceding ones. All three scripts wait a given number of minutes for new checkpoints (produced by t2t-decoder, which can be run concurrently with these scripts).
1 parent c1cd875 commit 584a290

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed

tensor2tensor/bin/t2t-avg-all

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2017 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Script to continously average last N checkpoints in a given directory."""
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import os
23+
import logging
24+
25+
# Dependency imports
26+
27+
import numpy as np
28+
import six
29+
from six.moves import zip # pylint: disable=redefined-builtin
30+
from collections import deque
31+
import shutil
32+
import tensorflow as tf
33+
from tensor2tensor.utils import bleu_hook
34+
35+
flags = tf.flags
36+
FLAGS = flags.FLAGS
37+
38+
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.")
39+
flags.DEFINE_string("output_dir", "avg/", "Directory to output the averaged checkpoints to.")
40+
flags.DEFINE_integer("n", 8, "How many checkpoints should be averaged?")
41+
flags.DEFINE_integer("min_steps", 0, "Ignore checkpoints with less steps.")
42+
flags.DEFINE_integer("wait_minutes", 0, "Wait upto N minutes for a new checkpoint.")
43+
44+
45+
def main(_):
46+
tf.logging._handler.setFormatter(logging.Formatter("%(asctime)s:" + logging.BASIC_FORMAT, None))
47+
tf.logging.set_verbosity(tf.logging.INFO)
48+
49+
model_dir = os.path.expanduser(FLAGS.model_dir)
50+
output_dir = os.path.expanduser(FLAGS.output_dir)
51+
out_base_file = os.path.join(output_dir, 'model.ckpt')
52+
53+
# Copy flags.txt with the original time, so t2t-bleu can report correct relative time.
54+
os.makedirs(FLAGS.output_dir, exist_ok=True)
55+
if not os.path.exists(os.path.join(output_dir, 'flags.txt')):
56+
shutil.copy2(os.path.join(model_dir, 'flags.txt'), os.path.join(output_dir, 'flags.txt'))
57+
58+
models_processed = 0
59+
queue = deque()
60+
for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes, FLAGS.min_steps):
61+
if models_processed == 0:
62+
var_list = tf.contrib.framework.list_variables(model.filename)
63+
avg_values = {}
64+
for (name, shape) in var_list:
65+
if not name.startswith("global_step"):
66+
avg_values[name] = np.zeros(shape)
67+
models_processed += 1
68+
69+
tf.logging.info("Loading [%d]: %s" % (models_processed, model.filename))
70+
reader = tf.contrib.framework.load_checkpoint(model.filename)
71+
for name in avg_values:
72+
avg_values[name] += reader.get_tensor(name) / FLAGS.n
73+
queue.append(model)
74+
if len(queue) < FLAGS.n:
75+
continue
76+
77+
out_file = "%s-%d" % (out_base_file, model.steps)
78+
tf_vars = []
79+
tf.logging.info("Averaging %s" % (out_file))
80+
for (name, value) in six.iteritems(avg_values):
81+
tf_vars.append(tf.get_variable(name, shape=value.shape)) # TODO , dtype=var_dtypes[name]
82+
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
83+
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
84+
85+
global_step = tf.Variable(model.steps, name="global_step", trainable=False, dtype=tf.int64)
86+
saver = tf.train.Saver(tf.global_variables())
87+
88+
tf.logging.info("Running session for %s" % (out_file))
89+
with tf.Session() as sess:
90+
sess.run(tf.global_variables_initializer())
91+
for p, assign_op, (name, value) in zip(placeholders, assign_ops, six.iteritems(avg_values)):
92+
sess.run(assign_op, {p: value})
93+
tf.logging.info("Storing to %s" % out_file)
94+
saver.save(sess, out_base_file, global_step=global_step)
95+
os.utime(out_file + '.index', (model.mtime, model.mtime))
96+
97+
tf.reset_default_graph()
98+
first_model = queue.popleft()
99+
100+
reader = tf.contrib.framework.load_checkpoint(first_model.filename)
101+
for name in avg_values:
102+
avg_values[name] -= reader.get_tensor(name) / FLAGS.n
103+
104+
105+
if __name__ == "__main__":
106+
tf.app.run()

tensor2tensor/bin/t2t-bleu

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2017 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Evaluate BLEU score for all checkpoints/translations in a given directory.
18+
19+
This script can be used in two ways.
20+
21+
To evaluate one already translated file:
22+
`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de`
23+
24+
To evaluate all translations in a given directory (translated by t2t-translate-all):
25+
`t2t-bleu
26+
--translations_dir=my-translations
27+
--reference=wmt13_deen.de
28+
--event_dir=events`
29+
30+
In addition to the above-mentioned compulsory parameters,
31+
there are optional parameters:
32+
33+
* bleu_variant: cased (case-sensitive), uncased, both (default).
34+
* tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. tag_suffix
35+
can be used e.g. for different beam sizes if these should be plotted in different graphs.
36+
* min_steps: Don't evaluate checkpoints with less steps.
37+
Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
38+
of the last successfully evaluated checkpoint.
39+
* report_zero: Store BLEU=0 and guess its time based on the oldest file in the translations_dir.
40+
Default=True. This is useful, so TensorBoard reports correct relative time for the remaining
41+
checkpoints. This flag is set to False if min_steps is > 0.
42+
* wait_minutes: Wait upto N minutes for a new translated file. Default=0.
43+
This is useful for continuous evaluation of a running training,
44+
in which case this should be equal to save_checkpoints_secs/60 plus time needed for translation
45+
plus some reserve.
46+
"""
47+
from __future__ import absolute_import
48+
from __future__ import division
49+
from __future__ import print_function
50+
import os
51+
from tensor2tensor.utils import bleu_hook
52+
import tensorflow as tf
53+
54+
55+
flags = tf.flags
56+
FLAGS = flags.FLAGS
57+
58+
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
59+
flags.DEFINE_string("reference", None, "Path to the reference translation file")
60+
flags.DEFINE_string("translation", None, "Path to the MT system translation file")
61+
flags.DEFINE_string("translations_dir", None, "Directory with translated files to be evaulated.")
62+
flags.DEFINE_string("event_dir", None, "Where to store the event file.")
63+
64+
flags.DEFINE_string("bleu_variant", "both",
65+
"Possible values: cased(case-sensitive), uncased, both(default).")
66+
flags.DEFINE_string("tag_suffix", "",
67+
"What to add to BLEU_cased and BLEU_uncased tags. Default=''.")
68+
flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.")
69+
flags.DEFINE_integer("wait_minutes", 0,
70+
"Wait upto N minutes for a new checkpoint, cf. save_checkpoints_secs.")
71+
flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on the oldest file.")
72+
73+
74+
def main(_):
75+
tf.logging.set_verbosity(tf.logging.INFO)
76+
if FLAGS.translation:
77+
if FLAGS.translations_dir:
78+
raise ValueError('Cannot specify both --translation and --translations_dir.')
79+
if FLAGS.bleu_variant in ('uncased', 'both'):
80+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False)
81+
print("BLEU_uncased = %6.2f" % bleu)
82+
if FLAGS.bleu_variant in ('cased', 'both'):
83+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True)
84+
print("BLEU_cased = %6.2f" % bleu)
85+
return
86+
87+
if not FLAGS.translations_dir:
88+
raise ValueError('Either --translation or --translations_dir must be specified.')
89+
transl_dir = os.path.expanduser(FLAGS.translations_dir)
90+
91+
last_step_file = os.path.join(FLAGS.event_dir, 'last_evaluated_step.txt')
92+
if FLAGS.min_steps == -1:
93+
try:
94+
with open(last_step_file) as ls_file:
95+
FLAGS.min_steps = int(ls_file.read())
96+
except FileNotFoundError:
97+
FLAGS.min_steps = 0
98+
if FLAGS.report_zero is None:
99+
FLAGS.report_zero = FLAGS.min_steps == 0
100+
101+
writer = tf.summary.FileWriter(FLAGS.event_dir)
102+
for transl_file in bleu_hook.stepfiles_iterator(transl_dir, FLAGS.wait_minutes,
103+
FLAGS.min_steps, path_suffix=''):
104+
# report_zero handling must be inside the for-loop,
105+
# so we are sure the transl_dir is already created.
106+
if FLAGS.report_zero:
107+
all_files = (os.path.join(transl_dir, f) for f in os.listdir(transl_dir))
108+
start_time = min(os.path.getmtime(f) for f in all_files if os.path.isfile(f))
109+
values = []
110+
if FLAGS.bleu_variant in ('uncased', 'both'):
111+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0))
112+
if FLAGS.bleu_variant in ('cased', 'both'):
113+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0))
114+
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values),
115+
wall_time=start_time, step=0))
116+
FLAGS.report_zero = False
117+
118+
filename = transl_file.filename
119+
tf.logging.info("Evaluating " + filename)
120+
values = []
121+
if FLAGS.bleu_variant in ('uncased', 'both'):
122+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, filename, case_sensitive=False)
123+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu))
124+
tf.logging.info("%s: BLEU_uncased = %6.2f" % (filename, bleu))
125+
if FLAGS.bleu_variant in ('cased', 'both'):
126+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, filename, case_sensitive=True)
127+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu))
128+
tf.logging.info("%s: BLEU_cased = %6.2f" % (transl_file.filename, bleu))
129+
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values),
130+
wall_time=transl_file.mtime, step=transl_file.steps))
131+
writer.flush()
132+
with open(last_step_file, 'w') as ls_file:
133+
ls_file.write(str(transl_file.steps) + '\n')
134+
135+
136+
if __name__ == "__main__":
137+
tf.app.run()

tensor2tensor/bin/t2t-translate-all

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2017 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Translate a file with all checkpoints in a given directory.
18+
19+
t2t-decoder will be executed with these parameters:
20+
--problems
21+
--data_dir
22+
--output_dir with the value of --model_dir
23+
--decode_from_file with the value of --source
24+
--decode_hparams with properly formated --beam_size and --alpha
25+
--checkpoint_path automatically filled
26+
--decode_to_file automatically filled
27+
"""
28+
from __future__ import absolute_import
29+
from __future__ import division
30+
from __future__ import print_function
31+
import os
32+
import shutil
33+
import tensorflow as tf
34+
from tensor2tensor.utils import bleu_hook
35+
36+
37+
flags = tf.flags
38+
39+
# t2t-translate-all specific options
40+
flags.DEFINE_string("decoder_command", "t2t-decoder {params}",
41+
"Which command to execute instead t2t-decoder."
42+
"{params} is replaced by the parameters. Useful e.g. for qsub wrapper.")
43+
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.")
44+
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
45+
flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files.")
46+
flags.DEFINE_integer("min_steps", 0, "Ignore checkpoints with less steps.")
47+
flags.DEFINE_integer("wait_minutes", 0, "Wait upto N minutes for a new checkpoint")
48+
49+
# options derived from t2t-decoder
50+
flags.DEFINE_integer("beam_size", 4, "Beam-search width.")
51+
flags.DEFINE_float("alpha", 0.6, "Beam-search alpha.")
52+
flags.DEFINE_string("model", "transformer", "see t2t-decoder")
53+
flags.DEFINE_string("t2t_usr_dir", None, "see t2t-decoder")
54+
flags.DEFINE_string("data_dir", None, "see t2t-decoder")
55+
flags.DEFINE_string("problems", None, "see t2t-decoder")
56+
flags.DEFINE_string("hparams_set", "transformer_big_single_gpu", "see t2t-decoder")
57+
58+
59+
def main(_):
60+
FLAGS = flags.FLAGS
61+
tf.logging.set_verbosity(tf.logging.INFO)
62+
model_dir = os.path.expanduser(FLAGS.model_dir)
63+
translations_dir = os.path.expanduser(FLAGS.translations_dir)
64+
source = os.path.expanduser(FLAGS.source)
65+
os.makedirs(translations_dir, exist_ok=True)
66+
translated_base_file = os.path.join(translations_dir, FLAGS.problems)
67+
68+
# Copy flags.txt with the original time, so t2t-bleu can report correct relative time.
69+
flags_path = os.path.join(translations_dir, FLAGS.problems + '-flags.txt')
70+
if not os.path.exists(flags_path):
71+
shutil.copy2(os.path.join(model_dir, 'flags.txt'), flags_path)
72+
73+
for model in bleu_hook.stepfiles_iterator(model_dir, FLAGS.wait_minutes, FLAGS.min_steps):
74+
tf.logging.info("Translating " + model.filename)
75+
out_file = translated_base_file + '-' + str(model.steps)
76+
if os.path.exists(out_file):
77+
tf.logging.info(out_file + " already exists, so skipping it.")
78+
else:
79+
tf.logging.info("Translating " + out_file)
80+
params = ("--t2t_usr_dir={FLAGS.t2t_usr_dir} --output_dir={model_dir} "
81+
"--data_dir={FLAGS.data_dir} --problems={FLAGS.problems} "
82+
"--decode_hparams=beam_size={FLAGS.beam_size},alpha={FLAGS.alpha} "
83+
"--model={FLAGS.model} --hparams_set={FLAGS.hparams_set} "
84+
"--checkpoint_path={model.filename} --decode_from_file={source} "
85+
"--decode_to_file={out_file}".format(**locals()))
86+
command = FLAGS.decoder_command.format(**locals())
87+
tf.logging.info("Running:\n" + command)
88+
os.system(command)
89+
90+
if __name__ == "__main__":
91+
tf.app.run()

0 commit comments

Comments
 (0)