1
1
#include " common.h"
2
2
#include " llama.h"
3
3
#include " build-info.h"
4
+ #include " grammar-parser.h"
4
5
5
6
#ifndef NDEBUG
6
7
// crash the server in debug mode, otherwise send an http 500 error
@@ -195,6 +196,8 @@ struct llama_server_context
195
196
llama_context *ctx = nullptr ;
196
197
gpt_params params;
197
198
199
+ llama_grammar *grammar = nullptr ;
200
+
198
201
bool truncated = false ;
199
202
bool stopped_eos = false ;
200
203
bool stopped_word = false ;
@@ -226,6 +229,7 @@ struct llama_server_context
226
229
void rewind ()
227
230
{
228
231
params.antiprompt .clear ();
232
+ params.grammar .clear ();
229
233
num_prompt_tokens = 0 ;
230
234
num_tokens_predicted = 0 ;
231
235
generated_text = " " ;
@@ -237,6 +241,7 @@ struct llama_server_context
237
241
stopped_limit = false ;
238
242
stopping_word = " " ;
239
243
multibyte_pending = 0 ;
244
+ grammar = nullptr ;
240
245
241
246
n_remain = 0 ;
242
247
n_past = 0 ;
@@ -257,6 +262,33 @@ struct llama_server_context
257
262
return true ;
258
263
}
259
264
265
+ bool loadGrammar ()
266
+ {
267
+ if (!params.grammar .empty ()) {
268
+ grammar_parser::parse_state parsed_grammar;
269
+
270
+ parsed_grammar = grammar_parser::parse (params.grammar .c_str ());
271
+ // will be empty (default) if there are parse errors
272
+ if (parsed_grammar.rules .empty ()) {
273
+ LOG_ERROR (" grammar parse error" , {{" grammar" , params.grammar }});
274
+ return false ;
275
+ }
276
+ grammar_parser::print_grammar (stderr, parsed_grammar);
277
+
278
+ {
279
+ auto it = params.logit_bias .find (llama_token_eos ());
280
+ if (it != params.logit_bias .end () && it->second == -INFINITY) {
281
+ LOG_WARNING (" EOS token is disabled, which will cause most grammars to fail" , {});
282
+ }
283
+ }
284
+
285
+ std::vector<const llama_grammar_element *> grammar_rules (parsed_grammar.c_rules ());
286
+ grammar = llama_grammar_init (
287
+ grammar_rules.data (), grammar_rules.size (), parsed_grammar.symbol_ids .at (" root" ));
288
+ }
289
+ return true ;
290
+ }
291
+
260
292
void loadPrompt ()
261
293
{
262
294
params.prompt .insert (0 , 1 , ' ' ); // always add a first space
@@ -420,6 +452,10 @@ struct llama_server_context
420
452
logits[llama_token_nl ()] = nl_logit;
421
453
}
422
454
455
+ if (grammar != nullptr ) {
456
+ llama_sample_grammar (ctx, &candidates_p, grammar);
457
+ }
458
+
423
459
if (temp <= 0 )
424
460
{
425
461
// Greedy sampling
@@ -457,10 +493,15 @@ struct llama_server_context
457
493
}
458
494
}
459
495
496
+ if (grammar != nullptr ) {
497
+ llama_grammar_accept_token (ctx, grammar, result.tok );
498
+ }
499
+
460
500
for (size_t i = 0 ; i < std::min (candidates_p.size , (size_t )n_probs); ++i)
461
501
{
462
502
result.probs .push_back ({candidates_p.data [i].id , candidates_p.data [i].p });
463
503
}
504
+
464
505
last_n_tokens.erase (last_n_tokens.begin ());
465
506
last_n_tokens.push_back (result.tok );
466
507
num_tokens_predicted++;
@@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
947
988
{" stream" , llama.stream },
948
989
{" logit_bias" , llama.params .logit_bias },
949
990
{" n_probs" , llama.params .n_probs },
991
+ {" grammar" , llama.params .grammar },
950
992
};
951
993
}
952
994
@@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
1048
1090
llama.params .n_keep = body.value (" n_keep" , default_params.n_keep );
1049
1091
llama.params .seed = body.value (" seed" , default_params.seed );
1050
1092
llama.params .prompt = body.value (" prompt" , default_params.prompt );
1093
+ llama.params .grammar = body.value (" grammar" , default_params.grammar );
1051
1094
llama.params .n_probs = body.value (" n_probs" , default_params.n_probs );
1052
1095
1053
1096
llama.params .logit_bias .clear ();
@@ -1179,6 +1222,12 @@ int main(int argc, char **argv)
1179
1222
1180
1223
parse_options_completion (json::parse (req.body ), llama);
1181
1224
1225
+ if (!llama.loadGrammar ())
1226
+ {
1227
+ res.status = 400 ;
1228
+ return ;
1229
+ }
1230
+
1182
1231
llama.loadPrompt ();
1183
1232
llama.beginCompletion ();
1184
1233
@@ -1334,8 +1383,12 @@ int main(int argc, char **argv)
1334
1383
1335
1384
svr.set_error_handler ([](const Request &, Response &res)
1336
1385
{
1337
- res.set_content (" File Not Found" , " text/plain" );
1338
- res.status = 404 ; });
1386
+ if (res.status == 400 ) {
1387
+ res.set_content (" Invalid request" , " text/plain" );
1388
+ } else {
1389
+ res.set_content (" File Not Found" , " text/plain" );
1390
+ res.status = 404 ;
1391
+ } });
1339
1392
1340
1393
// set timeouts and change hostname and port
1341
1394
svr.set_read_timeout (sparams.read_timeout );
@@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
1363
1416
return 1 ;
1364
1417
}
1365
1418
1419
+ if (llama.grammar != nullptr ) {
1420
+ llama_grammar_free (llama.grammar );
1421
+ }
1366
1422
llama_backend_free ();
1367
1423
1368
1424
return 0 ;
0 commit comments