Skip to content

Commit be8c85c

Browse files
committed
Factor llama rotary_emb initialization out to protected method.
Override rotary_emb initialization for NemotronNAS Attention Signed-off-by: Nave Assaf <[email protected]>
1 parent 23e3405 commit be8c85c

File tree

2 files changed

+48
-34
lines changed

2 files changed

+48
-34
lines changed

vllm/model_executor/models/llama.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,9 @@ def __init__(
162162
prefix=f"{prefix}.o_proj",
163163
)
164164

165-
is_neox_style = True
166-
is_gguf = quant_config and quant_config.get_name() == "gguf"
167-
if is_gguf and config.model_type == "llama":
168-
is_neox_style = False
169-
170-
self.rotary_emb = get_rope(
171-
self.head_dim,
172-
rotary_dim=self.head_dim,
173-
max_position=max_position_embeddings,
174-
base=rope_theta,
175-
rope_scaling=rope_scaling,
176-
is_neox_style=is_neox_style,
177-
partial_rotary_factor=self.partial_rotary_factor,
178-
)
165+
self._init_rotary_emb(config,
166+
rope_scaling=rope_scaling,
167+
quant_config=quant_config)
179168

180169
if hasattr(config, "interleaved_sliding_window"):
181170
interleaved_sliding_window = config.interleaved_sliding_window
@@ -214,6 +203,24 @@ def forward(
214203
output, _ = self.o_proj(attn_output)
215204
return output
216205

206+
def _init_rotary_emb(self, config: LlamaConfig,
207+
rope_scaling: Optional[dict[str, Any]],
208+
quant_config: Optional[QuantizationConfig]) -> None:
209+
is_neox_style = True
210+
is_gguf = quant_config and quant_config.get_name() == "gguf"
211+
if is_gguf and self.config.model_type == "llama":
212+
is_neox_style = False
213+
214+
self.rotary_emb = get_rope(
215+
self.head_dim,
216+
rotary_dim=self.head_dim,
217+
max_position=self.max_position_embeddings,
218+
base=self.rope_theta,
219+
rope_scaling=rope_scaling,
220+
is_neox_style=is_neox_style,
221+
partial_rotary_factor=self.partial_rotary_factor,
222+
)
223+
217224

218225
class LlamaDecoderLayer(nn.Module):
219226

vllm/model_executor/models/nemotron_nas.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# limitations under the License.
2424
"""Inference-only deci model compatible with HuggingFace weights."""
2525
from collections.abc import Iterable
26-
from typing import Optional, Union
26+
from typing import Any, Optional, Union
2727

2828
import torch
2929
from torch import nn
@@ -66,36 +66,43 @@ def _find_multiple(n: int, k: int) -> int:
6666

6767
class DeciLMAttention(LlamaAttention):
6868

69-
def __init__(self,
70-
config,
71-
hidden_size,
72-
num_heads,
73-
num_kv_heads,
74-
rope_theta=10000,
75-
rope_scaling=None,
76-
max_position_embeddings=8192,
77-
quant_config=None,
78-
bias=False,
79-
bias_o_proj=False,
80-
cache_config=None,
81-
prefix="",
82-
attn_type=AttentionType.DECODER):
69+
def __init__(
70+
self,
71+
config: LlamaConfig,
72+
hidden_size: int,
73+
num_heads: int,
74+
num_kv_heads: int,
75+
rope_theta: float = 10000,
76+
rope_scaling: Optional[dict[str, Any]] = None,
77+
max_position_embeddings: int = 8192,
78+
quant_config: Optional[QuantizationConfig] = None,
79+
bias: bool = False,
80+
bias_o_proj: bool = False,
81+
cache_config: Optional[CacheConfig] = None,
82+
prefix: str = "",
83+
attn_type: str = AttentionType.DECODER,
84+
) -> None:
8385
super().__init__(config, hidden_size, num_heads, num_kv_heads,
8486
rope_theta, rope_scaling, max_position_embeddings,
8587
quant_config, bias, bias_o_proj, cache_config, prefix,
8688
attn_type)
8789

88-
# Enable YARN by overriding rope
89-
interleaved_rope = config.position_embedding_type in [
90-
"mistral_yarn", "rope_llama4"
91-
]
90+
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
91+
quant_config: Optional[QuantizationConfig]) -> None:
92+
# Enables YARN for Mistral and LLaMA4 derivatives.
93+
is_neox_style = True
94+
if hasattr(config, "position_embedding_type"):
95+
is_neox_style = config.position_embedding_type not in [
96+
"mistral_yarn", "rope_llama4"
97+
]
98+
9299
self.rotary_emb = get_rope(
93100
self.head_dim,
94101
rotary_dim=self.head_dim,
95102
max_position=self.max_position_embeddings,
96103
base=self.rope_theta,
97104
rope_scaling=rope_scaling,
98-
is_neox_style=not interleaved_rope,
105+
is_neox_style=is_neox_style,
99106
partial_rotary_factor=self.partial_rotary_factor)
100107

101108

0 commit comments

Comments
 (0)