Skip to content

Commit 79d4bc7

Browse files
authored
[causal mask] fix preparation with multi-gpu (#37612)
* fix multi-gpu * forgot non-copied models * fixup
1 parent 7bb619d commit 79d4bc7

File tree

67 files changed

+278
-481
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+278
-481
lines changed

src/transformers/generation/utils.py

-3
Original file line numberDiff line numberDiff line change
@@ -557,10 +557,8 @@ def prepare_inputs_for_generation(
557557
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
558558
if model_inputs["inputs_embeds"] is not None:
559559
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
560-
device = model_inputs["inputs_embeds"].device
561560
else:
562561
batch_size, sequence_length = model_inputs[input_ids_key].shape
563-
device = model_inputs[input_ids_key].device
564562

565563
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
566564
# the 4D causal mask exists, it should be present in the base model (XXXModel class) or in its decoder.
@@ -586,7 +584,6 @@ def prepare_inputs_for_generation(
586584
sequence_length=sequence_length,
587585
target_length=past_key_values.get_max_cache_shape(),
588586
dtype=self.dtype,
589-
device=device,
590587
cache_position=cache_position,
591588
batch_size=batch_size,
592589
config=self.config,

src/transformers/models/aria/modeling_aria.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1003,7 +1003,7 @@ def _update_causal_mask(
10031003
):
10041004
return None
10051005

1006-
dtype, device = input_tensor.dtype, input_tensor.device
1006+
dtype = input_tensor.dtype
10071007
sequence_length = input_tensor.shape[1]
10081008
if using_static_cache:
10091009
target_length = past_key_values.get_max_cache_shape()
@@ -1020,7 +1020,6 @@ def _update_causal_mask(
10201020
sequence_length=sequence_length,
10211021
target_length=target_length,
10221022
dtype=dtype,
1023-
device=device,
10241023
cache_position=cache_position,
10251024
batch_size=input_tensor.shape[0],
10261025
)
@@ -1045,7 +1044,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
10451044
sequence_length: int,
10461045
target_length: int,
10471046
dtype: torch.dtype,
1048-
device: torch.device,
10491047
cache_position: torch.Tensor,
10501048
batch_size: int,
10511049
**kwargs,
@@ -1065,8 +1063,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
10651063
to account for the 0 padding, the part of the cache that is not filled yet.
10661064
dtype (`torch.dtype`):
10671065
The dtype to use for the 4D attention mask.
1068-
device (`torch.device`):
1069-
The device to place the 4D attention mask on.
10701066
cache_position (`torch.Tensor`):
10711067
Indices depicting the position of the input sequence tokens in the sequence.
10721068
batch_size (`torch.Tensor`):
@@ -1078,11 +1074,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
10781074
else:
10791075
min_dtype = torch.finfo(dtype).min
10801076
causal_mask = torch.full(
1081-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1077+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
10821078
)
10831079
if sequence_length != 1:
10841080
causal_mask = torch.triu(causal_mask, diagonal=1)
1085-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1081+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
10861082
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
10871083
if attention_mask is not None:
10881084
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/bamba/modeling_bamba.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1313,7 +1313,7 @@ def _update_causal_mask(
13131313
):
13141314
return None
13151315

1316-
dtype, device = input_tensor.dtype, input_tensor.device
1316+
dtype = input_tensor.dtype
13171317
sequence_length = input_tensor.shape[1]
13181318
target_length = (
13191319
attention_mask.shape[-1]
@@ -1327,7 +1327,6 @@ def _update_causal_mask(
13271327
sequence_length=sequence_length,
13281328
target_length=target_length,
13291329
dtype=dtype,
1330-
device=device,
13311330
cache_position=cache_position,
13321331
batch_size=input_tensor.shape[0],
13331332
)
@@ -1352,7 +1351,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
13521351
sequence_length: int,
13531352
target_length: int,
13541353
dtype: torch.dtype,
1355-
device: torch.device,
13561354
cache_position: torch.Tensor,
13571355
batch_size: int,
13581356
**kwargs,
@@ -1372,8 +1370,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
13721370
to account for the 0 padding, the part of the cache that is not filled yet.
13731371
dtype (`torch.dtype`):
13741372
The dtype to use for the 4D attention mask.
1375-
device (`torch.device`):
1376-
The device to place the 4D attention mask on.
13771373
cache_position (`torch.Tensor`):
13781374
Indices depicting the position of the input sequence tokens in the sequence.
13791375
batch_size (`torch.Tensor`):
@@ -1385,11 +1381,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
13851381
else:
13861382
min_dtype = torch.finfo(dtype).min
13871383
causal_mask = torch.full(
1388-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1384+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
13891385
)
13901386
if sequence_length != 1:
13911387
causal_mask = torch.triu(causal_mask, diagonal=1)
1392-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1388+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
13931389
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
13941390
if attention_mask is not None:
13951391
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/bamba/modular_bamba.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def _update_causal_mask(
10811081
):
10821082
return None
10831083

1084-
dtype, device = input_tensor.dtype, input_tensor.device
1084+
dtype = input_tensor.dtype
10851085
sequence_length = input_tensor.shape[1]
10861086
target_length = (
10871087
attention_mask.shape[-1]
@@ -1095,7 +1095,6 @@ def _update_causal_mask(
10951095
sequence_length=sequence_length,
10961096
target_length=target_length,
10971097
dtype=dtype,
1098-
device=device,
10991098
cache_position=cache_position,
11001099
batch_size=input_tensor.shape[0],
11011100
)
@@ -1120,7 +1119,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
11201119
sequence_length: int,
11211120
target_length: int,
11221121
dtype: torch.dtype,
1123-
device: torch.device,
11241122
cache_position: torch.Tensor,
11251123
batch_size: int,
11261124
**kwargs,
@@ -1140,8 +1138,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
11401138
to account for the 0 padding, the part of the cache that is not filled yet.
11411139
dtype (`torch.dtype`):
11421140
The dtype to use for the 4D attention mask.
1143-
device (`torch.device`):
1144-
The device to place the 4D attention mask on.
11451141
cache_position (`torch.Tensor`):
11461142
Indices depicting the position of the input sequence tokens in the sequence.
11471143
batch_size (`torch.Tensor`):
@@ -1153,11 +1149,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
11531149
else:
11541150
min_dtype = torch.finfo(dtype).min
11551151
causal_mask = torch.full(
1156-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1152+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
11571153
)
11581154
if sequence_length != 1:
11591155
causal_mask = torch.triu(causal_mask, diagonal=1)
1160-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1156+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
11611157
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
11621158
if attention_mask is not None:
11631159
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/bloom/modeling_bloom.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def _update_causal_mask(
773773
):
774774
return None
775775

776-
dtype, device = input_tensor.dtype, input_tensor.device
776+
dtype = input_tensor.dtype
777777
sequence_length = input_tensor.shape[1]
778778
if using_static_cache:
779779
target_length = past_key_values.get_max_cache_shape()
@@ -790,7 +790,6 @@ def _update_causal_mask(
790790
sequence_length=sequence_length,
791791
target_length=target_length,
792792
dtype=dtype,
793-
device=device,
794793
cache_position=cache_position,
795794
batch_size=input_tensor.shape[0],
796795
)
@@ -816,7 +815,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
816815
sequence_length: int,
817816
target_length: int,
818817
dtype: torch.dtype,
819-
device: torch.device,
820818
cache_position: torch.Tensor,
821819
batch_size: int,
822820
**kwargs,
@@ -836,8 +834,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
836834
to account for the 0 padding, the part of the cache that is not filled yet.
837835
dtype (`torch.dtype`):
838836
The dtype to use for the 4D attention mask.
839-
device (`torch.device`):
840-
The device to place the 4D attention mask on.
841837
cache_position (`torch.Tensor`):
842838
Indices depicting the position of the input sequence tokens in the sequence.
843839
batch_size (`torch.Tensor`):
@@ -849,11 +845,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
849845
else:
850846
min_dtype = torch.finfo(dtype).min
851847
causal_mask = torch.full(
852-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
848+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
853849
)
854850
if sequence_length != 1:
855851
causal_mask = torch.triu(causal_mask, diagonal=1)
856-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
852+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
857853
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
858854
if attention_mask is not None:
859855
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/chameleon/modeling_chameleon.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ def _update_causal_mask(
14061406
):
14071407
return None
14081408

1409-
dtype, device = input_tensor.dtype, input_tensor.device
1409+
dtype = input_tensor.dtype
14101410
sequence_length = input_tensor.shape[1]
14111411
if using_static_cache:
14121412
target_length = past_key_values.get_max_cache_shape()
@@ -1423,7 +1423,6 @@ def _update_causal_mask(
14231423
sequence_length=sequence_length,
14241424
target_length=target_length,
14251425
dtype=dtype,
1426-
device=device,
14271426
cache_position=cache_position,
14281427
batch_size=input_tensor.shape[0],
14291428
)
@@ -1449,7 +1448,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
14491448
sequence_length: int,
14501449
target_length: int,
14511450
dtype: torch.dtype,
1452-
device: torch.device,
14531451
cache_position: torch.Tensor,
14541452
batch_size: int,
14551453
**kwargs,
@@ -1469,8 +1467,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
14691467
to account for the 0 padding, the part of the cache that is not filled yet.
14701468
dtype (`torch.dtype`):
14711469
The dtype to use for the 4D attention mask.
1472-
device (`torch.device`):
1473-
The device to place the 4D attention mask on.
14741470
cache_position (`torch.Tensor`):
14751471
Indices depicting the position of the input sequence tokens in the sequence.
14761472
batch_size (`torch.Tensor`):
@@ -1482,11 +1478,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
14821478
else:
14831479
min_dtype = torch.finfo(dtype).min
14841480
causal_mask = torch.full(
1485-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
1481+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
14861482
)
14871483
if sequence_length != 1:
14881484
causal_mask = torch.triu(causal_mask, diagonal=1)
1489-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1485+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
14901486
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
14911487
if attention_mask is not None:
14921488
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/codegen/modeling_codegen.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def _update_causal_mask(
619619
):
620620
return None
621621

622-
dtype, device = input_tensor.dtype, input_tensor.device
622+
dtype = input_tensor.dtype
623623
sequence_length = input_tensor.shape[1]
624624
if using_static_cache:
625625
target_length = past_key_values.get_max_cache_shape()
@@ -636,7 +636,6 @@ def _update_causal_mask(
636636
sequence_length=sequence_length,
637637
target_length=target_length,
638638
dtype=dtype,
639-
device=device,
640639
cache_position=cache_position,
641640
batch_size=input_tensor.shape[0],
642641
)
@@ -662,7 +661,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
662661
sequence_length: int,
663662
target_length: int,
664663
dtype: torch.dtype,
665-
device: torch.device,
666664
cache_position: torch.Tensor,
667665
batch_size: int,
668666
**kwargs,
@@ -682,8 +680,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
682680
to account for the 0 padding, the part of the cache that is not filled yet.
683681
dtype (`torch.dtype`):
684682
The dtype to use for the 4D attention mask.
685-
device (`torch.device`):
686-
The device to place the 4D attention mask on.
687683
cache_position (`torch.Tensor`):
688684
Indices depicting the position of the input sequence tokens in the sequence.
689685
batch_size (`torch.Tensor`):
@@ -695,11 +691,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
695691
else:
696692
min_dtype = torch.finfo(dtype).min
697693
causal_mask = torch.full(
698-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
694+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
699695
)
700696
if sequence_length != 1:
701697
causal_mask = torch.triu(causal_mask, diagonal=1)
702-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
698+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
703699
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
704700
if attention_mask is not None:
705701
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/cohere/modeling_cohere.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def _update_causal_mask(
652652
):
653653
return None
654654

655-
dtype, device = input_tensor.dtype, input_tensor.device
655+
dtype = input_tensor.dtype
656656
sequence_length = input_tensor.shape[1]
657657
if using_static_cache:
658658
target_length = past_key_values.get_max_cache_shape()
@@ -669,7 +669,6 @@ def _update_causal_mask(
669669
sequence_length=sequence_length,
670670
target_length=target_length,
671671
dtype=dtype,
672-
device=device,
673672
cache_position=cache_position,
674673
batch_size=input_tensor.shape[0],
675674
)
@@ -694,7 +693,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
694693
sequence_length: int,
695694
target_length: int,
696695
dtype: torch.dtype,
697-
device: torch.device,
698696
cache_position: torch.Tensor,
699697
batch_size: int,
700698
**kwargs,
@@ -714,8 +712,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
714712
to account for the 0 padding, the part of the cache that is not filled yet.
715713
dtype (`torch.dtype`):
716714
The dtype to use for the 4D attention mask.
717-
device (`torch.device`):
718-
The device to place the 4D attention mask on.
719715
cache_position (`torch.Tensor`):
720716
Indices depicting the position of the input sequence tokens in the sequence.
721717
batch_size (`torch.Tensor`):
@@ -727,11 +723,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
727723
else:
728724
min_dtype = torch.finfo(dtype).min
729725
causal_mask = torch.full(
730-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
726+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
731727
)
732728
if sequence_length != 1:
733729
causal_mask = torch.triu(causal_mask, diagonal=1)
734-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
730+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
735731
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
736732
if attention_mask is not None:
737733
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

src/transformers/models/cohere2/modeling_cohere2.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
686686
sequence_length: int,
687687
target_length: int,
688688
dtype: torch.dtype,
689-
device: torch.device,
690689
cache_position: torch.Tensor,
691690
batch_size: int,
692691
**kwargs,
@@ -706,8 +705,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
706705
to account for the 0 padding, the part of the cache that is not filled yet.
707706
dtype (`torch.dtype`):
708707
The dtype to use for the 4D attention mask.
709-
device (`torch.device`):
710-
The device to place the 4D attention mask on.
711708
cache_position (`torch.Tensor`):
712709
Indices depicting the position of the input sequence tokens in the sequence.
713710
batch_size (`torch.Tensor`):
@@ -719,11 +716,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
719716
else:
720717
min_dtype = torch.finfo(dtype).min
721718
causal_mask = torch.full(
722-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
719+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
723720
)
724721
if sequence_length != 1:
725722
causal_mask = torch.triu(causal_mask, diagonal=1)
726-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
723+
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
727724
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
728725
if attention_mask is not None:
729726
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit

0 commit comments

Comments
 (0)