Skip to content

Check the full vocab for grammar only if necessary #4306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand All @@ -128,8 +129,17 @@ llama_token llama_sampling_sample(

llama_token id = 0;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
Expand Down Expand Up @@ -165,7 +175,8 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL) {
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

Expand Down Expand Up @@ -212,6 +223,29 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL && !is_resampling) {
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
Comment on lines +268 to +275
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could probably expose something more straightforward to check this if we want. It would probably be like llama_grammar_accept_token but returning a bool instead of throwing.


// If the token is not valid according to the grammar, perform resampling
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());

// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);

// Recursively call llama_sampling_sample to resample with the grammar checks applied first
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
}

return id;
}

Expand Down
10 changes: 6 additions & 4 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,18 @@ std::string llama_sampling_print(const llama_sampling_params & params);
// optional:
// - ctx_cfg: context to use for classifier-free guidance
// - idx: sample from llama_get_logits_ith(ctx, idx)
// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool is_resampling = false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should merge this after we fix this:

is_resampling is an implementation detail and should not be exposed in the API, because somebody can start using it or get confused what it is for.

Change llama_sampling_sample to not accept is_resampling and to call llama_sampling_sample_impl which in turn accepts is_resampling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into this soon


void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ int main(int argc, char ** argv) {

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down