Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit e3cd447

Browse files
authored
Merge pull request #449 from rsepassi/push
v1.3
2 parents 92983ea + 69701e4 commit e3cd447

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2656
-2105
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='tensor2tensor',
8-
version='1.2.9',
8+
version='1.3.0',
99
description='Tensor2Tensor',
1010
author='Google Inc.',
1111
author_email='[email protected]',

tensor2tensor/bin/t2t-datagen

100755100644
+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = {
112112
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
113113
lambda: audio.timit_generator(
114114
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
115-
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
115+
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
116116
}
117117

118118
# pylint: enable=g-long-lambda

tensor2tensor/bin/t2t-decoder

+4-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ flags = tf.flags
4747
FLAGS = flags.FLAGS
4848

4949
flags.DEFINE_string("output_dir", "", "Training directory to load from.")
50-
flags.DEFINE_string("decode_from_file", None, "Path to the source file for decoding")
51-
flags.DEFINE_string("decode_to_file", None, "Path to the decoded (output) file")
50+
flags.DEFINE_string("decode_from_file", None,
51+
"Path to the source file for decoding")
52+
flags.DEFINE_string("decode_to_file", None,
53+
"Path to the decoded (output) file")
5254
flags.DEFINE_bool("decode_interactive", False,
5355
"Interactive local inference mode.")
5456
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")

tensor2tensor/bin/t2t-trainer

+10-11
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ flags.DEFINE_string("schedule", "train_and_evaluate",
6161
"Method of tf.contrib.learn.Experiment to run.")
6262
flags.DEFINE_bool("profile", False, "Profile performance?")
6363

64+
6465
def main(_):
6566
tf.logging.set_verbosity(tf.logging.INFO)
6667
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
@@ -85,22 +86,20 @@ def main(_):
8586
# Run the trainer.
8687
def run_experiment():
8788
trainer_utils.run(
88-
data_dir=data_dir,
89-
model=FLAGS.model,
90-
output_dir=output_dir,
91-
train_steps=FLAGS.train_steps,
92-
eval_steps=FLAGS.eval_steps,
93-
schedule=FLAGS.schedule)
94-
89+
data_dir=data_dir,
90+
model=FLAGS.model,
91+
output_dir=output_dir,
92+
train_steps=FLAGS.train_steps,
93+
eval_steps=FLAGS.eval_steps,
94+
schedule=FLAGS.schedule)
95+
9596
if FLAGS.profile:
96-
with tf.contrib.tfprof.ProfileContext('t2tprof',
97+
with tf.contrib.tfprof.ProfileContext("t2tprof",
9798
trace_steps=range(100),
9899
dump_steps=range(100)) as pctx:
99100
opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
100-
pctx.add_auto_profiling('op', opts, range(100))
101-
101+
pctx.add_auto_profiling("op", opts, range(100))
102102
run_experiment()
103-
104103
else:
105104
run_experiment()
106105

tensor2tensor/data_generators/cnn_dailymail.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import hashlib
2223
import io
2324
import os
2425
import tarfile
25-
import hashlib
2626

2727
# Dependency imports
2828

@@ -129,7 +129,9 @@ def generate_hash(inp):
129129

130130
return filelist
131131

132+
132133
def example_generator(all_files, urls_path, sum_token):
134+
"""Generate examples."""
133135
def fix_run_on_sents(line):
134136
if u"@highlight" in line:
135137
return line
@@ -168,30 +170,37 @@ def fix_run_on_sents(line):
168170

169171
yield " ".join(story) + story_summary_split_token + " ".join(summary)
170172

173+
171174
def _story_summary_split(story):
172175
split_str = u" <summary> "
173176
split_str_len = len(split_str)
174177
split_pos = story.find(split_str)
175178
return story[:split_pos], story[split_pos+split_str_len:] # story, summary
176179

177-
def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
180+
181+
def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir,
182+
is_training):
183+
"""Write text to files."""
178184
def write_to_file(all_files, urls_path, data_dir, filename):
179-
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
180-
for example in example_generator(all_files, urls_path, sum_token=True):
181-
story, summary = _story_summary_split(example)
182-
fstory.write(story+"\n")
183-
fsummary.write(summary+"\n")
185+
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory:
186+
with io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
187+
for example in example_generator(all_files, urls_path, sum_token=True):
188+
story, summary = _story_summary_split(example)
189+
fstory.write(story+"\n")
190+
fsummary.write(summary+"\n")
184191

185192
filename = "cnndm.train" if is_training else "cnndm.dev"
186193
tf.logging.info("Writing %s" % filename)
187194
write_to_file(all_files, urls_path, data_dir, filename)
188195

189196
if not is_training:
190-
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
197+
test_urls_path = generator_utils.maybe_download(
198+
tmp_dir, "all_test.txt", _TEST_URLS)
191199
filename = "cnndm.test"
192200
tf.logging.info("Writing %s" % filename)
193201
write_to_file(all_files, test_urls_path, data_dir, filename)
194202

203+
195204
@registry.register_problem
196205
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
197206
"""Summarize CNN and Daily Mail articles to their summary highlights."""
@@ -237,7 +246,8 @@ def generator(self, data_dir, tmp_dir, is_training):
237246
encoder = generator_utils.get_or_generate_vocab_inner(
238247
data_dir, self.vocab_file, self.targeted_vocab_size,
239248
example_generator(all_files, urls_path, sum_token=False))
240-
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
249+
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir,
250+
is_training)
241251
for example in example_generator(all_files, urls_path, sum_token=True):
242252
story, summary = _story_summary_split(example)
243253
encoded_summary = encoder.encode(summary) + [EOS]

tensor2tensor/data_generators/generator_utils.py

+65
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,68 @@ def shuffle_dataset(filenames):
447447
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
448448
write_records(records, out_fname)
449449
tf.gfile.Remove(fname)
450+
451+
452+
def combine_examples_no_inputs(examples, max_length):
453+
"""Combine examples into longer examples.
454+
455+
Concatenate targets to form target sequences with length up to max_length.
456+
Target sequences longer than max_length are chopped into multiple sequences.
457+
458+
Args:
459+
examples: a generator returning feature dictionaries.
460+
max_length: an integer.
461+
462+
Yields:
463+
feature dictionaries.
464+
"""
465+
partial = []
466+
for example in examples:
467+
x = example["targets"]
468+
if len(x) + len(partial) > max_length:
469+
if partial:
470+
yield {"inputs": [0], "targets": partial}
471+
partial = []
472+
if len(x) > max_length:
473+
num_fragments = len(x) // max_length
474+
for i in xrange(num_fragments):
475+
yield {"inputs": [0], "targets": x[max_length * i:max_length * (i + 1)]}
476+
partial = x[max_length * num_fragments:]
477+
else:
478+
partial += x
479+
if partial:
480+
yield {"inputs": [0], "targets": partial}
481+
482+
483+
def combine_examples_with_inputs(examples, max_length):
484+
"""Combine examples into longer examples.
485+
486+
We combine multiple examples by concatenating the inputs and concatenating
487+
the targets. Sequences where the inputs or the targets are too long are
488+
emitted as singletons (not chopped).
489+
490+
Args:
491+
examples: a generator returning feature dictionaries.
492+
max_length: an integer.
493+
494+
Yields:
495+
feature dictionaries.
496+
"""
497+
partial_a = []
498+
partial_b = []
499+
for example in examples:
500+
a = example["inputs"]
501+
b = example["targets"]
502+
if (len(a) + len(partial_a) > max_length or
503+
len(b) + len(partial_b) > max_length):
504+
if partial_a or partial_b:
505+
yield {"inputs": partial_a, "targets": partial_b}
506+
partial_a = []
507+
partial_b = []
508+
if len(a) > max_length or len(b) > max_length:
509+
yield {"inputs": a, "targets": b}
510+
else:
511+
partial_a += a
512+
partial_b += b
513+
if partial_a or partial_b:
514+
yield {"inputs": partial_a, "targets": partial_b}

0 commit comments

Comments
 (0)