@@ -121,6 +121,8 @@ def schedule(self) -> "SchedulerOutput":
121
121
encoder_budget = self .max_num_encoder_input_tokens
122
122
# Spec decode-related.
123
123
scheduled_spec_decode_tokens : Dict [str , List [int ]] = {}
124
+
125
+ # For logging.
124
126
scheduled_timestamp = time .monotonic ()
125
127
126
128
# First, schedule the RUNNING requests.
@@ -187,6 +189,15 @@ def schedule(self) -> "SchedulerOutput":
187
189
token_budget -= num_new_tokens
188
190
req_index += 1
189
191
192
+ # Speculative decode related.
193
+ if request .spec_token_ids :
194
+ num_scheduled_spec_tokens = (num_new_tokens +
195
+ request .num_computed_tokens -
196
+ request .num_tokens )
197
+ if num_scheduled_spec_tokens > 0 :
198
+ scheduled_spec_decode_tokens [request .request_id ] = (
199
+ request .spec_token_ids [:num_scheduled_spec_tokens ])
200
+
190
201
# Encoder-related.
191
202
if encoder_inputs_to_schedule :
192
203
scheduled_encoder_inputs [request .request_id ] = (
@@ -196,11 +207,6 @@ def schedule(self) -> "SchedulerOutput":
196
207
self .encoder_cache_manager .allocate (request , i )
197
208
encoder_budget = new_encoder_budget
198
209
199
- # Speculative decode related.
200
- if request .spec_token_ids :
201
- scheduled_spec_decode_tokens [
202
- request .request_id ] = request .spec_token_ids
203
-
204
210
# Record the LoRAs in scheduled_running_reqs
205
211
requested_loras : Set [int ] = set ()
206
212
if self .lora_config :
@@ -324,23 +330,24 @@ def schedule(self) -> "SchedulerOutput":
324
330
# Construct the scheduler output.
325
331
new_reqs_data = [
326
332
NewRequestData .from_request (req ,
327
- req_to_new_block_ids [req .request_id ],
328
- req .num_computed_tokens )
333
+ req_to_new_block_ids [req .request_id ])
329
334
for req in scheduled_new_reqs
330
335
]
331
336
resumed_reqs_data = [
332
337
self ._make_cached_request_data (
333
338
req ,
339
+ num_scheduled_tokens [req .request_id ],
340
+ len (scheduled_spec_decode_tokens .get (req .request_id , ())),
334
341
req_to_new_block_ids [req .request_id ],
335
- req .num_computed_tokens ,
336
342
resumed_from_preemption = True ,
337
343
) for req in scheduled_resumed_reqs
338
344
]
339
345
running_reqs_data = [
340
346
self ._make_cached_request_data (
341
347
req ,
348
+ num_scheduled_tokens [req .request_id ],
349
+ len (scheduled_spec_decode_tokens .get (req .request_id , ())),
342
350
req_to_new_block_ids [req .request_id ],
343
- req .num_computed_tokens ,
344
351
resumed_from_preemption = False ,
345
352
) for req in scheduled_running_reqs
346
353
]
@@ -349,8 +356,8 @@ def schedule(self) -> "SchedulerOutput":
349
356
scheduled_cached_reqs = resumed_reqs_data + running_reqs_data ,
350
357
num_scheduled_tokens = num_scheduled_tokens ,
351
358
total_num_scheduled_tokens = total_num_scheduled_tokens ,
352
- scheduled_encoder_inputs = scheduled_encoder_inputs ,
353
359
scheduled_spec_decode_tokens = scheduled_spec_decode_tokens ,
360
+ scheduled_encoder_inputs = scheduled_encoder_inputs ,
354
361
num_common_prefix_blocks = num_common_prefix_blocks ,
355
362
# finished_req_ids is an existing state in the scheduler,
356
363
# instead of being newly scheduled in this step.
@@ -366,22 +373,28 @@ def schedule(self) -> "SchedulerOutput":
366
373
def _make_cached_request_data (
367
374
self ,
368
375
request : Request ,
376
+ num_scheduled_tokens : int ,
377
+ num_scheduled_spec_tokens : int ,
369
378
new_block_ids : List [int ],
370
- num_computed_tokens : int ,
371
379
resumed_from_preemption : bool ,
372
380
) -> "CachedRequestData" :
373
381
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
374
382
# them at each scheduling step.
375
- if request .request_id in self ._cached_reqs_data :
376
- req_data = self ._cached_reqs_data [request .request_id ]
383
+ num_computed_tokens = request .num_computed_tokens
384
+ num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
385
+ new_token_ids = request .all_token_ids [
386
+ num_computed_tokens :num_computed_tokens + num_regular_tokens ]
387
+ req_data = self ._cached_reqs_data .get (request .request_id )
388
+ if req_data is not None :
377
389
req_data .resumed_from_preemption = resumed_from_preemption
390
+ req_data .new_token_ids = new_token_ids
378
391
req_data .new_block_ids = new_block_ids
379
392
req_data .num_computed_tokens = num_computed_tokens
380
393
else :
381
394
req_data = CachedRequestData .from_request (request ,
382
395
resumed_from_preemption ,
383
- new_block_ids ,
384
- num_computed_tokens )
396
+ new_token_ids ,
397
+ new_block_ids )
385
398
self ._cached_reqs_data [request .request_id ] = req_data
386
399
return req_data
387
400
0 commit comments