Skip to content

Commit 472926d

Browse files
author
Your Name
committed
重构代码
1 parent 2b91a31 commit 472926d

File tree

8 files changed

+17
-1508
lines changed

8 files changed

+17
-1508
lines changed

math_operations_dataset.jsonl

-1,000
This file was deleted.

math_tool.py

-179
This file was deleted.

swift/plugin/orm.py

+1-61
Original file line numberDiff line numberDiff line change
@@ -286,65 +286,6 @@ def __call__(self, completions, **kwargs) -> List[float]:
286286
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
287287
return [1.0 if match else 0.0 for match in matches]
288288

289-
import re
290-
from typing import List
291-
292-
class SimpleReward(ORM):
293-
294-
def __call__(self, completions, **kwargs) -> List[float]:
295-
"""
296-
Reward function that checks if the completion has a specific format
297-
and compares the extracted answer with the expected answer.
298-
299-
Args:
300-
completions: List of completion strings to evaluate
301-
kwargs: Additional arguments, should include 'answer' key with expected float answer
302-
303-
Returns:
304-
List of scores: 0.0 for incorrect format, 0.2 for correct format but wrong answer,
305-
1.2 for correct format and correct answer
306-
"""
307-
# Format pattern to match <think>...</think><answer>...</answer>
308-
format_pattern = r'^<think>(.*?)</think>\n<answer>(.*?)</answer>$'
309-
scores = []
310-
expected_answer = kwargs.get("response", [0])[0] # Default to 0.0 if not provided
311-
312-
for content in completions:
313-
# Check if the format matches
314-
match = re.match(format_pattern, content, re.DOTALL | re.MULTILINE)
315-
316-
if not match:
317-
# Incorrect format
318-
scores.append(0.0)
319-
else:
320-
# Extract the answer
321-
extracted_answer_str = match.group(1).strip()
322-
323-
# Try to convert the extracted answer to float
324-
try:
325-
# Extract numeric part from the answer
326-
# This regex finds all numbers (including decimals) in the string
327-
number_matches = re.findall(r'-?\d+\.?\d*', extracted_answer_str)
328-
329-
if number_matches:
330-
# Take the first number found in the answer
331-
extracted_answer = float(number_matches[0])
332-
333-
# Compare with expected answer with some tolerance for floating point
334-
if abs(extracted_answer - expected_answer) < 1e-6:
335-
# Correct format and correct answer
336-
scores.append(1.2)
337-
else:
338-
# Correct format but wrong answer
339-
scores.append(0.2)
340-
else:
341-
# No numeric value found in the answer
342-
scores.append(0.2)
343-
except (ValueError, TypeError):
344-
# Failed to convert to float
345-
scores.append(0.2)
346-
347-
return scores
348289

349290
class ReActFormat(ORM):
350291

@@ -442,6 +383,5 @@ def __call__(self, completions, **kwargs) -> List[float]:
442383
'format': Format,
443384
'react_format': ReActFormat,
444385
'cosine': CosineReward,
445-
'repetition': RepetitionPenalty,
446-
'simplereward':SimpleReward
386+
'repetition': RepetitionPenalty
447387
}

swift/plugin/tool_call.py

+4-97
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,9 @@
1-
from typing import Union, Tuple, Optional
2-
1+
from typing import Tuple,Any, Optional
32
class TOOL_CALL:
4-
5-
def __call__(self, completion:str) -> Tuple[str, bool, Optional[int]]:
3+
def __call__(self, completion: str) -> Tuple[Any, bool, Optional[float]]:
64
raise NotImplementedError
75

86

9-
"""
10-
Search module for RL training loop.
11-
This module provides functions to search through vectorized documents and retrieve question-answer pairs.
12-
"""
13-
14-
import json
15-
import re
16-
from typing import Tuple, Optional
17-
import traceback
18-
19-
# Load the vectorstore when module is imported
20-
try:
21-
vectorstore = load_vectorstore()
22-
if vectorstore is None:
23-
print("Warning: FAISS vectorstore could not be loaded.")
24-
except Exception as e:
25-
print(f"Error loading vectorstore: {e}")
26-
vectorstore = None
27-
28-
def search(query: str, results: int = 5):
29-
"""
30-
Search for relevant chunks using similarity search.
31-
32-
Args:
33-
query: The search query
34-
return_type: Return as string or list (default: str)
35-
results: Number of results to return (default: 5)
36-
37-
Returns:
38-
Results as string or list depending on return_type
39-
"""
40-
if vectorstore is None:
41-
raise ValueError("Vectorstore not loaded. Please ensure FAISS index exists.")
42-
43-
search_results = vectorstore.similarity_search(query, k=results)
7+
tools = {
448

45-
result_dict = {}
46-
for idx, result in enumerate(search_results, start=1):
47-
result_dict[idx] = result.page_content
48-
49-
result_json = json.dumps(result_dict,indent=2,ensure_ascii=False)
50-
return f"<result>\n{result_json}\n</result>"
51-
52-
class TOOL_CALL:
53-
def __call__(self, completion: str) -> Tuple[str, bool, Optional[float]]:
54-
raise NotImplementedError
55-
56-
class Search_Tool(TOOL_CALL):
57-
def __call__(self, completion: str) -> Tuple[str, bool, Optional[float]]:
58-
"""
59-
Checks if the completion strictly follows the format <think>xxx</think><tool_call>xxx</tool_call>
60-
and if the tool_call contains valid JSON with "tool" and "arg" fields.
61-
62-
Args:
63-
completion: The text completion to check
64-
65-
Returns:
66-
Tuple containing:
67-
- search result or empty string
68-
- boolean indicating if there was an error
69-
- score (0.2 if successful, 0 if error)
70-
"""
71-
try:
72-
# Check for required strict format using regex
73-
pattern = r'^<think>(.*?)</think><tool_call>(.*?)</tool_call>$'
74-
match = re.match(pattern, completion.strip(), re.DOTALL)
75-
76-
if not match:
77-
return "", True, 0
78-
79-
tool_content = match.group(2).strip()
80-
81-
# Parse JSON from tool_call content
82-
try:
83-
tool_data = json.loads(tool_content)
84-
except json.JSONDecodeError:
85-
return "", True, 0
86-
87-
# Check if JSON has required fields
88-
if not isinstance(tool_data, dict) or "tool" not in tool_data or "arg" not in tool_data:
89-
return "", True, 0
90-
91-
# Check if the tool is "search"
92-
if tool_data["tool"] != "search":
93-
return "", True, 0
94-
95-
# Execute search with the provided argument
96-
search_result = search(tool_data["arg"])
97-
return search_result, False, 0.2
98-
99-
except Exception as e:
100-
print(f"Error in Search_Tool: {e}")
101-
traceback.print_exc()
102-
return "", True, 0
9+
}

swift/trainers/arguments.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from dataclasses import dataclass
44
from functools import wraps
5-
from typing import Any, Dict, Literal, Optional, Union
5+
from typing import Any, Dict, Literal, Optional, Union, Callable
66

77
import torch
88
import torch.utils.checkpoint
@@ -104,6 +104,9 @@ class GRPOArgumentsMixin:
104104
offload_optimizer: bool = False
105105
offload_model: bool = False
106106
gc_collect_after_offload: bool = False
107+
is_reward_tool_call:bool = True #是否额外单独计算每个tool call的format得分
108+
tool_call_weight:float = 1.0
109+
tool_call:str = None
107110

108111

109112
@dataclass

swift/trainers/rlhf_arguments.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import List, Optional
2+
from typing import List, Optional,Callable
33

44
from trl import CPOConfig as HfCPOConfig
55
from trl import DPOConfig as HfDPOConfig
@@ -45,7 +45,6 @@ class PPOConfig(SwiftArgumentsMixin, HfPPOConfig):
4545
@dataclass
4646
class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
4747
stop_words: List[str] = field(default_factory=list)
48-
is_reward_tool_call = True #是否额外单独计算每个tool call的format得分
4948

5049
def __post_init__(self):
5150
from swift.llm.argument.base_args.model_args import ModelArguments

0 commit comments

Comments
 (0)