10
10
stop the prefill instance when the decode instance is slow.
11
11
"""
12
12
import threading
13
- import time
14
13
from collections import deque
15
14
from typing import Deque , List , Optional , Union
16
15
@@ -29,21 +28,21 @@ class SimpleBuffer(KVLookupBufferBase):
29
28
def __init__ (self , signal_pipe : KVPipeBase , data_pipe : KVPipeBase ,
30
29
buffer_size_thresh : float ):
31
30
"""
32
- signal_pipe: on CPU
33
-
34
- NOTE: on-device recv will block all threads in the process, making the
35
- KV cache producer unable to listen to new request while transmitting
36
- KV cache. Luckily CPU recv only blocks the current thread so we use
31
+ signal_pipe: on CPU
32
+
33
+ NOTE: on-device recv will block all threads in the process, making the
34
+ KV cache producer unable to listen to new request while transmitting
35
+ KV cache. Luckily CPU recv only blocks the current thread so we use
37
36
CPU recv to listen to new request.
38
-
37
+
39
38
data_pipe: on device (e.g. GPU)
40
39
"""
41
40
42
41
self .buffer : Deque [List [torch .Tensor ]] = deque ()
43
42
44
43
self .buffer_size = 0
45
44
self .buffer_size_threshold = buffer_size_thresh
46
- self .buffer_lock = threading .Lock ()
45
+ self .buffer_cv = threading .Condition ()
47
46
self .signal_pipe = signal_pipe
48
47
self .data_pipe = data_pipe
49
48
self .request_handling_thread : Optional [threading .Thread ] = None
@@ -116,11 +115,19 @@ def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
116
115
hidden = hidden .clone ()
117
116
118
117
buffer_item = [input_tokens , roi , key , value , hidden ]
118
+ data_size = sum ([self ._get_element_size (data ) for data in buffer_item ])
119
+
120
+ with self .buffer_cv :
121
+ if self .buffer_size + data_size > self .buffer_size_threshold :
122
+ # log outside the while loop to avoid this message being logged
123
+ # repeatedly.
124
+ logger .debug ("KV transfer buffer is full. Handling..." )
125
+ while self .buffer_size + data_size > self .buffer_size_threshold :
126
+ self .buffer_cv .wait ()
119
127
120
- with self .buffer_lock :
121
- for data in buffer_item :
122
- self .buffer_size += self ._get_element_size (data )
128
+ self .buffer_size += data_size
123
129
self .buffer .append (buffer_item )
130
+ self .buffer_cv .notify ()
124
131
125
132
def _is_end_signal (self , signal ):
126
133
return signal is None
@@ -143,35 +150,31 @@ def drop_select_handler(self):
143
150
roi = (roi > 0.5 )
144
151
tokens_roi_recver = [input_tokens , roi ]
145
152
146
- matched_length = 0
147
-
148
- # perform input tokens and roi matching
149
- # FIXME: this matching is O(n), ideally it should be O(1)
150
- # but this buffer size won't (and shouldn't) be too large so
151
- # the fix is not urgent.
152
- with self .buffer_lock :
153
-
153
+ def is_buffer_available (
154
+ tokens_roi_recver : List [torch .Tensor ], ) -> bool :
155
+ # perform input tokens and roi matching
156
+ # FIXME: this matching is O(n), ideally it should be O(1)
157
+ # but this buffer size won't (and shouldn't) be too large so
158
+ # the fix is not urgent.
154
159
for _ in range (len (self .buffer )):
155
-
156
- temp_length = self ._matches (self .buffer [0 ],
157
- tokens_roi_recver )
158
- if temp_length > 0 :
159
- matched_length = temp_length
160
- break
160
+ if self ._matches (self .buffer [0 ],
161
+ tokens_roi_recver ) > 0 :
162
+ return True
161
163
# rotate the element we just accessed to the end
162
164
self .buffer .rotate (- 1 )
163
-
164
- if matched_length > 0 :
165
- # need to clone the tensor
166
- # in case the tensor is freed before sending finishes
167
- matched_item = self .buffer .popleft ()
168
- for tensor in matched_item :
169
- self ._send_tensor_and_dec_size (tensor )
170
-
171
- else :
172
- # no match, just send None
173
- for _ in range (5 ):
174
- self .data_pipe .send_tensor (None )
165
+ return False
166
+
167
+ with self .buffer_cv :
168
+ while not is_buffer_available (tokens_roi_recver ):
169
+ logger .debug (
170
+ "KV transfer buffer is not available. Waiting..." )
171
+ self .buffer_cv .wait ()
172
+ # need to clone the tensor
173
+ # in case the tensor is freed before sending finishes
174
+ matched_item = self .buffer .popleft ()
175
+ for tensor in matched_item :
176
+ self ._send_tensor_and_dec_size (tensor )
177
+ self .buffer_cv .notify ()
175
178
176
179
except RuntimeError as e :
177
180
if 'Connection closed by peer' not in str (e ):
@@ -208,20 +211,10 @@ def drop_select(
208
211
209
212
return [input_tokens , roi , key , value , hidden ]
210
213
211
- def full_handler (self ):
212
- time .sleep (0.001 )
213
-
214
214
def insert (self , input_tokens : torch .Tensor , roi : torch .Tensor ,
215
215
key : torch .Tensor , value : torch .Tensor ,
216
216
hidden : torch .Tensor ) -> None :
217
217
218
- if self .buffer_size > self .buffer_size_threshold :
219
- # log outside the while loop to avoid this message being logged
220
- # repeatedly.
221
- logger .debug ("KV transfer buffer is full. Handling..." )
222
- while self .buffer_size > self .buffer_size_threshold :
223
- self .full_handler ()
224
-
225
218
self ._add_to_buffer (input_tokens , roi , key , value , hidden )
226
219
227
220
# when calling the insert, the current process is a sender
0 commit comments