@@ -303,19 +303,20 @@ def __call__(
303
303
value = attn .to_v (encoder_hidden_states )
304
304
305
305
query = query .unflatten (- 1 , (attn .heads , - 1 )).transpose (1 ,2 ).flatten (end_dim = 1 )
306
- key = key .unflatten (- 1 , (attn .heads , - 1 )).transpose (1 ,2 ).flatten (end_dim = 1 )
306
+ key_t = key .transpose (1 ,2 ).unflatten (1 , (attn .heads , - 1 )).flatten (end_dim = 1 )
307
+ del key
307
308
value = value .unflatten (- 1 , (attn .heads , - 1 )).transpose (1 ,2 ).flatten (end_dim = 1 )
308
309
309
310
dtype = query .dtype
310
311
# TODO: do we still need to do *everything* in float32, given how we delay the division?
311
312
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
312
313
if attn .upcast_attention :
313
314
query = query .float ()
314
- key = key .float ()
315
+ key_t = key_t .float ()
315
316
316
317
bytes_per_token = torch .finfo (query .dtype ).bits // 8
317
318
batch_x_heads , q_tokens , _ = query .shape
318
- _ , k_tokens , _ = key .shape
319
+ _ , _ , k_tokens = key_t .shape
319
320
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
320
321
321
322
query_chunk_size = self .query_chunk_size
@@ -329,7 +330,7 @@ def __call__(
329
330
330
331
hidden_states = efficient_dot_product_attention (
331
332
query ,
332
- key ,
333
+ key_t ,
333
334
value ,
334
335
query_chunk_size = query_chunk_size ,
335
336
kv_chunk_size = kv_chunk_size ,
0 commit comments