Skip to content

Commit e8e6b42

Browse files
zhangningmofanke
zhangning
authored andcommitted
[Bugfix] add qwen3 reasoning-parser fix content is None when disable thinking (#17357)
1 parent 97cc872 commit e8e6b42

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

vllm/reasoning/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
44
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
55
from .granite_reasoning_parser import GraniteReasoningParser
6+
from .qwen3_reasoning_parser import Qwen3ReasoningParser
7+
68

79
__all__ = [
810
"ReasoningParser",
911
"ReasoningParserManager",
1012
"DeepSeekR1ReasoningParser",
1113
"GraniteReasoningParser",
14+
"Qwen3ReasoningParser",
1215
]
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import re
2+
from typing import Optional, Union
3+
from collections.abc import Sequence
4+
5+
from transformers import PreTrainedTokenizerBase
6+
7+
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
8+
DeltaMessage)
9+
from vllm.reasoning import ReasoningParser, ReasoningParserManager
10+
from vllm.logger import init_logger
11+
12+
logger = init_logger(__name__)
13+
14+
15+
@ReasoningParserManager.register_module("qwen3")
16+
class Qwen3ReasoningParser(ReasoningParser):
17+
"""
18+
Reasoning parser for the Qwen3 model.
19+
20+
The Qwen3 model uses <think>...</think> tokens to denote reasoning text
21+
within its output. The model provides a strict switch to disable reasoning
22+
output via the 'enable_thinking=False' parameter. This parser extracts the
23+
reasoning content enclosed by <think> and </think> tokens from the model's
24+
output.
25+
"""
26+
27+
def __init__(self, tokenizer: PreTrainedTokenizerBase):
28+
super().__init__(tokenizer)
29+
self.think_start_token = "<think>"
30+
self.think_end_token = "</think>"
31+
32+
self.reasoning_regex = re.compile(
33+
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
34+
35+
if not self.model_tokenizer:
36+
raise ValueError(
37+
"The model tokenizer must be passed to the ReasoningParser "
38+
"constructor during construction.")
39+
40+
self.think_start_token_id = self.vocab.get(self.think_start_token)
41+
self.think_end_token_id = self.vocab.get(self.think_end_token)
42+
if (self.think_start_token_id is None
43+
or self.think_end_token_id is None):
44+
raise RuntimeError(
45+
"Qwen3 reasoning parser could not locate think start/end "
46+
"tokens in the tokenizer!")
47+
48+
def extract_reasoning_content_streaming(
49+
self,
50+
previous_text: str,
51+
current_text: str,
52+
delta_text: str,
53+
previous_token_ids: Sequence[int],
54+
current_token_ids: Sequence[int],
55+
delta_token_ids: Sequence[int],
56+
) -> Union[DeltaMessage, None]:
57+
"""
58+
Extract reasoning content from a delta message.
59+
Handles streaming output where previous + delta = current.
60+
Uses token IDs for faster processing.
61+
For text <think>abc</think>xyz:
62+
- 'abc' goes to reasoning_content
63+
- 'xyz' goes to content
64+
"""
65+
# Skip single special tokens
66+
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
67+
self.think_start_token_id, self.think_end_token_id
68+
]):
69+
return None
70+
71+
if self.think_start_token_id in previous_token_ids:
72+
if self.think_end_token_id in delta_token_ids:
73+
# <think> in previous, </think> in delta,
74+
# extract reasoning content
75+
end_index = delta_text.find(self.think_end_token)
76+
reasoning_content = delta_text[:end_index]
77+
content = delta_text[end_index + len(self.think_end_token):]
78+
return DeltaMessage(reasoning_content=reasoning_content,
79+
content=content if content else None)
80+
elif self.think_end_token_id in previous_token_ids:
81+
# <think> in previous, </think> in previous,
82+
# reasoning content continues
83+
return DeltaMessage(content=delta_text)
84+
else:
85+
# <think> in previous, no </think> in previous or delta,
86+
# reasoning content continues
87+
return DeltaMessage(reasoning_content=delta_text)
88+
elif self.think_start_token_id in delta_token_ids:
89+
logger.info(delta_text)
90+
if self.think_end_token_id in delta_token_ids:
91+
# <think> in delta, </think> in delta, extract reasoning content
92+
start_index = delta_text.find(self.think_start_token)
93+
end_index = delta_text.find(self.think_end_token)
94+
reasoning_content = delta_text[start_index +
95+
len(self.think_start_token
96+
):end_index]
97+
content = delta_text[end_index + len(self.think_end_token):]
98+
return DeltaMessage(reasoning_content=reasoning_content,
99+
content=content if content else None)
100+
else:
101+
# <think> in delta, no </think> in delta,
102+
# reasoning content continues
103+
return DeltaMessage(reasoning_content=delta_text)
104+
else:
105+
# thinking is disabled, just content
106+
return DeltaMessage(content=delta_text)
107+
108+
def extract_reasoning_content(
109+
self, model_output: str, request: ChatCompletionRequest
110+
) -> tuple[Optional[str], Optional[str]]:
111+
112+
# Check if the model output contains the <think> tokens.
113+
if (self.think_start_token not in model_output
114+
or self.think_end_token not in model_output):
115+
return None, model_output
116+
else:
117+
# Use a regex to find the reasoning content
118+
reasoning_content = self.reasoning_regex.findall(model_output)[0]
119+
120+
# Remove the reasoning content from the model output
121+
# Although <think> token is always at the
122+
# beginning of the line, we cannot guarantee that the
123+
# other models will follow this convention.
124+
# Therefore, we need to add :start_index.
125+
start_index = model_output.find(self.think_start_token)
126+
if start_index != -1:
127+
end_index = start_index + len(
128+
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
129+
)
130+
model_output = model_output[:start_index] + \
131+
model_output[end_index:]
132+
133+
if len(model_output) == 0:
134+
return reasoning_content, None
135+
136+
return reasoning_content, model_output

0 commit comments

Comments
 (0)