Skip to content

Commit f957027

Browse files
nopperlcompilade
authored andcommitted
llama : add support for Chameleon (ggml-org#8543)
* convert chameleon hf to gguf * add chameleon tokenizer tests * fix lint * implement chameleon graph * add swin norm param * return qk norm weights and biases to original format * implement swin norm * suppress image token output * rem tabs * add comment to conversion * fix ci * check for k norm separately * adapt to new lora implementation * fix layer input for swin norm * move swin_norm in gguf writer * add comment regarding special token regex in chameleon pre-tokenizer * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * fix punctuation regex in chameleon pre-tokenizer (@compilade) Co-authored-by: compilade <[email protected]> * fix lint * trigger ci --------- Co-authored-by: compilade <[email protected]>
1 parent 374357d commit f957027

10 files changed

+505
-2
lines changed

convert_hf_to_gguf.py

+44
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
640640
if chkhsh == "fcace8b9cac38ce847670c970cd5892031a753a1ef381abd1d9af00f713da085":
641641
# ref: https://huggingface.co/microsoft/phi-2
642642
res = "phi-2"
643+
if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450":
644+
# ref: https://huggingface.co/facebook/chameleon-7b
645+
res = "chameleon"
643646

644647
if res is None:
645648
logger.warning("\n")
@@ -4138,6 +4141,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
41384141
return super().modify_tensors(data_torch, name, bid)
41394142

41404143

4144+
@Model.register("ChameleonForCausalLM")
4145+
class ChameleonModel(Model):
4146+
model_arch = gguf.MODEL_ARCH.CHAMELEON
4147+
4148+
def set_gguf_parameters(self):
4149+
super().set_gguf_parameters()
4150+
self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False))
4151+
4152+
def set_vocab(self):
4153+
self._set_vocab_gpt2()
4154+
4155+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4156+
# ignore image tokenizer for now
4157+
# TODO: remove this once image support is implemented for Chameleon
4158+
if name.startswith("model.vqmodel"):
4159+
return []
4160+
4161+
n_head = self.hparams["num_attention_heads"]
4162+
n_kv_head = self.hparams.get("num_key_value_heads")
4163+
hidden_dim = self.hparams.get("hidden_size")
4164+
4165+
if name.endswith(("q_proj.weight", "q_proj.bias")):
4166+
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
4167+
if name.endswith(("k_proj.weight", "k_proj.bias")):
4168+
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
4169+
if name.endswith(("q_norm.weight", "q_norm.bias")):
4170+
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim)
4171+
if name.endswith(("k_norm.weight", "k_norm.bias")):
4172+
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim)
4173+
4174+
return [(self.map_tensor_name(name), data_torch)]
4175+
4176+
# see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203
4177+
@staticmethod
4178+
def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
4179+
head_dim = hidden_dim // n_heads
4180+
data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1)
4181+
data_torch = data_torch.repeat_interleave(n_heads, 0)
4182+
return data_torch
4183+
4184+
41414185
###### CONVERSION LOGIC ######
41424186

41434187

convert_hf_to_gguf_update.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class TOKENIZER_TYPE(IntEnum):
9999
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
100100
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
101101
{"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", },
102+
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
102103
]
103104

104105

gguf-py/gguf/constants.py

+19
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class LLM:
9494
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
9595
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
9696
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
97+
SWIN_NORM = "{arch}.swin_norm"
9798
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
9899
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
99100
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
@@ -236,6 +237,7 @@ class MODEL_ARCH(IntEnum):
236237
EXAONE = auto()
237238
GRANITE = auto()
238239
GRANITE_MOE = auto()
240+
CHAMELEON = auto()
239241

240242

241243
class MODEL_TENSOR(IntEnum):
@@ -394,6 +396,7 @@ class MODEL_TENSOR(IntEnum):
394396
MODEL_ARCH.EXAONE: "exaone",
395397
MODEL_ARCH.GRANITE: "granite",
396398
MODEL_ARCH.GRANITE_MOE: "granitemoe",
399+
MODEL_ARCH.CHAMELEON: "chameleon",
397400
}
398401

399402
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1260,6 +1263,22 @@ class MODEL_TENSOR(IntEnum):
12601263
MODEL_TENSOR.FFN_DOWN_EXP,
12611264
MODEL_TENSOR.FFN_UP_EXP,
12621265
],
1266+
MODEL_ARCH.CHAMELEON: [
1267+
MODEL_TENSOR.TOKEN_EMBD,
1268+
MODEL_TENSOR.OUTPUT_NORM,
1269+
MODEL_TENSOR.OUTPUT,
1270+
MODEL_TENSOR.ATTN_NORM,
1271+
MODEL_TENSOR.ATTN_Q,
1272+
MODEL_TENSOR.ATTN_Q_NORM,
1273+
MODEL_TENSOR.ATTN_K,
1274+
MODEL_TENSOR.ATTN_K_NORM,
1275+
MODEL_TENSOR.ATTN_V,
1276+
MODEL_TENSOR.ATTN_OUT,
1277+
MODEL_TENSOR.FFN_NORM,
1278+
MODEL_TENSOR.FFN_GATE,
1279+
MODEL_TENSOR.FFN_DOWN,
1280+
MODEL_TENSOR.FFN_UP,
1281+
],
12631282
# TODO
12641283
}
12651284

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def add_expert_shared_count(self, count: int) -> None:
670670
def add_expert_weights_scale(self, value: float) -> None:
671671
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
672672

673+
def add_swin_norm(self, value: bool) -> None:
674+
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
675+
673676
def add_rescale_every_n_layers(self, count: int) -> None:
674677
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
675678

gguf-py/gguf/tensor_mapping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ class TensorNameMap:
380380
MODEL_TENSOR.ATTN_Q_NORM: (
381381
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
382382
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
383-
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
383+
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon
384384
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
385385
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
386386
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -389,7 +389,7 @@ class TensorNameMap:
389389
MODEL_TENSOR.ATTN_K_NORM: (
390390
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
391391
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
392-
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
392+
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon
393393
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
394394
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
395395
"transformer.layers.{bid}.attn.k_norm", # openelm

include/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ extern "C" {
102102
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
103103
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
104104
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
105+
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
105106
};
106107

107108
enum llama_rope_type {

models/ggml-vocab-chameleon.gguf.inp

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
ied 4 ½ months
2+
__ggml_vocab_test__
3+
Führer
4+
__ggml_vocab_test__
5+
6+
__ggml_vocab_test__
7+
8+
__ggml_vocab_test__
9+
10+
__ggml_vocab_test__
11+
12+
__ggml_vocab_test__
13+
14+
__ggml_vocab_test__
15+
16+
17+
__ggml_vocab_test__
18+
19+
20+
21+
__ggml_vocab_test__
22+
23+
24+
25+
26+
__ggml_vocab_test__
27+
28+
29+
__ggml_vocab_test__
30+
Hello world
31+
__ggml_vocab_test__
32+
Hello world
33+
__ggml_vocab_test__
34+
Hello World
35+
__ggml_vocab_test__
36+
Hello World
37+
__ggml_vocab_test__
38+
Hello World!
39+
__ggml_vocab_test__
40+
Hello, world!
41+
__ggml_vocab_test__
42+
Hello, world!
43+
__ggml_vocab_test__
44+
this is 🦙.cpp
45+
__ggml_vocab_test__
46+
w048 7tuijk dsdfhu
47+
__ggml_vocab_test__
48+
нещо на Български
49+
__ggml_vocab_test__
50+
កាន់តែពិសេសអាចខលចេញ
51+
__ggml_vocab_test__
52+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
53+
__ggml_vocab_test__
54+
Hello
55+
__ggml_vocab_test__
56+
Hello
57+
__ggml_vocab_test__
58+
Hello
59+
__ggml_vocab_test__
60+
Hello
61+
__ggml_vocab_test__
62+
Hello
63+
__ggml_vocab_test__
64+
Hello
65+
Hello
66+
__ggml_vocab_test__
67+
(
68+
__ggml_vocab_test__
69+
70+
=
71+
__ggml_vocab_test__
72+
' era
73+
__ggml_vocab_test__
74+
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
75+
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
78+
3
79+
__ggml_vocab_test__
80+
33
81+
__ggml_vocab_test__
82+
333
83+
__ggml_vocab_test__
84+
3333
85+
__ggml_vocab_test__
86+
33333
87+
__ggml_vocab_test__
88+
333333
89+
__ggml_vocab_test__
90+
3333333
91+
__ggml_vocab_test__
92+
33333333
93+
__ggml_vocab_test__
94+
333333333
95+
__ggml_vocab_test__
96+
Cửa Việt
97+
__ggml_vocab_test__
98+
discards
99+
__ggml_vocab_test__
100+
101+
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
112+
__ggml_vocab_test__

models/ggml-vocab-chameleon.gguf.out

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
17245 16604 16403 16604 33583 18355
2+
16421 51153
3+
4+
16604
5+
16650
6+
16650 16604
7+
16581
8+
16582
9+
16582 16582
10+
16582 16582 16582
11+
16581 16582
12+
31596 17394
13+
34926 17394
14+
31596 18671
15+
34926 18671
16+
34926 18671 16384
17+
31596 16395 17394 16384
18+
34926 16395 17394 16384
19+
16811 16704 20410 16483 16631 16397 52854
20+
16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639
21+
29479 23955 17012 20103 25527 27670 17408 19005 21473 24774
22+
54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615
23+
52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392
24+
31596
25+
34926
26+
16650 31596
27+
16650 34926
28+
16696 31596
29+
16696 31596 16582 16696 31596
30+
16604 16391
31+
16582 16604 16412
32+
16390 22623
33+
31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636
34+
16384 16384 16384 16384 16384 16384
35+
16402
36+
16402 16402
37+
16402 16402 16402
38+
16402 16402 16402 16402
39+
16402 16402 16402 16402 16402
40+
16402 16402 16402 16402 16402 16402
41+
16402 16402 16402 16402 16402 16402 16402
42+
16402 16402 16402 16402 16402 16402 16402 16402
43+
16402 16402 16402 16402 16402 16402 16402 16402 16402
44+
16418 19038 16639 16448 24315 33727 16467
45+
18765 17981
46+
16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427

src/llama-vocab.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,20 @@ struct llm_tokenizer_bpe {
450450
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
451451
};
452452
break;
453+
case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
454+
// Note: in theory, the special token (sentinel and image token) regex_exprs below
455+
// are unnecessary, as they are split in `tokenizer_st_partition` anyway.
456+
// However, since the upstream pre-tokenizer uses them, they are also
457+
// included here (see https://huggingface.co/facebook/chameleon-7b).
458+
regex_exprs = {
459+
"<sentinel:[0-9]+>", // Sentinel tokens
460+
"(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
461+
"([\\t\\n]| | )", // directly from tokenizer.json
462+
"\\p{N}", // Individual digits
463+
"[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
464+
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
465+
};
466+
break;
453467
default:
454468
// default regex for BPE tokenization pre-processing
455469
regex_exprs = {

0 commit comments

Comments
 (0)