Skip to content

Commit e3bed52

Browse files
committed
Support pass kwargs to sd3 custom attention processor
1 parent 9a92b81 commit e3bed52

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/diffusers/models/attention.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188
self._chunk_dim = dim
189189

190190
def forward(
191-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
191+
self,
192+
hidden_states: torch.FloatTensor,
193+
encoder_hidden_states: torch.FloatTensor,
194+
temb: torch.FloatTensor,
195+
joint_attention_kwargs: Dict[str, Any] = None,
192196
):
193197
if self.use_dual_attention:
194198
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -206,15 +210,17 @@ def forward(
206210

207211
# Attention.
208212
attn_output, context_attn_output = self.attn(
209-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
213+
hidden_states=norm_hidden_states,
214+
encoder_hidden_states=norm_encoder_hidden_states,
215+
**joint_attention_kwargs,
210216
)
211217

212218
# Process attention outputs for the `hidden_states`.
213219
attn_output = gate_msa.unsqueeze(1) * attn_output
214220
hidden_states = hidden_states + attn_output
215221

216222
if self.use_dual_attention:
217-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
223+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
218224
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219225
hidden_states = hidden_states + attn_output2
220226

src/diffusers/models/transformers/transformer_sd3.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -334,12 +334,16 @@ def custom_forward(*inputs):
334334
hidden_states,
335335
encoder_hidden_states,
336336
temb,
337+
joint_attention_kwargs,
337338
**ckpt_kwargs,
338339
)
339340

340341
else:
341342
encoder_hidden_states, hidden_states = block(
342-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
343+
hidden_states=hidden_states,
344+
encoder_hidden_states=encoder_hidden_states,
345+
temb=temb,
346+
joint_attention_kwargs=joint_attention_kwargs,
343347
)
344348

345349
# controlnet residual

0 commit comments

Comments
 (0)