Skip to content

[Bugfix] add qwen3 reasoning-parser fix content is None when disable … #17369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/features/reasoning_outputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ vLLM currently supports the following reasoning models:
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ |

- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.

Expand Down
141 changes: 141 additions & 0 deletions tests/reasoning/test_qwen3_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
from transformers import AutoTokenizer

from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager

parser_name = "qwen3"
start_token = "<think>"
end_token = "</think>"

REASONING_MODEL_NAME = "Qwen/Qwen3-0.6B"


@pytest.fixture(scope="module")
def qwen3_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)


# 带 <think></think>,非stream
WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
# 带 <think></think>,stream
WITH_THINK_STREAM = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
# 不带 <think></think>,非stream
WITHOUT_THINK = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
}
# 不带 <think></think>,stream
WITHOUT_THINK_STREAM = {
"output": "This is the rest",
"reasoning_content": None,
"content": "This is the rest",
}

COMPLETE_REASONING = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
}
MULTILINE_REASONING = {
"output":
"<think>This is a reasoning\nsection</think>This is the rest\nThat",
"reasoning_content": "This is a reasoning\nsection",
"content": "This is the rest\nThat",
}
ONLY_OPEN_TAG = {
"output": "<think>This is a reasoning section",
"reasoning_content": None,
"content": "<think>This is a reasoning section",
}

ONLY_OPEN_TAG_STREAM = {
"output": "<think>This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
}

TEST_CASES = [
pytest.param(
False,
WITH_THINK,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
id="with_think_stream",
),
pytest.param(
False,
WITHOUT_THINK,
id="without_think",
),
pytest.param(
True,
WITHOUT_THINK_STREAM,
id="without_think_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_stream",
),
pytest.param(
False,
MULTILINE_REASONING,
id="multiline_reasoning",
),
pytest.param(
True,
MULTILINE_REASONING,
id="multiline_reasoning_stream",
),
pytest.param(
False,
ONLY_OPEN_TAG,
id="only_open_tag",
),
pytest.param(
True,
ONLY_OPEN_TAG_STREAM,
id="only_open_tag_stream",
),
]


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
qwen3_tokenizer,
):
output = qwen3_tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
qwen3_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(qwen3_tokenizer)

reasoning, content = run_reasoning_extraction(parser,
output_tokens,
streaming=streaming)

assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
2 changes: 2 additions & 0 deletions vllm/reasoning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser

__all__ = [
"ReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"GraniteReasoningParser",
"Qwen3ReasoningParser",
]
138 changes: 138 additions & 0 deletions vllm/reasoning/qwen3_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0

import re
from collections.abc import Sequence
from typing import Optional, Union

from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)


@ReasoningParserManager.register_module("qwen3")
class Qwen3ReasoningParser(ReasoningParser):
"""
Reasoning parser for the Qwen3 model.

The Qwen3 model uses <think>...</think> tokens to denote reasoning text
within its output. The model provides a strict switch to disable reasoning
output via the 'enable_thinking=False' parameter. This parser extracts the
reasoning content enclosed by <think> and </think> tokens from the model's
output.
"""

def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"

self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)

if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")

self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
raise RuntimeError(
"Qwen3 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")

def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text <think>abc</think>xyz:
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
]):
return None

if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
else:
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# thinking is disabled, just content
return DeltaMessage(content=delta_text)

def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:

# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
return None, model_output
else:
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]

# Remove the reasoning content from the model output
# Although <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]

if len(model_output) == 0:
return reasoning_content, None

return reasoning_content, model_output