Skip to content

Commit 7116fd2

Browse files
authored
Support pass kwargs to cogvideox custom attention processor (#10456)
* Support pass kwargs to cogvideox custom attention processor * remove args in cogvideox attn processor * remove unused kwargs
1 parent 553b138 commit 7116fd2

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

+5
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ def forward(
120120
encoder_hidden_states: torch.Tensor,
121121
temb: torch.Tensor,
122122
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
123+
attention_kwargs: Optional[Dict[str, Any]] = None,
123124
) -> torch.Tensor:
124125
text_seq_length = encoder_hidden_states.size(1)
126+
attention_kwargs = attention_kwargs or {}
125127

126128
# norm & modulate
127129
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
@@ -133,6 +135,7 @@ def forward(
133135
hidden_states=norm_hidden_states,
134136
encoder_hidden_states=norm_encoder_hidden_states,
135137
image_rotary_emb=image_rotary_emb,
138+
**attention_kwargs,
136139
)
137140

138141
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -498,6 +501,7 @@ def custom_forward(*inputs):
498501
encoder_hidden_states,
499502
emb,
500503
image_rotary_emb,
504+
attention_kwargs,
501505
**ckpt_kwargs,
502506
)
503507
else:
@@ -506,6 +510,7 @@ def custom_forward(*inputs):
506510
encoder_hidden_states=encoder_hidden_states,
507511
temb=emb,
508512
image_rotary_emb=image_rotary_emb,
513+
attention_kwargs=attention_kwargs,
509514
)
510515

511516
if not self.config.use_rotary_positional_embeddings:

0 commit comments

Comments
 (0)