Skip to content

Commit 01882b7

Browse files
csy1204조상연[플레이스 AI]
authored andcommitted
[Bugfix] remove fallback in guided_json (int range, patterns) (vllm-project#16725)
Signed-off-by: csy1204 <[email protected]> Co-authored-by: 조상연[플레이스 AI] <[email protected]>
1 parent c03ec7f commit 01882b7

File tree

6 files changed

+94
-72
lines changed

6 files changed

+94
-72
lines changed

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
305305
with pytest.raises(
306306
ValueError,
307307
match="xgrammar does not support advanced JSON schema features "
308-
"like enums, patterns or numeric ranges."):
308+
"like string length, item limits, or property bounds."):
309309
llm.generate(prompts="This should fail",
310310
sampling_params=sampling_params,
311311
use_tqdm=True)
@@ -386,6 +386,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
386386
jsonschema.validate(instance=output_json, schema=json_schema)
387387

388388

389+
@pytest.mark.skip_global_cleanup
390+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
391+
def test_guided_number_range_json_completion(llm,
392+
guided_decoding_backend: str):
393+
sample_output_schema = {
394+
"type": "object",
395+
"properties": {
396+
"age": {
397+
"type": "integer",
398+
"minimum": 18,
399+
"maximum": 99
400+
},
401+
"score": {
402+
"type": "number",
403+
"minimum": 0.0,
404+
"maximum": 100.0
405+
},
406+
"zipcode": {
407+
"type": "string",
408+
"pattern": r"^\d{5}(-\d{4})?$"
409+
},
410+
},
411+
"required": ["age", "score", "zipcode"],
412+
}
413+
sampling_params = SamplingParams(
414+
temperature=1.0,
415+
max_tokens=1000,
416+
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
417+
backend=guided_decoding_backend),
418+
)
419+
outputs = llm.generate(
420+
prompts=[
421+
"Create a JSON object for a user with age, score, and zipcode."
422+
] * 2,
423+
sampling_params=sampling_params,
424+
use_tqdm=True,
425+
)
426+
427+
assert outputs is not None
428+
429+
for output in outputs:
430+
assert output is not None
431+
assert isinstance(output, RequestOutput)
432+
prompt = output.prompt
433+
434+
generated_text = output.outputs[0].text
435+
assert generated_text is not None
436+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
437+
output_json = json.loads(generated_text)
438+
jsonschema.validate(instance=output_json, schema=sample_output_schema)
439+
assert 18 <= output_json["age"] <= 99
440+
assert 0.0 <= output_json["score"] <= 100.0
441+
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
442+
is not None)
443+
444+
389445
@pytest.mark.skip_global_cleanup
390446
def test_guidance_no_additional_properties(llm):
391447
schema = {

tests/v1/entrypoints/conftest.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def sample_json_schema():
4747
"type": "string",
4848
}
4949
},
50+
"grade": {
51+
"type": "string",
52+
"pattern": "^[A-D]$" # Regex pattern
53+
},
54+
"email": {
55+
"type": "string",
56+
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
57+
},
5058
"work_history": {
5159
"type": "array",
5260
"items": {
@@ -56,17 +64,20 @@ def sample_json_schema():
5664
"type": "string"
5765
},
5866
"duration": {
59-
"type": "number"
67+
"type": "number",
68+
"minimum": 0.0,
69+
"maximum": 100.0, # Numeric range
6070
},
6171
"position": {
6272
"type": "string"
6373
}
6474
},
65-
"required": ["company", "position"]
75+
"required": ["company", "duration", "position"]
6676
}
6777
}
6878
},
69-
"required": ["name", "age", "skills", "work_history"]
79+
"required":
80+
["name", "age", "skills", "grade", "email", "work_history"]
7081
}
7182

7283

@@ -78,27 +89,18 @@ def unsupported_json_schema():
7889
"properties": {
7990
"score": {
8091
"type": "integer",
81-
"minimum": 0,
82-
"maximum": 100 # Numeric range
83-
},
84-
"grade": {
85-
"type": "string",
86-
"pattern": "^[A-D]$" # Regex pattern
87-
},
88-
"email": {
89-
"type": "string",
90-
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
92+
"multipleOf": 5 # Numeric multiple
9193
},
9294
"tags": {
9395
"type": "array",
9496
"items": {
9597
"type": "string",
96-
"pattern":
97-
"^[a-z]{1,10}$" # Combining length and pattern restrictions
98+
"minLength": 10,
99+
"maxLength": 20
98100
}
99101
}
100102
},
101-
"required": ["score", "grade", "email", "tags"]
103+
"required": ["score", "tags"]
102104
}
103105

104106

tests/v1/structured_output/test_utils.py

Lines changed: 16 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
@pytest.fixture
1010
def unsupported_string_schemas():
1111
return [
12-
{
13-
"type": "string",
14-
"pattern": "^[a-zA-Z]+$"
15-
},
1612
{
1713
"type": "string",
1814
"format": "email"
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
2319
@pytest.fixture
2420
def unsupported_integer_schemas():
2521
return [
26-
{
27-
"type": "integer",
28-
"minimum": 0
29-
},
30-
{
31-
"type": "integer",
32-
"maximum": 120
33-
},
34-
{
35-
"type": "integer",
36-
"exclusiveMinimum": 120
37-
},
38-
{
39-
"type": "integer",
40-
"exclusiveMaximum": 120
41-
},
4222
{
4323
"type": "integer",
4424
"multipleOf": 120
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
4929
@pytest.fixture
5030
def unsupported_number_schemas():
5131
return [
52-
{
53-
"type": "number",
54-
"minimum": 0
55-
},
56-
{
57-
"type": "number",
58-
"maximum": 120
59-
},
60-
{
61-
"type": "number",
62-
"exclusiveMinimum": 120
63-
},
64-
{
65-
"type": "number",
66-
"exclusiveMaximum": 120
67-
},
6832
{
6933
"type": "number",
7034
"multipleOf": 120
@@ -156,13 +120,28 @@ def supported_schema():
156120
"type": "string",
157121
"enum": ["sedan", "suv", "truck"]
158122
},
123+
"car_brand": {
124+
"type": "string",
125+
"pattern": "^[a-zA-Z]+$"
126+
},
159127
"short_description": {
160128
"type": "string",
161129
"maxLength": 50
162130
},
131+
"mileage": {
132+
"type": "number",
133+
"minimum": 0,
134+
"maximum": 1000000
135+
},
136+
"model_year": {
137+
"type": "integer",
138+
"exclusiveMinimum": 1900,
139+
"exclusiveMaximum": 2100
140+
},
163141
"long_description": {
164142
"type": "string",
165-
"minLength": 50
143+
"minLength": 50,
144+
"maxLength": 2000
166145
},
167146
"address": {
168147
"type": "object",

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
6565
fallback_or_error(
6666
guided_params,
6767
"xgrammar does not support advanced JSON schema features like "
68-
"enums, patterns or numeric ranges.", "outlines")
68+
"string length, item limits, or property bounds.", "outlines")
6969

7070
# xgrammar only supports GBNF grammars, so we must convert Lark.
7171
# We must check if the grammar is likely Lark and if that

vllm/model_executor/guided_decoding/utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,8 @@ def check_object(obj: dict) -> bool:
1010
if not isinstance(obj, dict):
1111
return False
1212

13-
# Check for pattern restrictions
14-
if "pattern" in obj:
15-
return True
16-
1713
# Check for numeric ranges
18-
if obj.get("type") in ("integer", "number") and any(
19-
key in obj for key in [
20-
"minimum", "maximum", "exclusiveMinimum",
21-
"exclusiveMaximum", "multipleOf"
22-
]):
14+
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
2315
return True
2416

2517
# Check for array unsupported keywords

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,8 @@ def check_object(obj: dict[str, Any]) -> bool:
179179
if not isinstance(obj, dict):
180180
return False
181181

182-
# Check for pattern restrictions
183-
if "pattern" in obj:
184-
return True
185-
186182
# Check for numeric ranges
187-
if obj.get("type") in ("integer", "number") and any(
188-
key in obj
189-
for key in ("minimum", "maximum", "exclusiveMinimum",
190-
"exclusiveMaximum", "multipleOf")):
183+
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
191184
return True
192185

193186
# Check for array unsupported keywords

0 commit comments

Comments
 (0)