4
4
from vllm .config import CacheConfig , ModelConfig , SchedulerConfig
5
5
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
6
6
from vllm .sampling_params import SamplingParams
7
- from vllm .v1 .core .scheduler import Scheduler
7
+ from vllm .v1 .core .scheduler import Scheduler , SchedulerOutput
8
8
from vllm .v1 .outputs import ModelRunnerOutput
9
9
from vllm .v1 .request import Request , RequestStatus
10
10
11
+ EOS_TOKEN_ID = 50256
12
+
11
13
12
14
def create_scheduler (
13
15
model : str = "facebook/opt-125m" ,
@@ -38,6 +40,7 @@ def create_scheduler(
38
40
return Scheduler (scheduler_config ,
39
41
model_config ,
40
42
cache_config ,
43
+ speculative_config = None ,
41
44
lora_config = None ,
42
45
log_stats = True )
43
46
@@ -46,8 +49,12 @@ def create_requests(
46
49
num_requests : int ,
47
50
num_tokens : int = 10 ,
48
51
mm_positions : Optional [List [PlaceholderRange ]] = None ,
52
+ max_tokens : int = 16 ,
53
+ stop_token_ids : Optional [List [int ]] = None ,
49
54
):
50
- sampling_params = SamplingParams ()
55
+ sampling_params = SamplingParams (ignore_eos = False ,
56
+ max_tokens = max_tokens ,
57
+ stop_token_ids = stop_token_ids )
51
58
requests = []
52
59
for i in range (num_requests ):
53
60
if mm_positions is not None :
@@ -64,7 +71,7 @@ def create_requests(
64
71
multi_modal_inputs = mm_inputs ,
65
72
multi_modal_placeholders = mm_position ,
66
73
multi_modal_hashes = None ,
67
- eos_token_id = None ,
74
+ eos_token_id = EOS_TOKEN_ID ,
68
75
arrival_time = 0 ,
69
76
)
70
77
requests .append (request )
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
195
202
model_runner_output = ModelRunnerOutput (
196
203
req_ids = [request .request_id for request in requests ],
197
204
req_id_to_index = req_to_index ,
198
- sampled_token_ids = [0 ] * len (requests ),
205
+ sampled_token_ids = [[ 0 ] for _ in range ( len (requests ))] ,
199
206
logprobs = None ,
200
207
prompt_logprobs_dict = {},
201
208
)
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
215
222
assert requests [2 ].request_id not in output .num_scheduled_tokens
216
223
217
224
225
+ def test_stop_via_update_from_output ():
226
+ """Test stopping behavior through update_from_output"""
227
+ scheduler = create_scheduler ()
228
+
229
+ # Test case 1: Stop on EOS token
230
+ requests = create_requests (num_requests = 2 , max_tokens = 10 )
231
+ for req in requests :
232
+ req .num_computed_tokens = req .num_tokens
233
+ scheduler .requests [req .request_id ] = req
234
+ scheduler .running .append (req )
235
+ scheduler .scheduled_req_ids .add (req .request_id )
236
+
237
+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
238
+ scheduled_cached_reqs = [],
239
+ num_scheduled_tokens = {
240
+ requests [0 ].request_id : 1 ,
241
+ requests [1 ].request_id : 2
242
+ },
243
+ total_num_scheduled_tokens = 3 ,
244
+ scheduled_encoder_inputs = {},
245
+ scheduled_spec_decode_tokens = {
246
+ requests [0 ].request_id : [],
247
+ requests [1 ].request_id : [10 ]
248
+ },
249
+ num_common_prefix_blocks = 0 ,
250
+ finished_req_ids = set (),
251
+ free_encoder_input_ids = [])
252
+
253
+ model_output = ModelRunnerOutput (
254
+ req_ids = [req .request_id for req in requests ],
255
+ req_id_to_index = {
256
+ req .request_id : i
257
+ for i , req in enumerate (requests )
258
+ },
259
+ sampled_token_ids = [[EOS_TOKEN_ID ],
260
+ [10 ,
261
+ 11 ]], # First request hits EOS, second continues
262
+ logprobs = None ,
263
+ prompt_logprobs_dict = {})
264
+
265
+ scheduler .update_from_output (scheduler_output , model_output )
266
+
267
+ # Verify first request stopped, second continues
268
+ assert len (scheduler .running ) == 1
269
+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
270
+ assert requests [0 ].status == RequestStatus .FINISHED_STOPPED
271
+ assert requests [0 ].request_id in scheduler .finished_req_ids
272
+ assert list (requests [0 ].output_token_ids ) == [EOS_TOKEN_ID ]
273
+ assert list (requests [1 ].output_token_ids ) == [10 , 11 ]
274
+
275
+ # Test case 2: Stop on custom stop token
276
+ scheduler = create_scheduler ()
277
+ requests = create_requests (num_requests = 2 ,
278
+ max_tokens = 10 ,
279
+ stop_token_ids = [42 , 43 ])
280
+ for req in requests :
281
+ req .num_computed_tokens = req .num_tokens
282
+ scheduler .requests [req .request_id ] = req
283
+ scheduler .running .append (req )
284
+ scheduler .scheduled_req_ids .add (req .request_id )
285
+
286
+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
287
+ scheduled_cached_reqs = [],
288
+ num_scheduled_tokens = {
289
+ requests [0 ].request_id : 3 ,
290
+ requests [1 ].request_id : 2
291
+ },
292
+ total_num_scheduled_tokens = 5 ,
293
+ scheduled_encoder_inputs = {},
294
+ scheduled_spec_decode_tokens = {
295
+ requests [0 ].request_id : [10 , 42 ],
296
+ requests [1 ].request_id : [13 ]
297
+ },
298
+ num_common_prefix_blocks = 0 ,
299
+ finished_req_ids = set (),
300
+ free_encoder_input_ids = [])
301
+
302
+ model_output = ModelRunnerOutput (
303
+ req_ids = [req .request_id for req in requests ],
304
+ req_id_to_index = {
305
+ req .request_id : i
306
+ for i , req in enumerate (requests )
307
+ },
308
+ sampled_token_ids = [[10 , 42 , 12 ],
309
+ [13 , 14 ]], # First request hits stop token
310
+ logprobs = None ,
311
+ prompt_logprobs_dict = {})
312
+
313
+ scheduler .update_from_output (scheduler_output , model_output )
314
+
315
+ # Verify first request stopped on custom token
316
+ assert len (scheduler .running ) == 1
317
+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
318
+ assert requests [0 ].status == RequestStatus .FINISHED_STOPPED
319
+ assert requests [0 ].stop_reason == 42
320
+ assert requests [0 ].request_id in scheduler .finished_req_ids
321
+ assert list (requests [0 ].output_token_ids ) == [10 , 42 ]
322
+ assert list (requests [1 ].output_token_ids ) == [13 , 14 ]
323
+
324
+ # Test case 3: Stop on max tokens
325
+ scheduler = create_scheduler ()
326
+ requests = create_requests (num_requests = 2 , max_tokens = 2 )
327
+ for req in requests :
328
+ req .num_computed_tokens = req .num_tokens
329
+ scheduler .requests [req .request_id ] = req
330
+ scheduler .running .append (req )
331
+ scheduler .scheduled_req_ids .add (req .request_id )
332
+
333
+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
334
+ scheduled_cached_reqs = [],
335
+ num_scheduled_tokens = {
336
+ requests [0 ].request_id : 3 ,
337
+ requests [1 ].request_id : 1
338
+ },
339
+ total_num_scheduled_tokens = 4 ,
340
+ scheduled_encoder_inputs = {},
341
+ scheduled_spec_decode_tokens = {
342
+ requests [0 ].request_id : [10 , 11 ],
343
+ requests [1 ].request_id : []
344
+ },
345
+ num_common_prefix_blocks = 0 ,
346
+ finished_req_ids = set (),
347
+ free_encoder_input_ids = [])
348
+
349
+ model_output = ModelRunnerOutput (
350
+ req_ids = [req .request_id for req in requests ],
351
+ req_id_to_index = {
352
+ req .request_id : i
353
+ for i , req in enumerate (requests )
354
+ },
355
+ sampled_token_ids = [[10 , 11 , 12 ],
356
+ [13 ]], # First request exceeds max_tokens
357
+ logprobs = None ,
358
+ prompt_logprobs_dict = {})
359
+
360
+ scheduler .update_from_output (scheduler_output , model_output )
361
+
362
+ # Verify first request stopped due to length
363
+ assert len (scheduler .running ) == 1
364
+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
365
+ assert requests [0 ].status == RequestStatus .FINISHED_LENGTH_CAPPED
366
+ assert requests [0 ].request_id in scheduler .finished_req_ids
367
+ assert list (requests [0 ].output_token_ids ) == [10 , 11
368
+ ] # Truncated to max_tokens
369
+ assert list (requests [1 ].output_token_ids ) == [13 ]
370
+
371
+ # Test case 4: Ignore EOS flag
372
+ scheduler = create_scheduler ()
373
+ requests = create_requests (num_requests = 1 , max_tokens = 10 )
374
+ requests [0 ].sampling_params .ignore_eos = True
375
+ requests [0 ].num_computed_tokens = requests [0 ].num_tokens
376
+ scheduler .requests [requests [0 ].request_id ] = requests [0 ]
377
+ scheduler .running .append (requests [0 ])
378
+ scheduler .scheduled_req_ids .add (requests [0 ].request_id )
379
+
380
+ scheduler_output = SchedulerOutput (
381
+ scheduled_new_reqs = [],
382
+ scheduled_cached_reqs = [],
383
+ num_scheduled_tokens = {requests [0 ].request_id : 3 },
384
+ total_num_scheduled_tokens = 3 ,
385
+ scheduled_encoder_inputs = {},
386
+ scheduled_spec_decode_tokens = {
387
+ requests [0 ].request_id : [EOS_TOKEN_ID , 10 ]
388
+ },
389
+ num_common_prefix_blocks = 0 ,
390
+ finished_req_ids = set (),
391
+ free_encoder_input_ids = [])
392
+
393
+ model_output = ModelRunnerOutput (
394
+ req_ids = [requests [0 ].request_id ],
395
+ req_id_to_index = {requests [0 ].request_id : 0 },
396
+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
397
+ logprobs = None ,
398
+ prompt_logprobs_dict = {})
399
+
400
+ scheduler .update_from_output (scheduler_output , model_output )
401
+
402
+ # Verify request continues past EOS
403
+ assert len (scheduler .running ) == 1
404
+ assert not requests [0 ].is_finished ()
405
+ assert list (requests [0 ].output_token_ids ) == [EOS_TOKEN_ID , 10 , 11 ]
406
+
407
+
218
408
def test_schedule_concurrent_batches ():
219
409
scheduler = create_scheduler (
220
410
max_num_batched_tokens = 1024 ,
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
243
433
model_runner_output = ModelRunnerOutput (
244
434
req_ids = [requests [0 ].request_id ],
245
435
req_id_to_index = {requests [0 ].request_id : 0 },
246
- sampled_token_ids = [0 ],
436
+ sampled_token_ids = [[ 0 ] ],
247
437
logprobs = None ,
248
438
prompt_logprobs_dict = {},
249
439
)
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
259
449
model_runner_output = ModelRunnerOutput (
260
450
req_ids = [requests [1 ].request_id ],
261
451
req_id_to_index = {requests [1 ].request_id : 0 },
262
- sampled_token_ids = [0 ],
452
+ sampled_token_ids = [[ 0 ] ],
263
453
logprobs = None ,
264
454
prompt_logprobs_dict = {},
265
455
)
0 commit comments