Skip to content

Commit b673c16

Browse files
CyrilvallezArthurZucker
authored andcommitted
Fix mask slicing for models with HybridCache (#35681)
* correctly slice * check mask * Update modular_gemma2.py * fix * add tests * fix typo * finally fix mask slicing * Finally correctly slice in all cases!! * add test for all attention functions * small fix in tests * trick around dynamo tracing issue * last update * more robust * kwargs propagation * make it explicit for checkpointing * apply modular
1 parent aa3e590 commit b673c16

File tree

6 files changed

+232
-22
lines changed

6 files changed

+232
-22
lines changed

src/transformers/models/cohere2/modeling_cohere2.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ def forward(
255255
}
256256
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
257257

258+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
259+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
260+
seq_len = attention_mask.shape[-1]
261+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
262+
258263
attention_interface: Callable = eager_attention_forward
259264
if self.config._attn_implementation != "eager":
260265
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -318,6 +323,7 @@ def forward(
318323
output_attentions: Optional[bool] = False,
319324
use_cache: Optional[bool] = False,
320325
cache_position: Optional[torch.LongTensor] = None,
326+
last_cache_position: int = 0,
321327
**kwargs: Unpack[FlashAttentionKwargs],
322328
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
323329
"""
@@ -338,21 +344,30 @@ def forward(
338344
(see `past_key_values`).
339345
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
340346
Indices depicting the position of the input sequence tokens in the sequence
347+
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
341348
"""
342349

343350
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
344-
# Flash-attn is a 2D tensor
351+
# In prefill, we may be larger than sliding window
352+
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
353+
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
354+
# thus we must slice from the right (at most `effective_seq_len` elements)
345355
if self.config._attn_implementation == "flash_attention_2":
346-
if past_key_value is not None: # when decoding
347-
attention_mask = attention_mask[:, -self.sliding_window :]
356+
attention_mask = attention_mask[:, -effective_seq_len:]
357+
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
358+
# from the left, with an offset if we are beyond the sliding window
348359
else:
349360
min_dtype = torch.finfo(hidden_states.dtype).min
350361
sliding_window_mask = torch.tril(
351362
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
352363
)
353364
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
354-
if attention_mask.shape[-1] <= 1: # when decoding
355-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
365+
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
366+
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
367+
offset = last_cache_position - effective_seq_len
368+
# Should only be used when beyond the sliding window (i.e. offset > 0)
369+
offset = max(0, offset)
370+
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
356371

357372
residual = hidden_states
358373

@@ -551,6 +566,7 @@ def forward(
551566
output_hidden_states: Optional[bool] = None,
552567
return_dict: Optional[bool] = None,
553568
cache_position: Optional[torch.LongTensor] = None,
569+
last_cache_position: Optional[int] = None,
554570
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
555571
) -> Union[Tuple, BaseModelOutputWithPast]:
556572
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -590,9 +606,20 @@ def forward(
590606
if position_ids is None:
591607
position_ids = cache_position.unsqueeze(0)
592608

609+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
610+
# (retrieving the same value from `cache_position` later on would crash dynamo)
611+
if last_cache_position is None:
612+
last_cache_position = 0
613+
if attention_mask is not None:
614+
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
615+
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
616+
last_cache_position = (
617+
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
618+
)
593619
causal_mask = self._update_causal_mask(
594620
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
595621
)
622+
596623
hidden_states = inputs_embeds
597624

598625
# create position embeddings to be shared across the decoder layers
@@ -616,6 +643,7 @@ def forward(
616643
output_attentions,
617644
use_cache,
618645
cache_position,
646+
last_cache_position,
619647
)
620648
else:
621649
layer_outputs = decoder_layer(
@@ -626,6 +654,7 @@ def forward(
626654
output_attentions=output_attentions,
627655
use_cache=use_cache,
628656
cache_position=cache_position,
657+
last_cache_position=last_cache_position,
629658
**flash_attn_kwargs,
630659
)
631660

@@ -908,6 +937,10 @@ def prepare_inputs_for_generation(
908937
# The clone here is for the same reason as for `position_ids`.
909938
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
910939

940+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
941+
# (retrieving the same value from `cache_position` later on would crash dynamo)
942+
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
943+
911944
if (
912945
isinstance(past_key_values, HybridCache)
913946
and attention_mask.ndim == 2

src/transformers/models/cohere2/modular_cohere2.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,11 @@ def forward(
296296
}
297297
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
298298

299+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
300+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
301+
seq_len = attention_mask.shape[-1]
302+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
303+
299304
attention_interface: Callable = eager_attention_forward
300305
if self.config._attn_implementation != "eager":
301306
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -340,6 +345,7 @@ def forward(
340345
output_attentions: Optional[bool] = False,
341346
use_cache: Optional[bool] = False,
342347
cache_position: Optional[torch.LongTensor] = None,
348+
last_cache_position: int = 0,
343349
**kwargs: Unpack[FlashAttentionKwargs],
344350
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
345351
"""
@@ -360,21 +366,30 @@ def forward(
360366
(see `past_key_values`).
361367
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
362368
Indices depicting the position of the input sequence tokens in the sequence
369+
last_cache_position (`int`): equivalent to `cache_position[-1]` but allow indexing without breaking dynamo tracing
363370
"""
364371

365372
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
366-
# Flash-attn is a 2D tensor
373+
# In prefill, we may be larger than sliding window
374+
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
375+
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
376+
# thus we must slice from the right (at most `effective_seq_len` elements)
367377
if self.config._attn_implementation == "flash_attention_2":
368-
if past_key_value is not None: # when decoding
369-
attention_mask = attention_mask[:, -self.sliding_window :]
378+
attention_mask = attention_mask[:, -effective_seq_len:]
379+
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
380+
# from the left, with an offset if we are beyond the sliding window
370381
else:
371382
min_dtype = torch.finfo(hidden_states.dtype).min
372383
sliding_window_mask = torch.tril(
373384
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
374385
)
375386
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
376-
if attention_mask.shape[-1] <= 1: # when decoding
377-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
387+
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
388+
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
389+
offset = last_cache_position - effective_seq_len
390+
# Should only be used when beyond the sliding window (i.e. offset > 0)
391+
offset = max(0, offset)
392+
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
378393

379394
residual = hidden_states
380395

@@ -434,6 +449,7 @@ def forward(
434449
output_hidden_states: Optional[bool] = None,
435450
return_dict: Optional[bool] = None,
436451
cache_position: Optional[torch.LongTensor] = None,
452+
last_cache_position: Optional[int] = None,
437453
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
438454
) -> Union[Tuple, BaseModelOutputWithPast]:
439455
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -473,9 +489,20 @@ def forward(
473489
if position_ids is None:
474490
position_ids = cache_position.unsqueeze(0)
475491

492+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
493+
# (retrieving the same value from `cache_position` later on would crash dynamo)
494+
if last_cache_position is None:
495+
last_cache_position = 0
496+
if attention_mask is not None:
497+
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
498+
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
499+
last_cache_position = (
500+
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
501+
)
476502
causal_mask = self._update_causal_mask(
477503
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
478504
)
505+
479506
hidden_states = inputs_embeds
480507

481508
# create position embeddings to be shared across the decoder layers
@@ -499,6 +526,7 @@ def forward(
499526
output_attentions,
500527
use_cache,
501528
cache_position,
529+
last_cache_position,
502530
)
503531
else:
504532
layer_outputs = decoder_layer(
@@ -509,6 +537,7 @@ def forward(
509537
output_attentions=output_attentions,
510538
use_cache=use_cache,
511539
cache_position=cache_position,
540+
last_cache_position=last_cache_position,
512541
**flash_attn_kwargs,
513542
)
514543

@@ -578,6 +607,10 @@ def prepare_inputs_for_generation(
578607
# The clone here is for the same reason as for `position_ids`.
579608
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
580609

610+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
611+
# (retrieving the same value from `cache_position` later on would crash dynamo)
612+
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
613+
581614
if (
582615
isinstance(past_key_values, HybridCache)
583616
and attention_mask.ndim == 2

src/transformers/models/gemma2/modeling_gemma2.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,19 @@ def forward(
220220

221221
if past_key_value is not None:
222222
# sin and cos are specific to RoPE models; cache_position needed for the static cache
223-
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
223+
cache_kwargs = {
224+
"sin": sin,
225+
"cos": cos,
226+
"cache_position": cache_position,
227+
"sliding_window": self.sliding_window,
228+
}
224229
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
225230

231+
# Here we need to slice as we use a static cache by default, but FA2 does not support it
232+
if attention_mask is not None and self.config._attn_implementation == "flash_attention_2":
233+
seq_len = attention_mask.shape[-1]
234+
key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :]
235+
226236
attention_interface: Callable = eager_attention_forward
227237
if self.config._attn_implementation != "eager":
228238
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
@@ -276,20 +286,30 @@ def forward(
276286
output_attentions: Optional[bool] = False,
277287
use_cache: Optional[bool] = False,
278288
cache_position: Optional[torch.LongTensor] = None,
289+
last_cache_position: int = 0,
290+
**kwargs,
279291
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
280292
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
281-
# Flash-attn is a 2D tensor
293+
# In prefill, we may be larger than sliding window
294+
effective_seq_len = max(cache_position.shape[0], self.sliding_window)
295+
# For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]),
296+
# thus we must slice from the right (at most `effective_seq_len` elements)
282297
if self.config._attn_implementation == "flash_attention_2":
283-
if past_key_value is not None: # when decoding
284-
attention_mask = attention_mask[:, -self.sliding_window :]
298+
attention_mask = attention_mask[:, -effective_seq_len:]
299+
# Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice
300+
# from the left, with an offset if we are beyond the sliding window
285301
else:
286302
min_dtype = torch.finfo(hidden_states.dtype).min
287303
sliding_window_mask = torch.tril(
288304
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
289305
)
290306
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
291-
if attention_mask.shape[-1] <= 1: # when decoding
292-
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
307+
# In case we are beyond the sliding window, we need to correctly offset the mask slicing
308+
# `last_cache_position` is equivalent to `cache_position[-1]` but without breaking dynamo
309+
offset = last_cache_position - effective_seq_len
310+
# Should only be used when beyond the sliding window (i.e. offset > 0)
311+
offset = max(0, offset)
312+
attention_mask = attention_mask[:, :, :, offset : offset + effective_seq_len]
293313

294314
residual = hidden_states
295315

@@ -305,6 +325,7 @@ def forward(
305325
output_attentions=output_attentions,
306326
use_cache=use_cache,
307327
cache_position=cache_position,
328+
**kwargs,
308329
)
309330
hidden_states = self.post_attention_layernorm(hidden_states)
310331
hidden_states = residual + hidden_states
@@ -549,6 +570,7 @@ def forward(
549570
output_hidden_states: Optional[bool] = None,
550571
return_dict: Optional[bool] = None,
551572
cache_position: Optional[torch.LongTensor] = None,
573+
last_cache_position: Optional[int] = None,
552574
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
553575
) -> Union[Tuple, BaseModelOutputWithPast]:
554576
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -589,6 +611,16 @@ def forward(
589611
if position_ids is None:
590612
position_ids = cache_position.unsqueeze(0)
591613

614+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
615+
# (retrieving the same value from `cache_position` later on would crash dynamo)
616+
if last_cache_position is None:
617+
last_cache_position = 0
618+
if attention_mask is not None:
619+
# In case a 4d mask is passed directly without using `generate`, we have to rely on cache_position
620+
# It will break dynamo tracing but there are no way around it (and it should never happen in practice)
621+
last_cache_position = (
622+
attention_mask.shape[-1] if attention_mask.dim() == 2 else cache_position[-1].item()
623+
)
592624
causal_mask = self._update_causal_mask(
593625
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
594626
)
@@ -624,6 +656,7 @@ def forward(
624656
output_attentions,
625657
use_cache,
626658
cache_position,
659+
last_cache_position,
627660
)
628661
else:
629662
layer_outputs = decoder_layer(
@@ -635,6 +668,7 @@ def forward(
635668
output_attentions=output_attentions,
636669
use_cache=use_cache,
637670
cache_position=cache_position,
671+
last_cache_position=last_cache_position,
638672
**flash_attn_kwargs,
639673
)
640674

@@ -850,6 +884,7 @@ def forward(
850884
output_hidden_states=output_hidden_states,
851885
return_dict=return_dict,
852886
cache_position=cache_position,
887+
**loss_kwargs,
853888
)
854889

855890
hidden_states = outputs[0]
@@ -918,6 +953,10 @@ def prepare_inputs_for_generation(
918953
# The clone here is for the same reason as for `position_ids`.
919954
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
920955

956+
# This is needed to correctly slice the mask without data-dependent slicing later on if using dynamo tracing
957+
# (retrieving the same value from `cache_position` later on would crash dynamo)
958+
model_inputs["last_cache_position"] = attention_mask.shape[-1] if attention_mask is not None else 0
959+
921960
if (
922961
isinstance(past_key_values, HybridCache)
923962
and attention_mask.ndim == 2

0 commit comments

Comments
 (0)