Skip to content

Commit afe9ff8

Browse files
Nexesenexnopperl
andcommitted
Add support for Chameleon ggml-org#8543
Co-Authored-By: nopperl <[email protected]>
1 parent cf1662d commit afe9ff8

8 files changed

+346
-2
lines changed

convert_hf_to_gguf.py

+45
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
617617
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
618618
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
619619
res = "exaone"
620+
if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450":
621+
# ref: https://huggingface.co/facebook/chameleon-7b
622+
res = "chameleon"
620623

621624
if res is None:
622625
logger.warning("\n")
@@ -3872,6 +3875,48 @@ def prepare_tensors(self):
38723875

38733876
super().prepare_tensors()
38743877

3878+
3879+
@Model.register("ChameleonForCausalLM")
3880+
class ChameleonModel(Model):
3881+
model_arch = gguf.MODEL_ARCH.CHAMELEON
3882+
3883+
def set_gguf_parameters(self):
3884+
super().set_gguf_parameters()
3885+
self.gguf_writer.add_swin_norm(self.hparams.get("swin_norm", False))
3886+
3887+
def set_vocab(self):
3888+
self._set_vocab_gpt2()
3889+
3890+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3891+
# ignore image tokenizer for now
3892+
# TODO: remove this once image support is implemented for Chameleon
3893+
if name.startswith("model.vqmodel"):
3894+
return []
3895+
3896+
n_head = self.hparams["num_attention_heads"]
3897+
n_kv_head = self.hparams.get("num_key_value_heads")
3898+
hidden_dim = self.hparams.get("hidden_size")
3899+
3900+
if name.endswith(("q_proj.weight", "q_proj.bias")):
3901+
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
3902+
if name.endswith(("k_proj.weight", "k_proj.bias")):
3903+
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
3904+
if name.endswith(("q_norm.weight", "q_norm.bias")):
3905+
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim)
3906+
if name.endswith(("k_norm.weight", "k_norm.bias")):
3907+
data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim)
3908+
3909+
return [(self.map_tensor_name(name), data_torch)]
3910+
3911+
# see: https://github.com/huggingface/transformers/blob/72fb02c47dbbe1999ae105319f24631cad6e2e00/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py#L176-L203
3912+
@staticmethod
3913+
def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
3914+
head_dim = hidden_dim // n_heads
3915+
data_torch = data_torch[0].view(2, head_dim // 2).t().reshape(1, -1)
3916+
data_torch = data_torch.repeat_interleave(n_heads, 0)
3917+
return data_torch
3918+
3919+
38753920
###### CONVERSION LOGIC ######
38763921

38773922

convert_hf_to_gguf_update.py

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class TOKENIZER_TYPE(IntEnum):
9797
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
9898
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
9999
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
100+
{"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", },
100101
]
101102

102103

gguf-py/gguf/constants.py

+19
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class LLM:
9494
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
9595
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
9696
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
97+
SWIN_NORM = "{arch}.swin_norm"
9798

9899
class Attention:
99100
HEAD_COUNT = "{arch}.attention.head_count"
@@ -221,6 +222,7 @@ class MODEL_ARCH(IntEnum):
221222
JAIS = auto()
222223
NEMOTRON = auto()
223224
EXAONE = auto()
225+
CHAMELEON = auto()
224226

225227

226228
class MODEL_TENSOR(IntEnum):
@@ -351,6 +353,7 @@ class MODEL_TENSOR(IntEnum):
351353
MODEL_ARCH.JAIS: "jais",
352354
MODEL_ARCH.NEMOTRON: "nemotron",
353355
MODEL_ARCH.EXAONE: "exaone",
356+
MODEL_ARCH.CHAMELEON: "chameleon",
354357
}
355358

356359
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -1100,6 +1103,22 @@ class MODEL_TENSOR(IntEnum):
11001103
MODEL_TENSOR.FFN_DOWN,
11011104
MODEL_TENSOR.FFN_UP,
11021105
],
1106+
MODEL_ARCH.CHAMELEON: [
1107+
MODEL_TENSOR.TOKEN_EMBD,
1108+
MODEL_TENSOR.OUTPUT_NORM,
1109+
MODEL_TENSOR.OUTPUT,
1110+
MODEL_TENSOR.ATTN_NORM,
1111+
MODEL_TENSOR.ATTN_Q,
1112+
MODEL_TENSOR.ATTN_Q_NORM,
1113+
MODEL_TENSOR.ATTN_K,
1114+
MODEL_TENSOR.ATTN_K_NORM,
1115+
MODEL_TENSOR.ATTN_V,
1116+
MODEL_TENSOR.ATTN_OUT,
1117+
MODEL_TENSOR.FFN_NORM,
1118+
MODEL_TENSOR.FFN_GATE,
1119+
MODEL_TENSOR.FFN_DOWN,
1120+
MODEL_TENSOR.FFN_UP,
1121+
],
11031122
# TODO
11041123
}
11051124

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ def add_expert_shared_count(self, count: int) -> None:
670670
def add_expert_weights_scale(self, value: float) -> None:
671671
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
672672

673+
def add_swin_norm(self, value: bool) -> None:
674+
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
675+
673676
def add_layer_norm_eps(self, value: float) -> None:
674677
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
675678

gguf-py/gguf/tensor_mapping.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ class TensorNameMap:
372372
MODEL_TENSOR.ATTN_Q_NORM: (
373373
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
374374
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
375-
"model.layers.{bid}.self_attn.q_norm", # cohere
375+
"model.layers.{bid}.self_attn.q_norm", # cohere chameleon
376376
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
377377
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
378378
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -381,7 +381,7 @@ class TensorNameMap:
381381
MODEL_TENSOR.ATTN_K_NORM: (
382382
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
383383
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
384-
"model.layers.{bid}.self_attn.k_norm", # cohere
384+
"model.layers.{bid}.self_attn.k_norm", # cohere chameleon
385385
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
386386
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
387387
"transformer.layers.{bid}.attn.k_norm", # openelm

include/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ extern "C" {
9696
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
9797
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
9898
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
99+
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
99100
};
100101

101102
enum llama_rope_type {

src/llama-vocab.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,20 @@ struct llm_tokenizer_bpe {
663663
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
664664
};
665665
break;
666+
case LLAMA_VOCAB_PRE_TYPE_CHAMELEON:
667+
// Note: in theory, the special token (sentinel and image token) regex_exprs below
668+
// are unnecessary, as they are split in `tokenizer_st_partition` anyway.
669+
// However, since the upstream pre-tokenizer uses them, they are also
670+
// included here (see https://huggingface.co/facebook/chameleon-7b).
671+
regex_exprs = {
672+
"<sentinel:[0-9]+>", // Sentinel tokens
673+
"(IMGIMG)((A|B|C|D|E|F|G|H|I){1,4})Z", // Image tokens
674+
"([\\t\\n]| | )", // directly from tokenizer.json
675+
"\\p{N}", // Individual digits
676+
"[\\p{P}!-/:-@\\[-`{-~]", // Punctuation, Isolated
677+
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
678+
};
679+
break;
666680
default:
667681
// default regex for BPE tokenization pre-processing
668682
regex_exprs = {

0 commit comments

Comments
 (0)