From 71373681d3b460bb384750d3e6fd9f17e6055089 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Mar 2025 12:19:56 -0700 Subject: [PATCH 01/13] adding re-ranking --- pom.xml | 34 ++++-- src/main/cpp/jllama.cpp | 111 +++++++++++++++++- src/main/cpp/jllama.h | 7 ++ src/main/java/de/kherud/llama/LlamaModel.java | 3 + .../java/de/kherud/llama/LlamaModelTest.java | 20 ++++ 5 files changed, 163 insertions(+), 12 deletions(-) diff --git a/pom.xml b/pom.xml index c081e19..fba7eb4 100644 --- a/pom.xml +++ b/pom.xml @@ -1,14 +1,16 @@ - 4.0.0 de.kherud llama - 4.0.0 + 4.0.1 jar ${project.groupId}:${project.artifactId} - Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++. + Java Bindings for llama.cpp - A Port of Facebook's LLaMA model + in C/C++. https://github.com/kherud/java-llama.cpp @@ -39,7 +41,8 @@ ossrh - https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ @@ -62,6 +65,7 @@ 24.1.0 compile + @@ -71,17 +75,21 @@ maven-compiler-plugin 3.13.0 - + gpu compile - compile + + compile + -h src/main/cpp - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda @@ -98,10 +106,12 @@ copy-resources - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda - ${basedir}/src/main/resources_linux_cuda/ + + ${basedir}/src/main/resources_linux_cuda/ **/*.* @@ -176,7 +186,8 @@ maven-jar-plugin 3.4.2 - + cuda package @@ -185,7 +196,8 @@ cuda12-linux-x86-64 - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 0db026e..9fafb6f 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -112,6 +112,26 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js return result; } +std::vector parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, const jsize length) { + std::vector result; + result.reserve(length); // Reserve memory for efficiency + + for (jsize i = 0; i < length; i++) { + jstring javaString = static_cast(env->GetObjectArrayElement(string_array, i)); + if (javaString == nullptr) continue; + + const char *cString = env->GetStringUTFChars(javaString, nullptr); + if (cString != nullptr) { + result.emplace_back(cString); // Add to vector + env->ReleaseStringUTFChars(javaString, cString); + } + + env->DeleteLocalRef(javaString); // Avoid memory leaks + } + + return result; +} + void free_string_array(char **array, jsize length) { if (array != nullptr) { for (jsize i = 0; i < length; i++) { @@ -239,6 +259,7 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { cc_integer = env->GetMethodID(c_integer, "", "(I)V"); cc_float = env->GetMethodID(c_float, "", "(F)V"); + if (!(cc_output && cc_hash_map && cc_integer && cc_float)) { goto error; } @@ -634,7 +655,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, json error = nullptr; server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - ctx_server->queue_results.remove_waiting_task_id(id_task); json response_str = result->to_json(); if (result->is_error()) { @@ -643,6 +663,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } + + if (result->is_stop()) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + const auto out_res = result->to_json(); @@ -679,6 +704,90 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return j_embedding; } +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, jobjectArray documents) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + return nullptr; + } + + + const std::string prompt = parse_jstring(env, jprompt); + + + + const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); + + json responses = json::array(); + bool error = false; + + std::vector tasks; + const jsize argc = env->GetArrayLength(documents); + std::vector documentsArray = parse_string_array_for_rerank(env, documents, argc); + + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true); + + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + std::vector results(task_ids.size()); + + // Create a new HashMap instance + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (o_probabilities == nullptr) { + env->ThrowNew(c_llama_error, "Failed to create HashMap object."); + return nullptr; + } + + for (int i = 0; i < (int)task_ids.size(); i++) { + server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); + if (result->is_error()) { + std::string response = result->to_json()["message"].get(); + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } + + const auto out_res = result->to_json(); + + std::cout << out_res.dump(4) << std::endl; + + if (result->is_stop()) { + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + } + + int index = out_res["index"].get(); + float score = out_res["score"].get(); + std::string tok_str = documentsArray[index]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + jobject jprob = env->NewObject(c_float, cc_float, score); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); + } + jbyteArray jbytes = parse_jbytes(env, prompt); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); + +} + JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 63d95b7..01e4d20 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -84,6 +84,13 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, job */ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring); +/* + * Class: de_kherud_llama_LlamaModel + * Method: rerank + * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput; + */ +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray); + #ifdef __cplusplus } #endif diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 7749b32..ffa9675 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -5,6 +5,7 @@ import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.function.BiConsumer; /** @@ -137,4 +138,6 @@ public void close() { public static String jsonSchemaToGrammar(String schema) { return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); } + + public native LlamaOutput rerank(String query, String... documents); } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index f2e931b..6481f09 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -158,6 +158,26 @@ public void testEmbedding() { float[] embedding = model.embed(prefix); Assert.assertEquals(4096, embedding.length); } + + + @Ignore + /** + * To run this test download the model from here https://huggingface.co/mradermacher/jina-reranker-v1-tiny-en-GGUF/tree/main + * remove .enableEmbedding() from model setup and add .enableReRanking() and then enable the test. + */ + public void testReRanking() { + + String query = "Machine learning is"; + String [] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3] ); + + System.out.println(llamaOutput); + } @Test public void testTokenization() { From e9c3de7ef5918c86fd8cca03efb58f8852339212 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Mar 2025 19:08:02 -0700 Subject: [PATCH 02/13] moving reranking to it's own test. --- .github/workflows/ci.yml | 14 +++++- .../de/kherud/llama/RerankingModelTest.java | 47 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 src/test/java/de/kherud/llama/RerankingModelTest.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 631fc86..9e913a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,8 @@ on: env: MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf MODEL_NAME: codellama-7b.Q2_K.gguf + RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf + RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf jobs: build-and-test-linux: @@ -21,8 +23,10 @@ jobs: run: | mvn compile .github/build.sh -DLLAMA_VERBOSE=ON - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - name: Run tests run: mvn test - if: failure() @@ -53,8 +57,11 @@ jobs: run: | mvn compile .github/build.sh ${{ matrix.target.cmake }} - - name: Download model + - name: Download text generaton model model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: Run tests run: mvn test - if: failure() @@ -79,6 +86,9 @@ jobs: .github\build.bat -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: Run tests run: mvn test - if: failure() diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java new file mode 100644 index 0000000..38ca7e2 --- /dev/null +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -0,0 +1,47 @@ +package de.kherud.llama; + +import java.io.*; +import java.util.*; +import java.util.regex.Pattern; + +import de.kherud.llama.args.LogFormat; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Test; + +public class RerankingModelTest { + + private static LlamaModel model; + + @BeforeClass + public static void setup() { + model = new LlamaModel( + new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en.Q4_K_M.gguf") + .setGpuLayers(43).enableReranking().enableLogTimestamps().enableLogPrefix()); + } + + @AfterClass + public static void tearDown() { + if (model != null) { + model.close(); + } + } + + @Test + public void testReRanking() { + + String query = "Machine learning is"; + String[] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], + TEST_DOCUMENTS[3]); + + System.out.println(llamaOutput); + } + +} From 01a6f83726cbae097fb282e6095f12e1dc10da4b Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Mar 2025 20:46:15 -0700 Subject: [PATCH 03/13] updating the workflow and reranking --- .github/workflows/ci.yml | 8 ++++++-- src/test/java/de/kherud/llama/RerankingModelTest.java | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e913a9..9ff9dfb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,8 @@ jobs: run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -61,7 +63,8 @@ jobs: run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} - name: Download reranking model run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -88,7 +91,8 @@ jobs: run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Download reranking model run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java index 38ca7e2..69adb7f 100644 --- a/src/test/java/de/kherud/llama/RerankingModelTest.java +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -18,7 +18,7 @@ public class RerankingModelTest { @BeforeClass public static void setup() { model = new LlamaModel( - new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en.Q4_K_M.gguf") + new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en-Q4_0.gguf") .setGpuLayers(43).enableReranking().enableLogTimestamps().enableLogPrefix()); } From 1685c3e5044fa4012595d5b7ea113da41f6c0ee8 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Mar 2025 20:57:02 -0700 Subject: [PATCH 04/13] updating windows build --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9ff9dfb..a15f809 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -90,7 +90,7 @@ jobs: - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME - name: Download reranking model - run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME - name: List files in models directory run: ls -l models/ - name: Run tests From 06b11a705669ac09864338b9c55364cf886b7e1e Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Wed, 12 Mar 2025 21:36:17 -0700 Subject: [PATCH 05/13] updated the test. --- .../de/kherud/llama/RerankingModelTest.java | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java index 69adb7f..8145829 100644 --- a/src/test/java/de/kherud/llama/RerankingModelTest.java +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -1,14 +1,10 @@ package de.kherud.llama; -import java.io.*; -import java.util.*; -import java.util.regex.Pattern; +import java.util.Map; -import de.kherud.llama.args.LogFormat; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; public class RerankingModelTest { @@ -41,7 +37,32 @@ public void testReRanking() { LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3]); - System.out.println(llamaOutput); + Map rankedDocumentsMap = llamaOutput.probabilities; + Assert.assertTrue(rankedDocumentsMap.size()==TEST_DOCUMENTS.length); + + // Finding the most and least relevant documents + String mostRelevantDoc = null; + String leastRelevantDoc = null; + float maxScore = Float.MIN_VALUE; + float minScore = Float.MAX_VALUE; + + for (Map.Entry entry : rankedDocumentsMap.entrySet()) { + if (entry.getValue() > maxScore) { + maxScore = entry.getValue(); + mostRelevantDoc = entry.getKey(); + } + if (entry.getValue() < minScore) { + minScore = entry.getValue(); + leastRelevantDoc = entry.getKey(); + } + } + + // Assertions + Assert.assertTrue(maxScore > minScore); + Assert.assertEquals("Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", mostRelevantDoc); + Assert.assertEquals("Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", leastRelevantDoc); + + } } From faa494e886824a888ea12cf388c9f45229ff35e7 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Mar 2025 15:41:56 -0700 Subject: [PATCH 06/13] removed std print and adding ranking test. --- src/main/cpp/jllama.cpp | 2 - .../java/de/kherud/llama/LlamaIterator.java | 3 ++ src/main/java/de/kherud/llama/LlamaModel.java | 25 +++++++++- src/main/java/de/kherud/llama/Pair.java | 48 +++++++++++++++++++ .../de/kherud/llama/RerankingModelTest.java | 29 ++++++++--- 5 files changed, 97 insertions(+), 10 deletions(-) create mode 100644 src/main/java/de/kherud/llama/Pair.java diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 9fafb6f..b0242c3 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -765,8 +765,6 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo const auto out_res = result->to_json(); - std::cout << out_res.dump(4) << std::endl; - if (result->is_stop()) { for (const int id_task : task_ids) { ctx_server->queue_results.remove_waiting_task_id(id_task); diff --git a/src/main/java/de/kherud/llama/LlamaIterator.java b/src/main/java/de/kherud/llama/LlamaIterator.java index fdff993..cb1c5c2 100644 --- a/src/main/java/de/kherud/llama/LlamaIterator.java +++ b/src/main/java/de/kherud/llama/LlamaIterator.java @@ -35,6 +35,9 @@ public LlamaOutput next() { } LlamaOutput output = model.receiveCompletion(taskId); hasNext = !output.stop; + if (output.stop) { + model.releaseTask(taskId); + } return output; } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index ffa9675..9ed86d0 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -5,7 +5,9 @@ import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; /** @@ -131,7 +133,7 @@ public void close() { private native void delete(); - private native void releaseTask(int taskId); + native void releaseTask(int taskId); private static native byte[] jsonSchemaToGrammarBytes(String schema); @@ -139,5 +141,26 @@ public static String jsonSchemaToGrammar(String schema) { return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); } + public List> rerank(boolean reRank, String query, String ... documents) { + LlamaOutput output = rerank(query, documents); + + Map scoredDocumentMap = output.probabilities; + + List> rankedDocuments = new ArrayList<>(); + + if (reRank) { + // Sort in descending order based on Float values + scoredDocumentMap.entrySet() + .stream() + .sorted((a, b) -> Float.compare(b.getValue(), a.getValue())) // Descending order + .forEach(entry -> rankedDocuments.add(new Pair<>(entry.getKey(), entry.getValue()))); + } else { + // Copy without sorting + scoredDocumentMap.forEach((key, value) -> rankedDocuments.add(new Pair<>(key, value))); + } + + return rankedDocuments; + } + public native LlamaOutput rerank(String query, String... documents); } diff --git a/src/main/java/de/kherud/llama/Pair.java b/src/main/java/de/kherud/llama/Pair.java new file mode 100644 index 0000000..48ac648 --- /dev/null +++ b/src/main/java/de/kherud/llama/Pair.java @@ -0,0 +1,48 @@ +package de.kherud.llama; + +import java.util.Objects; + +public class Pair { + + private final K key; + private final V value; + + public Pair(K key, V value) { + this.key = key; + this.value = value; + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Pair other = (Pair) obj; + return Objects.equals(key, other.key) && Objects.equals(value, other.value); + } + + @Override + public String toString() { + return "Pair [key=" + key + ", value=" + value + "]"; + } + + + + +} diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java index 8145829..60d32bd 100644 --- a/src/test/java/de/kherud/llama/RerankingModelTest.java +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -1,5 +1,6 @@ package de.kherud.llama; +import java.util.List; import java.util.Map; import org.junit.AfterClass; @@ -10,6 +11,13 @@ public class RerankingModelTest { private static LlamaModel model; + + String query = "Machine learning is"; + String[] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; @BeforeClass public static void setup() { @@ -28,12 +36,7 @@ public static void tearDown() { @Test public void testReRanking() { - String query = "Machine learning is"; - String[] TEST_DOCUMENTS = new String[] { - "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", - "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", - "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", - "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3]); @@ -64,5 +67,17 @@ public void testReRanking() { } - + + @Test + public void testSortedReRanking() { + List> rankedDocuments = model.rerank(true, query, TEST_DOCUMENTS); + Assert.assertEquals(rankedDocuments.size(), TEST_DOCUMENTS.length); + + // Check the ranking order: each score should be >= the next one + for (int i = 0; i < rankedDocuments.size() - 1; i++) { + float currentScore = rankedDocuments.get(i).getValue(); + float nextScore = rankedDocuments.get(i + 1).getValue(); + Assert.assertTrue("Ranking order incorrect at index " + i, currentScore >= nextScore); + } + } } From fe7c337a76f498f2fb7b7e1c501386554554235c Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Thu, 13 Mar 2025 16:57:46 -0700 Subject: [PATCH 07/13] updating release.yaml file for reranking --- .github/workflows/release.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ff566ad..6403202 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -11,6 +11,8 @@ on: env: MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf" MODEL_NAME: "codellama-7b.Q2_K.gguf" + RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf" + RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf" jobs: # todo: doesn't work with the newest llama.cpp version @@ -144,8 +146,10 @@ jobs: with: name: Linux-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - uses: actions/setup-java@v4 with: distribution: 'zulu' From 3d28a989ee7741715d1c593ab3282363185a72e4 Mon Sep 17 00:00:00 2001 From: Vaijanath Rao Date: Fri, 14 Mar 2025 02:36:21 -0700 Subject: [PATCH 08/13] adding support for messages. --- pom.xml | 8 +++- src/main/cpp/jllama.cpp | 14 ++++++ src/main/cpp/jllama.h | 7 +++ .../de/kherud/llama/InferenceParameters.java | 45 ++++++++++++++++++- src/main/java/de/kherud/llama/LlamaModel.java | 5 +++ .../java/de/kherud/llama/LlamaModelTest.java | 16 +++++++ 6 files changed, 92 insertions(+), 3 deletions(-) diff --git a/pom.xml b/pom.xml index fba7eb4..f4e1e45 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ de.kherud llama - 4.0.1 + 4.0.0 jar ${project.groupId}:${project.artifactId} @@ -65,7 +65,11 @@ 24.1.0 compile - + + com.fasterxml.jackson.core + jackson-databind + 2.16.0 + diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index b0242c3..a0aca71 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -786,6 +786,20 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo } +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams){ + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + std::string c_params = parse_jstring(env, jparams); + json data = json::parse(c_params); + + json templateData = oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + std::string tok_str = templateData.at("prompt"); + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + return jtok_str; +} + JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 01e4d20..dc17fa8 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -91,6 +91,13 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammar */ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray); +/* + * Class: de_kherud_llama_LlamaModel + * Method: applyTemplate + * Signature: (Ljava/lang/String;)Ljava/lang/String;; + */ +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring); + #ifdef __cplusplus } #endif diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 0ac1b1d..e868be0 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -1,8 +1,13 @@ package de.kherud.llama; import java.util.Collection; +import java.util.List; import java.util.Map; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; + import de.kherud.llama.args.MiroStat; import de.kherud.llama.args.Sampler; @@ -12,6 +17,9 @@ * {@link LlamaModel#complete(InferenceParameters)}. */ public final class InferenceParameters extends JsonParameters { + + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Reusable ObjectMapper + private static final String PARAM_PROMPT = "prompt"; private static final String PARAM_INPUT_PREFIX = "input_prefix"; @@ -47,6 +55,7 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_STREAM = "stream"; private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template"; private static final String PARAM_USE_JINJA = "use_jinja"; + private static final String PARAM_MESSAGES = "messages"; public InferenceParameters(String prompt) { // we always need a prompt @@ -493,7 +502,41 @@ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { return this; } - + /** + * Set the messages for chat-based inference. + * - Allows **only one** system message. + * - Allows **one or more** user/assistant messages. + */ + public InferenceParameters setMessages(String systemMessage, List> messages) { + ArrayNode messagesArray = OBJECT_MAPPER.createArrayNode(); + + // Add system message (if provided) + if (systemMessage != null && !systemMessage.isEmpty()) { + ObjectNode systemObj = OBJECT_MAPPER.createObjectNode(); + systemObj.put("role", "system"); + systemObj.put("content", systemMessage); + messagesArray.add(systemObj); + } + + // Add user/assistant messages + for (Pair message : messages) { + String role = message.getKey(); + String content = message.getValue(); + + if (!role.equals("user") && !role.equals("assistant")) { + throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'."); + } + + ObjectNode messageObj = OBJECT_MAPPER.createObjectNode(); + messageObj.put("role", role); + messageObj.put("content", content); + messagesArray.add(messageObj); + } + + // Convert ArrayNode to a JSON string and store it in parameters + parameters.put(PARAM_MESSAGES, messagesArray.toString()); + return this; + } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 9ed86d0..eab3620 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -163,4 +163,9 @@ public List> rerank(boolean reRank, String query, String ... } public native LlamaOutput rerank(String query, String... documents); + + public String applyTemplate(InferenceParameters parameters) { + return applyTemplate(parameters.toString()); + } + public native String applyTemplate(String parametersJson); } diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 6481f09..e3e69d8 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -316,4 +316,20 @@ public void testJsonSchemaToGrammar() { String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema); Assert.assertEquals(expectedGrammar, actualGrammar); } + + @Test + public void testTemplate() { + + List> userMessages = new ArrayList<>(); + userMessages.add(new Pair<>("user", "What is the best book?")); + userMessages.add(new Pair<>("assistant", "It depends on your interests. Do you like fiction or non-fiction?")); + + InferenceParameters params = new InferenceParameters("A book recommendation system.") + .setMessages("Book", userMessages) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setSeed(42); + Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n"); + } } From 6e95f61d51afa629b8a998d34f3cc3c4eb623709 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Tue, 18 Mar 2025 21:01:25 +0100 Subject: [PATCH 09/13] reformat c++ code --- src/main/cpp/jllama.cpp | 159 ++++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 80 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index a0aca71..b9436b7 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -112,13 +112,15 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js return result; } -std::vector parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, const jsize length) { +std::vector parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, + const jsize length) { std::vector result; result.reserve(length); // Reserve memory for efficiency for (jsize i = 0; i < length; i++) { jstring javaString = static_cast(env->GetObjectArrayElement(string_array, i)); - if (javaString == nullptr) continue; + if (javaString == nullptr) + continue; const char *cString = env->GetStringUTFChars(javaString, nullptr); if (cString != nullptr) { @@ -259,7 +261,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { cc_integer = env->GetMethodID(c_integer, "", "(I)V"); cc_float = env->GetMethodID(c_float, "", "(F)V"); - if (!(cc_output && cc_hash_map && cc_integer && cc_float)) { goto error; } @@ -663,12 +664,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - + if (result->is_stop()) { ctx_server->queue_results.remove_waiting_task_id(id_task); } - const auto out_res = result->to_json(); // Extract "embedding" as a vector of vectors (2D array) @@ -704,100 +704,99 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return j_embedding; } -JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, jobjectArray documents) { +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, + jobjectArray documents) { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) - if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { - env->ThrowNew(c_llama_error, + if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); - return nullptr; + return nullptr; } - const std::string prompt = parse_jstring(env, jprompt); - - const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); - + json responses = json::array(); bool error = false; - - std::vector tasks; - const jsize argc = env->GetArrayLength(documents); - std::vector documentsArray = parse_string_array_for_rerank(env, documents, argc); - - std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true); - - tasks.reserve(tokenized_docs.size()); - for (size_t i = 0; i < tokenized_docs.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server->queue_tasks.get_new_id(); - task.index = i; - task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); - tasks.push_back(task); - } - ctx_server->queue_results.add_waiting_tasks(tasks); - ctx_server->queue_tasks.post(tasks); - - // get the result - std::unordered_set task_ids = server_task::get_list_id(tasks); - std::vector results(task_ids.size()); - - // Create a new HashMap instance - jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); - if (o_probabilities == nullptr) { - env->ThrowNew(c_llama_error, "Failed to create HashMap object."); - return nullptr; - } - - for (int i = 0; i < (int)task_ids.size(); i++) { - server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); - if (result->is_error()) { - std::string response = result->to_json()["message"].get(); - for (const int id_task : task_ids) { - ctx_server->queue_results.remove_waiting_task_id(id_task); - } - env->ThrowNew(c_llama_error, response.c_str()); - return nullptr; - } - - const auto out_res = result->to_json(); - - if (result->is_stop()) { - for (const int id_task : task_ids) { - ctx_server->queue_results.remove_waiting_task_id(id_task); - } - } - - int index = out_res["index"].get(); - float score = out_res["score"].get(); - std::string tok_str = documentsArray[index]; - jstring jtok_str = env->NewStringUTF(tok_str.c_str()); - - jobject jprob = env->NewObject(c_float, cc_float, score); - env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); - env->DeleteLocalRef(jtok_str); - env->DeleteLocalRef(jprob); - } + + std::vector tasks; + const jsize argc = env->GetArrayLength(documents); + std::vector documentsArray = parse_string_array_for_rerank(env, documents, argc); + + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true); + + tasks.reserve(tokenized_docs.size()); + for (size_t i = 0; i < tokenized_docs.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + std::vector results(task_ids.size()); + + // Create a new HashMap instance + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (o_probabilities == nullptr) { + env->ThrowNew(c_llama_error, "Failed to create HashMap object."); + return nullptr; + } + + for (int i = 0; i < (int)task_ids.size(); i++) { + server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); + if (result->is_error()) { + std::string response = result->to_json()["message"].get(); + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } + + const auto out_res = result->to_json(); + + if (result->is_stop()) { + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + } + + int index = out_res["index"].get(); + float score = out_res["score"].get(); + std::string tok_str = documentsArray[index]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + jobject jprob = env->NewObject(c_float, cc_float, score); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); + } jbyteArray jbytes = parse_jbytes(env, prompt); - return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); - + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); } -JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams){ - jlong server_handle = env->GetLongField(obj, f_model_pointer); +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) std::string c_params = parse_jstring(env, jparams); json data = json::parse(c_params); - - json templateData = oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + + json templateData = + oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, + ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); std::string tok_str = templateData.at("prompt"); - jstring jtok_str = env->NewStringUTF(tok_str.c_str()); - - return jtok_str; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + return jtok_str; } JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { From 986bddf63bd294c37d903d14906bed25ba95d6e9 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Tue, 18 Mar 2025 21:25:18 +0100 Subject: [PATCH 10/13] re-use parse_string_array for re-ranking --- src/main/cpp/jllama.cpp | 39 +++++++++------------------------------ 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index b9436b7..ac056b9 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -112,28 +112,6 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js return result; } -std::vector parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, - const jsize length) { - std::vector result; - result.reserve(length); // Reserve memory for efficiency - - for (jsize i = 0; i < length; i++) { - jstring javaString = static_cast(env->GetObjectArrayElement(string_array, i)); - if (javaString == nullptr) - continue; - - const char *cString = env->GetStringUTFChars(javaString, nullptr); - if (cString != nullptr) { - result.emplace_back(cString); // Add to vector - env->ReleaseStringUTFChars(javaString, cString); - } - - env->DeleteLocalRef(javaString); // Avoid memory leaks - } - - return result; -} - void free_string_array(char **array, jsize length) { if (array != nullptr) { for (jsize i = 0; i < length; i++) { @@ -720,17 +698,18 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); json responses = json::array(); - bool error = false; std::vector tasks; - const jsize argc = env->GetArrayLength(documents); - std::vector documentsArray = parse_string_array_for_rerank(env, documents, argc); + const jsize amount_documents = env->GetArrayLength(documents); + auto *document_array = parse_string_array(env, documents, amount_documents); + auto document_vector = std::vector(document_array, document_array + amount_documents); + free_string_array(document_array, amount_documents); - std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true); + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true); tasks.reserve(tokenized_docs.size()); - for (size_t i = 0; i < tokenized_docs.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_RERANK); + for (int i = 0; i < tokenized_docs.size(); i++) { + auto task = server_task(SERVER_TASK_TYPE_RERANK); task.id = ctx_server->queue_tasks.get_new_id(); task.index = i; task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); @@ -753,7 +732,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo for (int i = 0; i < (int)task_ids.size(); i++) { server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); if (result->is_error()) { - std::string response = result->to_json()["message"].get(); + auto response = result->to_json()["message"].get(); for (const int id_task : task_ids) { ctx_server->queue_results.remove_waiting_task_id(id_task); } @@ -771,7 +750,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo int index = out_res["index"].get(); float score = out_res["score"].get(); - std::string tok_str = documentsArray[index]; + std::string tok_str = document_vector[index]; jstring jtok_str = env->NewStringUTF(tok_str.c_str()); jobject jprob = env->NewObject(c_float, cc_float, score); From 62cc40eff9e322815b2c750b95215b78597dc099 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Tue, 18 Mar 2025 21:25:39 +0100 Subject: [PATCH 11/13] replace jackson with string builder --- pom.xml | 5 -- .../de/kherud/llama/InferenceParameters.java | 55 ++++++++++--------- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/pom.xml b/pom.xml index f4e1e45..4982f40 100644 --- a/pom.xml +++ b/pom.xml @@ -65,11 +65,6 @@ 24.1.0 compile - - com.fasterxml.jackson.core - jackson-databind - 2.16.0 - diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index e868be0..41f74cc 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -4,10 +4,6 @@ import java.util.List; import java.util.Map; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; - import de.kherud.llama.args.MiroStat; import de.kherud.llama.args.Sampler; @@ -16,10 +12,8 @@ * and * {@link LlamaModel#complete(InferenceParameters)}. */ +@SuppressWarnings("unused") public final class InferenceParameters extends JsonParameters { - - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Reusable ObjectMapper - private static final String PARAM_PROMPT = "prompt"; private static final String PARAM_INPUT_PREFIX = "input_prefix"; @@ -489,13 +483,8 @@ public InferenceParameters setSamplers(Sampler... samplers) { return this; } - InferenceParameters setStream(boolean stream) { - parameters.put(PARAM_STREAM, String.valueOf(stream)); - return this; - } - /** - * Set whether or not generate should apply a chat template (default: false) + * Set whether generate should apply a chat template (default: false) */ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { parameters.put(PARAM_USE_JINJA, String.valueOf(useChatTemplate)); @@ -508,18 +497,22 @@ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { * - Allows **one or more** user/assistant messages. */ public InferenceParameters setMessages(String systemMessage, List> messages) { - ArrayNode messagesArray = OBJECT_MAPPER.createArrayNode(); + StringBuilder messagesBuilder = new StringBuilder(); + messagesBuilder.append("["); // Add system message (if provided) if (systemMessage != null && !systemMessage.isEmpty()) { - ObjectNode systemObj = OBJECT_MAPPER.createObjectNode(); - systemObj.put("role", "system"); - systemObj.put("content", systemMessage); - messagesArray.add(systemObj); + messagesBuilder.append("{\"role\": \"system\", \"content\": ") + .append(toJsonString(systemMessage)) + .append("}"); + if (!messages.isEmpty()) { + messagesBuilder.append(", "); + } } // Add user/assistant messages - for (Pair message : messages) { + for (int i = 0; i < messages.size(); i++) { + Pair message = messages.get(i); String role = message.getKey(); String content = message.getValue(); @@ -527,17 +520,27 @@ public InferenceParameters setMessages(String systemMessage, List Date: Tue, 18 Mar 2025 21:29:57 +0100 Subject: [PATCH 12/13] update readme code examples --- README.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 32f555e..1990aac 100644 --- a/README.md +++ b/README.md @@ -94,8 +94,8 @@ public class Example { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setNGpuLayers(43); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + @@ -114,8 +114,8 @@ public class Example { InferenceParameters inferParams = new InferenceParameters(prompt) .setTemperature(0.7f) .setPenalizeNl(true) - .setMirostat(InferenceParameters.MiroStat.V2) - .setAntiPrompt("\n"); + .setMiroStat(MiroStat.V2) + .setStopStrings("User:"); for (LlamaOutput output : model.generate(inferParams)) { System.out.print(output); prompt += output; @@ -135,7 +135,7 @@ model to your prompt in order to extend the context. If there is repeated conten cache this, to improve performance. ```java -ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf"); +ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf"); InferenceParameters inferParams = new InferenceParameters("Tell me a joke."); try (LlamaModel model = new LlamaModel(modelParams)) { // Stream a response and access more information about each output. @@ -167,9 +167,8 @@ for every inference task. All non-specified options have sensible defaults. ```java ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setLoraAdapter("/path/to/lora/adapter") - .setLoraBase("/path/to/lora/base"); + .setModel("/path/to/model.gguf") + .addLoraAdapter("/path/to/lora/adapter"); String grammar = """ root ::= (expr "=" term "\\n")+ expr ::= term ([-+*/] term)* From 1ad2bf6840fb6a2033f9b9a717031d7ca0e26259 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Tue, 18 Mar 2025 21:32:14 +0100 Subject: [PATCH 13/13] update to latest llama.cpp version --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2278d45..8f402fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4831 + GIT_TAG b4916 ) FetchContent_MakeAvailable(llama.cpp)