Skip to content

llama : update Swift and Android bindings for refactor sampling #8651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ context_params.n_threads = 8
context_params.n_threads_batch = 8

let context = llama_new_context_with_model(model, context_params)
let smpl = llama_get_sampling(context)

guard context != nil else {
print("Failed to initialize context")
exit(1)
Expand Down Expand Up @@ -144,13 +146,13 @@ while n_cur <= n_len {
let top_p: Float = 0.9
let temp: Float = 0.4

llama_sample_top_k(context, &candidates_p, top_k, 1)
llama_sample_top_p(context, &candidates_p, top_p, 1)
llama_sample_temp(context, &candidates_p, temp)
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
llama_sampling_temp(smpl, &candidates_p, temp)

let new_token_id = llama_sample_token(context, &candidates_p)
let new_token_id = llama_sampling_sample(smpl, &candidates_p)

// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);

// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down Expand Up @@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()

print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")

llama_print_timings(context)
llama_print_timings(context, smpl, nil)

private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
Expand Down
3 changes: 2 additions & 1 deletion examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto model = llama_get_model(context);

Expand All @@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// sample the most likely token
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);

const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
Expand Down
5 changes: 4 additions & 1 deletion examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var sampling: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
Expand All @@ -42,12 +43,14 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_get_sampling(context)
}

deinit {
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
llama_sampling_free(sampling)
llama_backend_free()
}

Expand Down Expand Up @@ -156,7 +159,7 @@ actor LlamaContext {
candidates.withUnsafeMutableBufferPointer() { buffer in
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)

new_token_id = llama_sample_token_greedy(context, &candidates_p)
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
}

if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ extern "C" {
float * mu);

/// @details Selects the token with the highest probability.
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
LLAMA_API llama_token llama_sampling_sample_greedy(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
Expand Down
Loading