Skip to content

Commit 8173f62

Browse files
author
Judd
committed
support CodeGeeX4
1 parent 493846f commit 8173f62

File tree

7 files changed

+107
-43
lines changed

7 files changed

+107
-43
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pure C++ implementation based on [@ggerganov](https://github.com/ggerganov)'s [g
1313

1414
**What's New:**
1515

16+
* 2024-07-05: CodeGeeX4
1617
* 2024-07-04: InternLM 2.5 with tool calling
1718
* 2024-07-03: Phi3 mini (June 2024 Update)
1819
* 2024-07-01: LLM-Compiler

convert.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class ModelType(Enum):
4747
CODEGEEX2 = 4
4848
CharacterGLM = 5
4949
CHATGLM4 = 6
50+
CODEGEEX4 = 7
5051

5152
InternLM = 0x100
5253
InternLM2 = 0x101
@@ -3392,7 +3393,7 @@ def main():
33923393
vocab_dir = Path(args.model_name_or_path) if args.vocab_dir == '' else Path(args.vocab_dir)
33933394
tokenizer_model_file_exists = False
33943395

3395-
if (config._name_or_path == 'THUDM/glm-4-9b-chat') or (config._name_or_path == 'THUDM/glm4-9b-chat'):
3396+
if config._name_or_path in ['THUDM/glm-4-9b-chat', 'THUDM/glm4-9b-chat', 'THUDM/codegeex4-all-9b']:
33963397
vocab = load_vocab_from_tiktok_mergeable_ranks(vocab_dir / 'tokenizer.model')
33973398
else:
33983399
tokenizer_model_file_exists = (vocab_dir / 'tokenizer.model').exists()
@@ -3429,6 +3430,9 @@ def main():
34293430
ChatGLM2Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
34303431
else:
34313432
ChatGLMConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
3433+
elif arch == 'codegeex4':
3434+
ChatGLM4Converter.MODEL_TYPE = ModelType.CODEGEEX4
3435+
ChatGLM4Converter.convert(config, model_files, vocab, ggml_type, args.save_path)
34323436
elif arch == 'characterglm':
34333437
CharacterGLMConverter.convert(config, model_files, vocab, ggml_type, args.save_path)
34343438
elif arch == 'InternLMForCausalLM':

docs/models.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
Note: Use additional key-value pair arguments to specify characters, `--kv user_name "..." bot_name "..." user_info "..." bot_info "..."`.
3939

4040
* [x] GLM-4: [Chat-9B-128k](https://huggingface.co/THUDM/glm-4-9b-chat), [Chat-9B-1M](https://huggingface.co/THUDM/glm-4-9b-chat-1m)
41+
* [x] CodeGeeX4: [9B](https://huggingface.co/THUDM/codegeex4-all-9b) (`-a CodeGeeX4`)
42+
4143

4244
* InternLM (`InternLMForCausalLM`, `InternLM2ForCausalLM`)
4345
* [x] v1: [Chat-7B](https://huggingface.co/internlm/internlm-chat-7b), [Chat-7B v1.1](https://huggingface.co/internlm/internlm-chat-7b-v1_1), [Chat-20B](https://huggingface.co/internlm/internlm-chat-20b)

models/chatglm.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,15 @@ class GLMInterceptor : public ChunkInterceptor
587587
}
588588

589589
if (find_meta)
590+
{
590591
oss << chunk;
592+
if (oss.str().find(' ') != std::string::npos)
593+
{
594+
streamer->put_chunk(true, oss.str());
595+
oss.str("");
596+
find_meta = false;
597+
}
598+
}
591599
else
592600
streamer->put_chunk(first, chunk);
593601
}
@@ -613,8 +621,8 @@ class ConditionalGeneration : public v2::ConditionalGeneration
613621
{
614622
public:
615623
ConditionalGeneration() = default;
616-
ConditionalGeneration(const Config &config)
617-
: v2::ConditionalGeneration(config, MODEL_TYPE_GLM4)
624+
ConditionalGeneration(const Config &config, ModelType type = MODEL_TYPE_GLM4)
625+
: v2::ConditionalGeneration(config, type)
618626
{
619627
for (int i = 0; i < config.num_hidden_layers; i++)
620628
{

models/codegeex.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
namespace v2
2+
{
3+
struct Config : public glm::v2::Config
4+
{
5+
};
6+
7+
class ChatHistoryEncoder : public BaseHistoryEncoder
8+
{
9+
public:
10+
void do_append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
11+
};
12+
13+
static ChatHistoryEncoder _chat_encoder;
14+
15+
class Tokenizer : public glm::v2::Tokenizer
16+
{
17+
public:
18+
Tokenizer(const Config &config) : glm::v2::Tokenizer::Tokenizer(config, &_chat_encoder)
19+
{
20+
sys_prompt = "# language: Python";
21+
}
22+
};
23+
24+
class ConditionalGeneration : public glm::v2::ConditionalGeneration
25+
{
26+
public:
27+
ConditionalGeneration() = default;
28+
ConditionalGeneration(const Config &config)
29+
: glm::v2::ConditionalGeneration(config, MODEL_TYPE_CODEGEEX2)
30+
{
31+
}
32+
};
33+
34+
void ChatHistoryEncoder::do_append_user(int round_idx, const std::string &user, std::vector<int> &ids) const
35+
{
36+
std::string combined = tokenizer->get_system_prompt() + "\n" + user + "\n";
37+
tokenizer->encode(combined, ids);
38+
}
39+
}
40+
41+
namespace v4
42+
{
43+
typedef glm::v4::Config Config;
44+
45+
class Tokenizer : public glm::v4::Tokenizer
46+
{
47+
public:
48+
Tokenizer(const Config &config) : glm::v4::Tokenizer(config)
49+
{}
50+
51+
size_t load(tokenizer::DataReader *buffer, int n_vocab) override
52+
{
53+
size_t r = glm::v4::Tokenizer::load(buffer, n_vocab);
54+
int special_id = observation_token_id + 5;
55+
code_prefix_token_id = special_id++;
56+
code_middle_token_id = special_id++;
57+
code_suffix_token_id = special_id++;
58+
cursor_token_id = special_id++;
59+
tp->AddAddedToken("<|code_prefix|>", code_prefix_token_id);
60+
tp->AddAddedToken("<|code_middle|>", code_middle_token_id);
61+
tp->AddAddedToken("<|code_suffix|>", code_suffix_token_id);
62+
tp->AddAddedToken("<|cursor|>", cursor_token_id);
63+
return r;
64+
}
65+
public:
66+
int code_prefix_token_id;
67+
int code_middle_token_id;
68+
int code_suffix_token_id;
69+
int cursor_token_id;
70+
};
71+
72+
class ConditionalGeneration : public glm::v4::ConditionalGeneration
73+
{
74+
public:
75+
ConditionalGeneration(const Config &config)
76+
: glm::v4::ConditionalGeneration(config, MODEL_TYPE_CODEGEEX4)
77+
{
78+
}
79+
80+
// FIXME: this mode seems not support tool calling actually
81+
// https://github.com/THUDM/CodeGeeX4/issues/8
82+
ChunkInterceptor *get_interceptor(void) override { return nullptr; }
83+
};
84+
}

models/codegeex_v2.cpp

Lines changed: 0 additions & 36 deletions
This file was deleted.

src/models.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ namespace chatllm
8484
MODEL_TYPE_CODEGEEX2 = 4,
8585
MODEL_TYPE_CHARACTERGLM = 5,
8686
MODEL_TYPE_GLM4 = 6,
87+
MODEL_TYPE_CODEGEEX4 = 7,
8788

8889
MODEL_TYPE_INTERNLM = 0x100,
8990
MODEL_TYPE_INTERNLM2= 0x101, // extended model, supporting 7B & 20B
@@ -212,6 +213,8 @@ namespace chatllm
212213
return "GLM-4";
213214
case MODEL_TYPE_CODEGEEX2:
214215
return "CodeGeeX2";
216+
case MODEL_TYPE_CODEGEEX4:
217+
return "CodeGeeX4";
215218
case MODEL_TYPE_CHARACTERGLM:
216219
return "CharacterGLM";
217220
case MODEL_TYPE_INTERNLM:
@@ -1146,10 +1149,7 @@ namespace chatllm
11461149

11471150
namespace codegeex
11481151
{
1149-
namespace v2
1150-
{
1151-
#include "../models/codegeex_v2.cpp"
1152-
}
1152+
#include "../models/codegeex.cpp"
11531153
}
11541154

11551155
namespace internlm
@@ -1504,6 +1504,7 @@ namespace chatllm
15041504
CASE(CODEGEEX2, codegeex::v2, 1) \
15051505
CASE(CHARACTERGLM, characterglm, 1) \
15061506
CASE(GLM4, glm::v4, 1) \
1507+
CASE(CODEGEEX4, codegeex::v4, 1) \
15071508
\
15081509
CASE(INTERNLM, internlm::v1, 1) \
15091510
CASE(INTERNLM2, internlm::v2, 1) \

0 commit comments

Comments
 (0)