Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit bb1173a

Browse files
authored
Merge pull request #436 from martinpopel/bleu
Proper BLEU evaluation
2 parents e3cd447 + 7ba78a2 commit bb1173a

File tree

10 files changed

+285
-10
lines changed

10 files changed

+285
-10
lines changed

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,13 @@ t2t-decoder \
126126
--output_dir=$TRAIN_DIR \
127127
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
128128
--decode_from_file=$DECODE_FILE
129-
130-
cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
129+
--decode_to_file=translation.en
131130
```
132131

132+
# Eval BLEU
133+
134+
t2t-bleu --translation=translation.en --reference=ref-translation.de
135+
133136
---
134137

135138
## Installation

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'tensor2tensor/bin/t2t-datagen',
2525
'tensor2tensor/bin/t2t-decoder',
2626
'tensor2tensor/bin/t2t-make-tf-configs',
27+
'tensor2tensor/bin/t2t-bleu',
2728
],
2829
install_requires=[
2930
'bz2file',

tensor2tensor/bin/t2t-bleu

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2017 The Tensor2Tensor Authors.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Evaluate BLEU score for all checkpoints in a given directory.
18+
19+
This script can be used in two ways.
20+
21+
To evaluate an already translated file:
22+
`t2t-bleu --translation=my-wmt13.de --reference=wmt13_deen.de`
23+
24+
To evaluate all checkpoints in a given directory:
25+
`t2t-bleu
26+
--model_dir=t2t_train
27+
--data_dir=t2t_data
28+
--translations_dir=my-translations
29+
--problems=translate_ende_wmt32k
30+
--hparams_set=transformer_big_single_gpu
31+
--source=wmt13_deen.en
32+
--reference=wmt13_deen.de`
33+
34+
In addition to the above-mentioned compulsory parameters,
35+
there are optional parameters:
36+
37+
* bleu_variant: cased (case-sensitive), uncased, both (default).
38+
* translations_dir: Where to store the translated files? Default="translations".
39+
* even_subdir: Where in the model_dir to store the even file? Default="",
40+
which means TensorBoard will show it as the same run as the training, but it will warn
41+
about "more than one metagraph event per run". event_subdir can be used e.g. if running
42+
this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`.
43+
* tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix
44+
can be used e.g. for different beam sizes if these should be plotted in different graphs.
45+
* min_steps: Don't evaluate checkpoints with less steps.
46+
Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
47+
of the last successfully evaluated checkpoint.
48+
* report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True.
49+
This is useful, so TensorBoard reports correct relative time for the remaining checkpoints.
50+
This flag is set to False if min_steps is > 0.
51+
* wait_secs: Wait upto N seconds for a new checkpoint. Default=0.
52+
This is useful for continuous evaluation of a running training,
53+
in which case this should be equal to save_checkpoints_secs plus some reserve.
54+
"""
55+
from __future__ import absolute_import
56+
from __future__ import division
57+
from __future__ import print_function
58+
import os
59+
import time
60+
from collections import namedtuple
61+
from tensor2tensor.utils import decoding
62+
from tensor2tensor.utils import trainer_utils
63+
from tensor2tensor.utils import usr_dir
64+
from tensor2tensor.utils import bleu_hook
65+
import tensorflow as tf
66+
67+
flags = tf.flags
68+
FLAGS = flags.FLAGS
69+
70+
# t2t-bleu specific options
71+
flags.DEFINE_string("bleu_variant", "both", "Possible values: cased(case-sensitive), uncased, both(default).")
72+
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.")
73+
flags.DEFINE_string("translation", None, "Path to the MT system translation file")
74+
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
75+
flags.DEFINE_string("reference", None, "Path to the reference translation file")
76+
flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files")
77+
flags.DEFINE_string("event_subdir", "", "Where in model_dir to store the event file")
78+
flags.DEFINE_string("tag_suffix", "", "What to add to BLEU_cased and BLEU_uncased tags. Default=''.")
79+
flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.")
80+
flags.DEFINE_integer("wait_secs", 0, "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs.")
81+
flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on flags.txt")
82+
83+
# options derived from t2t-decode
84+
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
85+
flags.DEFINE_string("t2t_usr_dir", "",
86+
"Path to a Python module that will be imported. The "
87+
"__init__.py file should include the necessary imports. "
88+
"The imported files should contain registrations, "
89+
"e.g. @registry.register_model calls, that will then be "
90+
"available to the t2t-decoder.")
91+
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
92+
flags.DEFINE_string("schedule", "train_and_evaluate",
93+
"Must be train_and_evaluate for decoding.")
94+
95+
Model = namedtuple('Model', 'filename time steps')
96+
97+
98+
def read_checkpoints_list(model_dir, min_steps):
99+
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1]))
100+
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))]
101+
return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps)
102+
103+
def main(_):
104+
tf.logging.set_verbosity(tf.logging.INFO)
105+
if FLAGS.translation:
106+
if FLAGS.model_dir:
107+
raise ValueError('Cannot specify both --translation and --model_dir.')
108+
if FLAGS.bleu_variant in ('uncased', 'both'):
109+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False)
110+
print("BLEU_uncased = %6.2f" % bleu)
111+
if FLAGS.bleu_variant in ('cased', 'both'):
112+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True)
113+
print("BLEU_cased = %6.2f" % bleu)
114+
return
115+
116+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
117+
FLAGS.model = FLAGS.model or 'transformer'
118+
FLAGS.output_dir = FLAGS.model_dir
119+
trainer_utils.log_registry()
120+
trainer_utils.validate_flags()
121+
assert FLAGS.schedule == "train_and_evaluate"
122+
data_dir = os.path.expanduser(FLAGS.data_dir)
123+
model_dir = os.path.expanduser(FLAGS.model_dir)
124+
125+
hparams = trainer_utils.create_hparams(
126+
FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams)
127+
trainer_utils.add_problem_hparams(hparams, FLAGS.problems)
128+
estimator, _ = trainer_utils.create_experiment_components(
129+
data_dir=data_dir,
130+
model_name=FLAGS.model,
131+
hparams=hparams,
132+
run_config=trainer_utils.create_run_config(model_dir))
133+
134+
decode_hp = decoding.decode_hparams(FLAGS.decode_hparams)
135+
decode_hp.add_hparam("shards", FLAGS.decode_shards)
136+
decode_hp.add_hparam("shard_id", FLAGS.worker_id)
137+
138+
os.makedirs(FLAGS.translations_dir, exist_ok=True)
139+
translated_base_file = os.path.join(FLAGS.translations_dir, FLAGS.problems)
140+
event_dir = os.path.join(FLAGS.model_dir, FLAGS.event_subdir)
141+
last_step_file = os.path.join(event_dir, 'last_evaluated_step.txt')
142+
if FLAGS.min_steps == -1:
143+
try:
144+
with open(last_step_file) as ls_file:
145+
FLAGS.min_steps = int(ls_file.read())
146+
except FileNotFoundError:
147+
FLAGS.min_steps = 0
148+
if FLAGS.report_zero is None:
149+
FLAGS.report_zero = FLAGS.min_steps == 0
150+
151+
models = read_checkpoints_list(model_dir, FLAGS.min_steps)
152+
tf.logging.info("Found %d models with steps: %s" % (len(models), ", ".join(str(x.steps) for x in models)))
153+
154+
writer = tf.summary.FileWriter(event_dir)
155+
if FLAGS.report_zero:
156+
start_time = os.path.getctime(os.path.join(model_dir, 'flags.txt'))
157+
values = []
158+
if FLAGS.bleu_variant in ('uncased', 'both'):
159+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0))
160+
if FLAGS.bleu_variant in ('cased', 'both'):
161+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0))
162+
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=start_time, step=0))
163+
164+
exit_time = time.time() + FLAGS.wait_secs
165+
min_steps = FLAGS.min_steps
166+
while True:
167+
if not models and FLAGS.wait_secs:
168+
tf.logging.info('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time.asctime(time.localtime(exit_time)))
169+
while True:
170+
time.sleep(10)
171+
models = read_checkpoints_list(model_dir, min_steps)
172+
if models or time.time() > exit_time:
173+
break
174+
if not models:
175+
return
176+
177+
model = models.pop(0)
178+
exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps
179+
tf.logging.info("Evaluating " + model.filename)
180+
out_file = translated_base_file + '-' + str(model.steps)
181+
tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO
182+
decoding.decode_from_file(estimator, FLAGS.source, decode_hp, out_file, checkpoint_path=model.filename)
183+
tf.logging.set_verbosity(tf.logging.INFO)
184+
values = []
185+
if FLAGS.bleu_variant in ('uncased', 'both'):
186+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False)
187+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu))
188+
tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu))
189+
if FLAGS.bleu_variant in ('cased', 'both'):
190+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True)
191+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu))
192+
tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu))
193+
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps))
194+
writer.flush()
195+
with open(last_step_file, 'w') as ls_file:
196+
ls_file.write(str(model.steps) + '\n')
197+
198+
199+
if __name__ == "__main__":
200+
tf.app.run()

tensor2tensor/bin/t2t-datagen

100644100755
File mode changed.

tensor2tensor/bin/t2t-decoder

100644100755
+5-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ import tensorflow as tf
4646
flags = tf.flags
4747
FLAGS = flags.FLAGS
4848

49-
flags.DEFINE_string("output_dir", "", "Training directory to load from.")
49+
flags.DEFINE_string("output_dir", "",
50+
"Training directory where the latest checkpoint is used.")
51+
flags.DEFINE_string("checkpoint_path", None,
52+
"Path to the model checkpoint. Overrides output_dir.")
5053
flags.DEFINE_string("decode_from_file", None,
5154
"Path to the source file for decoding")
5255
flags.DEFINE_string("decode_to_file", None,
@@ -90,7 +93,7 @@ def main(_):
9093
decoding.decode_interactively(estimator, decode_hp)
9194
elif FLAGS.decode_from_file:
9295
decoding.decode_from_file(estimator, FLAGS.decode_from_file, decode_hp,
93-
FLAGS.decode_to_file)
96+
FLAGS.decode_to_file, checkpoint_path=FLAGS.checkpoint_path)
9497
else:
9598
decoding.decode_from_dataset(
9699
estimator,

tensor2tensor/bin/t2t-make-tf-configs

100644100755
File mode changed.

tensor2tensor/bin/t2t-trainer

100644100755
File mode changed.

tensor2tensor/utils/bleu_hook.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,17 @@
2020

2121
import collections
2222
import math
23+
import re
24+
import sys
25+
import unicodedata
2326

2427
# Dependency imports
2528

2629
import numpy as np
2730
# pylint: disable=redefined-builtin
2831
from six.moves import xrange
2932
from six.moves import zip
33+
import six
3034
# pylint: enable=redefined-builtin
3135

3236
import tensorflow as tf
@@ -92,10 +96,17 @@ def compute_bleu(reference_corpus,
9296
matches_by_order[len(ngram) - 1] += overlap[ngram]
9397
for ngram in translation_ngram_counts:
9498
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
99+
assert reference_length, "no reference provided"
100+
assert translation_length, "no translation provided"
95101
precisions = [0] * max_order
102+
smooth = 1.0
96103
for i in xrange(0, max_order):
97104
if possible_matches_by_order[i] > 0:
98-
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
105+
if matches_by_order[i] > 0:
106+
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
107+
else:
108+
smooth *= 2
109+
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
99110
else:
100111
precisions[i] = 0.0
101112

@@ -131,3 +142,58 @@ def bleu_score(predictions, labels, **unused_kwargs):
131142

132143
bleu = tf.py_func(compute_bleu, (labels, outputs), tf.float32)
133144
return bleu, tf.constant(1.0)
145+
146+
147+
class UnicodeRegex:
148+
"""Ad-hoc hack to recognize all punctuation and symbols.
149+
150+
without dependening on https://pypi.python.org/pypi/regex/."""
151+
def _property_chars(prefix):
152+
return ''.join(six.unichr(x) for x in range(sys.maxunicode)
153+
if unicodedata.category(six.unichr(x)).startswith(prefix))
154+
punctuation = _property_chars('P')
155+
nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
156+
punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])')
157+
symbol_re = re.compile('([' + _property_chars('S') + '])')
158+
159+
160+
def bleu_tokenize(string):
161+
r"""Tokenize a string following the official BLEU implementation.
162+
163+
See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
164+
In our case, the input string is expected to be just one line
165+
and no HTML entities de-escaping is needed.
166+
So we just tokenize on punctuation and symbols,
167+
except when a punctuation is preceded and followed by a digit
168+
(e.g. a comma/dot as a thousand/decimal separator).
169+
170+
Note that a numer (e.g. a year) followed by a dot at the end of sentence is NOT tokenized,
171+
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
172+
does not match this case (unless we add a space after each sentence).
173+
However, this error is already in the original mteval-v14.pl
174+
and we want to be consistent with it.
175+
176+
Args:
177+
string: the input string
178+
179+
Returns:
180+
a list of tokens
181+
"""
182+
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
183+
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
184+
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
185+
return string.split()
186+
187+
188+
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
189+
"""Compute BLEU for two files (reference and hypothesis translation)."""
190+
# TODO: Does anyone care about Python2 compatibility?
191+
ref_lines = open(ref_filename, 'rt', encoding='utf-8').read().splitlines()
192+
hyp_lines = open(hyp_filename, 'rt', encoding='utf-8').read().splitlines()
193+
assert len(ref_lines) == len(hyp_lines)
194+
if not case_sensitive:
195+
ref_lines = [x.lower() for x in ref_lines]
196+
hyp_lines = [x.lower() for x in hyp_lines]
197+
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
198+
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
199+
return compute_bleu(ref_tokens, hyp_tokens)

tensor2tensor/utils/bleu_hook_test.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ def testComputeNotEqual(self):
3939
translation_corpus = [[1, 2, 3, 4]]
4040
reference_corpus = [[5, 6, 7, 8]]
4141
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
42-
actual_bleu = 0.0
43-
self.assertEqual(bleu, actual_bleu)
42+
# The smoothing prevents 0 for small corpora
43+
actual_bleu = 0.0798679
44+
self.assertAllClose(bleu, actual_bleu, atol=1e-03)
4445

4546
def testComputeMultipleBatch(self):
4647
translation_corpus = [[1, 2, 3, 4], [5, 6, 7, 0]]
@@ -53,8 +54,9 @@ def testComputeMultipleNgrams(self):
5354
reference_corpus = [[1, 2, 1, 13], [12, 6, 7, 4, 8, 9, 10]]
5455
translation_corpus = [[1, 2, 1, 3], [5, 6, 7, 4]]
5556
bleu = bleu_hook.compute_bleu(reference_corpus, translation_corpus)
56-
actual_bleu = 0.486
57+
actual_bleu = 0.3436
5758
self.assertAllClose(bleu, actual_bleu, atol=1e-03)
5859

60+
5961
if __name__ == '__main__':
6062
tf.test.main()

tensor2tensor/utils/decoding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def decode_from_dataset(estimator,
200200
tf.logging.info("Completed inference on %d samples." % num_predictions) # pylint: disable=undefined-loop-variable
201201

202202

203-
def decode_from_file(estimator, filename, decode_hp, decode_to_file=None):
203+
def decode_from_file(estimator, filename, decode_hp, decode_to_file=None, checkpoint_path=None):
204204
"""Compute predictions on entries in filename and write them out."""
205205
if not decode_hp.batch_size:
206206
decode_hp.batch_size = 32
@@ -230,7 +230,7 @@ def input_fn():
230230
return _decode_input_tensor_to_features_dict(example, hparams)
231231

232232
decodes = []
233-
result_iter = estimator.predict(input_fn)
233+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
234234
for result in result_iter:
235235
if decode_hp.return_beams:
236236
beam_decodes = []

0 commit comments

Comments
 (0)