@@ -120,8 +120,10 @@ def forward(
120
120
encoder_hidden_states : torch .Tensor ,
121
121
temb : torch .Tensor ,
122
122
image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
123
+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
123
124
) -> torch .Tensor :
124
125
text_seq_length = encoder_hidden_states .size (1 )
126
+ attention_kwargs = attention_kwargs or {}
125
127
126
128
# norm & modulate
127
129
norm_hidden_states , norm_encoder_hidden_states , gate_msa , enc_gate_msa = self .norm1 (
@@ -133,6 +135,7 @@ def forward(
133
135
hidden_states = norm_hidden_states ,
134
136
encoder_hidden_states = norm_encoder_hidden_states ,
135
137
image_rotary_emb = image_rotary_emb ,
138
+ ** attention_kwargs ,
136
139
)
137
140
138
141
hidden_states = hidden_states + gate_msa * attn_hidden_states
@@ -498,6 +501,7 @@ def custom_forward(*inputs):
498
501
encoder_hidden_states ,
499
502
emb ,
500
503
image_rotary_emb ,
504
+ attention_kwargs ,
501
505
** ckpt_kwargs ,
502
506
)
503
507
else :
@@ -506,6 +510,7 @@ def custom_forward(*inputs):
506
510
encoder_hidden_states = encoder_hidden_states ,
507
511
temb = emb ,
508
512
image_rotary_emb = image_rotary_emb ,
513
+ attention_kwargs = attention_kwargs ,
509
514
)
510
515
511
516
if not self .config .use_rotary_positional_embeddings :
0 commit comments