Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0ffe0e6

Browse files
nshazeerRyan Sepassi
authored and
Ryan Sepassi
committed
This change breaks previous checkpoints. Make Transformer fast on TPU.
PiperOrigin-RevId: 177255666
1 parent 398e85b commit 0ffe0e6

File tree

7 files changed

+275
-133
lines changed

7 files changed

+275
-133
lines changed

tensor2tensor/layers/common_attention.py

+42-102
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def combine_first_two_dimensions(x):
801801

802802
@expert_utils.add_name_scope()
803803
def split_heads(x, num_heads):
804-
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
804+
"""Split channels (dimension 2) into multiple heads (becomes dimension 1).
805805
806806
Args:
807807
x: a Tensor with shape [batch, length, channels]
@@ -815,7 +815,7 @@ def split_heads(x, num_heads):
815815

816816
@expert_utils.add_name_scope()
817817
def split_heads_2d(x, num_heads):
818-
"""Split channels (dimension 4) into multiple heads (becomes dimension 1).
818+
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
819819
820820
Args:
821821
x: a Tensor with shape [batch, height, width, channels]
@@ -968,12 +968,12 @@ def grouped_attention_multihead(query_antecedent,
968968
name,
969969
default_name="multihead_attention_sparse",
970970
values=[query_antecedent, memory_antecedent]):
971-
q = common_layers.conv1d(
972-
query_antecedent, total_key_depth, 1, name="q_transform")
973-
kv = common_layers.conv1d(
971+
q = tf.layers.dense(
972+
query_antecedent, total_key_depth, use_bias=False, name="q_transform")
973+
kv = tf.layers.dense(
974974
memory_antecedent,
975975
total_key_depth + total_value_depth,
976-
1,
976+
use_bias=False,
977977
name="kv_transform")
978978
q = split_heads(q, num_heads)
979979
kv = split_heads(kv, num_heads)
@@ -982,18 +982,18 @@ def grouped_attention_multihead(query_antecedent,
982982
# We will train these by auxiliary losses. We use stop_gradient here
983983
# to keep these losses from back-propagating to the rest of the model.
984984
# We add biases that help balance the usage of the experts.
985-
q_pred = common_layers.conv1d(
985+
q_pred = tf.layers.dense(
986986
tf.stop_gradient(query_antecedent),
987987
num_heads * num_groups,
988-
1,
988+
use_bias=False,
989989
name="q_pred")
990990
q_pred = split_heads(q_pred, num_heads)
991991
q_bias = tf.get_variable("q_bias", [1, num_heads, 1, num_groups])
992992
q_pred_biased = q_pred + q_bias
993-
m_pred = common_layers.conv1d(
993+
m_pred = tf.layers.dense(
994994
tf.stop_gradient(memory_antecedent),
995995
num_heads * num_groups,
996-
1,
996+
use_bias=False,
997997
name="m_pred")
998998
m_pred = split_heads(m_pred, num_heads)
999999
m_bias = tf.get_variable("m_bias", [1, num_heads, 1, num_groups])
@@ -1059,7 +1059,8 @@ def grouped_attention_multihead(query_antecedent,
10591059

10601060
o = tf.reshape(o, [batch, num_heads, length_q, depth_v])
10611061
o = combine_heads(o)
1062-
o = common_layers.conv1d(o, output_depth, 1, name="output_transform")
1062+
o = tf.layers.dense(
1063+
o, output_depth, use_bias=False, name="output_transform")
10631064

10641065
m_total = m_dispatcher.combine(m_total)
10651066
q_total = q_dispatcher.combine(q_total)
@@ -2189,86 +2190,19 @@ def compute_qkv(query_antecedent,
21892190
Returns:
21902191
q, k, v : [batch, length, depth] tensors
21912192
"""
2192-
if memory_antecedent is None and q_filter_width == kv_filter_width == 1:
2193-
# self attention with single position q, k, and v
2194-
combined = common_layers.conv1d(
2195-
query_antecedent,
2196-
total_key_depth * 2 + total_value_depth,
2197-
1,
2198-
name="qkv_transform")
2199-
q, k, v = tf.split(
2200-
combined, [total_key_depth, total_key_depth, total_value_depth], axis=2)
2201-
return q, k, v
2202-
2203-
if memory_antecedent is None:
2204-
# self attention
2205-
q = common_layers.conv1d(
2206-
query_antecedent,
2207-
total_key_depth,
2208-
q_filter_width,
2209-
padding=q_padding,
2210-
name="q_transform")
2211-
kv_combined = common_layers.conv1d(
2212-
query_antecedent,
2213-
total_key_depth + total_value_depth,
2214-
kv_filter_width,
2215-
padding=kv_padding,
2216-
name="kv_transform")
2217-
k, v = tf.split(kv_combined, [total_key_depth, total_value_depth], axis=2)
2218-
return q, k, v
2219-
2220-
# encoder-decoder attention
2221-
q = common_layers.conv1d(
2222-
query_antecedent,
2223-
total_key_depth,
2224-
q_filter_width,
2225-
padding=q_padding,
2226-
name="q_transform")
2227-
combined = common_layers.conv1d(
2228-
memory_antecedent,
2229-
total_key_depth + total_value_depth,
2230-
1,
2231-
padding=kv_padding,
2232-
name="kv_transform")
2233-
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
2234-
2235-
return q, k, v
2236-
2237-
2238-
def compute_qkv_2d(query_antecedent, memory_antecedent, total_key_depth,
2239-
total_value_depth):
2240-
"""Computes query, key and value.
2241-
2242-
Args:
2243-
query_antecedent: a Tensor with shape [batch, h, w, depth_k]
2244-
memory_antecedent: a Tensor with shape [batch, h, w, depth_k]
2245-
total_key_depth: an integer
2246-
total_value_depth: and integer
2247-
2248-
Returns:
2249-
q, k, v : [batch, h, w, depth_k] tensors
2250-
"""
2251-
# self attention with single position q, k, and v
22522193
if memory_antecedent is None:
2253-
combined = tf.layers.conv2d(
2254-
query_antecedent,
2255-
total_key_depth * 2 + total_value_depth, (1, 1),
2256-
name="qkv_transform")
2257-
q, k, v = tf.split(
2258-
combined, [total_key_depth, total_key_depth, total_value_depth],
2259-
axis=-1)
2260-
return q, k, v
2261-
2262-
# Encoder decoder attention
2263-
q = common_layers.conv1d(
2264-
query_antecedent, total_key_depth, 1, name="q_transform")
2265-
combined = common_layers.conv1d(
2266-
memory_antecedent,
2267-
total_key_depth + total_value_depth,
2268-
1,
2269-
name="kv_transform")
2270-
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)
2271-
2194+
memory_antecedent = query_antecedent
2195+
def _compute(inp, depth, filter_width, padding, name):
2196+
if filter_width == 1:
2197+
return tf.layers.dense(inp, depth, use_bias=False, name=name)
2198+
else:
2199+
return common_layers.conv1d(inp, depth, filter_width, padding, name=name)
2200+
q = _compute(
2201+
query_antecedent, total_key_depth, q_filter_width, q_padding, "q")
2202+
k = _compute(
2203+
memory_antecedent, total_key_depth, kv_filter_width, kv_padding, "k")
2204+
v = _compute(
2205+
memory_antecedent, total_value_depth, kv_filter_width, kv_padding, "v")
22722206
return q, k, v
22732207

22742208

@@ -2410,7 +2344,8 @@ def multihead_attention(query_antecedent,
24102344
x = dilated_self_attention_1d(q, k, v, block_length, block_width,
24112345
gap_size, num_memory_blocks)
24122346
x = combine_heads(x)
2413-
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
2347+
x = tf.layers.dense(
2348+
x, output_depth, use_bias=False, name="output_transform")
24142349
if additional_returned_value is not None:
24152350
return x, additional_returned_value
24162351
return x
@@ -2457,8 +2392,8 @@ def multihead_attention_2d(query_antecedent,
24572392
name,
24582393
default_name="multihead_attention_2d",
24592394
values=[query_antecedent, memory_antecedent]):
2460-
q, k, v = compute_qkv_2d(query_antecedent, memory_antecedent,
2461-
total_key_depth, total_value_depth)
2395+
q, k, v = compute_qkv(query_antecedent, memory_antecedent,
2396+
total_key_depth, total_value_depth)
24622397
# after splitting, shape is [batch, heads, h, w, depth]
24632398
q = split_heads_2d(q, num_heads)
24642399
k = split_heads_2d(k, num_heads)
@@ -2473,7 +2408,8 @@ def multihead_attention_2d(query_antecedent,
24732408
x = masked_local_attention_2d(
24742409
q, k, v, query_shape=query_shape, memory_flange=memory_flange)
24752410
x = combine_heads_2d(x)
2476-
x = tf.layers.conv2d(x, output_depth, (1, 1), name="output_transform")
2411+
x = tf.layers.dense(
2412+
x, output_depth, use_bias=False, name="output_transform")
24772413
return x
24782414

24792415

@@ -2512,16 +2448,18 @@ def ffn_self_attention_layer(x,
25122448
x_shape = common_layers.shape_list(x)
25132449
part_depth = filter_depth // num_parts
25142450
if not share_kv:
2515-
combined = common_layers.conv1d(
2516-
x, filter_depth * 3, 1, name="qkv_transform")
2451+
combined = tf.layers.dense(
2452+
x, filter_depth * 3, use_bias=False, name="qkv_transform")
25172453
combined = tf.expand_dims(combined, axis=2)
25182454
q, k, v = tf.split(combined, 3, axis=3)
25192455
else:
25202456
q = tf.expand_dims(
2521-
common_layers.conv1d(x, filter_depth, 1, name="q_transform"), axis=2)
2457+
tf.layers.dense(
2458+
x, filter_depth, use_bias=False, name="q_transform"), axis=2)
25222459
kv_combined = tf.expand_dims(
2523-
common_layers.conv1d(
2524-
tf.concat([x, x], axis=1), filter_depth, 1, name="kv_transform"),
2460+
tf.layers.dense(
2461+
tf.concat([x, x], axis=1), filter_depth, use_bias=False,
2462+
name="kv_transform"),
25252463
axis=2)
25262464
k, v = tf.split(kv_combined, [x_shape[1], x_shape[1]], axis=1)
25272465

@@ -2534,7 +2472,8 @@ def ffn_self_attention_layer(x,
25342472
bias = None
25352473
x = dot_product_attention(batch_q, batch_k, batch_v, bias, dropout_rate)
25362474
x = tf.reshape(x, [x_shape[0], x_shape[1], filter_depth])
2537-
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
2475+
x = tf.layers.dense(
2476+
x, output_depth, use_bias=False, name="output_transform")
25382477
return x
25392478

25402479

@@ -2585,7 +2524,7 @@ def parameter_attention(x,
25852524
output_depth**0.5)
25862525
batch_size = common_layers.shape_list(x)[0]
25872526
length = common_layers.shape_list(x)[1]
2588-
q = common_layers.conv1d(x, total_key_depth, 1, name="q_transform")
2527+
q = tf.layers.dense(x, total_key_depth, use_bias=False, name="q_transform")
25892528
if dropout_rate:
25902529
# This is a cheaper form of attention dropout where we use to use
25912530
# the same dropout decisions across batch elemets and query positions,
@@ -2604,7 +2543,8 @@ def parameter_attention(x,
26042543
y = tf.transpose(y, [1, 2, 0, 3])
26052544
y = tf.reshape(y, [batch_size, length, total_value_depth])
26062545
y.set_shape([None, None, total_value_depth])
2607-
y = common_layers.conv1d(y, output_depth, 1, name="output_transform")
2546+
y = tf.layers.dense(
2547+
y, output_depth, use_bias=False, name="output_transform")
26082548
return y
26092549

26102550

tensor2tensor/layers/common_hparams.py

+3
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ def basic_params1():
187187
# Things not compatible with eager mode use this flag to implement
188188
# alternative functionality. We expect this to go away soon.
189189
use_eager_mode=False,
190+
# Set by tpu_trainer to let the model know whether we are on TPU.
191+
# Switching on/off tpu should not invalidate checkpoints.
192+
use_tpu=False,
190193
)
191194

192195

0 commit comments

Comments
 (0)