@@ -185,17 +185,23 @@ def get_attention_scores(self, query, key, attention_mask=None):
185
185
query = query .float ()
186
186
key = key .float ()
187
187
188
+ if attention_mask is None :
189
+ baddbmm_input = torch .empty (
190
+ query .shape [0 ], query .shape [1 ], key .shape [1 ], dtype = query .dtype , device = query .device
191
+ )
192
+ beta = 0
193
+ else :
194
+ baddbmm_input = attention_mask
195
+ beta = 1
196
+
188
197
attention_scores = torch .baddbmm (
189
- torch . empty ( query . shape [ 0 ], query . shape [ 1 ], key . shape [ 1 ], dtype = query . dtype , device = query . device ) ,
198
+ baddbmm_input ,
190
199
query ,
191
200
key .transpose (- 1 , - 2 ),
192
- beta = 0 ,
201
+ beta = beta ,
193
202
alpha = self .scale ,
194
203
)
195
204
196
- if attention_mask is not None :
197
- attention_scores = attention_scores + attention_mask
198
-
199
205
if self .upcast_softmax :
200
206
attention_scores = attention_scores .float ()
201
207
@@ -228,11 +234,12 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
228
234
attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length )
229
235
230
236
query = attn .to_q (hidden_states )
231
- query = attn .head_to_batch_dim (query )
232
237
233
238
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
234
239
key = attn .to_k (encoder_hidden_states )
235
240
value = attn .to_v (encoder_hidden_states )
241
+
242
+ query = attn .head_to_batch_dim (query )
236
243
key = attn .head_to_batch_dim (key )
237
244
value = attn .head_to_batch_dim (value )
238
245
0 commit comments