Skip to content

Commit ce32060

Browse files
authored
llama : support InternLM2 (ggml-org#5184)
* support InternLM2 inference * add add_space_prefix KV pair
1 parent 1cfb537 commit ce32060

File tree

5 files changed

+387
-5
lines changed

5 files changed

+387
-5
lines changed

convert-hf-to-gguf.py

+152
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def from_model_architecture(model_architecture):
203203
return CodeShellModel
204204
if model_architecture == "OrionForCausalLM":
205205
return OrionModel
206+
if model_architecture == "InternLM2ForCausalLM":
207+
return InternLM2Model
206208
return Model
207209

208210
def _is_model_safetensors(self) -> bool:
@@ -254,6 +256,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
254256
return gguf.MODEL_ARCH.CODESHELL
255257
if arch == "OrionForCausalLM":
256258
return gguf.MODEL_ARCH.ORION
259+
if arch == "InternLM2ForCausalLM":
260+
return gguf.MODEL_ARCH.INTERNLM2
257261

258262
raise NotImplementedError(f'Architecture "{arch}" not supported!')
259263

@@ -1344,6 +1348,154 @@ def write_tensors(self):
13441348
self.gguf_writer.add_tensor("output.weight", data)
13451349
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")
13461350

1351+
1352+
class InternLM2Model(Model):
1353+
def set_vocab(self):
1354+
# (TODO): Is there a better way?
1355+
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
1356+
# \x00 specially and convert it into an emoji character to prevent it from being mistakenly
1357+
# recognized as an empty string in C++.
1358+
from sentencepiece import SentencePieceProcessor
1359+
from sentencepiece import sentencepiece_model_pb2 as model
1360+
1361+
tokenizer_path = self.dir_model / 'tokenizer.model'
1362+
1363+
tokens: list[bytes] = []
1364+
scores: list[float] = []
1365+
toktypes: list[int] = []
1366+
1367+
if not tokenizer_path.is_file():
1368+
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
1369+
sys.exit(1)
1370+
1371+
sentencepiece_model = model.ModelProto()
1372+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
1373+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
1374+
1375+
tokenizer = SentencePieceProcessor(str(tokenizer_path))
1376+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
1377+
1378+
for token_id in range(vocab_size):
1379+
piece = tokenizer.id_to_piece(token_id)
1380+
text = piece.encode("utf-8")
1381+
score = tokenizer.get_score(token_id)
1382+
if text == b"\x00":
1383+
# (TODO): fixme
1384+
# Hack here and replace the \x00 characters.
1385+
print(f"InternLM2 convert token '{text}' to '🐉'!")
1386+
text = "🐉"
1387+
1388+
toktype = SentencePieceTokenTypes.NORMAL
1389+
if tokenizer.is_unknown(token_id):
1390+
toktype = SentencePieceTokenTypes.UNKNOWN
1391+
elif tokenizer.is_control(token_id):
1392+
toktype = SentencePieceTokenTypes.CONTROL
1393+
elif tokenizer.is_unused(token_id):
1394+
toktype = SentencePieceTokenTypes.UNUSED
1395+
elif tokenizer.is_byte(token_id):
1396+
toktype = SentencePieceTokenTypes.BYTE
1397+
1398+
tokens.append(text)
1399+
scores.append(score)
1400+
toktypes.append(toktype)
1401+
1402+
added_tokens_file = self.dir_model / 'added_tokens.json'
1403+
if added_tokens_file.is_file():
1404+
with open(added_tokens_file, "r", encoding="utf-8") as f:
1405+
added_tokens_json = json.load(f)
1406+
1407+
for key in added_tokens_json:
1408+
tokens.append(key.encode("utf-8"))
1409+
scores.append(-1000.0)
1410+
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
1411+
1412+
self.gguf_writer.add_tokenizer_model("llama")
1413+
self.gguf_writer.add_token_list(tokens)
1414+
self.gguf_writer.add_token_scores(scores)
1415+
self.gguf_writer.add_token_types(toktypes)
1416+
self.gguf_writer.add_add_space_prefix(add_prefix)
1417+
1418+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
1419+
special_vocab.add_to_gguf(self.gguf_writer)
1420+
1421+
def set_gguf_parameters(self):
1422+
self.gguf_writer.add_name("InternLM2")
1423+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
1424+
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
1425+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
1426+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
1427+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
1428+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
1429+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
1430+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
1431+
1432+
def post_write_tensors(self, tensor_map, name, data_torch):
1433+
old_dtype = data_torch.dtype
1434+
1435+
# convert any unsupported data types to float32
1436+
if data_torch.dtype not in (torch.float16, torch.float32):
1437+
data_torch = data_torch.to(torch.float32)
1438+
1439+
data = data_torch.squeeze().numpy()
1440+
1441+
# map tensor names
1442+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1443+
if new_name is None:
1444+
print(f"Can not map tensor {name!r}")
1445+
sys.exit()
1446+
1447+
n_dims = len(data.shape)
1448+
data_dtype = data.dtype
1449+
1450+
# if f32 desired, convert any float16 to float32
1451+
if self.ftype == 0 and data_dtype == np.float16:
1452+
data = data.astype(np.float32)
1453+
1454+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1455+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1456+
data = data.astype(np.float32)
1457+
1458+
# if f16 desired, convert any float32 2-dim weight tensors to float16
1459+
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
1460+
data = data.astype(np.float16)
1461+
1462+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1463+
self.gguf_writer.add_tensor(new_name, data)
1464+
1465+
def write_tensors(self):
1466+
from einops import rearrange
1467+
1468+
num_heads = self.hparams.get("num_attention_heads")
1469+
num_kv_heads = self.hparams.get("num_key_value_heads")
1470+
hidden_size = self.hparams.get("hidden_size")
1471+
q_per_kv = num_heads // num_kv_heads
1472+
head_dim = hidden_size // num_heads
1473+
num_groups = num_heads // q_per_kv
1474+
1475+
block_count = self.hparams["num_hidden_layers"]
1476+
model_kv = dict(self.get_tensors())
1477+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1478+
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
1479+
for name, data_torch in model_kv.items():
1480+
# we don't need these
1481+
if name.endswith(".rotary_emb.inv_freq"):
1482+
continue
1483+
1484+
if re.match(qkv_pattern, name):
1485+
bid = re.findall(qkv_pattern, name)[0]
1486+
qkv = data_torch
1487+
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
1488+
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
1489+
q = rearrange(q, " o g n i -> o (g n i)").T
1490+
k = rearrange(k, " o g n i -> o (g n i)").T
1491+
v = rearrange(v, " o g n i -> o (g n i)").T
1492+
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
1493+
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
1494+
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
1495+
else:
1496+
self.post_write_tensors(tensor_map, name, data_torch)
1497+
1498+
13471499
###### CONVERSION LOGIC ######
13481500

13491501

gguf-py/gguf/constants.py

+18
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Tokenizer:
7272
PAD_ID = "tokenizer.ggml.padding_token_id"
7373
ADD_BOS = "tokenizer.ggml.add_bos_token"
7474
ADD_EOS = "tokenizer.ggml.add_eos_token"
75+
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
7576
HF_JSON = "tokenizer.huggingface.json"
7677
RWKV = "tokenizer.rwkv.world"
7778
CHAT_TEMPLATE = "tokenizer.chat_template"
@@ -102,6 +103,7 @@ class MODEL_ARCH(IntEnum):
102103
PLAMO = auto()
103104
CODESHELL = auto()
104105
ORION = auto()
106+
INTERNLM2 = auto()
105107

106108

107109
class MODEL_TENSOR(IntEnum):
@@ -153,6 +155,7 @@ class MODEL_TENSOR(IntEnum):
153155
MODEL_ARCH.PLAMO: "plamo",
154156
MODEL_ARCH.CODESHELL: "codeshell",
155157
MODEL_ARCH.ORION: "orion",
158+
MODEL_ARCH.INTERNLM2: "internlm2",
156159
}
157160

158161
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -446,6 +449,21 @@ class MODEL_TENSOR(IntEnum):
446449
MODEL_TENSOR.FFN_DOWN,
447450
MODEL_TENSOR.FFN_UP,
448451
],
452+
MODEL_ARCH.INTERNLM2: [
453+
MODEL_TENSOR.TOKEN_EMBD,
454+
MODEL_TENSOR.OUTPUT_NORM,
455+
MODEL_TENSOR.OUTPUT,
456+
MODEL_TENSOR.ATTN_NORM,
457+
MODEL_TENSOR.ATTN_Q,
458+
MODEL_TENSOR.ATTN_K,
459+
MODEL_TENSOR.ATTN_V,
460+
MODEL_TENSOR.ATTN_OUT,
461+
MODEL_TENSOR.ATTN_ROT_EMBD,
462+
MODEL_TENSOR.FFN_NORM,
463+
MODEL_TENSOR.FFN_GATE,
464+
MODEL_TENSOR.FFN_DOWN,
465+
MODEL_TENSOR.FFN_UP,
466+
],
449467
# TODO
450468
}
451469

gguf-py/gguf/gguf_writer.py

+3
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ def add_add_bos_token(self, value: bool) -> None:
411411
def add_add_eos_token(self, value: bool) -> None:
412412
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
413413

414+
def add_add_space_prefix(self, value: bool) -> None:
415+
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
416+
414417
def add_chat_template(self, value: str) -> None:
415418
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
416419

gguf-py/gguf/tensor_mapping.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class TensorNameMap:
1919
"language_model.embedding.word_embeddings", # persimmon
2020
"wte", # gpt2
2121
"transformer.embd.wte", # phi2
22+
"model.tok_embeddings", # internlm2
2223
),
2324

2425
# Token type embeddings
@@ -42,7 +43,7 @@ class TensorNameMap:
4243
MODEL_TENSOR.OUTPUT: (
4344
"embed_out", # gptneox
4445
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
45-
"output", # llama-pth bloom
46+
"output", # llama-pth bloom internlm2
4647
"word_embeddings_for_head", # persimmon
4748
"lm_head.linear", # phi2
4849
),
@@ -51,7 +52,7 @@ class TensorNameMap:
5152
MODEL_TENSOR.OUTPUT_NORM: (
5253
"gpt_neox.final_layer_norm", # gptneox
5354
"transformer.ln_f", # gpt2 gpt-j falcon
54-
"model.norm", # llama-hf baichuan
55+
"model.norm", # llama-hf baichuan internlm2
5556
"norm", # llama-pth
5657
"embeddings.LayerNorm", # bert
5758
"transformer.norm_f", # mpt
@@ -84,6 +85,7 @@ class TensorNameMap:
8485
"h.{bid}.ln_1", # gpt2
8586
"transformer.h.{bid}.ln", # phi2
8687
"model.layers.layers.{bid}.norm", # plamo
88+
"model.layers.{bid}.attention_norm", # internlm2
8789
),
8890

8991
# Attention norm 2
@@ -111,6 +113,7 @@ class TensorNameMap:
111113
"encoder.layer.{bid}.attention.self.query", # bert
112114
"transformer.h.{bid}.attn.q_proj", # gpt-j
113115
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
116+
"model.layers.{bid}.attention.wq" # internlm2
114117
),
115118

116119
# Attention key
@@ -120,6 +123,7 @@ class TensorNameMap:
120123
"encoder.layer.{bid}.attention.self.key", # bert
121124
"transformer.h.{bid}.attn.k_proj", # gpt-j
122125
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
126+
"model.layers.{bid}.attention.wk" # internlm2
123127
),
124128

125129
# Attention value
@@ -129,6 +133,7 @@ class TensorNameMap:
129133
"encoder.layer.{bid}.attention.self.value", # bert
130134
"transformer.h.{bid}.attn.v_proj", # gpt-j
131135
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
136+
"model.layers.{bid}.attention.wv" # internlm2
132137
),
133138

134139
# Attention output
@@ -147,6 +152,7 @@ class TensorNameMap:
147152
"h.{bid}.attn.c_proj", # gpt2
148153
"transformer.h.{bid}.mixer.out_proj", # phi2
149154
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
155+
"model.layers.{bid}.attention.wo", # internlm2
150156
),
151157

152158
# Rotary embeddings
@@ -169,6 +175,7 @@ class TensorNameMap:
169175
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
170176
"model.layers.{bid}.ln2", # yi
171177
"h.{bid}.ln_2", # gpt2
178+
"model.layers.{bid}.ffn_norm", # internlm2
172179
),
173180

174181
MODEL_TENSOR.FFN_GATE_INP: (
@@ -194,6 +201,7 @@ class TensorNameMap:
194201
"transformer.h.{bid}.mlp.fc1", # phi2
195202
"model.layers.{bid}.mlp.fc1", # phi2
196203
"model.layers.layers.{bid}.mlp.up_proj", # plamo
204+
"model.layers.{bid}.feed_forward.w3", # internlm2
197205
),
198206

199207
MODEL_TENSOR.FFN_UP_EXP: (
@@ -212,6 +220,7 @@ class TensorNameMap:
212220
"layers.{bid}.feed_forward.w1", # llama-pth
213221
"transformer.h.{bid}.mlp.w2", # qwen
214222
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
223+
"model.layers.{bid}.feed_forward.w1", # internlm2
215224
),
216225

217226
MODEL_TENSOR.FFN_GATE_EXP: (
@@ -236,6 +245,7 @@ class TensorNameMap:
236245
"transformer.h.{bid}.mlp.fc2", # phi2
237246
"model.layers.{bid}.mlp.fc2", # phi2
238247
"model.layers.layers.{bid}.mlp.down_proj", # plamo
248+
"model.layers.{bid}.feed_forward.w2", # internlm2
239249
),
240250

241251
MODEL_TENSOR.FFN_DOWN_EXP: (

0 commit comments

Comments
 (0)