Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit a655d49

Browse files
committed
more options for t2t-bleu
So it can be used for continous evaluation or for resuming older evaluation from a checkpoint with a given number of steps. It is also possible to specify the name of the events subdirectory and tag suffix.
1 parent 0ca4f2b commit a655d49

File tree

1 file changed

+64
-11
lines changed

1 file changed

+64
-11
lines changed

tensor2tensor/bin/t2t-bleu

+64-11
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,27 @@ To evaluate all checkpoints in a given directory:
3030
--hparams_set=transformer_big_single_gpu
3131
--source=wmt13_deen.en
3232
--reference=wmt13_deen.de`
33+
34+
In addition to the above-mentioned compulsory parameters,
35+
there are optional parameters:
36+
37+
* bleu_variant: cased (case-sensitive), uncased, both (default).
38+
* translations_dir: Where to store the translated files? Default="translations".
39+
* even_subdir: Where in the model_dir to store the even file? Default="",
40+
which means TensorBoard will show it as the same run as the training, but it will warn
41+
about "more than one metagraph event per run". event_subdir can be used e.g. if running
42+
this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`.
43+
* tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix
44+
can be used e.g. for different beam sizes if these should be plotted in different graphs.
45+
* min_steps: Don't evaluate checkpoints with less steps.
46+
Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
47+
of the last successfully evaluated checkpoint.
48+
* report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True.
49+
This is useful, so TensorBoard reports correct relative time for the remaining checkpoints.
50+
This flag is set to False if min_steps is > 0.
51+
* wait_secs: Wait upto N seconds for a new checkpoint. Default=0.
52+
This is useful for continuous evaluation of a running training,
53+
in which case this should be equal to save_checkpoints_secs plus some reserve.
3354
"""
3455
from __future__ import absolute_import
3556
from __future__ import division
@@ -53,7 +74,11 @@ flags.DEFINE_string("translation", None, "Path to the MT system translation file
5374
flags.DEFINE_string("source", None, "Path to the source-language file to be translated")
5475
flags.DEFINE_string("reference", None, "Path to the reference translation file")
5576
flags.DEFINE_string("translations_dir", "translations", "Where to store the translated files")
56-
flags.DEFINE_bool("report_zero", True, "Store BLEU=0 and guess its time via flags.txt")
77+
flags.DEFINE_string("event_subdir", "", "Where in model_dir to store the event file")
78+
flags.DEFINE_string("tag_suffix", "", "What to add to BLEU_cased and BLEU_uncased tags. Default=''.")
79+
flags.DEFINE_integer("min_steps", -1, "Don't evaluate checkpoints with less steps.")
80+
flags.DEFINE_integer("wait_secs", 0, "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs.")
81+
flags.DEFINE_bool("report_zero", None, "Store BLEU=0 and guess its time based on flags.txt")
5782

5883
# options derived from t2t-decode
5984
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
@@ -70,6 +95,11 @@ flags.DEFINE_string("schedule", "train_and_evaluate",
7095
Model = namedtuple('Model', 'filename time steps')
7196

7297

98+
def read_checkpoints_list(model_dir, min_steps):
99+
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1]))
100+
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))]
101+
return sorted((x for x in models if x.steps > min_steps), key=lambda x: x.steps)
102+
73103
def main(_):
74104
tf.logging.set_verbosity(tf.logging.INFO)
75105
if FLAGS.translation:
@@ -107,22 +137,43 @@ def main(_):
107137

108138
os.makedirs(FLAGS.translations_dir, exist_ok=True)
109139
translated_base_file = os.path.join(FLAGS.translations_dir, FLAGS.problems)
110-
models = [Model(x[:-6], os.path.getctime(x), int(x[:-6].rsplit('-')[-1]))
111-
for x in tf.gfile.Glob(os.path.join(model_dir, 'model.ckpt-*.index'))]
112-
models = sorted(models, key=lambda x: x.time)
140+
event_dir = os.path.join(FLAGS.model_dir, FLAGS.event_subdir)
141+
last_step_file = os.path.join(event_dir, 'last_evaluated_step.txt')
142+
if FLAGS.min_steps == -1:
143+
try:
144+
with open(last_step_file) as ls_file:
145+
FLAGS.min_steps = int(ls_file.read())
146+
except FileNotFoundError:
147+
FLAGS.min_steps = 0
148+
if FLAGS.report_zero is None:
149+
FLAGS.report_zero = FLAGS.min_steps == 0
150+
151+
models = read_checkpoints_list(model_dir, FLAGS.min_steps)
113152
tf.logging.info("Found %d models with steps: %s" % (len(models), ", ".join(str(x.steps) for x in models)))
114153

115-
writer = tf.summary.FileWriter(FLAGS.model_dir)
154+
writer = tf.summary.FileWriter(event_dir)
116155
if FLAGS.report_zero:
117156
start_time = os.path.getctime(os.path.join(model_dir, 'flags.txt'))
118157
values = []
119158
if FLAGS.bleu_variant in ('uncased', 'both'):
120-
values.append(tf.Summary.Value(tag='BLEU_uncased', simple_value=0))
159+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=0))
121160
if FLAGS.bleu_variant in ('cased', 'both'):
122-
values.append(tf.Summary.Value(tag='BLEU_cased', simple_value=0))
161+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=0))
123162
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=start_time, step=0))
124163

125-
for model in models:
164+
exit_time = time.time() + FLAGS.wait_secs
165+
min_steps = FLAGS.min_steps
166+
while True:
167+
if not models and FLAGS.wait_secs:
168+
tf.logging.info('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time.asctime(time.localtime(exit_time)))
169+
while not models and time.time() < exit_time:
170+
time.sleep(10)
171+
models = read_checkpoints_list(model_dir, min_steps)
172+
if not models:
173+
return
174+
175+
model = models.pop(0)
176+
exit_time, min_steps = model.time + FLAGS.wait_secs, model.steps
126177
tf.logging.info("Evaluating " + model.filename)
127178
out_file = translated_base_file + '-' + str(model.steps)
128179
tf.logging.set_verbosity(tf.logging.ERROR) # decode_from_file logs all the translations as INFO
@@ -131,15 +182,17 @@ def main(_):
131182
values = []
132183
if FLAGS.bleu_variant in ('uncased', 'both'):
133184
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=False)
134-
values.append(tf.Summary.Value(tag='BLEU_uncased', simple_value=bleu))
185+
values.append(tf.Summary.Value(tag='BLEU_uncased' + FLAGS.tag_suffix, simple_value=bleu))
135186
tf.logging.info("%s: BLEU_uncased = %6.2f" % (model.filename, bleu))
136187
if FLAGS.bleu_variant in ('cased', 'both'):
137188
bleu = 100 * bleu_hook.bleu_wrapper(FLAGS.reference, out_file, case_sensitive=True)
138-
values.append(tf.Summary.Value(tag='BLEU_cased', simple_value=bleu))
189+
values.append(tf.Summary.Value(tag='BLEU_cased' + FLAGS.tag_suffix, simple_value=bleu))
139190
tf.logging.info("%s: BLEU_cased = %6.2f" % (model.filename, bleu))
140191
writer.add_event(tf.summary.Event(summary=tf.Summary(value=values), wall_time=model.time, step=model.steps))
192+
writer.flush()
193+
with open(last_step_file, 'w') as ls_file:
194+
ls_file.write(str(model.steps) + '\n')
141195

142-
writer.flush()
143196

144197
if __name__ == "__main__":
145198
tf.app.run()

0 commit comments

Comments
 (0)