Skip to content

Commit daccbe7

Browse files
authored
Merge pull request #6 from pcuenca/add_palma_shift_mask
Shift mask from `1:`
2 parents 404abd8 + 60ad9c5 commit daccbe7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/paligemma/modeling_paligemma.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def forward(
463463
if attention_mask.dim() == 4:
464464
# take top or bottom row of the 4d mask.
465465
# this should only be used in the initial pass with full attention on prefix.
466-
shift_attention_mask = attention_mask[:, 0, 0, :-1].squeeze(1) if not left_padding else attention_mask[:, 0, -1, :-1].squeeze(1)
466+
shift_attention_mask = attention_mask[:, 0, 0, 1:].squeeze(1) if not left_padding else attention_mask[:, 0, -1, 1:].squeeze(1)
467467
elif attention_mask.dim() == 2:
468468
# take normal slice of the attn mask
469469
shift_attention_mask = attention_mask[..., 1:]

0 commit comments

Comments
 (0)