Skip to content

Commit 5c34026

Browse files
authored
Merge pull request #96 from vaiju1981/re_rank
- update to llama.cpp b4916 - support re-ranking - expose applying chat-template
2 parents ca148c8 + 1ad2bf6 commit 5c34026

File tree

13 files changed

+421
-33
lines changed

13 files changed

+421
-33
lines changed

.github/workflows/ci.yml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ on:
66
env:
77
MODEL_URL: https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf
88
MODEL_NAME: codellama-7b.Q2_K.gguf
9+
RERANKING_MODEL_URL: https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf
10+
RERANKING_MODEL_NAME: jina-reranker-v1-tiny-en-Q4_0.gguf
911
jobs:
1012

1113
build-and-test-linux:
@@ -21,8 +23,12 @@ jobs:
2123
run: |
2224
mvn compile
2325
.github/build.sh -DLLAMA_VERBOSE=ON
24-
- name: Download model
26+
- name: Download text generation model
2527
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
28+
- name: Download reranking model
29+
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
30+
- name: List files in models directory
31+
run: ls -l models/
2632
- name: Run tests
2733
run: mvn test
2834
- if: failure()
@@ -53,8 +59,12 @@ jobs:
5359
run: |
5460
mvn compile
5561
.github/build.sh ${{ matrix.target.cmake }}
56-
- name: Download model
62+
- name: Download text generaton model model
5763
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
64+
- name: Download reranking model
65+
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
66+
- name: List files in models directory
67+
run: ls -l models/
5868
- name: Run tests
5969
run: mvn test
6070
- if: failure()
@@ -79,6 +89,10 @@ jobs:
7989
.github\build.bat -DLLAMA_VERBOSE=ON
8090
- name: Download model
8191
run: curl -L $env:MODEL_URL --create-dirs -o models/$env:MODEL_NAME
92+
- name: Download reranking model
93+
run: curl -L $env:RERANKING_MODEL_URL --create-dirs -o models/$env:RERANKING_MODEL_NAME
94+
- name: List files in models directory
95+
run: ls -l models/
8296
- name: Run tests
8397
run: mvn test
8498
- if: failure()

.github/workflows/release.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ on:
1111
env:
1212
MODEL_URL: "https://huggingface.co/TheBloke/CodeLlama-7B-GGUF/resolve/main/codellama-7b.Q2_K.gguf"
1313
MODEL_NAME: "codellama-7b.Q2_K.gguf"
14+
RERANKING_MODEL_URL: "https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-Q4_0.gguf"
15+
RERANKING_MODEL_NAME: "jina-reranker-v1-tiny-en-Q4_0.gguf"
1416
jobs:
1517

1618
# todo: doesn't work with the newest llama.cpp version
@@ -144,8 +146,10 @@ jobs:
144146
with:
145147
name: Linux-x86_64-libraries
146148
path: ${{ github.workspace }}/src/main/resources/de/kherud/llama/
147-
- name: Download model
149+
- name: Download text generation model
148150
run: curl -L ${MODEL_URL} --create-dirs -o models/${MODEL_NAME}
151+
- name: Download reranking model
152+
run: curl -L ${RERANKING_MODEL_URL} --create-dirs -o models/${RERANKING_MODEL_NAME}
149153
- uses: actions/setup-java@v4
150154
with:
151155
distribution: 'zulu'

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ set(LLAMA_BUILD_COMMON ON)
2525
FetchContent_Declare(
2626
llama.cpp
2727
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
28-
GIT_TAG b4831
28+
GIT_TAG b4916
2929
)
3030
FetchContent_MakeAvailable(llama.cpp)
3131

README.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ public class Example {
9494

9595
public static void main(String... args) throws IOException {
9696
ModelParameters modelParams = new ModelParameters()
97-
.setModelFilePath("/path/to/model.gguf")
98-
.setNGpuLayers(43);
97+
.setModel("models/mistral-7b-instruct-v0.2.Q2_K.gguf")
98+
.setGpuLayers(43);
9999

100100
String system = "This is a conversation between User and Llama, a friendly chatbot.\n" +
101101
"Llama is helpful, kind, honest, good at writing, and never fails to answer any " +
@@ -114,8 +114,8 @@ public class Example {
114114
InferenceParameters inferParams = new InferenceParameters(prompt)
115115
.setTemperature(0.7f)
116116
.setPenalizeNl(true)
117-
.setMirostat(InferenceParameters.MiroStat.V2)
118-
.setAntiPrompt("\n");
117+
.setMiroStat(MiroStat.V2)
118+
.setStopStrings("User:");
119119
for (LlamaOutput output : model.generate(inferParams)) {
120120
System.out.print(output);
121121
prompt += output;
@@ -135,7 +135,7 @@ model to your prompt in order to extend the context. If there is repeated conten
135135
cache this, to improve performance.
136136

137137
```java
138-
ModelParameters modelParams = new ModelParameters().setModelFilePath("/path/to/model.gguf");
138+
ModelParameters modelParams = new ModelParameters().setModel("/path/to/model.gguf");
139139
InferenceParameters inferParams = new InferenceParameters("Tell me a joke.");
140140
try (LlamaModel model = new LlamaModel(modelParams)) {
141141
// Stream a response and access more information about each output.
@@ -167,9 +167,8 @@ for every inference task. All non-specified options have sensible defaults.
167167

168168
```java
169169
ModelParameters modelParams = new ModelParameters()
170-
.setModelFilePath("/path/to/model.gguf")
171-
.setLoraAdapter("/path/to/lora/adapter")
172-
.setLoraBase("/path/to/lora/base");
170+
.setModel("/path/to/model.gguf")
171+
.addLoraAdapter("/path/to/lora/adapter");
173172
String grammar = """
174173
root ::= (expr "=" term "\\n")+
175174
expr ::= term ([-+*/] term)*

pom.xml

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
1+
<project xmlns="http://maven.apache.org/POM/4.0.0"
2+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
23
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
34
<modelVersion>4.0.0</modelVersion>
45

@@ -8,7 +9,8 @@
89
<packaging>jar</packaging>
910

1011
<name>${project.groupId}:${project.artifactId}</name>
11-
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model in C/C++.</description>
12+
<description>Java Bindings for llama.cpp - A Port of Facebook's LLaMA model
13+
in C/C++.</description>
1214
<url>https://github.com/kherud/java-llama.cpp</url>
1315

1416
<licenses>
@@ -39,7 +41,8 @@
3941
</snapshotRepository>
4042
<repository>
4143
<id>ossrh</id>
42-
<url>https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
44+
<url>
45+
https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/</url>
4346
</repository>
4447
</distributionManagement>
4548

@@ -71,17 +74,21 @@
7174
<artifactId>maven-compiler-plugin</artifactId>
7275
<version>3.13.0</version>
7376
<executions>
74-
<!-- We have to perform a separate build pass for cuda classifier -->
77+
<!-- We have to perform a separate build pass for cuda
78+
classifier -->
7579
<execution>
7680
<id>gpu</id>
7781
<phase>compile</phase>
78-
<goals><goal>compile</goal></goals>
82+
<goals>
83+
<goal>compile</goal>
84+
</goals>
7985
<configuration>
8086
<compilerArgs>
8187
<arg>-h</arg>
8288
<arg>src/main/cpp</arg>
8389
</compilerArgs>
84-
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
90+
<outputDirectory>
91+
${project.build.outputDirectory}_cuda</outputDirectory>
8592
</configuration>
8693
</execution>
8794
</executions>
@@ -98,10 +105,12 @@
98105
<goal>copy-resources</goal>
99106
</goals>
100107
<configuration>
101-
<outputDirectory>${project.build.outputDirectory}_cuda</outputDirectory>
108+
<outputDirectory>
109+
${project.build.outputDirectory}_cuda</outputDirectory>
102110
<resources>
103111
<resource>
104-
<directory>${basedir}/src/main/resources_linux_cuda/</directory>
112+
<directory>
113+
${basedir}/src/main/resources_linux_cuda/</directory>
105114
<includes>
106115
<include>**/*.*</include>
107116
</includes>
@@ -176,7 +185,8 @@
176185
<artifactId>maven-jar-plugin</artifactId>
177186
<version>3.4.2</version>
178187
<executions>
179-
<!-- Pick class files AND libs from custom output directory -->
188+
<!-- Pick class files AND libs from custom output
189+
directory -->
180190
<execution>
181191
<id>cuda</id>
182192
<phase>package</phase>
@@ -185,7 +195,8 @@
185195
</goals>
186196
<configuration>
187197
<classifier>cuda12-linux-x86-64</classifier>
188-
<classesDirectory>${project.build.outputDirectory}_cuda</classesDirectory>
198+
<classesDirectory>
199+
${project.build.outputDirectory}_cuda</classesDirectory>
189200
</configuration>
190201
</execution>
191202
</executions>

src/main/cpp/jllama.cpp

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
634634
json error = nullptr;
635635

636636
server_task_result_ptr result = ctx_server->queue_results.recv(id_task);
637-
ctx_server->queue_results.remove_waiting_task_id(id_task);
638637

639638
json response_str = result->to_json();
640639
if (result->is_error()) {
@@ -644,6 +643,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
644643
return nullptr;
645644
}
646645

646+
if (result->is_stop()) {
647+
ctx_server->queue_results.remove_waiting_task_id(id_task);
648+
}
649+
647650
const auto out_res = result->to_json();
648651

649652
// Extract "embedding" as a vector of vectors (2D array)
@@ -679,6 +682,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
679682
return j_embedding;
680683
}
681684

685+
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt,
686+
jobjectArray documents) {
687+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
688+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
689+
690+
if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
691+
env->ThrowNew(c_llama_error,
692+
"This server does not support reranking. Start it with `--reranking` and without `--embedding`");
693+
return nullptr;
694+
}
695+
696+
const std::string prompt = parse_jstring(env, jprompt);
697+
698+
const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
699+
700+
json responses = json::array();
701+
702+
std::vector<server_task> tasks;
703+
const jsize amount_documents = env->GetArrayLength(documents);
704+
auto *document_array = parse_string_array(env, documents, amount_documents);
705+
auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
706+
free_string_array(document_array, amount_documents);
707+
708+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true);
709+
710+
tasks.reserve(tokenized_docs.size());
711+
for (int i = 0; i < tokenized_docs.size(); i++) {
712+
auto task = server_task(SERVER_TASK_TYPE_RERANK);
713+
task.id = ctx_server->queue_tasks.get_new_id();
714+
task.index = i;
715+
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
716+
tasks.push_back(task);
717+
}
718+
ctx_server->queue_results.add_waiting_tasks(tasks);
719+
ctx_server->queue_tasks.post(tasks);
720+
721+
// get the result
722+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
723+
std::vector<server_task_result_ptr> results(task_ids.size());
724+
725+
// Create a new HashMap instance
726+
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
727+
if (o_probabilities == nullptr) {
728+
env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
729+
return nullptr;
730+
}
731+
732+
for (int i = 0; i < (int)task_ids.size(); i++) {
733+
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
734+
if (result->is_error()) {
735+
auto response = result->to_json()["message"].get<std::string>();
736+
for (const int id_task : task_ids) {
737+
ctx_server->queue_results.remove_waiting_task_id(id_task);
738+
}
739+
env->ThrowNew(c_llama_error, response.c_str());
740+
return nullptr;
741+
}
742+
743+
const auto out_res = result->to_json();
744+
745+
if (result->is_stop()) {
746+
for (const int id_task : task_ids) {
747+
ctx_server->queue_results.remove_waiting_task_id(id_task);
748+
}
749+
}
750+
751+
int index = out_res["index"].get<int>();
752+
float score = out_res["score"].get<float>();
753+
std::string tok_str = document_vector[index];
754+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
755+
756+
jobject jprob = env->NewObject(c_float, cc_float, score);
757+
env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
758+
env->DeleteLocalRef(jtok_str);
759+
env->DeleteLocalRef(jprob);
760+
}
761+
jbyteArray jbytes = parse_jbytes(env, prompt);
762+
return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
763+
}
764+
765+
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
766+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
767+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
768+
769+
std::string c_params = parse_jstring(env, jparams);
770+
json data = json::parse(c_params);
771+
772+
json templateData =
773+
oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
774+
ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
775+
std::string tok_str = templateData.at("prompt");
776+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
777+
778+
return jtok_str;
779+
}
780+
682781
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
683782
jlong server_handle = env->GetLongField(obj, f_model_pointer);
684783
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

src/main/cpp/jllama.h

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)