@@ -40,9 +40,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
40
40
freqs = 1.0 / (theta ** (torch .arange (0 , dim , 2 )[: (dim // 2 )].float () / dim ))
41
41
t = torch .arange (end , device = freqs .device ) # type: ignore
42
42
freqs = torch .outer (t , freqs ).float () # type: ignore
43
- freqs_cis = torch .polar ( torch . ones_like ( freqs ), freqs ) # complex64
44
- return freqs_cis
45
-
43
+ freqs_cos = torch .cos ( freqs ) # real part
44
+ freqs_sin = torch . sin ( freqs ) # imaginary part
45
+ return freqs_cos , freqs_sin
46
46
47
47
def reshape_for_broadcast (freqs_cis : torch .Tensor , x : torch .Tensor ):
48
48
ndim = x .ndim
@@ -51,17 +51,31 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
51
51
shape = [d if i == 1 or i == ndim - 1 else 1 for i , d in enumerate (x .shape )]
52
52
return freqs_cis .view (* shape )
53
53
54
-
55
54
def apply_rotary_emb (
56
55
xq : torch .Tensor ,
57
56
xk : torch .Tensor ,
58
- freqs_cis : torch .Tensor ,
57
+ freqs_cos : torch .Tensor ,
58
+ freqs_sin : torch .Tensor
59
59
) -> Tuple [torch .Tensor , torch .Tensor ]:
60
- xq_ = torch .view_as_complex (xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ))
61
- xk_ = torch .view_as_complex (xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ))
62
- freqs_cis = reshape_for_broadcast (freqs_cis , xq_ )
63
- xq_out = torch .view_as_real (xq_ * freqs_cis ).flatten (3 )
64
- xk_out = torch .view_as_real (xk_ * freqs_cis ).flatten (3 )
60
+
61
+ # reshape xq and xk to match the complex representation
62
+ xq_r , xq_i = xq .float ().reshape (* xq .shape [:- 1 ], - 1 , 2 ).unbind (- 1 )
63
+ xk_r , xk_i = xk .float ().reshape (* xk .shape [:- 1 ], - 1 , 2 ).unbind (- 1 )
64
+
65
+ # reshape freqs_cos and freqs_sin for broadcasting
66
+ freqs_cos = reshape_for_broadcast (freqs_cos , xq_r )
67
+ freqs_sin = reshape_for_broadcast (freqs_sin , xq_r )
68
+
69
+ # apply rotation using real numbers
70
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
71
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
72
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
73
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
74
+
75
+ # flatten last two dimensions
76
+ xq_out = torch .stack ([xq_out_r , xq_out_i ], dim = - 1 ).flatten (3 )
77
+ xk_out = torch .stack ([xk_out_r , xk_out_i ], dim = - 1 ).flatten (3 )
78
+
65
79
return xq_out .type_as (xq ), xk_out .type_as (xk )
66
80
67
81
def repeat_kv (x : torch .Tensor , n_rep : int ) -> torch .Tensor :
@@ -103,7 +117,8 @@ def __init__(self, args: ModelArgs):
103
117
def forward (
104
118
self ,
105
119
x : torch .Tensor ,
106
- freqs_cis : torch .Tensor ,
120
+ freqs_cos : torch .Tensor ,
121
+ freqs_sin : torch .Tensor ,
107
122
):
108
123
bsz , seqlen , _ = x .shape
109
124
@@ -114,7 +129,7 @@ def forward(
114
129
xv = xv .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
115
130
116
131
# RoPE relative positional embeddings
117
- xq , xk = apply_rotary_emb (xq , xk , freqs_cis )
132
+ xq , xk = apply_rotary_emb (xq , xk , freqs_cos , freqs_sin )
118
133
119
134
# grouped multiquery attention: expand out keys and values
120
135
xk = repeat_kv (xk , self .n_rep ) # (bs, seqlen, n_local_heads, head_dim)
@@ -176,8 +191,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
176
191
self .attention_norm = RMSNorm (args .dim , eps = args .norm_eps )
177
192
self .ffn_norm = RMSNorm (args .dim , eps = args .norm_eps )
178
193
179
- def forward (self , x , freqs_cis ):
180
- h = x + self .attention .forward (self .attention_norm (x ), freqs_cis )
194
+ def forward (self , x , freqs_cos , freqs_sin ):
195
+ h = x + self .attention .forward (self .attention_norm (x ), freqs_cos , freqs_sin )
181
196
out = h + self .feed_forward .forward (self .ffn_norm (h ))
182
197
return out
183
198
@@ -201,8 +216,9 @@ def __init__(self, params: ModelArgs):
201
216
self .tok_embeddings .weight = self .output .weight # https://paperswithcode.com/method/weight-tying
202
217
203
218
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
204
- freqs_cis = precompute_freqs_cis (self .params .dim // self .params .n_heads , self .params .max_seq_len * 2 )
205
- self .register_buffer ("freqs_cis" , freqs_cis , persistent = False )
219
+ freqs_cos , freqs_sin = precompute_freqs_cis (self .params .dim // self .params .n_heads , self .params .max_seq_len * 2 )
220
+ self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
221
+ self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
206
222
207
223
# init all weights
208
224
self .apply (self ._init_weights )
@@ -223,10 +239,11 @@ def forward(self, tokens, targets=None):
223
239
_bsz , seqlen = tokens .shape
224
240
h = self .tok_embeddings (tokens )
225
241
h = self .dropout (h )
226
- freqs_cis = self .freqs_cis [:seqlen ]
242
+ freqs_cos = self .freqs_cos [:seqlen ]
243
+ freqs_sin = self .freqs_sin [:seqlen ]
227
244
228
245
for layer in self .layers :
229
- h = layer (h , freqs_cis )
246
+ h = layer (h , freqs_cos , freqs_sin )
230
247
h = self .norm (h )
231
248
232
249
if targets is not None :
@@ -359,8 +376,8 @@ def serialize(t):
359
376
serialize (self .norm .weight )
360
377
# note: no need to write final classifier weights due to weight sharing
361
378
# freqs_cis
362
- serialize (self .freqs_cis . real [:p .max_seq_len ])
363
- serialize (self .freqs_cis . imag [:p .max_seq_len ])
379
+ serialize (self .freqs_cos [:p .max_seq_len ])
380
+ serialize (self .freqs_sin [:p .max_seq_len ])
364
381
365
382
# write to binary file
366
383
f .close ()
0 commit comments