Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 01b8c31

Browse files
Lukasz KaiserRyan Sepassi
Lukasz Kaiser
authored and
Ryan Sepassi
committed
CHECKPOINT BREAKING: make T2TModel a subclass of Layer so it can be called; all variables are now in model-name scope.
PiperOrigin-RevId: 176407831
1 parent c0ce3dd commit 01b8c31

18 files changed

+155
-91
lines changed

tensor2tensor/models/bluenet_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ def testBlueNet(self):
4545
}
4646
model = bluenet.BlueNet(
4747
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
48-
sharded_logits, _ = model.model_fn(features)
49-
logits = tf.concat(sharded_logits, 0)
48+
logits, _ = model(features)
5049
session.run(tf.global_variables_initializer())
5150
res = session.run(logits)
5251
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))

tensor2tensor/models/bytenet_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def testByteNet(self):
4444
}
4545
model = bytenet.ByteNet(
4646
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
47-
sharded_logits, _ = model.model_fn(features)
48-
logits = tf.concat(sharded_logits, 0)
47+
logits, _ = model(features)
4948
session.run(tf.global_variables_initializer())
5049
res = session.run(logits)
5150
self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))

tensor2tensor/models/gene_expression_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ def _testModel(self, hparams, model_cls):
5555
"targets": tf.constant(targets, dtype=tf.float32),
5656
}
5757
p_hparams, = hparams.problems
58-
sharded_logits, _ = model_cls(hparams, tf.estimator.ModeKeys.TRAIN,
59-
p_hparams).model_fn(features)
60-
logits = tf.concat(sharded_logits, 0)
58+
logits, _ = model_cls(
59+
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)(features)
6160

6261
with self.test_session() as sess:
6362
sess.run(tf.global_variables_initializer())

tensor2tensor/models/lstm_test.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def testLSTMSeq2Seq(self):
4444
}
4545
model = lstm.LSTMSeq2seq(hparams, tf.estimator.ModeKeys.TRAIN,
4646
p_hparams)
47-
sharded_logits, _ = model.model_fn(features)
48-
logits = tf.concat(sharded_logits, 0)
47+
logits, _ = model(features)
4948
session.run(tf.global_variables_initializer())
5049
res = session.run(logits)
5150
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
@@ -67,8 +66,7 @@ def testLSTMSeq2SeqAttention(self):
6766
}
6867
model = lstm.LSTMSeq2seqAttention(
6968
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
70-
sharded_logits, _ = model.model_fn(features)
71-
logits = tf.concat(sharded_logits, 0)
69+
logits, _ = model(features)
7270
session.run(tf.global_variables_initializer())
7371
res = session.run(logits)
7472
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))

tensor2tensor/models/multimodel_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def testMultiModel(self):
4848
}
4949
model = multimodel.MultiModel(
5050
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
51-
sharded_logits, _ = model.model_fn(features)
52-
logits = tf.concat(sharded_logits, 0)
51+
logits, _ = model(features)
5352
session.run(tf.global_variables_initializer())
5453
res = session.run(logits)
5554
self.assertEqual(res.shape, (3, 1, 1, 1, 10))

tensor2tensor/models/neural_gpu_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def testNeuralGPU(self):
5252
}
5353
model = neural_gpu.NeuralGPU(hparams, tf.estimator.ModeKeys.TRAIN,
5454
p_hparams)
55-
shadred_logits, _ = model.model_fn(features)
56-
logits = tf.concat(shadred_logits, 0)
55+
logits, _ = model(features)
5756
session.run(tf.global_variables_initializer())
5857
res = session.run(logits)
5958
self.assertEqual(res.shape, (batch_size, target_length, 1, 1,

tensor2tensor/models/resnet_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def _testResnet(self, img_size, output_size):
5656
"targets": tf.constant(y, dtype=tf.int32),
5757
}
5858
model = resnet.Resnet50(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
59-
sharded_logits, _ = model.model_fn(features)
60-
logits = tf.concat(sharded_logits, 0)
59+
logits, _ = model(features)
6160
session.run(tf.global_variables_initializer())
6261
res = session.run(logits)
6362
self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size))

tensor2tensor/models/slicenet_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def testSliceNet(self):
4949
}
5050
model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
5151
p_hparams)
52-
sharded_logits, _ = model.model_fn(features)
53-
logits = tf.concat(sharded_logits, 0)
52+
logits, _ = model(features)
5453
session.run(tf.global_variables_initializer())
5554
res = session.run(logits)
5655
self.assertEqual(res.shape, (3, 1, 1, 1, 10))

tensor2tensor/models/transformer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def _greedy_infer(self, features, decode_length):
158158
Raises:
159159
NotImplementedError: If there are multiple data shards.
160160
"""
161-
decoded_ids, _ = self._fast_decode(features, decode_length)
161+
with tf.variable_scope(self.name):
162+
decoded_ids, _ = self._fast_decode(features, decode_length)
162163
return decoded_ids, None, None
163164

164165
def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
@@ -175,8 +176,9 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
175176
Returns:
176177
samples: an integer `Tensor`. Top samples from the beam search
177178
"""
178-
decoded_ids, scores = self._fast_decode(features, decode_length, beam_size,
179-
top_beams, alpha)
179+
with tf.variable_scope(self.name):
180+
decoded_ids, scores = self._fast_decode(
181+
features, decode_length, beam_size, top_beams, alpha)
180182
return {"outputs": decoded_ids, "scores": scores}
181183

182184
def _fast_decode(self,

tensor2tensor/models/transformer_revnet_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def testTransformer(self):
5959
}
6060
model = transformer_revnet.TransformerRevnet(
6161
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
62-
sharded_logits, _ = model.model_fn(features)
63-
logits = tf.concat(sharded_logits, 0)
62+
logits, _ = model(features)
6463
grads = tf.gradients(
6564
tf.reduce_mean(logits), [features["inputs"]] + tf.global_variables())
6665
grads = [g for g in grads if g is not None]

tensor2tensor/models/transformer_test.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,24 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN):
5151
targets = -1 + np.random.random_integers(
5252
VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
5353
features = {
54-
"inputs": tf.constant(inputs, dtype=tf.int32),
55-
"targets": tf.constant(targets, dtype=tf.int32),
56-
"target_space_id": tf.constant(1, dtype=tf.int32),
54+
"inputs": tf.constant(inputs, dtype=tf.int32, name="inputs"),
55+
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
56+
"target_space_id": tf.constant(1, dtype=tf.int32)
5757
}
5858

5959
return transformer.Transformer(hparams, mode, p_hparams), features
6060

6161
def testTransformer(self):
6262
model, features = self.getModel(transformer.transformer_small())
63-
shadred_logits, _ = model.model_fn(features)
64-
logits = tf.concat(shadred_logits, 0)
63+
logits, _ = model(features)
6564
with self.test_session() as session:
6665
session.run(tf.global_variables_initializer())
6766
res = session.run(logits)
6867
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
6968

7069
def testTransformerRelative(self):
7170
model, features = self.getModel(transformer.transformer_relative_tiny())
72-
shadred_logits, _ = model.model_fn(features)
73-
logits = tf.concat(shadred_logits, 0)
71+
logits, _ = model(features)
7472
with self.test_session() as session:
7573
session.run(tf.global_variables_initializer())
7674
res = session.run(logits)
@@ -81,8 +79,8 @@ def testGreedyVsFast(self):
8179

8280
decode_length = 2
8381

84-
out_logits, _ = model.model_fn(features)
85-
out_logits = tf.squeeze(out_logits[0], axis=[2, 3])
82+
out_logits, _ = model(features)
83+
out_logits = tf.squeeze(out_logits, axis=[2, 3])
8684
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
8785
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
8886
labels=tf.reshape(features["targets"], [-1]))
@@ -94,8 +92,7 @@ def testGreedyVsFast(self):
9492
for _ in range(100):
9593
apply_grad.run()
9694

97-
model, _ = self.getModel(transformer.transformer_small(),
98-
mode=tf.estimator.ModeKeys.PREDICT)
95+
model.set_mode(tf.estimator.ModeKeys.PREDICT)
9996

10097
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
10198
greedy_result, _, _ = model._slow_greedy_infer(features, decode_length)
@@ -115,8 +112,8 @@ def testBeamVsFast(self):
115112

116113
decode_length = 2
117114

118-
out_logits, _ = model.model_fn(features)
119-
out_logits = tf.squeeze(out_logits[0], axis=[2, 3])
115+
out_logits, _ = model(features)
116+
out_logits = tf.squeeze(out_logits, axis=[2, 3])
120117
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
121118
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
122119
labels=tf.reshape(features["targets"], [-1]))
@@ -128,8 +125,7 @@ def testBeamVsFast(self):
128125
for _ in range(100):
129126
apply_grad.run()
130127

131-
model, _ = self.getModel(transformer.transformer_small(),
132-
mode=tf.estimator.ModeKeys.PREDICT)
128+
model.set_mode(tf.estimator.ModeKeys.PREDICT)
133129

134130
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
135131
beam_result = model._beam_decode_slow(

tensor2tensor/models/transformer_vae.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -654,9 +654,9 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
654654
dtype=tf.int64)
655655

656656
features["targets"] = initial_output
657-
sharded_logits, _ = self.model_fn(features, False, force_full_predict=True)
658-
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
659-
samples = tf.concat(sharded_samples, 0)
657+
logits, _ = self.__call__(
658+
features, skip=False, force_full_predict=True)
659+
samples = tf.argmax(logits, axis=-1)
660660

661661
if inputs_old is not None: # Restore to not confuse Estimator.
662662
features["inputs"] = inputs_old

tensor2tensor/models/xception_test.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def _testXception(self, img_size, output_size):
4848
"targets": tf.constant(y, dtype=tf.int32),
4949
}
5050
model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
51-
sharded_logits, _ = model.model_fn(features)
52-
logits = tf.concat(sharded_logits, 0)
51+
logits, _ = model(features)
5352
session.run(tf.global_variables_initializer())
5453
res = session.run(logits)
5554
self.assertEqual(res.shape, output_size + (1, vocab_size))

tensor2tensor/tpu/tpu_trainer_lib.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def t2t_model_fn(model_name,
209209
EstimatorSpec or TPUEstimatorSpec
210210
"""
211211
_create_dummy_vars()
212-
213212
hparams = copy.deepcopy(hparams)
214213
problem = hparams.problem_instances[0]
215214
problem_hp = hparams.problems[0]
@@ -224,10 +223,9 @@ def t2t_model_fn(model_name,
224223
if use_tpu else create_data_parallelism(**config.t2t_device_info))
225224
model = registry.model(model_name)(
226225
hparams, mode, problem_hp, data_parallelism=data_parallelism)
227-
sharded_logits, losses_dict = model.model_fn(features)
226+
logits, losses_dict = model(features)
228227

229228
# Set known shapes
230-
logits = tf.concat(sharded_logits, 0)
231229
shape = logits.get_shape().as_list()
232230
if shape[0] is None:
233231
shape[0] = _get_batch_size(params, hparams, config)

tensor2tensor/utils/model_builder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def nth_model(n):
127127
if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL:
128128
sharded_logits, losses_dict = model_class.eval_autoregressive(features)
129129
else:
130-
sharded_logits, losses_dict = model_class.model_fn(
130+
sharded_logits, losses_dict = model_class(
131131
features, skip=(skipping_is_on and skip_this_one))
132132
with tf.variable_scope("losses_avg"):
133133
total_loss, ops = 0.0, []

tensor2tensor/utils/registry.py

+13-14
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _reset():
9090
ctr.clear()
9191

9292

93-
def _default_name(obj_class):
93+
def default_name(obj_class):
9494
"""Convert a class name to the registry's default name for the class.
9595
9696
Args:
@@ -99,7 +99,6 @@ def _default_name(obj_class):
9999
Returns:
100100
The registry's default name for the class.
101101
"""
102-
103102
return _convert_camel_to_snake(obj_class.__name__)
104103

105104

@@ -112,25 +111,25 @@ def default_object_name(obj):
112111
Returns:
113112
The registry's default name for the class of the object.
114113
"""
115-
116-
return _default_name(obj.__class__)
114+
return default_name(obj.__class__)
117115

118116

119117
def register_model(name=None):
120118
"""Register a model. name defaults to class name snake-cased."""
121119

122120
def decorator(model_cls, registration_name=None):
123121
"""Registers & returns model_cls with registration_name or default name."""
124-
model_name = registration_name or _default_name(model_cls)
122+
model_name = registration_name or default_name(model_cls)
125123
if model_name in _MODELS:
126124
raise LookupError("Model %s already registered." % model_name)
125+
model_cls.REGISTERED_NAME = property(lambda _: model_name)
127126
_MODELS[model_name] = model_cls
128127
return model_cls
129128

130129
# Handle if decorator was used without parens
131130
if callable(name):
132131
model_cls = name
133-
return decorator(model_cls, registration_name=_default_name(model_cls))
132+
return decorator(model_cls, registration_name=default_name(model_cls))
134133

135134
return lambda model_cls: decorator(model_cls, name)
136135

@@ -150,7 +149,7 @@ def register_hparams(name=None):
150149

151150
def decorator(hp_fn, registration_name=None):
152151
"""Registers & returns hp_fn with registration_name or default name."""
153-
hp_name = registration_name or _default_name(hp_fn)
152+
hp_name = registration_name or default_name(hp_fn)
154153
if hp_name in _HPARAMS:
155154
raise LookupError("HParams set %s already registered." % hp_name)
156155
_HPARAMS[hp_name] = hp_fn
@@ -159,7 +158,7 @@ def decorator(hp_fn, registration_name=None):
159158
# Handle if decorator was used without parens
160159
if callable(name):
161160
hp_fn = name
162-
return decorator(hp_fn, registration_name=_default_name(hp_fn))
161+
return decorator(hp_fn, registration_name=default_name(hp_fn))
163162

164163
return lambda hp_fn: decorator(hp_fn, name)
165164

@@ -182,7 +181,7 @@ def register_ranged_hparams(name=None):
182181

183182
def decorator(rhp_fn, registration_name=None):
184183
"""Registers & returns hp_fn with registration_name or default name."""
185-
rhp_name = registration_name or _default_name(rhp_fn)
184+
rhp_name = registration_name or default_name(rhp_fn)
186185
if rhp_name in _RANGED_HPARAMS:
187186
raise LookupError("RangedHParams set %s already registered." % rhp_name)
188187
# Check that the fn takes a single argument
@@ -197,7 +196,7 @@ def decorator(rhp_fn, registration_name=None):
197196
# Handle if decorator was used without parens
198197
if callable(name):
199198
rhp_fn = name
200-
return decorator(rhp_fn, registration_name=_default_name(rhp_fn))
199+
return decorator(rhp_fn, registration_name=default_name(rhp_fn))
201200

202201
return lambda rhp_fn: decorator(rhp_fn, name)
203202

@@ -217,7 +216,7 @@ def register_problem(name=None):
217216

218217
def decorator(p_cls, registration_name=None):
219218
"""Registers & returns p_cls with registration_name or default name."""
220-
p_name = registration_name or _default_name(p_cls)
219+
p_name = registration_name or default_name(p_cls)
221220
if p_name in _PROBLEMS:
222221
raise LookupError("Problem %s already registered." % p_name)
223222

@@ -228,7 +227,7 @@ def decorator(p_cls, registration_name=None):
228227
# Handle if decorator was used without parens
229228
if callable(name):
230229
p_cls = name
231-
return decorator(p_cls, registration_name=_default_name(p_cls))
230+
return decorator(p_cls, registration_name=default_name(p_cls))
232231

233232
return lambda p_cls: decorator(p_cls, name)
234233

@@ -313,7 +312,7 @@ def _internal_register_modality(name, mod_collection, collection_str):
313312

314313
def decorator(mod_cls, registration_name=None):
315314
"""Registers & returns mod_cls with registration_name or default name."""
316-
mod_name = registration_name or _default_name(mod_cls)
315+
mod_name = registration_name or default_name(mod_cls)
317316
if mod_name in mod_collection:
318317
raise LookupError("%s modality %s already registered." % (collection_str,
319318
mod_name))
@@ -323,7 +322,7 @@ def decorator(mod_cls, registration_name=None):
323322
# Handle if decorator was used without parens
324323
if callable(name):
325324
mod_cls = name
326-
return decorator(mod_cls, registration_name=_default_name(mod_cls))
325+
return decorator(mod_cls, registration_name=default_name(mod_cls))
327326

328327
return lambda mod_cls: decorator(mod_cls, name)
329328

0 commit comments

Comments
 (0)