Skip to content

Commit 1cf0719

Browse files
authored
[Minor][Spec Decode] Add use_eagle to SpeculativeConfig (#17213)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 537d5ee commit 1cf0719

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

vllm/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2566,6 +2566,9 @@ def num_lookahead_slots(self) -> int:
25662566
"""
25672567
return self.num_speculative_tokens
25682568

2569+
def use_eagle(self) -> bool:
2570+
return self.method in ("eagle", "eagle3")
2571+
25692572
def __repr__(self) -> str:
25702573
method = self.method
25712574
model = None if method == "ngram" else self.draft_model_config.model

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
self.num_spec_tokens = self.num_lookahead_tokens = 0
127127
if speculative_config:
128128
self.num_spec_tokens = speculative_config.num_speculative_tokens
129-
if speculative_config.method in ("eagle", "eagle3"):
129+
if speculative_config.use_eagle():
130130
self.num_lookahead_tokens = self.num_spec_tokens
131131

132132
def schedule(self) -> SchedulerOutput:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ def __init__(
171171
if get_pp_group().is_last_rank:
172172
if self.speculative_config.method == "ngram":
173173
self.drafter = NgramProposer(self.vllm_config)
174-
elif self.speculative_config.method == "eagle" or \
175-
self.speculative_config.method == "eagle3":
174+
elif self.speculative_config.use_eagle():
176175
self.drafter = EagleProposer(self.vllm_config,
177176
self.device) # type: ignore
178177
if self.speculative_config.method == "eagle3":
@@ -1192,8 +1191,7 @@ def execute_model(
11921191
assert isinstance(self.drafter, NgramProposer)
11931192
spec_token_ids = self.generate_draft_token_ids(
11941193
valid_sampled_token_ids, sampling_metadata)
1195-
elif self.speculative_config.method == "eagle" or \
1196-
self.speculative_config.method == "eagle3":
1194+
elif self.speculative_config.use_eagle():
11971195
assert isinstance(self.drafter, EagleProposer)
11981196
# TODO(woosuk): Refactor the loop.
11991197
next_token_ids: list[int] = []

0 commit comments

Comments
 (0)