Skip to content

Commit f903054

Browse files
committed
Inital supertools (TM) Moving all of the tool related support functions to their own modules
1 parent b3caf7b commit f903054

File tree

2 files changed

+212
-52
lines changed

2 files changed

+212
-52
lines changed

Diff for: endpoints/OAI/utils/chat_completion.py

+62-52
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import pathlib
55
from asyncio import CancelledError
66
from copy import deepcopy
7-
from typing import List, Optional, Type
7+
from typing import List, Optional
88
import json
99

1010
from fastapi import HTTPException, Request
1111
from jinja2 import TemplateError
1212
from loguru import logger
13-
from pydantic import BaseModel, create_model
1413

1514
from common import model
1615
from common.networking import (
@@ -32,7 +31,11 @@
3231
)
3332
from endpoints.OAI.types.common import UsageStats
3433
from endpoints.OAI.utils.completion import _stream_collector
35-
from endpoints.OAI.types.tools import ToolCall
34+
35+
from endpoints.OAI.utils.tools import(
36+
postprocess_tool_call,
37+
generate_strict_schemas
38+
)
3639

3740

3841
def _create_response(
@@ -434,9 +437,12 @@ async def generate_tool_calls(
434437

435438
# Copy to make sure the parent JSON schema doesn't get modified
436439
# FIXME: May not be necessary depending on how the codebase evolves
437-
create_tool_call_model(data)
440+
if data.tools:
441+
strict_schema = generate_strict_schemas(data)
442+
print(strict_schema)
438443
tool_data = deepcopy(data)
439-
tool_data.json_schema = tool_data.tool_call_schema
444+
#tool_data.json_schema = tool_data.tool_call_schema
445+
tool_data.json_schema = strict_schema # needs strict flag
440446
gen_params = tool_data.to_gen_params()
441447

442448
for idx, gen in enumerate(generations):
@@ -467,50 +473,54 @@ async def generate_tool_calls(
467473

468474
return generations
469475

470-
def create_tool_call_model(data: ChatCompletionRequest):
471-
"""Create a tool call model to guide model based on the tools spec provided"""
472-
dtypes = {
473-
"integer": int,
474-
"string": str,
475-
"boolean": bool,
476-
"object": dict,
477-
"array": list
478-
}
479-
480-
tool_response_models = []
481-
for tool in data.tools:
482-
483-
name = tool.function.name
484-
params = tool.function.parameters.get('properties', {})
485-
required_params = tool.function.parameters.get('required', [])
486-
487-
model_fields = {}
488-
if params:
489-
for arg_key, arg_val in params.items():
490-
arg_name = arg_key
491-
arg_dtype = dtypes[arg_val['type']]
492-
required = arg_name in required_params
493-
model_fields["name"] = name # this need to be a string with a strict value of name
494-
model_fields["arguments"] = {}
495-
496-
# Use Field to set whether the argument is required or not
497-
if required:
498-
model_fields["arguments"][arg_name] = (arg_dtype, ...)
499-
else:
500-
model_fields["arguments"][arg_name] = (arg_dtype, None)
501-
502-
# Create the Pydantic model for the tool
503-
tool_response_model = create_model(name, **model_fields)
504-
tool_response_models.append(tool_response_model)
505-
506-
print(tool_response_models) # these tool_response_model will go into the tool_call as a union of them, need to format correctly
507-
508-
509-
510-
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
511-
tool_calls = json.loads(call_str)
512-
for tool_call in tool_calls:
513-
tool_call["function"]["arguments"] = json.dumps(
514-
tool_call["function"]["arguments"]
515-
)
516-
return [ToolCall(**tool_call) for tool_call in tool_calls]
476+
# def create_tool_call_model(data: ChatCompletionRequest):
477+
# """Create a tool call model to guide model based on the tools spec provided"""
478+
# dtypes = {
479+
# "integer": int,
480+
# "string": str,
481+
# "boolean": bool,
482+
# "object": dict,
483+
# "array": list
484+
# }
485+
486+
# function_models = []
487+
# for tool in data.tools:
488+
489+
# tool_name = tool.function.name
490+
# raw_params = tool.function.parameters.get('properties', {})
491+
# required_params = tool.function.parameters.get('required', [])
492+
493+
# fields = {}
494+
# if raw_params:
495+
# for arg_key, val_dict in raw_params.items():
496+
497+
# arg_name = arg_key
498+
# arg_dtype = dtypes[val_dict['type']]
499+
# required = arg_name in required_params
500+
# fields[arg_name] = (arg_dtype, ... if required else None)
501+
# if not required:
502+
# arg_dtype = Optional[arg_dtype]
503+
504+
# fields[arg_name] = (arg_dtype, ... if required else None)
505+
506+
# arguments_model = create_model(f"{tool_name}Arguments", **fields)
507+
508+
# function_model = create_model(
509+
# f"{tool_name}Model",
510+
# name=(str, tool_name),
511+
# arguments=(arguments_model, ...)
512+
# )
513+
514+
# function_models.append(function_model)
515+
516+
# fucntion_union = Union[tuple(function_models)]
517+
518+
# tool_response_model = create_model(
519+
# "tools_call_response_model",
520+
# id=(str, ...),
521+
# function=(fucntion_union, ...)
522+
# )
523+
524+
# tool_response_model.model_rebuild()
525+
526+
# return tool_response_model

Diff for: endpoints/OAI/utils/tools.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Support functions to enable tool calling"""
2+
3+
from typing import List, Dict
4+
from copy import deepcopy
5+
import json
6+
7+
from endpoints.OAI.types.tools import ToolCall
8+
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
9+
10+
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
11+
print(call_str)
12+
tool_calls = json.loads(call_str)
13+
for tool_call in tool_calls:
14+
tool_call["function"]["arguments"] = json.dumps(
15+
tool_call["function"]["arguments"]
16+
)
17+
return [ToolCall(**tool_call) for tool_call in tool_calls]
18+
19+
20+
def generate_strict_schemas(data: ChatCompletionRequest):
21+
base_schema = {
22+
"$defs": {},
23+
"properties": {
24+
"id": {"title": "Id", "type": "string"},
25+
"function": {"title": "Function"},
26+
"type": {"$ref": "#/$defs/Type"}
27+
},
28+
"required": ["id", "function", "type"],
29+
"title": "ModelItem",
30+
"type": "object"
31+
}
32+
33+
function_schemas = []
34+
argument_schemas = {}
35+
36+
for i, tool in enumerate(data.tools):
37+
function_name = f"Function{i+1}" if i > 0 else "Function"
38+
argument_name = f"Arguments{i+1}" if i > 0 else "Arguments"
39+
name_def = f"Name{i+1}" if i > 0 else "Name"
40+
41+
# Create Name definition
42+
base_schema["$defs"][name_def] = {
43+
"const": tool.function.name,
44+
"enum": [tool.function.name],
45+
"title": name_def,
46+
"type": "string"
47+
}
48+
49+
# Create Arguments definition
50+
arg_properties = {}
51+
required_params = tool.function.parameters.get('required', [])
52+
for arg_name, arg_info in tool.function.parameters.get('properties', {}).items():
53+
arg_properties[arg_name] = {
54+
"title": arg_name.capitalize(),
55+
"type": arg_info['type']
56+
}
57+
58+
argument_schemas[argument_name] = {
59+
"properties": arg_properties,
60+
"required": required_params,
61+
"title": argument_name,
62+
"type": "object"
63+
}
64+
65+
# Create Function definition
66+
function_schema = {
67+
"properties": {
68+
"name": {"$ref": f"#/$defs/{name_def}"},
69+
"arguments": {"$ref": f"#/$defs/{argument_name}"}
70+
},
71+
"required": ["name", "arguments"],
72+
"title": function_name,
73+
"type": "object"
74+
}
75+
76+
function_schemas.append({"$ref": f"#/$defs/{function_name}"})
77+
base_schema["$defs"][function_name] = function_schema
78+
79+
# Add argument schemas to $defs
80+
base_schema["$defs"].update(argument_schemas)
81+
82+
# Add Type definition
83+
base_schema["$defs"]["Type"] = {
84+
"const": "function",
85+
"enum": ["function"],
86+
"title": "Type",
87+
"type": "string"
88+
}
89+
90+
# Set up the function property
91+
base_schema["properties"]["function"]["anyOf"] = function_schemas
92+
93+
return base_schema
94+
95+
96+
# def generate_strict_schemas(data: ChatCompletionRequest):
97+
# schema = {
98+
# "type": "object",
99+
# "properties": {
100+
# "name": {"type": "string"},
101+
# "arguments": {
102+
# "type": "object",
103+
# "properties": {},
104+
# "required": []
105+
# }
106+
# },
107+
# "required": ["name", "arguments"]
108+
# }
109+
110+
# function_schemas = []
111+
112+
# for tool in data.tools:
113+
# func_schema = deepcopy(schema)
114+
# func_schema["properties"]["name"]["enum"] = [tool.function.name]
115+
# raw_params = tool.function.parameters.get('properties', {})
116+
# required_params = tool.function.parameters.get('required', [])
117+
118+
# # Add argument properties and required fields
119+
# arg_properties = {}
120+
# for arg_name, arg_type in raw_params.items():
121+
# arg_properties[arg_name] = {"type": arg_type['type']}
122+
123+
# func_schema["properties"]["arguments"]["properties"] = arg_properties
124+
# func_schema["properties"]["arguments"]["required"] = required_params
125+
126+
# function_schemas.append(func_schema)
127+
128+
# return _create_full_schema(function_schemas)
129+
130+
# def _create_full_schema(function_schemas: List) -> Dict:
131+
# # Define the master schema structure with placeholders for function schemas
132+
# tool_call_schema = {
133+
# "$schema": "http://json-schema.org/draft-07/schema#",
134+
# "type": "array",
135+
# "items": {
136+
# "type": "object",
137+
# "properties": {
138+
# "id": {"type": "string"},
139+
# "function": {
140+
# "type": "object", # Add this line
141+
# "oneOf": function_schemas
142+
# },
143+
# "type": {"type": "string", "enum": ["function"]}
144+
# },
145+
# "required": ["id", "function", "type"]
146+
# }
147+
# }
148+
149+
# print(json.dumps(tool_call_schema, indent=2))
150+
# return tool_call_schema

0 commit comments

Comments
 (0)