Skip to content

Commit 7e177b7

Browse files
committed
model: support phimoe
#9119
1 parent d79d8f3 commit 7e177b7

File tree

5 files changed

+203
-30
lines changed

5 files changed

+203
-30
lines changed

convert_hf_to_gguf.py

+56
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,62 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
25642564
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
25652565
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
25662566

2567+
@Model.register("PhiMoEForCausalLM")
2568+
class PhiMoeModel(Phi3MiniModel):
2569+
model_arch = gguf.MODEL_ARCH.PHIMOE
2570+
2571+
_experts: list[dict[str, Tensor]] | None = None
2572+
2573+
def set_gguf_parameters(self):
2574+
super().set_gguf_parameters()
2575+
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
2576+
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
2577+
2578+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2579+
# process the experts separately
2580+
if name.find("block_sparse_moe.experts") != -1:
2581+
n_experts = self.hparams["num_local_experts"]
2582+
assert bid is not None
2583+
2584+
if self._experts is None:
2585+
self._experts = [{} for _ in range(self.block_count)]
2586+
2587+
self._experts[bid][name] = data_torch
2588+
2589+
if len(self._experts[bid]) >= n_experts * 3:
2590+
tensors: list[tuple[str, Tensor]] = []
2591+
2592+
# merge the experts into a single 3d tensor
2593+
for w_name in ["w1", "w2", "w3"]:
2594+
datas: list[Tensor] = []
2595+
2596+
for xid in range(n_experts):
2597+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
2598+
datas.append(self._experts[bid][ename])
2599+
del self._experts[bid][ename]
2600+
2601+
data_torch = torch.stack(datas, dim=0)
2602+
2603+
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
2604+
2605+
new_name = self.map_tensor_name(merged_name)
2606+
2607+
tensors.append((new_name, data_torch))
2608+
return tensors
2609+
else:
2610+
return []
2611+
2612+
return [(self.map_tensor_name(name), data_torch)]
2613+
2614+
def prepare_tensors(self):
2615+
super().prepare_tensors()
2616+
2617+
if self._experts is not None:
2618+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2619+
experts = [k for d in self._experts for k in d.keys()]
2620+
if len(experts) > 0:
2621+
raise ValueError(f"Unprocessed experts: {experts}")
2622+
25672623

25682624
@Model.register("PlamoForCausalLM")
25692625
class PlamoModel(Model):

docs/development/HOWTO-add-model.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
2828
```python
2929
@Model.register("MyModelForCausalLM")
3030
class MyModel(Model):
31-
model_arch = gguf.MODEL_ARCH.GROK
31+
model_arch = gguf.MODEL_ARCH.MYMODEL
3232
```
3333

3434
2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
@@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
7979
- `Model#set_vocab`
8080
- `Model#write_tensors`
8181

82-
NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
82+
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
8383

8484
### 2. Define the model architecture in `llama.cpp`
8585

8686
The model params and tensors layout must be defined in `llama.cpp`:
8787
1. Define a new `llm_arch`
8888
2. Define the tensors layout in `LLM_TENSOR_NAMES`
89-
3. Add any non standard metadata in `llm_load_hparams`
89+
3. Add any non-standard metadata in `llm_load_hparams`
9090
4. Create the tensors for inference in `llm_load_tensors`
9191
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
9292

@@ -98,7 +98,7 @@ This is the funniest part, you have to provide the inference graph implementatio
9898

9999
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
100100

101-
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
101+
Some `ggml` backends do not support all operations, backend implementation can be added in a separate PR.
102102

103103
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
104104

gguf-py/gguf/constants.py

+20
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class MODEL_ARCH(IntEnum):
242242
QWEN2VL = auto()
243243
PHI2 = auto()
244244
PHI3 = auto()
245+
PHIMOE = auto()
245246
PLAMO = auto()
246247
CODESHELL = auto()
247248
ORION = auto()
@@ -424,6 +425,7 @@ class MODEL_TENSOR(IntEnum):
424425
MODEL_ARCH.QWEN2VL: "qwen2vl",
425426
MODEL_ARCH.PHI2: "phi2",
426427
MODEL_ARCH.PHI3: "phi3",
428+
MODEL_ARCH.PHIMOE: "phimoe",
427429
MODEL_ARCH.PLAMO: "plamo",
428430
MODEL_ARCH.CODESHELL: "codeshell",
429431
MODEL_ARCH.ORION: "orion",
@@ -934,6 +936,24 @@ class MODEL_TENSOR(IntEnum):
934936
MODEL_TENSOR.FFN_DOWN,
935937
MODEL_TENSOR.FFN_UP,
936938
],
939+
MODEL_ARCH.PHIMOE: [
940+
MODEL_TENSOR.TOKEN_EMBD,
941+
MODEL_TENSOR.OUTPUT_NORM,
942+
MODEL_TENSOR.OUTPUT,
943+
MODEL_TENSOR.ROPE_FACTORS_LONG,
944+
MODEL_TENSOR.ROPE_FACTORS_SHORT,
945+
MODEL_TENSOR.ATTN_NORM,
946+
MODEL_TENSOR.ATTN_QKV,
947+
MODEL_TENSOR.ATTN_Q,
948+
MODEL_TENSOR.ATTN_K,
949+
MODEL_TENSOR.ATTN_V,
950+
MODEL_TENSOR.ATTN_OUT,
951+
MODEL_TENSOR.FFN_NORM,
952+
MODEL_TENSOR.FFN_GATE_INP,
953+
MODEL_TENSOR.FFN_GATE_EXP,
954+
MODEL_TENSOR.FFN_DOWN_EXP,
955+
MODEL_TENSOR.FFN_UP_EXP,
956+
],
937957
MODEL_ARCH.CODESHELL: [
938958
MODEL_TENSOR.TOKEN_EMBD,
939959
MODEL_TENSOR.POS_EMBD,

gguf-py/gguf/tensor_mapping.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TensorNameMap:
5555
# Output
5656
MODEL_TENSOR.OUTPUT: (
5757
"embed_out", # gptneox
58-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
58+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
5959
"output", # llama-pth bloom internlm2
6060
"word_embeddings_for_head", # persimmon
6161
"lm_head.linear", # phi2
@@ -68,7 +68,7 @@ class TensorNameMap:
6868
MODEL_TENSOR.OUTPUT_NORM: (
6969
"gpt_neox.final_layer_norm", # gptneox
7070
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
71-
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
71+
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
7272
"norm", # llama-pth
7373
"transformer.norm_f", # mpt dbrx
7474
"ln_f", # refact bloom qwen gpt2
@@ -108,7 +108,7 @@ class TensorNameMap:
108108
"transformer.h.{bid}.input_layernorm", # falcon7b
109109
"h.{bid}.input_layernorm", # bloom
110110
"transformer.h.{bid}.ln_mlp", # falcon40b
111-
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
111+
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
112112
"layers.{bid}.attention_norm", # llama-pth
113113
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
114114
"model.layers.{bid}.ln1", # yi
@@ -152,7 +152,7 @@ class TensorNameMap:
152152

153153
# Attention query
154154
MODEL_TENSOR.ATTN_Q: (
155-
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
155+
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
156156
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
157157
"layers.{bid}.attention.wq", # llama-pth
158158
"encoder.layer.{bid}.attention.self.query", # bert
@@ -165,7 +165,7 @@ class TensorNameMap:
165165

166166
# Attention key
167167
MODEL_TENSOR.ATTN_K: (
168-
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
168+
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
169169
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
170170
"layers.{bid}.attention.wk", # llama-pth
171171
"encoder.layer.{bid}.attention.self.key", # bert
@@ -179,7 +179,7 @@ class TensorNameMap:
179179

180180
# Attention value
181181
MODEL_TENSOR.ATTN_V: (
182-
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
182+
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
183183
"layers.{bid}.attention.wv", # llama-pth
184184
"encoder.layer.{bid}.attention.self.value", # bert
185185
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -197,7 +197,7 @@ class TensorNameMap:
197197
"transformer.blocks.{bid}.attn.out_proj", # mpt
198198
"transformer.h.{bid}.self_attention.dense", # falcon
199199
"h.{bid}.self_attention.dense", # bloom
200-
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
200+
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
201201
"model.layers.{bid}.self_attn.linear_attn", # deci
202202
"layers.{bid}.attention.wo", # llama-pth
203203
"encoder.layer.{bid}.attention.output.dense", # bert
@@ -242,7 +242,7 @@ class TensorNameMap:
242242
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
243243
"h.{bid}.post_attention_layernorm", # bloom
244244
"transformer.blocks.{bid}.norm_2", # mpt
245-
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
245+
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
246246
"layers.{bid}.ffn_norm", # llama-pth
247247
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
248248
"model.layers.{bid}.ln2", # yi
@@ -265,7 +265,7 @@ class TensorNameMap:
265265

266266
MODEL_TENSOR.FFN_GATE_INP: (
267267
"layers.{bid}.feed_forward.gate", # mixtral
268-
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
268+
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
269269
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
270270
"transformer.decoder_layer.{bid}.router", # Grok
271271
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
@@ -306,10 +306,11 @@ class TensorNameMap:
306306
),
307307

308308
MODEL_TENSOR.FFN_UP_EXP: (
309-
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
310-
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
311-
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
312-
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
309+
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
310+
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
311+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
312+
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
313+
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
313314
),
314315

315316
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -338,10 +339,11 @@ class TensorNameMap:
338339
),
339340

340341
MODEL_TENSOR.FFN_GATE_EXP: (
341-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
342-
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
343-
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
344-
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
342+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
343+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
344+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
345+
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
346+
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
345347
),
346348

347349
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -383,6 +385,7 @@ class TensorNameMap:
383385
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
384386
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
385387
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
388+
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
386389
),
387390

388391
MODEL_TENSOR.FFN_DOWN_SHEXP: (

0 commit comments

Comments
 (0)