Skip to content

Commit 2b91a31

Browse files
author
Your Name
committed
增加grpo多次工具调用训练
1 parent 6e982d7 commit 2b91a31

File tree

9 files changed

+2025
-70
lines changed

9 files changed

+2025
-70
lines changed

gen_data.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Dataset Generator for Custom Mathematical Operations
3+
This module generates a dataset of custom mathematical expressions and their results.
4+
"""
5+
6+
import random
7+
import json
8+
import re
9+
from typing import Dict, List
10+
from math_tool import parse_expression, SYMBOL_TO_OPERATION, OPERATION_DEFINITIONS
11+
12+
def generate_safe_expression():
13+
"""Generate an expression that won't cause overflow errors"""
14+
# Start with a moderate number
15+
expression = str(random.randint(1, 20))
16+
17+
# Add 3-5 operations with safe numbers
18+
num_ops = random.randint(1, 6)
19+
20+
for i in range(num_ops):
21+
# Choose operation
22+
op = random.choice(['@', '&', '$', '^'])
23+
24+
# For @ operation, use smaller numbers to avoid overflow
25+
if op == '@':
26+
# For exponentiation, keep the exponent small
27+
num = random.randint(1, 3)
28+
else:
29+
num = random.randint(1, 10)
30+
31+
expression += op + str(num)
32+
33+
return expression
34+
35+
def generate_dataset(num_samples: int = 1000, output_file: str = "math_operations_dataset.jsonl") -> None:
36+
"""
37+
Generates a dataset of custom mathematical expressions and their results.
38+
Saves the dataset as a JSONL file.
39+
40+
Args:
41+
num_samples: Number of samples to generate
42+
output_file: Path to save the JSONL file
43+
"""
44+
with open(output_file, 'w') as f:
45+
for _ in range(num_samples):
46+
# Generate a safe expression
47+
expression = generate_safe_expression()
48+
49+
# Calculate the result
50+
try:
51+
result = parse_expression(expression)
52+
53+
# Create the data entry
54+
# data_entry = {
55+
# "query": f"Calculate the result of the expression: {expression}",
56+
# "answer": result
57+
# }
58+
data_entry = {"messages": [{"role": "user", "content": f"Calculate the result of the expression: {expression}"}],"response":result}
59+
60+
# Write to JSONL file
61+
f.write(json.dumps(data_entry) + '\n')
62+
except Exception as e:
63+
print(f"Skipping problematic expression {expression}: {e}")
64+
continue
65+
66+
print(f"Generated dataset with {num_samples} samples and saved to {output_file}")
67+
68+
generate_dataset()

math_operations_dataset.jsonl

+1,000
Large diffs are not rendered by default.

math_tool.py

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""
2+
Custom Mathematical Operations Tool Module
3+
This module provides a unified tool call interface for various mathematical operations.
4+
"""
5+
6+
import json
7+
import re
8+
import math
9+
from typing import Tuple, Optional, Any, Dict
10+
11+
# Define the custom mathematical operations with overflow protection
12+
def operation_at(a: int, b: int) -> float:
13+
"""@ operation: Returns a raised to the power of b, then adds a*b
14+
15+
Includes overflow protection for large numbers
16+
"""
17+
try:
18+
# For very large exponents, use log approximation
19+
if b > 100 or (a > 20 and b > 10):
20+
# Fall back to a simpler calculation for very large values
21+
return a * b * 2 # Simplified approximation
22+
power_result = a ** b
23+
product_result = a * b
24+
return power_result + product_result
25+
except OverflowError:
26+
# If overflow occurs, return a simplified approximation
27+
return a * b * 2 # Simplified approximation
28+
29+
def operation_amp(a: int, b: int) -> float:
30+
"""& operation: Returns the average of a and b, multiplied by their absolute difference"""
31+
avg = (a + b) / 2
32+
diff = abs(a - b)
33+
return avg * diff
34+
35+
def operation_dollar(a: int, b: int) -> float:
36+
"""$ operation: Returns a factorial-like sum of a repeated b times: a + (a-1) + (a-2) + ... + (a-b+1)"""
37+
if b <= 0 or b > a:
38+
return a
39+
40+
# For large values, use arithmetic sequence sum formula
41+
if b > 1000:
42+
# Sum of arithmetic sequence: n/2 * (first + last)
43+
n = min(b, a)
44+
first = a
45+
last = a - n + 1
46+
return n * (first + last) / 2
47+
48+
return sum(a - i for i in range(int(min(b, a))))
49+
50+
def operation_caret(a: int, b: int) -> float:
51+
"""^ operation: Returns a * b if both are even, a + b if both are odd, a - b otherwise"""
52+
if a % 2 == 0 and b % 2 == 0:
53+
return a * b
54+
elif a % 2 == 1 and b % 2 == 1:
55+
return a + b
56+
else:
57+
return a - b
58+
59+
class TOOL_CALL:
60+
def __call__(self, completion: str) -> Tuple[Any, bool, Optional[float]]:
61+
raise NotImplementedError
62+
63+
class MathOperation_Tool(TOOL_CALL):
64+
"""Unified tool for handling all mathematical operations"""
65+
66+
def __init__(self):
67+
self.operations = {
68+
"at_operation": operation_at,
69+
"amp_operation": operation_amp,
70+
"dollar_operation": operation_dollar,
71+
"caret_operation": operation_caret
72+
}
73+
74+
def __call__(self, completion: str) -> Tuple[float, bool, float]:
75+
try:
76+
# Check for required strict format
77+
pattern = r'^<think>(.*?)</think>\n<tool>(.*?)</tool>$'
78+
match = re.match(pattern, completion.strip(), re.DOTALL)
79+
80+
if not match:
81+
return "", True, 0
82+
83+
tool_content = match.group(2).strip()
84+
85+
# Parse JSON from tool content
86+
try:
87+
tool_data = json.loads(tool_content)
88+
except json.JSONDecodeError:
89+
return "", True, 0
90+
91+
# Check if JSON has required fields
92+
if not isinstance(tool_data, dict) or "tool" not in tool_data or "a" not in tool_data or "b" not in tool_data:
93+
return "", True, 0
94+
95+
tool_name = tool_data["tool"]
96+
97+
# Check if the requested operation exists
98+
if tool_name not in self.operations:
99+
return "", True, 0
100+
101+
# Get the operation function
102+
operation_func = self.operations[tool_name]
103+
104+
# Execute operation
105+
try:
106+
a, b = float(tool_data["a"]), float(tool_data["b"])
107+
result = operation_func(a, b)
108+
return f"<result>\n{result}\n</reuslt>", False, 0.2
109+
except (ValueError, TypeError):
110+
return "", True, 0
111+
112+
except Exception as e:
113+
print(f"Error in MathOperation_Tool: {e}")
114+
return "", True, 0
115+
116+
# Parser for expressions with overflow protection
117+
def parse_expression(expression: str) -> float:
118+
"""
119+
Parses and evaluates a custom mathematical expression.
120+
Supports operations: @, &, $, ^
121+
Example: "11@2&1$44^2"
122+
123+
Includes overflow protection for large numbers
124+
"""
125+
# Tokenize the expression - find all numbers and operators
126+
tokens = re.findall(r'(\d+|\@|\&|\$|\^)', expression)
127+
128+
# Process tokens
129+
result = None
130+
current_op = None
131+
132+
for token in tokens:
133+
if token in ['@', '&', '$', '^']:
134+
current_op = token
135+
else:
136+
try:
137+
num = int(token)
138+
if result is None:
139+
result = num
140+
elif current_op == '@':
141+
# Limit very large inputs for @ operation
142+
if result > 10000 or num > 100:
143+
result = result * num * 2 # Simplified approximation
144+
else:
145+
result = operation_at(result, num)
146+
elif current_op == '&':
147+
result = operation_amp(result, num)
148+
elif current_op == '$':
149+
result = operation_dollar(result, num)
150+
elif current_op == '^':
151+
result = operation_caret(result, num)
152+
except (OverflowError, ValueError):
153+
# Handle overflow by using a simplified calculation
154+
if current_op == '@':
155+
result = result * num * 2 # Simplified approximation
156+
elif current_op == '&':
157+
result = result * num # Simplified approximation
158+
elif current_op == '$':
159+
result = result + num # Simplified approximation
160+
elif current_op == '^':
161+
result = max(result, num) # Simplified approximation
162+
163+
return result
164+
165+
# Map symbols to operation names
166+
SYMBOL_TO_OPERATION = {
167+
'@': 'at_operation',
168+
'&': 'amp_operation',
169+
'$': 'dollar_operation',
170+
'^': 'caret_operation'
171+
}
172+
173+
# Operation definitions for reference
174+
OPERATION_DEFINITIONS = {
175+
"@": "a@b = (a^b) + (a*b)",
176+
"&": "a&b = ((a+b)/2) * |a-b|",
177+
"$": "a$b = a + (a-1) + (a-2) + ... + (a-b+1)",
178+
"^": "a^b = a*b if both even, a+b if both odd, a-b otherwise"
179+
}

swift/plugin/orm.py

+60
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,65 @@ def __call__(self, completions, **kwargs) -> List[float]:
286286
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]
287287
return [1.0 if match else 0.0 for match in matches]
288288

289+
import re
290+
from typing import List
291+
292+
class SimpleReward(ORM):
293+
294+
def __call__(self, completions, **kwargs) -> List[float]:
295+
"""
296+
Reward function that checks if the completion has a specific format
297+
and compares the extracted answer with the expected answer.
298+
299+
Args:
300+
completions: List of completion strings to evaluate
301+
kwargs: Additional arguments, should include 'answer' key with expected float answer
302+
303+
Returns:
304+
List of scores: 0.0 for incorrect format, 0.2 for correct format but wrong answer,
305+
1.2 for correct format and correct answer
306+
"""
307+
# Format pattern to match <think>...</think><answer>...</answer>
308+
format_pattern = r'^<think>(.*?)</think>\n<answer>(.*?)</answer>$'
309+
scores = []
310+
expected_answer = kwargs.get("response", [0])[0] # Default to 0.0 if not provided
311+
312+
for content in completions:
313+
# Check if the format matches
314+
match = re.match(format_pattern, content, re.DOTALL | re.MULTILINE)
315+
316+
if not match:
317+
# Incorrect format
318+
scores.append(0.0)
319+
else:
320+
# Extract the answer
321+
extracted_answer_str = match.group(1).strip()
322+
323+
# Try to convert the extracted answer to float
324+
try:
325+
# Extract numeric part from the answer
326+
# This regex finds all numbers (including decimals) in the string
327+
number_matches = re.findall(r'-?\d+\.?\d*', extracted_answer_str)
328+
329+
if number_matches:
330+
# Take the first number found in the answer
331+
extracted_answer = float(number_matches[0])
332+
333+
# Compare with expected answer with some tolerance for floating point
334+
if abs(extracted_answer - expected_answer) < 1e-6:
335+
# Correct format and correct answer
336+
scores.append(1.2)
337+
else:
338+
# Correct format but wrong answer
339+
scores.append(0.2)
340+
else:
341+
# No numeric value found in the answer
342+
scores.append(0.2)
343+
except (ValueError, TypeError):
344+
# Failed to convert to float
345+
scores.append(0.2)
346+
347+
return scores
289348

290349
class ReActFormat(ORM):
291350

@@ -384,4 +443,5 @@ def __call__(self, completions, **kwargs) -> List[float]:
384443
'react_format': ReActFormat,
385444
'cosine': CosineReward,
386445
'repetition': RepetitionPenalty,
446+
'simplereward':SimpleReward
387447
}

0 commit comments

Comments
 (0)