Skip to content

Commit 764e7a2

Browse files
njhillyangw-dev
authored andcommitted
[V1] Zero-copy tensor/ndarray serialization/transmission (vllm-project#13790)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Yang Wang <[email protected]>
1 parent 8f503df commit 764e7a2

File tree

4 files changed

+217
-58
lines changed

4 files changed

+217
-58
lines changed

tests/v1/test_serial_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from collections import UserDict
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
import torch
7+
8+
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
9+
10+
11+
class UnrecognizedType(UserDict):
12+
13+
def __init__(self, an_int: int):
14+
super().__init__()
15+
self.an_int = an_int
16+
17+
18+
@dataclass
19+
class MyType:
20+
tensor1: torch.Tensor
21+
a_string: str
22+
list_of_tensors: list[torch.Tensor]
23+
numpy_array: np.ndarray
24+
unrecognized: UnrecognizedType
25+
26+
27+
def test_encode_decode():
28+
"""Test encode/decode loop with zero-copy tensors."""
29+
30+
obj = MyType(
31+
tensor1=torch.randint(low=0,
32+
high=100,
33+
size=(1024, ),
34+
dtype=torch.int32),
35+
a_string="hello",
36+
list_of_tensors=[
37+
torch.rand((1, 10), dtype=torch.float32),
38+
torch.rand((3, 5, 4000), dtype=torch.float64),
39+
torch.tensor(1984), # test scalar too
40+
],
41+
numpy_array=np.arange(512),
42+
unrecognized=UnrecognizedType(33),
43+
)
44+
45+
encoder = MsgpackEncoder()
46+
decoder = MsgpackDecoder(MyType)
47+
48+
encoded = encoder.encode(obj)
49+
50+
# There should be the main buffer + 2 large tensor buffers
51+
# + 1 large numpy array. "large" is <= 256 bytes.
52+
# The two small tensors are encoded inline.
53+
assert len(encoded) == 4
54+
55+
decoded: MyType = decoder.decode(encoded)
56+
57+
assert_equal(decoded, obj)
58+
59+
# Test encode_into case
60+
61+
preallocated = bytearray()
62+
63+
encoded2 = encoder.encode_into(obj, preallocated)
64+
65+
assert len(encoded2) == 4
66+
assert encoded2[0] is preallocated
67+
68+
decoded2: MyType = decoder.decode(encoded2)
69+
70+
assert_equal(decoded2, obj)
71+
72+
73+
def assert_equal(obj1: MyType, obj2: MyType):
74+
assert torch.equal(obj1.tensor1, obj2.tensor1)
75+
assert obj1.a_string == obj2.a_string
76+
assert all(
77+
torch.equal(a, b)
78+
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
79+
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
80+
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int

vllm/v1/engine/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,14 +490,14 @@ def process_input_socket(self, input_path: str, engine_index: int):
490490

491491
while True:
492492
# (RequestType, RequestData)
493-
type_frame, data_frame = socket.recv_multipart(copy=False)
493+
type_frame, *data_frames = socket.recv_multipart(copy=False)
494494
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
495495

496496
# Deserialize the request data.
497497
decoder = add_request_decoder if (
498498
request_type
499499
== EngineCoreRequestType.ADD) else generic_decoder
500-
request = decoder.decode(data_frame.buffer)
500+
request = decoder.decode(data_frames)
501501

502502
# Push to input queue for core busy loop.
503503
self.input_queue.put_nowait((request_type, request))
@@ -514,8 +514,8 @@ def process_output_socket(self, output_path: str, engine_index: int):
514514
while True:
515515
outputs = self.output_queue.get()
516516
outputs.engine_index = engine_index
517-
encoder.encode_into(outputs, buffer)
518-
socket.send(buffer, copy=False)
517+
buffers = encoder.encode_into(outputs, buffer)
518+
socket.send_multipart(buffers, copy=False)
519519

520520

521521
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)

vllm/v1/engine/core_client.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
EngineCoreRequestType, UtilityOutput)
2727
from vllm.v1.engine.core import EngineCore, EngineCoreProc
2828
from vllm.v1.executor.abstract import Executor
29-
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
29+
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
3030
from vllm.v1.utils import BackgroundProcHandle
3131

3232
logger = init_logger(__name__)
@@ -505,8 +505,8 @@ def process_outputs_socket():
505505
# shutdown signal, exit thread.
506506
break
507507

508-
frame = out_socket.recv(copy=False)
509-
outputs = decoder.decode(frame.buffer)
508+
frames = out_socket.recv_multipart(copy=False)
509+
outputs = decoder.decode(frames)
510510
if outputs.utility_output:
511511
_process_utility_output(outputs.utility_output,
512512
utility_results)
@@ -529,7 +529,7 @@ def get_output(self) -> EngineCoreOutputs:
529529
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
530530
# (Identity, RequestType, SerializedRequest)
531531
msg = (self.core_engine.identity, request_type.value,
532-
self.encoder.encode(request))
532+
*self.encoder.encode(request))
533533
self.input_socket.send_multipart(msg, copy=False)
534534

535535
def call_utility(self, method: str, *args) -> Any:
@@ -633,8 +633,8 @@ def _ensure_output_queue_task(self):
633633

634634
async def process_outputs_socket():
635635
while True:
636-
(frame, ) = await output_socket.recv_multipart(copy=False)
637-
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
636+
frames = await output_socket.recv_multipart(copy=False)
637+
outputs: EngineCoreOutputs = decoder.decode(frames)
638638
if outputs.utility_output:
639639
_process_utility_output(outputs.utility_output,
640640
utility_results)
@@ -666,12 +666,12 @@ def _send_input(self,
666666
if engine is None:
667667
engine = self.core_engine
668668

669-
message = (request_type.value, self.encoder.encode(request))
669+
message = (request_type.value, *self.encoder.encode(request))
670670
return self._send_input_message(message, engine)
671671

672-
def _send_input_message(self, message: tuple[bytes, bytes],
672+
def _send_input_message(self, message: tuple[bytestr, ...],
673673
engine: CoreEngine) -> Awaitable[None]:
674-
message = (engine.identity, ) + message # type: ignore[assignment]
674+
message = (engine.identity, ) + message
675675
return self.input_socket.send_multipart(message, copy=False)
676676

677677
async def call_utility_async(self, method: str, *args) -> Any:
@@ -684,8 +684,8 @@ async def _call_utility_async(self, method: str, *args,
684684
call_id = uuid.uuid1().int >> 64
685685
future = asyncio.get_running_loop().create_future()
686686
self.utility_results[call_id] = future
687-
message = (EngineCoreRequestType.UTILITY.value,
688-
self.encoder.encode((call_id, method, args)))
687+
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
688+
(call_id, method, args)))
689689
await self._send_input_message(message, engine)
690690
self._ensure_output_queue_task()
691691
return await future
@@ -760,7 +760,7 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
760760

761761
# Control message used for triggering dp idle mode loop.
762762
self.start_dp_msg = (EngineCoreRequestType.START_DP.value,
763-
self.encoder.encode(None))
763+
*self.encoder.encode(None))
764764

765765
self.num_engines_running = 0
766766
self.reqs_in_flight: dict[str, CoreEngine] = {}
@@ -794,7 +794,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
794794
# tokenized.
795795
request.prompt = None
796796

797-
msg = (EngineCoreRequestType.ADD.value, self.encoder.encode(request))
797+
msg = (EngineCoreRequestType.ADD.value, *self.encoder.encode(request))
798798

799799
chosen_engine = self.get_core_engine_for_request()
800800
self.reqs_in_flight[request.request_id] = chosen_engine

vllm/v1/serial_utils.py

Lines changed: 120 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,140 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pickle
4+
from collections.abc import Sequence
5+
from inspect import isclass
46
from types import FunctionType
5-
from typing import Any, Optional
7+
from typing import Any, Optional, Union
68

79
import cloudpickle
10+
import numpy as np
811
import torch
12+
import zmq
913
from msgspec import msgpack
1014

11-
CUSTOM_TYPE_TENSOR = 1
12-
CUSTOM_TYPE_PICKLE = 2
13-
CUSTOM_TYPE_CLOUDPICKLE = 3
15+
CUSTOM_TYPE_PICKLE = 1
16+
CUSTOM_TYPE_CLOUDPICKLE = 2
1417

18+
# TODO calibrate this size
19+
INLINE_BUF_SIZE_THRESHOLD = 256
1520

16-
class MsgpackEncoder:
17-
"""Encoder with custom torch tensor serialization."""
21+
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
1822

19-
def __init__(self):
20-
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
2123

22-
def encode(self, obj: Any) -> bytes:
23-
return self.encoder.encode(obj)
24+
class MsgpackEncoder:
25+
"""Encoder with custom torch tensor and numpy array serialization.
2426
25-
def encode_into(self, obj: Any, buf: bytearray) -> None:
26-
self.encoder.encode_into(obj, buf)
27+
Note that unlike vanilla `msgspec` Encoders, this interface is generally
28+
not thread-safe when encoding tensors / numpy arrays.
29+
"""
30+
31+
def __init__(self):
32+
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
33+
# This is used as a local stash of buffers that we can then access from
34+
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
35+
# pass custom data to the hook otherwise.
36+
self.aux_buffers: Optional[list[bytestr]] = None
37+
38+
def encode(self, obj: Any) -> Sequence[bytestr]:
39+
try:
40+
self.aux_buffers = bufs = [b'']
41+
bufs[0] = self.encoder.encode(obj)
42+
# This `bufs` list allows us to collect direct pointers to backing
43+
# buffers of tensors and np arrays, and return them along with the
44+
# top-level encoded buffer instead of copying their data into the
45+
# new buffer.
46+
return bufs
47+
finally:
48+
self.aux_buffers = None
49+
50+
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
51+
try:
52+
self.aux_buffers = [buf]
53+
bufs = self.aux_buffers
54+
self.encoder.encode_into(obj, buf)
55+
return bufs
56+
finally:
57+
self.aux_buffers = None
58+
59+
def enc_hook(self, obj: Any) -> Any:
60+
if isinstance(obj, torch.Tensor):
61+
return self._encode_ndarray(obj.numpy())
62+
63+
# Fall back to pickle for object or void kind ndarrays.
64+
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
65+
return self._encode_ndarray(obj)
66+
67+
if isinstance(obj, FunctionType):
68+
# `pickle` is generally faster than cloudpickle, but can have
69+
# problems serializing methods.
70+
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
71+
72+
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
73+
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
74+
75+
def _encode_ndarray(
76+
self, obj: np.ndarray
77+
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
78+
assert self.aux_buffers is not None
79+
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
80+
# Encode small arrays and scalars inline.
81+
data = obj.data
82+
else:
83+
# Otherwise encode index of backing buffer.
84+
obj = np.ascontiguousarray(obj)
85+
data = len(self.aux_buffers)
86+
self.aux_buffers.append(obj.data)
87+
# We serialize the ndarray as a tuple of native types.
88+
# The data is either inlined if small, or an index into a list of
89+
# backing buffers that we've stashed in `aux_buffers`.
90+
return obj.dtype.str, obj.shape, data
2791

2892

2993
class MsgpackDecoder:
30-
"""Decoder with custom torch tensor serialization."""
94+
"""Decoder with custom torch tensor and numpy array serialization.
95+
96+
Note that unlike vanilla `msgspec` Decoders, this interface is generally
97+
not thread-safe when encoding tensors / numpy arrays.
98+
"""
3199

32100
def __init__(self, t: Optional[Any] = None):
33101
args = () if t is None else (t, )
34-
self.decoder = msgpack.Decoder(*args, ext_hook=custom_ext_hook)
35-
36-
def decode(self, obj: Any):
37-
return self.decoder.decode(obj)
38-
39-
40-
def custom_enc_hook(obj: Any) -> Any:
41-
if isinstance(obj, torch.Tensor):
42-
# NOTE(rob): it is fastest to use numpy + pickle
43-
# when serializing torch tensors.
44-
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
45-
return msgpack.Ext(CUSTOM_TYPE_TENSOR, pickle.dumps(obj.numpy()))
46-
47-
if isinstance(obj, FunctionType):
48-
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
49-
50-
return msgpack.Ext(CUSTOM_TYPE_PICKLE, pickle.dumps(obj))
51-
52-
53-
def custom_ext_hook(code: int, data: memoryview) -> Any:
54-
if code == CUSTOM_TYPE_TENSOR:
55-
return torch.from_numpy(pickle.loads(data))
56-
if code == CUSTOM_TYPE_PICKLE:
57-
return pickle.loads(data)
58-
if code == CUSTOM_TYPE_CLOUDPICKLE:
59-
return cloudpickle.loads(data)
60-
61-
raise NotImplementedError(f"Extension type code {code} is not supported")
102+
self.decoder = msgpack.Decoder(*args,
103+
ext_hook=self.ext_hook,
104+
dec_hook=self.dec_hook)
105+
self.aux_buffers: Sequence[bytestr] = ()
106+
107+
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
108+
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
109+
# TODO - This check can become `isinstance(bufs, bytestr)`
110+
# as of Python 3.10.
111+
return self.decoder.decode(bufs)
112+
113+
self.aux_buffers = bufs
114+
try:
115+
return self.decoder.decode(bufs[0])
116+
finally:
117+
self.aux_buffers = ()
118+
119+
def dec_hook(self, t: type, obj: Any) -> Any:
120+
# Given native types in `obj`, convert to type `t`.
121+
if isclass(t):
122+
if issubclass(t, np.ndarray):
123+
return self._decode_ndarray(obj)
124+
if issubclass(t, torch.Tensor):
125+
return torch.from_numpy(self._decode_ndarray(obj))
126+
return obj
127+
128+
def _decode_ndarray(self, arr: Any) -> np.ndarray:
129+
dtype, shape, data = arr
130+
buffer = self.aux_buffers[data] if isinstance(data, int) else data
131+
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
132+
133+
def ext_hook(self, code: int, data: memoryview) -> Any:
134+
if code == CUSTOM_TYPE_PICKLE:
135+
return pickle.loads(data)
136+
if code == CUSTOM_TYPE_CLOUDPICKLE:
137+
return cloudpickle.loads(data)
138+
139+
raise NotImplementedError(
140+
f"Extension type code {code} is not supported")

0 commit comments

Comments
 (0)