Skip to content

Commit b12f085

Browse files
houseroadSzymonOzog
authored andcommitted
[Bugfix] Fix disagg hang caused by the prefill and decode communication issues (vllm-project#12723)
Signed-off-by: Lu Fang <[email protected]> Signed-off-by: SzymonOzog <[email protected]>
1 parent d8aec46 commit b12f085

File tree

1 file changed

+40
-47
lines changed

1 file changed

+40
-47
lines changed

vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py

+40-47
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
stop the prefill instance when the decode instance is slow.
1111
"""
1212
import threading
13-
import time
1413
from collections import deque
1514
from typing import Deque, List, Optional, Union
1615

@@ -29,21 +28,21 @@ class SimpleBuffer(KVLookupBufferBase):
2928
def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase,
3029
buffer_size_thresh: float):
3130
"""
32-
signal_pipe: on CPU
33-
34-
NOTE: on-device recv will block all threads in the process, making the
35-
KV cache producer unable to listen to new request while transmitting
36-
KV cache. Luckily CPU recv only blocks the current thread so we use
31+
signal_pipe: on CPU
32+
33+
NOTE: on-device recv will block all threads in the process, making the
34+
KV cache producer unable to listen to new request while transmitting
35+
KV cache. Luckily CPU recv only blocks the current thread so we use
3736
CPU recv to listen to new request.
38-
37+
3938
data_pipe: on device (e.g. GPU)
4039
"""
4140

4241
self.buffer: Deque[List[torch.Tensor]] = deque()
4342

4443
self.buffer_size = 0
4544
self.buffer_size_threshold = buffer_size_thresh
46-
self.buffer_lock = threading.Lock()
45+
self.buffer_cv = threading.Condition()
4746
self.signal_pipe = signal_pipe
4847
self.data_pipe = data_pipe
4948
self.request_handling_thread: Optional[threading.Thread] = None
@@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
116115
hidden = hidden.clone()
117116

118117
buffer_item = [input_tokens, roi, key, value, hidden]
118+
data_size = sum([self._get_element_size(data) for data in buffer_item])
119+
120+
with self.buffer_cv:
121+
if self.buffer_size + data_size > self.buffer_size_threshold:
122+
# log outside the while loop to avoid this message being logged
123+
# repeatedly.
124+
logger.debug("KV transfer buffer is full. Handling...")
125+
while self.buffer_size + data_size > self.buffer_size_threshold:
126+
self.buffer_cv.wait()
119127

120-
with self.buffer_lock:
121-
for data in buffer_item:
122-
self.buffer_size += self._get_element_size(data)
128+
self.buffer_size += data_size
123129
self.buffer.append(buffer_item)
130+
self.buffer_cv.notify()
124131

125132
def _is_end_signal(self, signal):
126133
return signal is None
@@ -143,35 +150,31 @@ def drop_select_handler(self):
143150
roi = (roi > 0.5)
144151
tokens_roi_recver = [input_tokens, roi]
145152

146-
matched_length = 0
147-
148-
# perform input tokens and roi matching
149-
# FIXME: this matching is O(n), ideally it should be O(1)
150-
# but this buffer size won't (and shouldn't) be too large so
151-
# the fix is not urgent.
152-
with self.buffer_lock:
153-
153+
def is_buffer_available(
154+
tokens_roi_recver: List[torch.Tensor], ) -> bool:
155+
# perform input tokens and roi matching
156+
# FIXME: this matching is O(n), ideally it should be O(1)
157+
# but this buffer size won't (and shouldn't) be too large so
158+
# the fix is not urgent.
154159
for _ in range(len(self.buffer)):
155-
156-
temp_length = self._matches(self.buffer[0],
157-
tokens_roi_recver)
158-
if temp_length > 0:
159-
matched_length = temp_length
160-
break
160+
if self._matches(self.buffer[0],
161+
tokens_roi_recver) > 0:
162+
return True
161163
# rotate the element we just accessed to the end
162164
self.buffer.rotate(-1)
163-
164-
if matched_length > 0:
165-
# need to clone the tensor
166-
# in case the tensor is freed before sending finishes
167-
matched_item = self.buffer.popleft()
168-
for tensor in matched_item:
169-
self._send_tensor_and_dec_size(tensor)
170-
171-
else:
172-
# no match, just send None
173-
for _ in range(5):
174-
self.data_pipe.send_tensor(None)
165+
return False
166+
167+
with self.buffer_cv:
168+
while not is_buffer_available(tokens_roi_recver):
169+
logger.debug(
170+
"KV transfer buffer is not available. Waiting...")
171+
self.buffer_cv.wait()
172+
# need to clone the tensor
173+
# in case the tensor is freed before sending finishes
174+
matched_item = self.buffer.popleft()
175+
for tensor in matched_item:
176+
self._send_tensor_and_dec_size(tensor)
177+
self.buffer_cv.notify()
175178

176179
except RuntimeError as e:
177180
if 'Connection closed by peer' not in str(e):
@@ -208,20 +211,10 @@ def drop_select(
208211

209212
return [input_tokens, roi, key, value, hidden]
210213

211-
def full_handler(self):
212-
time.sleep(0.001)
213-
214214
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
215215
key: torch.Tensor, value: torch.Tensor,
216216
hidden: torch.Tensor) -> None:
217217

218-
if self.buffer_size > self.buffer_size_threshold:
219-
# log outside the while loop to avoid this message being logged
220-
# repeatedly.
221-
logger.debug("KV transfer buffer is full. Handling...")
222-
while self.buffer_size > self.buffer_size_threshold:
223-
self.full_handler()
224-
225218
self._add_to_buffer(input_tokens, roi, key, value, hidden)
226219

227220
# when calling the insert, the current process is a sender

0 commit comments

Comments
 (0)