diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 631fc86..a15f809 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,6 +6,8 @@ on: env: MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf MODEL_NAME: codellama-7b.Q2_K.gguf + RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf + RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf jobs: build-and-test-linux: @@ -21,8 +23,12 @@ jobs: run: | mvn compile .github/build.sh -DLLAMA_VERBOSE=ON - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -53,8 +59,12 @@ jobs: run: | mvn compile .github/build.sh ${{ matrix.target.cmake }} - - name: Download model + - name: Download text generaton model model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() @@ -79,6 +89,10 @@ jobs: .github\build.bat -DLLAMA_VERBOSE=ON - name: Download model run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME + - name: Download reranking model + run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME + - name: List files in models directory + run: ls -l models/ - name: Run tests run: mvn test - if: failure() diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ff566ad..6403202 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -11,6 +11,8 @@ on: env: MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf" MODEL_NAME: "codellama-7b.Q2_K.gguf" + RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf" + RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf" jobs: # todo: doesn't work with the newest llama.cpp version @@ -144,8 +146,10 @@ jobs: with: name: Linux-x86_64-libraries path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/ - - name: Download model + - name: Download text generation model run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME} + - name: Download reranking model + run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME} - uses: actions/setup-java@v4 with: distribution: 'zulu' diff --git a/CMakeLists.txt b/CMakeLists.txt index 2278d45..8f402fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b4831 + GIT_TAG b4916 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/README.md b/README.md index 32f555e..1990aac 100644 --- a/README.md +++ b/README.md @@ -94,8 +94,8 @@ public class Example { public static void main(String... args) throws IOException { ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setNGpuLayers(43); + .setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf") + .setGpuLayers(43); String system = "This is a conversation between User and Llama, a friendly chatbot.\n" + "Llama is helpful, kind, honest, good at writing, and never fails to answer any " + @@ -114,8 +114,8 @@ public class Example { InferenceParameters inferParams = new InferenceParameters(prompt) .setTemperature(0.7f) .setPenalizeNl(true) - .setMirostat(InferenceParameters.MiroStat.V2) - .setAntiPrompt("\n"); + .setMiroStat(MiroStat.V2) + .setStopStrings("User:"); for (LlamaOutput output : model.generate(inferParams)) { System.out.print(output); prompt += output; @@ -135,7 +135,7 @@ model to your prompt in order to extend the context. If there is repeated conten cache this, to improve performance. ```java -ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf"); +ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf"); InferenceParameters inferParams = new InferenceParameters("Tell me a joke."); try (LlamaModel model = new LlamaModel(modelParams)) { // Stream a response and access more information about each output. @@ -167,9 +167,8 @@ for every inference task. All non-specified options have sensible defaults. ```java ModelParameters modelParams = new ModelParameters() - .setModelFilePath("/path/to/model.gguf") - .setLoraAdapter("/path/to/lora/adapter") - .setLoraBase("/path/to/lora/base"); + .setModel("/path/to/model.gguf") + .addLoraAdapter("/path/to/lora/adapter"); String grammar = """ root ::= (expr "=" term "\\n")+ expr ::= term ([-+*/] term)* diff --git a/pom.xml b/pom.xml index c081e19..4982f40 100644 --- a/pom.xml +++ b/pom.xml @@ -1,4 +1,5 @@ - 4.0.0 @@ -8,7 +9,8 @@ jar ${project.groupId}:${project.artifactId} - Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++. + Java Bindings for llama.cpp - A Port of Facebook's LLaMA model + in C/C++. https://github.com/kherud/java-llama.cpp @@ -39,7 +41,8 @@ ossrh - https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ + + https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ @@ -71,17 +74,21 @@ maven-compiler-plugin 3.13.0 - + gpu compile - compile + + compile + -h src/main/cpp - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda @@ -98,10 +105,12 @@ copy-resources - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda - ${basedir}/src/main/resources_linux_cuda/ + + ${basedir}/src/main/resources_linux_cuda/ **/*.* @@ -176,7 +185,8 @@ maven-jar-plugin 3.4.2 - + cuda package @@ -185,7 +195,8 @@ cuda12-linux-x86-64 - ${project.build.outputDirectory}_cuda + + ${project.build.outputDirectory}_cuda diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index 0db026e..ac056b9 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -634,7 +634,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, json error = nullptr; server_task_result_ptr result = ctx_server->queue_results.recv(id_task); - ctx_server->queue_results.remove_waiting_task_id(id_task); json response_str = result->to_json(); if (result->is_error()) { @@ -644,6 +643,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return nullptr; } + if (result->is_stop()) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + const auto out_res = result->to_json(); // Extract "embedding" as a vector of vectors (2D array) @@ -679,6 +682,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, return j_embedding; } +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, + jobjectArray documents) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); + return nullptr; + } + + const std::string prompt = parse_jstring(env, jprompt); + + const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true); + + json responses = json::array(); + + std::vector tasks; + const jsize amount_documents = env->GetArrayLength(documents); + auto *document_array = parse_string_array(env, documents, amount_documents); + auto document_vector = std::vector(document_array, document_array + amount_documents); + free_string_array(document_array, amount_documents); + + std::vector tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true); + + tasks.reserve(tokenized_docs.size()); + for (int i = 0; i < tokenized_docs.size(); i++) { + auto task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server->queue_tasks.get_new_id(); + task.index = i; + task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]); + tasks.push_back(task); + } + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); + std::vector results(task_ids.size()); + + // Create a new HashMap instance + jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map); + if (o_probabilities == nullptr) { + env->ThrowNew(c_llama_error, "Failed to create HashMap object."); + return nullptr; + } + + for (int i = 0; i < (int)task_ids.size(); i++) { + server_task_result_ptr result = ctx_server->queue_results.recv(task_ids); + if (result->is_error()) { + auto response = result->to_json()["message"].get(); + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + env->ThrowNew(c_llama_error, response.c_str()); + return nullptr; + } + + const auto out_res = result->to_json(); + + if (result->is_stop()) { + for (const int id_task : task_ids) { + ctx_server->queue_results.remove_waiting_task_id(id_task); + } + } + + int index = out_res["index"].get(); + float score = out_res["score"].get(); + std::string tok_str = document_vector[index]; + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + jobject jprob = env->NewObject(c_float, cc_float, score); + env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob); + env->DeleteLocalRef(jtok_str); + env->DeleteLocalRef(jprob); + } + jbyteArray jbytes = parse_jbytes(env, prompt); + return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true); +} + +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) { + jlong server_handle = env->GetLongField(obj, f_model_pointer); + auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) + + std::string c_params = parse_jstring(env, jparams); + json data = json::parse(c_params); + + json templateData = + oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, + ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get()); + std::string tok_str = templateData.at("prompt"); + jstring jtok_str = env->NewStringUTF(tok_str.c_str()); + + return jtok_str; +} + JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) { jlong server_handle = env->GetLongField(obj, f_model_pointer); auto *ctx_server = reinterpret_cast(server_handle); // NOLINT(*-no-int-to-ptr) diff --git a/src/main/cpp/jllama.h b/src/main/cpp/jllama.h index 63d95b7..dc17fa8 100644 --- a/src/main/cpp/jllama.h +++ b/src/main/cpp/jllama.h @@ -84,6 +84,20 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *, job */ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *, jclass, jstring); +/* + * Class: de_kherud_llama_LlamaModel + * Method: rerank + * Signature: (Ljava/lang/String;[Ljava/lang/String;)Lde/kherud/llama/LlamaOutput; + */ +JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *, jobject, jstring, jobjectArray); + +/* + * Class: de_kherud_llama_LlamaModel + * Method: applyTemplate + * Signature: (Ljava/lang/String;)Ljava/lang/String;; + */ +JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *, jobject, jstring); + #ifdef __cplusplus } #endif diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 0ac1b1d..41f74cc 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -1,6 +1,7 @@ package de.kherud.llama; import java.util.Collection; +import java.util.List; import java.util.Map; import de.kherud.llama.args.MiroStat; @@ -11,6 +12,7 @@ * and * {@link LlamaModel#complete(InferenceParameters)}. */ +@SuppressWarnings("unused") public final class InferenceParameters extends JsonParameters { private static final String PARAM_PROMPT = "prompt"; @@ -47,6 +49,7 @@ public final class InferenceParameters extends JsonParameters { private static final String PARAM_STREAM = "stream"; private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template"; private static final String PARAM_USE_JINJA = "use_jinja"; + private static final String PARAM_MESSAGES = "messages"; public InferenceParameters(String prompt) { // we always need a prompt @@ -480,21 +483,64 @@ public InferenceParameters setSamplers(Sampler... samplers) { return this; } - InferenceParameters setStream(boolean stream) { - parameters.put(PARAM_STREAM, String.valueOf(stream)); - return this; - } - /** - * Set whether or not generate should apply a chat template (default: false) + * Set whether generate should apply a chat template (default: false) */ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) { parameters.put(PARAM_USE_JINJA, String.valueOf(useChatTemplate)); return this; } - - - + /** + * Set the messages for chat-based inference. + * - Allows **only one** system message. + * - Allows **one or more** user/assistant messages. + */ + public InferenceParameters setMessages(String systemMessage, List> messages) { + StringBuilder messagesBuilder = new StringBuilder(); + messagesBuilder.append("["); + + // Add system message (if provided) + if (systemMessage != null && !systemMessage.isEmpty()) { + messagesBuilder.append("{\"role\": \"system\", \"content\": ") + .append(toJsonString(systemMessage)) + .append("}"); + if (!messages.isEmpty()) { + messagesBuilder.append(", "); + } + } + + // Add user/assistant messages + for (int i = 0; i < messages.size(); i++) { + Pair message = messages.get(i); + String role = message.getKey(); + String content = message.getValue(); + + if (!role.equals("user") && !role.equals("assistant")) { + throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'."); + } + + messagesBuilder.append("{\"role\":") + .append(toJsonString(role)) + .append(", \"content\": ") + .append(toJsonString(content)) + .append("}"); + + if (i < messages.size() - 1) { + messagesBuilder.append(", "); + } + } + + messagesBuilder.append("]"); + + // Convert ArrayNode to a JSON string and store it in parameters + parameters.put(PARAM_MESSAGES, messagesBuilder.toString()); + return this; + } + + InferenceParameters setStream(boolean stream) { + parameters.put(PARAM_STREAM, String.valueOf(stream)); + return this; + } } diff --git a/src/main/java/de/kherud/llama/LlamaIterator.java b/src/main/java/de/kherud/llama/LlamaIterator.java index fdff993..cb1c5c2 100644 --- a/src/main/java/de/kherud/llama/LlamaIterator.java +++ b/src/main/java/de/kherud/llama/LlamaIterator.java @@ -35,6 +35,9 @@ public LlamaOutput next() { } LlamaOutput output = model.receiveCompletion(taskId); hasNext = !output.stop; + if (output.stop) { + model.releaseTask(taskId); + } return output; } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 7749b32..eab3620 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -5,6 +5,9 @@ import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; /** @@ -130,11 +133,39 @@ public void close() { private native void delete(); - private native void releaseTask(int taskId); + native void releaseTask(int taskId); private static native byte[] jsonSchemaToGrammarBytes(String schema); public static String jsonSchemaToGrammar(String schema) { return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8); } + + public List> rerank(boolean reRank, String query, String ... documents) { + LlamaOutput output = rerank(query, documents); + + Map scoredDocumentMap = output.probabilities; + + List> rankedDocuments = new ArrayList<>(); + + if (reRank) { + // Sort in descending order based on Float values + scoredDocumentMap.entrySet() + .stream() + .sorted((a, b) -> Float.compare(b.getValue(), a.getValue())) // Descending order + .forEach(entry -> rankedDocuments.add(new Pair<>(entry.getKey(), entry.getValue()))); + } else { + // Copy without sorting + scoredDocumentMap.forEach((key, value) -> rankedDocuments.add(new Pair<>(key, value))); + } + + return rankedDocuments; + } + + public native LlamaOutput rerank(String query, String... documents); + + public String applyTemplate(InferenceParameters parameters) { + return applyTemplate(parameters.toString()); + } + public native String applyTemplate(String parametersJson); } diff --git a/src/main/java/de/kherud/llama/Pair.java b/src/main/java/de/kherud/llama/Pair.java new file mode 100644 index 0000000..48ac648 --- /dev/null +++ b/src/main/java/de/kherud/llama/Pair.java @@ -0,0 +1,48 @@ +package de.kherud.llama; + +import java.util.Objects; + +public class Pair { + + private final K key; + private final V value; + + public Pair(K key, V value) { + this.key = key; + this.value = value; + } + + public K getKey() { + return key; + } + + public V getValue() { + return value; + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Pair other = (Pair) obj; + return Objects.equals(key, other.key) && Objects.equals(value, other.value); + } + + @Override + public String toString() { + return "Pair [key=" + key + ", value=" + value + "]"; + } + + + + +} diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index f2e931b..e3e69d8 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -158,6 +158,26 @@ public void testEmbedding() { float[] embedding = model.embed(prefix); Assert.assertEquals(4096, embedding.length); } + + + @Ignore + /** + * To run this test download the model from here https://huggingface.co/mradermacher/jina-reranker-v1-tiny-en-GGUF/tree/main + * remove .enableEmbedding() from model setup and add .enableReRanking() and then enable the test. + */ + public void testReRanking() { + + String query = "Machine learning is"; + String [] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + }; + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], TEST_DOCUMENTS[3] ); + + System.out.println(llamaOutput); + } @Test public void testTokenization() { @@ -296,4 +316,20 @@ public void testJsonSchemaToGrammar() { String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema); Assert.assertEquals(expectedGrammar, actualGrammar); } + + @Test + public void testTemplate() { + + List> userMessages = new ArrayList<>(); + userMessages.add(new Pair<>("user", "What is the best book?")); + userMessages.add(new Pair<>("assistant", "It depends on your interests. Do you like fiction or non-fiction?")); + + InferenceParameters params = new InferenceParameters("A book recommendation system.") + .setMessages("Book", userMessages) + .setTemperature(0.95f) + .setStopStrings("\"\"\"") + .setNPredict(nPredict) + .setSeed(42); + Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n"); + } } diff --git a/src/test/java/de/kherud/llama/RerankingModelTest.java b/src/test/java/de/kherud/llama/RerankingModelTest.java new file mode 100644 index 0000000..60d32bd --- /dev/null +++ b/src/test/java/de/kherud/llama/RerankingModelTest.java @@ -0,0 +1,83 @@ +package de.kherud.llama; + +import java.util.List; +import java.util.Map; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class RerankingModelTest { + + private static LlamaModel model; + + String query = "Machine learning is"; + String[] TEST_DOCUMENTS = new String[] { + "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", + "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", + "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." }; + + @BeforeClass + public static void setup() { + model = new LlamaModel( + new ModelParameters().setCtxSize(128).setModel("models/jina-reranker-v1-tiny-en-Q4_0.gguf") + .setGpuLayers(43).enableReranking().enableLogTimestamps().enableLogPrefix()); + } + + @AfterClass + public static void tearDown() { + if (model != null) { + model.close(); + } + } + + @Test + public void testReRanking() { + + + LlamaOutput llamaOutput = model.rerank(query, TEST_DOCUMENTS[0], TEST_DOCUMENTS[1], TEST_DOCUMENTS[2], + TEST_DOCUMENTS[3]); + + Map rankedDocumentsMap = llamaOutput.probabilities; + Assert.assertTrue(rankedDocumentsMap.size()==TEST_DOCUMENTS.length); + + // Finding the most and least relevant documents + String mostRelevantDoc = null; + String leastRelevantDoc = null; + float maxScore = Float.MIN_VALUE; + float minScore = Float.MAX_VALUE; + + for (Map.Entry entry : rankedDocumentsMap.entrySet()) { + if (entry.getValue() > maxScore) { + maxScore = entry.getValue(); + mostRelevantDoc = entry.getKey(); + } + if (entry.getValue() < minScore) { + minScore = entry.getValue(); + leastRelevantDoc = entry.getKey(); + } + } + + // Assertions + Assert.assertTrue(maxScore > minScore); + Assert.assertEquals("Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", mostRelevantDoc); + Assert.assertEquals("Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", leastRelevantDoc); + + + } + + @Test + public void testSortedReRanking() { + List> rankedDocuments = model.rerank(true, query, TEST_DOCUMENTS); + Assert.assertEquals(rankedDocuments.size(), TEST_DOCUMENTS.length); + + // Check the ranking order: each score should be >= the next one + for (int i = 0; i < rankedDocuments.size() - 1; i++) { + float currentScore = rankedDocuments.get(i).getValue(); + float nextScore = rankedDocuments.get(i + 1).getValue(); + Assert.assertTrue("Ranking order incorrect at index " + i, currentScore >= nextScore); + } + } +}