Skip to content

Commit 697e506

Browse files
committed
option to count bleu after my own postprocessing
requires this to be added to the registered Translate problem class in t2t_usr_dir: `def postprocess(self, text): return re.sub("@@ ","",text)` (example for standard BPE postprocessing) and `def needs_postprocessing(self): return True`
1 parent 322a8ce commit 697e506

File tree

1 file changed

+71
-17
lines changed

1 file changed

+71
-17
lines changed

tensor2tensor/bin/t2t-bleu

+71-17
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,18 @@ from tensor2tensor.utils import decoding
6262
from tensor2tensor.utils import trainer_utils
6363
from tensor2tensor.utils import usr_dir
6464
from tensor2tensor.utils import bleu_hook
65+
from tensor2tensor.utils import registry
66+
from tensor2tensor import _set_time_logging
67+
6568
import tensorflow as tf
6669

6770
flags = tf.flags
6871
FLAGS = flags.FLAGS
6972

7073
# t2t-bleu specific options
7174
flags.DEFINE_string("bleu_variant", "both", "Possible values: cased(case-sensitive), uncased, both(default).")
75+
flags.DEFINE_bool("postprocess", True, "Postprocess translation and reference before calculating BLEU. True, False(default).")
76+
flags.DEFINE_string("postprocess_suffix", ".post", "Possible values: True, False(default).")
7277
flags.DEFINE_string("model_dir", "", "Directory to load model checkpoints from.")
7378
flags.DEFINE_string("translation", None, "Path to the MT system translation file")
7479
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
@@ -92,28 +97,60 @@ flags.DEFINE_string("master", "", "Address of TensorFlow master.")
9297
flags.DEFINE_string("schedule", "train_and_evaluate",
9398
"Must be train_and_evaluate for decoding.")
9499

95-
Model = namedtuple('Model', 'filename time steps')
96-
100+
Model = namedtuple('Model', 'filename time steps')
97101

98102
def read_checkpoints_list(model_dir, min_steps):
99103
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1]))
100104
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))]
101105
return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps)
102106

107+
def postprocess(pre, post, problem):
108+
if tf.gfile.Exists(post): return
109+
with open(pre, "r", encoding="utf-8") as o:
110+
with open(post, "w", encoding="utf-8") as p:
111+
for _ in range(10): tf.logging.info("postprocessing file %s" % post)
112+
p.write(problem.postprocess(o.read()))
113+
114+
def postprocess_maybe_add_suffix(filename, problem):
115+
# postprocess reference or translation file, if needed
116+
if not filename.endswith(FLAGS.postprocess_suffix):
117+
# this creates a new file with ".post" suffix (by default) in the same directory as reference
118+
post = filename + FLAGS.postprocess_suffix
119+
if not tf.gfile.Exists(post):
120+
postprocess(filename, post, problem)
121+
return post
122+
return filename
123+
103124
def main(_):
125+
_set_time_logging()
126+
104127
tf.logging.set_verbosity(tf.logging.INFO)
105-
if FLAGS.translation:
128+
129+
if FLAGS.translation: ## TODO: this variant is not tested
106130
if FLAGS.model_dir:
107131
raise ValueError('Cannot specify both --translation and --model_dir.')
108-
if FLAGS.bleu_variant in ('uncased', 'both'):
109-
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False)
110-
print("BLEU_uncased = %6.2f" % bleu)
111-
if FLAGS.bleu_variant in ('cased', 'both'):
112-
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True)
113-
print("BLEU_cased = %6.2f" % bleu)
132+
133+
def count_bleu(ref, trans, ptag=""):
134+
if FLAGS.bleu_variant in ('uncased', 'both'):
135+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=False)
136+
print("BLEU_uncased%s = %6.2f" % (ptag, bleu))
137+
if FLAGS.bleu_variant in ('cased', 'both'):
138+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, FLAGS.translation, case_sensitive=True)
139+
print("BLEU_cased%s = %6.2f" % (ptag, bleu))
140+
141+
if FLAGS.postprocess:
142+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
143+
problem = registry.problem(FLAGS.problems)
144+
ref_post = postprocess_maybe_add_suffix(FLAGS.reference)
145+
ref_trans = postprocess_maybe_add_suffix(FLAGS.translation)
146+
count_bleu(ref_post, ref_trans, ptag="_post")
147+
else:
148+
count_bleu(FLAGS.reference, FLAGS.translation, ptag="")
114149
return
115150

116151
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
152+
problem = registry.problem(FLAGS.problems)
153+
117154
FLAGS.model = FLAGS.model or 'transformer'
118155
FLAGS.output_dir = FLAGS.model_dir
119156
trainer_utils.log_registry()
@@ -177,19 +214,36 @@ def main(_):
177214
model = models.pop(0)
178215
exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps
179216
tf.logging.info("Evaluating " + model.filename)
217+
180218
out_file = translated_base_file + '-' + str(model.steps)
219+
181220
tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO
182221
decoding.decode_from_file(estimator, FLAGS.source, decode_hp, out_file, checkpoint_path=model.filename)
183222
tf.logging.set_verbosity(tf.logging.INFO)
223+
224+
post_out_file = out_file + FLAGS.postprocess_suffix
225+
if problem.needs_postprocessing and FLAGS.postprocess:
226+
post_out_file = postprocess_maybe_add_suffix(out_file, problem)
227+
else:
228+
post_out_file = out_file
229+
230+
post_reference = postprocess_maybe_add_suffix(FLAGS.reference, problem)
231+
184232
values = []
185-
if FLAGS.bleu_variant in ('uncased', 'both'):
186-
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False)
187-
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu))
188-
tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu))
189-
if FLAGS.bleu_variant in ('cased', 'both'):
190-
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True)
191-
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu))
192-
tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu))
233+
def count_bleu(ref, out, ptag=""):
234+
if FLAGS.bleu_variant in ('uncased', 'both'):
235+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False)
236+
values.append(tf.Summary.Value(tag='BLEU_uncased' + ptag + FLAGS.tag_suffix, simple_value=bleu))
237+
tf.logging.info("%s: BLEU_uncased%s%s = %6.2f" % (model.filename, ptag, FLAGS.tag_suffix, bleu))
238+
if FLAGS.bleu_variant in ('cased', 'both'):
239+
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True)
240+
values.append(tf.Summary.Value(tag='BLEU_cased' + ptag + FLAGS.tag_suffix, simple_value=bleu))
241+
tf.logging.info("%s: BLEU_uncased%s%s = %6.2f" % (model.filename, ptag, FLAGS.tag_suffix, bleu))
242+
if FLAGS.postprocess:
243+
count_bleu(post_reference, post_out_file, ptag="_post")
244+
# else: ## TODO: else or not ????
245+
count_bleu(FLAGS.reference, out_file, ptag="")
246+
193247
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps))
194248
writer.flush()
195249
with open(last_step_file, 'w') as ls_file:

0 commit comments

Comments
 (0)