Skip to content

Commit 32171db

Browse files
p88hnjhill
authored andcommitted
Serialize tensors using int8 views (vllm-project#16866)
Signed-off-by: Staszek Pasko <[email protected]> Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
1 parent 2d2d629 commit 32171db

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

tests/v1/test_serial_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def test_encode_decode():
4747
torch.rand((1, 10), dtype=torch.float32),
4848
torch.rand((3, 5, 4000), dtype=torch.float64),
4949
torch.tensor(1984), # test scalar too
50+
# Make sure to test bf16 which numpy doesn't support.
51+
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
52+
torch.tensor([float("-inf"), float("inf")] * 1024,
53+
dtype=torch.bfloat16),
5054
],
5155
numpy_array=np.arange(512),
5256
unrecognized=UnrecognizedType(33),
@@ -64,7 +68,7 @@ def test_encode_decode():
6468
# There should be the main buffer + 4 large tensor buffers
6569
# + 1 large numpy array. "large" is <= 512 bytes.
6670
# The two small tensors are encoded inline.
67-
assert len(encoded) == 6
71+
assert len(encoded) == 8
6872

6973
decoded: MyType = decoder.decode(encoded)
7074

@@ -76,7 +80,7 @@ def test_encode_decode():
7680

7781
encoded2 = encoder.encode_into(obj, preallocated)
7882

79-
assert len(encoded2) == 6
83+
assert len(encoded2) == 8
8084
assert encoded2[0] is preallocated
8185

8286
decoded2: MyType = decoder.decode(encoded2)
@@ -114,15 +118,15 @@ def test_multimodal_kwargs():
114118

115119
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
116120

117-
# expected total encoding length, should be 44536, +-20 for minor changes
118-
assert total_len >= 44516 and total_len <= 44556
121+
# expected total encoding length, should be 44559, +-20 for minor changes
122+
assert total_len >= 44539 and total_len <= 44579
119123
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
120124
assert all(nested_equal(d[k], decoded[k]) for k in d)
121125

122126

123127
def test_multimodal_items_by_modality():
124-
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
125-
dtype=torch.int16),
128+
e1 = MultiModalFieldElem("audio", "a0",
129+
torch.zeros(1000, dtype=torch.bfloat16),
126130
MultiModalBatchedField())
127131
e2 = MultiModalFieldElem(
128132
"video",

vllm/v1/serial_utils.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
8080

8181
def enc_hook(self, obj: Any) -> Any:
8282
if isinstance(obj, torch.Tensor):
83-
return self._encode_ndarray(obj.numpy())
83+
return self._encode_tensor(obj)
8484

8585
# Fall back to pickle for object or void kind ndarrays.
8686
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
@@ -133,9 +133,27 @@ def _encode_ndarray(
133133
# backing buffers that we've stashed in `aux_buffers`.
134134
return obj.dtype.str, obj.shape, data
135135

136+
def _encode_tensor(
137+
self, obj: torch.Tensor
138+
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
139+
assert self.aux_buffers is not None
140+
# this creates a copy of the tensor if it's not already contiguous
141+
obj = obj.contiguous()
142+
# view the tensor as a 1D array of bytes
143+
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy()
144+
if obj.nbytes < self.size_threshold:
145+
# Smaller tensors are encoded inline, just like ndarrays.
146+
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
147+
else:
148+
# Otherwise encode index of backing buffer to avoid copy.
149+
data = len(self.aux_buffers)
150+
self.aux_buffers.append(arr.data)
151+
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix
152+
return dtype, obj.shape, data
153+
136154
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
137155
if isinstance(nt, torch.Tensor):
138-
return self._encode_ndarray(nt.numpy())
156+
return self._encode_tensor(nt)
139157
if isinstance(nt, (int, float)):
140158
# Although it violates NestedTensors type, MultiModalKwargs
141159
# values are sometimes floats.
@@ -186,7 +204,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
186204
if issubclass(t, np.ndarray):
187205
return self._decode_ndarray(obj)
188206
if issubclass(t, torch.Tensor):
189-
return torch.from_numpy(self._decode_ndarray(obj))
207+
return self._decode_tensor(obj)
190208
if issubclass(t, MultiModalKwargs):
191209
if isinstance(obj, list):
192210
return MultiModalKwargs.from_items(
@@ -199,11 +217,24 @@ def dec_hook(self, t: type, obj: Any) -> Any:
199217

200218
def _decode_ndarray(self, arr: Any) -> np.ndarray:
201219
dtype, shape, data = arr
202-
# Copy from inline representation, otherwise Torch is unhappy since
203-
# the returned memory is non-writeable.
220+
# zero-copy decode. We assume the ndarray will not be kept around,
221+
# as it now locks the whole received message buffer in memory.
222+
buffer = self.aux_buffers[data] if isinstance(data, int) else data
223+
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
224+
225+
def _decode_tensor(self, arr: Any) -> torch.Tensor:
226+
dtype, shape, data = arr
227+
# Copy from inline representation, to decouple the memory storage
228+
# of the message from the original buffer. And also make Torch
229+
# not complain about a readonly memoryview.
204230
buffer = self.aux_buffers[data] if isinstance(data, int) \
205231
else bytearray(data)
206-
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
232+
# Create numpy wrapper around the bytes
233+
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
234+
torch_dtype = getattr(torch, dtype)
235+
assert isinstance(torch_dtype, torch.dtype)
236+
# Convert back to proper shape & type
237+
return torch.from_numpy(arr).view(torch_dtype).view(shape)
207238

208239
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
209240
decoded_items = []
@@ -228,7 +259,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
228259
if not isinstance(obj, list):
229260
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
230261
if obj and isinstance(obj[0], str):
231-
return torch.from_numpy(self._decode_ndarray(obj))
262+
return self._decode_tensor(obj)
232263
return [self._decode_nested_tensors(x) for x in obj]
233264

234265
def ext_hook(self, code: int, data: memoryview) -> Any:

0 commit comments

Comments
 (0)