-
Notifications
You must be signed in to change notification settings - Fork 3.6k
t2t-decoder with --checkpoint_path #524
Changes from 6 commits
0ccc562
bb3c8e6
3754f32
c4a6549
747d429
61005b0
9ca3087
bc09d9a
d394be3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import re | ||
import sys | ||
import time | ||
import glob | ||
import unicodedata | ||
|
||
# Dependency imports | ||
|
@@ -158,6 +159,7 @@ def property_chars(self, prefix): | |
return "".join(six.unichr(x) for x in range(sys.maxunicode) | ||
if unicodedata.category(six.unichr(x)).startswith(prefix)) | ||
|
||
uregex = UnicodeRegex() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this top level now? would prefer to leave it where it is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ok. thanks. fine as is then. |
||
|
||
def bleu_tokenize(string): | ||
r"""Tokenize a string following the official BLEU implementation. | ||
|
@@ -183,7 +185,6 @@ def bleu_tokenize(string): | |
Returns: | ||
a list of tokens | ||
""" | ||
uregex = UnicodeRegex() | ||
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string) | ||
string = uregex.punct_nondigit_re.sub(r" \1 \2", string) | ||
string = uregex.symbol_re.sub(r" \1 ", string) | ||
|
@@ -209,7 +210,11 @@ def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False): | |
def _read_stepfiles_list(path_prefix, path_suffix=".index", min_steps=0): | ||
"""Return list of StepFiles sorted by step from files at path_prefix.""" | ||
stepfiles = [] | ||
for filename in tf.gfile.Glob(path_prefix + "*-[0-9]*" + path_suffix): | ||
# tf.gfile.Glob may crash with | ||
# tensorflow.python.framework.errors_impl.NotFoundError: | ||
# xy/model.ckpt-1130761_temp_9cb4cb0b0f5f4382b5ea947aadfb7a40; No such file or directory | ||
# Let's use standard glob.glob instead. | ||
for filename in glob.glob(path_prefix + '*-[0-9]*' + path_suffix): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would prefer tf.gfile.Glob. why is it crashing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When saving a checkpoint, TensorFlow always creates a temp file first (which unfortunately matches the given wildcard pattern, but this is not important here I think). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the benefit of tf.gfile is that it will support many underlying filesystems, including Google cloud storage and internal filesystems. glob won't know what to do with a path that starts with gs://. would updating the pattern to 'ckpt-[0-9]*.index' work? i'm inclined to advocate for a retrying_glob function to maintain portability across filesystems though i agree it's hacky. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know the benefit of tf.gfile and this was the reason why I tried it first.
If this is not acceptable, I will catch the exception and call again.
No. I use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have the retrying tf.gfile.Glob with the exception catching. Monkey-patching a lib would be really confusing for people reading the code and not understanding why something is happening. |
||
basename = filename[:-len(path_suffix)] if len(path_suffix) else filename | ||
try: | ||
steps = int(basename.rsplit("-")[-1]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FLAGS defined above. rm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can delete the current line 41 (I hope), but not this one.
FLAGS need to be a local variable because they are used few lines below -- https://github.com/martinpopel/tensor2tensor/blob/61005b0b1e063e3e26cb403cf6e8cdea3509c3b5/tensor2tensor/bin/t2t_translate_all.py#L94
Thanks to
**locals()
, the users can provide e.g--decoder_command='qsub my-decoder.sh {params} --my_param={FLAGS.my_param}'
and I think it is more user-friendly than forcing them to write--my_param={flags.FLAGS.my_param}
.BTW: Last time someone deleted this line during an internal-Google refactor, but didn't noticed me within PR review as you do now. It was one of the reasons why
t2t_translate_all.py
didn't work (another reason was I was lazy to write a test).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. How about this?:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. If you feel it's more elegant, I will do it this way (and test).