Skip to content

Commit fda1a97

Browse files
HanClintoggerganov
authored andcommitted
Updated Swift and Android bindings to use the new llama_sampling_* refactor from #8643 (#8651)
1 parent 32a7dc9 commit fda1a97

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

Diff for: examples/batched.swift/Sources/main.swift

+8-6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ context_params.n_threads = 8
4444
context_params.n_threads_batch = 8
4545

4646
let context = llama_new_context_with_model(model, context_params)
47+
let smpl = llama_get_sampling(context)
48+
4749
guard context != nil else {
4850
print("Failed to initialize context")
4951
exit(1)
@@ -144,13 +146,13 @@ while n_cur <= n_len {
144146
let top_p: Float = 0.9
145147
let temp: Float = 0.4
146148

147-
llama_sample_top_k(context, &candidates_p, top_k, 1)
148-
llama_sample_top_p(context, &candidates_p, top_p, 1)
149-
llama_sample_temp(context, &candidates_p, temp)
149+
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
150+
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
151+
llama_sampling_temp(smpl, &candidates_p, temp)
150152

151-
let new_token_id = llama_sample_token(context, &candidates_p)
153+
let new_token_id = llama_sampling_sample(smpl, &candidates_p)
152154

153-
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
155+
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
154156

155157
// is it an end of stream? -> mark the stream as finished
156158
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
@@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()
212214

213215
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
214216

215-
llama_print_timings(context)
217+
llama_print_timings(context, smpl, nil)
216218

217219
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
218220
let utf8Count = text.utf8.count

Diff for: examples/llama.android/llama/src/main/cpp/llama-android.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
385385
jobject intvar_ncur
386386
) {
387387
const auto context = reinterpret_cast<llama_context *>(context_pointer);
388+
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
388389
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
389390
const auto model = llama_get_model(context);
390391

@@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
405406
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
406407

407408
// sample the most likely token
408-
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
409+
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);
409410

410411
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
411412
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

Diff for: examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
2424
actor LlamaContext {
2525
private var model: OpaquePointer
2626
private var context: OpaquePointer
27+
private var sampling: OpaquePointer
2728
private var batch: llama_batch
2829
private var tokens_list: [llama_token]
2930
var is_done: Bool = false
@@ -42,12 +43,14 @@ actor LlamaContext {
4243
self.tokens_list = []
4344
self.batch = llama_batch_init(512, 0, 1)
4445
self.temporary_invalid_cchars = []
46+
self.sampling = llama_get_sampling(context)
4547
}
4648

4749
deinit {
4850
llama_batch_free(batch)
4951
llama_free(context)
5052
llama_free_model(model)
53+
llama_sampling_free(sampling)
5154
llama_backend_free()
5255
}
5356

@@ -156,7 +159,7 @@ actor LlamaContext {
156159
candidates.withUnsafeMutableBufferPointer() { buffer in
157160
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
158161

159-
new_token_id = llama_sample_token_greedy(context, &candidates_p)
162+
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
160163
}
161164

162165
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {

Diff for: include/llama.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,7 @@ extern "C" {
11481148
float * mu);
11491149

11501150
/// @details Selects the token with the highest probability.
1151-
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
1151+
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
11521152
LLAMA_API llama_token llama_sampling_sample_greedy(
11531153
struct llama_sampling * smpl,
11541154
llama_token_data_array * candidates);

0 commit comments

Comments
 (0)