@@ -388,13 +388,26 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
388
388
389
389
@pytest .mark .skip_global_cleanup
390
390
@pytest .mark .parametrize ("guided_decoding_backend" , GUIDED_DECODING_BACKENDS )
391
- def test_guided_number_range_json_completion (llm , guided_decoding_backend : str ):
391
+ def test_guided_number_range_json_completion (llm ,
392
+ guided_decoding_backend : str ):
392
393
sample_number_range_schema = {
393
394
"type" : "object" ,
394
395
"properties" : {
395
- "age" : {"type" : "integer" , "minimum" : 18 , "maximum" : 99 },
396
- "score" : {"type" : "number" , "minimum" : 0.0 , "maximum" : 100.0 },
397
- "level" : {"type" : "integer" , "minimum" : 1 , "maximum" : 10 },
396
+ "age" : {
397
+ "type" : "integer" ,
398
+ "minimum" : 18 ,
399
+ "maximum" : 99
400
+ },
401
+ "score" : {
402
+ "type" : "number" ,
403
+ "minimum" : 0.0 ,
404
+ "maximum" : 100.0
405
+ },
406
+ "level" : {
407
+ "type" : "integer" ,
408
+ "minimum" : 1 ,
409
+ "maximum" : 10
410
+ },
398
411
},
399
412
"required" : ["age" , "score" , "level" ]
400
413
}
@@ -420,7 +433,8 @@ def test_guided_number_range_json_completion(llm, guided_decoding_backend: str):
420
433
assert generated_text is not None
421
434
print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
422
435
output_json = json .loads (generated_text )
423
- jsonschema .validate (instance = output_json , schema = sample_number_range_schema )
436
+ jsonschema .validate (instance = output_json ,
437
+ schema = sample_number_range_schema )
424
438
assert 18 <= output_json ["age" ] <= 99
425
439
assert 0.0 <= output_json ["score" ] <= 100.0
426
440
assert 1 <= output_json ["level" ] <= 10
@@ -432,7 +446,10 @@ def test_guided_pattern_json_completion(llm, guided_decoding_backend: str):
432
446
sample_pattern_schema = {
433
447
"type" : "object" ,
434
448
"properties" : {
435
- "zipcode" : {"type" : "string" , "pattern" : r"^\\d{5}(-\\d{4})?$" },
449
+ "zipcode" : {
450
+ "type" : "string" ,
451
+ "pattern" : r"^\\d{5}(-\\d{4})?$"
452
+ },
436
453
},
437
454
"required" : ["zipcode" ]
438
455
}
@@ -441,11 +458,10 @@ def test_guided_pattern_json_completion(llm, guided_decoding_backend: str):
441
458
guided_decoding = GuidedDecodingParams (
442
459
json = sample_pattern_schema ,
443
460
backend = guided_decoding_backend ))
444
- outputs = llm .generate (prompts = [
445
- "Create a JSON object for a US zipcode (5 or 9 digits)."
446
- ] * 2 ,
447
- sampling_params = sampling_params ,
448
- use_tqdm = True )
461
+ outputs = llm .generate (
462
+ prompts = ["Create a JSON object for a US zipcode (5 or 9 digits)." ] * 2 ,
463
+ sampling_params = sampling_params ,
464
+ use_tqdm = True )
449
465
450
466
assert outputs is not None
451
467
@@ -459,4 +475,5 @@ def test_guided_pattern_json_completion(llm, guided_decoding_backend: str):
459
475
print (f"Prompt: { prompt !r} , Generated text: { generated_text !r} " )
460
476
output_json = json .loads (generated_text )
461
477
jsonschema .validate (instance = output_json , schema = sample_pattern_schema )
462
- assert re .fullmatch (r"^\d{5}(-\d{4})?$" , output_json ["zipcode" ]) is not None
478
+ assert re .fullmatch (r"^\d{5}(-\d{4})?$" ,
479
+ output_json ["zipcode" ]) is not None
0 commit comments