Skip to content

Commit 2a5a1b1

Browse files
author
Vaijanath Rao
committed
adding tool support and chat completions
1 parent f41fc8c commit 2a5a1b1

File tree

5 files changed

+136
-7
lines changed

5 files changed

+136
-7
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>de.kherud</groupId>
77
<artifactId>llama</artifactId>
8-
<version>4.1.0</version>
8+
<version>4.1.1</version>
99
<packaging>jar</packaging>
1010

1111
<name>${project.groupId}:${project.artifactId}</name>

src/main/cpp/jllama.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,69 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
493493
env->SetLongField(obj, f_model_pointer, reinterpret_cast<jlong>(ctx_server));
494494
}
495495

496+
JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChat(JNIEnv *env, jobject obj, jstring jparams) {
497+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
498+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
499+
500+
std::string c_params = parse_jstring(env, jparams);
501+
json data = json::parse(c_params);
502+
std::cout << "dumping data" << std::endl;
503+
std::cout << data.dump(4) << std::endl;
504+
json oi_params = oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
505+
std::cout << "dumping oi_params" << std::endl;
506+
std::cout << oi_params.dump(4) << std::endl;
507+
508+
server_task_type type = SERVER_TASK_TYPE_COMPLETION;
509+
510+
if (oi_params.contains("input_prefix") || oi_params.contains("input_suffix")) {
511+
type = SERVER_TASK_TYPE_INFILL;
512+
}
513+
514+
auto completion_id = gen_chatcmplid();
515+
std::vector<server_task> tasks;
516+
517+
try {
518+
const auto &prompt = oi_params.at("prompt");
519+
520+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);
521+
522+
tasks.reserve(tokenized_prompts.size());
523+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
524+
server_task task = server_task(type);
525+
526+
task.id = ctx_server->queue_tasks.get_new_id();
527+
task.index = i;
528+
529+
task.prompt_tokens = std::move(tokenized_prompts[i]);
530+
task.params = server_task::params_from_json_cmpl(ctx_server->ctx, ctx_server->params_base, oi_params);
531+
task.id_selected_slot = json_value(oi_params, "id_slot", -1);
532+
533+
// OAI-compat
534+
task.params.oaicompat = OAICOMPAT_TYPE_CHAT;
535+
task.params.oaicompat_cmpl_id = completion_id;
536+
// oaicompat_model is already populated by params_from_json_cmpl
537+
538+
tasks.push_back(task);
539+
}
540+
} catch (const std::exception &e) {
541+
const auto &err = format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST);
542+
env->ThrowNew(c_llama_error, err.dump().c_str());
543+
return 0;
544+
}
545+
546+
ctx_server->queue_results.add_waiting_tasks(tasks);
547+
ctx_server->queue_tasks.post(tasks);
548+
549+
const auto task_ids = server_task::get_list_id(tasks);
550+
551+
if (task_ids.size() != 1) {
552+
env->ThrowNew(c_llama_error, "multitasking currently not supported");
553+
return 0;
554+
}
555+
556+
return *task_ids.begin();
557+
}
558+
496559
JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) {
497560
jlong server_handle = env->GetLongField(obj, f_model_pointer);
498561
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
@@ -557,6 +620,31 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env,
557620
ctx_server->queue_results.remove_waiting_task_id(id_task);
558621
}
559622

623+
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveChatCompletion(JNIEnv *env, jobject obj, jint id_task) {
624+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
625+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
626+
627+
server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
628+
629+
if (result->is_error()) {
630+
std::string response = result->to_json()["message"].get<std::string>();
631+
ctx_server->queue_results.remove_waiting_task_id(id_task);
632+
env->ThrowNew(c_llama_error, response.c_str());
633+
return nullptr;
634+
}
635+
const auto out_res = result->to_json();
636+
std::cout << out_res.dump(4) << std::endl;
637+
638+
639+
if (result->is_stop()) {
640+
ctx_server->queue_results.remove_waiting_task_id(id_task);
641+
}
642+
643+
jstring jtok_str = env->NewStringUTF(out_res.dump(4).c_str());
644+
645+
return jtok_str;
646+
}
647+
560648
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) {
561649
jlong server_handle = env->GetLongField(obj, f_model_pointer);
562650
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
@@ -570,6 +658,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE
570658
return nullptr;
571659
}
572660
const auto out_res = result->to_json();
661+
std::cout << out_res.dump(4) << std::endl;
573662

574663
std::string response = out_res["content"].get<std::string>();
575664
if (result->is_stop()) {

src/main/cpp/jllama.h

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main/java/de/kherud/llama/InferenceParameters.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ public final class InferenceParameters extends JsonParameters {
5050
private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template";
5151
private static final String PARAM_USE_JINJA = "use_jinja";
5252
private static final String PARAM_MESSAGES = "messages";
53+
private static final String PARAM_TOOLS = "tools";
54+
private static final String PARAM_TOOL_CHOICE = "tool_choice";
55+
private static final String PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls";
5356

5457
public InferenceParameters(String prompt) {
5558
// we always need a prompt
@@ -537,11 +540,33 @@ public InferenceParameters setMessages(String systemMessage, List<Pair<String, S
537540
parameters.put(PARAM_MESSAGES, messagesBuilder.toString());
538541
return this;
539542
}
543+
544+
540545

541546
InferenceParameters setStream(boolean stream) {
542547
parameters.put(PARAM_STREAM, String.valueOf(stream));
543548
return this;
544549
}
550+
551+
/**
552+
* Set Tools
553+
*/
554+
public InferenceParameters setTools(String... tools) {
555+
StringBuilder toolBuilder = new StringBuilder();
556+
557+
for (String tool:tools) {
558+
if (toolBuilder.length() > 0) {
559+
toolBuilder.append(",");
560+
}
561+
toolBuilder.append(tool);
562+
563+
}
564+
565+
parameters.put(PARAM_TOOLS, "[" + toolBuilder.toString() +"]");
566+
parameters.put(PARAM_TOOL_CHOICE, toJsonString("required"));
567+
// parameters.put(PARAM_PARALLEL_TOOL_CALLS,String.valueOf(false));
568+
return this;
569+
}
545570

546571
public String get(String field) {
547572
return parameters.get(field);

src/main/java/de/kherud/llama/LlamaModel.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ public String complete(InferenceParameters parameters) {
6868
*/
6969
public String completeChat(InferenceParameters parameters) {
7070
parameters.setStream(false);
71-
String prompt = applyTemplate(parameters);
72-
parameters.setPrompt(prompt);
73-
int taskId = requestCompletion(parameters.toString());
74-
LlamaOutput output = receiveCompletion(taskId);
75-
return output.text;
71+
72+
int taskId = requestChat(parameters.toString());
73+
String output = receiveChatCompletion(taskId);
74+
return output;
7675
}
7776

7877
/**
@@ -148,9 +147,13 @@ public void close() {
148147

149148
// don't overload native methods since the C++ function names get nasty
150149
native int requestCompletion(String params) throws LlamaException;
150+
151+
native int requestChat(String params) throws LlamaException;
151152

152153
native LlamaOutput receiveCompletion(int taskId) throws LlamaException;
153-
154+
155+
native String receiveChatCompletion(int taskId) throws LlamaException;
156+
154157
native void cancelCompletion(int taskId);
155158

156159
native byte[] decodeBytes(int[] tokens);

0 commit comments

Comments
 (0)