Skip to content

Commit df51a6e

Browse files
committed
fix: plamo from PR ggml-org#3557
1 parent b9fdfbd commit df51a6e

File tree

2 files changed

+256
-13
lines changed

2 files changed

+256
-13
lines changed

convert-plamo-hf-to-gguf.py

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
import argparse
2+
import json
3+
import sys
4+
import os
5+
import torch
6+
import numpy as np
7+
from pathlib import Path
8+
import gguf
9+
from sentencepiece import SentencePieceProcessor # type: ignore[import]
10+
11+
try:
12+
from safetensors import safe_open
13+
except ImportError:
14+
print("Please install `safetensors` python package")
15+
sys.exit(1)
16+
17+
18+
def count_model_parts(dir_model: Path) -> int:
19+
# get number of model parts
20+
num_parts = 0
21+
for filename in os.listdir(dir_model):
22+
if filename.startswith("model-00"):
23+
num_parts += 1
24+
25+
if num_parts > 0:
26+
print("gguf: found " + str(num_parts) + " model parts")
27+
return num_parts
28+
29+
30+
def parse_args() -> argparse.Namespace:
31+
parser = argparse.ArgumentParser(description="Convert a PLaMo model to a GGML compatible file")
32+
parser.add_argument(
33+
"--vocab-only", action="store_true",
34+
help="extract only the vocab",
35+
)
36+
parser.add_argument(
37+
"--outfile", type=Path,
38+
help="path to write to; default: based on input",
39+
)
40+
parser.add_argument(
41+
"model", type=Path,
42+
help="directory containing model file, or model file itself (*.bin)",
43+
)
44+
parser.add_argument(
45+
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
46+
help="output format - use 0 for float32, 1 for float16",
47+
)
48+
return parser.parse_args()
49+
50+
51+
args = parse_args()
52+
53+
dir_model = args.model
54+
ftype = args.ftype
55+
if not dir_model.is_dir():
56+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
57+
sys.exit(1)
58+
59+
60+
# possible tensor data types
61+
# ftype == 0 -> float32
62+
# ftype == 1 -> float16
63+
64+
# map from ftype to string
65+
ftype_str = ["f32", "f16"]
66+
67+
if args.outfile is not None:
68+
fname_out = args.outfile
69+
else:
70+
# output in the same directory as the model by default
71+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
72+
73+
print("gguf: loading model "+dir_model.name)
74+
75+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
76+
hparams = json.load(f)
77+
78+
if hparams["architectures"][0] != "PlamoForCausalLM":
79+
print("Model architecture not supported: " + hparams["architectures"][0])
80+
81+
sys.exit(1)
82+
83+
# get number of model parts
84+
num_parts = count_model_parts(dir_model)
85+
86+
# from add PLaMo model #3557
87+
# https://github.com/ggerganov/llama.cpp/pull/3557/files
88+
89+
ARCH=gguf.MODEL_ARCH.PLAMO
90+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
91+
92+
print("gguf: get model metadata")
93+
94+
block_count = hparams["num_hidden_layers"]
95+
96+
gguf_writer.add_name("PLaMo")
97+
gguf_writer.add_context_length(4096) # not in config.json
98+
gguf_writer.add_embedding_length(hparams["hidden_size"])
99+
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
100+
gguf_writer.add_block_count(block_count)
101+
gguf_writer.add_head_count(hparams["num_attention_heads"])
102+
gguf_writer.add_head_count_kv(hparams["num_attention_heads"] // hparams["n_shared_head"])
103+
gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
104+
gguf_writer.add_file_type(ftype)
105+
106+
107+
# TOKENIZATION
108+
109+
print("gguf: get tokenizer metadata")
110+
111+
tokens: list[bytes] = []
112+
scores: list[float] = []
113+
toktypes: list[int] = []
114+
115+
tokenizer_model_file = dir_model / 'tokenizer.model'
116+
if not tokenizer_model_file.is_file():
117+
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
118+
sys.exit(1)
119+
120+
# vocab type sentencepiece
121+
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
122+
123+
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
124+
125+
for i in range(tokenizer.vocab_size()):
126+
text: bytes
127+
score: float
128+
129+
piece = tokenizer.id_to_piece(i)
130+
text = piece.encode("utf-8")
131+
score = tokenizer.get_score(i)
132+
133+
toktype = 1 # defualt to normal token type
134+
if tokenizer.is_unknown(i):
135+
toktype = 2
136+
if tokenizer.is_control(i):
137+
toktype = 3
138+
139+
# toktype = 4 is user-defined = tokens from added_tokens.json
140+
141+
if tokenizer.is_unused(i):
142+
toktype = 5
143+
if tokenizer.is_byte(i):
144+
toktype = 6
145+
146+
tokens.append(text)
147+
scores.append(score)
148+
toktypes.append(toktype)
149+
150+
gguf_writer.add_tokenizer_model("llama")
151+
gguf_writer.add_token_list(tokens)
152+
gguf_writer.add_token_scores(scores)
153+
gguf_writer.add_token_types(toktypes)
154+
gguf_writer.add_sep_token_id(5)
155+
gguf_writer.add_pad_token_id(3)
156+
157+
special_vocab = gguf.SpecialVocab(dir_model)
158+
special_vocab.add_to_gguf(gguf_writer)
159+
160+
# TENSORS
161+
162+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
163+
164+
# params for qkv transform
165+
n_head = hparams["num_attention_heads"]
166+
n_head_kv = hparams["num_key_value_heads"]
167+
168+
head_dim = hparams["hidden_size"] // n_head
169+
170+
# tensor info
171+
print("gguf: get tensor metadata")
172+
173+
if num_parts == 0:
174+
part_names = iter(("model.safetensors",))
175+
else:
176+
part_names = (
177+
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
178+
)
179+
180+
for part_name in part_names:
181+
if args.vocab_only:
182+
break
183+
print("gguf: loading model part '" + part_name + "'")
184+
model_part = safe_open(dir_model / part_name, framework="pt")
185+
186+
for name in model_part.keys():
187+
if "self_attn.rotary_emb.inv_freq" in name:
188+
continue
189+
data = model_part.get_tensor(name)
190+
191+
old_dtype = data.dtype
192+
193+
# convert any unsupported data types to float32
194+
if data.dtype != torch.float16 and data.dtype != torch.float32:
195+
data = data.to(torch.float32)
196+
197+
data = data.squeeze().numpy()
198+
199+
# map tensor names
200+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
201+
if new_name is None:
202+
print("Can not map tensor '" + name + "'")
203+
sys.exit()
204+
205+
n_dims = len(data.shape)
206+
data_dtype = data.dtype
207+
208+
# if f32 desired, convert any float16 to float32
209+
if ftype == 0 and data_dtype == np.float16:
210+
data = data.astype(np.float32)
211+
212+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
213+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
214+
data = data.astype(np.float32)
215+
216+
# if f16 desired, convert any float32 2-dim weight tensors to float16
217+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
218+
data = data.astype(np.float16)
219+
220+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
221+
222+
gguf_writer.add_tensor(new_name, data)
223+
224+
225+
print("gguf: write header")
226+
gguf_writer.write_header_to_file()
227+
print("gguf: write metadata")
228+
gguf_writer.write_kv_data_to_file()
229+
if not args.vocab_only:
230+
print("gguf: write tensors")
231+
gguf_writer.write_tensors_to_file()
232+
233+
gguf_writer.close()
234+
235+
print(f"gguf: model successfully exported to '{fname_out}'")
236+
print("")

llama.cpp

+20-13
Original file line numberDiff line numberDiff line change
@@ -7232,7 +7232,7 @@ static struct ggml_cgraph * llm_build_plamo(
72327232
ggml_element_size(kv_self.k)*n_embd_head,
72337233
ggml_element_size(kv_self.k)*n_embd_gqa,
72347234
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
7235-
K_shift, n_embd_head, 0, 0, freq_base, freq_scale);
7235+
K_shift, n_embd_head, 2, 0, freq_base, freq_scale);
72367236
offload_func_kq(tmp);
72377237
ggml_build_forward_expand(gf, tmp);
72387238
}
@@ -7274,11 +7274,11 @@ static struct ggml_cgraph * llm_build_plamo(
72747274
offload_func_kq(tmpq);
72757275
ggml_set_name(tmpq, "tmpq");
72767276

7277-
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
7277+
struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
72787278
offload_func_kq(Kcur);
72797279
ggml_set_name(Kcur, "Kcur");
72807280

7281-
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale);
7281+
struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale);
72827282
offload_func_kq(Qcur);
72837283
ggml_set_name(Qcur, "Qcur");
72847284

@@ -7322,8 +7322,17 @@ static struct ggml_cgraph * llm_build_plamo(
73227322
offload_func_kq(K);
73237323
ggml_set_name(K, "K");
73247324

7325+
// from this PR
7326+
// https://github.com/ggerganov/llama.cpp/pull/3557
7327+
73257328
// K * Q
7326-
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
7329+
//struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
7330+
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
7331+
struct ggml_tensor * K_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K->ne[0], K->ne[1], Q->ne[2]);
7332+
offload_func_kq(K_repeated);
7333+
ggml_set_name(K_repeated, "K_repeated");
7334+
7335+
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, ggml_repeat(ctx0, K, K_repeated), Q);
73277336
offload_func_kq(KQ);
73287337
ggml_set_name(KQ, "KQ");
73297338

@@ -7353,17 +7362,15 @@ static struct ggml_cgraph * llm_build_plamo(
73537362
offload_func_v(V);
73547363
ggml_set_name(V, "V");
73557364

7356-
#if 1
7357-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
7365+
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
7366+
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
7367+
struct ggml_tensor * V_repeated = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, V->ne[0], V->ne[1], Q->ne[2]);
7368+
offload_func_v(V_repeated);
7369+
ggml_set_name(V_repeated, "V_repeated");
7370+
7371+
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, ggml_repeat(ctx0, V, V_repeated), KQ_soft_max);
73587372
offload_func_v(KQV);
73597373
ggml_set_name(KQV, "KQV");
7360-
#else
7361-
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
7362-
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
7363-
// is there a better way?
7364-
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
7365-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
7366-
#endif
73677374

73687375
// KQV_merged = KQV.permute(0, 2, 1, 3)
73697376
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);

0 commit comments

Comments
 (0)