Skip to content

[V1] Zero-copy tensor/ndarray serialization/transmission #13790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass

import numpy as np
import torch

from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder


@dataclass
class MyType:
tensor1: torch.Tensor
a_string: str
list_of_tensors: list[torch.Tensor]
numpy_array: np.ndarray


def test_encode_decode():
"""Test encode/decode loop with zero-copy tensors."""

obj = MyType(
tensor1=torch.randint(low=0, high=100, size=(10, ), dtype=torch.int32),
a_string="hello",
list_of_tensors=[
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4), dtype=torch.float64)
],
numpy_array=np.arange(20),
)

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyType)

encoded = encoder.encode(obj)

# There should be the main buffer + 3 tensor buffers + one ndarray buffer
assert len(encoded) == 5

decoded: MyType = decoder.decode(encoded)

assert_equal(decoded, obj)

# Test encode_into case

preallocated = bytearray()

encoded2 = encoder.encode_into(obj, preallocated)

assert len(encoded2) == 5
assert encoded2[0] is preallocated

decoded2: MyType = decoder.decode(encoded2)

assert_equal(decoded2, obj)


def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
assert all(
torch.equal(a, b)
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
8 changes: 4 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,14 @@ def process_input_socket(self, input_path: str, engine_index: int):

while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))

# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
request = decoder.decode(data_frames)

# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
Expand All @@ -514,8 +514,8 @@ def process_output_socket(self, output_path: str, engine_index: int):
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send(buffer, copy=False)
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)


ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
Expand Down
26 changes: 13 additions & 13 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
from vllm.v1.utils import BackgroundProcHandle

logger = init_logger(__name__)
Expand Down Expand Up @@ -500,8 +500,8 @@ def process_outputs_socket():
# shutdown signal, exit thread.
break

frame = out_socket.recv(copy=False)
outputs = decoder.decode(frame.buffer)
frames = out_socket.recv_multipart(copy=False)
outputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
Expand All @@ -524,7 +524,7 @@ def get_output(self) -> EngineCoreOutputs:
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
self.encoder.encode(request))
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)

def call_utility(self, method: str, *args) -> Any:
Expand Down Expand Up @@ -628,8 +628,8 @@ def _ensure_output_queue_task(self):

async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
frames = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
Expand Down Expand Up @@ -661,12 +661,12 @@ def _send_input(self,
if engine is None:
engine = self.core_engine

message = (request_type.value, self.encoder.encode(request))
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)

def _send_input_message(self, message: tuple[bytes, bytes],
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
message = (engine.identity, ) + message # type: ignore[assignment]
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)

async def call_utility_async(self, method: str, *args) -> Any:
Expand All @@ -679,8 +679,8 @@ async def _call_utility_async(self, method: str, *args,
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value,
self.encoder.encode((call_id, method, args)))
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
self._ensure_output_queue_task()
return await future
Expand Down Expand Up @@ -755,7 +755,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],

# Control message used for triggering dp idle mode loop.
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
self.encoder.encode(None))
*self.encoder.encode(None))

self.num_engines_running = 0
self.reqs_in_flight: dict[str, CoreEngine] = {}
Expand Down Expand Up @@ -789,7 +789,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
# tokenized.
request.prompt = None

msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))

chosen_engine = self.get_core_engine_for_request()
self.reqs_in_flight[request.request_id] = chosen_engine
Expand Down
130 changes: 89 additions & 41 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,109 @@
# SPDX-License-Identifier: Apache-2.0

import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional
from typing import Any, Optional, Union

import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack

CUSTOM_TYPE_TENSOR = 1
CUSTOM_TYPE_PICKLE = 2
CUSTOM_TYPE_CLOUDPICKLE = 3
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2

bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]


class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
"""Encoder with custom torch tensor and numpy array serialization."""

def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)

def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)

def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
self.aux_buffers: Optional[list[bytestr]] = None

def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
return bufs
finally:
self.aux_buffers = None

def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None

def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())

# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)

if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))

def _encode_ndarray(self, obj: np.ndarray) -> Any:
assert self.aux_buffers is not None
obj = np.ascontiguousarray(obj)
index = len(self.aux_buffers)
self.aux_buffers.append(obj.data)
return obj.dtype.str, obj.shape, index


class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
"""Decoder with custom torch tensor and numpy array serialization."""

def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)

def decode(self, obj: Any):
return self.decoder.decode(obj)


def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))

if isinstance(obj, FunctionType):
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))

return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))


def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_TENSOR:
return torch.from_numpy(pickle.loads(data))
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)

raise NotImplementedError(f"Extension type code {code} is not supported")
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()

def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
return self.decoder.decode(bufs)

self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()

def dec_hook(self, t: type, obj: Any) -> Any:
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))

raise NotImplementedError(f"Type {t} is not supported")

def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, index = arr
return np.ndarray(buffer=self.aux_buffers[index],
dtype=np.dtype(dtype),
shape=shape)

def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)

raise NotImplementedError(
f"Extension type code {code} is not supported")