Skip to content

Commit c67f2a8

Browse files
russellbwuisawesome
authored andcommitted
[V1] Add structural_tag support using xgrammar (vllm-project#17085)
1 parent 810fd2b commit c67f2a8

File tree

10 files changed

+270
-15
lines changed

10 files changed

+270
-15
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from openai import OpenAI
3+
4+
# This example demonstrates the `structural_tag` response format.
5+
# It can be used to specify a structured output format that occurs between
6+
# specific tags in the response. This example shows how it could be used
7+
# to enforce the format of a tool call response, but it could be used for
8+
# any structured output within a subset of the response.
9+
10+
11+
def main():
12+
client = OpenAI(
13+
base_url="http://localhost:8000/v1",
14+
api_key="-",
15+
)
16+
17+
messages = [{
18+
"role":
19+
"user",
20+
"content":
21+
"""
22+
You have access to the following function to retrieve the weather in a city:
23+
24+
{
25+
"name": "get_weather",
26+
"parameters": {
27+
"city": {
28+
"param_type": "string",
29+
"description": "The city to get the weather for",
30+
"required": True
31+
}
32+
}
33+
}
34+
35+
If a you choose to call a function ONLY reply in the following format:
36+
<{start_tag}={function_name}>{parameters}{end_tag}
37+
where
38+
39+
start_tag => `<function`
40+
parameters => a JSON dict with the function argument name as key and function
41+
argument value as value.
42+
end_tag => `</function>`
43+
44+
Here is an example,
45+
<function=example_function_name>{"example_name": "example_value"}</function>
46+
47+
Reminder:
48+
- Function calls MUST follow the specified format
49+
- Required parameters MUST be specified
50+
- Only call one function at a time
51+
- Put the entire function call reply on one line
52+
- Always add your sources when using search results to answer the user query
53+
54+
You are a helpful assistant.
55+
56+
Given the previous instructions, what is the weather in New York City, Boston,
57+
and San Francisco?
58+
"""
59+
}]
60+
61+
response = client.chat.completions.create(
62+
model="meta-llama/Llama-3.1-8B-Instruct",
63+
messages=messages,
64+
response_format={
65+
"type":
66+
"structural_tag",
67+
"structures": [{
68+
"begin": "<function=get_weather>",
69+
"schema": {
70+
"type": "object",
71+
"properties": {
72+
"city": {
73+
"type": "string"
74+
}
75+
}
76+
},
77+
"end": "</function>"
78+
}],
79+
"triggers": ["<function="]
80+
})
81+
print(response)
82+
83+
84+
if __name__ == "__main__":
85+
main()

tests/v1/entrypoints/llm/test_struct_output_generate.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def test_structured_output(
350350
temperature=1.0,
351351
max_tokens=1000,
352352
guided_decoding=GuidedDecodingParams(json=json_schema))
353+
353354
outputs = llm.generate(
354355
prompts="Generate a description of a frog using 50 characters.",
355356
sampling_params=sampling_params,
@@ -368,6 +369,106 @@ def test_structured_output(
368369
output_json = json.loads(generated_text)
369370
jsonschema.validate(instance=output_json, schema=json_schema)
370371

372+
#
373+
# Test 11: Generate structured output using structural_tag format
374+
#
375+
structural_tag_config = {
376+
"type":
377+
"structural_tag",
378+
"structures": [{
379+
"begin": "<function=get_weather>",
380+
"schema": {
381+
"type": "object",
382+
"properties": {
383+
"city": {
384+
"type": "string"
385+
}
386+
}
387+
},
388+
"end": "</function>"
389+
}],
390+
"triggers": ["<function="]
391+
}
392+
393+
sampling_params = SamplingParams(
394+
temperature=0.0,
395+
max_tokens=100,
396+
guided_decoding=GuidedDecodingParams(
397+
structural_tag=json.dumps(structural_tag_config)))
398+
399+
prompt = """
400+
You have access to the following function to retrieve the weather in a city:
401+
402+
{
403+
"name": "get_weather",
404+
"parameters": {
405+
"city": {
406+
"param_type": "string",
407+
"description": "The city to get the weather for",
408+
"required": True
409+
}
410+
}
411+
}
412+
413+
If a you choose to call a function ONLY reply in the following format:
414+
<{start_tag}={function_name}>{parameters}{end_tag}
415+
where
416+
417+
start_tag => `<function`
418+
parameters => a JSON dict with the function argument name
419+
as key and function argument value as value.
420+
end_tag => `</function>`
421+
422+
Here is an example,
423+
<function=example_function_name>{"example_name": "example_value"}</function>
424+
425+
Reminder:
426+
- Function calls MUST follow the specified format
427+
- Required parameters MUST be specified
428+
- Only call one function at a time
429+
- Put the entire function call reply on one line
430+
- Always add your sources when using search results to answer the user query
431+
432+
You are a helpful assistant.
433+
434+
Given the previous instructions, what is the weather in New York City?
435+
"""
436+
437+
# Change this once other backends support structural_tag
438+
if guided_decoding_backend.startswith("xgrammar"):
439+
outputs = llm.generate(prompts=prompt,
440+
sampling_params=sampling_params,
441+
use_tqdm=True)
442+
assert outputs is not None
443+
else:
444+
outputs = []
445+
446+
for output in outputs:
447+
assert output is not None
448+
assert isinstance(output, RequestOutput)
449+
generated_text = output.outputs[0].text
450+
assert generated_text is not None
451+
452+
# Search for function call pattern in the response
453+
function_call_pattern = r'<function=get_weather>(.*?)</function>'
454+
matches = re.findall(function_call_pattern, generated_text)
455+
456+
if not matches:
457+
print(f"Warning: No function calls found in response: "
458+
f"{generated_text!r}")
459+
continue
460+
461+
# Take the first function call if multiple are found
462+
json_str = matches[0]
463+
try:
464+
json_content = json.loads(json_str)
465+
assert "city" in json_content
466+
assert isinstance(json_content["city"], str)
467+
print(f"Found valid function call: {generated_text!r}")
468+
except (json.JSONDecodeError, AssertionError) as e:
469+
pytest.fail("Invalid function call format: "
470+
f"{generated_text!r}\nError: {str(e)}")
471+
371472

372473
@pytest.mark.skip_global_cleanup
373474
@pytest.mark.parametrize("model_name, tokenizer_mode",

vllm/entrypoints/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,7 +1396,9 @@ def _add_guided_params(
13961396
grammar=guided_options.guided_grammar,
13971397
json_object=guided_options.guided_json_object,
13981398
backend=guided_options.guided_decoding_backend,
1399-
whitespace_pattern=guided_options.guided_whitespace_pattern)
1399+
whitespace_pattern=guided_options.guided_whitespace_pattern,
1400+
structural_tag=guided_options.structural_tag,
1401+
)
14001402
return params
14011403

14021404
def _run_engine(

vllm/entrypoints/openai/protocol.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Adapted from
44
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
5+
import json
56
import re
67
import time
78
from argparse import Namespace
@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
139140
strict: Optional[bool] = None
140141

141142

143+
class StructuralTag(OpenAIBaseModel):
144+
begin: str
145+
# schema is the field, but that causes conflicts with pydantic so
146+
# instead use structural_tag_schema with an alias
147+
structural_tag_schema: Optional[dict[str, Any]] = Field(default=None,
148+
alias="schema")
149+
end: str
150+
151+
152+
class StructuralTagResponseFormat(OpenAIBaseModel):
153+
type: Literal["structural_tag"]
154+
structures: list[StructuralTag]
155+
triggers: list[str]
156+
157+
142158
class ResponseFormat(OpenAIBaseModel):
143-
# type must be "json_schema", "json_object" or "text"
159+
# type must be "json_schema", "json_object", or "text"
144160
type: Literal["text", "json_object", "json_schema"]
145161
json_schema: Optional[JsonSchemaResponseFormat] = None
146162

147163

164+
AnyResponseFormat = Union[ResponseFormat, StructuralTagResponseFormat]
165+
166+
148167
class StreamOptions(OpenAIBaseModel):
149168
include_usage: Optional[bool] = True
150169
continuous_usage_stats: Optional[bool] = False
@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
227246
max_completion_tokens: Optional[int] = None
228247
n: Optional[int] = 1
229248
presence_penalty: Optional[float] = 0.0
230-
response_format: Optional[ResponseFormat] = None
249+
response_format: Optional[AnyResponseFormat] = None
231250
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
232251
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
233252
stream: Optional[bool] = False
@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
340359
description=(
341360
"If specified, the output will follow the context free grammar."),
342361
)
362+
structural_tag: Optional[str] = Field(
363+
default=None,
364+
description=(
365+
"If specified, the output will follow the structural tag schema."),
366+
)
343367
guided_decoding_backend: Optional[str] = Field(
344368
default=None,
345369
description=(
@@ -476,6 +500,12 @@ def to_sampling_params(
476500
json_schema = self.response_format.json_schema
477501
assert json_schema is not None
478502
self.guided_json = json_schema.json_schema
503+
elif self.response_format.type == "structural_tag":
504+
structural_tag = self.response_format
505+
assert structural_tag is not None and isinstance(
506+
structural_tag, StructuralTagResponseFormat)
507+
s_tag_obj = structural_tag.model_dump(by_alias=True)
508+
self.structural_tag = json.dumps(s_tag_obj)
479509

480510
guided_decoding = GuidedDecodingParams.from_optional(
481511
json=self._get_guided_json_from_tool() or self.guided_json,
@@ -485,6 +515,7 @@ def to_sampling_params(
485515
json_object=guided_json_object,
486516
backend=self.guided_decoding_backend,
487517
whitespace_pattern=self.guided_whitespace_pattern,
518+
structural_tag=self.structural_tag,
488519
)
489520

490521
return SamplingParams.from_optional(
@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
742773
"If true (the default), special tokens (e.g. BOS) will be added to "
743774
"the prompt."),
744775
)
745-
response_format: Optional[ResponseFormat] = Field(
776+
response_format: Optional[AnyResponseFormat] = Field(
746777
default=None,
747-
description=
748-
("Similar to chat completion, this parameter specifies the format of "
749-
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
750-
"{'type': 'text' } is supported."),
778+
description=(
779+
"Similar to chat completion, this parameter specifies the format "
780+
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
781+
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
782+
),
751783
)
752784
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
753785
default=None,

vllm/model_executor/guided_decoding/guided_fields.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ class GuidedDecodingRequest:
2727
guided_decoding_backend: Optional[str] = None
2828
guided_whitespace_pattern: Optional[str] = None
2929
guided_json_object: Optional[bool] = None
30+
structural_tag: Optional[str] = None
3031

3132
def __post_init__(self):
3233
"""Validate that some fields are mutually exclusive."""
33-
guide_count = sum([
34-
self.guided_json is not None, self.guided_regex is not None,
35-
self.guided_choice is not None, self.guided_grammar is not None,
36-
self.guided_json_object is not None
37-
])
34+
guide_count = sum(x is not None
35+
for x in (self.guided_json, self.guided_regex,
36+
self.guided_choice, self.guided_grammar,
37+
self.guided_json_object,
38+
self.structural_tag))
3839
if guide_count > 1:
3940
raise ValueError(
4041
"You can only use one kind of guided decoding but multiple are "

vllm/sampling_params.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class GuidedDecodingParams:
3838
"""These are other options that can be set"""
3939
backend: Optional[str] = None
4040
whitespace_pattern: Optional[str] = None
41+
structural_tag: Optional[str] = None
4142

4243
@staticmethod
4344
def from_optional(
@@ -48,9 +49,10 @@ def from_optional(
4849
json_object: Optional[bool] = None,
4950
backend: Optional[str] = None,
5051
whitespace_pattern: Optional[str] = None,
52+
structural_tag: Optional[str] = None,
5153
) -> Optional["GuidedDecodingParams"]:
52-
if all(arg is None
53-
for arg in (json, regex, choice, grammar, json_object)):
54+
if all(arg is None for arg in (json, regex, choice, grammar,
55+
json_object, structural_tag)):
5456
return None
5557
# Extract json schemas from pydantic models
5658
if isinstance(json, (BaseModel, type(BaseModel))):
@@ -63,6 +65,7 @@ def from_optional(
6365
json_object=json_object,
6466
backend=backend,
6567
whitespace_pattern=whitespace_pattern,
68+
structural_tag=structural_tag,
6669
)
6770

6871
@property

vllm/v1/structured_output/backend_guidance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def serialize_guidance_grammar(
194194
tp = "grammar"
195195
elif request_type == StructuredOutputOptions.CHOICE:
196196
tp = "choice"
197+
elif request_type == StructuredOutputOptions.STRUCTURAL_TAG:
198+
raise ValueError("Structural tag is not supported "
199+
"for guidance backend yet")
197200
else:
198201
logger.error("Validation should have already occurred. "
199202
"Please file an issue.")

vllm/v1/structured_output/backend_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class StructuredOutputOptions(enum.Enum):
1212
REGEX = enum.auto()
1313
GRAMMAR = enum.auto()
1414
CHOICE = enum.auto()
15+
STRUCTURAL_TAG = enum.auto()
1516

1617

1718
StructuredOutputKey = tuple[StructuredOutputOptions, str]

0 commit comments

Comments
 (0)