|
23 | 23 | # limitations under the License.
|
24 | 24 | """Inference-only deci model compatible with HuggingFace weights."""
|
25 | 25 | from collections.abc import Iterable
|
26 |
| -from typing import Optional, Union |
| 26 | +from typing import Any, Optional, Union |
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from torch import nn
|
@@ -66,36 +66,43 @@ def _find_multiple(n: int, k: int) -> int:
|
66 | 66 |
|
67 | 67 | class DeciLMAttention(LlamaAttention):
|
68 | 68 |
|
69 |
| - def __init__(self, |
70 |
| - config, |
71 |
| - hidden_size, |
72 |
| - num_heads, |
73 |
| - num_kv_heads, |
74 |
| - rope_theta=10000, |
75 |
| - rope_scaling=None, |
76 |
| - max_position_embeddings=8192, |
77 |
| - quant_config=None, |
78 |
| - bias=False, |
79 |
| - bias_o_proj=False, |
80 |
| - cache_config=None, |
81 |
| - prefix="", |
82 |
| - attn_type=AttentionType.DECODER): |
| 69 | + def __init__( |
| 70 | + self, |
| 71 | + config: LlamaConfig, |
| 72 | + hidden_size: int, |
| 73 | + num_heads: int, |
| 74 | + num_kv_heads: int, |
| 75 | + rope_theta: float = 10000, |
| 76 | + rope_scaling: Optional[dict[str, Any]] = None, |
| 77 | + max_position_embeddings: int = 8192, |
| 78 | + quant_config: Optional[QuantizationConfig] = None, |
| 79 | + bias: bool = False, |
| 80 | + bias_o_proj: bool = False, |
| 81 | + cache_config: Optional[CacheConfig] = None, |
| 82 | + prefix: str = "", |
| 83 | + attn_type: str = AttentionType.DECODER, |
| 84 | + ) -> None: |
83 | 85 | super().__init__(config, hidden_size, num_heads, num_kv_heads,
|
84 | 86 | rope_theta, rope_scaling, max_position_embeddings,
|
85 | 87 | quant_config, bias, bias_o_proj, cache_config, prefix,
|
86 | 88 | attn_type)
|
87 | 89 |
|
88 |
| - # Enable YARN by overriding rope |
89 |
| - interleaved_rope = config.position_embedding_type in [ |
90 |
| - "mistral_yarn", "rope_llama4" |
91 |
| - ] |
| 90 | + def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]], |
| 91 | + quant_config: Optional[QuantizationConfig]) -> None: |
| 92 | + # Enables YARN for Mistral and LLaMA4 derivatives. |
| 93 | + is_neox_style = True |
| 94 | + if hasattr(config, "position_embedding_type"): |
| 95 | + is_neox_style = config.position_embedding_type not in [ |
| 96 | + "mistral_yarn", "rope_llama4" |
| 97 | + ] |
| 98 | + |
92 | 99 | self.rotary_emb = get_rope(
|
93 | 100 | self.head_dim,
|
94 | 101 | rotary_dim=self.head_dim,
|
95 | 102 | max_position=self.max_position_embeddings,
|
96 | 103 | base=self.rope_theta,
|
97 | 104 | rope_scaling=rope_scaling,
|
98 |
| - is_neox_style=not interleaved_rope, |
| 105 | + is_neox_style=is_neox_style, |
99 | 106 | partial_rotary_factor=self.partial_rotary_factor)
|
100 | 107 |
|
101 | 108 |
|
|
0 commit comments