Skip to content

Commit de3c909

Browse files
committed
support glm-4-9b-chat
Signed-off-by: XingXing Qiao <[email protected]>
1 parent 35c6887 commit de3c909

File tree

6 files changed

+151
-2
lines changed

6 files changed

+151
-2
lines changed

convert-hf-to-gguf.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
483483
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
484484
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
485485
res = "jina-v2-code"
486+
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
487+
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
488+
res = "chatglm-bpe"
486489

487490
if res is None:
488491
logger.warning("\n")
@@ -2729,7 +2732,7 @@ def write_tensors(self):
27292732
class ChatGLMModel(Model):
27302733
model_arch = gguf.MODEL_ARCH.CHATGLM
27312734

2732-
def set_vocab(self):
2735+
def set_vocab_chatglm3(self):
27332736
dir_model = self.dir_model
27342737
hparams = self.hparams
27352738
tokens: list[bytearray] = []
@@ -2789,6 +2792,95 @@ def set_vocab(self):
27892792
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
27902793
special_vocab.add_to_gguf(self.gguf_writer)
27912794

2795+
@staticmethod
2796+
def token_bytes_to_string(b):
2797+
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
2798+
byte_encoder = bytes_to_unicode()
2799+
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
2800+
2801+
@staticmethod
2802+
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
2803+
parts = [bytes([b]) for b in token]
2804+
while True:
2805+
min_idx = None
2806+
min_rank = None
2807+
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
2808+
rank = mergeable_ranks.get(pair[0] + pair[1])
2809+
if rank is not None and (min_rank is None or rank < min_rank):
2810+
min_idx = i
2811+
min_rank = rank
2812+
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
2813+
break
2814+
assert min_idx is not None
2815+
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
2816+
return parts
2817+
2818+
def set_vocab(self):
2819+
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
2820+
self.set_vocab_chatglm3()
2821+
return
2822+
2823+
dir_model = self.dir_model
2824+
hparams = self.hparams
2825+
tokens: list[str] = []
2826+
toktypes: list[int] = []
2827+
2828+
from transformers import AutoTokenizer
2829+
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
2830+
vocab_size = hparams["padded_vocab_size"]
2831+
assert max(tokenizer.get_vocab().values()) < vocab_size
2832+
2833+
tokpre = self.get_vocab_base_pre(tokenizer)
2834+
2835+
merges = []
2836+
vocab = {}
2837+
mergeable_ranks = tokenizer.mergeable_ranks
2838+
for token, rank in mergeable_ranks.items():
2839+
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
2840+
if len(token) == 1:
2841+
continue
2842+
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
2843+
assert len(merged) >= 2 and len(merged) <= 7
2844+
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
2845+
2846+
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
2847+
added_vocab = tokenizer.get_added_vocab()
2848+
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
2849+
2850+
for i in range(vocab_size):
2851+
if i not in reverse_vocab:
2852+
tokens.append(f"[PAD{i}]")
2853+
toktypes.append(gguf.TokenType.USER_DEFINED)
2854+
elif reverse_vocab[i] in added_vocab:
2855+
tokens.append(reverse_vocab[i])
2856+
if tokenizer.added_tokens_decoder[i].special:
2857+
toktypes.append(gguf.TokenType.CONTROL)
2858+
else:
2859+
toktypes.append(gguf.TokenType.USER_DEFINED)
2860+
else:
2861+
tokens.append(reverse_vocab[i])
2862+
toktypes.append(gguf.TokenType.NORMAL)
2863+
2864+
self.gguf_writer.add_tokenizer_model("gpt2")
2865+
self.gguf_writer.add_tokenizer_pre(tokpre)
2866+
self.gguf_writer.add_token_list(tokens)
2867+
self.gguf_writer.add_token_types(toktypes)
2868+
2869+
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
2870+
special_vocab.chat_template = "ChatGLM4"
2871+
special_vocab.merges = merges
2872+
# only add special tokens when they were not already loaded from config.json
2873+
2874+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
2875+
2876+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
2877+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|user|>"])
2878+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|observation|>"])
2879+
special_vocab._set_special_token("eot", 151336)
2880+
# this one is usually not in config.json anyway
2881+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
2882+
special_vocab.add_to_gguf(self.gguf_writer)
2883+
27922884
def set_gguf_parameters(self):
27932885
self.gguf_writer.add_name(self.dir_model.name)
27942886
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))

examples/server/public/index-new.html

+2
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,8 @@
717717
<option value="vicuna">Tess</option>
718718
<option value="yi34b">Yi-6/9/34B-Chat</option>
719719
<option value="zephyr">Zephyr</option>
720+
<option value="chatglm3">ChatGLM3-6B</option>
721+
<option value="chatglm4">ChatGLM4-9B</option>
720722
<option value=""></option>
721723
</optgroup>
722724
</select>

examples/server/public/prompt-formats.js

+37-1
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,41 @@ export const promptFormats = {
327327
userMsgSuffix: "",
328328

329329
stops: ""
330+
},
331+
332+
// ----------------------------
333+
334+
"chatglm3": {
335+
template: `[gMASK]sop<|system|>\n {{prompt}}{{history}}<|{{char}}|>`,
336+
337+
historyTemplate: `<|{{name}}|>\n {{message}}`,
338+
339+
char: "assistant",
340+
charMsgPrefix: "",
341+
charMsgSuffix: "",
342+
343+
user: "user",
344+
userMsgPrefix: "",
345+
userMsgSuffix: "",
346+
347+
stops: ""
348+
},
349+
350+
// ----------------------------
351+
352+
"chatglm4": {
353+
template: `[gMASK]<sop><|system|>\n{{prompt}}{{history}}<|{{char}}|>`,
354+
355+
historyTemplate: `<|{{name}}|>\n{{message}}`,
356+
357+
char: "assistant",
358+
charMsgPrefix: "",
359+
charMsgSuffix: "",
360+
361+
user: "user",
362+
userMsgPrefix: "",
363+
userMsgSuffix: "",
364+
365+
stops: ""
330366
}
331-
};
367+
};

llama.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -4730,6 +4730,7 @@ static void llm_load_hparams(
47304730
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
47314731
switch (hparams.n_layer) {
47324732
case 28: model.type = e_model::MODEL_7B; break;
4733+
case 40: model.type = e_model::MODEL_8B; break;
47334734
default: model.type = e_model::MODEL_UNKNOWN;
47344735
}
47354736
} break;
@@ -4922,6 +4923,9 @@ static void llm_load_vocab(
49224923
} else if (
49234924
tokenizer_pre == "poro-chat") {
49244925
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
4926+
} else if (
4927+
tokenizer_pre == "chatglm-bpe") {
4928+
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
49254929
} else {
49264930
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
49274931
}
@@ -13369,6 +13373,7 @@ struct llm_tokenizer_bpe {
1336913373
break;
1337013374
case LLAMA_VOCAB_PRE_TYPE_DBRX:
1337113375
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
13376+
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
1337213377
regex_exprs = {
1337313378
// same as llama3
1337413379
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
@@ -18914,6 +18919,15 @@ static int32_t llama_chat_apply_template_internal(
1891418919
if (add_ass) {
1891518920
ss << "<|assistant|>";
1891618921
}
18922+
} else if (tmpl.find("ChatGLM4") != std::string::npos) {
18923+
ss << "[gMASK]" << "<sop>";
18924+
for (auto message : chat) {
18925+
std::string role(message->role);
18926+
ss << "<|" << role << "|>" << "\n" << message->content;
18927+
}
18928+
if (add_ass) {
18929+
ss << "<|assistant|>";
18930+
}
1891718931
} else {
1891818932
// template not supported
1891918933
return -1;

llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ extern "C" {
8787
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
8888
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
8989
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
90+
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16,
9091
};
9192

9293
// note: these values should be synchronized with ggml_rope

tests/test-chat-template.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ int main(void) {
5959
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
6060
// ChatGLM3
6161
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
62+
// ChatGLM4
63+
"ChatGLM4",
6264
};
6365
std::vector<std::string> expected_output = {
6466
// teknium/OpenHermes-2.5-Mistral-7B
@@ -97,6 +99,8 @@ int main(void) {
9799
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
98100
// ChatGLM3
99101
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
102+
// ChatGLM4
103+
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
100104
};
101105
std::vector<char> formatted_chat(1024);
102106
int32_t res;

0 commit comments

Comments
 (0)