@@ -1003,7 +1003,7 @@ def _update_causal_mask(
1003
1003
):
1004
1004
return None
1005
1005
1006
- dtype , device = input_tensor .dtype , input_tensor . device
1006
+ dtype = input_tensor .dtype
1007
1007
sequence_length = input_tensor .shape [1 ]
1008
1008
if using_static_cache :
1009
1009
target_length = past_key_values .get_max_cache_shape ()
@@ -1020,7 +1020,6 @@ def _update_causal_mask(
1020
1020
sequence_length = sequence_length ,
1021
1021
target_length = target_length ,
1022
1022
dtype = dtype ,
1023
- device = device ,
1024
1023
cache_position = cache_position ,
1025
1024
batch_size = input_tensor .shape [0 ],
1026
1025
)
@@ -1045,7 +1044,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
1045
1044
sequence_length : int ,
1046
1045
target_length : int ,
1047
1046
dtype : torch .dtype ,
1048
- device : torch .device ,
1049
1047
cache_position : torch .Tensor ,
1050
1048
batch_size : int ,
1051
1049
** kwargs ,
@@ -1065,8 +1063,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
1065
1063
to account for the 0 padding, the part of the cache that is not filled yet.
1066
1064
dtype (`torch.dtype`):
1067
1065
The dtype to use for the 4D attention mask.
1068
- device (`torch.device`):
1069
- The device to place the 4D attention mask on.
1070
1066
cache_position (`torch.Tensor`):
1071
1067
Indices depicting the position of the input sequence tokens in the sequence.
1072
1068
batch_size (`torch.Tensor`):
@@ -1078,11 +1074,11 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
1078
1074
else :
1079
1075
min_dtype = torch .finfo (dtype ).min
1080
1076
causal_mask = torch .full (
1081
- (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
1077
+ (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = cache_position . device
1082
1078
)
1083
1079
if sequence_length != 1 :
1084
1080
causal_mask = torch .triu (causal_mask , diagonal = 1 )
1085
- causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
1081
+ causal_mask *= torch .arange (target_length , device = cache_position . device ) > cache_position .reshape (- 1 , 1 )
1086
1082
causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
1087
1083
if attention_mask is not None :
1088
1084
causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
0 commit comments