@@ -188,7 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188
188
self ._chunk_dim = dim
189
189
190
190
def forward (
191
- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor
191
+ self ,
192
+ hidden_states : torch .FloatTensor ,
193
+ encoder_hidden_states : torch .FloatTensor ,
194
+ temb : torch .FloatTensor ,
195
+ joint_attention_kwargs : Dict [str , Any ] = None ,
192
196
):
193
197
if self .use_dual_attention :
194
198
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp , norm_hidden_states2 , gate_msa2 = self .norm1 (
@@ -206,15 +210,17 @@ def forward(
206
210
207
211
# Attention.
208
212
attn_output , context_attn_output = self .attn (
209
- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states
213
+ hidden_states = norm_hidden_states ,
214
+ encoder_hidden_states = norm_encoder_hidden_states ,
215
+ ** joint_attention_kwargs ,
210
216
)
211
217
212
218
# Process attention outputs for the `hidden_states`.
213
219
attn_output = gate_msa .unsqueeze (1 ) * attn_output
214
220
hidden_states = hidden_states + attn_output
215
221
216
222
if self .use_dual_attention :
217
- attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 )
223
+ attn_output2 = self .attn2 (hidden_states = norm_hidden_states2 , ** joint_attention_kwargs )
218
224
attn_output2 = gate_msa2 .unsqueeze (1 ) * attn_output2
219
225
hidden_states = hidden_states + attn_output2
220
226
0 commit comments