Skip to content

Commit bb3c8e6

Browse files
committed
allow specifying --checkpoint_path with t2t-decoder
and allow keeping timestamp in that case. This is needed for t2t-translate-all + t2t-bleu to work as expected (I forgot to add this commit to tensorflow#488).
1 parent 0ccc562 commit bb3c8e6

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

tensor2tensor/bin/t2t_decoder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@
4747
FLAGS = flags.FLAGS
4848

4949
# Additional flags in bin/t2t_trainer.py and utils/flags.py
50+
flags.DEFINE_string("checkpoint_path", None,
51+
"Path to the model checkpoint. Overrides output_dir.")
5052
flags.DEFINE_string("decode_from_file", None,
5153
"Path to the source file for decoding")
5254
flags.DEFINE_string("decode_to_file", None,
5355
"Path to the decoded (output) file")
56+
flags.DEFINE_bool("keep_timestamp", True,
57+
"Set the mtime of the decoded file to the checkpoint_path+'.index' mtime.")
5458
flags.DEFINE_bool("decode_interactive", False,
5559
"Interactive local inference mode.")
5660
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
@@ -76,7 +80,11 @@ def decode(estimator, hparams, decode_hp):
7680
decoding.decode_interactively(estimator, hparams, decode_hp)
7781
elif FLAGS.decode_from_file:
7882
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
79-
decode_hp, FLAGS.decode_to_file)
83+
decode_hp, FLAGS.decode_to_file,
84+
checkpoint_path=FLAGS.checkpoint_path)
85+
if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
86+
ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + '.index')
87+
os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
8088
else:
8189
decoding.decode_from_dataset(
8290
estimator,

tensor2tensor/utils/decoding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def decode_from_file(estimator,
219219
filename,
220220
hparams,
221221
decode_hp,
222-
decode_to_file=None):
222+
decode_to_file=None,
223+
checkpoint_path=None):
223224
"""Compute predictions on entries in filename and write them out."""
224225
if not decode_hp.batch_size:
225226
decode_hp.batch_size = 32
@@ -248,7 +249,7 @@ def input_fn():
248249
return _decode_input_tensor_to_features_dict(example, hparams)
249250

250251
decodes = []
251-
result_iter = estimator.predict(input_fn)
252+
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
252253
for result in result_iter:
253254
if decode_hp.return_beams:
254255
beam_decodes = []

0 commit comments

Comments
 (0)