Skip to content

Commit 8bcfc55

Browse files
committed
server : return tokens ids only if requested
ggml-ci
1 parent d58f8a1 commit 8bcfc55

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

examples/server/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ These words will not be included in the completion, so make sure to add them to
438438

439439
`cache_prompt`: Re-use KV cache from a previous request if possible. This way the common prefix does not have to be re-processed, only the suffix that differs between the requests. Because (depending on the backend) the logits are **not** guaranteed to be bit-for-bit identical for different batch sizes (prompt processing vs. token generation) enabling this option can cause nondeterministic results. Default: `true`
440440

441+
`return_tokens`: Return the raw generated token ids in the `tokens` field. Otherwise `tokens` remains empty. Default: `false`
442+
441443
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
442444

443445
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
@@ -451,7 +453,7 @@ These words will not be included in the completion, so make sure to add them to
451453
```json
452454
{
453455
"content": "<the token generated by the model>",
454-
"tokens": [ generated token ids ],
456+
"tokens": [ generated token ids if requested ],
455457
"probs": [
456458
{
457459
"prob": float,
@@ -469,7 +471,7 @@ These words will not be included in the completion, so make sure to add them to
469471
Notice that each `probs` is an array of length `n_probs`.
470472

471473
- `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string.
472-
- `tokens`: Same as `content` but represented as raw token ids.
474+
- `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request.
473475
- `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options)
474476
- `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.).
475477
- `model`: The path to the model loaded with `-m`

examples/server/server.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ enum error_type {
7979
};
8080

8181
struct slot_params {
82-
bool stream = true;
83-
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
82+
bool stream = true;
83+
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
84+
bool return_tokens = false;
8485

8586
int32_t n_keep = 0; // number of tokens to keep from initial prompt
8687
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -199,6 +200,7 @@ struct server_task {
199200

200201
params.stream = json_value(data, "stream", false);
201202
params.cache_prompt = json_value(data, "cache_prompt", true);
203+
params.return_tokens = json_value(data, "return_tokens", false);
202204
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
203205
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
204206
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
@@ -543,7 +545,7 @@ struct server_task_result_cmpl_final : server_task_result {
543545
json choices = json::array({json{
544546
{"finish_reason", finish_reason},
545547
{"index", 0},
546-
{"message", json{
548+
{"message", json {
547549
{"content", content},
548550
{"tokens", tokens},
549551
{"role", "assistant"}
@@ -998,7 +1000,6 @@ struct server_slot {
9981000
n_prompt_tokens = 0;
9991001
last_nl_pos = 0;
10001002
generated_text = "";
1001-
generated_tokens = {};
10021003
has_new_line = false;
10031004
truncated = false;
10041005
stop = STOP_TYPE_NONE;
@@ -1008,6 +1009,7 @@ struct server_slot {
10081009
n_sent_token_probs = 0;
10091010
task_type = SERVER_TASK_TYPE_COMPLETION;
10101011

1012+
generated_tokens.clear();
10111013
generated_token_probs.clear();
10121014
}
10131015

@@ -1748,9 +1750,10 @@ struct server_context {
17481750
const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special);
17491751
slot.sampled = result.tok;
17501752

1751-
// search stop word and delete it
17521753
slot.generated_text += token_str;
1753-
slot.generated_tokens.push_back(result.tok);
1754+
if (slot.params.return_tokens) {
1755+
slot.generated_tokens.push_back(result.tok);
1756+
}
17541757
slot.has_next_token = true;
17551758

17561759
// check if there is incomplete UTF-8 character at the end
@@ -1775,6 +1778,7 @@ struct server_context {
17751778
break;
17761779
}
17771780

1781+
// search stop word and delete it
17781782
if (!incomplete) {
17791783
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
17801784

examples/server/tests/unit/test_completion.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,28 @@ def create_server():
1010
global server
1111
server = ServerPreset.tinyllama2()
1212

13-
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
14-
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
15-
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
13+
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
14+
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
15+
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
1616
])
17-
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
17+
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
1818
global server
1919
server.start()
2020
res = server.make_request("POST", "/completion", data={
2121
"n_predict": n_predict,
2222
"prompt": prompt,
23+
"return_tokens": return_tokens,
2324
})
2425
assert res.status_code == 200
2526
assert res.body["timings"]["prompt_n"] == n_prompt
2627
assert res.body["timings"]["predicted_n"] == n_predicted
2728
assert res.body["truncated"] == truncated
2829
assert type(res.body["has_new_line"]) == bool
2930
assert match_regex(re_content, res.body["content"])
31+
if return_tokens:
32+
assert res.body["tokens"] != []
33+
else:
34+
assert res.body["tokens"] == []
3035

3136

3237
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
@@ -56,6 +61,7 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
5661
assert data["generation_settings"]["seed"] == server.seed
5762
assert match_regex(re_content, content)
5863
else:
64+
assert data["tokens"] != []
5965
content += data["content"]
6066

6167

0 commit comments

Comments
 (0)