Skip to content

Commit 1030368

Browse files
committed
add stablelm from PR StableLM support ggml-org#3586
1 parent df51a6e commit 1030368

File tree

3 files changed

+720
-2
lines changed

3 files changed

+720
-2
lines changed

convert-stablelm-hf-to-gguf.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#!/usr/bin/env python3
2+
# HF stablelm --> gguf conversion
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import json
8+
import os
9+
import struct
10+
import sys
11+
from pathlib import Path
12+
from typing import Any
13+
14+
import numpy as np
15+
import torch
16+
from transformers import AutoTokenizer # type: ignore[import]
17+
try:
18+
from safetensors import safe_open
19+
except ImportError:
20+
print("Please install `safetensors` python package")
21+
sys.exit(1)
22+
23+
if 'NO_LOCAL_GGUF' not in os.environ:
24+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
25+
import gguf
26+
27+
28+
def parse_args() -> argparse.Namespace:
29+
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file")
30+
parser.add_argument(
31+
"--vocab-only", action="store_true",
32+
help="extract only the vocab",
33+
)
34+
parser.add_argument(
35+
"--outfile", type=Path,
36+
help="path to write to; default: based on input",
37+
)
38+
parser.add_argument(
39+
"model", type=Path,
40+
help="directory containing model file, or model file itself (*.bin)",
41+
)
42+
parser.add_argument(
43+
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
44+
help="output format - use 0 for float32, 1 for float16",
45+
)
46+
return parser.parse_args()
47+
48+
args = parse_args()
49+
50+
dir_model = args.model
51+
ftype = args.ftype
52+
if not dir_model.is_dir():
53+
print(f'Error: {args.model} is not a directory', file = sys.stderr)
54+
sys.exit(1)
55+
56+
# possible tensor data types
57+
# ftype == 0 -> float32
58+
# ftype == 1 -> float16
59+
60+
# map from ftype to string
61+
ftype_str = ["f32", "f16"]
62+
63+
if args.outfile is not None:
64+
fname_out = args.outfile
65+
else:
66+
# output in the same directory as the model by default
67+
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
68+
69+
print("gguf: loading model "+dir_model.name)
70+
71+
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
72+
hparams = json.load(f)
73+
74+
if hparams["architectures"][0] != "StableLMEpochForCausalLM":
75+
print("Model architecture not supported: " + hparams["architectures"][0])
76+
77+
sys.exit()
78+
79+
80+
ARCH=gguf.MODEL_ARCH.STABLELM
81+
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
82+
83+
print("gguf: get model metadata")
84+
85+
block_count = hparams["num_hidden_layers"]
86+
87+
gguf_writer.add_name(dir_model.name)
88+
gguf_writer.add_context_length(hparams["max_position_embeddings"])
89+
gguf_writer.add_embedding_length(hparams["hidden_size"])
90+
gguf_writer.add_block_count(block_count)
91+
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
92+
gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
93+
gguf_writer.add_head_count(hparams["num_attention_heads"])
94+
gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
95+
gguf_writer.add_layer_norm_eps(1e-5)
96+
97+
# TOKENIZATION
98+
99+
print("gguf: get tokenizer metadata")
100+
101+
tokens: list[bytearray] = []
102+
scores: list[float] = []
103+
toktypes: list[int] = []
104+
105+
# gpt2 tokenizer
106+
gguf_writer.add_tokenizer_model("gpt2")
107+
108+
print("gguf: get gpt2 tokenizer vocab")
109+
110+
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
111+
tokenizer = AutoTokenizer.from_pretrained(dir_model)
112+
113+
# The number of tokens in tokenizer.json can differ from the expected vocab size.
114+
# This causes downstream issues with mismatched tensor sizes when running the inference
115+
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
116+
assert max(tokenizer.vocab.values()) < vocab_size
117+
118+
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
119+
120+
for i in range(vocab_size):
121+
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
122+
scores.append(0.0) # dummy
123+
toktypes.append(gguf.TokenType.NORMAL)
124+
125+
gguf_writer.add_token_list(tokens)
126+
gguf_writer.add_token_scores(scores)
127+
gguf_writer.add_token_types(toktypes)
128+
129+
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
130+
special_vocab.add_to_gguf(gguf_writer)
131+
132+
# TENSORS
133+
134+
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
135+
136+
# tensor info
137+
print("gguf: get tensor metadata")
138+
139+
part_names = iter(("model.safetensors",))
140+
141+
for part_name in part_names:
142+
if args.vocab_only:
143+
break
144+
print("gguf: loading model part '" + part_name + "'")
145+
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
146+
with ctx as model_part:
147+
for name in model_part.keys():
148+
data = model_part.get_tensor(name)
149+
150+
# we don't need these
151+
if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"):
152+
continue
153+
154+
old_dtype = data.dtype
155+
156+
# convert any unsupported data types to float32
157+
if data.dtype != torch.float16 and data.dtype != torch.float32:
158+
data = data.to(torch.float32)
159+
160+
data = data.squeeze().numpy()
161+
162+
# map tensor names
163+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
164+
if new_name is None:
165+
print("Can not map tensor '" + name + "'")
166+
sys.exit()
167+
168+
n_dims = len(data.shape)
169+
data_dtype = data.dtype
170+
171+
# if f32 desired, convert any float16 to float32
172+
if ftype == 0 and data_dtype == np.float16:
173+
data = data.astype(np.float32)
174+
175+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
176+
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
177+
data = data.astype(np.float32)
178+
179+
# if f16 desired, convert any float32 2-dim weight tensors to float16
180+
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
181+
data = data.astype(np.float16)
182+
183+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
184+
185+
gguf_writer.add_tensor(new_name, data)
186+
187+
188+
print("gguf: write header")
189+
gguf_writer.write_header_to_file()
190+
print("gguf: write metadata")
191+
gguf_writer.write_kv_data_to_file()
192+
if not args.vocab_only:
193+
print("gguf: write tensors")
194+
gguf_writer.write_tensors_to_file()
195+
196+
gguf_writer.close()
197+
198+
print(f"gguf: model successfully exported to '{fname_out}'")
199+
print("")

gguf-py/gguf/gguf.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class MODEL_ARCH(IntEnum):
9191
BERT : int = auto()
9292
BLOOM : int = auto()
9393
PLAMO : int = auto()
94-
94+
STABLELM : int = auto()
9595

9696
class MODEL_TENSOR(IntEnum):
9797
TOKEN_EMBD : int = auto()
@@ -131,6 +131,7 @@ class MODEL_TENSOR(IntEnum):
131131
MODEL_ARCH.BERT: "bert",
132132
MODEL_ARCH.BLOOM: "bloom",
133133
MODEL_ARCH.PLAMO: "plamo",
134+
MODEL_ARCH.STABLELM: "stablelm",
134135
}
135136

136137
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -330,7 +331,21 @@ class MODEL_TENSOR(IntEnum):
330331
MODEL_TENSOR.FFN_DOWN,
331332
MODEL_TENSOR.FFN_UP,
332333
],
333-
# TODO
334+
MODEL_ARCH.STABLELM: [
335+
MODEL_TENSOR.TOKEN_EMBD,
336+
MODEL_TENSOR.OUTPUT_NORM,
337+
MODEL_TENSOR.OUTPUT,
338+
MODEL_TENSOR.ROPE_FREQS,
339+
MODEL_TENSOR.ATTN_NORM,
340+
MODEL_TENSOR.ATTN_Q,
341+
MODEL_TENSOR.ATTN_K,
342+
MODEL_TENSOR.ATTN_V,
343+
MODEL_TENSOR.ATTN_OUT,
344+
MODEL_TENSOR.FFN_NORM,
345+
MODEL_TENSOR.FFN_GATE,
346+
MODEL_TENSOR.FFN_DOWN,
347+
MODEL_TENSOR.FFN_UP,
348+
], # TODO
334349
}
335350

336351
# tensors that will not be serialized

0 commit comments

Comments
 (0)