Skip to content

Commit b07bf83

Browse files
authored
[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)
Signed-off-by: Nick Hill <[email protected]>
1 parent 53e8cf5 commit b07bf83

File tree

3 files changed

+77
-12
lines changed

3 files changed

+77
-12
lines changed

tests/v1/test_serial_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class MyType:
3232
large_f_contig_tensor: torch.Tensor
3333
small_non_contig_tensor: torch.Tensor
3434
large_non_contig_tensor: torch.Tensor
35+
empty_tensor: torch.Tensor
3536

3637

3738
def test_encode_decode():
@@ -58,6 +59,7 @@ def test_encode_decode():
5859
large_f_contig_tensor=torch.rand(1024, 4).t(),
5960
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
6061
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
62+
empty_tensor=torch.empty(0),
6163
)
6264

6365
encoder = MsgpackEncoder(size_threshold=256)
@@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
193195
obj2.small_non_contig_tensor)
194196
assert torch.equal(obj1.large_non_contig_tensor,
195197
obj2.large_non_contig_tensor)
198+
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)

vllm/v1/engine/core.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import threading
77
import time
8+
from collections import deque
89
from concurrent.futures import Future
910
from inspect import isclass, signature
1011
from logging import DEBUG
@@ -527,8 +528,12 @@ def process_output_socket(self, output_path: str, engine_index: int):
527528

528529
# Msgpack serialization encoding.
529530
encoder = MsgpackEncoder()
530-
# Reuse send buffer.
531-
buffer = bytearray()
531+
# Send buffers to reuse.
532+
reuse_buffers: list[bytearray] = []
533+
# Keep references to outputs and buffers until zmq is finished
534+
# with them (outputs may contain tensors/np arrays whose
535+
# backing buffers were extracted for zero-copy send).
536+
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
532537

533538
# We must set linger to ensure the ENGINE_CORE_DEAD
534539
# message is sent prior to closing the socket.
@@ -541,8 +546,22 @@ def process_output_socket(self, output_path: str, engine_index: int):
541546
break
542547
assert not isinstance(outputs, bytes)
543548
outputs.engine_index = engine_index
549+
550+
# Reclaim buffers that zmq is finished with.
551+
while pending and pending[-1][0].done:
552+
reuse_buffers.append(pending.pop()[2])
553+
554+
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
544555
buffers = encoder.encode_into(outputs, buffer)
545-
socket.send_multipart(buffers, copy=False)
556+
tracker = socket.send_multipart(buffers,
557+
copy=False,
558+
track=True)
559+
if not tracker.done:
560+
ref = outputs if len(buffers) > 1 else None
561+
pending.appendleft((tracker, ref, buffer))
562+
elif len(reuse_buffers) < 2:
563+
# Keep at most 2 buffers to reuse.
564+
reuse_buffers.append(buffer)
546565

547566

548567
class DPEngineCoreProc(EngineCoreProc):

vllm/v1/engine/core_client.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import asyncio
3+
import contextlib
34
import queue
45
import uuid
56
import weakref
67
from abc import ABC, abstractmethod
8+
from collections import deque
79
from collections.abc import Awaitable, Sequence
810
from concurrent.futures import Future
911
from dataclasses import dataclass, field
@@ -396,6 +398,12 @@ def __init__(
396398
self._wait_for_engine_startup()
397399

398400
self.utility_results: dict[int, AnyFuture] = {}
401+
402+
# Request objects which may contain pytorch-allocated tensors
403+
# that we need to keep references to until zmq is done with the
404+
# underlying data.
405+
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
406+
399407
success = True
400408
finally:
401409
if not success:
@@ -459,6 +467,14 @@ def ensure_alive(self):
459467
if self.resources.engine_dead:
460468
raise EngineDeadError()
461469

470+
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
471+
if not tracker.done:
472+
self.pending_messages.appendleft((tracker, msg))
473+
474+
def free_pending_messages(self):
475+
while self.pending_messages and self.pending_messages[-1][0].done:
476+
self.pending_messages.pop()
477+
462478

463479
def _process_utility_output(output: UtilityOutput,
464480
utility_results: dict[int, AnyFuture]):
@@ -544,10 +560,18 @@ def get_output(self) -> EngineCoreOutputs:
544560

545561
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
546562
self.ensure_alive()
563+
self.free_pending_messages()
547564
# (Identity, RequestType, SerializedRequest)
548565
msg = (self.core_engine.identity, request_type.value,
549566
*self.encoder.encode(request))
550-
self.input_socket.send_multipart(msg, copy=False)
567+
568+
if len(msg) <= 3:
569+
# No auxiliary buffers => no tensor backing buffers in request.
570+
self.input_socket.send_multipart(msg, copy=False)
571+
return
572+
573+
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
574+
self.add_pending_message(tracker, request)
551575

552576
def call_utility(self, method: str, *args) -> Any:
553577
call_id = uuid.uuid1().int >> 64
@@ -698,19 +722,38 @@ async def get_output_async(self) -> EngineCoreOutputs:
698722
def _send_input(self,
699723
request_type: EngineCoreRequestType,
700724
request: Any,
701-
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
725+
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
702726
self.ensure_alive()
703727
if engine is None:
704728
engine = self.core_engine
705729

706730
message = (request_type.value, *self.encoder.encode(request))
707-
return self._send_input_message(message, engine)
708-
709-
def _send_input_message(self, message: tuple[bytestr, ...],
710-
engine: CoreEngine) -> Awaitable[None]:
731+
return self._send_input_message(message, engine, request)
732+
733+
def _send_input_message(self, message: tuple[bytestr,
734+
...], engine: CoreEngine,
735+
objects: Any) -> Awaitable[Any]:
736+
"""
737+
objects is a reference to retain until zmq is finished with the
738+
buffers, in case they were extracted from tensors in the request.
739+
"""
711740
self.ensure_alive()
712-
message = (engine.identity, ) + message
713-
return self.input_socket.send_multipart(message, copy=False)
741+
self.free_pending_messages()
742+
743+
msg = (engine.identity, ) + message
744+
if not objects or len(msg) <= 3:
745+
# No auxiliary buffers => no tensor backing buffers in request.
746+
return self.input_socket.send_multipart(msg, copy=False)
747+
748+
future: asyncio.Future[zmq.MessageTracker]
749+
future = self.input_socket.send_multipart(msg, copy=False, track=True)
750+
751+
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
752+
with contextlib.suppress(BaseException):
753+
self.add_pending_message(f.result(), objects)
754+
755+
future.add_done_callback(add_pending)
756+
return future
714757

715758
async def call_utility_async(self, method: str, *args) -> Any:
716759
return await self._call_utility_async(method,
@@ -724,7 +767,7 @@ async def _call_utility_async(self, method: str, *args,
724767
self.utility_results[call_id] = future
725768
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
726769
(call_id, method, args)))
727-
await self._send_input_message(message, engine)
770+
await self._send_input_message(message, engine, args)
728771
self._ensure_output_queue_task()
729772
return await future
730773

0 commit comments

Comments
 (0)