Skip to content

[Bugfix]: Reasoning output bug according to the chat template change #13025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content

print("reasoning_content:", reasoning_content)
print("content:", content)
print("reasoning_content for Round 1:", reasoning_content)
print("content for Round 1:", content)

# Round 2
messages.append({"role": "assistant", "content": content})
Expand All @@ -50,5 +50,5 @@
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content

print("reasoning_content:", reasoning_content)
print("content:", content)
print("reasoning_content for Round 2:", reasoning_content)
print("content for Round 2:", content)
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,62 @@
end_token = "</think>"

SIMPLE_REASONING = {
"output": "<think>This is a reasoning section</think>This is the rest",
"output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
COMPLETE_REASONING = {
"output": "<think>This is a reasoning section</think>",
"output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
}
NO_REASONING = {
"output": "This is a reasoning section",
"output": "This is content",
"reasoning_content": None,
"content": "This is a reasoning section",
"content": "This is content",
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
}
MULTIPLE_LINES = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "<think></think>This is the rest",
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
}
SHORTEST_REASONING = {
"output": "<think></think>This is the rest",
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
}
REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
}
SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
}
Expand All @@ -49,37 +79,37 @@
pytest.param(
False,
SIMPLE_REASONING,
id="simple_streaming",
id="simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_streaming",
id="simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_streaming",
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_streaming",
id="complete_reasoning_streaming",
),
pytest.param(
False,
NO_REASONING,
id="no_streaming",
id="no_reasoning_token",
),
pytest.param(
True,
NO_REASONING,
id="no_streaming",
NO_REASONING_STREAMING,
id="no_reasoning_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines_streaming",
id="multiple_lines",
),
pytest.param(
True,
Expand All @@ -89,23 +119,65 @@
pytest.param(
True,
SHORTEST_REASONING,
id="shortest_streaming",
id="shortest",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_streaming",
),
pytest.param(
False,
REASONING_WITH_THINK,
id="reasoning_with_think",
),
pytest.param(
True,
REASONING_WITH_THINK,
id="reasoning_with_think_streaming",
),
pytest.param(
False,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think",
),
pytest.param(
True,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think_streaming",
),
pytest.param(
False,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think",
),
pytest.param(
True,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think_streaming",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
id="shortest_with_think",
),
pytest.param(
True,
SHORTEST_REASONING_WITH_THINK,
id="shortest_with_think_streaming",
),
]

# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])


@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
):
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])
output = tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: List[str] = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def extract_reasoning_content_streaming(
]):
return None

# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
Expand All @@ -85,7 +87,6 @@ def extract_reasoning_content_streaming(
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
Expand All @@ -101,35 +102,46 @@ def extract_reasoning_content_streaming(
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
return DeltaMessage(content=delta_text)
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)

def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:

# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
return None, model_output
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]

# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]

if len(model_output) == 0:
return reasoning_content, None

return reasoning_content, model_output
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
final_output = model_output[end_index:]

if len(final_output) == 0:
return reasoning_content, None

return reasoning_content, final_output