Skip to content

Commit 852ca12

Browse files
committed
[Bugfix] add qwen3 reasoning-parser fix content is None when disable thinking (#17357)
Signed-off-by: mofanke <[email protected]>
1 parent 97cc872 commit 852ca12

File tree

4 files changed

+284
-0
lines changed

4 files changed

+284
-0
lines changed

docs/source/features/reasoning_outputs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ vLLM currently supports the following reasoning models:
1515
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` ||
1616
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` ||
1717
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` |||
18+
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` ||
1819

1920
- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
2021

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import pytest
4+
from transformers import AutoTokenizer
5+
6+
from tests.reasoning.utils import run_reasoning_extraction
7+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
8+
9+
parser_name = "qwen3"
10+
start_token = "<think>"
11+
end_token = "</think>"
12+
13+
REASONING_MODEL_NAME = "Qwen/Qwen3-0.6B"
14+
15+
16+
@pytest.fixture(scope="module")
17+
def qwen3_tokenizer():
18+
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
19+
20+
21+
# 带 <think></think>,非stream
22+
WITH_THINK = {
23+
"output": "<think>This is a reasoning section</think>This is the rest",
24+
"reasoning_content": "This is a reasoning section",
25+
"content": "This is the rest",
26+
}
27+
# 带 <think></think>,stream
28+
WITH_THINK_STREAM = {
29+
"output": "<think>This is a reasoning section</think>This is the rest",
30+
"reasoning_content": "This is a reasoning section",
31+
"content": "This is the rest",
32+
}
33+
# 不带 <think></think>,非stream
34+
WITHOUT_THINK = {
35+
"output": "This is the rest",
36+
"reasoning_content": None,
37+
"content": "This is the rest",
38+
}
39+
# 不带 <think></think>,stream
40+
WITHOUT_THINK_STREAM = {
41+
"output": "This is the rest",
42+
"reasoning_content": None,
43+
"content": "This is the rest",
44+
}
45+
46+
COMPLETE_REASONING = {
47+
"output": "<think>This is a reasoning section</think>",
48+
"reasoning_content": "This is a reasoning section",
49+
"content": None,
50+
}
51+
MULTILINE_REASONING = {
52+
"output": "<think>This is a reasoning\nsection</think>This is the rest\nThat",
53+
"reasoning_content": "This is a reasoning\nsection",
54+
"content": "This is the rest\nThat",
55+
}
56+
ONLY_OPEN_TAG = {
57+
"output": "<think>This is a reasoning section",
58+
"reasoning_content": None,
59+
"content": "<think>This is a reasoning section",
60+
}
61+
62+
63+
ONLY_OPEN_TAG_STREAM = {
64+
"output": "<think>This is a reasoning section",
65+
"reasoning_content": "This is a reasoning section",
66+
"content": None,
67+
}
68+
69+
70+
TEST_CASES = [
71+
pytest.param(
72+
False,
73+
WITH_THINK,
74+
id="with_think",
75+
),
76+
pytest.param(
77+
True,
78+
WITH_THINK_STREAM,
79+
id="with_think_stream",
80+
),
81+
pytest.param(
82+
False,
83+
WITHOUT_THINK,
84+
id="without_think",
85+
),
86+
pytest.param(
87+
True,
88+
WITHOUT_THINK_STREAM,
89+
id="without_think_stream",
90+
),
91+
pytest.param(
92+
False,
93+
COMPLETE_REASONING,
94+
id="complete_reasoning",
95+
),
96+
pytest.param(
97+
True,
98+
COMPLETE_REASONING,
99+
id="complete_reasoning_stream",
100+
),
101+
pytest.param(
102+
False,
103+
MULTILINE_REASONING,
104+
id="multiline_reasoning",
105+
),
106+
pytest.param(
107+
True,
108+
MULTILINE_REASONING,
109+
id="multiline_reasoning_stream",
110+
),
111+
pytest.param(
112+
False,
113+
ONLY_OPEN_TAG,
114+
id="only_open_tag",
115+
),
116+
pytest.param(
117+
True,
118+
ONLY_OPEN_TAG_STREAM,
119+
id="only_open_tag_stream",
120+
),
121+
]
122+
123+
124+
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
125+
def test_reasoning(
126+
streaming: bool,
127+
param_dict: dict,
128+
qwen3_tokenizer,
129+
):
130+
output = qwen3_tokenizer.tokenize(param_dict["output"])
131+
output_tokens: list[str] = [
132+
qwen3_tokenizer.convert_tokens_to_string([token])
133+
for token in output
134+
]
135+
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
136+
parser_name)(qwen3_tokenizer)
137+
138+
reasoning, content = run_reasoning_extraction(parser,
139+
output_tokens,
140+
streaming=streaming)
141+
142+
assert reasoning == param_dict["reasoning_content"]
143+
assert content == param_dict["content"]

vllm/reasoning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
44
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
55
from .granite_reasoning_parser import GraniteReasoningParser
6+
from .qwen3_reasoning_parser import Qwen3ReasoningParser
67

78
__all__ = [
89
"ReasoningParser",
910
"ReasoningParserManager",
1011
"DeepSeekR1ReasoningParser",
1112
"GraniteReasoningParser",
13+
"Qwen3ReasoningParser",
1214
]
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import re
4+
from collections.abc import Sequence
5+
from typing import Optional, Union
6+
7+
from transformers import PreTrainedTokenizerBase
8+
9+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
10+
DeltaMessage)
11+
from vllm.logger import init_logger
12+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
13+
14+
logger = init_logger(__name__)
15+
16+
17+
@ReasoningParserManager.register_module("qwen3")
18+
class Qwen3ReasoningParser(ReasoningParser):
19+
"""
20+
Reasoning parser for the Qwen3 model.
21+
22+
The Qwen3 model uses <think>...</think> tokens to denote reasoning text
23+
within its output. The model provides a strict switch to disable reasoning
24+
output via the 'enable_thinking=False' parameter. This parser extracts the
25+
reasoning content enclosed by <think> and </think> tokens from the model's
26+
output.
27+
"""
28+
29+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
30+
super().__init__(tokenizer)
31+
self.think_start_token = "<think>"
32+
self.think_end_token = "</think>"
33+
34+
self.reasoning_regex = re.compile(
35+
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
36+
37+
if not self.model_tokenizer:
38+
raise ValueError(
39+
"The model tokenizer must be passed to the ReasoningParser "
40+
"constructor during construction.")
41+
42+
self.think_start_token_id = self.vocab.get(self.think_start_token)
43+
self.think_end_token_id = self.vocab.get(self.think_end_token)
44+
if (self.think_start_token_id is None
45+
or self.think_end_token_id is None):
46+
raise RuntimeError(
47+
"Qwen3 reasoning parser could not locate think start/end "
48+
"tokens in the tokenizer!")
49+
50+
def extract_reasoning_content_streaming(
51+
self,
52+
previous_text: str,
53+
current_text: str,
54+
delta_text: str,
55+
previous_token_ids: Sequence[int],
56+
current_token_ids: Sequence[int],
57+
delta_token_ids: Sequence[int],
58+
) -> Union[DeltaMessage, None]:
59+
"""
60+
Extract reasoning content from a delta message.
61+
Handles streaming output where previous + delta = current.
62+
Uses token IDs for faster processing.
63+
For text <think>abc</think>xyz:
64+
- 'abc' goes to reasoning_content
65+
- 'xyz' goes to content
66+
"""
67+
# Skip single special tokens
68+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
69+
self.think_start_token_id, self.think_end_token_id
70+
]):
71+
return None
72+
73+
if self.think_start_token_id in previous_token_ids:
74+
if self.think_end_token_id in delta_token_ids:
75+
# <think> in previous, </think> in delta,
76+
# extract reasoning content
77+
end_index = delta_text.find(self.think_end_token)
78+
reasoning_content = delta_text[:end_index]
79+
content = delta_text[end_index + len(self.think_end_token):]
80+
return DeltaMessage(reasoning_content=reasoning_content,
81+
content=content if content else None)
82+
elif self.think_end_token_id in previous_token_ids:
83+
# <think> in previous, </think> in previous,
84+
# reasoning content continues
85+
return DeltaMessage(content=delta_text)
86+
else:
87+
# <think> in previous, no </think> in previous or delta,
88+
# reasoning content continues
89+
return DeltaMessage(reasoning_content=delta_text)
90+
elif self.think_start_token_id in delta_token_ids:
91+
logger.info(delta_text)
92+
if self.think_end_token_id in delta_token_ids:
93+
# <think> in delta, </think> in delta, extract reasoning content
94+
start_index = delta_text.find(self.think_start_token)
95+
end_index = delta_text.find(self.think_end_token)
96+
reasoning_content = delta_text[start_index +
97+
len(self.think_start_token
98+
):end_index]
99+
content = delta_text[end_index + len(self.think_end_token):]
100+
return DeltaMessage(reasoning_content=reasoning_content,
101+
content=content if content else None)
102+
else:
103+
# <think> in delta, no </think> in delta,
104+
# reasoning content continues
105+
return DeltaMessage(reasoning_content=delta_text)
106+
else:
107+
# thinking is disabled, just content
108+
return DeltaMessage(content=delta_text)
109+
110+
def extract_reasoning_content(
111+
self, model_output: str, request: ChatCompletionRequest
112+
) -> tuple[Optional[str], Optional[str]]:
113+
114+
# Check if the model output contains the <think> tokens.
115+
if (self.think_start_token not in model_output
116+
or self.think_end_token not in model_output):
117+
return None, model_output
118+
else:
119+
# Use a regex to find the reasoning content
120+
reasoning_content = self.reasoning_regex.findall(model_output)[0]
121+
122+
# Remove the reasoning content from the model output
123+
# Although <think> token is always at the
124+
# beginning of the line, we cannot guarantee that the
125+
# other models will follow this convention.
126+
# Therefore, we need to add :start_index.
127+
start_index = model_output.find(self.think_start_token)
128+
if start_index != -1:
129+
end_index = start_index + len(
130+
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
131+
)
132+
model_output = model_output[:start_index] + \
133+
model_output[end_index:]
134+
135+
if len(model_output) == 0:
136+
return reasoning_content, None
137+
138+
return reasoning_content, model_output

0 commit comments

Comments
 (0)