Skip to content

Commit 48efee1

Browse files
committed
llama : improve infill support
ggml-ci
1 parent fa42aa6 commit 48efee1

File tree

10 files changed

+522
-376
lines changed

10 files changed

+522
-376
lines changed

common/arg.cpp

+110-136
Large diffs are not rendered by default.

common/common.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <algorithm>
1414
#include <cinttypes>
15+
#include <climits>
1516
#include <cmath>
1617
#include <codecvt>
1718
#include <cstdarg>
@@ -23,10 +24,10 @@
2324
#include <regex>
2425
#include <sstream>
2526
#include <string>
27+
#include <thread>
2628
#include <unordered_map>
2729
#include <unordered_set>
2830
#include <vector>
29-
#include <thread>
3031

3132
#if defined(__APPLE__) && defined(__MACH__)
3233
#include <sys/types.h>
@@ -400,6 +401,21 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
400401
// String utils
401402
//
402403

404+
std::string string_format(const char * fmt, ...) {
405+
va_list ap;
406+
va_list ap2;
407+
va_start(ap, fmt);
408+
va_copy(ap2, ap);
409+
int size = vsnprintf(NULL, 0, fmt, ap);
410+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
411+
std::vector<char> buf(size + 1);
412+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
413+
GGML_ASSERT(size2 == size);
414+
va_end(ap2);
415+
va_end(ap);
416+
return std::string(buf.data(), size);
417+
}
418+
403419
std::vector<std::string> string_split(std::string input, char separator) {
404420
std::vector<std::string> parts;
405421
size_t separator_pos = input.find(separator);

common/common.h

+16-3
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,28 @@ void gpt_init();
349349

350350
std::string gpt_params_get_system_info(const gpt_params & params);
351351

352-
bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
353-
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
354-
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
352+
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
353+
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
354+
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
355355
bool set_process_priority(enum ggml_sched_priority prio);
356356

357357
//
358358
// String utils
359359
//
360360

361+
#ifdef __GNUC__
362+
#ifdef __MINGW32__
363+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
364+
#else
365+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
366+
#endif
367+
#else
368+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
369+
#endif
370+
371+
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
372+
std::string string_format(const char * fmt, ...);
373+
361374
std::vector<std::string> string_split(std::string input, char separator);
362375

363376
std::string string_strip(const std::string & str);

examples/infill/infill.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
205205
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
206206
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
207207

208-
GGML_ASSERT(llama_token_prefix(model) >= 0);
209-
GGML_ASSERT(llama_token_suffix(model) >= 0);
208+
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
209+
GGML_ASSERT(llama_token_fim_suf(model) >= 0);
210210

211-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
212-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
211+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
212+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
213213

214214
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
215215
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
@@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
218218
}
219219
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
220220

221-
const llama_token middle_token = llama_token_middle(model);
221+
const llama_token middle_token = llama_token_fim_mid(model);
222222
if (middle_token >= 0) {
223223
embd_inp.push_back(middle_token);
224224
}
@@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
508508
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
509509
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
510510

511-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
512-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
511+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
512+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
513513

514514
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
515515
embd_end = params.spm_infill ? inp_pfx : inp_sfx;

examples/server/README.md

-2
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,6 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
525525

526526
`input_suffix`: Set the suffix of the code to infill.
527527

528-
It also accepts all the options of `/completion` except `stream` and `prompt`.
529-
530528
- **GET** `/props`: Return current server settings.
531529

532530
**Response format**

examples/server/server.cpp

+43-36
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,7 @@ struct server_context {
753753
metrics.init();
754754
}
755755

756-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757-
// TODO: currently, we tokenize using special tokens by default
758-
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759-
// but it's better compared to completely ignoring ChatML and other chat templates
760-
const bool TMP_FORCE_SPECIAL = true;
761-
756+
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
762757
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
763758
// or the first element of the json_prompt array is a string.
764759
std::vector<llama_token> prompt_tokens;
@@ -771,10 +766,10 @@ struct server_context {
771766

772767
std::vector<llama_token> p;
773768
if (first) {
774-
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
769+
p = ::llama_tokenize(ctx, s, add_special, parse_special);
775770
first = false;
776771
} else {
777-
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
772+
p = ::llama_tokenize(ctx, s, false, parse_special);
778773
}
779774

780775
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@@ -788,7 +783,7 @@ struct server_context {
788783
}
789784
} else {
790785
auto s = json_prompt.template get<std::string>();
791-
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
786+
prompt_tokens = ::llama_tokenize(ctx, s, add_special, parse_special);
792787
}
793788

794789
return prompt_tokens;
@@ -1220,7 +1215,7 @@ struct server_context {
12201215
slot.params.n_predict, n_ctx_train);
12211216
}
12221217

1223-
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1218+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
12241219

12251220
return slot.has_next_token; // continue
12261221
}
@@ -1488,9 +1483,8 @@ struct server_context {
14881483
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
14891484
data["index"] = 0;
14901485
create_task(data, false, nullptr);
1491-
}
1492-
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1493-
else if (prompt.is_array()) {
1486+
} else if (prompt.is_array()) {
1487+
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14941488
std::vector<json> prompts = prompt;
14951489
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
14961490
// prompts[0] is the question
@@ -1515,9 +1509,8 @@ struct server_context {
15151509
}
15161510
}
15171511
}
1518-
}
1519-
// invalid case
1520-
else {
1512+
} else {
1513+
// invalid case
15211514
throw std::runtime_error(error_msg);
15221515
}
15231516

@@ -1988,31 +1981,23 @@ struct server_context {
19881981

19891982
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
19901983
const bool add_bos = llama_add_bos_token(model);
1991-
bool suff_rm_leading_spc = true;
1992-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1993-
params.input_suffix.erase(0, 1);
1994-
suff_rm_leading_spc = false;
1995-
}
19961984

1997-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1998-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1985+
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
1986+
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
19991987

2000-
const int space_token = 29871; // TODO: this should not be hardcoded
2001-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
2002-
suffix_tokens.erase(suffix_tokens.begin());
2003-
}
2004-
2005-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
2006-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
1988+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
1989+
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
20071990

20081991
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
20091992
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1993+
20101994
if (add_bos) {
20111995
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
20121996
}
1997+
20131998
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
20141999

2015-
const llama_token middle_token = llama_token_middle(model);
2000+
const llama_token middle_token = llama_token_fim_mid(model);
20162001
if (middle_token >= 0) {
20172002
embd_inp.push_back(middle_token);
20182003
}
@@ -2031,25 +2016,30 @@ struct server_context {
20312016
prompt_tokens.clear();
20322017
prompt_tokens.push_back(llama_token_bos(model));
20332018
{
2034-
const auto part = tokenize(slot.prompt[0], false);
2019+
const auto part = tokenize(slot.prompt[0], false, false);
20352020
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
20362021
}
20372022
prompt_tokens.push_back(llama_token_eos(model));
20382023
prompt_tokens.push_back(llama_token_sep(model));
20392024
{
2040-
const auto part = tokenize(slot.prompt[1], false);
2025+
const auto part = tokenize(slot.prompt[1], false, false);
20412026
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
20422027
}
20432028
prompt_tokens.push_back(llama_token_eos(model));
20442029
} else {
2045-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
2030+
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
20462031
}
20472032

20482033
slot.n_past = 0;
20492034
slot.n_prompt_tokens = prompt_tokens.size();
20502035

20512036
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
20522037

2038+
// print prompt tokens:
2039+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
2040+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
2041+
}
2042+
20532043
// empty prompt passed -> release the slot and send empty response
20542044
if (prompt_tokens.empty()) {
20552045
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -2942,7 +2932,23 @@ int main(int argc, char ** argv) {
29422932
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
29432933
};
29442934

2945-
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2935+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2936+
std::string err;
2937+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2938+
err += "prefix token is missing. ";
2939+
}
2940+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2941+
err += "suffix token is missing. ";
2942+
}
2943+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2944+
err += "middle token is missing. ";
2945+
}
2946+
2947+
if (!err.empty()) {
2948+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2949+
return;
2950+
}
2951+
29462952
json data = json::parse(req.body);
29472953
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
29482954
};
@@ -3028,7 +3034,8 @@ int main(int argc, char ** argv) {
30283034
if (body.count("content") != 0) {
30293035
const bool add_special = json_value(body, "add_special", false);
30303036
const bool with_pieces = json_value(body, "with_pieces", false);
3031-
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3037+
3038+
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
30323039

30333040
if (with_pieces) {
30343041
for (const auto& token : tokens) {

include/llama.h

+12-5
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ extern "C" {
896896
// Special tokens
897897
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
898898
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
899+
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
899900
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
900901
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
901902
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -904,11 +905,17 @@ extern "C" {
904905
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
905906
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
906907

907-
// Codellama infill tokens
908-
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
909-
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
910-
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
911-
LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
908+
// infill tokens
909+
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
910+
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
911+
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
912+
913+
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
914+
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
915+
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
916+
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
917+
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
918+
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
912919

913920
//
914921
// Tokenization

src/llama-vocab.cpp

+31-7
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
16631663
return vocab.special_eos_id;
16641664
}
16651665

1666+
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1667+
return vocab.special_eot_id;
1668+
}
1669+
1670+
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1671+
return vocab.special_eom_id;
1672+
}
1673+
16661674
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
16671675
return vocab.special_cls_id;
16681676
}
@@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
16881696
}
16891697

16901698
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1691-
return vocab.special_prefix_id;
1699+
return vocab.special_fim_pre_id;
16921700
}
16931701

16941702
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1695-
return vocab.special_middle_id;
1703+
return vocab.special_fim_mid_id;
16961704
}
16971705

16981706
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1699-
return vocab.special_suffix_id;
1707+
return vocab.special_fim_suf_id;
17001708
}
17011709

1702-
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1703-
return vocab.special_eot_id;
1710+
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
1711+
return vocab.special_fim_pre_id;
17041712
}
17051713

1706-
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1707-
return vocab.special_eom_id;
1714+
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
1715+
return vocab.special_fim_suf_id;
1716+
}
1717+
1718+
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
1719+
return vocab.special_fim_mid_id;
1720+
}
1721+
1722+
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
1723+
return vocab.special_fim_pad_id;
1724+
}
1725+
1726+
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
1727+
return vocab.special_fim_rep_id;
1728+
}
1729+
1730+
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
1731+
return vocab.special_fim_sep_id;
17081732
}
17091733

17101734
int32_t llama_tokenize_impl(

0 commit comments

Comments
 (0)