@@ -493,6 +493,69 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
493
493
env->SetLongField (obj, f_model_pointer, reinterpret_cast <jlong>(ctx_server));
494
494
}
495
495
496
+ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChat (JNIEnv *env, jobject obj, jstring jparams) {
497
+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
498
+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
499
+
500
+ std::string c_params = parse_jstring (env, jparams);
501
+ json data = json::parse (c_params);
502
+ std::cout << " dumping data" << std::endl;
503
+ std::cout << data.dump (4 ) << std::endl;
504
+ json oi_params = oaicompat_completion_params_parse (data, ctx_server->params_base .use_jinja , ctx_server->params_base .reasoning_format , ctx_server->chat_templates .get ());
505
+ std::cout << " dumping oi_params" << std::endl;
506
+ std::cout << oi_params.dump (4 ) << std::endl;
507
+
508
+ server_task_type type = SERVER_TASK_TYPE_COMPLETION;
509
+
510
+ if (oi_params.contains (" input_prefix" ) || oi_params.contains (" input_suffix" )) {
511
+ type = SERVER_TASK_TYPE_INFILL;
512
+ }
513
+
514
+ auto completion_id = gen_chatcmplid ();
515
+ std::vector<server_task> tasks;
516
+
517
+ try {
518
+ const auto &prompt = oi_params.at (" prompt" );
519
+
520
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts (ctx_server->vocab , prompt, true , true );
521
+
522
+ tasks.reserve (tokenized_prompts.size ());
523
+ for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
524
+ server_task task = server_task (type);
525
+
526
+ task.id = ctx_server->queue_tasks .get_new_id ();
527
+ task.index = i;
528
+
529
+ task.prompt_tokens = std::move (tokenized_prompts[i]);
530
+ task.params = server_task::params_from_json_cmpl (ctx_server->ctx , ctx_server->params_base , oi_params);
531
+ task.id_selected_slot = json_value (oi_params, " id_slot" , -1 );
532
+
533
+ // OAI-compat
534
+ task.params .oaicompat = OAICOMPAT_TYPE_CHAT;
535
+ task.params .oaicompat_cmpl_id = completion_id;
536
+ // oaicompat_model is already populated by params_from_json_cmpl
537
+
538
+ tasks.push_back (task);
539
+ }
540
+ } catch (const std::exception &e) {
541
+ const auto &err = format_error_response (e.what (), ERROR_TYPE_INVALID_REQUEST);
542
+ env->ThrowNew (c_llama_error, err.dump ().c_str ());
543
+ return 0 ;
544
+ }
545
+
546
+ ctx_server->queue_results .add_waiting_tasks (tasks);
547
+ ctx_server->queue_tasks .post (tasks);
548
+
549
+ const auto task_ids = server_task::get_list_id (tasks);
550
+
551
+ if (task_ids.size () != 1 ) {
552
+ env->ThrowNew (c_llama_error, " multitasking currently not supported" );
553
+ return 0 ;
554
+ }
555
+
556
+ return *task_ids.begin ();
557
+ }
558
+
496
559
JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion (JNIEnv *env, jobject obj, jstring jparams) {
497
560
jlong server_handle = env->GetLongField (obj, f_model_pointer);
498
561
auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
@@ -557,6 +620,31 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env,
557
620
ctx_server->queue_results .remove_waiting_task_id (id_task);
558
621
}
559
622
623
+ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveChatCompletion (JNIEnv *env, jobject obj, jint id_task) {
624
+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
625
+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
626
+
627
+ server_task_result_ptr result = ctx_server->queue_results .recv (id_task);
628
+
629
+ if (result->is_error ()) {
630
+ std::string response = result->to_json ()[" message" ].get <std::string>();
631
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
632
+ env->ThrowNew (c_llama_error, response.c_str ());
633
+ return nullptr ;
634
+ }
635
+ const auto out_res = result->to_json ();
636
+ std::cout << out_res.dump (4 ) << std::endl;
637
+
638
+
639
+ if (result->is_stop ()) {
640
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
641
+ }
642
+
643
+ jstring jtok_str = env->NewStringUTF (out_res.dump (4 ).c_str ());
644
+
645
+ return jtok_str;
646
+ }
647
+
560
648
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion (JNIEnv *env, jobject obj, jint id_task) {
561
649
jlong server_handle = env->GetLongField (obj, f_model_pointer);
562
650
auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
@@ -570,6 +658,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE
570
658
return nullptr ;
571
659
}
572
660
const auto out_res = result->to_json ();
661
+ std::cout << out_res.dump (4 ) << std::endl;
573
662
574
663
std::string response = out_res[" content" ].get <std::string>();
575
664
if (result->is_stop ()) {
0 commit comments