@@ -82,7 +82,7 @@ int main(int argc, char ** argv) {
82
82
// GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
83
83
84
84
// how many tokens to draft each time
85
- const int n_draft = params.n_draft ;
85
+ int n_draft = params.n_draft ;
86
86
87
87
int n_predict = 0 ;
88
88
int n_drafted = 0 ;
@@ -131,6 +131,7 @@ int main(int argc, char ** argv) {
131
131
LOG (" drafted: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx_dft, drafted));
132
132
133
133
int i_dft = 0 ;
134
+
134
135
while (true ) {
135
136
// sample from the target model
136
137
const llama_token id = llama_sample_token (ctx_tgt, NULL , grammar_tgt, params, last_tokens, candidates, i_dft);
@@ -174,6 +175,27 @@ int main(int argc, char ** argv) {
174
175
llama_eval (ctx_dft, &id, 1 , n_past_dft, params.n_threads );
175
176
++n_past_dft;
176
177
178
+ // heuristic for n_draft
179
+ {
180
+ const int n_draft_cur = (int ) drafted.size ();
181
+ const bool all_accepted = i_dft == n_draft_cur;
182
+
183
+ LOG (" n_draft = %d\n " , n_draft);
184
+ LOG (" n_draft_cur = %d\n " , n_draft_cur);
185
+ LOG (" i_dft = %d\n " , i_dft);
186
+ LOG (" all_accepted = %d\n " , all_accepted);
187
+
188
+ if (all_accepted && n_draft == n_draft_cur) {
189
+ LOG (" - max drafted tokens accepted - n_draft += 8\n " );
190
+ n_draft = std::min (30 , n_draft + 8 );
191
+ } else if (all_accepted) {
192
+ LOG (" - partially drafted tokens accepted - no change\n " );
193
+ } else {
194
+ LOG (" - drafted token rejected - n_draft -= 1\n " );
195
+ n_draft = std::max (2 , n_draft - 1 );
196
+ }
197
+ }
198
+
177
199
drafted.clear ();
178
200
drafted.push_back (id);
179
201
0 commit comments