-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
[V1][Core] Add worker_base for v1 worker #12816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf26f2c
b400e14
c43e6bc
8e7906d
052cd6b
4f86570
86d0705
cb8a099
889c72a
a68c0fe
d3c7075
47ceccf
c5b222c
7c54cd4
84b2dab
793fe42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.kv_cache_interface import KVCacheSpec | ||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0 | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class WorkerBase(WorkerBaseV0): | ||
""" | ||
Abstract class for v1 worker, mainly define some methods for v1. | ||
For methods shared by v0 and v1, define them in v0 WorkerBase | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vllm_config: VllmConfig, | ||
local_rank: int, | ||
rank: int, | ||
distributed_init_method: str, | ||
is_driver_worker: bool = False, | ||
): | ||
""" | ||
Initialize common worker components. | ||
|
||
Args: | ||
vllm_config: Complete vLLM configuration | ||
local_rank: Local device index | ||
rank: Global rank in distributed setup | ||
distributed_init_method: Distributed initialization method | ||
is_driver_worker: Whether this worker handles driver | ||
responsibilities | ||
""" | ||
# Configuration storage | ||
super().__init__(vllm_config=vllm_config) | ||
|
||
self.local_rank = local_rank | ||
self.rank = rank | ||
self.distributed_init_method = distributed_init_method | ||
self.is_driver_worker = is_driver_worker | ||
|
||
# Device and model state | ||
self.device: Optional[torch.device] = None | ||
self.model_runner: Optional[nn.Module] = None | ||
|
||
def get_kv_cache_spec(self) -> KVCacheSpec: | ||
"""Get specifications for KV cache implementation.""" | ||
raise NotImplementedError | ||
|
||
def compile_or_warm_up_model(self) -> None: | ||
"""Prepare model for execution through compilation/warmup.""" | ||
raise NotImplementedError | ||
|
||
def check_health(self) -> None: | ||
"""Basic health check (override for device-specific checks).""" | ||
return |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import dataclasses | ||
import os | ||
import time | ||
from abc import ABC, abstractmethod | ||
from abc import abstractmethod | ||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union | ||
|
||
import cloudpickle | ||
|
@@ -19,15 +19,17 @@ | |
from vllm.sequence import ExecuteModelRequest, IntermediateTensors | ||
from vllm.utils import (enable_trace_function_call_for_thread, | ||
resolve_obj_by_qualname, run_method, | ||
update_environment_variables) | ||
update_environment_variables, | ||
warn_for_unimplemented_methods) | ||
from vllm.worker.model_runner_base import (BroadcastableModelInput, | ||
ModelRunnerBase, | ||
ModelRunnerInputBase) | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class WorkerBase(ABC): | ||
@warn_for_unimplemented_methods | ||
class WorkerBase: | ||
"""Worker interface that allows vLLM to cleanly separate implementations for | ||
different hardware. Also abstracts control plane communication, e.g., to | ||
communicate request metadata to other workers. | ||
|
@@ -53,35 +55,31 @@ def __init__( | |
from vllm.platforms import current_platform | ||
self.current_platform = current_platform | ||
|
||
@abstractmethod | ||
def init_device(self) -> None: | ||
"""Initialize device state, such as loading the model or other on-device | ||
memory allocations. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
"""Determine the number of available blocks for the GPU KV cache and | ||
swappable CPU KV cache. | ||
|
||
The implementation may run profiling or other heuristics to determine | ||
the size of caches. | ||
|
||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks | ||
are blocks that are "active" on the device and can be appended to. | ||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be | ||
appended to. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def initialize_cache(self, num_gpu_blocks: int, | ||
num_cpu_blocks: int) -> None: | ||
"""Initialize the KV cache with the given size in blocks. | ||
""" | ||
raise NotImplementedError | ||
|
||
def get_model(self) -> nn.Module: | ||
raise NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, doing so prevents subclasses from being instantiated if they haven't implemented the abstract method, instead of only erroring out when the method is called. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think that's too restrictive, and hurts the development. erroring out when the method is called looks enough. we can add the implementation step by step. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree to remove unnecessary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To avoid forgetting to implement them at the end, can we add a test that prints out which abstract methods remain to be implemented for each worker subclass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This confuses me. What's the difference between throwing error when the method is called ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can make the test print out warnings instead of failing outright. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The difference is that users won't get unnecessary warnings while developers can remain aware of this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi, @DarkLight1337 and @youkaichao . I propose a decorator to throw warnings for methods not implemented in sub class. it will show something like this. I still put one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me. You can keep the |
||
|
||
def load_model(self) -> None: | ||
"""Load model onto target device.""" | ||
raise NotImplementedError | ||
|
||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> Optional[List[SamplerOutput]]: | ||
raise NotImplementedError | ||
|
||
def start_worker_execution_loop(self) -> None: | ||
"""Execute model loop in parallel worker. | ||
|
||
|
@@ -94,40 +92,43 @@ def start_worker_execution_loop(self) -> None: | |
if output is None: | ||
return None | ||
|
||
@abstractmethod | ||
def get_model(self) -> nn.Module: | ||
raise NotImplementedError | ||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
"""Determine the number of available blocks for the GPU KV cache and | ||
swappable CPU KV cache. | ||
|
||
@abstractmethod | ||
def execute_model( | ||
self, | ||
execute_model_req: Optional[ExecuteModelRequest] = None | ||
) -> Optional[List[SamplerOutput]]: | ||
The implementation may run profiling or other heuristics to determine | ||
the size of caches. | ||
|
||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks | ||
are blocks that are "active" on the device and can be appended to. | ||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be | ||
appended to. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_cache_block_size_bytes(self) -> int: | ||
"""Return the size of a single cache block, in bytes. Used in | ||
speculative decoding. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def add_lora(self, lora_request: LoRARequest) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def remove_lora(self, lora_id: int) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def pin_lora(self, lora_id: int) -> bool: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def list_loras(self) -> Set[int]: | ||
raise NotImplementedError | ||
|
||
@property | ||
def vocab_size(self) -> int: | ||
"""Get vocabulary size from model configuration.""" | ||
return self.model_config.get_vocab_size() | ||
|
||
|
||
class DelegateWorkerBase(WorkerBase): | ||
""" | ||
|
@@ -156,6 +157,10 @@ def initialize_cache(self, num_gpu_blocks: int, | |
num_cpu_blocks: int) -> None: | ||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) | ||
|
||
def load_model(self) -> None: | ||
"""Load model onto target device.""" | ||
self.worker.load_model() | ||
|
||
def get_model(self) -> nn.Module: | ||
return self.worker.get_model() | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.