@@ -220,9 +220,19 @@ def forward(
220
220
221
221
if past_key_value is not None :
222
222
# sin and cos are specific to RoPE models; cache_position needed for the static cache
223
- cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position }
223
+ cache_kwargs = {
224
+ "sin" : sin ,
225
+ "cos" : cos ,
226
+ "cache_position" : cache_position ,
227
+ "sliding_window" : self .sliding_window ,
228
+ }
224
229
key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
225
230
231
+ # Here we need to slice as we use a static cache by default, but FA2 does not support it
232
+ if attention_mask is not None and self .config ._attn_implementation == "flash_attention_2" :
233
+ seq_len = attention_mask .shape [- 1 ]
234
+ key_states , value_states = key_states [:, :, :seq_len , :], value_states [:, :, :seq_len , :]
235
+
226
236
attention_interface : Callable = eager_attention_forward
227
237
if self .config ._attn_implementation != "eager" :
228
238
if self .config ._attn_implementation == "sdpa" and kwargs .get ("output_attentions" , False ):
@@ -276,20 +286,30 @@ def forward(
276
286
output_attentions : Optional [bool ] = False ,
277
287
use_cache : Optional [bool ] = False ,
278
288
cache_position : Optional [torch .LongTensor ] = None ,
289
+ last_cache_position : int = 0 ,
290
+ ** kwargs ,
279
291
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
280
292
if self .is_sliding and attention_mask is not None : # efficient SDPA and no padding
281
- # Flash-attn is a 2D tensor
293
+ # In prefill, we may be larger than sliding window
294
+ effective_seq_len = max (cache_position .shape [0 ], self .sliding_window )
295
+ # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
296
+ # thus we must slice from the right (at most `effective_seq_len` elements)
282
297
if self .config ._attn_implementation == "flash_attention_2" :
283
- if past_key_value is not None : # when decoding
284
- attention_mask = attention_mask [:, - self .sliding_window :]
298
+ attention_mask = attention_mask [:, - effective_seq_len :]
299
+ # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
300
+ # from the left, with an offset if we are beyond the sliding window
285
301
else :
286
302
min_dtype = torch .finfo (hidden_states .dtype ).min
287
303
sliding_window_mask = torch .tril (
288
304
torch .ones_like (attention_mask , dtype = torch .bool ), diagonal = - self .sliding_window
289
305
)
290
306
attention_mask = torch .where (sliding_window_mask , min_dtype , attention_mask )
291
- if attention_mask .shape [- 1 ] <= 1 : # when decoding
292
- attention_mask = attention_mask [:, :, :, - self .sliding_window :]
307
+ # In case we are beyond the sliding window, we need to correctly offset the mask slicing
308
+ # `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
309
+ offset = last_cache_position - effective_seq_len
310
+ # Should only be used when beyond the sliding window (i.e. offset > 0)
311
+ offset = max (0 , offset )
312
+ attention_mask = attention_mask [:, :, :, offset : offset + effective_seq_len ]
293
313
294
314
residual = hidden_states
295
315
@@ -305,6 +325,7 @@ def forward(
305
325
output_attentions = output_attentions ,
306
326
use_cache = use_cache ,
307
327
cache_position = cache_position ,
328
+ ** kwargs ,
308
329
)
309
330
hidden_states = self .post_attention_layernorm (hidden_states )
310
331
hidden_states = residual + hidden_states
@@ -549,6 +570,7 @@ def forward(
549
570
output_hidden_states : Optional [bool ] = None ,
550
571
return_dict : Optional [bool ] = None ,
551
572
cache_position : Optional [torch .LongTensor ] = None ,
573
+ last_cache_position : Optional [int ] = None ,
552
574
** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
553
575
) -> Union [Tuple , BaseModelOutputWithPast ]:
554
576
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -589,6 +611,16 @@ def forward(
589
611
if position_ids is None :
590
612
position_ids = cache_position .unsqueeze (0 )
591
613
614
+ # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
615
+ # (retrieving the same value from `cache_position` later on would crash dynamo)
616
+ if last_cache_position is None :
617
+ last_cache_position = 0
618
+ if attention_mask is not None :
619
+ # In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
620
+ # It will break dynamo tracing but there are no way around it (and it should never happen in practice)
621
+ last_cache_position = (
622
+ attention_mask .shape [- 1 ] if attention_mask .dim () == 2 else cache_position [- 1 ].item ()
623
+ )
592
624
causal_mask = self ._update_causal_mask (
593
625
attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
594
626
)
@@ -624,6 +656,7 @@ def forward(
624
656
output_attentions ,
625
657
use_cache ,
626
658
cache_position ,
659
+ last_cache_position ,
627
660
)
628
661
else :
629
662
layer_outputs = decoder_layer (
@@ -635,6 +668,7 @@ def forward(
635
668
output_attentions = output_attentions ,
636
669
use_cache = use_cache ,
637
670
cache_position = cache_position ,
671
+ last_cache_position = last_cache_position ,
638
672
** flash_attn_kwargs ,
639
673
)
640
674
@@ -850,6 +884,7 @@ def forward(
850
884
output_hidden_states = output_hidden_states ,
851
885
return_dict = return_dict ,
852
886
cache_position = cache_position ,
887
+ ** loss_kwargs ,
853
888
)
854
889
855
890
hidden_states = outputs [0 ]
@@ -918,6 +953,10 @@ def prepare_inputs_for_generation(
918
953
# The clone here is for the same reason as for `position_ids`.
919
954
model_inputs = {"input_ids" : input_ids .clone (memory_format = torch .contiguous_format ), "inputs_embeds" : None }
920
955
956
+ # This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
957
+ # (retrieving the same value from `cache_position` later on would crash dynamo)
958
+ model_inputs ["last_cache_position" ] = attention_mask .shape [- 1 ] if attention_mask is not None else 0
959
+
921
960
if (
922
961
isinstance (past_key_values , HybridCache )
923
962
and attention_mask .ndim == 2
0 commit comments