4
4
import os
5
5
import time
6
6
from abc import ABC , abstractmethod
7
- from typing import Any , Dict , List , Optional , Set , Tuple , Type , Union
7
+ from functools import wraps
8
+ from typing import (Any , Callable , Dict , List , Optional , Set , Tuple , Type ,
9
+ TypeVar , Union )
8
10
9
11
import cloudpickle
10
12
import torch
27
29
logger = init_logger (__name__ )
28
30
29
31
32
+ def check_implementation ():
33
+ """
34
+ A decorator that checks if all abstract methods from the base class
35
+ are implemented in the subclass and gives warnings for unimplemented
36
+ methods.
37
+ """
38
+
39
+ def decorator (cls : Type ):
40
+ original_init = cls .__init__
41
+
42
+ @wraps (original_init )
43
+ def wrapped_init (self , * args , ** kwargs ):
44
+ original_init (self , * args , ** kwargs )
45
+ unimplemented_methods = []
46
+ for attr_name in dir (self ):
47
+ # bypass inner method
48
+ if attr_name .startswith ('_' ):
49
+ continue
50
+ base_method = getattr (self , attr_name )
51
+ # bypass method already defined
52
+ if getattr (base_method , '_avoid_check' , False ):
53
+ continue
54
+ # get the func of callable method
55
+ if callable (base_method ):
56
+ base_method_name = base_method .__func__
57
+ else :
58
+ continue
59
+ class_method = getattr (cls , attr_name , False )
60
+ # bypass method defined in sub class
61
+ if not class_method :
62
+ continue
63
+ if class_method == base_method_name :
64
+ unimplemented_methods .append (attr_name )
65
+ if unimplemented_methods :
66
+ method_names = ',' .join (unimplemented_methods )
67
+ msg = (f"Methods { method_names } not implemented in { self } " )
68
+ logger .warning (msg )
69
+
70
+ cls .__init__ = wrapped_init
71
+ return cls
72
+
73
+ return decorator
74
+
75
+
76
+ T = TypeVar ('T' )
77
+
78
+
79
+ def avoid_check (func : Callable [..., T ]) -> Callable [..., T ]:
80
+
81
+ @wraps (func )
82
+ def wrapper (* args : Any , ** kwargs : Any ) -> T :
83
+ return func (* args , ** kwargs )
84
+
85
+ wrapper ._avoid_check = True # type: ignore
86
+ return wrapper
87
+
88
+
89
+ @check_implementation ()
30
90
class WorkerBase (ABC ):
31
91
"""Worker interface that allows vLLM to cleanly separate implementations for
32
92
different hardware. Also abstracts control plane communication, e.g., to
@@ -60,28 +120,26 @@ def init_device(self) -> None:
60
120
"""
61
121
raise NotImplementedError
62
122
63
- @abstractmethod
64
- def determine_num_available_blocks (self ) -> Tuple [int , int ]:
65
- """Determine the number of available blocks for the GPU KV cache and
66
- swappable CPU KV cache.
67
-
68
- The implementation may run profiling or other heuristics to determine
69
- the size of caches.
70
-
71
- Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
72
- are blocks that are "active" on the device and can be appended to.
73
- num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
74
- appended to.
75
- """
76
- raise NotImplementedError
77
-
78
- @abstractmethod
79
123
def initialize_cache (self , num_gpu_blocks : int ,
80
124
num_cpu_blocks : int ) -> None :
81
125
"""Initialize the KV cache with the given size in blocks.
82
126
"""
83
127
raise NotImplementedError
84
128
129
+ def get_model (self ) -> nn .Module :
130
+ raise NotImplementedError
131
+
132
+ def load_model (self ) -> None :
133
+ """Load model onto target device."""
134
+ raise NotImplementedError
135
+
136
+ def execute_model (
137
+ self ,
138
+ execute_model_req : Optional [ExecuteModelRequest ] = None
139
+ ) -> Optional [List [SamplerOutput ]]:
140
+ raise NotImplementedError
141
+
142
+ @avoid_check
85
143
def start_worker_execution_loop (self ) -> None :
86
144
"""Execute model loop in parallel worker.
87
145
@@ -94,40 +152,43 @@ def start_worker_execution_loop(self) -> None:
94
152
if output is None :
95
153
return None
96
154
97
- @ abstractmethod
98
- def get_model ( self ) -> nn . Module :
99
- raise NotImplementedError
155
+ def determine_num_available_blocks ( self ) -> Tuple [ int , int ]:
156
+ """Determine the number of available blocks for the GPU KV cache and
157
+ swappable CPU KV cache.
100
158
101
- @abstractmethod
102
- def execute_model (
103
- self ,
104
- execute_model_req : Optional [ExecuteModelRequest ] = None
105
- ) -> Optional [List [SamplerOutput ]]:
159
+ The implementation may run profiling or other heuristics to determine
160
+ the size of caches.
161
+
162
+ Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
163
+ are blocks that are "active" on the device and can be appended to.
164
+ num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
165
+ appended to.
166
+ """
106
167
raise NotImplementedError
107
168
108
- @abstractmethod
109
169
def get_cache_block_size_bytes (self ) -> int :
110
170
"""Return the size of a single cache block, in bytes. Used in
111
171
speculative decoding.
112
172
"""
113
173
raise NotImplementedError
114
174
115
- @abstractmethod
116
175
def add_lora (self , lora_request : LoRARequest ) -> bool :
117
176
raise NotImplementedError
118
177
119
- @abstractmethod
120
178
def remove_lora (self , lora_id : int ) -> bool :
121
179
raise NotImplementedError
122
180
123
- @abstractmethod
124
181
def pin_lora (self , lora_id : int ) -> bool :
125
182
raise NotImplementedError
126
183
127
- @abstractmethod
128
184
def list_loras (self ) -> Set [int ]:
129
185
raise NotImplementedError
130
186
187
+ @property
188
+ def vocab_size (self ) -> int :
189
+ """Get vocabulary size from model configuration."""
190
+ return self .model_config .get_vocab_size ()
191
+
131
192
132
193
class DelegateWorkerBase (WorkerBase ):
133
194
"""
0 commit comments