4
4
import pathlib
5
5
from asyncio import CancelledError
6
6
from copy import deepcopy
7
- from typing import List , Optional , Type
7
+ from typing import List , Optional
8
8
import json
9
9
10
10
from fastapi import HTTPException , Request
11
11
from jinja2 import TemplateError
12
12
from loguru import logger
13
- from pydantic import BaseModel , create_model
14
13
15
14
from common import model
16
15
from common .networking import (
32
31
)
33
32
from endpoints .OAI .types .common import UsageStats
34
33
from endpoints .OAI .utils .completion import _stream_collector
35
- from endpoints .OAI .types .tools import ToolCall
34
+
35
+ from endpoints .OAI .utils .tools import (
36
+ postprocess_tool_call ,
37
+ generate_strict_schemas
38
+ )
36
39
37
40
38
41
def _create_response (
@@ -434,9 +437,12 @@ async def generate_tool_calls(
434
437
435
438
# Copy to make sure the parent JSON schema doesn't get modified
436
439
# FIXME: May not be necessary depending on how the codebase evolves
437
- create_tool_call_model (data )
440
+ if data .tools :
441
+ strict_schema = generate_strict_schemas (data )
442
+ print (strict_schema )
438
443
tool_data = deepcopy (data )
439
- tool_data .json_schema = tool_data .tool_call_schema
444
+ #tool_data.json_schema = tool_data.tool_call_schema
445
+ tool_data .json_schema = strict_schema # needs strict flag
440
446
gen_params = tool_data .to_gen_params ()
441
447
442
448
for idx , gen in enumerate (generations ):
@@ -467,50 +473,54 @@ async def generate_tool_calls(
467
473
468
474
return generations
469
475
470
- def create_tool_call_model (data : ChatCompletionRequest ):
471
- """Create a tool call model to guide model based on the tools spec provided"""
472
- dtypes = {
473
- "integer" : int ,
474
- "string" : str ,
475
- "boolean" : bool ,
476
- "object" : dict ,
477
- "array" : list
478
- }
479
-
480
- tool_response_models = []
481
- for tool in data .tools :
482
-
483
- name = tool .function .name
484
- params = tool .function .parameters .get ('properties' , {})
485
- required_params = tool .function .parameters .get ('required' , [])
486
-
487
- model_fields = {}
488
- if params :
489
- for arg_key , arg_val in params .items ():
490
- arg_name = arg_key
491
- arg_dtype = dtypes [arg_val ['type' ]]
492
- required = arg_name in required_params
493
- model_fields ["name" ] = name # this need to be a string with a strict value of name
494
- model_fields ["arguments" ] = {}
495
-
496
- # Use Field to set whether the argument is required or not
497
- if required :
498
- model_fields ["arguments" ][arg_name ] = (arg_dtype , ...)
499
- else :
500
- model_fields ["arguments" ][arg_name ] = (arg_dtype , None )
501
-
502
- # Create the Pydantic model for the tool
503
- tool_response_model = create_model (name , ** model_fields )
504
- tool_response_models .append (tool_response_model )
505
-
506
- print (tool_response_models ) # these tool_response_model will go into the tool_call as a union of them, need to format correctly
507
-
508
-
509
-
510
- def postprocess_tool_call (call_str : str ) -> List [ToolCall ]:
511
- tool_calls = json .loads (call_str )
512
- for tool_call in tool_calls :
513
- tool_call ["function" ]["arguments" ] = json .dumps (
514
- tool_call ["function" ]["arguments" ]
515
- )
516
- return [ToolCall (** tool_call ) for tool_call in tool_calls ]
476
+ # def create_tool_call_model(data: ChatCompletionRequest):
477
+ # """Create a tool call model to guide model based on the tools spec provided"""
478
+ # dtypes = {
479
+ # "integer": int,
480
+ # "string": str,
481
+ # "boolean": bool,
482
+ # "object": dict,
483
+ # "array": list
484
+ # }
485
+
486
+ # function_models = []
487
+ # for tool in data.tools:
488
+
489
+ # tool_name = tool.function.name
490
+ # raw_params = tool.function.parameters.get('properties', {})
491
+ # required_params = tool.function.parameters.get('required', [])
492
+
493
+ # fields = {}
494
+ # if raw_params:
495
+ # for arg_key, val_dict in raw_params.items():
496
+
497
+ # arg_name = arg_key
498
+ # arg_dtype = dtypes[val_dict['type']]
499
+ # required = arg_name in required_params
500
+ # fields[arg_name] = (arg_dtype, ... if required else None)
501
+ # if not required:
502
+ # arg_dtype = Optional[arg_dtype]
503
+
504
+ # fields[arg_name] = (arg_dtype, ... if required else None)
505
+
506
+ # arguments_model = create_model(f"{tool_name}Arguments", **fields)
507
+
508
+ # function_model = create_model(
509
+ # f"{tool_name}Model",
510
+ # name=(str, tool_name),
511
+ # arguments=(arguments_model, ...)
512
+ # )
513
+
514
+ # function_models.append(function_model)
515
+
516
+ # fucntion_union = Union[tuple(function_models)]
517
+
518
+ # tool_response_model = create_model(
519
+ # "tools_call_response_model",
520
+ # id=(str, ...),
521
+ # function=(fucntion_union, ...)
522
+ # )
523
+
524
+ # tool_response_model.model_rebuild()
525
+
526
+ # return tool_response_model
0 commit comments