Skip to content

Commit 08028c9

Browse files
committed
experimentation with creating tools specific schemas
1 parent e75d764 commit 08028c9

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

endpoints/OAI/utils/chat_completion.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import pathlib
55
from asyncio import CancelledError
66
from copy import deepcopy
7-
from typing import List, Optional
7+
from typing import List, Optional, Type
88
import json
99

1010
from fastapi import HTTPException, Request
1111
from jinja2 import TemplateError
1212
from loguru import logger
13+
from pydantic import BaseModel, create_model
1314

1415
from common import model
1516
from common.networking import (
@@ -433,6 +434,7 @@ async def generate_tool_calls(
433434

434435
# Copy to make sure the parent JSON schema doesn't get modified
435436
# FIXME: May not be necessary depending on how the codebase evolves
437+
create_tool_call_model(data)
436438
tool_data = deepcopy(data)
437439
tool_data.json_schema = tool_data.tool_call_schema
438440
gen_params = tool_data.to_gen_params()
@@ -465,6 +467,45 @@ async def generate_tool_calls(
465467

466468
return generations
467469

470+
def create_tool_call_model(data: ChatCompletionRequest):
471+
"""Create a tool call model to guide model based on the tools spec provided"""
472+
dtypes = {
473+
"integer": int,
474+
"string": str,
475+
"boolean": bool,
476+
"object": dict,
477+
"array": list
478+
}
479+
480+
tool_response_models = []
481+
for tool in data.tools:
482+
483+
name = tool.function.name
484+
params = tool.function.parameters.get('properties', {})
485+
required_params = tool.function.parameters.get('required', [])
486+
487+
model_fields = {}
488+
if params:
489+
for arg_key, arg_val in params.items():
490+
arg_name = arg_key
491+
arg_dtype = dtypes[arg_val['type']]
492+
required = arg_name in required_params
493+
model_fields["name"] = name # this need to be a string with a strict value of name
494+
model_fields["arguments"] = {}
495+
496+
# Use Field to set whether the argument is required or not
497+
if required:
498+
model_fields["arguments"][arg_name] = (arg_dtype, ...)
499+
else:
500+
model_fields["arguments"][arg_name] = (arg_dtype, None)
501+
502+
# Create the Pydantic model for the tool
503+
tool_response_model = create_model(name, **model_fields)
504+
tool_response_models.append(tool_response_model)
505+
506+
print(tool_response_models) # these tool_response_model will go into the tool_call as a union of them, need to format correctly
507+
508+
468509

469510
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
470511
tool_calls = json.loads(call_str)

0 commit comments

Comments
 (0)