Skip to content

Commit 002848a

Browse files
committed
[Bugfix][Structured Output] Support outlines with reasoning outputs
Signed-off-by: Ce Gao <[email protected]>
1 parent 7f89a59 commit 002848a

File tree

7 files changed

+102
-9
lines changed

7 files changed

+102
-9
lines changed

examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,42 @@
3333
models = client.models.list()
3434
model = models.data[0].id
3535

36+
# Guided decoding by Regex
37+
prompt = ("What is the capital of France?")
38+
39+
completion = client.chat.completions.create(
40+
model=model,
41+
messages=[{
42+
"role": "user",
43+
"content": prompt,
44+
}],
45+
extra_body={
46+
"guided_regex": "(Paris|London)",
47+
},
48+
)
49+
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
50+
print("content: ", completion.choices[0].message.content)
51+
52+
53+
class People(BaseModel):
54+
name: str
55+
age: int
56+
57+
58+
json_schema = People.model_json_schema()
59+
60+
prompt = ("Generate a JSON with the name and age of one random person.")
61+
completion = client.chat.completions.create(
62+
model=model,
63+
messages=[{
64+
"role": "user",
65+
"content": prompt,
66+
}],
67+
extra_body={"guided_json": json_schema},
68+
)
69+
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
70+
print("content: ", completion.choices[0].message.content)
71+
3672

3773
# Guided decoding by JSON using Pydantic schema
3874
class CarType(str, Enum):
@@ -51,7 +87,7 @@ class CarDescription(BaseModel):
5187
json_schema = CarDescription.model_json_schema()
5288

5389
prompt = ("Generate a JSON with the brand, model and car_type of"
54-
"the most iconic car from the 90's, think in 100 tokens")
90+
"the most iconic car from the 90's")
5591
completion = client.chat.completions.create(
5692
model=model,
5793
messages=[{
@@ -60,5 +96,34 @@ class CarDescription(BaseModel):
6096
}],
6197
extra_body={"guided_json": json_schema},
6298
)
63-
print("content", completion.choices[0].message.content)
6499
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
100+
print("content: ", completion.choices[0].message.content)
101+
102+
# Guided decoding by Grammar
103+
simplified_sql_grammar = """
104+
?start: select_statement
105+
106+
?select_statement: "SELECT " column_list " FROM " table_name
107+
108+
?column_list: column_name ("," column_name)*
109+
110+
?table_name: identifier
111+
112+
?column_name: identifier
113+
114+
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
115+
"""
116+
117+
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
118+
prompt = ("Generate an SQL query to show the 'username' and 'email'"
119+
"from the 'users' table.")
120+
completion = client.chat.completions.create(
121+
model=model,
122+
messages=[{
123+
"role": "user",
124+
"content": prompt,
125+
}],
126+
extra_body={"guided_grammar": simplified_sql_grammar},
127+
)
128+
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
129+
print("content: ", completion.choices[0].message.content)

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ async def get_guided_decoding_logits_processor(
112112
reasoner = get_reasoner(tokenizer, reasoning_backend)
113113

114114
guided_params = maybe_backend_fallback(guided_params)
115+
115116
# CFG grammar not supported by LMFE, so we use outlines instead
116117
if guided_params.backend_name == 'outlines':
117118
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193

vllm/model_executor/guided_decoding/outlines_logits_processors.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class BaseLogitsProcessor:
4343

4444
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
4545
self._guide: Guide = guide
46-
self._reasoner = reasoner
46+
self._reasoner: Optional[Reasoner] = reasoner
4747
# CFGState is used for the FSM state for CFGGuide
4848
self._fsm_state: DefaultDict[int, Union[int,
4949
CFGState]] = defaultdict(int)
@@ -54,10 +54,14 @@ def __call__(self, input_ids: List[int],
5454

5555
# Skip the structured logits processing if reasoning is not finished.
5656
# reasoner is not None only when `--enable-reasoning` is set.
57-
if self._reasoner is not None and \
58-
not self._reasoner.is_reasoning_end(
59-
input_ids):
60-
return scores
57+
if self._reasoner is not None:
58+
if not self._reasoner.is_reasoning_end(input_ids):
59+
return scores
60+
else:
61+
# Remove the reasoning tokens from the input_ids
62+
# We need this because our implementation relies on the
63+
# hash of the input_ids to store the FSM state.
64+
input_ids = self._reasoner.extract_content(input_ids)
6165

6266
seq_id = hash(tuple(input_ids))
6367

vllm/model_executor/guided_decoding/reasoner/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55
from transformers import PreTrainedTokenizer
66

7+
from vllm.logger import init_logger
78
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
89
DeepSeekReasoner)
910
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
1011

12+
logger = init_logger(__name__)
13+
1114

1215
def get_reasoner(tokenizer: PreTrainedTokenizer,
1316
reasoning_backend: str | None) -> Reasoner | None:
@@ -17,7 +20,12 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
1720
elif reasoning_backend == "deepseek_r1":
1821
return DeepSeekReasoner.from_tokenizer(tokenizer)
1922
else:
20-
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
23+
# Raise a warning for unknown reasoning backend and return None
24+
# We cannot raise an error here because some reasoning models
25+
# may not have a corresponding Reasoner class.
26+
logger.warning("Unknown reasoning backend %s for structured outputs ",
27+
reasoning_backend)
28+
return None
2129

2230

2331
__all__ = ["Reasoner", "get_reasoner"]

vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,14 @@ def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
2626

2727
def is_reasoning_end(self, input_ids: list[int]) -> bool:
2828
return self.end_token_id in input_ids
29+
30+
def extract_content(self, input_ids: list[int]) -> list[int]:
31+
"""
32+
Extract the content after the end tokens
33+
"""
34+
if self.end_token_id not in input_ids:
35+
return input_ids
36+
elif input_ids.index(self.end_token_id) + 1 == len(input_ids):
37+
return []
38+
else:
39+
return input_ids[input_ids.index(self.end_token_id) + 1:]

vllm/model_executor/guided_decoding/reasoner/reasoner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@ def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
1717
@abstractmethod
1818
def is_reasoning_end(self, input_ids: list[int]) -> bool:
1919
pass
20+
21+
@abstractmethod
22+
def extract_content(self, input_ids: list[int]) -> list[int]:
23+
pass

vllm/model_executor/guided_decoding/xgrammar_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def __call__(self, input_ids: list[int],
392392
def clone(self) -> XGrammarLogitsProcessor:
393393
"""Create a new instance with shared compiled grammar
394394
but separate state"""
395-
new_processor = XGrammarLogitsProcessor(self.config)
395+
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
396396

397397
# Share the compiled grammar context (immutable after compilation)
398398
new_processor.ctx = self.ctx

0 commit comments

Comments
 (0)