@@ -801,7 +801,7 @@ def combine_first_two_dimensions(x):
801
801
802
802
@expert_utils .add_name_scope ()
803
803
def split_heads (x , num_heads ):
804
- """Split channels (dimension 3 ) into multiple heads (becomes dimension 1).
804
+ """Split channels (dimension 2 ) into multiple heads (becomes dimension 1).
805
805
806
806
Args:
807
807
x: a Tensor with shape [batch, length, channels]
@@ -815,7 +815,7 @@ def split_heads(x, num_heads):
815
815
816
816
@expert_utils .add_name_scope ()
817
817
def split_heads_2d (x , num_heads ):
818
- """Split channels (dimension 4 ) into multiple heads (becomes dimension 1).
818
+ """Split channels (dimension 3 ) into multiple heads (becomes dimension 1).
819
819
820
820
Args:
821
821
x: a Tensor with shape [batch, height, width, channels]
@@ -2191,10 +2191,10 @@ def compute_qkv(query_antecedent,
2191
2191
"""
2192
2192
if memory_antecedent is None and q_filter_width == kv_filter_width == 1 :
2193
2193
# self attention with single position q, k, and v
2194
- combined = common_layers . conv1d (
2194
+ combined = tf . layers . dense (
2195
2195
query_antecedent ,
2196
2196
total_key_depth * 2 + total_value_depth ,
2197
- 1 ,
2197
+ use_bias = False ,
2198
2198
name = "qkv_transform" )
2199
2199
q , k , v = tf .split (
2200
2200
combined , [total_key_depth , total_key_depth , total_value_depth ], axis = 2 )
@@ -2250,22 +2250,19 @@ def compute_qkv_2d(query_antecedent, memory_antecedent, total_key_depth,
2250
2250
"""
2251
2251
# self attention with single position q, k, and v
2252
2252
if memory_antecedent is None :
2253
- combined = tf .layers .conv2d (
2254
- query_antecedent ,
2255
- total_key_depth * 2 + total_value_depth , (1 , 1 ),
2256
- name = "qkv_transform" )
2253
+ combined = tf .layers .dense (
2254
+ query_antecedent , total_key_depth * 2 + total_value_depth ,
2255
+ use_bias = False , name = "qkv_transform" )
2257
2256
q , k , v = tf .split (
2258
2257
combined , [total_key_depth , total_key_depth , total_value_depth ],
2259
2258
axis = - 1 )
2260
2259
return q , k , v
2261
2260
2262
2261
# Encoder decoder attention
2263
- q = common_layers .conv1d (
2264
- query_antecedent , total_key_depth , 1 , name = "q_transform" )
2265
- combined = common_layers .conv1d (
2266
- memory_antecedent ,
2267
- total_key_depth + total_value_depth ,
2268
- 1 ,
2262
+ q = tf .layers .dense (
2263
+ query_antecedent , total_key_depth , use_bias = False , name = "q_transform" )
2264
+ combined = tf .layers .dense (
2265
+ memory_antecedent , total_key_depth + total_value_depth , use_bias = False ,
2269
2266
name = "kv_transform" )
2270
2267
k , v = tf .split (combined , [total_key_depth , total_value_depth ], axis = 2 )
2271
2268
@@ -2410,7 +2407,8 @@ def multihead_attention(query_antecedent,
2410
2407
x = dilated_self_attention_1d (q , k , v , block_length , block_width ,
2411
2408
gap_size , num_memory_blocks )
2412
2409
x = combine_heads (x )
2413
- x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
2410
+ x = tf .layers .dense (
2411
+ x , output_depth , use_bias = False , name = "output_transform" )
2414
2412
if additional_returned_value is not None :
2415
2413
return x , additional_returned_value
2416
2414
return x
0 commit comments