Skip to content

Commit f137065

Browse files
authored
Merge pull request huggingface#9 from RyanMullins/gemma3attention
Gemma3attention is now lower triangular
2 parents 48bca47 + 576f065 commit f137065

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/transformers/models/gemma3/modeling_gemma3.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -730,8 +730,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
730730
batch_size (`torch.Tensor`):
731731
Batch size.
732732
"""
733-
if attention_mask is not None and attention_mask.dim() == 4:
734-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
733+
if attention_mask is not None and attention_mask.dim() == 4 and attention_mask.shape[2] == 1:
734+
# In this case that the mask comes already in inverted form and requires no inversion or slicing.
735735
causal_mask = attention_mask
736736
else:
737737
min_dtype = torch.finfo(dtype).min
@@ -753,7 +753,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
753753
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
754754
padding_mask, min_dtype
755755
)
756-
757756
return causal_mask
758757

759758

src/transformers/models/gemma3/modular_gemma3.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,7 @@ def _update_causal_mask(
13371337
else input_tensor.shape[1]
13381338
)
13391339

1340-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
1340+
# We generate a lower triangular causal mask here (4D).
13411341
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
13421342
attention_mask,
13431343
sequence_length=sequence_length,
@@ -1382,8 +1382,8 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
13821382
batch_size (`torch.Tensor`):
13831383
Batch size.
13841384
"""
1385-
if attention_mask is not None and attention_mask.dim() == 4:
1386-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
1385+
if attention_mask is not None and attention_mask.dim() == 4 and attention_mask.shape[2] == 1:
1386+
# In this case that the mask comes already in inverted form and requires no inversion or slicing.
13871387
causal_mask = attention_mask
13881388
else:
13891389
min_dtype = torch.finfo(dtype).min
@@ -1412,7 +1412,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
14121412
causal_mask[:, :, :, :mask_length] = causal_mask[
14131413
:, :, :, :mask_length
14141414
].masked_fill(padding_mask, min_dtype)
1415-
14161415
return causal_mask
14171416

14181417

0 commit comments

Comments
 (0)