Skip to content

Commit 7f14ee5

Browse files
tjohnson31415lk-chen
authored andcommitted
[Frontend] Support guidance:no-additional-properties for compatibility with xgrammar (vllm-project#15949)
Signed-off-by: Travis Johnson <[email protected]>
1 parent 0d94169 commit 7f14ee5

File tree

8 files changed

+201
-34
lines changed

8 files changed

+201
-34
lines changed

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,4 +383,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
383383
assert generated_text is not None
384384
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
385385
output_json = json.loads(generated_text)
386-
jsonschema.validate(instance=output_json, schema=json_schema)
386+
jsonschema.validate(instance=output_json, schema=json_schema)
387+
388+
389+
@pytest.mark.skip_global_cleanup
390+
def test_guidance_no_additional_properties(llm):
391+
schema = {
392+
'type': 'object',
393+
'properties': {
394+
'a1': {
395+
'type': 'string'
396+
},
397+
'a2': {
398+
'type': 'string'
399+
},
400+
'a3': {
401+
'type': 'string'
402+
}
403+
},
404+
'required': ['a1', 'a2', 'a3'],
405+
}
406+
407+
prompt = (
408+
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
409+
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
410+
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
411+
"<|im_end|>\n<|im_start|>assistant\n")
412+
413+
def generate_with_backend(backend):
414+
guided_params = GuidedDecodingParams(json=schema, backend=backend)
415+
sampling_params = SamplingParams(temperature=0,
416+
max_tokens=256,
417+
guided_decoding=guided_params)
418+
419+
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
420+
assert outputs is not None
421+
generated_text = outputs[0].outputs[0].text
422+
assert generated_text is not None
423+
parsed_json = json.loads(generated_text)
424+
assert isinstance(parsed_json, dict)
425+
jsonschema.validate(instance=parsed_json, schema=schema)
426+
return parsed_json
427+
428+
base_generated = generate_with_backend('guidance:disable-any-whitespace')
429+
assert "a1" in base_generated
430+
assert "a2" in base_generated
431+
assert "a3" in base_generated
432+
# by default additional keys are generated
433+
assert "a4" in base_generated
434+
assert "a5" in base_generated
435+
assert "a6" in base_generated
436+
437+
generated = generate_with_backend(
438+
'guidance:no-additional-properties,disable-any-whitespace')
439+
assert "a1" in generated
440+
assert "a2" in generated
441+
assert "a3" in generated
442+
assert "a4" not in generated
443+
assert "a5" not in generated
444+
assert "a6" not in generated

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,59 @@ def test_structured_output_auto_mode(
412412
# Parse to verify it is valid JSON
413413
parsed_json = json.loads(generated_text)
414414
assert isinstance(parsed_json, dict)
415+
416+
417+
@pytest.mark.skip_global_cleanup
418+
def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
419+
monkeypatch.setenv("VLLM_USE_V1", "1")
420+
421+
backend = 'guidance:no-additional-properties,disable-any-whitespace'
422+
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
423+
max_model_len=1024,
424+
guided_decoding_backend=backend)
425+
426+
schema = {
427+
'type': 'object',
428+
'properties': {
429+
'a1': {
430+
'type': 'string'
431+
},
432+
'a2': {
433+
'type': 'string'
434+
},
435+
'a3': {
436+
'type': 'string'
437+
}
438+
},
439+
'required': ['a1', 'a2', 'a3'],
440+
}
441+
442+
prompt = (
443+
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a "
444+
"helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a "
445+
"large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20"
446+
"<|im_end|>\n<|im_start|>assistant\n")
447+
448+
def generate_with_backend(backend):
449+
guided_params = GuidedDecodingParams(json=schema, backend=backend)
450+
sampling_params = SamplingParams(temperature=0,
451+
max_tokens=256,
452+
guided_decoding=guided_params)
453+
454+
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
455+
assert outputs is not None
456+
generated_text = outputs[0].outputs[0].text
457+
assert generated_text is not None
458+
parsed_json = json.loads(generated_text)
459+
assert isinstance(parsed_json, dict)
460+
jsonschema.validate(instance=parsed_json, schema=schema)
461+
return parsed_json
462+
463+
generated = generate_with_backend(
464+
'guidance:no-additional-properties,disable-any-whitespace')
465+
assert "a1" in generated
466+
assert "a2" in generated
467+
assert "a3" in generated
468+
assert "a4" not in generated
469+
assert "a5" not in generated
470+
assert "a6" not in generated

vllm/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3107,7 +3107,7 @@ def get_served_model_name(model: str,
31073107

31083108

31093109
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
3110-
"xgrammar"]
3110+
"xgrammar", "guidance"]
31113111
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
31123112

31133113

vllm/engine/arg_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from vllm import version
1919
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
2020
ConfigFormat, ConfigType, DecodingConfig, Device,
21-
DeviceConfig, DistributedExecutorBackend, HfOverrides,
21+
DeviceConfig, DistributedExecutorBackend,
22+
GuidedDecodingBackendV1, HfOverrides,
2223
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
2324
ModelConfig, ModelImpl, MultiModalConfig,
2425
ObservabilityConfig, ParallelConfig, PoolerConfig,
@@ -1370,14 +1371,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13701371
recommend_to_remove=True)
13711372
return False
13721373

1373-
# Xgrammar and Guidance are supported.
1374-
SUPPORTED_GUIDED_DECODING = [
1375-
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
1376-
"guidance:disable-any-whitespace", "auto"
1377-
]
1378-
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
1379-
_raise_or_fallback(feature_name="--guided-decoding-backend",
1380-
recommend_to_remove=False)
1374+
# remove backend options when doing this check
1375+
if self.guided_decoding_backend.split(':')[0] \
1376+
not in get_args(GuidedDecodingBackendV1):
1377+
_raise_or_fallback(
1378+
feature_name=
1379+
f"--guided-decoding-backend={self.guided_decoding_backend}",
1380+
recommend_to_remove=False)
13811381
return False
13821382

13831383
# Need at least Ampere for now (FA support required).

vllm/model_executor/guided_decoding/guidance_decoding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import json
23
from re import escape as regex_escape
34

45
import llguidance
@@ -7,6 +8,8 @@
78
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
89
GuidanceLogitsProcessor)
910
from vllm.sampling_params import GuidedDecodingParams
11+
from vllm.v1.structured_output.backend_guidance import (
12+
process_for_additional_properties)
1013

1114

1215
def get_local_guidance_guided_decoding_logits_processor(
@@ -20,9 +23,17 @@ def get_local_guidance_guided_decoding_logits_processor(
2023
grm = ""
2124
any_whitespace = 'disable-any-whitespace' not in \
2225
guided_params.backend_options()
23-
if guided_params.json:
26+
if (guide_json := guided_params.json) is not None:
27+
# Optionally set additionalProperties to False at the top-level
28+
# By default, other backends do not allow additional top-level
29+
# properties, so this makes guidance more similar to other backends
30+
if 'no-additional-properties' in guided_params.backend_options():
31+
if not isinstance(guide_json, str):
32+
guide_json = json.dumps(guide_json)
33+
guide_json = process_for_additional_properties(guide_json)
34+
2435
grm = llguidance.LLMatcher.grammar_from_json_schema(
25-
guided_params.json,
36+
guide_json,
2637
overrides={"whitespace_pattern": guided_params.whitespace_pattern},
2738
defaults={
2839
"whitespace_flexible": any_whitespace,

vllm/v1/engine/processor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,7 @@ def _validate_structured_output(self, params: SamplingParams) -> None:
145145
if not params.guided_decoding or not self.decoding_config:
146146
return
147147

148-
supported_backends = [
149-
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
150-
"guidance:disable-any-whitespace", "auto"
151-
]
152-
153148
engine_level_backend = self.decoding_config.guided_decoding_backend
154-
if engine_level_backend not in supported_backends:
155-
raise ValueError(f"Only {supported_backends} structured output is "
156-
"supported in V1.")
157149
if params.guided_decoding.backend:
158150
# Request-level backend selection is not supported in V1.
159151
# The values may differ if `params` is reused and was set

vllm/v1/structured_output/backend_guidance.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import copy
4+
import json
35
import os
46
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Optional
7+
from typing import TYPE_CHECKING, Any, Optional, Union
68

79
import torch
810

911
from vllm.config import VllmConfig
1012
from vllm.logger import init_logger
11-
from vllm.sampling_params import SamplingParams
13+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1214
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
1315
from vllm.utils import LazyLoader
1416
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
@@ -29,6 +31,29 @@
2931
logger = init_logger(__name__)
3032

3133

34+
def _walk_json_for_additional_properties(data: object):
35+
if isinstance(data, dict):
36+
for value in data.values():
37+
_walk_json_for_additional_properties(value)
38+
if 'additionalProperties' not in data and \
39+
('properties' in data or 'patternProperties' in data):
40+
data['additionalProperties'] = False
41+
elif isinstance(data, list):
42+
for item in data:
43+
_walk_json_for_additional_properties(item)
44+
45+
46+
def process_for_additional_properties(
47+
guide_json: Union[str, dict[str, Any]]) -> dict[str, Any]:
48+
if isinstance(guide_json, str):
49+
guide_json_obj = json.loads(guide_json)
50+
else:
51+
# copy for modifications
52+
guide_json_obj = copy.deepcopy(guide_json)
53+
_walk_json_for_additional_properties(guide_json_obj)
54+
return guide_json_obj
55+
56+
3257
class GuidanceBackend(StructuredOutputBackend):
3358

3459
def __init__(self, vllm_config: VllmConfig):
@@ -41,9 +66,20 @@ def __init__(self, vllm_config: VllmConfig):
4166
tokenizer_group.ping()
4267
self.vllm_config = vllm_config
4368
self.vocab_size = vllm_config.model_config.get_vocab_size()
44-
self.disable_any_whitespace = (
45-
"disable-any-whitespace"
46-
in vllm_config.decoding_config.guided_decoding_backend)
69+
70+
self.disable_any_whitespace = False
71+
self.no_additional_properties = False
72+
backend_options = GuidedDecodingParams(
73+
backend=vllm_config.decoding_config.guided_decoding_backend
74+
).backend_options()
75+
for option in backend_options:
76+
if option == "disable-any-whitespace":
77+
self.disable_any_whitespace = True
78+
elif option == "no-additional-properties":
79+
self.no_additional_properties = True
80+
else:
81+
raise ValueError(
82+
f"Unsupported option for the guidance backend: {option}")
4783

4884
tokenizer = tokenizer_group.get_lora_tokenizer(None)
4985
self.ll_tokenizer = llguidance_hf.from_tokenizer(
@@ -52,7 +88,8 @@ def __init__(self, vllm_config: VllmConfig):
5288
def compile_grammar(self, request_type: StructuredOutputOptions,
5389
grammar_spec: str) -> StructuredOutputGrammar:
5490
self.serialized_grammar = serialize_guidance_grammar(
55-
request_type, grammar_spec, self.disable_any_whitespace)
91+
request_type, grammar_spec, self.disable_any_whitespace,
92+
self.no_additional_properties)
5693

5794
ll_matcher = llguidance.LLMatcher(
5895
self.ll_tokenizer,
@@ -129,10 +166,15 @@ def reset(self):
129166
self.ll_matcher.reset()
130167

131168

132-
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
133-
grammar_spec: str,
134-
disable_any_whitespace: bool = False) -> str:
169+
def serialize_guidance_grammar(
170+
request_type: StructuredOutputOptions,
171+
grammar_spec: Union[str, dict[str, Any]],
172+
disable_any_whitespace: bool = False,
173+
no_additional_properties: bool = False,
174+
) -> str:
135175
if request_type == StructuredOutputOptions.JSON:
176+
if no_additional_properties:
177+
grammar_spec = process_for_additional_properties(grammar_spec)
136178
return llguidance.LLMatcher.grammar_from_json_schema(
137179
grammar_spec,
138180
defaults={

vllm/v1/structured_output/backend_xgrammar.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import vllm.envs
1010
from vllm.config import VllmConfig
1111
from vllm.logger import init_logger
12-
from vllm.sampling_params import SamplingParams
12+
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1313
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
1414
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1515
from vllm.utils import LazyLoader
@@ -32,16 +32,24 @@ class XgrammarBackend(StructuredOutputBackend):
3232

3333
def __init__(self, vllm_config: VllmConfig):
3434
self.vllm_config = vllm_config
35-
self.disable_any_whitespace = (
36-
"disable-any-whitespace"
37-
in vllm_config.decoding_config.guided_decoding_backend)
3835
tokenizer_group = init_tokenizer_from_configs(
3936
model_config=vllm_config.model_config,
4037
scheduler_config=vllm_config.scheduler_config,
4138
parallel_config=vllm_config.parallel_config,
4239
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
4340
tokenizer_group.ping()
4441

42+
self.disable_any_whitespace = False
43+
backend_options = GuidedDecodingParams(
44+
backend=vllm_config.decoding_config.guided_decoding_backend
45+
).backend_options()
46+
for option in backend_options:
47+
if option == "disable-any-whitespace":
48+
self.disable_any_whitespace = True
49+
else:
50+
raise ValueError(
51+
f"Unsupported option for the xgrammar backend: {option}")
52+
4553
tokenizer = tokenizer_group.get_lora_tokenizer(None)
4654
self.vocab_size = vllm_config.model_config.get_vocab_size()
4755
if isinstance(tokenizer, MistralTokenizer):

0 commit comments

Comments
 (0)