Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

t2t-decoder with --checkpoint_path #524

Merged
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ before_install:
- sudo apt-get update -qq
- sudo apt-get install -qq libhdf5-dev
install:
- pip install tensorflow
- pip install .[tests]
- pip install -q tensorflow
- pip install -q .[tests]
env:
global:
- T2T_PROBLEM=algorithmic_reverse_binary40_test
Expand Down
10 changes: 9 additions & 1 deletion tensor2tensor/bin/t2t_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@
FLAGS = flags.FLAGS

# Additional flags in bin/t2t_trainer.py and utils/flags.py
flags.DEFINE_string("checkpoint_path", None,
"Path to the model checkpoint. Overrides output_dir.")
flags.DEFINE_string("decode_from_file", None,
"Path to the source file for decoding")
flags.DEFINE_string("decode_to_file", None,
"Path to the decoded (output) file")
flags.DEFINE_bool("keep_timestamp", True,
"Set the mtime of the decoded file to the checkpoint_path+'.index' mtime.")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
Expand All @@ -76,7 +80,11 @@ def decode(estimator, hparams, decode_hp):
decoding.decode_interactively(estimator, hparams, decode_hp)
elif FLAGS.decode_from_file:
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
decode_hp, FLAGS.decode_to_file)
decode_hp, FLAGS.decode_to_file,
checkpoint_path=FLAGS.checkpoint_path)
if FLAGS.checkpoint_path and FLAGS.keep_timestamp:
ckpt_time = os.path.getmtime(FLAGS.checkpoint_path + '.index')
os.utime(FLAGS.decode_to_file, (ckpt_time, ckpt_time))
else:
decoding.decode_from_dataset(
estimator,
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/bin/t2t_translate_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@


def main(_):
FLAGS = flags.FLAGS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FLAGS defined above. rm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can delete the current line 41 (I hope), but not this one.
FLAGS need to be a local variable because they are used few lines below -- https://github.com/martinpopel/tensor2tensor/blob/61005b0b1e063e3e26cb403cf6e8cdea3509c3b5/tensor2tensor/bin/t2t_translate_all.py#L94
Thanks to **locals(), the users can provide e.g --decoder_command='qsub my-decoder.sh {params} --my_param={FLAGS.my_param}' and I think it is more user-friendly than forcing them to write --my_param={flags.FLAGS.my_param}.

BTW: Last time someone deleted this line during an internal-Google refactor, but didn't noticed me within PR review as you do now. It was one of the reasons why t2t_translate_all.py didn't work (another reason was I was lazy to write a test).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How about this?:

locals_and_flags = dict()
local_and_flags.update(FLAGS.__dict__)
locals_and_flags.update(locals())
...
.format(**locals_and_flags)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. If you feel it's more elegant, I will do it this way (and test).

tf.logging.set_verbosity(tf.logging.INFO)
# pylint: disable=unused-variable
model_dir = os.path.expanduser(FLAGS.model_dir)
Expand Down
9 changes: 7 additions & 2 deletions tensor2tensor/utils/bleu_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
import sys
import time
import glob
import unicodedata

# Dependency imports
Expand Down Expand Up @@ -158,6 +159,7 @@ def property_chars(self, prefix):
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))

uregex = UnicodeRegex()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this top level now? would prefer to leave it where it is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the UnicodeRegex instance is quite slow because I need to iterate twice over all Unicode characters.
It is really a bad idea to create a new instance for each line of the test set to be tokenized.
It caused the whole t2t-bleu.py to be about 10000 times slower (depending on the test set size).
As I note in my commit log, I would slightly prefer my original implementation with a class with static methods only (the class was there rather as a namespace).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok. thanks. fine as is then.


def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
Expand All @@ -183,7 +185,6 @@ def bleu_tokenize(string):
Returns:
a list of tokens
"""
uregex = UnicodeRegex()
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
Expand All @@ -209,7 +210,11 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
def _read_stepfiles_list(path_prefix, path_suffix=".index", min_steps=0):
"""Return list of StepFiles sorted by step from files at path_prefix."""
stepfiles = []
for filename in tf.gfile.Glob(path_prefix + "*-[0-9]*" + path_suffix):
# tf.gfile.Glob may crash with
# tensorflow.python.framework.errors_impl.NotFoundError:
# xy/model.ckpt-1130761_temp_9cb4cb0b0f5f4382b5ea947aadfb7a40; No such file or directory
# Let's use standard glob.glob instead.
for filename in glob.glob(path_prefix + '*-[0-9]*' + path_suffix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would prefer tf.gfile.Glob. why is it crashing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When saving a checkpoint, TensorFlow always creates a temp file first (which unfortunately matches the given wildcard pattern, but this is not important here I think).
This temporary file is deleted soon, but if this happens when tf.gfile.Glob is being called, it crashes.
It crashes within the tf.gfile.Glob (i.e. not within the for-cycle in my code) with tensorflow.python.framework.errors_impl.NotFoundError, which is difficult to handle (I could catch the exception and try the same call again, but that's too hacky).
According to my extensive tests, glob.glob does not have this non-deterministic bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the benefit of tf.gfile is that it will support many underlying filesystems, including Google cloud storage and internal filesystems. glob won't know what to do with a path that starts with gs://.

would updating the pattern to 'ckpt-[0-9]*.index' work?
i'm thinking maybe not because it's probably failing on a IsDirectory call.

i'm inclined to advocate for a retrying_glob function to maintain portability across filesystems though i agree it's hacky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the benefit of tf.gfile and this was the reason why I tried it first.
I would prefer a different solution:

  • import tf.gfile wrapper once which overrides glob.glob, and keep glob.glob everywhere in the code
  • make the wrapper detect special systems (gs://), but otherwise fallback to the standard (and reliable) glob.glob

If this is not acceptable, I will catch the exception and call again.
Even if the bug is rare, it is annoying when running many week-long experiments and some of the evaluation jobs crash.

would updating the pattern to 'ckpt-[0-9]*.index' work

No. I use stepfiles_iterator also in t2t-bleu, where the files being iterated are not checkpoints, but translations of the test set, which I decided to have no extension (but path_suffix can be e.g. .txt as well). For this reason I call it stepfiles and not checkpoints.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have the retrying tf.gfile.Glob with the exception catching. Monkey-patching a lib would be really confusing for people reading the code and not understanding why something is happening.

basename = filename[:-len(path_suffix)] if len(path_suffix) else filename
try:
steps = int(basename.rsplit("-")[-1])
Expand Down
5 changes: 3 additions & 2 deletions tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def decode_from_file(estimator,
filename,
hparams,
decode_hp,
decode_to_file=None):
decode_to_file=None,
checkpoint_path=None):
"""Compute predictions on entries in filename and write them out."""
if not decode_hp.batch_size:
decode_hp.batch_size = 32
Expand Down Expand Up @@ -248,7 +249,7 @@ def input_fn():
return _decode_input_tensor_to_features_dict(example, hparams)

decodes = []
result_iter = estimator.predict(input_fn)
result_iter = estimator.predict(input_fn, checkpoint_path=checkpoint_path)
for result in result_iter:
if decode_hp.return_beams:
beam_decodes = []
Expand Down