1
- from typing import Union , Tuple , Optional
2
-
1
+ from typing import Tuple ,Any , Optional
3
2
class TOOL_CALL :
4
-
5
- def __call__ (self , completion :str ) -> Tuple [str , bool , Optional [int ]]:
3
+ def __call__ (self , completion : str ) -> Tuple [Any , bool , Optional [float ]]:
6
4
raise NotImplementedError
7
5
8
6
9
- """
10
- Search module for RL training loop.
11
- This module provides functions to search through vectorized documents and retrieve question-answer pairs.
12
- """
13
-
14
- import json
15
- import re
16
- from typing import Tuple , Optional
17
- import traceback
18
-
19
- # Load the vectorstore when module is imported
20
- try :
21
- vectorstore = load_vectorstore ()
22
- if vectorstore is None :
23
- print ("Warning: FAISS vectorstore could not be loaded." )
24
- except Exception as e :
25
- print (f"Error loading vectorstore: { e } " )
26
- vectorstore = None
27
-
28
- def search (query : str , results : int = 5 ):
29
- """
30
- Search for relevant chunks using similarity search.
31
-
32
- Args:
33
- query: The search query
34
- return_type: Return as string or list (default: str)
35
- results: Number of results to return (default: 5)
36
-
37
- Returns:
38
- Results as string or list depending on return_type
39
- """
40
- if vectorstore is None :
41
- raise ValueError ("Vectorstore not loaded. Please ensure FAISS index exists." )
42
-
43
- search_results = vectorstore .similarity_search (query , k = results )
7
+ tools = {
44
8
45
- result_dict = {}
46
- for idx , result in enumerate (search_results , start = 1 ):
47
- result_dict [idx ] = result .page_content
48
-
49
- result_json = json .dumps (result_dict ,indent = 2 ,ensure_ascii = False )
50
- return f"<result>\n { result_json } \n </result>"
51
-
52
- class TOOL_CALL :
53
- def __call__ (self , completion : str ) -> Tuple [str , bool , Optional [float ]]:
54
- raise NotImplementedError
55
-
56
- class Search_Tool (TOOL_CALL ):
57
- def __call__ (self , completion : str ) -> Tuple [str , bool , Optional [float ]]:
58
- """
59
- Checks if the completion strictly follows the format <think>xxx</think><tool_call>xxx</tool_call>
60
- and if the tool_call contains valid JSON with "tool" and "arg" fields.
61
-
62
- Args:
63
- completion: The text completion to check
64
-
65
- Returns:
66
- Tuple containing:
67
- - search result or empty string
68
- - boolean indicating if there was an error
69
- - score (0.2 if successful, 0 if error)
70
- """
71
- try :
72
- # Check for required strict format using regex
73
- pattern = r'^<think>(.*?)</think><tool_call>(.*?)</tool_call>$'
74
- match = re .match (pattern , completion .strip (), re .DOTALL )
75
-
76
- if not match :
77
- return "" , True , 0
78
-
79
- tool_content = match .group (2 ).strip ()
80
-
81
- # Parse JSON from tool_call content
82
- try :
83
- tool_data = json .loads (tool_content )
84
- except json .JSONDecodeError :
85
- return "" , True , 0
86
-
87
- # Check if JSON has required fields
88
- if not isinstance (tool_data , dict ) or "tool" not in tool_data or "arg" not in tool_data :
89
- return "" , True , 0
90
-
91
- # Check if the tool is "search"
92
- if tool_data ["tool" ] != "search" :
93
- return "" , True , 0
94
-
95
- # Execute search with the provided argument
96
- search_result = search (tool_data ["arg" ])
97
- return search_result , False , 0.2
98
-
99
- except Exception as e :
100
- print (f"Error in Search_Tool: { e } " )
101
- traceback .print_exc ()
102
- return "" , True , 0
9
+ }
0 commit comments