@@ -136,10 +136,6 @@ struct slot_params {
136
136
int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
137
137
138
138
std::vector<std::string> antiprompt;
139
-
140
- json input_prefix;
141
- json input_suffix;
142
- json extra_context;
143
139
};
144
140
145
141
struct server_slot {
@@ -169,6 +165,10 @@ struct server_slot {
169
165
170
166
json prompt; // can be either a string, array of strings or array of token ids
171
167
168
+ json input_prefix;
169
+ json input_suffix;
170
+ json input_extra;
171
+
172
172
// when a task is submitted, we first tokenize the prompt and store it here
173
173
std::vector<llama_token> prompt_tokens;
174
174
std::vector<llama_token> extra_tokens;
@@ -910,12 +910,12 @@ struct server_context {
910
910
}
911
911
912
912
// infill
913
- slot.params . input_prefix = json_value (data, " input_prefix" , default_params. input_prefix );
914
- slot.params . input_suffix = json_value (data, " input_suffix" , default_params. input_suffix );
915
- slot.params . extra_context = json_value (data, " extra_context " , default_params. extra_context );
913
+ slot.input_prefix = json_value (data, " input_prefix" , json () );
914
+ slot.input_suffix = json_value (data, " input_suffix" , json () );
915
+ slot.input_extra = json_value (data, " input_extra " , json () );
916
916
917
- SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params . extra_context .size ());
918
- for (const auto & chunk : slot.params . extra_context ) {
917
+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.input_extra .size ());
918
+ for (const auto & chunk : slot.input_extra ) {
919
919
// { "text": string, "filename": string }
920
920
if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
921
921
send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
@@ -932,7 +932,7 @@ struct server_context {
932
932
}
933
933
934
934
// get prompt
935
- if (task. cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
935
+ {
936
936
const auto & prompt = data.find (" prompt" );
937
937
if (prompt == data.end ()) {
938
938
send_error (task, " \" prompt\" must be provided" , ERROR_TYPE_INVALID_REQUEST);
@@ -1958,6 +1958,8 @@ struct server_context {
1958
1958
} break ;
1959
1959
case SERVER_TASK_CMPL_TYPE_INFILL:
1960
1960
{
1961
+ // TODO: optimize this block by reducing memory allocations and movement
1962
+
1961
1963
// use FIM repo-level pattern:
1962
1964
// ref: https://arxiv.org/pdf/2409.12186
1963
1965
//
@@ -1968,10 +1970,11 @@ struct server_context {
1968
1970
// extra chunk 1
1969
1971
// ...
1970
1972
// [FIM_SEP]filename
1971
- // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1973
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
1972
1974
//
1973
- auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1974
- auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1975
+ auto tokens_prefix = tokenize (slot.input_prefix , false , false );
1976
+ auto tokens_suffix = tokenize (slot.input_suffix , false , false );
1977
+ auto tokens_prompt = tokenize (slot.prompt , false , false );
1975
1978
1976
1979
slot.extra_tokens .clear ();
1977
1980
if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
@@ -1981,7 +1984,7 @@ struct server_context {
1981
1984
slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1982
1985
}
1983
1986
1984
- for (const auto & chunk : slot.params . extra_context ) {
1987
+ for (const auto & chunk : slot.input_extra ) {
1985
1988
// { "text": string, "filename": string }
1986
1989
const std::string text = chunk.value (" text" , " " );
1987
1990
const std::string filename = chunk.value (" filename" , " tmp" );
@@ -2012,20 +2015,21 @@ struct server_context {
2012
2015
}
2013
2016
2014
2017
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2015
- const int n_suffix_take = std::min<int >(suffix_tokens .size (), (n_batch)/ 4 );
2016
- const int n_prefix_take = std::min<int >(prefix_tokens .size (), (n_batch - 3 ) - n_suffix_take );
2018
+ const int n_suffix_take = std::min<int >(tokens_suffix .size (), (n_batch/ 4 ) );
2019
+ const int n_prefix_take = std::min<int >(tokens_prefix .size (), 3 * (n_batch/ 4 ) - 3 );
2017
2020
2018
2021
// fill the rest of the context with extra chunks
2019
2022
const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2020
2023
2021
- prefix_tokens .erase (prefix_tokens .begin (), prefix_tokens .begin () + prefix_tokens .size () - n_prefix_take);
2022
- suffix_tokens .resize (n_suffix_take);
2024
+ tokens_prefix .erase (tokens_prefix .begin (), tokens_prefix .begin () + tokens_prefix .size () - n_prefix_take);
2025
+ tokens_suffix .resize (n_suffix_take);
2023
2026
2024
- prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
2025
- suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
2027
+ tokens_prefix.insert (tokens_prefix.begin (), llama_token_fim_pre (model));
2028
+ tokens_prefix.insert (tokens_prefix.end (), tokens_prompt.begin (), tokens_prompt.end ());
2029
+ tokens_suffix.insert (tokens_suffix.begin (), llama_token_fim_suf (model));
2026
2030
2027
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens ;
2028
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens ;
2031
+ auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix ;
2032
+ auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix ;
2029
2033
2030
2034
if (llama_add_bos_token (model)) {
2031
2035
embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
@@ -2140,40 +2144,17 @@ struct server_context {
2140
2144
2141
2145
while (head_c < slot.cache_tokens .size () &&
2142
2146
head_p < prompt_tokens.size ()) {
2143
- if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2144
- slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2145
- slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2146
- break ;
2147
- }
2148
-
2149
- if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2150
- prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2151
- prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2152
- break ;
2153
- }
2154
2147
2155
2148
size_t n_match = 0 ;
2156
-
2157
2149
while (head_c + n_match < slot.cache_tokens .size () &&
2158
2150
head_p + n_match < prompt_tokens.size () &&
2159
2151
slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2160
- if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2161
- slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2162
- slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2163
- break ;
2164
- }
2165
-
2166
- if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2167
- prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2168
- prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2169
- break ;
2170
- }
2171
2152
2172
2153
n_match++;
2173
2154
}
2174
2155
2175
2156
if (n_match >= (size_t ) params.n_cache_reuse ) {
2176
- SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2157
+ SLT_INF (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2177
2158
// for (size_t i = head_p; i < head_p + n_match; i++) {
2178
2159
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2179
2160
// }
0 commit comments