diff --git a/vllm/config.py b/vllm/config.py index 0ac3cc46b06..d8f880d26e9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2566,6 +2566,9 @@ def num_lookahead_slots(self) -> int: """ return self.num_speculative_tokens + def use_eagle(self) -> bool: + return self.method in ("eagle", "eagle3") + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index adec4462963..2833e330e6a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -126,7 +126,7 @@ def __init__( self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens - if speculative_config.method in ("eagle", "eagle3"): + if speculative_config.use_eagle(): self.num_lookahead_tokens = self.num_spec_tokens def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7910481762e..19c2a33349c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -171,8 +171,7 @@ def __init__( if get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) - elif self.speculative_config.method == "eagle" or \ - self.speculative_config.method == "eagle3": + elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device) # type: ignore if self.speculative_config.method == "eagle3": @@ -1192,8 +1191,7 @@ def execute_model( assert isinstance(self.drafter, NgramProposer) spec_token_ids = self.generate_draft_token_ids( valid_sampled_token_ids, sampling_metadata) - elif self.speculative_config.method == "eagle" or \ - self.speculative_config.method == "eagle3": + elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = []