@@ -131,6 +131,7 @@ struct slot_params {
131
131
int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
132
132
int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133
133
int32_t n_predict = -1 ; // new tokens to predict
134
+ int32_t n_indent = 0 ; // mininum line indentation for the generated text in number of whitespace characters
134
135
135
136
int64_t t_max_prompt_ms = -1 ; // TODO: implement
136
137
int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
@@ -173,6 +174,8 @@ struct server_slot {
173
174
std::vector<llama_token> prompt_tokens;
174
175
std::vector<llama_token> extra_tokens;
175
176
177
+ size_t last_nl_pos = 0 ;
178
+
176
179
std::string generated_text;
177
180
std::vector<llama_token> cache_tokens;
178
181
std::vector<completion_token_output> generated_token_probs;
@@ -215,6 +218,7 @@ struct server_slot {
215
218
SLT_DBG (*this , " %s" , " \n " );
216
219
217
220
n_prompt_tokens = 0 ;
221
+ last_nl_pos = 0 ;
218
222
generated_text = " " ;
219
223
has_new_line = false ;
220
224
truncated = false ;
@@ -860,6 +864,7 @@ struct server_context {
860
864
slot.params .stream = json_value (data, " stream" , false );
861
865
slot.params .cache_prompt = json_value (data, " cache_prompt" , false );
862
866
slot.params .n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , default_params.n_predict ));
867
+ slot.params .n_indent = json_value (data, " n_indent" , default_params.n_indent );
863
868
slot.sparams .top_k = json_value (data, " top_k" , default_sparams.top_k );
864
869
slot.sparams .top_p = json_value (data, " top_p" , default_sparams.top_p );
865
870
slot.sparams .min_p = json_value (data, " min_p" , default_sparams.min_p );
@@ -878,7 +883,7 @@ struct server_context {
878
883
slot.sparams .mirostat_tau = json_value (data, " mirostat_tau" , default_sparams.mirostat_tau );
879
884
slot.sparams .mirostat_eta = json_value (data, " mirostat_eta" , default_sparams.mirostat_eta );
880
885
slot.sparams .penalize_nl = json_value (data, " penalize_nl" , default_sparams.penalize_nl );
881
- slot.params .n_keep = json_value (data, " n_keep" , slot. params .n_keep );
886
+ slot.params .n_keep = json_value (data, " n_keep" , default_params .n_keep );
882
887
slot.params .n_discard = json_value (data, " n_discard" , default_params.n_discard );
883
888
slot.sparams .seed = json_value (data, " seed" , default_sparams.seed );
884
889
slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
@@ -1129,13 +1134,48 @@ struct server_context {
1129
1134
SLT_DBG (slot, " stopped by limit, n_decoded = %d, n_predict = %d\n " , slot.n_decoded , slot.params .n_predict );
1130
1135
}
1131
1136
1132
- // if we have already seen a new line, we stop after a certain time limit
1133
- if (slot.has_new_line && slot.params .t_max_predict_ms > 0 &&
1134
- (ggml_time_us () - slot.t_start_generation > 1000 .0f *slot.params .t_max_predict_ms )) {
1135
- slot.stopped_limit = true ;
1136
- slot.has_next_token = false ;
1137
+ if (slot.has_new_line ) {
1138
+ // if we have already seen a new line, we stop after a certain time limit
1139
+ if (slot.params .t_max_predict_ms > 0 && (ggml_time_us () - slot.t_start_generation > 1000 .0f *slot.params .t_max_predict_ms )) {
1140
+ slot.stopped_limit = true ;
1141
+ slot.has_next_token = false ;
1142
+
1143
+ SLT_DBG (slot, " stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n " , slot.n_decoded , (int ) slot.params .t_max_predict_ms );
1144
+ }
1145
+
1146
+ // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1147
+ if (slot.params .n_indent > 0 ) {
1148
+ // check the current indentation
1149
+ // TODO: improve by not doing it more than once for each new line
1150
+ if (slot.last_nl_pos > 0 ) {
1151
+ size_t pos = slot.last_nl_pos ;
1152
+
1153
+ int n_indent = 0 ;
1154
+ while (pos < slot.generated_text .size () && (slot.generated_text [pos] == ' ' || slot.generated_text [pos] == ' \t ' )) {
1155
+ n_indent++;
1156
+ pos++;
1157
+ }
1158
+
1159
+ if (pos < slot.generated_text .size () && n_indent < slot.params .n_indent ) {
1160
+ slot.stopped_limit = true ;
1161
+ slot.has_next_token = false ;
1162
+
1163
+ // cut the last line
1164
+ slot.generated_text .erase (pos, std::string::npos);
1137
1165
1138
- SLT_DBG (slot, " stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n " , slot.n_decoded , (int ) slot.params .t_max_predict_ms );
1166
+ SLT_DBG (slot, " stopped by indentation limit, n_decoded = %d, n_indent = %d\n " , slot.n_decoded , n_indent);
1167
+ }
1168
+ }
1169
+
1170
+ // find the next new line
1171
+ {
1172
+ const size_t pos = slot.generated_text .find (' \n ' , slot.last_nl_pos );
1173
+
1174
+ if (pos != std::string::npos) {
1175
+ slot.last_nl_pos = pos + 1 ;
1176
+ }
1177
+ }
1178
+ }
1139
1179
}
1140
1180
1141
1181
// check if there is a new line in the generated text
0 commit comments