@@ -509,15 +509,12 @@ def __init__(
509
509
):
510
510
super ().__init__ ()
511
511
512
- if rotary_dim != head_size :
513
- raise ValueError (
514
- f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
515
- rotary_dim != head_size ({ rotary_dim } !={ head_size } )." )
516
512
if is_neox_style is False :
517
513
raise ValueError (
518
514
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
519
515
)
520
516
517
+ self .rotary_dim = rotary_dim
521
518
self .head_size = head_size
522
519
self .max_position_embeddings = max_position_embeddings
523
520
self .original_max_position_embeddings = original_max_position_embeddings
@@ -557,7 +554,7 @@ def __init__(
557
554
def _compute_inv_freq (self , rescale_factors : List [float ]) -> torch .Tensor :
558
555
rescale_factors = torch .tensor (rescale_factors , dtype = torch .float32 )
559
556
inv_freq = 1.0 / (rescale_factors * (self .base ** (torch .arange (
560
- 0 , self .head_size , 2 , dtype = torch .float ) / self .head_size )))
557
+ 0 , self .rotary_dim , 2 , dtype = torch .float ) / self .rotary_dim )))
561
558
return inv_freq
562
559
563
560
def _compute_cos_sin_cache (
@@ -596,8 +593,15 @@ def forward(
596
593
cos = cos .repeat (1 , 2 ).unsqueeze (- 2 )
597
594
sin = sin .repeat (1 , 2 ).unsqueeze (- 2 )
598
595
599
- query = query * cos + _rotate_neox (query ) * sin
600
- key = key * cos + _rotate_neox (key ) * sin
596
+ query_rot = query [..., :self .rotary_dim ]
597
+ query_pass = query [..., self .rotary_dim :]
598
+ query_rot = query_rot * cos + _rotate_neox (query_rot ) * sin
599
+ query = torch .cat ((query_rot , query_pass ), dim = - 1 )
600
+
601
+ key_rot = key [..., :self .rotary_dim ]
602
+ key_pass = key [..., self .rotary_dim :]
603
+ key_rot = key_rot * cos + _rotate_neox (key_rot ) * sin
604
+ key = torch .cat ((key_rot , key_pass ), dim = - 1 )
601
605
602
606
return query .flatten (- 2 ), key .flatten (- 2 )
603
607
0 commit comments