Skip to content

Commit 4c21ce9

Browse files
authored
[V1] Get input tokens from scheduler (#13339)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent ce77eb9 commit 4c21ce9

File tree

4 files changed

+139
-139
lines changed

4 files changed

+139
-139
lines changed

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
154154
cached_req_data = CachedRequestData(
155155
req_id=req_id,
156156
resumed_from_preemption=False,
157+
new_token_ids=[],
157158
new_block_ids=[],
158159
num_computed_tokens=0,
159160
)

vllm/v1/core/scheduler.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput":
121121
encoder_budget = self.max_num_encoder_input_tokens
122122
# Spec decode-related.
123123
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
124+
125+
# For logging.
124126
scheduled_timestamp = time.monotonic()
125127

126128
# First, schedule the RUNNING requests.
@@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput":
187189
token_budget -= num_new_tokens
188190
req_index += 1
189191

192+
# Speculative decode related.
193+
if request.spec_token_ids:
194+
num_scheduled_spec_tokens = (num_new_tokens +
195+
request.num_computed_tokens -
196+
request.num_tokens)
197+
if num_scheduled_spec_tokens > 0:
198+
scheduled_spec_decode_tokens[request.request_id] = (
199+
request.spec_token_ids[:num_scheduled_spec_tokens])
200+
190201
# Encoder-related.
191202
if encoder_inputs_to_schedule:
192203
scheduled_encoder_inputs[request.request_id] = (
@@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput":
196207
self.encoder_cache_manager.allocate(request, i)
197208
encoder_budget = new_encoder_budget
198209

199-
# Speculative decode related.
200-
if request.spec_token_ids:
201-
scheduled_spec_decode_tokens[
202-
request.request_id] = request.spec_token_ids
203-
204210
# Record the LoRAs in scheduled_running_reqs
205211
requested_loras: Set[int] = set()
206212
if self.lora_config:
@@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput":
324330
# Construct the scheduler output.
325331
new_reqs_data = [
326332
NewRequestData.from_request(req,
327-
req_to_new_block_ids[req.request_id],
328-
req.num_computed_tokens)
333+
req_to_new_block_ids[req.request_id])
329334
for req in scheduled_new_reqs
330335
]
331336
resumed_reqs_data = [
332337
self._make_cached_request_data(
333338
req,
339+
num_scheduled_tokens[req.request_id],
340+
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
334341
req_to_new_block_ids[req.request_id],
335-
req.num_computed_tokens,
336342
resumed_from_preemption=True,
337343
) for req in scheduled_resumed_reqs
338344
]
339345
running_reqs_data = [
340346
self._make_cached_request_data(
341347
req,
348+
num_scheduled_tokens[req.request_id],
349+
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
342350
req_to_new_block_ids[req.request_id],
343-
req.num_computed_tokens,
344351
resumed_from_preemption=False,
345352
) for req in scheduled_running_reqs
346353
]
@@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput":
349356
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
350357
num_scheduled_tokens=num_scheduled_tokens,
351358
total_num_scheduled_tokens=total_num_scheduled_tokens,
352-
scheduled_encoder_inputs=scheduled_encoder_inputs,
353359
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
360+
scheduled_encoder_inputs=scheduled_encoder_inputs,
354361
num_common_prefix_blocks=num_common_prefix_blocks,
355362
# finished_req_ids is an existing state in the scheduler,
356363
# instead of being newly scheduled in this step.
@@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput":
366373
def _make_cached_request_data(
367374
self,
368375
request: Request,
376+
num_scheduled_tokens: int,
377+
num_scheduled_spec_tokens: int,
369378
new_block_ids: List[int],
370-
num_computed_tokens: int,
371379
resumed_from_preemption: bool,
372380
) -> "CachedRequestData":
373381
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
374382
# them at each scheduling step.
375-
if request.request_id in self._cached_reqs_data:
376-
req_data = self._cached_reqs_data[request.request_id]
383+
num_computed_tokens = request.num_computed_tokens
384+
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
385+
new_token_ids = request.all_token_ids[
386+
num_computed_tokens:num_computed_tokens + num_regular_tokens]
387+
req_data = self._cached_reqs_data.get(request.request_id)
388+
if req_data is not None:
377389
req_data.resumed_from_preemption = resumed_from_preemption
390+
req_data.new_token_ids = new_token_ids
378391
req_data.new_block_ids = new_block_ids
379392
req_data.num_computed_tokens = num_computed_tokens
380393
else:
381394
req_data = CachedRequestData.from_request(request,
382395
resumed_from_preemption,
383-
new_block_ids,
384-
num_computed_tokens)
396+
new_token_ids,
397+
new_block_ids)
385398
self._cached_reqs_data[request.request_id] = req_data
386399
return req_data
387400

vllm/v1/core/scheduler_output.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def from_request(
3030
cls,
3131
request: "Request",
3232
block_ids: List[int],
33-
num_computed_tokens: int,
3433
) -> "NewRequestData":
3534
return cls(
3635
req_id=request.request_id,
@@ -41,7 +40,7 @@ def from_request(
4140
mm_positions=request.mm_positions,
4241
sampling_params=request.sampling_params,
4342
block_ids=block_ids,
44-
num_computed_tokens=num_computed_tokens,
43+
num_computed_tokens=request.num_computed_tokens,
4544
lora_request=request.lora_request,
4645
)
4746

@@ -54,6 +53,7 @@ class CachedRequestData:
5453
# the request's block IDs. If True, new_block_ids will be used as the
5554
# request's block IDs instead of appending to the existing block IDs.
5655
resumed_from_preemption: bool
56+
new_token_ids: List[int]
5757
new_block_ids: List[int]
5858
num_computed_tokens: int
5959

@@ -62,14 +62,15 @@ def from_request(
6262
cls,
6363
request: "Request",
6464
resumed_from_preemption: bool,
65+
new_token_ids: List[int],
6566
new_block_ids: List[int],
66-
num_computed_tokens: int,
6767
) -> "CachedRequestData":
6868
return cls(
6969
req_id=request.request_id,
7070
resumed_from_preemption=resumed_from_preemption,
71+
new_token_ids=new_token_ids,
7172
new_block_ids=new_block_ids,
72-
num_computed_tokens=num_computed_tokens,
73+
num_computed_tokens=request.num_computed_tokens,
7374
)
7475

7576

@@ -91,9 +92,9 @@ class SchedulerOutput:
9192
# Total number of tokens scheduled for all requests.
9293
# Equal to sum(num_scheduled_tokens.values())
9394
total_num_scheduled_tokens: int
94-
# req_id -> spec_decode_tokens
95-
# If a request does not have any spec decode tokens, it will
96-
# not be included in the dictionary.
95+
# req_id -> spec_token_ids
96+
# If a request does not have any spec decode tokens, it will not be
97+
# included in the dictionary.
9798
scheduled_spec_decode_tokens: Dict[str, List[int]]
9899
# req_id -> encoder input indices that need processing.
99100
# E.g., if a request has [0, 1], it could mean the vision encoder needs

0 commit comments

Comments
 (0)