Skip to content

[V1] Use msgpack for core request serialization #12918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 14 additions & 28 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
# SPDX-License-Identifier: Apache-2.0

import enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional, Union

import msgspec

from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors

if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams

# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort")
Expand All @@ -39,8 +36,11 @@ def __str__(self):
return FINISH_REASON_STRINGS[self.value]


@dataclass
class EngineCoreRequest:
class EngineCoreRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

# NOTE: prompt and prompt_token_ids should be DecoderOnlyInput,
# but this object is currently not playing well with msgspec
Expand All @@ -51,13 +51,13 @@ class EngineCoreRequest:
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: List[int]
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
mm_hashes: Optional[List[str]]
mm_placeholders: Optional[List["PlaceholderRange"]]
sampling_params: "SamplingParams"
mm_placeholders: Optional[List[PlaceholderRange]]
sampling_params: SamplingParams
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional["LoRARequest"]
lora_request: Optional[LoRARequest]


class EngineCoreOutput(
Expand Down Expand Up @@ -94,16 +94,6 @@ class EngineCoreOutputs(
scheduler_stats: SchedulerStats


@dataclass
class EngineCoreProfile:
is_start: bool


@dataclass
class EngineCoreResetPrefixCache:
pass


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
Expand All @@ -113,7 +103,3 @@ class EngineCoreRequestType(enum.Enum):
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'


EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]
61 changes: 26 additions & 35 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

import pickle
import queue
import signal
import threading
import time
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
from typing import Any, List, Tuple, Type

import psutil
import zmq
Expand All @@ -19,13 +18,12 @@
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.version import __version__ as VLLM_VERSION

logger = init_logger(__name__)
Expand Down Expand Up @@ -161,7 +159,8 @@ def __init__(
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
Expand Down Expand Up @@ -223,7 +222,7 @@ def run_busy_loop(self):
while True:
try:
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
self._handle_client_request(req)
self._handle_client_request(*req)
break
except queue.Empty:
logger.debug("EngineCore busy loop waiting.")
Expand All @@ -233,59 +232,51 @@ def run_busy_loop(self):
except BaseException:
raise

# 2) Handle any new client requests (Abort or Add).
# 2) Handle any new client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(req)
self._handle_client_request(*req)

# 3) Step the engine core.
outputs = self.step()

# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)

def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""

if isinstance(request, EngineCoreRequest):
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)

def process_input_socket(self, input_path: str):
"""Input socket IO thread."""

# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()

with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer
request_type = EngineCoreRequestType(bytes(type_frame.buffer))

# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)

# Push to input queue for core busy loop.
self.input_queue.put_nowait(request)
self.input_queue.put_nowait((request_type, request))

def process_output_socket(self, output_path: str):
"""Output socket IO thread."""
Expand Down
27 changes: 11 additions & 16 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import signal
import weakref
from abc import ABC, abstractmethod
from typing import List, Optional, Type
from typing import Any, List, Optional, Type

import zmq
import zmq.asyncio
Expand All @@ -14,12 +14,11 @@
from vllm.logger import init_logger
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.utils import BackgroundProcHandle

logger = init_logger(__name__)
Expand Down Expand Up @@ -161,7 +160,7 @@ def sigusr1_handler(signum, frame):
signal.signal(signal.SIGUSR1, sigusr1_handler)

# Serialization setup.
self.encoder = PickleEncoder()
self.encoder = MsgpackEncoder()
self.decoder = MsgpackDecoder(EngineCoreOutputs)

# ZMQ setup.
Expand Down Expand Up @@ -220,7 +219,7 @@ def get_output(self) -> EngineCoreOutputs:
return self.decoder.decode(frame.buffer)

def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
request: Any) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
Expand All @@ -237,12 +236,10 @@ def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
self._send_input(EngineCoreRequestType.PROFILE, is_start)

def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)


class AsyncMPClient(MPClient):
Expand Down Expand Up @@ -277,7 +274,7 @@ async def process_outputs_socket():
return self.decoder.decode(await self.outputs_queue.get())

async def _send_input(self, request_type: EngineCoreRequestType,
request: EngineCoreRequestUnion) -> None:
request: Any) -> None:

msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
Expand All @@ -293,9 +290,7 @@ async def abort_requests_async(self, request_ids: List[str]) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
await self._send_input(EngineCoreRequestType.PROFILE, is_start)

async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
27 changes: 11 additions & 16 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
# SPDX-License-Identifier: Apache-2.0

import pickle
from typing import Any
from typing import Any, Optional

import torch
from msgspec import msgpack

CUSTOM_TYPE_CODE_PICKLE = 1


class PickleEncoder:

def encode(self, obj: Any):
return pickle.dumps(obj)

def decode(self, data: Any):
return pickle.loads(data)
CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2


class MsgpackEncoder:
Expand All @@ -34,8 +26,9 @@ def encode_into(self, obj: Any, buf: bytearray) -> None:
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""

def __init__(self, t: Any):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)

def decode(self, obj: Any):
return self.decoder.decode(obj)
Expand All @@ -46,13 +39,15 @@ def custom_enc_hook(obj: Any) -> Any:
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))

raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))


def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)

raise NotImplementedError(f"Extension type code {code} is not supported")