File tree 3 files changed +6
-5
lines changed
3 files changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -2566,6 +2566,9 @@ def num_lookahead_slots(self) -> int:
2566
2566
"""
2567
2567
return self .num_speculative_tokens
2568
2568
2569
+ def use_eagle (self ) -> bool :
2570
+ return self .method in ("eagle" , "eagle3" )
2571
+
2569
2572
def __repr__ (self ) -> str :
2570
2573
method = self .method
2571
2574
model = None if method == "ngram" else self .draft_model_config .model
Original file line number Diff line number Diff line change @@ -126,7 +126,7 @@ def __init__(
126
126
self .num_spec_tokens = self .num_lookahead_tokens = 0
127
127
if speculative_config :
128
128
self .num_spec_tokens = speculative_config .num_speculative_tokens
129
- if speculative_config .method in ( "eagle" , "eagle3" ):
129
+ if speculative_config .use_eagle ( ):
130
130
self .num_lookahead_tokens = self .num_spec_tokens
131
131
132
132
def schedule (self ) -> SchedulerOutput :
Original file line number Diff line number Diff line change @@ -171,8 +171,7 @@ def __init__(
171
171
if get_pp_group ().is_last_rank :
172
172
if self .speculative_config .method == "ngram" :
173
173
self .drafter = NgramProposer (self .vllm_config )
174
- elif self .speculative_config .method == "eagle" or \
175
- self .speculative_config .method == "eagle3" :
174
+ elif self .speculative_config .use_eagle ():
176
175
self .drafter = EagleProposer (self .vllm_config ,
177
176
self .device ) # type: ignore
178
177
if self .speculative_config .method == "eagle3" :
@@ -1192,8 +1191,7 @@ def execute_model(
1192
1191
assert isinstance (self .drafter , NgramProposer )
1193
1192
spec_token_ids = self .generate_draft_token_ids (
1194
1193
valid_sampled_token_ids , sampling_metadata )
1195
- elif self .speculative_config .method == "eagle" or \
1196
- self .speculative_config .method == "eagle3" :
1194
+ elif self .speculative_config .use_eagle ():
1197
1195
assert isinstance (self .drafter , EagleProposer )
1198
1196
# TODO(woosuk): Refactor the loop.
1199
1197
next_token_ids : list [int ] = []
You can’t perform that action at this time.
0 commit comments