Skip to content

Commit cf26f2c

Browse files
committed
[V1][Core] Add worker_base for v1 worker
1. reuse WorkerBase in vllm.worker.worker_base 2. remove unnecessary abstract methods and only give warnings for unimplemented methods Signed-off-by: Aoyu <[email protected]>
1 parent 022bcc7 commit cf26f2c

File tree

3 files changed

+166
-49
lines changed

3 files changed

+166
-49
lines changed

vllm/v1/worker/gpu_worker.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2323
from vllm.v1.outputs import ModelRunnerOutput
2424
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
25+
from vllm.v1.worker.worker_base import WorkerBase
2526

2627
logger = init_logger(__name__)
2728

2829
if TYPE_CHECKING:
2930
from vllm.v1.core.scheduler import SchedulerOutput
3031

3132

32-
class Worker:
33+
class Worker(WorkerBase):
3334

3435
def __init__(
3536
self,
@@ -40,23 +41,11 @@ def __init__(
4041
is_driver_worker: bool = False,
4142
):
4243

43-
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
44-
self.vllm_config = vllm_config
45-
self.model_config = vllm_config.model_config
46-
self.cache_config = vllm_config.cache_config
47-
self.lora_config = vllm_config.lora_config
48-
self.load_config = vllm_config.load_config
49-
self.parallel_config = vllm_config.parallel_config
50-
self.scheduler_config = vllm_config.scheduler_config
51-
self.device_config = vllm_config.device_config
52-
self.speculative_config = vllm_config.speculative_config
53-
self.prompt_adapter_config = vllm_config.prompt_adapter_config
54-
self.observability_config = vllm_config.observability_config
55-
56-
self.parallel_config.rank = rank
57-
self.local_rank = local_rank
58-
self.rank = rank
59-
self.distributed_init_method = distributed_init_method
44+
super().__init__(vllm_config=vllm_config,
45+
local_rank=local_rank,
46+
rank=rank,
47+
distributed_init_method=distributed_init_method,
48+
is_driver_worker=is_driver_worker)
6049

6150
if self.model_config.trust_remote_code:
6251
# note: lazy import to avoid importing torch before initializing
@@ -127,7 +116,8 @@ def init_device(self):
127116
set_random_seed(self.model_config.seed)
128117

129118
# Construct the model runner
130-
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
119+
self.model_runner: GPUModelRunner = GPUModelRunner(
120+
self.vllm_config, self.device)
131121

132122
def load_model(self) -> None:
133123
if self.vllm_config.model_config.enable_sleep_mode:

vllm/v1/worker/worker_base.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from abc import abstractmethod
4+
from typing import Optional
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
from vllm.config import VllmConfig
10+
from vllm.logger import init_logger
11+
from vllm.v1.kv_cache_interface import KVCacheSpec
12+
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
13+
14+
logger = init_logger(__name__)
15+
16+
17+
class WorkerBase(WorkerBaseV0):
18+
"""
19+
Abstract class for v1 worker, mainly define some methods for v1.
20+
For methods shared by v0 and v1, define them in v0 WorkerBase
21+
"""
22+
23+
def __init__(
24+
self,
25+
vllm_config: VllmConfig,
26+
local_rank: int,
27+
rank: int,
28+
distributed_init_method: str,
29+
is_driver_worker: bool = False,
30+
):
31+
"""
32+
Initialize common worker components.
33+
34+
Args:
35+
vllm_config: Complete vLLM configuration
36+
local_rank: Local device index
37+
rank: Global rank in distributed setup
38+
distributed_init_method: Distributed initialization method
39+
is_driver_worker: Whether this worker handles driver
40+
responsibilities
41+
"""
42+
# Configuration storage
43+
super().__init__(vllm_config=vllm_config)
44+
45+
self.local_rank = local_rank
46+
self.rank = rank
47+
self.distributed_init_method = distributed_init_method
48+
self.is_driver_worker = is_driver_worker
49+
50+
# Device and model state
51+
self.device: Optional[torch.device] = None
52+
self.model_runner: Optional[nn.Module] = None
53+
54+
@abstractmethod
55+
def get_kv_cache_spec(self) -> KVCacheSpec:
56+
"""Get specifications for KV cache implementation."""
57+
raise NotImplementedError
58+
59+
@abstractmethod
60+
def compile_or_warm_up_model(self) -> None:
61+
"""Prepare model for execution through compilation/warmup."""
62+
raise NotImplementedError
63+
64+
def check_health(self) -> None:
65+
"""Basic health check (override for device-specific checks)."""
66+
return

vllm/worker/worker_base.py

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import os
55
import time
66
from abc import ABC, abstractmethod
7-
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
7+
from functools import wraps
8+
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
9+
TypeVar, Union)
810

911
import cloudpickle
1012
import torch
@@ -27,6 +29,64 @@
2729
logger = init_logger(__name__)
2830

2931

32+
def check_implementation():
33+
"""
34+
A decorator that checks if all abstract methods from the base class
35+
are implemented in the subclass and gives warnings for unimplemented
36+
methods.
37+
"""
38+
39+
def decorator(cls: Type):
40+
original_init = cls.__init__
41+
42+
@wraps(original_init)
43+
def wrapped_init(self, *args, **kwargs):
44+
original_init(self, *args, **kwargs)
45+
unimplemented_methods = []
46+
for attr_name in dir(self):
47+
# bypass inner method
48+
if attr_name.startswith('_'):
49+
continue
50+
base_method = getattr(self, attr_name)
51+
# bypass method already defined
52+
if getattr(base_method, '_avoid_check', False):
53+
continue
54+
# get the func of callable method
55+
if callable(base_method):
56+
base_method_name = base_method.__func__
57+
else:
58+
continue
59+
class_method = getattr(cls, attr_name, False)
60+
# bypass method defined in sub class
61+
if not class_method:
62+
continue
63+
if class_method == base_method_name:
64+
unimplemented_methods.append(attr_name)
65+
if unimplemented_methods:
66+
method_names = ','.join(unimplemented_methods)
67+
msg = (f"Methods {method_names} not implemented in {self}")
68+
logger.warning(msg)
69+
70+
cls.__init__ = wrapped_init
71+
return cls
72+
73+
return decorator
74+
75+
76+
T = TypeVar('T')
77+
78+
79+
def avoid_check(func: Callable[..., T]) -> Callable[..., T]:
80+
81+
@wraps(func)
82+
def wrapper(*args: Any, **kwargs: Any) -> T:
83+
return func(*args, **kwargs)
84+
85+
wrapper._avoid_check = True # type: ignore
86+
return wrapper
87+
88+
89+
@check_implementation()
3090
class WorkerBase(ABC):
3191
"""Worker interface that allows vLLM to cleanly separate implementations for
3292
different hardware. Also abstracts control plane communication, e.g., to
@@ -60,28 +120,26 @@ def init_device(self) -> None:
60120
"""
61121
raise NotImplementedError
62122

63-
@abstractmethod
64-
def determine_num_available_blocks(self) -> Tuple[int, int]:
65-
"""Determine the number of available blocks for the GPU KV cache and
66-
swappable CPU KV cache.
67-
68-
The implementation may run profiling or other heuristics to determine
69-
the size of caches.
70-
71-
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
72-
are blocks that are "active" on the device and can be appended to.
73-
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
74-
appended to.
75-
"""
76-
raise NotImplementedError
77-
78-
@abstractmethod
79123
def initialize_cache(self, num_gpu_blocks: int,
80124
num_cpu_blocks: int) -> None:
81125
"""Initialize the KV cache with the given size in blocks.
82126
"""
83127
raise NotImplementedError
84128

129+
def get_model(self) -> nn.Module:
130+
raise NotImplementedError
131+
132+
def load_model(self) -> None:
133+
"""Load model onto target device."""
134+
raise NotImplementedError
135+
136+
def execute_model(
137+
self,
138+
execute_model_req: Optional[ExecuteModelRequest] = None
139+
) -> Optional[List[SamplerOutput]]:
140+
raise NotImplementedError
141+
142+
@avoid_check
85143
def start_worker_execution_loop(self) -> None:
86144
"""Execute model loop in parallel worker.
87145
@@ -94,40 +152,43 @@ def start_worker_execution_loop(self) -> None:
94152
if output is None:
95153
return None
96154

97-
@abstractmethod
98-
def get_model(self) -> nn.Module:
99-
raise NotImplementedError
155+
def determine_num_available_blocks(self) -> Tuple[int, int]:
156+
"""Determine the number of available blocks for the GPU KV cache and
157+
swappable CPU KV cache.
100158
101-
@abstractmethod
102-
def execute_model(
103-
self,
104-
execute_model_req: Optional[ExecuteModelRequest] = None
105-
) -> Optional[List[SamplerOutput]]:
159+
The implementation may run profiling or other heuristics to determine
160+
the size of caches.
161+
162+
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
163+
are blocks that are "active" on the device and can be appended to.
164+
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
165+
appended to.
166+
"""
106167
raise NotImplementedError
107168

108-
@abstractmethod
109169
def get_cache_block_size_bytes(self) -> int:
110170
"""Return the size of a single cache block, in bytes. Used in
111171
speculative decoding.
112172
"""
113173
raise NotImplementedError
114174

115-
@abstractmethod
116175
def add_lora(self, lora_request: LoRARequest) -> bool:
117176
raise NotImplementedError
118177

119-
@abstractmethod
120178
def remove_lora(self, lora_id: int) -> bool:
121179
raise NotImplementedError
122180

123-
@abstractmethod
124181
def pin_lora(self, lora_id: int) -> bool:
125182
raise NotImplementedError
126183

127-
@abstractmethod
128184
def list_loras(self) -> Set[int]:
129185
raise NotImplementedError
130186

187+
@property
188+
def vocab_size(self) -> int:
189+
"""Get vocabulary size from model configuration."""
190+
return self.model_config.get_vocab_size()
191+
131192

132193
class DelegateWorkerBase(WorkerBase):
133194
"""

0 commit comments

Comments
 (0)