diff --git a/vllm/utils.py b/vllm/utils.py index 6a41afff8f0..cf960140eca 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2253,3 +2253,46 @@ def import_pynvml(): """ import vllm.third_party.pynvml as pynvml return pynvml + + +def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith('_'): + continue + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: + continue + src = inspect.getsource(attr_func) + if "NotImplementedError" in src: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ','.join(unimplemented_methods) + msg = (f"Methods {method_names} not implemented in {self}") + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, '__init__', wrapped_init) + return cls diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ad53f90b866..dbd97c97936 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -21,6 +21,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -28,7 +29,7 @@ from vllm.v1.core.scheduler_output import SchedulerOutput -class Worker: +class Worker(WorkerBase): def __init__( self, @@ -39,23 +40,11 @@ def __init__( is_driver_worker: bool = False, ): - # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method + super().__init__(vllm_config=vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker) if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -126,7 +115,8 @@ def init_device(self): set_random_seed(self.model_config.seed) # Construct the model runner - self.model_runner = GPUModelRunner(self.vllm_config, self.device) + self.model_runner: GPUModelRunner = GPUModelRunner( + self.vllm_config, self.device) def load_model(self) -> None: if self.vllm_config.model_config.enable_sleep_mode: diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py new file mode 100644 index 00000000000..bc7e76c38ae --- /dev/null +++ b/vllm/v1/worker/worker_base.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 + +logger = init_logger(__name__) + + +class WorkerBase(WorkerBaseV0): + """ + Abstract class for v1 worker, mainly define some methods for v1. + For methods shared by v0 and v1, define them in v0 WorkerBase + """ + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + """ + Initialize common worker components. + + Args: + vllm_config: Complete vLLM configuration + local_rank: Local device index + rank: Global rank in distributed setup + distributed_init_method: Distributed initialization method + is_driver_worker: Whether this worker handles driver + responsibilities + """ + # Configuration storage + super().__init__(vllm_config=vllm_config) + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + # Device and model state + self.device: Optional[torch.device] = None + self.model_runner: Optional[nn.Module] = None + + def get_kv_cache_spec(self) -> KVCacheSpec: + """Get specifications for KV cache implementation.""" + raise NotImplementedError + + def compile_or_warm_up_model(self) -> None: + """Prepare model for execution through compilation/warmup.""" + raise NotImplementedError + + def check_health(self) -> None: + """Basic health check (override for device-specific checks).""" + return diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 819b81fbfdb..83fcf0865ae 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -3,7 +3,7 @@ import dataclasses import os import time -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import cloudpickle @@ -19,7 +19,8 @@ from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, - update_environment_variables) + update_environment_variables, + warn_for_unimplemented_methods) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -27,7 +28,8 @@ logger = init_logger(__name__) -class WorkerBase(ABC): +@warn_for_unimplemented_methods +class WorkerBase: """Worker interface that allows vLLM to cleanly separate implementations for different hardware. Also abstracts control plane communication, e.g., to communicate request metadata to other workers. @@ -53,35 +55,31 @@ def __init__( from vllm.platforms import current_platform self.current_platform = current_platform - @abstractmethod def init_device(self) -> None: """Initialize device state, such as loading the model or other on-device memory allocations. """ raise NotImplementedError - @abstractmethod - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available blocks for the GPU KV cache and - swappable CPU KV cache. - - The implementation may run profiling or other heuristics to determine - the size of caches. - - Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks - are blocks that are "active" on the device and can be appended to. - num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be - appended to. - """ - raise NotImplementedError - - @abstractmethod def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Initialize the KV cache with the given size in blocks. """ raise NotImplementedError + def get_model(self) -> nn.Module: + raise NotImplementedError + + def load_model(self) -> None: + """Load model onto target device.""" + raise NotImplementedError + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError + def start_worker_execution_loop(self) -> None: """Execute model loop in parallel worker. @@ -94,40 +92,43 @@ def start_worker_execution_loop(self) -> None: if output is None: return None - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. - @abstractmethod - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ raise NotImplementedError - @abstractmethod def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ raise NotImplementedError - @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError - @abstractmethod def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError - @abstractmethod def list_loras(self) -> Set[int]: raise NotImplementedError + @property + def vocab_size(self) -> int: + """Get vocabulary size from model configuration.""" + return self.model_config.get_vocab_size() + class DelegateWorkerBase(WorkerBase): """ @@ -156,6 +157,10 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + def load_model(self) -> None: + """Load model onto target device.""" + self.worker.load_model() + def get_model(self) -> nn.Module: return self.worker.get_model()