Skip to content

Commit f5bfea0

Browse files
authored
Allow passing grammar to completion endpoint (ggml-org#2532)
* Allow passing grammar to completion endpoint
1 parent acfc547 commit f5bfea0

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml.
380380
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
381381
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
382382

383-
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
383+
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
384384
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
385385

386386
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)

examples/server/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ node .
151151

152152
`mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).
153153

154+
`grammar`: Set grammar for grammar-based sampling (default: no grammar)
155+
154156
`seed`: Set the random number generator (RNG) seed (default: -1, -1 = random seed).
155157

156158
`ignore_eos`: Ignore end of stream token and continue generating (default: false).

examples/server/server.cpp

+58-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "common.h"
22
#include "llama.h"
33
#include "build-info.h"
4+
#include "grammar-parser.h"
45

56
#ifndef NDEBUG
67
// crash the server in debug mode, otherwise send an http 500 error
@@ -195,6 +196,8 @@ struct llama_server_context
195196
llama_context *ctx = nullptr;
196197
gpt_params params;
197198

199+
llama_grammar *grammar = nullptr;
200+
198201
bool truncated = false;
199202
bool stopped_eos = false;
200203
bool stopped_word = false;
@@ -226,6 +229,7 @@ struct llama_server_context
226229
void rewind()
227230
{
228231
params.antiprompt.clear();
232+
params.grammar.clear();
229233
num_prompt_tokens = 0;
230234
num_tokens_predicted = 0;
231235
generated_text = "";
@@ -237,6 +241,7 @@ struct llama_server_context
237241
stopped_limit = false;
238242
stopping_word = "";
239243
multibyte_pending = 0;
244+
grammar = nullptr;
240245

241246
n_remain = 0;
242247
n_past = 0;
@@ -257,6 +262,33 @@ struct llama_server_context
257262
return true;
258263
}
259264

265+
bool loadGrammar()
266+
{
267+
if (!params.grammar.empty()) {
268+
grammar_parser::parse_state parsed_grammar;
269+
270+
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
271+
// will be empty (default) if there are parse errors
272+
if (parsed_grammar.rules.empty()) {
273+
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
274+
return false;
275+
}
276+
grammar_parser::print_grammar(stderr, parsed_grammar);
277+
278+
{
279+
auto it = params.logit_bias.find(llama_token_eos());
280+
if (it != params.logit_bias.end() && it->second == -INFINITY) {
281+
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
282+
}
283+
}
284+
285+
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
286+
grammar = llama_grammar_init(
287+
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
288+
}
289+
return true;
290+
}
291+
260292
void loadPrompt()
261293
{
262294
params.prompt.insert(0, 1, ' '); // always add a first space
@@ -420,6 +452,10 @@ struct llama_server_context
420452
logits[llama_token_nl()] = nl_logit;
421453
}
422454

455+
if (grammar != nullptr) {
456+
llama_sample_grammar(ctx, &candidates_p, grammar);
457+
}
458+
423459
if (temp <= 0)
424460
{
425461
// Greedy sampling
@@ -457,10 +493,15 @@ struct llama_server_context
457493
}
458494
}
459495

496+
if (grammar != nullptr) {
497+
llama_grammar_accept_token(ctx, grammar, result.tok);
498+
}
499+
460500
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
461501
{
462502
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
463503
}
504+
464505
last_n_tokens.erase(last_n_tokens.begin());
465506
last_n_tokens.push_back(result.tok);
466507
num_tokens_predicted++;
@@ -947,6 +988,7 @@ static json format_generation_settings(llama_server_context &llama)
947988
{"stream", llama.stream},
948989
{"logit_bias", llama.params.logit_bias},
949990
{"n_probs", llama.params.n_probs},
991+
{"grammar", llama.params.grammar},
950992
};
951993
}
952994

@@ -1048,6 +1090,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
10481090
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
10491091
llama.params.seed = body.value("seed", default_params.seed);
10501092
llama.params.prompt = body.value("prompt", default_params.prompt);
1093+
llama.params.grammar = body.value("grammar", default_params.grammar);
10511094
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
10521095

10531096
llama.params.logit_bias.clear();
@@ -1179,6 +1222,12 @@ int main(int argc, char **argv)
11791222

11801223
parse_options_completion(json::parse(req.body), llama);
11811224

1225+
if (!llama.loadGrammar())
1226+
{
1227+
res.status = 400;
1228+
return;
1229+
}
1230+
11821231
llama.loadPrompt();
11831232
llama.beginCompletion();
11841233

@@ -1334,8 +1383,12 @@ int main(int argc, char **argv)
13341383

13351384
svr.set_error_handler([](const Request &, Response &res)
13361385
{
1337-
res.set_content("File Not Found", "text/plain");
1338-
res.status = 404; });
1386+
if (res.status == 400) {
1387+
res.set_content("Invalid request", "text/plain");
1388+
} else {
1389+
res.set_content("File Not Found", "text/plain");
1390+
res.status = 404;
1391+
} });
13391392

13401393
// set timeouts and change hostname and port
13411394
svr.set_read_timeout(sparams.read_timeout);
@@ -1363,6 +1416,9 @@ int main(int argc, char **argv)
13631416
return 1;
13641417
}
13651418

1419+
if (llama.grammar != nullptr) {
1420+
llama_grammar_free(llama.grammar);
1421+
}
13661422
llama_backend_free();
13671423

13681424
return 0;

0 commit comments

Comments
 (0)