Skip to content

Commit dff8cbe

Browse files
committed
convert : support Mixtral as LLAMA arch
1 parent fe680e3 commit dff8cbe

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

Diff for: convert.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,23 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
266266
# LLaMA v1
267267
n_ctx = 2048
268268

269+
# print model keys
270+
for k in model.keys():
271+
print(k)
272+
273+
# check if MoE
274+
if "layers.0.feed_forward.experts.0.w1.weight" in model:
275+
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
276+
n_ctx = 32768
277+
else:
278+
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
279+
269280
return Params(
270281
n_vocab = model["tok_embeddings.weight"].shape[0],
271282
n_embd = config["dim"],
272283
n_layer = config["n_layers"],
273284
n_ctx = n_ctx,
274-
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
285+
n_ff = n_ff,
275286
n_head = (n_head := config["n_heads"]),
276287
n_head_kv = config.get("n_kv_heads", n_head),
277288
f_norm_eps = config["norm_eps"],

Diff for: gguf-py/gguf/constants.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,14 @@ class MODEL_TENSOR(IntEnum):
111111
ATTN_NORM = auto()
112112
ATTN_NORM_2 = auto()
113113
ATTN_ROT_EMBD = auto()
114+
FFN_GATE_INP = auto()
115+
FFN_NORM = auto()
114116
FFN_GATE = auto()
115117
FFN_DOWN = auto()
116118
FFN_UP = auto()
117-
FFN_NORM = auto()
119+
FFN_GATE_EXP = auto()
120+
FFN_DOWN_EXP = auto()
121+
FFN_UP_EXP = auto()
118122
ATTN_Q_NORM = auto()
119123
ATTN_K_NORM = auto()
120124

@@ -154,10 +158,14 @@ class MODEL_TENSOR(IntEnum):
154158
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
155159
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
156160
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
161+
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
157162
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
158163
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
159164
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
160165
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
166+
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate.{xid}",
167+
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down.{xid}",
168+
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up.{xid}",
161169
}
162170

163171
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -172,10 +180,14 @@ class MODEL_TENSOR(IntEnum):
172180
MODEL_TENSOR.ATTN_V,
173181
MODEL_TENSOR.ATTN_OUT,
174182
MODEL_TENSOR.ATTN_ROT_EMBD,
183+
MODEL_TENSOR.FFN_GATE_INP,
175184
MODEL_TENSOR.FFN_NORM,
176185
MODEL_TENSOR.FFN_GATE,
177186
MODEL_TENSOR.FFN_DOWN,
178187
MODEL_TENSOR.FFN_UP,
188+
MODEL_TENSOR.FFN_GATE_EXP,
189+
MODEL_TENSOR.FFN_DOWN_EXP,
190+
MODEL_TENSOR.FFN_UP_EXP,
179191
],
180192
MODEL_ARCH.GPTNEOX: [
181193
MODEL_TENSOR.TOKEN_EMBD,

Diff for: gguf-py/gguf/tensor_mapping.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ class TensorNameMap:
149149
"model.layers.{bid}.ln2", # yi
150150
),
151151

152+
MODEL_TENSOR.FFN_GATE_INP: (
153+
"layers.{bid}.feed_forward.gate", # mixtral
154+
),
155+
152156
# Feed-forward up
153157
MODEL_TENSOR.FFN_UP: (
154158
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
@@ -164,11 +168,19 @@ class TensorNameMap:
164168
"transformer.h.{bid}.mlp.w1", # qwen
165169
),
166170

171+
MODEL_TENSOR.FFN_UP_EXP: (
172+
"layers.{bid}.feed_forward.experts.{xid}.w3", # mixtral
173+
),
174+
167175
# Feed-forward gate
168176
MODEL_TENSOR.FFN_GATE: (
169-
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
170-
"layers.{bid}.feed_forward.w1", # llama-pth
171-
"transformer.h.{bid}.mlp.w2", # qwen
177+
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
178+
"layers.{bid}.feed_forward.w1", # llama-pth
179+
"transformer.h.{bid}.mlp.w2", # qwen
180+
),
181+
182+
MODEL_TENSOR.FFN_GATE_EXP: (
183+
"layers.{bid}.feed_forward.experts.{xid}.w1", # mixtral
172184
),
173185

174186
# Feed-forward down
@@ -185,6 +197,10 @@ class TensorNameMap:
185197
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
186198
),
187199

200+
MODEL_TENSOR.FFN_DOWN_EXP: (
201+
"layers.{bid}.feed_forward.experts.{xid}.w2", # mixtral
202+
),
203+
188204
MODEL_TENSOR.ATTN_Q_NORM: (
189205
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
190206
),
@@ -213,11 +229,14 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
213229
for tensor, keys in self.block_mappings_cfg.items():
214230
if tensor not in MODEL_TENSORS[arch]:
215231
continue
216-
tensor_name = TENSOR_NAMES[tensor].format(bid = bid)
217-
self.mapping[tensor_name] = (tensor, tensor_name)
218-
for key in keys:
219-
key = key.format(bid = bid)
220-
self.mapping[key] = (tensor, tensor_name)
232+
# TODO: make this configurable
233+
n_experts = 8
234+
for xid in range(n_experts):
235+
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
236+
self.mapping[tensor_name] = (tensor, tensor_name)
237+
for key in keys:
238+
key = key.format(bid = bid, xid = xid)
239+
self.mapping[key] = (tensor, tensor_name)
221240

222241
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
223242
result = self.mapping.get(key)

0 commit comments

Comments
 (0)