|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
3 |
| -import re |
4 | 3 | from collections.abc import Sequence
|
5 | 4 | from typing import Optional, Union
|
6 | 5 |
|
@@ -31,9 +30,6 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
31 | 30 | self.think_start_token = "<think>"
|
32 | 31 | self.think_end_token = "</think>"
|
33 | 32 |
|
34 |
| - self.reasoning_regex = re.compile( |
35 |
| - rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) |
36 |
| - |
37 | 33 | if not self.model_tokenizer:
|
38 | 34 | raise ValueError(
|
39 | 35 | "The model tokenizer must be passed to the ReasoningParser "
|
@@ -121,29 +117,34 @@ def extract_reasoning_content_streaming(
|
121 | 117 | def extract_reasoning_content(
|
122 | 118 | self, model_output: str, request: ChatCompletionRequest
|
123 | 119 | ) -> tuple[Optional[str], Optional[str]]:
|
| 120 | + """ |
| 121 | + Extract reasoning content from the model output. |
| 122 | +
|
| 123 | + For text <think>abc</think>xyz: |
| 124 | + - 'abc' goes to reasoning_content |
| 125 | + - 'xyz' goes to content |
124 | 126 |
|
125 |
| - # Check if the model output contains the <think> tokens. |
| 127 | + Returns: |
| 128 | + tuple[Optional[str], Optional[str]]: reasoning content and content |
| 129 | + """ |
| 130 | + |
| 131 | + # Check if the model output contains the <think> and </think> tokens. |
126 | 132 | if (self.think_start_token not in model_output
|
127 | 133 | or self.think_end_token not in model_output):
|
128 | 134 | return None, model_output
|
129 |
| - else: |
130 |
| - # Use a regex to find the reasoning content |
131 |
| - reasoning_content = self.reasoning_regex.findall(model_output)[0] |
132 |
| - |
133 |
| - # Remove the reasoning content from the model output |
134 |
| - # Although <think> token is always at the |
135 |
| - # beginning of the line, we cannot guarantee that the |
136 |
| - # other models will follow this convention. |
137 |
| - # Therefore, we need to add :start_index. |
138 |
| - start_index = model_output.find(self.think_start_token) |
139 |
| - if start_index != -1: |
140 |
| - end_index = start_index + len( |
141 |
| - f"{self.think_start_token}{reasoning_content}{self.think_end_token}" |
142 |
| - ) |
143 |
| - model_output = model_output[:start_index] + \ |
144 |
| - model_output[end_index:] |
145 |
| - |
146 |
| - if len(model_output) == 0: |
147 |
| - return reasoning_content, None |
148 |
| - |
149 |
| - return reasoning_content, model_output |
| 135 | + # Check if the <think> is present in the model output, remove it |
| 136 | + # if it is present. |
| 137 | + model_output_parts = model_output.partition(self.think_start_token) |
| 138 | + model_output = model_output_parts[2] if model_output_parts[ |
| 139 | + 1] else model_output_parts[0] |
| 140 | + # Check if the model output contains the </think> tokens. |
| 141 | + # If the end token is not found, return the model output as is. |
| 142 | + if self.think_end_token not in model_output: |
| 143 | + return None, model_output |
| 144 | + |
| 145 | + # Extract reasoning content from the model output. |
| 146 | + reasoning_content, _, content = model_output.partition( |
| 147 | + self.think_end_token) |
| 148 | + |
| 149 | + final_content = content or None |
| 150 | + return reasoning_content, final_content |
0 commit comments