@@ -801,7 +801,7 @@ def combine_first_two_dimensions(x):
801
801
802
802
@expert_utils .add_name_scope ()
803
803
def split_heads (x , num_heads ):
804
- """Split channels (dimension 3 ) into multiple heads (becomes dimension 1).
804
+ """Split channels (dimension 2 ) into multiple heads (becomes dimension 1).
805
805
806
806
Args:
807
807
x: a Tensor with shape [batch, length, channels]
@@ -815,7 +815,7 @@ def split_heads(x, num_heads):
815
815
816
816
@expert_utils .add_name_scope ()
817
817
def split_heads_2d (x , num_heads ):
818
- """Split channels (dimension 4 ) into multiple heads (becomes dimension 1).
818
+ """Split channels (dimension 3 ) into multiple heads (becomes dimension 1).
819
819
820
820
Args:
821
821
x: a Tensor with shape [batch, height, width, channels]
@@ -968,12 +968,12 @@ def grouped_attention_multihead(query_antecedent,
968
968
name ,
969
969
default_name = "multihead_attention_sparse" ,
970
970
values = [query_antecedent , memory_antecedent ]):
971
- q = common_layers . conv1d (
972
- query_antecedent , total_key_depth , 1 , name = "q_transform" )
973
- kv = common_layers . conv1d (
971
+ q = tf . layers . dense (
972
+ query_antecedent , total_key_depth , use_bias = False , name = "q_transform" )
973
+ kv = tf . layers . dense (
974
974
memory_antecedent ,
975
975
total_key_depth + total_value_depth ,
976
- 1 ,
976
+ use_bias = False ,
977
977
name = "kv_transform" )
978
978
q = split_heads (q , num_heads )
979
979
kv = split_heads (kv , num_heads )
@@ -982,18 +982,18 @@ def grouped_attention_multihead(query_antecedent,
982
982
# We will train these by auxiliary losses. We use stop_gradient here
983
983
# to keep these losses from back-propagating to the rest of the model.
984
984
# We add biases that help balance the usage of the experts.
985
- q_pred = common_layers . conv1d (
985
+ q_pred = tf . layers . dense (
986
986
tf .stop_gradient (query_antecedent ),
987
987
num_heads * num_groups ,
988
- 1 ,
988
+ use_bias = False ,
989
989
name = "q_pred" )
990
990
q_pred = split_heads (q_pred , num_heads )
991
991
q_bias = tf .get_variable ("q_bias" , [1 , num_heads , 1 , num_groups ])
992
992
q_pred_biased = q_pred + q_bias
993
- m_pred = common_layers . conv1d (
993
+ m_pred = tf . layers . dense (
994
994
tf .stop_gradient (memory_antecedent ),
995
995
num_heads * num_groups ,
996
- 1 ,
996
+ use_bias = False ,
997
997
name = "m_pred" )
998
998
m_pred = split_heads (m_pred , num_heads )
999
999
m_bias = tf .get_variable ("m_bias" , [1 , num_heads , 1 , num_groups ])
@@ -1059,7 +1059,8 @@ def grouped_attention_multihead(query_antecedent,
1059
1059
1060
1060
o = tf .reshape (o , [batch , num_heads , length_q , depth_v ])
1061
1061
o = combine_heads (o )
1062
- o = common_layers .conv1d (o , output_depth , 1 , name = "output_transform" )
1062
+ o = tf .layers .dense (
1063
+ o , output_depth , use_bias = False , name = "output_transform" )
1063
1064
1064
1065
m_total = m_dispatcher .combine (m_total )
1065
1066
q_total = q_dispatcher .combine (q_total )
@@ -2189,86 +2190,19 @@ def compute_qkv(query_antecedent,
2189
2190
Returns:
2190
2191
q, k, v : [batch, length, depth] tensors
2191
2192
"""
2192
- if memory_antecedent is None and q_filter_width == kv_filter_width == 1 :
2193
- # self attention with single position q, k, and v
2194
- combined = common_layers .conv1d (
2195
- query_antecedent ,
2196
- total_key_depth * 2 + total_value_depth ,
2197
- 1 ,
2198
- name = "qkv_transform" )
2199
- q , k , v = tf .split (
2200
- combined , [total_key_depth , total_key_depth , total_value_depth ], axis = 2 )
2201
- return q , k , v
2202
-
2203
- if memory_antecedent is None :
2204
- # self attention
2205
- q = common_layers .conv1d (
2206
- query_antecedent ,
2207
- total_key_depth ,
2208
- q_filter_width ,
2209
- padding = q_padding ,
2210
- name = "q_transform" )
2211
- kv_combined = common_layers .conv1d (
2212
- query_antecedent ,
2213
- total_key_depth + total_value_depth ,
2214
- kv_filter_width ,
2215
- padding = kv_padding ,
2216
- name = "kv_transform" )
2217
- k , v = tf .split (kv_combined , [total_key_depth , total_value_depth ], axis = 2 )
2218
- return q , k , v
2219
-
2220
- # encoder-decoder attention
2221
- q = common_layers .conv1d (
2222
- query_antecedent ,
2223
- total_key_depth ,
2224
- q_filter_width ,
2225
- padding = q_padding ,
2226
- name = "q_transform" )
2227
- combined = common_layers .conv1d (
2228
- memory_antecedent ,
2229
- total_key_depth + total_value_depth ,
2230
- 1 ,
2231
- padding = kv_padding ,
2232
- name = "kv_transform" )
2233
- k , v = tf .split (combined , [total_key_depth , total_value_depth ], axis = 2 )
2234
-
2235
- return q , k , v
2236
-
2237
-
2238
- def compute_qkv_2d (query_antecedent , memory_antecedent , total_key_depth ,
2239
- total_value_depth ):
2240
- """Computes query, key and value.
2241
-
2242
- Args:
2243
- query_antecedent: a Tensor with shape [batch, h, w, depth_k]
2244
- memory_antecedent: a Tensor with shape [batch, h, w, depth_k]
2245
- total_key_depth: an integer
2246
- total_value_depth: and integer
2247
-
2248
- Returns:
2249
- q, k, v : [batch, h, w, depth_k] tensors
2250
- """
2251
- # self attention with single position q, k, and v
2252
2193
if memory_antecedent is None :
2253
- combined = tf .layers .conv2d (
2254
- query_antecedent ,
2255
- total_key_depth * 2 + total_value_depth , (1 , 1 ),
2256
- name = "qkv_transform" )
2257
- q , k , v = tf .split (
2258
- combined , [total_key_depth , total_key_depth , total_value_depth ],
2259
- axis = - 1 )
2260
- return q , k , v
2261
-
2262
- # Encoder decoder attention
2263
- q = common_layers .conv1d (
2264
- query_antecedent , total_key_depth , 1 , name = "q_transform" )
2265
- combined = common_layers .conv1d (
2266
- memory_antecedent ,
2267
- total_key_depth + total_value_depth ,
2268
- 1 ,
2269
- name = "kv_transform" )
2270
- k , v = tf .split (combined , [total_key_depth , total_value_depth ], axis = 2 )
2271
-
2194
+ memory_antecedent = query_antecedent
2195
+ def _compute (inp , depth , filter_width , padding , name ):
2196
+ if filter_width == 1 :
2197
+ return tf .layers .dense (inp , depth , use_bias = False , name = name )
2198
+ else :
2199
+ return common_layers .conv1d (inp , depth , filter_width , padding , name = name )
2200
+ q = _compute (
2201
+ query_antecedent , total_key_depth , q_filter_width , q_padding , "q" )
2202
+ k = _compute (
2203
+ memory_antecedent , total_key_depth , kv_filter_width , kv_padding , "k" )
2204
+ v = _compute (
2205
+ memory_antecedent , total_value_depth , kv_filter_width , kv_padding , "v" )
2272
2206
return q , k , v
2273
2207
2274
2208
@@ -2410,7 +2344,8 @@ def multihead_attention(query_antecedent,
2410
2344
x = dilated_self_attention_1d (q , k , v , block_length , block_width ,
2411
2345
gap_size , num_memory_blocks )
2412
2346
x = combine_heads (x )
2413
- x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
2347
+ x = tf .layers .dense (
2348
+ x , output_depth , use_bias = False , name = "output_transform" )
2414
2349
if additional_returned_value is not None :
2415
2350
return x , additional_returned_value
2416
2351
return x
@@ -2457,8 +2392,8 @@ def multihead_attention_2d(query_antecedent,
2457
2392
name ,
2458
2393
default_name = "multihead_attention_2d" ,
2459
2394
values = [query_antecedent , memory_antecedent ]):
2460
- q , k , v = compute_qkv_2d (query_antecedent , memory_antecedent ,
2461
- total_key_depth , total_value_depth )
2395
+ q , k , v = compute_qkv (query_antecedent , memory_antecedent ,
2396
+ total_key_depth , total_value_depth )
2462
2397
# after splitting, shape is [batch, heads, h, w, depth]
2463
2398
q = split_heads_2d (q , num_heads )
2464
2399
k = split_heads_2d (k , num_heads )
@@ -2473,7 +2408,8 @@ def multihead_attention_2d(query_antecedent,
2473
2408
x = masked_local_attention_2d (
2474
2409
q , k , v , query_shape = query_shape , memory_flange = memory_flange )
2475
2410
x = combine_heads_2d (x )
2476
- x = tf .layers .conv2d (x , output_depth , (1 , 1 ), name = "output_transform" )
2411
+ x = tf .layers .dense (
2412
+ x , output_depth , use_bias = False , name = "output_transform" )
2477
2413
return x
2478
2414
2479
2415
@@ -2512,16 +2448,18 @@ def ffn_self_attention_layer(x,
2512
2448
x_shape = common_layers .shape_list (x )
2513
2449
part_depth = filter_depth // num_parts
2514
2450
if not share_kv :
2515
- combined = common_layers . conv1d (
2516
- x , filter_depth * 3 , 1 , name = "qkv_transform" )
2451
+ combined = tf . layers . dense (
2452
+ x , filter_depth * 3 , use_bias = False , name = "qkv_transform" )
2517
2453
combined = tf .expand_dims (combined , axis = 2 )
2518
2454
q , k , v = tf .split (combined , 3 , axis = 3 )
2519
2455
else :
2520
2456
q = tf .expand_dims (
2521
- common_layers .conv1d (x , filter_depth , 1 , name = "q_transform" ), axis = 2 )
2457
+ tf .layers .dense (
2458
+ x , filter_depth , use_bias = False , name = "q_transform" ), axis = 2 )
2522
2459
kv_combined = tf .expand_dims (
2523
- common_layers .conv1d (
2524
- tf .concat ([x , x ], axis = 1 ), filter_depth , 1 , name = "kv_transform" ),
2460
+ tf .layers .dense (
2461
+ tf .concat ([x , x ], axis = 1 ), filter_depth , use_bias = False ,
2462
+ name = "kv_transform" ),
2525
2463
axis = 2 )
2526
2464
k , v = tf .split (kv_combined , [x_shape [1 ], x_shape [1 ]], axis = 1 )
2527
2465
@@ -2534,7 +2472,8 @@ def ffn_self_attention_layer(x,
2534
2472
bias = None
2535
2473
x = dot_product_attention (batch_q , batch_k , batch_v , bias , dropout_rate )
2536
2474
x = tf .reshape (x , [x_shape [0 ], x_shape [1 ], filter_depth ])
2537
- x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
2475
+ x = tf .layers .dense (
2476
+ x , output_depth , use_bias = False , name = "output_transform" )
2538
2477
return x
2539
2478
2540
2479
@@ -2585,7 +2524,7 @@ def parameter_attention(x,
2585
2524
output_depth ** 0.5 )
2586
2525
batch_size = common_layers .shape_list (x )[0 ]
2587
2526
length = common_layers .shape_list (x )[1 ]
2588
- q = common_layers . conv1d (x , total_key_depth , 1 , name = "q_transform" )
2527
+ q = tf . layers . dense (x , total_key_depth , use_bias = False , name = "q_transform" )
2589
2528
if dropout_rate :
2590
2529
# This is a cheaper form of attention dropout where we use to use
2591
2530
# the same dropout decisions across batch elemets and query positions,
@@ -2604,7 +2543,8 @@ def parameter_attention(x,
2604
2543
y = tf .transpose (y , [1 , 2 , 0 , 3 ])
2605
2544
y = tf .reshape (y , [batch_size , length , total_value_depth ])
2606
2545
y .set_shape ([None , None , total_value_depth ])
2607
- y = common_layers .conv1d (y , output_depth , 1 , name = "output_transform" )
2546
+ y = tf .layers .dense (
2547
+ y , output_depth , use_bias = False , name = "output_transform" )
2608
2548
return y
2609
2549
2610
2550
0 commit comments