Skip to content

Commit e4255c3

Browse files
phymbertmegha95
authored andcommitted
model: support arch DbrxForCausalLM (ggml-org#6515)
* model: dbrx convert to gguf ggml-org#6344 * llama: support dbrx ggml-org#6344 * doc: dbrx: add the model as supported * scripts: get-wikitext-2 add unzip * llama: increase maximum experts allowed * llama: factorize moe graph implementation between grok, mixtral and dbrx --------- Co-authored-by: Megha Agarwal <[email protected]>
1 parent 81c3461 commit e4255c3

File tree

7 files changed

+427
-147
lines changed

7 files changed

+427
-147
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ Typically finetunes of the base models below are supported as well.
9494
- [x] LLaMA 2 🦙🦙
9595
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
9696
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
97+
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
9798
- [X] Falcon
9899
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
99100
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)

Diff for: convert-hf-to-gguf.py

+96
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,102 @@ def write_tensors(self):
14271427
self.gguf_writer.add_tensor(new_name, data)
14281428

14291429

1430+
@Model.register("DbrxForCausalLM")
1431+
class DbrxModel(Model):
1432+
model_arch = gguf.MODEL_ARCH.DBRX
1433+
1434+
def set_gguf_parameters(self):
1435+
ffn_config = self.hparams["ffn_config"]
1436+
attn_config = self.hparams["attn_config"]
1437+
self.gguf_writer.add_name(self.hparams["model_type"])
1438+
self.gguf_writer.add_block_count(self.hparams["n_layers"])
1439+
1440+
self.gguf_writer.add_context_length(self.hparams["max_seq_len"])
1441+
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
1442+
self.gguf_writer.add_feed_forward_length(ffn_config["ffn_hidden_size"])
1443+
1444+
self.gguf_writer.add_head_count(self.hparams["n_heads"])
1445+
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
1446+
1447+
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
1448+
1449+
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
1450+
self.gguf_writer.add_file_type(self.ftype)
1451+
1452+
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])
1453+
self.gguf_writer.add_expert_used_count(ffn_config["moe_top_k"])
1454+
1455+
self.gguf_writer.add_layer_norm_eps(1e-5)
1456+
1457+
self.gguf_writer.add_file_type(self.ftype)
1458+
print(f"gguf: file type = {self.ftype}")
1459+
1460+
def write_tensors(self):
1461+
block_count = self.hparams.get("n_layers")
1462+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1463+
for name, data_torch in self.get_tensors():
1464+
n_expert = self.hparams["ffn_config"]["moe_num_experts"]
1465+
n_ff = self.hparams["ffn_config"]["ffn_hidden_size"]
1466+
n_embd = self.hparams["d_model"]
1467+
1468+
# Specific behavior for experts tensors: suffix .weight, view as 3D and transpose
1469+
# original implementation expects (n_expert, n_ff, n_embd) for all experts weights
1470+
# But llama.cpp moe graph works differently
1471+
# AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions
1472+
# so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor
1473+
exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
1474+
"ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert}
1475+
"ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert}
1476+
experts = False
1477+
for exp_tensor_name in exp_tensor_names.keys():
1478+
if name.find(exp_tensor_name) != -1 and name.find(".weight") == -1:
1479+
experts = True
1480+
data_torch = data_torch.view(n_expert, n_ff, n_embd)
1481+
if (permute_tensor := exp_tensor_names[exp_tensor_name]) is not None:
1482+
data_torch = data_torch.permute(*permute_tensor)
1483+
break
1484+
1485+
old_dtype = data_torch.dtype
1486+
1487+
# convert any unsupported data types to float32
1488+
if data_torch.dtype not in (torch.float16, torch.float32):
1489+
data_torch = data_torch.to(torch.float32)
1490+
1491+
data = data_torch.squeeze().numpy()
1492+
1493+
# map tensor names
1494+
# In MoE models the ffn tensors are typically most of the model weights,
1495+
# and need to be quantizable. Quantize expects tensor names to be suffixed by .weight.
1496+
# Every other model has the weight names ending in .weight,
1497+
# let's assume that is the convention which is not the case for dbrx:
1498+
# https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15
1499+
new_name = tensor_map.get_name(name if not experts else name + ".weight", try_suffixes=(".weight",))
1500+
if new_name is None:
1501+
print(f"Can not map tensor {name!r}")
1502+
sys.exit()
1503+
1504+
n_dims = len(data.shape)
1505+
data_dtype = data.dtype
1506+
1507+
# Most of the codebase that takes in 1D tensors only handles F32 tensors
1508+
# and most of the outputs tensors are F32.
1509+
if data_dtype != np.float32 and n_dims == 1:
1510+
print(f"Can not map tensor {name!r}: all 1D tensors must be F32")
1511+
sys.exit()
1512+
1513+
# if f32 desired, convert any float16 to float32
1514+
if self.ftype == 0 and data_dtype == np.float16:
1515+
data = data.astype(np.float32)
1516+
1517+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1518+
if self.ftype == 1 and data_dtype == np.float32 and n_dims > 1:
1519+
data = data.astype(np.float16)
1520+
1521+
print(f"{new_name}, n_dims = {n_dims}, shape = {data.shape}, {old_dtype} --> {data.dtype}")
1522+
1523+
self.gguf_writer.add_tensor(new_name, data)
1524+
1525+
14301526
@Model.register("MiniCPMForCausalLM")
14311527
class MiniCPMModel(Model):
14321528
model_arch = gguf.MODEL_ARCH.MINICPM

Diff for: examples/eval-callback/eval-callback.cpp

+18-8
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,27 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
2828
}
2929

3030
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
31+
GGML_ASSERT(n > 0);
3132
float sum = 0;
3233
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
3334
printf(" [\n");
34-
for (int64_t i2 = 0; i2 < ne[2] && i2 < n; i2++) {
35+
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
36+
if (i2 == n && ne[2] > 2*n) {
37+
printf(" ..., \n");
38+
i2 = ne[2] - n;
39+
}
3540
printf(" [\n");
36-
for (int64_t i1 = 0; i1 < ne[1] && i1 < n; i1++) {
41+
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
42+
if (i1 == n && ne[1] > 2*n) {
43+
printf(" ..., \n");
44+
i1 = ne[1] - n;
45+
}
3746
printf(" [");
38-
for (int64_t i0 = 0; i0 < ne[0] && i0 < n; i0++) {
47+
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
48+
if (i0 == n && ne[0] > 2*n) {
49+
printf("..., ");
50+
i0 = ne[0] - n;
51+
}
3952
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
4053
float v;
4154
if (type == GGML_TYPE_F16) {
@@ -51,17 +64,14 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
5164
} else {
5265
GGML_ASSERT(false);
5366
}
54-
printf("%8.4f", v);
67+
printf("%12.4f", v);
5568
sum += v;
56-
if (i0 < ne[0] - 1 && i0 < n - 1) printf(", ");
69+
if (i0 < ne[0] - 1) printf(", ");
5770
}
58-
if (ne[0] > n) printf(", ...");
5971
printf("],\n");
6072
}
61-
if (ne[1] > n) printf(" ...\n");
6273
printf(" ],\n");
6374
}
64-
if (ne[2] > n) printf(" ...\n");
6575
printf(" ]\n");
6676
printf(" sum = %f\n", sum);
6777
}

Diff for: gguf-py/gguf/constants.py

+15
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class MODEL_ARCH(IntEnum):
126126
MAMBA = auto()
127127
XVERSE = auto()
128128
COMMAND_R = auto()
129+
DBRX = auto()
129130

130131

131132
class MODEL_TENSOR(IntEnum):
@@ -195,6 +196,7 @@ class MODEL_TENSOR(IntEnum):
195196
MODEL_ARCH.MAMBA: "mamba",
196197
MODEL_ARCH.XVERSE: "xverse",
197198
MODEL_ARCH.COMMAND_R: "command-r",
199+
MODEL_ARCH.DBRX: "dbrx",
198200
}
199201

200202
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -642,6 +644,19 @@ class MODEL_TENSOR(IntEnum):
642644
MODEL_TENSOR.ATTN_K_NORM,
643645
MODEL_TENSOR.ATTN_Q_NORM,
644646
],
647+
MODEL_ARCH.DBRX: [
648+
MODEL_TENSOR.TOKEN_EMBD,
649+
MODEL_TENSOR.OUTPUT_NORM,
650+
MODEL_TENSOR.OUTPUT,
651+
MODEL_TENSOR.ATTN_NORM,
652+
MODEL_TENSOR.ATTN_QKV,
653+
MODEL_TENSOR.ATTN_OUT,
654+
MODEL_TENSOR.ATTN_OUT_NORM,
655+
MODEL_TENSOR.FFN_GATE_INP,
656+
MODEL_TENSOR.FFN_GATE_EXP,
657+
MODEL_TENSOR.FFN_DOWN_EXP,
658+
MODEL_TENSOR.FFN_UP_EXP,
659+
],
645660
# TODO
646661
}
647662

Diff for: gguf-py/gguf/tensor_mapping.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class TensorNameMap:
1010
# Token embeddings
1111
MODEL_TENSOR.TOKEN_EMBD: (
1212
"gpt_neox.embed_in", # gptneox
13-
"transformer.wte", # gpt2 gpt-j mpt refact qwen
13+
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx
1414
"transformer.word_embeddings", # falcon
1515
"word_embeddings", # bloom
1616
"model.embed_tokens", # llama-hf
@@ -48,7 +48,7 @@ class TensorNameMap:
4848
# Output
4949
MODEL_TENSOR.OUTPUT: (
5050
"embed_out", # gptneox
51-
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba
51+
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx
5252
"output", # llama-pth bloom internlm2
5353
"word_embeddings_for_head", # persimmon
5454
"lm_head.linear", # phi2
@@ -60,7 +60,7 @@ class TensorNameMap:
6060
"transformer.ln_f", # gpt2 gpt-j falcon
6161
"model.norm", # llama-hf baichuan internlm2
6262
"norm", # llama-pth
63-
"transformer.norm_f", # mpt
63+
"transformer.norm_f", # mpt dbrx
6464
"ln_f", # refact bloom qwen gpt2
6565
"language_model.encoder.final_layernorm", # persimmon
6666
"model.final_layernorm", # persimmon
@@ -96,6 +96,7 @@ class TensorNameMap:
9696
"model.layers.{bid}.norm", # mamba-qbert
9797
"backbone.layers.{bid}.norm", # mamba
9898
"transformer.decoder_layer.{bid}.rms_norm", # Grok
99+
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
99100
),
100101

101102
# Attention norm 2
@@ -108,6 +109,7 @@ class TensorNameMap:
108109
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
109110
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen
110111
"transformer.blocks.{bid}.attn.Wqkv", # mpt
112+
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
111113
"transformer.h.{bid}.self_attention.query_key_value", # falcon
112114
"h.{bid}.self_attention.query_key_value", # bloom
113115
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
@@ -152,30 +154,32 @@ class TensorNameMap:
152154

153155
# Attention output
154156
MODEL_TENSOR.ATTN_OUT: (
155-
"gpt_neox.layers.{bid}.attention.dense", # gptneox
156-
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
157-
"transformer.blocks.{bid}.attn.out_proj", # mpt
158-
"transformer.h.{bid}.self_attention.dense", # falcon
159-
"h.{bid}.self_attention.dense", # bloom
160-
"model.layers.{bid}.self_attn.o_proj", # llama-hf
161-
"layers.{bid}.attention.wo", # llama-pth
162-
"encoder.layer.{bid}.attention.output.dense", # bert
163-
"transformer.h.{bid}.attn.out_proj", # gpt-j
164-
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
165-
"model.layers.{bid}.self_attn.dense", # persimmon
166-
"h.{bid}.attn.c_proj", # gpt2
167-
"transformer.h.{bid}.mixer.out_proj", # phi2
168-
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
169-
"model.layers.{bid}.attention.wo", # internlm2
170-
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
171-
"transformer.decoder_layer.{bid}.multi_head_attention.linear"# Grok
157+
"gpt_neox.layers.{bid}.attention.dense", # gptneox
158+
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
159+
"transformer.blocks.{bid}.attn.out_proj", # mpt
160+
"transformer.h.{bid}.self_attention.dense", # falcon
161+
"h.{bid}.self_attention.dense", # bloom
162+
"model.layers.{bid}.self_attn.o_proj", # llama-hf
163+
"layers.{bid}.attention.wo", # llama-pth
164+
"encoder.layer.{bid}.attention.output.dense", # bert
165+
"transformer.h.{bid}.attn.out_proj", # gpt-j
166+
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
167+
"model.layers.{bid}.self_attn.dense", # persimmon
168+
"h.{bid}.attn.c_proj", # gpt2
169+
"transformer.h.{bid}.mixer.out_proj", # phi2
170+
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
171+
"model.layers.{bid}.attention.wo", # internlm2
172+
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
173+
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
174+
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
172175
),
173176

174177
# Attention output norm
175178
MODEL_TENSOR.ATTN_OUT_NORM: (
176179
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
177180
"encoder.layers.{bid}.norm1", # nomic-bert
178181
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
182+
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
179183
),
180184

181185
# Rotary embeddings
@@ -202,9 +206,10 @@ class TensorNameMap:
202206
),
203207

204208
MODEL_TENSOR.FFN_GATE_INP: (
205-
"layers.{bid}.feed_forward.gate", # mixtral
206-
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
207-
"transformer.decoder_layer.{bid}.router" # Grok
209+
"layers.{bid}.feed_forward.gate", # mixtral
210+
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
211+
"transformer.decoder_layer.{bid}.router", # Grok
212+
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
208213
),
209214

210215
# Feed-forward up
@@ -233,6 +238,7 @@ class TensorNameMap:
233238
MODEL_TENSOR.FFN_UP_EXP: (
234239
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
235240
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
241+
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
236242
),
237243

238244
# AWQ-activation gate
@@ -251,8 +257,9 @@ class TensorNameMap:
251257
),
252258

253259
MODEL_TENSOR.FFN_GATE_EXP: (
254-
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
255-
"transformer.decoder_layer.{bid}.moe.linear" # Grok (merged)
260+
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
261+
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
262+
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
256263
),
257264

258265
# Feed-forward down
@@ -280,6 +287,7 @@ class TensorNameMap:
280287
MODEL_TENSOR.FFN_DOWN_EXP: (
281288
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
282289
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
290+
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
283291
),
284292

285293
MODEL_TENSOR.ATTN_Q_NORM: (

0 commit comments

Comments
 (0)