Skip to content

Commit 221f4f9

Browse files
authored
Merge branch 'master' into patch-1
2 parents 338f606 + e06ff42 commit 221f4f9

File tree

5 files changed

+49
-28
lines changed

5 files changed

+49
-28
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ runomp: run.c
3434

3535
.PHONY: win64
3636
win64:
37-
x86_64-w64-mingw32-gcc-win32 -Ofast -D_WIN32 -o run.exe -I. run.c win.c
37+
x86_64-w64-mingw32-gcc -Ofast -D_WIN32 -o run.exe -I. run.c win.c
3838

3939
# compiles with gnu99 standard flags for amazon linux, coreos, etc. compatibility
4040
.PHONY: rungnu

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,14 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
200200
- [llama2.go](https://github.com/haormj/llama2.go) by @haormj: a Go port of this project
201201
- [llama2.go](https://github.com/saracen/llama2.go) by @saracen: a Go port of this project
202202
- [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @Manuel030: adds Android binaries of this project
203+
- [llama2.c-android-wrapper](https://github.com/celikin/llama2.c-android-wrapper): by @celikin: added JNI wrapper, PoC
203204
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project
204205
- [llama2.js](https://github.com/epicure/llama2.js) by @epicure: a JavaScript port of this project
205206
- [llama2.zig](https://github.com/cgbur/llama2.zig) by @cgbur: A Zig port of this project
206207
- [llama2.jl](https://github.com/juvi21/llama2.jl) by @juvi21: a Julia port of this project
208+
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @trholding: Standalone, Bootable & Portable Binary Llama 2
209+
- [llama2.rs](https://github.com/leo-du/llama2.rs) by @leo-du: A Rust port of this project
210+
- [llama2.scala](https://github.com/jrudolph/llama2.scala) by @jrudolph: a Scala port of this project
207211

208212
## unsorted todos
209213

build_msvc.bat

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
cl.exe /Ox /openmp /I. run.c win.c
1+
cl.exe /fp:fast /Ox /openmp /I. run.c win.c

export_meta_llama_bin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ def serialize(key):
5656

5757
# final rmsnorm
5858
serialize('norm.weight')
59-
# freqs_cis
60-
freqs_cis = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
61-
state_dict['freqs_cis.real'] = freqs_cis.real[:p['max_seq_len']]
62-
state_dict['freqs_cis.imag'] = freqs_cis.imag[:p['max_seq_len']]
63-
serialize('freqs_cis.real')
64-
serialize('freqs_cis.imag')
59+
# freqs_cos, freqs_sin
60+
freqs_cos, freqs_sin = precompute_freqs_cis(p['dim'] // p['n_heads'], p['max_seq_len'] * 2)
61+
state_dict['freqs_cos'] = freqs_cos[:p['max_seq_len']]
62+
state_dict['freqs_sin'] = freqs_sin[:p['max_seq_len']]
63+
serialize('freqs_cos')
64+
serialize('freqs_sin')
6565

6666
# finally write the output weights
6767
serialize('output.weight')

model.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
4040
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
4141
t = torch.arange(end, device=freqs.device) # type: ignore
4242
freqs = torch.outer(t, freqs).float() # type: ignore
43-
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
44-
return freqs_cis
45-
43+
freqs_cos = torch.cos(freqs) # real part
44+
freqs_sin = torch.sin(freqs) # imaginary part
45+
return freqs_cos, freqs_sin
4646

4747
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
4848
ndim = x.ndim
@@ -51,17 +51,31 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
5151
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
5252
return freqs_cis.view(*shape)
5353

54-
5554
def apply_rotary_emb(
5655
xq: torch.Tensor,
5756
xk: torch.Tensor,
58-
freqs_cis: torch.Tensor,
57+
freqs_cos: torch.Tensor,
58+
freqs_sin: torch.Tensor
5959
) -> Tuple[torch.Tensor, torch.Tensor]:
60-
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
61-
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
62-
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
63-
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
64-
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
60+
61+
# reshape xq and xk to match the complex representation
62+
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
63+
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
64+
65+
# reshape freqs_cos and freqs_sin for broadcasting
66+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
67+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
68+
69+
# apply rotation using real numbers
70+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
71+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
72+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
73+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
74+
75+
# flatten last two dimensions
76+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
77+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
78+
6579
return xq_out.type_as(xq), xk_out.type_as(xk)
6680

6781
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -103,7 +117,8 @@ def __init__(self, args: ModelArgs):
103117
def forward(
104118
self,
105119
x: torch.Tensor,
106-
freqs_cis: torch.Tensor,
120+
freqs_cos: torch.Tensor,
121+
freqs_sin: torch.Tensor,
107122
):
108123
bsz, seqlen, _ = x.shape
109124

@@ -114,7 +129,7 @@ def forward(
114129
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
115130

116131
# RoPE relative positional embeddings
117-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
132+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
118133

119134
# grouped multiquery attention: expand out keys and values
120135
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
@@ -176,8 +191,8 @@ def __init__(self, layer_id: int, args: ModelArgs):
176191
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
177192
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
178193

179-
def forward(self, x, freqs_cis):
180-
h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
194+
def forward(self, x, freqs_cos, freqs_sin):
195+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
181196
out = h + self.feed_forward.forward(self.ffn_norm(h))
182197
return out
183198

@@ -201,8 +216,9 @@ def __init__(self, params: ModelArgs):
201216
self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
202217

203218
# some useful precompute for the RoPE relative positional embeddings. TODO why * 2 here? confuse
204-
freqs_cis = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
205-
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
219+
freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len * 2)
220+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
221+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
206222

207223
# init all weights
208224
self.apply(self._init_weights)
@@ -223,10 +239,11 @@ def forward(self, tokens, targets=None):
223239
_bsz, seqlen = tokens.shape
224240
h = self.tok_embeddings(tokens)
225241
h = self.dropout(h)
226-
freqs_cis = self.freqs_cis[:seqlen]
242+
freqs_cos = self.freqs_cos[:seqlen]
243+
freqs_sin = self.freqs_sin[:seqlen]
227244

228245
for layer in self.layers:
229-
h = layer(h, freqs_cis)
246+
h = layer(h, freqs_cos, freqs_sin)
230247
h = self.norm(h)
231248

232249
if targets is not None:
@@ -359,8 +376,8 @@ def serialize(t):
359376
serialize(self.norm.weight)
360377
# note: no need to write final classifier weights due to weight sharing
361378
# freqs_cis
362-
serialize(self.freqs_cis.real[:p.max_seq_len])
363-
serialize(self.freqs_cis.imag[:p.max_seq_len])
379+
serialize(self.freqs_cos[:p.max_seq_len])
380+
serialize(self.freqs_sin[:p.max_seq_len])
364381

365382
# write to binary file
366383
f.close()

0 commit comments

Comments
 (0)