Skip to content

Commit f0ded27

Browse files
committed
server : add n_indent parameter for line indentation requirement
ggml-ci
1 parent 99bd4ac commit f0ded27

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

examples/server/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ node index.js
333333

334334
`n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
335335

336+
`n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`
337+
336338
`n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
337339
By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
338340

examples/server/server.cpp

+47-7
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ struct slot_params {
131131
int32_t n_keep = 0; // number of tokens to keep from initial prompt
132132
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133133
int32_t n_predict = -1; // new tokens to predict
134+
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
134135

135136
int64_t t_max_prompt_ms = -1; // TODO: implement
136137
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -173,6 +174,8 @@ struct server_slot {
173174
std::vector<llama_token> prompt_tokens;
174175
std::vector<llama_token> extra_tokens;
175176

177+
size_t last_nl_pos = 0;
178+
176179
std::string generated_text;
177180
std::vector<llama_token> cache_tokens;
178181
std::vector<completion_token_output> generated_token_probs;
@@ -215,6 +218,7 @@ struct server_slot {
215218
SLT_DBG(*this, "%s", "\n");
216219

217220
n_prompt_tokens = 0;
221+
last_nl_pos = 0;
218222
generated_text = "";
219223
has_new_line = false;
220224
truncated = false;
@@ -860,6 +864,7 @@ struct server_context {
860864
slot.params.stream = json_value(data, "stream", false);
861865
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
862866
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
867+
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
863868
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
864869
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
865870
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@@ -878,7 +883,7 @@ struct server_context {
878883
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
879884
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
880885
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
881-
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
886+
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
882887
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
883888
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
884889
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
@@ -1129,13 +1134,48 @@ struct server_context {
11291134
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
11301135
}
11311136

1132-
// if we have already seen a new line, we stop after a certain time limit
1133-
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
1134-
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1135-
slot.stopped_limit = true;
1136-
slot.has_next_token = false;
1137+
if (slot.has_new_line) {
1138+
// if we have already seen a new line, we stop after a certain time limit
1139+
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1140+
slot.stopped_limit = true;
1141+
slot.has_next_token = false;
1142+
1143+
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1144+
}
1145+
1146+
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1147+
if (slot.params.n_indent > 0) {
1148+
// check the current indentation
1149+
// TODO: improve by not doing it more than once for each new line
1150+
if (slot.last_nl_pos > 0) {
1151+
size_t pos = slot.last_nl_pos;
1152+
1153+
int n_indent = 0;
1154+
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
1155+
n_indent++;
1156+
pos++;
1157+
}
1158+
1159+
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
1160+
slot.stopped_limit = true;
1161+
slot.has_next_token = false;
1162+
1163+
// cut the last line
1164+
slot.generated_text.erase(pos, std::string::npos);
11371165

1138-
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1166+
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
1167+
}
1168+
}
1169+
1170+
// find the next new line
1171+
{
1172+
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
1173+
1174+
if (pos != std::string::npos) {
1175+
slot.last_nl_pos = pos + 1;
1176+
}
1177+
}
1178+
}
11391179
}
11401180

11411181
// check if there is a new line in the generated text

0 commit comments

Comments
 (0)