Skip to content

server : accept extra_context for the infill endpoint #9874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,30 @@ Takes a prefix and a suffix and returns the predicted completion as stream.

- `input_prefix`: Set the prefix of the code to infill.
- `input_suffix`: Set the suffix of the code to infill.
- `prompt`: Added after the `FIM_MID` token
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874

It also accepts all the options of `/completion`.

If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:

```txt
<FIM_REP>myproject
<FIM_SEP>{chunk 0 filename}
{chunk 0 text}
<FIM_SEP>{chunk 1 filename}
{chunk 1 text}
...
<FIM_SEP>filename
<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```

If the tokens are missing, then the extra context is simply prefixed at the start:

```txt
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
```

### **GET** `/props`: Get server global properties.

This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props`
Expand Down
102 changes: 94 additions & 8 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ struct slot_params {

json input_prefix;
json input_suffix;
json extra_context;
};

struct server_slot {
Expand Down Expand Up @@ -170,6 +171,7 @@ struct server_slot {

// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
std::vector<llama_token> extra_tokens;

std::string generated_text;
std::vector<llama_token> cache_tokens;
Expand Down Expand Up @@ -906,8 +908,26 @@ struct server_context {
}

// infill
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);

SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
for (const auto & chunk : slot.params.extra_context) {
// { "text": string, "filename": string }
if (!chunk.contains("text") || !chunk["text"].is_string()) {
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
return false;
}

// filename is optional
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
return false;
}

SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
}

// get prompt
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
Expand Down Expand Up @@ -1934,13 +1954,66 @@ struct server_context {
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
//
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);

// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
slot.extra_tokens.clear();
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
static const auto k_fim_repo = tokenize("myproject\n", false, false);

slot.extra_tokens.push_back(llama_token_fim_rep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}

for (const auto & chunk : slot.params.extra_context) {
// { "text": string, "filename": string }
const std::string text = chunk.value("text", "");
const std::string filename = chunk.value("filename", "tmp");

if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = tokenize(filename + "\n", false, false);

slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);

slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}

const auto chunk_tokens = tokenize(text, false, false);
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}

if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = tokenize("filename\n", false, false);

slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}

// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);

// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());

prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
suffix_tokens.resize(n_suffix_take);

Expand All @@ -1954,6 +2027,11 @@ struct server_context {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}

SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());

// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());

embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));

Expand Down Expand Up @@ -2058,11 +2136,15 @@ struct server_context {

while (head_c < slot.cache_tokens.size() &&
head_p < prompt_tokens.size()) {
if (llama_token_is_control(model, slot.cache_tokens[head_c])) {
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p])) {
if (llama_token_is_control(model, prompt_tokens[head_p]) &&
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
break;
}

Expand All @@ -2071,11 +2153,15 @@ struct server_context {
while (head_c + n_match < slot.cache_tokens.size() &&
head_p + n_match < prompt_tokens.size() &&
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match])) {
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
break;
}

if (llama_token_is_control(model, prompt_tokens[head_p + n_match])) {
if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
break;
}

Expand Down
Loading
Loading