From ef7e1ec4c92c4f2d811ada5efa37cf377496136f Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Tue, 11 Feb 2025 12:00:56 +0800 Subject: [PATCH 1/2] chore: Keep compatibility Signed-off-by: Ce Gao --- .../openai_chat_completion_with_reasoning.py | 8 +- .../test_deepseekr1_reasoning_parser.py | 82 +++++++++++++++---- .../deepseek_r1_reasoning_parser.py | 58 +++++++------ 3 files changed, 105 insertions(+), 43 deletions(-) diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index a88c8adb55c..b5dbed1205d 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -36,8 +36,8 @@ reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 1:", reasoning_content) +print("content for Round 1:", content) # Round 2 messages.append({"role": "assistant", "content": content}) @@ -50,5 +50,5 @@ reasoning_content = response.choices[0].message.reasoning_content content = response.choices[0].message.content -print("reasoning_content:", reasoning_content) -print("content:", content) +print("reasoning_content for Round 2:", reasoning_content) +print("content for Round 2:", content) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index f7b81be48bd..b0b8739dbe1 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -15,71 +15,91 @@ end_token = "" SIMPLE_REASONING = { - "output": "This is a reasoning sectionThis is the rest", + "output": "This is a reasoning sectionThis is the rest", "reasoning_content": "This is a reasoning section", "content": "This is the rest", } COMPLETE_REASONING = { - "output": "This is a reasoning section", + "output": "This is a reasoning section", "reasoning_content": "This is a reasoning section", "content": None, } NO_REASONING = { - "output": "This is a reasoning section", + "output": "This is content", "reasoning_content": None, - "content": "This is a reasoning section", + "content": "This is content", +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, } MULTIPLE_LINES = { - "output": "This\nThatThis is the rest\nThat", + "output": "This\nThatThis is the rest\nThat", "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } SHORTEST_REASONING_NO_STREAMING = { - "output": "This is the rest", + "output": "This is the rest", "reasoning_content": "", "content": "This is the rest", } SHORTEST_REASONING = { - "output": "This is the rest", + "output": "This is the rest", "reasoning_content": None, "content": "This is the rest", } +REASONING_WITH_THINK = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} TEST_CASES = [ pytest.param( False, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning", ), pytest.param( True, SIMPLE_REASONING, - id="simple_streaming", + id="simple_reasoning_streaming", ), pytest.param( False, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning", ), pytest.param( True, COMPLETE_REASONING, - id="complete_streaming", + id="complete_reasoning_streaming", ), pytest.param( False, NO_REASONING, - id="no_streaming", + id="no_reasoning_token", ), pytest.param( True, - NO_REASONING, - id="no_streaming", + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", ), pytest.param( False, MULTIPLE_LINES, - id="multiple_lines_streaming", + id="multiple_lines", ), pytest.param( True, @@ -89,13 +109,43 @@ pytest.param( True, SHORTEST_REASONING, - id="shortest_streaming", + id="shortest", ), pytest.param( False, SHORTEST_REASONING_NO_STREAMING, id="shortest_streaming", ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), ] diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py index 5c19888d454..33bba04882b 100644 --- a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -67,6 +67,8 @@ def extract_reasoning_content_streaming( ]): return None + # Check if is present in previous or delta. + # Keep compatibility with models that don't generate tokens. if self.think_start_token_id in previous_token_ids: if self.think_end_token_id in delta_token_ids: # in previous, in delta, @@ -85,7 +87,6 @@ def extract_reasoning_content_streaming( # reasoning content continues return DeltaMessage(reasoning_content=delta_text) elif self.think_start_token_id in delta_token_ids: - logger.info(delta_text) if self.think_end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) @@ -101,35 +102,46 @@ def extract_reasoning_content_streaming( # reasoning content continues return DeltaMessage(reasoning_content=delta_text) else: - # No in previous or delta, reasoning content continues. - return DeltaMessage(content=delta_text) + # No in previous or delta, also need to check for . + # Because the model may have generated without + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token_id in delta_token_ids: + # in delta with more tokens, + # extract reasoning content and content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, thinking content ends + return DeltaMessage(content=delta_text) + else: + # no in previous or delta, reasoning content continues + return DeltaMessage(reasoning_content=delta_text) def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> Tuple[Optional[str], Optional[str]]: - # Check if the model output contains the tokens. - if (self.think_start_token not in model_output - or self.think_end_token not in model_output): + # DeepSeek R1 doesn't generate now. + # Thus we assume the reasoning content is always at the start. + # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f + if self.think_end_token not in model_output: return None, model_output else: + # Add a start token if it's missing to keep compatibility. + if self.think_start_token not in model_output: + model_output = f"{self.think_start_token}{model_output}" # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] - # Remove the reasoning content from the model output - # Although deepseek's token is always at the - # beginning of the line, we cannot guarantee that the - # other models will follow this convention. - # Therefore, we need to add :start_index. - start_index = model_output.find(self.think_start_token) - if start_index != -1: - end_index = start_index + len( - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" - ) - model_output = model_output[:start_index] + \ - model_output[end_index:] - - if len(model_output) == 0: - return reasoning_content, None - - return reasoning_content, model_output + end_index = len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + final_output = model_output[end_index:] + + if len(final_output) == 0: + return reasoning_content, None + + return reasoning_content, final_output From 11b6b4b61c2da26d42cdd7b6cf9161d10a70b04d Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Tue, 11 Feb 2025 12:15:27 +0800 Subject: [PATCH 2/2] chore: Add two more test cases and speed it up Signed-off-by: Ce Gao --- .../test_deepseekr1_reasoning_parser.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py index b0b8739dbe1..fdadb2e21ff 100644 --- a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -64,6 +64,16 @@ "reasoning_content": "This\nThat", "content": "This is the rest\nThat", } +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", +} +SHORTEST_REASONING_WITH_THINK = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} TEST_CASES = [ pytest.param( @@ -146,16 +156,28 @@ MULTIPLE_LINES_WITH_THINK, id="multiple_lines_with_think_streaming", ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), ] +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") +tokenizer.add_tokens([start_token, end_token]) + @pytest.mark.parametrize("streaming, param_dict", TEST_CASES) def test_reasoning( streaming: bool, param_dict: dict, ): - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - tokenizer.add_tokens([start_token, end_token]) output = tokenizer.tokenize(param_dict["output"]) # decode everything to tokens output_tokens: List[str] = [