@@ -390,7 +390,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
390
390
@pytest .mark .parametrize ("guided_decoding_backend" , GUIDED_DECODING_BACKENDS )
391
391
def test_guided_number_range_json_completion (llm ,
392
392
guided_decoding_backend : str ):
393
- sample_number_range_schema = {
393
+ sample_output_schema = {
394
394
"type" : "object" ,
395
395
"properties" : {
396
396
"age" : {
@@ -403,23 +403,22 @@ def test_guided_number_range_json_completion(llm,
403
403
"minimum" : 0.0 ,
404
404
"maximum" : 100.0
405
405
},
406
- "level" : {
407
- "type" : "integer" ,
408
- "minimum" : 1 ,
409
- "maximum" : 10
406
+ "zipcode" : {
407
+ "type" : "string" ,
408
+ "pattern" : r"^\d{5}(-\d{4})?$"
410
409
},
411
410
},
412
- "required" : ["age" , "score" , "level " ],
411
+ "required" : ["age" , "score" , "zipcode " ],
413
412
}
414
413
sampling_params = SamplingParams (
415
414
temperature = 1.0 ,
416
415
max_tokens = 1000 ,
417
- guided_decoding = GuidedDecodingParams (json = sample_number_range_schema ,
416
+ guided_decoding = GuidedDecodingParams (json = sample_output_schema ,
418
417
backend = guided_decoding_backend ),
419
418
)
420
419
outputs = llm .generate (
421
420
prompts = [
422
- "Create a JSON object for a user with age, score, and level ."
421
+ "Create a JSON object for a user with age, score, and zipcode ."
423
422
] * 2 ,
424
423
sampling_params = sampling_params ,
425
424
use_tqdm = True ,
@@ -436,49 +435,8 @@ def test_guided_number_range_json_completion(llm,
436
435
assert generated_text is not None
437
436
print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
438
437
output_json = json .loads (generated_text )
439
- jsonschema .validate (instance = output_json ,
440
- schema = sample_number_range_schema )
438
+ jsonschema .validate (instance = output_json , schema = sample_output_schema )
441
439
assert 18 <= output_json ["age" ] <= 99
442
440
assert 0.0 <= output_json ["score" ] <= 100.0
443
- assert 1 <= output_json ["level" ] <= 10
444
-
445
-
446
- @pytest .mark .skip_global_cleanup
447
- @pytest .mark .parametrize ("guided_decoding_backend" , GUIDED_DECODING_BACKENDS )
448
- def test_guided_pattern_json_completion (llm , guided_decoding_backend : str ):
449
- sample_pattern_schema = {
450
- "type" : "object" ,
451
- "properties" : {
452
- "zipcode" : {
453
- "type" : "string" ,
454
- "pattern" : r"^\d{5}(-\d{4})?$"
455
- },
456
- },
457
- "required" : ["zipcode" ],
458
- }
459
- sampling_params = SamplingParams (
460
- temperature = 1.0 ,
461
- max_tokens = 1000 ,
462
- guided_decoding = GuidedDecodingParams (json = sample_pattern_schema ,
463
- backend = guided_decoding_backend ),
464
- )
465
- outputs = llm .generate (
466
- prompts = ["Create a JSON object for a US zipcode (5 or 9 digits)." ] * 2 ,
467
- sampling_params = sampling_params ,
468
- use_tqdm = True ,
469
- )
470
-
471
- assert outputs is not None
472
-
473
- for output in outputs :
474
- assert output is not None
475
- assert isinstance (output , RequestOutput )
476
- prompt = output .prompt
477
-
478
- generated_text = output .outputs [0 ].text
479
- assert generated_text is not None
480
- print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
481
- output_json = json .loads (generated_text )
482
- jsonschema .validate (instance = output_json , schema = sample_pattern_schema )
483
441
assert (re .fullmatch (r"^\d{5}(-\d{4})?$" , output_json ["zipcode" ])
484
442
is not None )
0 commit comments