Skip to content

Commit 02d2875

Browse files
llm : add bloom models (#3553)
* feat: Support bloom models * fix(bloom): fix model size --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 0aa6595 commit 02d2875

File tree

3 files changed

+678
-55
lines changed

3 files changed

+678
-55
lines changed

Diff for: convert-bloom-hf-to-gguf.py

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
#!/usr/bin/env python3
2+
# HF bloom --> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import re
10+
import struct
11+
import sys
12+
from pathlib import Path
13+
from typing import Any
14+
15+
import numpy as np
16+
import torch
17+
from transformers import AutoTokenizer # type: ignore[import]
18+
19+
if 'NO_LOCAL_GGUF' not in os.environ:
20+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
21+
import gguf
22+
23+
24+
def count_model_parts(dir_model: Path) -> int:
25+
num_parts = 0
26+
for filename in os.listdir(dir_model):
27+
if filename.startswith("pytorch_model-"):
28+
num_parts += 1
29+
30+
if num_parts > 0:
31+
print("gguf: found " + str(num_parts) + " model parts")
32+
return num_parts
33+
34+
35+
# Supported Models:
36+
# https://huggingface.co/bigscience/bloom-1b7
37+
# https://huggingface.co/bigscience/bloom-3b
38+
# https://huggingface.co/bigscience/bloom-7b1
39+
# https://huggingface.co/Langboat/bloom-1b4-zh
40+
def parse_args() -> argparse.Namespace:
41+
parser = argparse.ArgumentParser(description="Convert a Bloom model to a GGML compatible file")
42+
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
43+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
44+
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
45+
parser.add_argument("ftype", type=int, help="output format - use 0 for float32, 1 for float16", choices=[0, 1], default = 1)
46+
return parser.parse_args()
47+
48+
args = parse_args()
49+
50+
dir_model = args.model
51+
ftype = args.ftype
52+
if not dir_model.is_dir():
53+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
54+
sys.exit(1)
55+
56+
# possible tensor data types
57+
# ftype == 0 -> float32
58+
# ftype == 1 -> float16
59+
60+
# map from ftype to string
61+
ftype_str = ["f32", "f16"]
62+
63+
if args.outfile is not None:
64+
fname_out = args.outfile
65+
else:
66+
# output in the same directory as the model by default
67+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
68+
69+
print("gguf: loading model "+dir_model.name)
70+
71+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
72+
hparams = json.load(f)
73+
74+
if hparams["architectures"][0] != "BloomForCausalLM":
75+
print("Model architecture not supported: " + hparams["architectures"][0])
76+
sys.exit(1)
77+
78+
# get number of model parts
79+
num_parts = count_model_parts(dir_model)
80+
81+
ARCH=gguf.MODEL_ARCH.BLOOM
82+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
83+
84+
print("gguf: get model metadata")
85+
86+
block_count = hparams["n_layer"]
87+
88+
gguf_writer.add_name("Bloom")
89+
n_embed = hparams.get("hidden_size", hparams.get("n_embed"))
90+
n_head = hparams.get("n_head", hparams.get("num_attention_heads"))
91+
gguf_writer.add_context_length(hparams.get("seq_length", n_embed))
92+
gguf_writer.add_embedding_length(n_embed)
93+
gguf_writer.add_feed_forward_length(4 * n_embed)
94+
gguf_writer.add_block_count(block_count)
95+
gguf_writer.add_head_count(n_head)
96+
gguf_writer.add_head_count_kv(n_head)
97+
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
98+
gguf_writer.add_file_type(ftype)
99+
100+
# TOKENIZATION
101+
102+
print("gguf: get tokenizer metadata")
103+
104+
tokens: list[bytearray] = []
105+
scores: list[float] = []
106+
toktypes: list[int] = []
107+
108+
# gpt2 tokenizer
109+
gguf_writer.add_tokenizer_model("gpt2")
110+
111+
print("gguf: get gpt2 tokenizer vocab")
112+
113+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
114+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
115+
116+
# The number of tokens in tokenizer.json can differ from the expected vocab size.
117+
# This causes downstream issues with mismatched tensor sizes when running the inference
118+
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
119+
assert max(tokenizer.vocab.values()) < vocab_size
120+
121+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
122+
123+
for i in range(vocab_size):
124+
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
125+
scores.append(0.0) # dummy
126+
toktypes.append(gguf.TokenType.NORMAL)
127+
128+
gguf_writer.add_token_list(tokens)
129+
gguf_writer.add_token_scores(scores)
130+
gguf_writer.add_token_types(toktypes)
131+
132+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
133+
special_vocab.add_to_gguf(gguf_writer)
134+
135+
# TENSORS
136+
137+
tensor_map = gguf.get_tensor_name_map(ARCH, block_count)
138+
139+
# params for qkv transform
140+
n_head_kv = hparams.get("n_head_kv", n_head)
141+
head_dim = n_embed // n_head
142+
143+
# tensor info
144+
print("gguf: get tensor metadata")
145+
146+
if num_parts == 0:
147+
part_names = iter(("pytorch_model.bin",))
148+
else:
149+
part_names = (
150+
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
151+
)
152+
153+
for part_name in part_names:
154+
if args.vocab_only:
155+
break
156+
print("gguf: loading model part '" + part_name + "'")
157+
model_part = torch.load(dir_model / part_name, map_location="cpu")
158+
159+
has_lm_head = True
160+
if "lm_head.weight" not in model_part.keys() and "output.weight" not in model_part.keys():
161+
has_lm_head = False
162+
163+
for original_name in model_part.keys():
164+
data = model_part[original_name]
165+
name = re.sub(r'transformer\.', '', original_name)
166+
167+
old_dtype = data.dtype
168+
169+
# convert any unsupported data types to float32
170+
if data.dtype != torch.float16 and data.dtype != torch.float32:
171+
data = data.to(torch.float32)
172+
173+
data = data.squeeze().numpy()
174+
175+
if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name):
176+
# Map bloom-style qkv_linear to gpt-style qkv_linear
177+
# bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa
178+
# gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa
179+
qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed))
180+
data = np.concatenate(
181+
(qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
182+
qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
183+
qkv_weights[:, 2, :, :].reshape((-1, n_embed))),
184+
axis=0
185+
)
186+
print("re-format attention.linear_qkv.weight")
187+
elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name):
188+
qkv_bias = data.reshape((n_head, 3, n_embed // n_head))
189+
data = np.concatenate(
190+
(qkv_bias[:, 0, :].reshape((n_embed,)),
191+
qkv_bias[:, 1, :].reshape((n_embed,)),
192+
qkv_bias[:, 2, :].reshape((n_embed,))),
193+
axis=0
194+
)
195+
print("re-format attention.linear_qkv.bias")
196+
197+
# map tensor names
198+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
199+
if new_name is None:
200+
print("Can not map tensor '" + name + "'")
201+
sys.exit()
202+
203+
n_dims = len(data.shape)
204+
data_dtype = data.dtype
205+
206+
# if f32 desired, convert any float16 to float32
207+
if ftype == 0 and data_dtype == np.float16:
208+
data = data.astype(np.float32)
209+
210+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
211+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
212+
data = data.astype(np.float32)
213+
214+
# if f16 desired, convert any float32 2-dim weight tensors to float16
215+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
216+
data = data.astype(np.float16)
217+
218+
print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
219+
220+
gguf_writer.add_tensor(new_name, data)
221+
222+
if not has_lm_head and name == "word_embeddings.weight":
223+
gguf_writer.add_tensor("output.weight", data)
224+
print(name, "=>", "output.weight" + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype)) # noqa
225+
226+
227+
print("gguf: write header")
228+
gguf_writer.write_header_to_file()
229+
print("gguf: write metadata")
230+
gguf_writer.write_kv_data_to_file()
231+
if not args.vocab_only:
232+
print("gguf: write tensors")
233+
gguf_writer.write_tensors_to_file()
234+
235+
gguf_writer.close()
236+
237+
print(f"gguf: model successfully exported to '{fname_out}'")
238+
print("")

0 commit comments

Comments
 (0)