Skip to content

Commit 3d2e41a

Browse files
leng-yueggerganov
authored andcommitted
speculative : add heuristic algorithm (ggml-org#3006)
* Add heuristic algo for speculative * Constrain minimum n_draft to 2 * speculative : improve heuristic impl * speculative : be more rewarding upon guessing max drafted tokens * speculative : fix typos --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 86ad534 commit 3d2e41a

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

examples/speculative/speculative.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
8282
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
8383

8484
// how many tokens to draft each time
85-
const int n_draft = params.n_draft;
85+
int n_draft = params.n_draft;
8686

8787
int n_predict = 0;
8888
int n_drafted = 0;
@@ -131,6 +131,7 @@ int main(int argc, char ** argv) {
131131
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
132132

133133
int i_dft = 0;
134+
134135
while (true) {
135136
// sample from the target model
136137
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
@@ -174,6 +175,27 @@ int main(int argc, char ** argv) {
174175
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
175176
++n_past_dft;
176177

178+
// heuristic for n_draft
179+
{
180+
const int n_draft_cur = (int) drafted.size();
181+
const bool all_accepted = i_dft == n_draft_cur;
182+
183+
LOG("n_draft = %d\n", n_draft);
184+
LOG("n_draft_cur = %d\n", n_draft_cur);
185+
LOG("i_dft = %d\n", i_dft);
186+
LOG("all_accepted = %d\n", all_accepted);
187+
188+
if (all_accepted && n_draft == n_draft_cur) {
189+
LOG(" - max drafted tokens accepted - n_draft += 8\n");
190+
n_draft = std::min(30, n_draft + 8);
191+
} else if (all_accepted) {
192+
LOG(" - partially drafted tokens accepted - no change\n");
193+
} else {
194+
LOG(" - drafted token rejected - n_draft -= 1\n");
195+
n_draft = std::max(2, n_draft - 1);
196+
}
197+
}
198+
177199
drafted.clear();
178200
drafted.push_back(id);
179201

0 commit comments

Comments
 (0)