Skip to content

Commit dec5f21

Browse files
jameswu2014pkrmf
authored andcommitted
feature : support Baichuan serial models (ggml-org#3009)
1 parent 3d2e41a commit dec5f21

File tree

4 files changed

+781
-3
lines changed

4 files changed

+781
-3
lines changed

Diff for: convert-baichuan-hf-to-gguf.py

+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
#!/usr/bin/env python3
2+
# HF baichuan --> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import struct
10+
import sys
11+
from pathlib import Path
12+
from typing import TYPE_CHECKING, Any
13+
import itertools
14+
import gguf
15+
import numpy as np
16+
import torch
17+
from sentencepiece import SentencePieceProcessor # type: ignore[import]
18+
19+
20+
if TYPE_CHECKING:
21+
from typing import TypeAlias
22+
23+
NDArray: TypeAlias = 'np.ndarray[Any, Any]'
24+
25+
# reverse HF permute back to original pth layout
26+
27+
28+
def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: int | None = None) -> NDArray:
29+
if n_kv_head is not None and n_head != n_kv_head:
30+
n_head //= n_kv_head
31+
32+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
33+
.swapaxes(1, 2)
34+
.reshape(weights.shape))
35+
36+
def reverse_hf_permute_part(weights: NDArray, n_part: int, n_head: int, n_head_kv: int| None = None) -> NDArray:
37+
r = weights.shape[0] // 3
38+
return (reverse_hf_permute(weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv))
39+
40+
def reverse_hf_part(weights: NDArray, n_part: int) -> NDArray:
41+
r = weights.shape[0] // 3
42+
return weights[r * n_part : r * n_part + r, ...]
43+
44+
def count_model_parts(dir_model: str) -> int:
45+
num_parts = 0
46+
47+
for filename in os.listdir(dir_model):
48+
if filename.startswith("pytorch_model-"):
49+
num_parts += 1
50+
51+
if num_parts > 0:
52+
print("gguf: found " + str(num_parts) + " model parts")
53+
54+
return num_parts
55+
56+
57+
58+
def parse_args() -> argparse.Namespace:
59+
parser = argparse.ArgumentParser(description="Convert a HuggingFace LLaMA model to a GGML compatible file")
60+
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
61+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
62+
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
63+
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
64+
return parser.parse_args()
65+
66+
args = parse_args()
67+
68+
dir_model = args.model
69+
ftype = args.ftype
70+
if not dir_model.is_dir():
71+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
72+
sys.exit(1)
73+
74+
# possible tensor data types
75+
# ftype == 0 -> float32
76+
# ftype == 1 -> float16
77+
78+
# map from ftype to string
79+
ftype_str = ["f32", "f16"]
80+
81+
if args.outfile is not None:
82+
fname_out = args.outfile
83+
else:
84+
# output in the same directory as the model by default
85+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
86+
87+
print("gguf: loading model "+dir_model.name)
88+
89+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
90+
hparams = json.load(f)
91+
print("hello print: ",hparams["architectures"][0])
92+
if hparams["architectures"][0] != "BaichuanForCausalLM":
93+
print("Model architecture not supported: " + hparams["architectures"][0])
94+
95+
sys.exit()
96+
97+
# get number of model parts
98+
num_parts = count_model_parts(dir_model)
99+
print(f"num_parts:{num_parts}\n")
100+
ARCH=gguf.MODEL_ARCH.BAICHUAN
101+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
102+
103+
print("gguf: get model metadata")
104+
105+
block_count = hparams["num_hidden_layers"]
106+
head_count = hparams["num_attention_heads"]
107+
108+
if "num_key_value_heads" in hparams:
109+
head_count_kv = hparams["num_key_value_heads"]
110+
else:
111+
head_count_kv = head_count
112+
113+
if "_name_or_path" in hparams:
114+
hf_repo = hparams["_name_or_path"]
115+
else:
116+
hf_repo = ""
117+
118+
if "max_sequence_length" in hparams:
119+
ctx_length = hparams["max_sequence_length"]
120+
elif "max_position_embeddings" in hparams:
121+
ctx_length = hparams["max_position_embeddings"]
122+
elif "model_max_length" in hparams:
123+
ctx_length = hparams["model_max_length"]
124+
else:
125+
print("gguf: can not find ctx length parameter.")
126+
127+
sys.exit()
128+
129+
130+
gguf_writer.add_name(dir_model.name)
131+
gguf_writer.add_source_hf_repo(hf_repo)
132+
gguf_writer.add_tensor_data_layout("Meta AI original pth")
133+
gguf_writer.add_context_length(ctx_length)
134+
gguf_writer.add_embedding_length(hparams["hidden_size"])
135+
gguf_writer.add_block_count(block_count)
136+
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
137+
gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
138+
gguf_writer.add_head_count(head_count)
139+
gguf_writer.add_head_count_kv(head_count_kv)
140+
gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
141+
142+
if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]:
143+
if "type" in hparams["rope_scaling"]:
144+
if hparams["rope_scaling"]["type"] == "linear":
145+
gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"])
146+
147+
148+
# TOKENIZATION
149+
150+
print("gguf: get tokenizer metadata")
151+
152+
tokens: list[bytes] = []
153+
scores: list[float] = []
154+
toktypes: list[int] = []
155+
156+
tokenizer_model_file = dir_model / 'tokenizer.model'
157+
if not tokenizer_model_file.is_file():
158+
print(f'Error: Missing {tokenizer_model_file}', file = sys.stderr)
159+
sys.exit(1)
160+
161+
# vocab type sentencepiece
162+
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
163+
164+
tokenizer = SentencePieceProcessor(str(tokenizer_model_file))
165+
166+
for i in range(tokenizer.vocab_size()):
167+
text: bytes
168+
score: float
169+
170+
piece = tokenizer.id_to_piece(i)
171+
text = piece.encode("utf-8")
172+
score = tokenizer.get_score(i)
173+
174+
toktype = 1 # defualt to normal token type
175+
if tokenizer.is_unknown(i):
176+
toktype = 2
177+
if tokenizer.is_control(i):
178+
toktype = 3
179+
180+
# toktype = 4 is user-defined = tokens from added_tokens.json
181+
182+
if tokenizer.is_unused(i):
183+
toktype = 5
184+
if tokenizer.is_byte(i):
185+
toktype = 6
186+
187+
tokens.append(text)
188+
scores.append(score)
189+
toktypes.append(toktype)
190+
191+
added_tokens_file = dir_model / 'added_tokens.json'
192+
if added_tokens_file.is_file():
193+
with open(added_tokens_file, "r", encoding="utf-8") as f:
194+
addtokens_json = json.load(f)
195+
196+
print("gguf: get added tokens")
197+
198+
for key in addtokens_json:
199+
tokens.append( key.encode("utf-8") )
200+
scores.append(-1000.0)
201+
toktypes.append(4) # user-defined token type
202+
203+
204+
gguf_writer.add_tokenizer_model("llama")
205+
gguf_writer.add_token_list(tokens)
206+
gguf_writer.add_token_scores(scores)
207+
gguf_writer.add_token_types(toktypes)
208+
209+
special_vocab = gguf.SpecialVocab(dir_model)
210+
special_vocab.add_to_gguf(gguf_writer)
211+
212+
# TENSORS
213+
214+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
215+
216+
# tensor info
217+
print("gguf: get tensor metadata")
218+
219+
if num_parts == 0:
220+
part_names = iter(("pytorch_model.bin",))
221+
else:
222+
part_names = (
223+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
224+
)
225+
226+
227+
for part_name in part_names:
228+
if args.vocab_only:
229+
break
230+
print("gguf: loading model part '" + part_name + "'")
231+
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
232+
233+
tmp=model_part
234+
for i in range(block_count):
235+
if f"model.layers.{i}.self_attn.W_pack.weight" in model_part:
236+
print(f"Unpacking and permuting layer {i}")
237+
tmp[f"model.layers.{i}.self_attn.q_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],0,head_count,head_count)
238+
tmp[f"model.layers.{i}.self_attn.k_proj.weight"]=reverse_hf_permute_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],1,head_count,head_count_kv)
239+
tmp[f"model.layers.{i}.self_attn.v_proj.weight"]=reverse_hf_part(model_part[f"model.layers.{i}.self_attn.W_pack.weight"],2)
240+
del tmp[f"model.layers.{i}.self_attn.W_pack.weight"]
241+
242+
for name in model_part.keys():
243+
data = model_part[name]
244+
# we don't need these
245+
if name.endswith(".rotary_emb.inv_freq"):
246+
continue
247+
248+
old_dtype = data.dtype
249+
250+
# convert any unsupported data types to float32
251+
if data.dtype != torch.float16 and data.dtype != torch.float32:
252+
data = data.to(torch.float32)
253+
254+
data = data.squeeze().numpy()
255+
256+
# map tensor names
257+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
258+
if new_name is None:
259+
print("Can not map tensor '" + name + "'")
260+
sys.exit()
261+
262+
n_dims = len(data.shape)
263+
data_dtype = data.dtype
264+
265+
# if f32 desired, convert any float16 to float32
266+
if ftype == 0 and data_dtype == np.float16:
267+
data = data.astype(np.float32)
268+
269+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
270+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
271+
data = data.astype(np.float32)
272+
273+
# if f16 desired, convert any float32 2-dim weight tensors to float16
274+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
275+
data = data.astype(np.float16)
276+
277+
print(name + " -> " + new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
278+
gguf_writer.add_tensor(new_name, data)
279+
280+
281+
print("gguf: write header")
282+
gguf_writer.write_header_to_file()
283+
print("gguf: write metadata")
284+
gguf_writer.write_kv_data_to_file()
285+
if not args.vocab_only:
286+
print("gguf: write tensors")
287+
gguf_writer.write_tensors_to_file()
288+
289+
gguf_writer.close()
290+
291+
print(f"gguf: model successfully exported to '{fname_out}'")
292+
print("")

Diff for: gguf-py/gguf/gguf.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
class MODEL_ARCH(IntEnum):
8080
LLAMA : int = auto()
8181
FALCON : int = auto()
82+
BAICHUAN:int = auto()
8283
GPT2 : int = auto()
8384
GPTJ : int = auto()
8485
GPTNEOX: int = auto()
@@ -108,6 +109,7 @@ class MODEL_TENSOR(IntEnum):
108109
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
109110
MODEL_ARCH.LLAMA: "llama",
110111
MODEL_ARCH.FALCON: "falcon",
112+
MODEL_ARCH.BAICHUAN:"baichuan",
111113
MODEL_ARCH.GPT2: "gpt2",
112114
MODEL_ARCH.GPTJ: "gptj",
113115
MODEL_ARCH.GPTNEOX: "gptneox",
@@ -153,6 +155,22 @@ class MODEL_TENSOR(IntEnum):
153155
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
154156
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
155157
},
158+
MODEL_ARCH.BAICHUAN: {
159+
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
160+
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
161+
MODEL_TENSOR.OUTPUT: "output",
162+
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
163+
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
164+
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
165+
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
166+
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
167+
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
168+
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
169+
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
170+
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
171+
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
172+
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
173+
},
156174
MODEL_ARCH.GPT2: {
157175
# TODO
158176
},
@@ -165,6 +183,10 @@ class MODEL_TENSOR(IntEnum):
165183
MODEL_TENSOR.ROPE_FREQS,
166184
MODEL_TENSOR.ATTN_ROT_EMBD,
167185
],
186+
MODEL_ARCH.BAICHUAN: [
187+
MODEL_TENSOR.ROPE_FREQS,
188+
MODEL_TENSOR.ATTN_ROT_EMBD,
189+
],
168190
}
169191

170192

@@ -187,15 +209,15 @@ class TensorNameMap:
187209
# Output
188210
MODEL_TENSOR.OUTPUT: (
189211
"embed_out", # gptneox
190-
"lm_head", # gpt2 mpt falcon llama-hf
212+
"lm_head", # gpt2 mpt falcon llama-hf baichuan
191213
"output", # llama-pth
192214
),
193215

194216
# Output norm
195217
MODEL_TENSOR.OUTPUT_NORM: (
196218
"gpt_neox.final_layer_norm", # gptneox
197219
"transformer.ln_f", # gpt2 falcon
198-
"model.norm", # llama-hf
220+
"model.norm", # llama-hf baichuan
199221
"norm", # llama-pth
200222
),
201223

0 commit comments

Comments
 (0)