Skip to content

Commit 4b6f069

Browse files
authored
Add support for CodeLlama (#854)
1 parent 791d79d commit 4b6f069

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

vllm/model_executor/models/llama.py

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
hidden_size: int,
8686
num_heads: int,
8787
num_kv_heads: int,
88+
rope_theta: float = 10000,
8889
):
8990
super().__init__()
9091
self.hidden_size = hidden_size
@@ -99,6 +100,7 @@ def __init__(
99100
self.q_size = self.num_heads * self.head_dim
100101
self.kv_size = self.num_kv_heads * self.head_dim
101102
self.scaling = self.head_dim**-0.5
103+
self.rope_theta = rope_theta
102104

103105
self.qkv_proj = ColumnParallelLinear(
104106
hidden_size,
@@ -118,6 +120,7 @@ def __init__(
118120
self.attn = PagedAttentionWithRoPE(self.num_heads,
119121
self.head_dim,
120122
self.scaling,
123+
base=self.rope_theta,
121124
rotary_dim=self.head_dim,
122125
num_kv_heads=self.num_kv_heads)
123126

@@ -143,10 +146,13 @@ class LlamaDecoderLayer(nn.Module):
143146
def __init__(self, config: LlamaConfig):
144147
super().__init__()
145148
self.hidden_size = config.hidden_size
149+
# Requires transformers > 4.32.0
150+
rope_theta = getattr(config, "rope_theta", 10000)
146151
self.self_attn = LlamaAttention(
147152
hidden_size=self.hidden_size,
148153
num_heads=config.num_attention_heads,
149154
num_kv_heads=config.num_key_value_heads,
155+
rope_theta=rope_theta,
150156
)
151157
self.mlp = LlamaMLP(
152158
hidden_size=self.hidden_size,

0 commit comments

Comments
 (0)