Skip to content

Commit dc07dc4

Browse files
authored
convert : various script cleanups/fixes + merges and special token handling (#2842)
* convert: Fix permute calls and method/func definitions * Cleanups for gguf-py * Minor types cleanups. * Initial implementation of handling merges and special tokens * convert: Handle special tokens and merges in vocab only mode convert: Vocab only mode no longer requires loading model tensors * gguf: Refactor tensor name mapping * convert: Fix type hint for special_token_types in SpecialVocab * Use common special vocab handling in various conversion scripts * First pass at implementing suggested changes * Second pass * gguf: SpecialVocab: Fix issue with special token content not in a dict gguf: SpecialVocab: Allow skipping handling of merges * convert-falcon-hf-to-gguf: Support --vocab-only option, bail out if no tokenizer.json * convert-gptneox-hf-to-gguf and convert: Only handle merges for BPE tokenizer * gguf: SpecialVocab: Actually set load_merges in object * Uniform args parsing and vocab only mode for convert examples * convert.py: Set gpt2 as tokenizer model when using BPE * Squish last type warning in gguf.py - yay!
1 parent ad9ddcf commit dc07dc4

10 files changed

+738
-758
lines changed

convert-falcon-hf-to-gguf.py

+75-93
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import json
99
import numpy as np
1010
import torch
11+
import argparse
1112

1213
from typing import Any, List
1314
from pathlib import Path
@@ -32,11 +33,10 @@ def bytes_to_unicode():
3233
bs.append(b)
3334
cs.append(2**8+n)
3435
n += 1
35-
cs = [chr(n) for n in cs]
36-
return dict(zip(bs, cs))
36+
return dict(zip(bs, (chr(n) for n in cs)))
3737

3838

39-
def count_model_parts(dir_model: str) -> int:
39+
def count_model_parts(dir_model: Path) -> int:
4040
num_parts = 0
4141
for filename in os.listdir(dir_model):
4242
if filename.startswith("pytorch_model-"):
@@ -47,16 +47,21 @@ def count_model_parts(dir_model: str) -> int:
4747
return num_parts
4848

4949

50-
if len(sys.argv) < 3:
51-
print(f"Usage: python {sys.argv[0]} dir-model ftype\n")
52-
print(" ftype == 0 -> float32")
53-
print(" ftype == 1 -> float16")
54-
sys.exit(1)
50+
def parse_args() -> argparse.Namespace:
51+
parser = argparse.ArgumentParser(description="Convert a Falcon model to a GGML compatible file")
52+
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
53+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
54+
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
55+
parser.add_argument("ftype", type=int, choices=[0, 1], help="output format - use 0 for float32, 1 for float16", default = 1)
56+
return parser.parse_args()
5557

58+
args = parse_args()
5659

57-
# output in the same directory as the model
58-
dir_model = sys.argv[1]
59-
last_dir = os.path.basename(os.path.normpath(dir_model))
60+
dir_model = args.model
61+
ftype = args.ftype
62+
if not dir_model.is_dir():
63+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
64+
sys.exit(1)
6065

6166
# possible tensor data types
6267
# ftype == 0 -> float32
@@ -65,25 +70,21 @@ def count_model_parts(dir_model: str) -> int:
6570
# map from ftype to string
6671
ftype_str = ["f32", "f16"]
6772

68-
ftype = 1
69-
if len(sys.argv) > 2:
70-
ftype = int(sys.argv[2])
71-
if ftype < 0 or ftype > 1:
72-
print("Invalid ftype: " + str(ftype))
73-
74-
sys.exit(1)
75-
76-
fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf"
73+
if args.outfile is not None:
74+
fname_out = args.outfile
75+
else:
76+
# output in the same directory as the model by default
77+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
7778

78-
print("gguf: loading model "+last_dir)
79+
print("gguf: loading model "+dir_model.name)
7980

80-
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
81+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
8182
hparams = json.load(f)
8283

8384
if hparams["architectures"][0] != "RWForCausalLM":
8485
print("Model architecture not supported: " + hparams["architectures"][0])
8586

86-
sys.exit()
87+
sys.exit(1)
8788

8889
# get number of model parts
8990
num_parts = count_model_parts(dir_model)
@@ -113,77 +114,58 @@ def count_model_parts(dir_model: str) -> int:
113114

114115
print("gguf: get tokenizer metadata")
115116

116-
tokens: List[str] = []
117+
tokens: List[bytearray] = []
117118
scores: List[float] = []
118119
toktypes: List[int] = []
119-
merges: List[str] = []
120-
121-
122-
if Path(dir_model + "/tokenizer.json").is_file():
123-
# gpt2 tokenizer
124-
gguf_writer.add_tokenizer_model("gpt2")
125120

126-
print("gguf: get gpt2 tokenizer merges")
127-
128-
with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f:
129-
tokenizer_json = json.load(f)
130-
merges = tokenizer_json["model"]["merges"]
131-
132-
gguf_writer.add_token_merges(merges)
133-
134-
print("gguf: get gpt2 tokenizer vocab")
135-
136-
vocab_size = len(tokenizer_json["model"]["vocab"])
137-
138-
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
139-
tokenizer = AutoTokenizer.from_pretrained(dir_model)
140-
141-
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
142-
byte_encoder = bytes_to_unicode()
143-
byte_decoder = {v: k for k, v in byte_encoder.items()}
121+
tokenizer_json_file = dir_model / 'tokenizer.json'
122+
if not tokenizer_json_file.is_file():
123+
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
124+
sys.exit(1)
144125

145-
for i in range(vocab_size):
146-
if i in reverse_vocab:
147-
try:
148-
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
149-
except KeyError:
150-
text = bytearray()
151-
for c in reverse_vocab[i]:
152-
if ord(c) < 256: # single byte character
153-
text.append(byte_decoder[ord(c)])
154-
else: # multibyte special token character
155-
text.extend(c.encode('utf-8'))
156-
else:
157-
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
158-
pad_token = f"[PAD{i}]".encode("utf8")
159-
text = bytearray(pad_token)
126+
# gpt2 tokenizer
127+
gguf_writer.add_tokenizer_model("gpt2")
160128

161-
tokens.append(text)
162-
scores.append(0.0) # dymmy
163-
toktypes.append(gguf.TokenType.NORMAL) # dummy
129+
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
130+
tokenizer_json = json.load(f)
164131

165-
gguf_writer.add_token_list(tokens)
166-
gguf_writer.add_token_scores(scores)
167-
gguf_writer.add_token_types(toktypes)
132+
print("gguf: get gpt2 tokenizer vocab")
168133

169-
print("gguf: get special token ids")
170-
# Look for special tokens in config.json
134+
vocab_size = len(tokenizer_json["model"]["vocab"])
171135

172-
if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
173-
gguf_writer.add_bos_token_id(hparams["bos_token_id"])
136+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
137+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
174138

175-
if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
176-
gguf_writer.add_eos_token_id(hparams["eos_token_id"])
139+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
140+
byte_encoder = bytes_to_unicode()
141+
byte_decoder = {v: k for k, v in byte_encoder.items()}
177142

178-
if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
179-
gguf_writer.add_unk_token_id(hparams["unk_token_id"])
143+
for i in range(vocab_size):
144+
if i in reverse_vocab:
145+
try:
146+
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
147+
except KeyError:
148+
text = bytearray()
149+
for c in reverse_vocab[i]:
150+
if ord(c) < 256: # single byte character
151+
text.append(byte_decoder[ord(c)])
152+
else: # multibyte special token character
153+
text.extend(c.encode('utf-8'))
154+
else:
155+
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
156+
pad_token = f"[PAD{i}]".encode("utf8")
157+
text = bytearray(pad_token)
180158

181-
if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
182-
gguf_writer.add_sep_token_id(hparams["sep_token_id"])
159+
tokens.append(text)
160+
scores.append(0.0) # dymmy
161+
toktypes.append(gguf.TokenType.NORMAL) # dummy
183162

184-
if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
185-
gguf_writer.add_pad_token_id(hparams["pad_token_id"])
163+
gguf_writer.add_token_list(tokens)
164+
gguf_writer.add_token_scores(scores)
165+
gguf_writer.add_token_types(toktypes)
186166

167+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
168+
special_vocab.add_to_gguf(gguf_writer)
187169

188170
# TENSORS
189171

@@ -199,15 +181,17 @@ def count_model_parts(dir_model: str) -> int:
199181
print("gguf: get tensor metadata")
200182

201183
if num_parts == 0:
202-
part_names = ("pytorch_model.bin",)
184+
part_names = iter(("pytorch_model.bin",))
203185
else:
204186
part_names = (
205187
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
206188
)
207189

208190
for part_name in part_names:
191+
if args.vocab_only:
192+
break
209193
print("gguf: loading model part '" + part_name + "'")
210-
model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu")
194+
model_part = torch.load(dir_model / part_name, map_location="cpu")
211195

212196
for name in model_part.keys():
213197
data = model_part[name]
@@ -238,11 +222,8 @@ def count_model_parts(dir_model: str) -> int:
238222
data = data.squeeze().numpy()
239223

240224
# map tensor names
241-
if name.endswith(".weight") and name[:-7] in tensor_map:
242-
name = tensor_map[name[:-7]] + ".weight"
243-
elif name.endswith(".bias") and name[:-5] in tensor_map:
244-
name = tensor_map[name[:-5]] + ".bias"
245-
else:
225+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
226+
if new_name is None:
246227
print("Can not map tensor '" + name + "'")
247228
sys.exit()
248229

@@ -261,19 +242,20 @@ def count_model_parts(dir_model: str) -> int:
261242
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
262243
data = data.astype(np.float16)
263244

264-
print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
245+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
265246

266-
gguf_writer.add_tensor(name, data)
247+
gguf_writer.add_tensor(new_name, data)
267248

268249

269250
print("gguf: write header")
270251
gguf_writer.write_header_to_file()
271252
print("gguf: write metadata")
272253
gguf_writer.write_kv_data_to_file()
273-
print("gguf: write tensors")
274-
gguf_writer.write_tensors_to_file()
254+
if not args.vocab_only:
255+
print("gguf: write tensors")
256+
gguf_writer.write_tensors_to_file()
275257

276258
gguf_writer.close()
277259

278-
print("gguf: model successfully exported to '" + fname_out + "'")
260+
print(f"gguf: model successfully exported to '{fname_out}'")
279261
print("")

0 commit comments

Comments
 (0)