Skip to content

Commit 39fbaf9

Browse files
committed
llama : redirect external API to internal APIs
ggml-ci
1 parent 66ac80f commit 39fbaf9

9 files changed

+838
-519
lines changed

common/sampling.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
330330
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
331331

332332
// Apply grammar constraints to the single token
333-
llama_grammar_sample(ctx_main, &single_token_data_array, ctx_sampling->grammar);
333+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);
334334

335335
// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
336336
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
@@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
421421

422422
// apply grammar checks before sampling logic
423423
if (apply_grammar && ctx_sampling->grammar != NULL) {
424-
llama_grammar_sample(ctx_main, &cur_p, ctx_sampling->grammar);
424+
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
425425
}
426426

427427
return cur_p;
@@ -455,6 +455,6 @@ void llama_sampling_accept(
455455
ctx_sampling->prev.push_back(id);
456456

457457
if (ctx_sampling->grammar != NULL && apply_grammar) {
458-
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
458+
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
459459
}
460460
}

include/llama.h

+9-5
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,10 @@ extern "C" {
965965
bool remove_special,
966966
bool unparse_special);
967967

968+
//
969+
// Chat templates
970+
//
971+
968972
/// Apply chat template. Inspired by hf apply_chat_template() on python.
969973
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
970974
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
@@ -1005,19 +1009,19 @@ extern "C" {
10051009

10061010
/// @details Apply constraints from grammar
10071011
LLAMA_API void llama_grammar_sample(
1008-
struct llama_context * ctx,
1009-
llama_token_data_array * candidates,
1010-
const struct llama_grammar * grammar);
1011-
LLAMA_API DEPRECATED(bool llama_sample_grammar(
1012+
const struct llama_grammar * grammar,
1013+
const struct llama_context * ctx,
1014+
llama_token_data_array * candidates);
1015+
LLAMA_API DEPRECATED(void llama_sample_grammar(
10121016
struct llama_context * ctx,
10131017
llama_token_data_array * candidates,
10141018
const struct llama_grammar * grammar),
10151019
"use llama_grammar_sample instead");
10161020

10171021
/// @details Accepts the sampled token into the grammar
10181022
LLAMA_API void llama_grammar_accept_token(
1019-
struct llama_context * ctx,
10201023
struct llama_grammar * grammar,
1024+
struct llama_context * ctx,
10211025
llama_token token);
10221026

10231027
//

src/llama-grammar.cpp

+14-12
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ static bool llama_grammar_detect_left_recursion(
384384
// grammar - external
385385
//
386386

387-
struct llama_grammar * llama_grammar_init(
387+
struct llama_grammar * llama_grammar_init_impl(
388388
const llama_grammar_element ** rules,
389389
size_t n_rules,
390390
size_t start_rule_index) {
@@ -441,11 +441,11 @@ struct llama_grammar * llama_grammar_init(
441441
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
442442
}
443443

444-
void llama_grammar_free(struct llama_grammar * grammar) {
444+
void llama_grammar_free_impl(struct llama_grammar * grammar) {
445445
delete grammar;
446446
}
447447

448-
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
448+
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449449
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
450450

451451
// redirect elements in stacks to point to new rules
@@ -464,8 +464,10 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar)
464464
return result;
465465
}
466466

467-
void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
468-
GGML_ASSERT(ctx);
467+
void llama_grammar_sample(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468+
GGML_ASSERT(grammar);
469+
GGML_ASSERT(vocab);
470+
469471
int64_t t_start_sample_us = ggml_time_us();
470472

471473
bool allow_eog = false;
@@ -484,9 +486,9 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c
484486

485487
for (size_t i = 0; i < candidates->size; ++i) {
486488
const llama_token id = candidates->data[i].id;
487-
const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(id);
489+
const std::string & piece = vocab->cache_token_to_piece.at(id);
488490

489-
if (llama_token_is_eog(llama_get_model(ctx), id)) {
491+
if (llama_token_is_eog(*vocab, id)) {
490492
if (!allow_eog) {
491493
candidates->data[i].logit = -INFINITY;
492494
}
@@ -503,13 +505,13 @@ void llama_grammar_sample(struct llama_context * ctx, llama_token_data_array * c
503505
candidates->data[reject.index].logit = -INFINITY;
504506
}
505507

506-
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
508+
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
507509
}
508510

509-
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
511+
void llama_grammar_accept_token(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
510512
const int64_t t_start_sample_us = ggml_time_us();
511513

512-
if (llama_token_is_eog(llama_get_model(ctx), token)) {
514+
if (llama_token_is_eog(*vocab, token)) {
513515
for (const auto & stack : grammar->stacks) {
514516
if (stack.empty()) {
515517
return;
@@ -518,7 +520,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
518520
GGML_ASSERT(false);
519521
}
520522

521-
const std::string & piece = llama_get_vocab(ctx)->cache_token_to_piece.at(token);
523+
const std::string & piece = vocab->cache_token_to_piece.at(token);
522524

523525
// Note terminating 0 in decoded string
524526
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -533,5 +535,5 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
533535
grammar->partial_utf8 = decoded.second;
534536
GGML_ASSERT(!grammar->stacks.empty());
535537

536-
llama_get_sampling(ctx)->t_sample_us += ggml_time_us() - t_start_sample_us;
538+
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
537539
}

src/llama-grammar.h

+22
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44

55
struct llama_vocab;
6+
struct llama_sampling;
67

78
struct llama_grammar {
89
const llama_grammar_rules rules;
@@ -13,3 +14,24 @@ struct llama_grammar {
1314
};
1415

1516
struct llama_grammar * llama_get_grammar(struct llama_context * ctx);
17+
18+
struct llama_grammar * llama_grammar_init_impl(
19+
const llama_grammar_element ** rules,
20+
size_t n_rules,
21+
size_t start_rule_index);
22+
23+
void llama_grammar_free_impl(struct llama_grammar * grammar);
24+
25+
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar);
26+
27+
void llama_grammar_sample(
28+
const struct llama_grammar * grammar,
29+
const struct llama_vocab * vocab,
30+
const struct llama_sampling * smpl,
31+
llama_token_data_array * candidates);
32+
33+
void llama_grammar_accept_token(
34+
struct llama_grammar * grammar,
35+
const struct llama_vocab * vocab,
36+
const struct llama_sampling * smpl,
37+
llama_token token);

0 commit comments

Comments
 (0)