Skip to content

Commit ddf0c1b

Browse files
njhillkwang1012
authored andcommitted
[V1] Use msgpack for core request serialization (vllm-project#12918)
Signed-off-by: Nick Hill <[email protected]>
1 parent afb476f commit ddf0c1b

File tree

4 files changed

+62
-95
lines changed

4 files changed

+62
-95
lines changed

vllm/v1/engine/__init__.py

+14-28
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import enum
4-
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, List, Optional, Union
4+
from typing import List, Optional, Union
65

76
import msgspec
87

8+
from vllm.lora.request import LoRARequest
9+
from vllm.multimodal import MultiModalKwargs
10+
from vllm.multimodal.inputs import PlaceholderRange
11+
from vllm.sampling_params import SamplingParams
912
from vllm.v1.metrics.stats import SchedulerStats
1013
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
1114

12-
if TYPE_CHECKING:
13-
from vllm.lora.request import LoRARequest
14-
from vllm.multimodal import MultiModalKwargs
15-
from vllm.multimodal.inputs import PlaceholderRange
16-
from vllm.sampling_params import SamplingParams
17-
1815
# These are possible values of RequestOutput.finish_reason,
1916
# so form part of the external API.
2017
FINISH_REASON_STRINGS = ("stop", "length", "abort")
@@ -39,8 +36,11 @@ def __str__(self):
3936
return FINISH_REASON_STRINGS[self.value]
4037

4138

42-
@dataclass
43-
class EngineCoreRequest:
39+
class EngineCoreRequest(
40+
msgspec.Struct,
41+
array_like=True, # type: ignore[call-arg]
42+
omit_defaults=True, # type: ignore[call-arg]
43+
gc=False): # type: ignore[call-arg]
4444

4545
# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
4646
# but this object is currently not playing well with msgspec
@@ -51,13 +51,13 @@ class EngineCoreRequest:
5151
# Detokenizer, but set to None when it is added to EngineCoreClient.
5252
prompt: Optional[str]
5353
prompt_token_ids: List[int]
54-
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
54+
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
5555
mm_hashes: Optional[List[str]]
56-
mm_placeholders: Optional[List["PlaceholderRange"]]
57-
sampling_params: "SamplingParams"
56+
mm_placeholders: Optional[List[PlaceholderRange]]
57+
sampling_params: SamplingParams
5858
eos_token_id: Optional[int]
5959
arrival_time: float
60-
lora_request: Optional["LoRARequest"]
60+
lora_request: Optional[LoRARequest]
6161

6262

6363
class EngineCoreOutput(
@@ -94,16 +94,6 @@ class EngineCoreOutputs(
9494
scheduler_stats: SchedulerStats
9595

9696

97-
@dataclass
98-
class EngineCoreProfile:
99-
is_start: bool
100-
101-
102-
@dataclass
103-
class EngineCoreResetPrefixCache:
104-
pass
105-
106-
10797
class EngineCoreRequestType(enum.Enum):
10898
"""
10999
Request types defined as hex byte strings, so it can be sent over sockets
@@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
113103
ABORT = b'\x01'
114104
PROFILE = b'\x02'
115105
RESET_PREFIX_CACHE = b'\x03'
116-
117-
118-
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
119-
EngineCoreResetPrefixCache, List[str]]

vllm/v1/engine/core.py

+26-35
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
import pickle
43
import queue
54
import signal
65
import threading
76
import time
87
from multiprocessing.connection import Connection
9-
from typing import List, Tuple, Type
8+
from typing import Any, List, Tuple, Type
109

1110
import psutil
1211
import zmq
@@ -19,13 +18,12 @@
1918
from vllm.utils import get_exception_traceback, zmq_socket_ctx
2019
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
2120
from vllm.v1.core.scheduler import Scheduler
22-
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
23-
EngineCoreRequest, EngineCoreRequestType,
24-
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
21+
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
22+
EngineCoreRequestType)
2523
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
2624
from vllm.v1.executor.abstract import Executor
2725
from vllm.v1.request import Request, RequestStatus
28-
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
26+
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
2927
from vllm.version import __version__ as VLLM_VERSION
3028

3129
logger = init_logger(__name__)
@@ -171,7 +169,8 @@ def __init__(
171169
# and to overlap some serialization/deserialization with the
172170
# model forward pass.
173171
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
174-
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
172+
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
173+
Any]] = queue.Queue()
175174
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
176175
threading.Thread(target=self.process_input_socket,
177176
args=(input_path, ),
@@ -233,7 +232,7 @@ def run_busy_loop(self):
233232
while True:
234233
try:
235234
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
236-
self._handle_client_request(req)
235+
self._handle_client_request(*req)
237236
break
238237
except queue.Empty:
239238
logger.debug("EngineCore busy loop waiting.")
@@ -243,59 +242,51 @@ def run_busy_loop(self):
243242
except BaseException:
244243
raise
245244

246-
# 2) Handle any new client requests (Abort or Add).
245+
# 2) Handle any new client requests.
247246
while not self.input_queue.empty():
248247
req = self.input_queue.get_nowait()
249-
self._handle_client_request(req)
248+
self._handle_client_request(*req)
250249

251250
# 3) Step the engine core.
252251
outputs = self.step()
253252

254253
# 5) Put EngineCoreOutputs into the output queue.
255254
self.output_queue.put_nowait(outputs)
256255

257-
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
258-
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
256+
def _handle_client_request(self, request_type: EngineCoreRequestType,
257+
request: Any) -> None:
258+
"""Dispatch request from client."""
259259

260-
if isinstance(request, EngineCoreRequest):
260+
if request_type == EngineCoreRequestType.ADD:
261261
self.add_request(request)
262-
elif isinstance(request, EngineCoreProfile):
263-
self.model_executor.profile(request.is_start)
264-
elif isinstance(request, EngineCoreResetPrefixCache):
265-
self.reset_prefix_cache()
266-
else:
267-
# TODO: make an EngineCoreAbort wrapper
268-
assert isinstance(request, list)
262+
elif request_type == EngineCoreRequestType.ABORT:
269263
self.abort_requests(request)
264+
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
265+
self.reset_prefix_cache()
266+
elif request_type == EngineCoreRequestType.PROFILE:
267+
self.model_executor.profile(request)
270268

271269
def process_input_socket(self, input_path: str):
272270
"""Input socket IO thread."""
273271

274272
# Msgpack serialization decoding.
275-
decoder_add_req = PickleEncoder()
276-
decoder_abort_req = PickleEncoder()
273+
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
274+
generic_decoder = MsgpackDecoder()
277275

278276
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
279277
while True:
280278
# (RequestType, RequestData)
281279
type_frame, data_frame = socket.recv_multipart(copy=False)
282-
request_type = type_frame.buffer
283-
request_data = data_frame.buffer
280+
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
284281

285282
# Deserialize the request data.
286-
if request_type == EngineCoreRequestType.ADD.value:
287-
request = decoder_add_req.decode(request_data)
288-
elif request_type == EngineCoreRequestType.ABORT.value:
289-
request = decoder_abort_req.decode(request_data)
290-
elif request_type in (
291-
EngineCoreRequestType.PROFILE.value,
292-
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
293-
request = pickle.loads(request_data)
294-
else:
295-
raise ValueError(f"Unknown RequestType: {request_type}")
283+
decoder = add_request_decoder if (
284+
request_type
285+
== EngineCoreRequestType.ADD) else generic_decoder
286+
request = decoder.decode(data_frame.buffer)
296287

297288
# Push to input queue for core busy loop.
298-
self.input_queue.put_nowait(request)
289+
self.input_queue.put_nowait((request_type, request))
299290

300291
def process_output_socket(self, output_path: str):
301292
"""Output socket IO thread."""

vllm/v1/engine/core_client.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import signal
66
import weakref
77
from abc import ABC, abstractmethod
8-
from typing import List, Optional, Type
8+
from typing import Any, List, Optional, Type
99

1010
import zmq
1111
import zmq.asyncio
@@ -14,12 +14,11 @@
1414
from vllm.logger import init_logger
1515
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
1616
make_zmq_socket)
17-
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
18-
EngineCoreRequest, EngineCoreRequestType,
19-
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
17+
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
18+
EngineCoreRequestType)
2019
from vllm.v1.engine.core import EngineCore, EngineCoreProc
2120
from vllm.v1.executor.abstract import Executor
22-
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
21+
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
2322
from vllm.v1.utils import BackgroundProcHandle
2423

2524
logger = init_logger(__name__)
@@ -161,7 +160,7 @@ def sigusr1_handler(signum, frame):
161160
signal.signal(signal.SIGUSR1, sigusr1_handler)
162161

163162
# Serialization setup.
164-
self.encoder = PickleEncoder()
163+
self.encoder = MsgpackEncoder()
165164
self.decoder = MsgpackDecoder(EngineCoreOutputs)
166165

167166
# ZMQ setup.
@@ -220,7 +219,7 @@ def get_output(self) -> EngineCoreOutputs:
220219
return self.decoder.decode(frame.buffer)
221220

222221
def _send_input(self, request_type: EngineCoreRequestType,
223-
request: EngineCoreRequestUnion) -> None:
222+
request: Any) -> None:
224223

225224
# (RequestType, SerializedRequest)
226225
msg = (request_type.value, self.encoder.encode(request))
@@ -237,12 +236,10 @@ def abort_requests(self, request_ids: List[str]) -> None:
237236
self._send_input(EngineCoreRequestType.ABORT, request_ids)
238237

239238
def profile(self, is_start: bool = True) -> None:
240-
self._send_input(EngineCoreRequestType.PROFILE,
241-
EngineCoreProfile(is_start))
239+
self._send_input(EngineCoreRequestType.PROFILE, is_start)
242240

243241
def reset_prefix_cache(self) -> None:
244-
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
245-
EngineCoreResetPrefixCache())
242+
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
246243

247244

248245
class AsyncMPClient(MPClient):
@@ -277,7 +274,7 @@ async def process_outputs_socket():
277274
return self.decoder.decode(await self.outputs_queue.get())
278275

279276
async def _send_input(self, request_type: EngineCoreRequestType,
280-
request: EngineCoreRequestUnion) -> None:
277+
request: Any) -> None:
281278

282279
msg = (request_type.value, self.encoder.encode(request))
283280
await self.input_socket.send_multipart(msg, copy=False)
@@ -293,9 +290,7 @@ async def abort_requests_async(self, request_ids: List[str]) -> None:
293290
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
294291

295292
async def profile_async(self, is_start: bool = True) -> None:
296-
await self._send_input(EngineCoreRequestType.PROFILE,
297-
EngineCoreProfile(is_start))
293+
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
298294

299295
async def reset_prefix_cache_async(self) -> None:
300-
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
301-
EngineCoreResetPrefixCache())
296+
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)

vllm/v1/serial_utils.py

+11-16
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pickle
4-
from typing import Any
4+
from typing import Any, Optional
55

66
import torch
77
from msgspec import msgpack
88

9-
CUSTOM_TYPE_CODE_PICKLE = 1
10-
11-
12-
class PickleEncoder:
13-
14-
def encode(self, obj: Any):
15-
return pickle.dumps(obj)
16-
17-
def decode(self, data: Any):
18-
return pickle.loads(data)
9+
CUSTOM_TYPE_TENSOR = 1
10+
CUSTOM_TYPE_PICKLE = 2
1911

2012

2113
class MsgpackEncoder:
@@ -34,8 +26,9 @@ def encode_into(self, obj: Any, buf: bytearray) -> None:
3426
class MsgpackDecoder:
3527
"""Decoder with custom torch tensor serialization."""
3628

37-
def __init__(self, t: Any):
38-
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
29+
def __init__(self, t: Optional[Any] = None):
30+
args = () if t is None else (t, )
31+
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
3932

4033
def decode(self, obj: Any):
4134
return self.decoder.decode(obj)
@@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
4639
# NOTE(rob): it is fastest to use numpy + pickle
4740
# when serializing torch tensors.
4841
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
49-
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
42+
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
5043

51-
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
44+
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
5245

5346

5447
def custom_ext_hook(code: int, data: memoryview) -> Any:
55-
if code == CUSTOM_TYPE_CODE_PICKLE:
48+
if code == CUSTOM_TYPE_TENSOR:
5649
return torch.from_numpy(pickle.loads(data))
50+
if code == CUSTOM_TYPE_PICKLE:
51+
return pickle.loads(data)
5752

5853
raise NotImplementedError(f"Extension type code {code} is not supported")

0 commit comments

Comments
 (0)