Skip to content

Commit 474d0e6

Browse files
committed
llama : add infill sampler
ggml-ci
1 parent 25f3b4d commit 474d0e6

File tree

11 files changed

+294
-64
lines changed

11 files changed

+294
-64
lines changed

common/arg.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
947947
params.sparams.tfs_z = std::stof(value);
948948
}
949949
).set_sparam());
950+
add_opt(llama_arg(
951+
{"--infill-p"}, "N",
952+
string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p),
953+
[](gpt_params & params, const std::string & value) {
954+
params.sparams.infill_p = std::stof(value);
955+
}
956+
).set_sparam());
957+
add_opt(llama_arg(
958+
{"--infill-p-eog"}, "N",
959+
string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog),
960+
[](gpt_params & params, const std::string & value) {
961+
params.sparams.infill_p_eog = std::stof(value);
962+
}
963+
).set_sparam());
950964
add_opt(llama_arg(
951965
{"--typical"}, "N",
952966
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),

common/common.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ enum gpt_sampler_type {
9090
GPT_SAMPLER_TYPE_TFS_Z = 4,
9191
GPT_SAMPLER_TYPE_TYPICAL_P = 5,
9292
GPT_SAMPLER_TYPE_TEMPERATURE = 6,
93+
GPT_SAMPLER_TYPE_INFILL = 7,
9394
};
9495

9596
// dimensionality reduction methods, used by cvector-generator
@@ -113,6 +114,8 @@ struct gpt_sampler_params {
113114
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
114115
float dynatemp_range = 0.00f; // 0.0 = disabled
115116
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
117+
float infill_p = 0.80f;
118+
float infill_p_eog = 0.01f;
116119
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
117120
float penalty_repeat = 1.00f; // 1.0 = disabled
118121
float penalty_freq = 0.00f; // 0.0 = disabled
@@ -130,7 +133,7 @@ struct gpt_sampler_params {
130133
GPT_SAMPLER_TYPE_TYPICAL_P,
131134
GPT_SAMPLER_TYPE_TOP_P,
132135
GPT_SAMPLER_TYPE_MIN_P,
133-
GPT_SAMPLER_TYPE_TEMPERATURE
136+
GPT_SAMPLER_TYPE_TEMPERATURE,
134137
};
135138

136139
std::string grammar; // optional BNF-like grammar to constrain sampling

common/sampling.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
193193
case GPT_SAMPLER_TYPE_TEMPERATURE:
194194
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195195
break;
196+
case GPT_SAMPLER_TYPE_INFILL:
197+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog));
198+
break;
196199
default:
197200
GGML_ASSERT(false && "unknown sampler type");
198201
}
@@ -372,6 +375,7 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
372375
case GPT_SAMPLER_TYPE_TOP_P: return 'p';
373376
case GPT_SAMPLER_TYPE_MIN_P: return 'm';
374377
case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
378+
case GPT_SAMPLER_TYPE_INFILL: return 'i';
375379
default : return '?';
376380
}
377381
}
@@ -384,6 +388,7 @@ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
384388
case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
385389
case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
386390
case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
391+
case GPT_SAMPLER_TYPE_INFILL: return "infill";
387392
default : return "";
388393
}
389394
}
@@ -396,6 +401,7 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std
396401
{ "min_p", GPT_SAMPLER_TYPE_MIN_P },
397402
{ "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
398403
{ "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
404+
{ "infill", GPT_SAMPLER_TYPE_INFILL }
399405
};
400406

401407
// since samplers names are written multiple ways
@@ -441,7 +447,8 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & c
441447
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
442448
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
443449
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
444-
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
450+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE },
451+
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_INFILL), GPT_SAMPLER_TYPE_INFILL },
445452
};
446453

447454
std::vector<gpt_sampler_type> samplers;

examples/llama.vim

+12-10
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
"
1212

1313
let s:default_config = {
14+
\ 'endpoint': 'http://127.0.0.1:8012/infill',
1415
\ 'prefix_lines': 32,
1516
\ 'suffix_lines': 32,
16-
\ 'endpoint': 'http://127.0.0.1:8012/infill',
17-
\ 'stop': ["\n"],
18-
\ 'n_predict': 64,
19-
\ 'n_probs': 3,
20-
\ 'temperature': 0.1
21-
\}
17+
\ 'n_predict': 64,
18+
\ 'n_probs': 3,
19+
\ 'temperature': 0.1,
20+
\ 'stop': ["\n"]
21+
\ }
2222

2323
let g:llama_config = get(g:, 'llama_config', s:default_config)
2424

@@ -45,14 +45,16 @@ function! llama#fim() abort
4545
\ 'prompt': "",
4646
\ 'input_prefix': l:prefix,
4747
\ 'input_suffix': l:suffix,
48-
"\ 'stop': g:llama_config.stop,
48+
"\ 'stop': g:llama_config.stop,
4949
\ 'n_predict': g:llama_config.n_predict,
50-
"\ 'n_probs': g:llama_config.n_probs,
50+
"\ 'n_probs': g:llama_config.n_probs,
5151
\ 'penalty_last_n': 0,
5252
\ 'temperature': g:llama_config.temperature,
53-
\ 'top_k': 10,
53+
\ 'top_k': 5,
54+
\ 'infill_p': 0.20,
55+
\ 'infill_p_eog': 0.001,
5456
\ 'stream': v:false,
55-
\ 'samplers': ["top_k"]
57+
\ 'samplers': ["top_k", "infill"]
5658
\ })
5759

5860
" request completion from the server

examples/server/server.cpp

+55-49
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,8 @@ struct server_context {
889889
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
890890
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
891891
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
892+
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
893+
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
892894
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
893895
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
894896
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
@@ -1236,6 +1238,8 @@ struct server_context {
12361238
{"min_p", slot.sparams.min_p},
12371239
{"tfs_z", slot.sparams.tfs_z},
12381240
{"typical_p", slot.sparams.typ_p},
1241+
{"infill_p", slot.sparams.infill_p},
1242+
{"infill_p_eog", slot.sparams.infill_p_eog},
12391243
{"repeat_last_n", slot.sparams.penalty_last_n},
12401244
{"repeat_penalty", slot.sparams.penalty_repeat},
12411245
{"presence_penalty", slot.sparams.penalty_present},
@@ -1964,55 +1968,57 @@ struct server_context {
19641968
slot.t_start_process_prompt = ggml_time_us();
19651969
slot.t_start_generation = 0;
19661970

1967-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1968-
const bool add_bos = llama_add_bos_token(model);
1969-
1970-
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
1971-
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
1972-
1973-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
1974-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
1975-
1976-
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
1977-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1978-
1979-
if (add_bos) {
1980-
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
1981-
}
1982-
1983-
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
1984-
1985-
const llama_token middle_token = llama_token_fim_mid(model);
1986-
if (middle_token >= 0) {
1987-
embd_inp.push_back(middle_token);
1988-
}
1989-
1990-
prompt_tokens = embd_inp;
1991-
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1992-
// require slot.prompt to be array of 2 strings
1993-
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
1994-
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
1995-
slot.release();
1996-
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
1997-
continue;
1998-
}
1999-
2000-
// prompt: [BOS]query[EOS][SEP]doc[EOS]
2001-
prompt_tokens.clear();
2002-
prompt_tokens.push_back(llama_token_bos(model));
2003-
{
2004-
const auto part = tokenize(slot.prompt[0], false, false);
2005-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2006-
}
2007-
prompt_tokens.push_back(llama_token_eos(model));
2008-
prompt_tokens.push_back(llama_token_sep(model));
2009-
{
2010-
const auto part = tokenize(slot.prompt[1], false, false);
2011-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2012-
}
2013-
prompt_tokens.push_back(llama_token_eos(model));
2014-
} else {
2015-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
1971+
switch (slot.cmpl_type) {
1972+
case SERVER_TASK_CMPL_TYPE_NORMAL:
1973+
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
1974+
{
1975+
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
1976+
} break;
1977+
case SERVER_TASK_CMPL_TYPE_RERANK:
1978+
{
1979+
// require slot.prompt to be array of 2 strings
1980+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
1981+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
1982+
slot.release();
1983+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
1984+
continue;
1985+
}
1986+
1987+
// prompt: [BOS]query[EOS][SEP]doc[EOS]
1988+
prompt_tokens.clear();
1989+
prompt_tokens.push_back(llama_token_bos(model));
1990+
{
1991+
const auto part = tokenize(slot.prompt[0], false, false);
1992+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1993+
}
1994+
prompt_tokens.push_back(llama_token_eos(model));
1995+
prompt_tokens.push_back(llama_token_sep(model));
1996+
{
1997+
const auto part = tokenize(slot.prompt[1], false, false);
1998+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1999+
}
2000+
prompt_tokens.push_back(llama_token_eos(model));
2001+
} break;
2002+
case SERVER_TASK_CMPL_TYPE_INFILL:
2003+
{
2004+
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
2005+
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
2006+
2007+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
2008+
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
2009+
2010+
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2011+
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2012+
2013+
if (llama_add_bos_token(model)) {
2014+
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2015+
}
2016+
2017+
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
2018+
embd_inp.push_back(llama_token_fim_mid(model));
2019+
2020+
prompt_tokens = std::move(embd_inp);
2021+
} break;
20162022
}
20172023

20182024
slot.n_past = 0;

include/llama.h

+26
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,12 @@ extern "C" {
952952
int32_t lstrip,
953953
bool special);
954954

955+
// check if token0 is contained as a prefix in token1
956+
LLAMA_API bool llama_token_is_prefix(
957+
const struct llama_model * model,
958+
llama_token token0,
959+
llama_token token1);
960+
955961
/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
956962
/// @param text The char pointer must be large enough to hold the resulting text.
957963
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
@@ -1144,6 +1150,26 @@ extern "C" {
11441150
int32_t n_logit_bias,
11451151
const llama_logit_bias * logit_bias);
11461152

1153+
// 1. if there is a high-prob token (>= 0.9f) - pick it
1154+
// 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away
1155+
// 3. combine probs of tokens that have the same prefix
1156+
//
1157+
// example:
1158+
//
1159+
// - before:
1160+
// "hel": 0.5
1161+
// "hell": 0.2
1162+
// "hello": 0.1
1163+
// "dummy": 0.1
1164+
//
1165+
// - after:
1166+
// "hel": 0.8
1167+
// "dummy": 0.1
1168+
//
1169+
LLAMA_API struct llama_sampler * llama_sampler_init_infill(
1170+
const struct llama_model * model,
1171+
float p,
1172+
float p_eog);
11471173

11481174
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
11491175
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

0 commit comments

Comments
 (0)