Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit c10e016

Browse files
nshazeerRyan Sepassi
authored and
Ryan Sepassi
committed
This change breaks previous checkpoints. Make Transformer fast on TPU.
PiperOrigin-RevId: 176747359
1 parent cc80721 commit c10e016

File tree

7 files changed

+88
-30
lines changed

7 files changed

+88
-30
lines changed

tensor2tensor/layers/common_attention.py

+13-15
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def combine_first_two_dimensions(x):
801801

802802
@expert_utils.add_name_scope()
803803
def split_heads(x, num_heads):
804-
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
804+
"""Split channels (dimension 2) into multiple heads (becomes dimension 1).
805805
806806
Args:
807807
x: a Tensor with shape [batch, length, channels]
@@ -815,7 +815,7 @@ def split_heads(x, num_heads):
815815

816816
@expert_utils.add_name_scope()
817817
def split_heads_2d(x, num_heads):
818-
"""Split channels (dimension 4) into multiple heads (becomes dimension 1).
818+
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
819819
820820
Args:
821821
x: a Tensor with shape [batch, height, width, channels]
@@ -2191,10 +2191,10 @@ def compute_qkv(query_antecedent,
21912191
"""
21922192
if memory_antecedent is None and q_filter_width == kv_filter_width == 1:
21932193
# self attention with single position q, k, and v
2194-
combined = common_layers.conv1d(
2194+
combined = tf.layers.dense(
21952195
query_antecedent,
21962196
total_key_depth * 2 + total_value_depth,
2197-
1,
2197+
use_bias=False,
21982198
name="qkv_transform")
21992199
q, k, v = tf.split(
22002200
combined, [total_key_depth, total_key_depth, total_value_depth], axis=2)
@@ -2250,22 +2250,19 @@ def compute_qkv_2d(query_antecedent, memory_antecedent, total_key_depth,
22502250
"""
22512251
# self attention with single position q, k, and v
22522252
if memory_antecedent is None:
2253-
combined = tf.layers.conv2d(
2254-
query_antecedent,
2255-
total_key_depth * 2 + total_value_depth, (1, 1),
2256-
name="qkv_transform")
2253+
combined = tf.layers.dense(
2254+
query_antecedent, total_key_depth * 2 + total_value_depth,
2255+
use_bias=False, name="qkv_transform")
22572256
q, k, v = tf.split(
22582257
combined, [total_key_depth, total_key_depth, total_value_depth],
22592258
axis=-1)
22602259
return q, k, v
22612260

22622261
# Encoder decoder attention
2263-
q = common_layers.conv1d(
2264-
query_antecedent, total_key_depth, 1, name="q_transform")
2265-
combined = common_layers.conv1d(
2266-
memory_antecedent,
2267-
total_key_depth + total_value_depth,
2268-
1,
2262+
q = tf.layers.dense(
2263+
query_antecedent, total_key_depth, use_bias=False, name="q_transform")
2264+
combined = tf.layers.dense(
2265+
memory_antecedent, total_key_depth + total_value_depth, use_bias=False,
22692266
name="kv_transform")
22702267
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
22712268

@@ -2410,7 +2407,8 @@ def multihead_attention(query_antecedent,
24102407
x = dilated_self_attention_1d(q, k, v, block_length, block_width,
24112408
gap_size, num_memory_blocks)
24122409
x = combine_heads(x)
2413-
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
2410+
x = tf.layers.dense(
2411+
x, output_depth, use_bias=False, name="output_transform")
24142412
if additional_returned_value is not None:
24152413
return x, additional_returned_value
24162414
return x

tensor2tensor/layers/common_hparams.py

+3
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def basic_params1():
179179
# This is the actual batch size, *not* tokens per batch (i.e. for
180180
# language models this is the number of sentences in the batch)
181181
tpu_batch_size_per_shard=24,
182+
# Set by tpu_trainer to let the model know whether we are on TPU.
183+
# Switching on/off tpu should not invalidate checkpoints.
184+
use_tpu=False,
182185
)
183186

184187

tensor2tensor/layers/common_layers.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,15 @@ def relu_density_logit(x, reduce_dims):
12291229
return scaled
12301230

12311231

1232+
def conv_hidden_relu_simple(inputs, hidden_size, output_size, dropout=0.0):
1233+
h = tf.layers.dense(
1234+
inputs, hidden_size, use_bias=False, activation=tf.nn.relu, name="conv1")
1235+
if dropout != 0.0:
1236+
h = tf.nn.dropout(h, 1.0 - dropout)
1237+
o = tf.layers.dense(h, output_size, use_bias=False, name="conv2")
1238+
return o
1239+
1240+
12321241
def conv_hidden_relu(inputs,
12331242
hidden_size,
12341243
output_size,
@@ -1239,6 +1248,9 @@ def conv_hidden_relu(inputs,
12391248
"""Hidden layer with RELU activation followed by linear projection."""
12401249
name = kwargs.pop("name") if "name" in kwargs else None
12411250
with tf.variable_scope(name, "conv_hidden_relu", [inputs]):
1251+
if kernel_size == (1, 1) and second_kernel_size == (1, 1):
1252+
return conv_hidden_relu_simple(
1253+
inputs, hidden_size, output_size, dropout=dropout)
12421254
if inputs.get_shape().ndims == 3:
12431255
is_3d = True
12441256
inputs = tf.expand_dims(inputs, 2)
@@ -1487,10 +1499,15 @@ def padded_cross_entropy(logits,
14871499
confidence = 1.0 - label_smoothing
14881500
vocab_size = shape_list(logits)[-1]
14891501
with tf.name_scope("padded_cross_entropy", [logits, labels]):
1490-
pad_logits, pad_labels = pad_with_zeros(logits, labels)
1491-
xent = smoothing_cross_entropy(pad_logits, pad_labels, vocab_size,
1492-
confidence)
1493-
weights = weights_fn(pad_labels)
1502+
if len(logits.get_shape().as_list()) == 2:
1503+
# Deal with the case where we did not insert extra dimensions due to
1504+
# TPU issues. No pad-to-same-length happens in this case.
1505+
# TODO(noam): remove this logic once TPU can handle extra dimensions.
1506+
labels = tf.reshape(labels, [-1])
1507+
else:
1508+
logits, labels = pad_with_zeros(logits, labels)
1509+
xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence)
1510+
weights = weights_fn(labels)
14941511
if not reduce_sum:
14951512
return xent * weights, weights
14961513
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)

tensor2tensor/layers/modalities.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@
3030
import tensorflow as tf
3131

3232

33+
# TODO(noam): remove this function after TPUs do gather faster.
34+
def tpu_gather(params, indices):
35+
vocab_size = params.get_shape().as_list()[0]
36+
indices_flat = tf.reshape(indices, [-1])
37+
out = tf.matmul(tf.one_hot(indices_flat, vocab_size), params)
38+
out = eu.reshape_like(out, tf.expand_dims(indices, -1))
39+
return out
40+
41+
3342
@registry.register_symbol_modality("default")
3443
class SymbolModality(modality.Modality):
3544
"""Modality for sets of discrete symbols.
@@ -94,7 +103,8 @@ def bottom_simple(self, x, name, reuse):
94103
# Squeeze out the channels dimension.
95104
x = tf.squeeze(x, axis=3)
96105
var = self._get_weights()
97-
ret = tf.gather(var, x)
106+
ret = (tpu_gather(var, x) if self._model_hparams.use_tpu
107+
else tf.gather(var, x))
98108
if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
99109
ret *= self._body_input_depth**0.5
100110
ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
@@ -142,14 +152,18 @@ def top(self, body_output, _):
142152
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
143153
# insert channels dimension
144154
body_output = tf.expand_dims(body_output, 3)
145-
logits = common_layers.FactoredTensor(body_output, var)
155+
return common_layers.FactoredTensor(body_output, var)
146156
else:
147157
body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
148158
logits = tf.matmul(body_output, var, transpose_b=True)
149-
150-
out_shape = body_output_shape[:-1] + [1, self._vocab_size]
151-
logits = tf.reshape(logits, out_shape)
152-
return logits
159+
if (self._model_hparams.use_tpu and
160+
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
161+
# TPU does not react kindly to extra dimensions.
162+
# TODO(noam): remove this once TPU is more forgiving of extra dims.
163+
return logits
164+
else:
165+
return tf.reshape(
166+
logits, body_output_shape[:-1] + [1, self._vocab_size])
153167

154168

155169
@registry.register_symbol_modality("ctc")

tensor2tensor/layers/modalities_test.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def testSymbolModalityInputs(self):
4242
multiply_embedding_mode="sqrt_depth",
4343
symbol_modality_skip_top=0,
4444
shared_embedding_and_softmax_weights=0,
45-
prepend_mode="none")
45+
prepend_mode="none",
46+
use_tpu=False)
4647
x = -1 + np.random.random_integers(
4748
vocab_size, size=(batch_size, length, 1, 1))
4849
m = modalities.SymbolModality(model_hparams, vocab_size)
@@ -71,7 +72,8 @@ def testSymbolModalityTargets(self):
7172
shared_embedding_and_softmax_weights=0,
7273
factored_logits=0,
7374
mode=tf.estimator.ModeKeys.TRAIN,
74-
prepend_mode="none")
75+
prepend_mode="none",
76+
use_tpu=False)
7577
body_output = -1 + np.random.random_integers(
7678
100, size=(batch_size, length, height, hidden_size))
7779
targets = -1 + np.random.random_integers(
@@ -107,7 +109,8 @@ def testSymbolModalityTargetsFactored(self):
107109
shared_embedding_and_softmax_weights=0,
108110
factored_logits=1,
109111
mode=tf.estimator.ModeKeys.TRAIN,
110-
prepend_mode="none")
112+
prepend_mode="none",
113+
use_tpu=False)
111114
body_output = -1 + np.random.random_integers(
112115
100, size=(batch_size, length, height, hidden_size))
113116
targets = -1 + np.random.random_integers(

tensor2tensor/models/transformer.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,13 @@ def decode(self,
108108
hparams,
109109
cache=cache)
110110

111-
# Expand since t2t expects 4d tensors.
112-
return tf.expand_dims(decoder_output, axis=2)
111+
if hparams.use_tpu and hparams.mode == tf.estimator.ModeKeys.TRAIN:
112+
# TPU does not react kindly to extra dimensions.
113+
# TODO(noam): remove this once TPU is more forgiving of extra dims.
114+
return decoder_output
115+
else:
116+
# Expand since t2t expects 4d tensors.
117+
return tf.expand_dims(decoder_output, axis=2)
113118

114119
def model_fn_body(self, features):
115120
"""Transformer main model_fn.
@@ -1113,3 +1118,20 @@ def transformer_clean_big():
11131118
hparams.hidden_size = 1024
11141119
hparams.filter_size = 4096
11151120
return hparams
1121+
1122+
1123+
@registry.register_hparams
1124+
def transformer_tpu_lm1b():
1125+
"""Hparams for training languagemodel_lm1b8k_concat on tpu."""
1126+
hparams = transformer_clean()
1127+
update_hparams_for_tpu(hparams)
1128+
hparams.max_length = 512
1129+
hparams.tpu_batch_size_per_shard = 8
1130+
hparams.hidden_size = 1024
1131+
hparams.filter_size = 4096
1132+
hparams.num_heads = 4
1133+
hparams.label_smoothing = 0.0
1134+
hparams.layer_prepostprocess_dropout = 0.0
1135+
hparams.attention_dropout = 0.0
1136+
hparams.relu_dropout = 0.0
1137+
return hparams

tensor2tensor/tpu/tpu_trainer_lib.py

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def t2t_model_fn(model_name,
212212
hparams = copy.deepcopy(hparams)
213213
problem = hparams.problem_instances[0]
214214
problem_hp = hparams.problems[0]
215+
hparams.use_tpu = use_tpu
215216

216217
features["problem_choice"] = tf.constant(0)
217218
features["input_space_id"] = tf.constant(problem_hp.input_space_id)

0 commit comments

Comments
 (0)