Skip to content

adding re-ranking #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ on:
env:
MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf
MODEL_NAME: codellama-7b.Q2_K.gguf
RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf
RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf
jobs:

build-and-test-linux:
Expand All @@ -21,8 +23,12 @@ jobs:
run: |
mvn compile
.github/build.sh -DLLAMA_VERBOSE=ON
- name: Download model
- name: Download text generation model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
- name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
- name: List files in models directory
run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
Expand Down Expand Up @@ -53,8 +59,12 @@ jobs:
run: |
mvn compile
.github/build.sh ${{ matrix.target.cmake }}
- name: Download model
- name: Download text generaton model model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
- name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
- name: List files in models directory
run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
Expand All @@ -79,6 +89,10 @@ jobs:
.github\build.bat -DLLAMA_VERBOSE=ON
- name: Download model
run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME
- name: Download reranking model
run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME
- name: List files in models directory
run: ls -l models/
- name: Run tests
run: mvn test
- if: failure()
Expand Down
6 changes: 5 additions & 1 deletion .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ on:
env:
MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf"
MODEL_NAME: "codellama-7b.Q2_K.gguf"
RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf"
RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf"
jobs:

# todo: doesn't work with the newest llama.cpp version
Expand Down Expand Up @@ -144,8 +146,10 @@ jobs:
with:
name: Linux-x86_64-libraries
path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/
- name: Download model
- name: Download text generation model
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
- name: Download reranking model
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
- uses: actions/setup-java@v4
with:
distribution: 'zulu'
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
GIT_TAG b4831
GIT_TAG b4916
)
FetchContent_MakeAvailable(llama.cpp)

Expand Down
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ public class Example {

public static void main(String... args) throws IOException {
ModelParameters modelParams = new ModelParameters()
.setModelFilePath("/path/to/model.gguf")
.setNGpuLayers(43);
.setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf")
.setGpuLayers(43);

String system = "This is a conversation between User and Llama, a friendly chatbot.\n" +
"Llama is helpful, kind, honest, good at writing, and never fails to answer any " +
Expand All @@ -114,8 +114,8 @@ public class Example {
InferenceParameters inferParams = new InferenceParameters(prompt)
.setTemperature(0.7f)
.setPenalizeNl(true)
.setMirostat(InferenceParameters.MiroStat.V2)
.setAntiPrompt("\n");
.setMiroStat(MiroStat.V2)
.setStopStrings("User:");
for (LlamaOutput output : model.generate(inferParams)) {
System.out.print(output);
prompt += output;
Expand All @@ -135,7 +135,7 @@ model to your prompt in order to extend the context. If there is repeated conten
cache this, to improve performance.

```java
ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf");
ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf");
InferenceParameters inferParams = new InferenceParameters("Tell me a joke.");
try (LlamaModel model = new LlamaModel(modelParams)) {
// Stream a response and access more information about each output.
Expand Down Expand Up @@ -167,9 +167,8 @@ for every inference task. All non-specified options have sensible defaults.

```java
ModelParameters modelParams = new ModelParameters()
.setModelFilePath("/path/to/model.gguf")
.setLoraAdapter("/path/to/lora/adapter")
.setLoraBase("/path/to/lora/base");
.setModel("/path/to/model.gguf")
.addLoraAdapter("/path/to/lora/adapter");
String grammar = """
root ::= (expr "=" term "\\n")+
expr ::= term ([-+*/] term)*
Expand Down
31 changes: 21 additions & 10 deletions pom.xml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

Expand All @@ -8,7 +9,8 @@
<packaging>jar</packaging>

<name>${project.groupId}:${project.artifactId}</name>
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++.</description>
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model
in C/C++.</description>
<url>https://github.com/kherud/java-llama.cpp</url>

<licenses>
Expand Down Expand Up @@ -39,7 +41,8 @@
</snapshotRepository>
<repository>
<id>ossrh</id>
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
<url>
https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
</repository>
</distributionManagement>

Expand Down Expand Up @@ -71,17 +74,21 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.13.0</version>
<executions>
<!-- We have to perform a separate build pass for cuda classifier -->
<!-- We have to perform a separate build pass for cuda
classifier -->
<execution>
<id>gpu</id>
<phase>compile</phase>
<goals><goal>compile</goal></goals>
<goals>
<goal>compile</goal>
</goals>
<configuration>
<compilerArgs>
<arg>-h</arg>
<arg>src/main/cpp</arg>
</compilerArgs>
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
<outputDirectory>
${project.build.outputDirectory}_cuda</outputDirectory>
</configuration>
</execution>
</executions>
Expand All @@ -98,10 +105,12 @@
<goal>copy-resources</goal>
</goals>
<configuration>
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
<outputDirectory>
${project.build.outputDirectory}_cuda</outputDirectory>
<resources>
<resource>
<directory>${basedir}/src/main/resources_linux_cuda/</directory>
<directory>
${basedir}/src/main/resources_linux_cuda/</directory>
<includes>
<include>**/*.*</include>
</includes>
Expand Down Expand Up @@ -176,7 +185,8 @@
<artifactId>maven-jar-plugin</artifactId>
<version>3.4.2</version>
<executions>
<!-- Pick class files AND libs from custom output directory -->
<!-- Pick class files AND libs from custom output
directory -->
<execution>
<id>cuda</id>
<phase>package</phase>
Expand All @@ -185,7 +195,8 @@
</goals>
<configuration>
<classifier>cuda12-linux-x86-64</classifier>
<classesDirectory>${project.build.outputDirectory}_cuda</classesDirectory>
<classesDirectory>
${project.build.outputDirectory}_cuda</classesDirectory>
</configuration>
</execution>
</executions>
Expand Down
101 changes: 100 additions & 1 deletion src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
json error = nullptr;

server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
ctx_server->queue_results.remove_waiting_task_id(id_task);

json response_str = result->to_json();
if (result->is_error()) {
Expand All @@ -644,6 +643,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
return nullptr;
}

if (result->is_stop()) {
ctx_server->queue_results.remove_waiting_task_id(id_task);
}

const auto out_res = result->to_json();

// Extract "embedding" as a vector of vectors (2D array)
Expand Down Expand Up @@ -679,6 +682,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
return j_embedding;
}

JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt,
jobjectArray documents) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
"This server does not support reranking. Start it with `--reranking` and without `--embedding`");
return nullptr;
}

const std::string prompt = parse_jstring(env, jprompt);

const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);

json responses = json::array();

std::vector<server_task> tasks;
const jsize amount_documents = env->GetArrayLength(documents);
auto *document_array = parse_string_array(env, documents, amount_documents);
auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
free_string_array(document_array, amount_documents);

std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true);

tasks.reserve(tokenized_docs.size());
for (int i = 0; i < tokenized_docs.size(); i++) {
auto task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server->queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
tasks.push_back(task);
}
ctx_server->queue_results.add_waiting_tasks(tasks);
ctx_server->queue_tasks.post(tasks);

// get the result
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
std::vector<server_task_result_ptr> results(task_ids.size());

// Create a new HashMap instance
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
if (o_probabilities == nullptr) {
env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
return nullptr;
}

for (int i = 0; i < (int)task_ids.size(); i++) {
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
if (result->is_error()) {
auto response = result->to_json()["message"].get<std::string>();
for (const int id_task : task_ids) {
ctx_server->queue_results.remove_waiting_task_id(id_task);
}
env->ThrowNew(c_llama_error, response.c_str());
return nullptr;
}

const auto out_res = result->to_json();

if (result->is_stop()) {
for (const int id_task : task_ids) {
ctx_server->queue_results.remove_waiting_task_id(id_task);
}
}

int index = out_res["index"].get<int>();
float score = out_res["score"].get<float>();
std::string tok_str = document_vector[index];
jstring jtok_str = env->NewStringUTF(tok_str.c_str());

jobject jprob = env->NewObject(c_float, cc_float, score);
env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
env->DeleteLocalRef(jtok_str);
env->DeleteLocalRef(jprob);
}
jbyteArray jbytes = parse_jbytes(env, prompt);
return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
}

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

std::string c_params = parse_jstring(env, jparams);
json data = json::parse(c_params);

json templateData =
oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
std::string tok_str = templateData.at("prompt");
jstring jtok_str = env->NewStringUTF(tok_str.c_str());

return jtok_str;
}

JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
Expand Down
14 changes: 14 additions & 0 deletions src/main/cpp/jllama.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading