Skip to content

grammars: fix resampling logic regression #7424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
bool is_resampling) {
const llama_sampling_params & params = ctx_sampling->params;

const float temp = params.temp;
Expand All @@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
const float mirostat_eta = params.mirostat_eta;

std::vector<float> original_logits;
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
if (ctx_sampling->grammar != NULL && !is_resampling) {
GGML_ASSERT(!original_logits.empty());
}
llama_token id = 0;
Expand Down Expand Up @@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);

return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
}
}

Expand Down Expand Up @@ -285,8 +285,10 @@ static llama_token_data_array llama_sampling_prepare_impl(
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

if (apply_grammar && original_logits != NULL) {
if (ctx_sampling->grammar != NULL && !apply_grammar) {
GGML_ASSERT(original_logits != NULL);
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
// TODO: if idx >= 0 then use ctx->output_ids.size() as upper bound?
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
}

Expand Down Expand Up @@ -342,7 +344,7 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
const int idx) {
// Call the implementation function with is_resampling set to false by default
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
}

llama_token_data_array llama_sampling_prepare(
Expand Down
4 changes: 2 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ int main(int argc, char ** argv) {

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);

llama_sampling_accept(ctx_sampling, ctx, id, true);
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);

LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());

Expand All @@ -728,7 +728,7 @@ int main(int argc, char ** argv) {

// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);

++n_consumed;
if ((int) embd.size() >= params.n_batch) {
Expand Down
Loading