@@ -85,6 +85,7 @@ def __init__(
85
85
hidden_size : int ,
86
86
num_heads : int ,
87
87
num_kv_heads : int ,
88
+ rope_theta : float = 10000 ,
88
89
):
89
90
super ().__init__ ()
90
91
self .hidden_size = hidden_size
@@ -99,6 +100,7 @@ def __init__(
99
100
self .q_size = self .num_heads * self .head_dim
100
101
self .kv_size = self .num_kv_heads * self .head_dim
101
102
self .scaling = self .head_dim ** - 0.5
103
+ self .rope_theta = rope_theta
102
104
103
105
self .qkv_proj = ColumnParallelLinear (
104
106
hidden_size ,
@@ -118,6 +120,7 @@ def __init__(
118
120
self .attn = PagedAttentionWithRoPE (self .num_heads ,
119
121
self .head_dim ,
120
122
self .scaling ,
123
+ base = self .rope_theta ,
121
124
rotary_dim = self .head_dim ,
122
125
num_kv_heads = self .num_kv_heads )
123
126
@@ -143,10 +146,13 @@ class LlamaDecoderLayer(nn.Module):
143
146
def __init__ (self , config : LlamaConfig ):
144
147
super ().__init__ ()
145
148
self .hidden_size = config .hidden_size
149
+ # Requires transformers > 4.32.0
150
+ rope_theta = getattr (config , "rope_theta" , 10000 )
146
151
self .self_attn = LlamaAttention (
147
152
hidden_size = self .hidden_size ,
148
153
num_heads = config .num_attention_heads ,
149
154
num_kv_heads = config .num_key_value_heads ,
155
+ rope_theta = rope_theta ,
150
156
)
151
157
self .mlp = LlamaMLP (
152
158
hidden_size = self .hidden_size ,
0 commit comments