Skip to content

Commit eebbab0

Browse files
okdshinggerganov
authored andcommitted
llama : add PLaMo model (ggml-org#3557)
* add plamo mock * add tensor loading * plamo convert * update norm * able to compile * fix norm_rms_eps hparam * runnable * use inp_pos * seems ok * update kqv code * remove develop code * update README * shuffle attn_q.weight and attn_output.weight for broadcasting * remove plamo_llm_build_kqv and use llm_build_kqv * fix style * update * llama : remove obsolete KQ_scale * plamo : fix tensor names for correct GPU offload --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 1fc2dcc commit eebbab0

File tree

5 files changed

+307
-15
lines changed

5 files changed

+307
-15
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ as the main playground for developing new features for the [ggml](https://github
102102
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
103103
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
104104
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
105+
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
105106

106107
**Multimodal models:**
107108

convert-hf-to-gguf.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def from_model_architecture(model_architecture):
184184
return MixtralModel
185185
if model_architecture == "PhiForCausalLM":
186186
return Phi2Model
187+
if model_architecture == "PlamoForCausalLM":
188+
return PlamoModel
187189
return Model
188190

189191
def _is_model_safetensors(self) -> bool:
@@ -225,6 +227,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
225227
return gguf.MODEL_ARCH.LLAMA
226228
if arch == "PhiForCausalLM":
227229
return gguf.MODEL_ARCH.PHI2
230+
if arch == "PlamoForCausalLM":
231+
return gguf.MODEL_ARCH.PLAMO
228232

229233
raise NotImplementedError(f'Architecture "{arch}" not supported!')
230234

@@ -1002,11 +1006,91 @@ def set_gguf_parameters(self):
10021006
self.gguf_writer.add_add_bos_token(False)
10031007

10041008

1009+
class PlamoModel(Model):
1010+
def set_vocab(self):
1011+
self._set_vocab_sentencepiece()
1012+
1013+
def set_gguf_parameters(self):
1014+
hparams = self.hparams
1015+
block_count = hparams["num_hidden_layers"]
1016+
1017+
self.gguf_writer.add_name("PLaMo")
1018+
self.gguf_writer.add_context_length(4096) # not in config.json
1019+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
1020+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
1021+
self.gguf_writer.add_block_count(block_count)
1022+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
1023+
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
1024+
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
1025+
1026+
def shuffle_attn_q_weight(self, data_torch):
1027+
assert data_torch.size() == (5120, 5120)
1028+
data_torch = data_torch.reshape(8, 5, 128, 5120)
1029+
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
1030+
data_torch = torch.reshape(data_torch, (5120, 5120))
1031+
return data_torch
1032+
1033+
def shuffle_attn_output_weight(self, data_torch):
1034+
assert data_torch.size() == (5120, 5120)
1035+
data_torch = data_torch.reshape(5120, 8, 5, 128)
1036+
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
1037+
data_torch = torch.reshape(data_torch, (5120, 5120))
1038+
return data_torch
1039+
1040+
def write_tensors(self):
1041+
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
1042+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1043+
1044+
for name, data_torch in self.get_tensors():
1045+
if "self_attn.rotary_emb.inv_freq" in name:
1046+
continue
1047+
1048+
# map tensor names
1049+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1050+
if new_name is None:
1051+
print(f"Can not map tensor {name!r}")
1052+
sys.exit()
1053+
1054+
# shuffle for broadcasting of gqa in ggml_mul_mat
1055+
if new_name.endswith("attn_q.weight"):
1056+
data_torch = self.shuffle_attn_q_weight(data_torch)
1057+
elif new_name.endswith("attn_output.weight"):
1058+
data_torch = self.shuffle_attn_output_weight(data_torch)
1059+
1060+
old_dtype = data_torch.dtype
1061+
1062+
# convert any unsupported data types to float32
1063+
if data_torch.dtype not in (torch.float16, torch.float32):
1064+
data_torch = data_torch.to(torch.float32)
1065+
1066+
data = data_torch.squeeze().numpy()
1067+
1068+
n_dims = len(data.shape)
1069+
data_dtype = data.dtype
1070+
1071+
# if f32 desired, convert any float16 to float32
1072+
if self.ftype == 0 and data_dtype == np.float16:
1073+
data = data.astype(np.float32)
1074+
1075+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1076+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1077+
data = data.astype(np.float32)
1078+
1079+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1080+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1081+
data = data.astype(np.float16)
1082+
1083+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1084+
1085+
self.gguf_writer.add_tensor(new_name, data)
1086+
1087+
10051088
###### CONVERSION LOGIC ######
10061089

10071090

10081091
def parse_args() -> argparse.Namespace:
1009-
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
1092+
parser = argparse.ArgumentParser(
1093+
description="Convert a huggingface model to a GGML compatible file")
10101094
parser.add_argument(
10111095
"--vocab-only", action="store_true",
10121096
help="extract only the vocab",

gguf-py/gguf/constants.py

+17
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum):
9696
STABLELM = auto()
9797
QWEN = auto()
9898
PHI2 = auto()
99+
PLAMO = auto()
99100

100101

101102
class MODEL_TENSOR(IntEnum):
@@ -142,6 +143,7 @@ class MODEL_TENSOR(IntEnum):
142143
MODEL_ARCH.STABLELM: "stablelm",
143144
MODEL_ARCH.QWEN: "qwen",
144145
MODEL_ARCH.PHI2: "phi2",
146+
MODEL_ARCH.PLAMO: "plamo",
145147
}
146148

147149
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -349,6 +351,21 @@ class MODEL_TENSOR(IntEnum):
349351
MODEL_TENSOR.FFN_DOWN,
350352
MODEL_TENSOR.FFN_UP,
351353
],
354+
MODEL_ARCH.PLAMO: [
355+
MODEL_TENSOR.TOKEN_EMBD,
356+
MODEL_TENSOR.OUTPUT_NORM,
357+
MODEL_TENSOR.OUTPUT,
358+
MODEL_TENSOR.ROPE_FREQS,
359+
MODEL_TENSOR.ATTN_NORM,
360+
MODEL_TENSOR.ATTN_Q,
361+
MODEL_TENSOR.ATTN_K,
362+
MODEL_TENSOR.ATTN_V,
363+
MODEL_TENSOR.ATTN_OUT,
364+
MODEL_TENSOR.ATTN_ROT_EMBD,
365+
MODEL_TENSOR.FFN_GATE,
366+
MODEL_TENSOR.FFN_DOWN,
367+
MODEL_TENSOR.FFN_UP,
368+
],
352369
MODEL_ARCH.GPT2: [
353370
# TODO
354371
],

gguf-py/gguf/tensor_mapping.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ class TensorNameMap:
7979
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
8080
"model.layers.{bid}.ln1", # yi
8181
"transformer.h.{bid}.ln", # phi2
82+
"model.layers.layers.{bid}.norm", # plamo
8283
),
8384

8485
# Attention norm 2
@@ -99,26 +100,29 @@ class TensorNameMap:
99100

100101
# Attention query
101102
MODEL_TENSOR.ATTN_Q: (
102-
"model.layers.{bid}.self_attn.q_proj", # llama-hf
103-
"layers.{bid}.attention.wq", # llama-pth
104-
"encoder.layer.{bid}.attention.self.query", # bert
105-
"transformer.h.{bid}.attn.q_proj", # gpt-j
103+
"model.layers.{bid}.self_attn.q_proj", # llama-hf
104+
"layers.{bid}.attention.wq", # llama-pth
105+
"encoder.layer.{bid}.attention.self.query", # bert
106+
"transformer.h.{bid}.attn.q_proj", # gpt-j
107+
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
106108
),
107109

108110
# Attention key
109111
MODEL_TENSOR.ATTN_K: (
110-
"model.layers.{bid}.self_attn.k_proj", # llama-hf
111-
"layers.{bid}.attention.wk", # llama-pth
112-
"encoder.layer.{bid}.attention.self.key", # bert
113-
"transformer.h.{bid}.attn.k_proj", # gpt-j
112+
"model.layers.{bid}.self_attn.k_proj", # llama-hf
113+
"layers.{bid}.attention.wk", # llama-pth
114+
"encoder.layer.{bid}.attention.self.key", # bert
115+
"transformer.h.{bid}.attn.k_proj", # gpt-j
116+
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
114117
),
115118

116119
# Attention value
117120
MODEL_TENSOR.ATTN_V: (
118-
"model.layers.{bid}.self_attn.v_proj", # llama-hf
119-
"layers.{bid}.attention.wv", # llama-pth
120-
"encoder.layer.{bid}.attention.self.value", # bert
121-
"transformer.h.{bid}.attn.v_proj", # gpt-j
121+
"model.layers.{bid}.self_attn.v_proj", # llama-hf
122+
"layers.{bid}.attention.wv", # llama-pth
123+
"encoder.layer.{bid}.attention.self.value", # bert
124+
"transformer.h.{bid}.attn.v_proj", # gpt-j
125+
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
122126
),
123127

124128
# Attention output
@@ -134,12 +138,14 @@ class TensorNameMap:
134138
"transformer.h.{bid}.attn.out_proj", # gpt-j
135139
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
136140
"transformer.h.{bid}.mixer.out_proj", # phi2
141+
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
137142
),
138143

139144
# Rotary embeddings
140145
MODEL_TENSOR.ATTN_ROT_EMBD: (
141-
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
142-
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
146+
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
147+
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
148+
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
143149
),
144150

145151
# Feed-forward norm
@@ -174,6 +180,7 @@ class TensorNameMap:
174180
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
175181
"transformer.h.{bid}.mlp.w1", # qwen
176182
"transformer.h.{bid}.mlp.fc1", # phi2
183+
"model.layers.layers.{bid}.mlp.up_proj", # plamo
177184
),
178185

179186
MODEL_TENSOR.FFN_UP_EXP: (
@@ -186,6 +193,7 @@ class TensorNameMap:
186193
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
187194
"layers.{bid}.feed_forward.w1", # llama-pth
188195
"transformer.h.{bid}.mlp.w2", # qwen
196+
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
189197
),
190198

191199
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -206,6 +214,7 @@ class TensorNameMap:
206214
"transformer.h.{bid}.mlp.fc_out", # gpt-j
207215
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
208216
"transformer.h.{bid}.mlp.fc2", # phi2
217+
"model.layers.layers.{bid}.mlp.down_proj", # plamo
209218
),
210219

211220
MODEL_TENSOR.FFN_DOWN_EXP: (

0 commit comments

Comments
 (0)