Skip to content

Commit a8e56e1

Browse files
committed
fix: zip test cases
Signed-off-by: csy1204 <[email protected]>
1 parent a7d9b7b commit a8e56e1

File tree

1 file changed

+8
-50
lines changed

1 file changed

+8
-50
lines changed

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
390390
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
391391
def test_guided_number_range_json_completion(llm,
392392
guided_decoding_backend: str):
393-
sample_number_range_schema = {
393+
sample_output_schema = {
394394
"type": "object",
395395
"properties": {
396396
"age": {
@@ -403,23 +403,22 @@ def test_guided_number_range_json_completion(llm,
403403
"minimum": 0.0,
404404
"maximum": 100.0
405405
},
406-
"level": {
407-
"type": "integer",
408-
"minimum": 1,
409-
"maximum": 10
406+
"zipcode": {
407+
"type": "string",
408+
"pattern": r"^\d{5}(-\d{4})?$"
410409
},
411410
},
412-
"required": ["age", "score", "level"],
411+
"required": ["age", "score", "zipcode"],
413412
}
414413
sampling_params = SamplingParams(
415414
temperature=1.0,
416415
max_tokens=1000,
417-
guided_decoding=GuidedDecodingParams(json=sample_number_range_schema,
416+
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
418417
backend=guided_decoding_backend),
419418
)
420419
outputs = llm.generate(
421420
prompts=[
422-
"Create a JSON object for a user with age, score, and level."
421+
"Create a JSON object for a user with age, score, and zipcode."
423422
] * 2,
424423
sampling_params=sampling_params,
425424
use_tqdm=True,
@@ -436,49 +435,8 @@ def test_guided_number_range_json_completion(llm,
436435
assert generated_text is not None
437436
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
438437
output_json = json.loads(generated_text)
439-
jsonschema.validate(instance=output_json,
440-
schema=sample_number_range_schema)
438+
jsonschema.validate(instance=output_json, schema=sample_output_schema)
441439
assert 18 <= output_json["age"] <= 99
442440
assert 0.0 <= output_json["score"] <= 100.0
443-
assert 1 <= output_json["level"] <= 10
444-
445-
446-
@pytest.mark.skip_global_cleanup
447-
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
448-
def test_guided_pattern_json_completion(llm, guided_decoding_backend: str):
449-
sample_pattern_schema = {
450-
"type": "object",
451-
"properties": {
452-
"zipcode": {
453-
"type": "string",
454-
"pattern": r"^\d{5}(-\d{4})?$"
455-
},
456-
},
457-
"required": ["zipcode"],
458-
}
459-
sampling_params = SamplingParams(
460-
temperature=1.0,
461-
max_tokens=1000,
462-
guided_decoding=GuidedDecodingParams(json=sample_pattern_schema,
463-
backend=guided_decoding_backend),
464-
)
465-
outputs = llm.generate(
466-
prompts=["Create a JSON object for a US zipcode (5 or 9 digits)."] * 2,
467-
sampling_params=sampling_params,
468-
use_tqdm=True,
469-
)
470-
471-
assert outputs is not None
472-
473-
for output in outputs:
474-
assert output is not None
475-
assert isinstance(output, RequestOutput)
476-
prompt = output.prompt
477-
478-
generated_text = output.outputs[0].text
479-
assert generated_text is not None
480-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
481-
output_json = json.loads(generated_text)
482-
jsonschema.validate(instance=output_json, schema=sample_pattern_schema)
483441
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
484442
is not None)

0 commit comments

Comments
 (0)