Skip to content

Commit ad3a050

Browse files
authored
Server: clean up OAI params parsing function (ggml-org#6284)
* server: clean up oai parsing function * fix response_format * fix empty response_format * minor fixes * add TODO for logprobs * update docs
1 parent 95ad616 commit ad3a050

File tree

3 files changed

+63
-38
lines changed

3 files changed

+63
-38
lines changed

examples/server/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ Notice that each `probs` is an array of length `n_probs`.
360360
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, has the same fields as the `generation_settings` response object from the `/completion` endpoint.
361361
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
362362

363-
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint.
363+
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only model with [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, ChatML template will be used.
364364

365365
*Options:*
366366

examples/server/server.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,16 @@ struct server_context {
847847
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
848848
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
849849
slot.params.seed = json_value(data, "seed", default_params.seed);
850-
if (data.contains("json_schema") && !data.contains("grammar")) {
850+
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
851+
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
852+
853+
// process "json_schema" and "grammar"
854+
if (data.contains("json_schema") && data.contains("grammar")) {
855+
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
856+
return false;
857+
} else if (data.contains("json_schema") && !data.contains("grammar")) {
851858
try {
852-
auto schema = json_value(data, "json_schema", json::object());
859+
auto schema = json_value(data, "json_schema", json::object());
853860
slot.sparams.grammar = json_schema_to_grammar(schema);
854861
} catch (const std::exception & e) {
855862
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
@@ -858,8 +865,6 @@ struct server_context {
858865
} else {
859866
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
860867
}
861-
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
862-
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
863868

864869
if (slot.params.cache_prompt && slot.ga_n != 1) {
865870
LOG_WARNING("cache_prompt is not supported with group-attention", {});

examples/server/utils.hpp

+53-33
Original file line numberDiff line numberDiff line change
@@ -352,51 +352,71 @@ static json oaicompat_completion_params_parse(
352352
// https://platform.openai.com/docs/api-reference/chat/create
353353
llama_sampling_params default_sparams;
354354
llama_params["model"] = json_value(body, "model", std::string("unknown"));
355-
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
356-
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
357-
llama_params["temperature"] = json_value(body, "temperature", 0.0);
358-
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
359-
llama_params["top_p"] = json_value(body, "top_p", 1.0);
360-
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
361-
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
362355
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
356+
llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
357+
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
363358
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
364359
llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
365360
llama_params["stream"] = json_value(body, "stream", false);
366-
llama_params["mirostat"] = json_value(body, "mirostat", default_sparams.mirostat);
367-
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
368-
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
369-
llama_params["penalize_nl"] = json_value(body, "penalize_nl", default_sparams.penalize_nl);
370-
llama_params["typical_p"] = json_value(body, "typical_p", default_sparams.typical_p);
371-
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
372-
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
373-
llama_params["tfs_z"] = json_value(body, "tfs_z", default_sparams.tfs_z);
374-
llama_params["n_keep"] = json_value(body, "n_keep", 0);
375-
376-
if (body.contains("grammar")) {
377-
llama_params["grammar"] = json_value(body, "grammar", json::object());
378-
}
361+
llama_params["temperature"] = json_value(body, "temperature", 0.0);
362+
llama_params["top_p"] = json_value(body, "top_p", 1.0);
379363

380-
if (body.contains("response_format")) {
381-
auto response_format = json_value(body, "response_format", json::object());
382-
if (response_format.contains("type")) {
383-
if (response_format["type"] == "json_object") {
384-
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
385-
} else {
386-
throw std::runtime_error("response_format type not supported: " + response_format["type"].dump());
387-
}
388-
}
389-
}
364+
// Apply chat template to the list of messages
365+
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
390366

391-
// Handle 'stop' field
367+
// Handle "stop" field
392368
if (body.contains("stop") && body["stop"].is_string()) {
393369
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
394370
} else {
395371
llama_params["stop"] = json_value(body, "stop", json::array());
396372
}
373+
// Some chat templates don't use EOS token to stop generation
374+
// We must add their end sequences to list of stop words
375+
llama_params["stop"].push_back("<|im_end|>"); // chatml
376+
llama_params["stop"].push_back("<end_of_turn>"); // gemma
397377

398-
// Ensure there is ChatML-specific end sequence among stop words
399-
llama_params["stop"].push_back("<|im_end|>");
378+
// Handle "response_format" field
379+
if (body.contains("response_format")) {
380+
json response_format = json_value(body, "response_format", json::object());
381+
std::string response_type = json_value(response_format, "type", std::string());
382+
if (response_type == "json_object") {
383+
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
384+
} else if (!response_type.empty() && response_type != "text") {
385+
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
386+
}
387+
}
388+
389+
// Handle "n" field
390+
int n_choices = json_value(body, "n", 1);
391+
if (n_choices != 1) {
392+
throw std::runtime_error("Only one completion choice is allowed");
393+
}
394+
395+
// Handle "logprobs" field
396+
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
397+
if (body.contains("logprobs")) {
398+
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
399+
} else if (body.contains("top_logprobs")) {
400+
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
401+
}
402+
403+
// Params supported by OAI but unsupported by llama.cpp
404+
static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
405+
for (auto & param : unsupported_params) {
406+
if (body.contains(param)) {
407+
throw std::runtime_error("Unsupported param: " + param);
408+
}
409+
}
410+
411+
// Copy remaining properties to llama_params
412+
// This allows user to use llama.cpp-specific params like "mirostat", "tfs_z",... via OAI endpoint.
413+
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
414+
for (const auto & item : body.items()) {
415+
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
416+
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
417+
llama_params[item.key()] = item.value();
418+
}
419+
}
400420

401421
return llama_params;
402422
}

0 commit comments

Comments
 (0)