Skip to content

Commit 538fab9

Browse files
authored
1 parent ce26b16 commit 538fab9

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

vllm/model_executor/layers/rotary_embedding.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -509,15 +509,12 @@ def __init__(
509509
):
510510
super().__init__()
511511

512-
if rotary_dim != head_size:
513-
raise ValueError(
514-
f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
515-
rotary_dim != head_size ({rotary_dim}!={head_size}).")
516512
if is_neox_style is False:
517513
raise ValueError(
518514
"`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
519515
)
520516

517+
self.rotary_dim = rotary_dim
521518
self.head_size = head_size
522519
self.max_position_embeddings = max_position_embeddings
523520
self.original_max_position_embeddings = original_max_position_embeddings
@@ -557,7 +554,7 @@ def __init__(
557554
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
558555
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
559556
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
560-
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
557+
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)))
561558
return inv_freq
562559

563560
def _compute_cos_sin_cache(
@@ -596,8 +593,15 @@ def forward(
596593
cos = cos.repeat(1, 2).unsqueeze(-2)
597594
sin = sin.repeat(1, 2).unsqueeze(-2)
598595

599-
query = query * cos + _rotate_neox(query) * sin
600-
key = key * cos + _rotate_neox(key) * sin
596+
query_rot = query[..., :self.rotary_dim]
597+
query_pass = query[..., self.rotary_dim:]
598+
query_rot = query_rot * cos + _rotate_neox(query_rot) * sin
599+
query = torch.cat((query_rot, query_pass), dim=-1)
600+
601+
key_rot = key[..., :self.rotary_dim]
602+
key_pass = key[..., self.rotary_dim:]
603+
key_rot = key_rot * cos + _rotate_neox(key_rot) * sin
604+
key = torch.cat((key_rot, key_pass), dim=-1)
601605

602606
return query.flatten(-2), key.flatten(-2)
603607

vllm/model_executor/models/llama.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,9 @@ def __init__(self,
128128
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
129129
self.head_dim = getattr(config, "head_dim",
130130
self.hidden_size // self.total_num_heads)
131+
# Phi models introduced a partial_rotary_factor parameter in the config
132+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
133+
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
131134
self.q_size = self.num_heads * self.head_dim
132135
self.kv_size = self.num_kv_heads * self.head_dim
133136
self.scaling = self.head_dim**-0.5
@@ -159,7 +162,7 @@ def __init__(self,
159162

160163
self.rotary_emb = get_rope(
161164
self.head_dim,
162-
rotary_dim=self.head_dim,
165+
rotary_dim=self.rotary_dim,
163166
max_position=max_position_embeddings,
164167
base=rope_theta,
165168
rope_scaling=rope_scaling,

0 commit comments

Comments
 (0)