Skip to content

Commit 0c3f56f

Browse files
committedApr 5, 2025·
remove enable_gqa and use repeat_interleave instead
1 parent 2c2b658 commit 0c3f56f

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed
 

Diff for: ‎src/diffusers/models/transformers/transformer_cosmos.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,17 @@ def __call__(
173173
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
174174
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
175175

176-
# 4. Attention
176+
# 4. Prepare for GQA
177+
key = key.repeat_interleave(query.size(3) // key.size(3), dim=3)
178+
value = value.repeat_interleave(query.size(3) // value.size(3), dim=3)
179+
180+
# 5. Attention
177181
hidden_states = F.scaled_dot_product_attention(
178-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, enable_gqa=True
182+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
179183
)
180184
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
181185

182-
# 5. Output projection
186+
# 6. Output projection
183187
hidden_states = attn.to_out[0](hidden_states)
184188
hidden_states = attn.to_out[1](hidden_states)
185189

0 commit comments

Comments
 (0)