|
4 | 4 | import pathlib
|
5 | 5 | from asyncio import CancelledError
|
6 | 6 | from copy import deepcopy
|
7 |
| -from typing import List, Optional |
| 7 | +from typing import List, Optional, Type |
8 | 8 | import json
|
9 | 9 |
|
10 | 10 | from fastapi import HTTPException, Request
|
11 | 11 | from jinja2 import TemplateError
|
12 | 12 | from loguru import logger
|
| 13 | +from pydantic import BaseModel, create_model |
13 | 14 |
|
14 | 15 | from common import model
|
15 | 16 | from common.networking import (
|
@@ -433,6 +434,7 @@ async def generate_tool_calls(
|
433 | 434 |
|
434 | 435 | # Copy to make sure the parent JSON schema doesn't get modified
|
435 | 436 | # FIXME: May not be necessary depending on how the codebase evolves
|
| 437 | + create_tool_call_model(data) |
436 | 438 | tool_data = deepcopy(data)
|
437 | 439 | tool_data.json_schema = tool_data.tool_call_schema
|
438 | 440 | gen_params = tool_data.to_gen_params()
|
@@ -465,6 +467,45 @@ async def generate_tool_calls(
|
465 | 467 |
|
466 | 468 | return generations
|
467 | 469 |
|
| 470 | +def create_tool_call_model(data: ChatCompletionRequest): |
| 471 | + """Create a tool call model to guide model based on the tools spec provided""" |
| 472 | + dtypes = { |
| 473 | + "integer": int, |
| 474 | + "string": str, |
| 475 | + "boolean": bool, |
| 476 | + "object": dict, |
| 477 | + "array": list |
| 478 | + } |
| 479 | + |
| 480 | + tool_response_models = [] |
| 481 | + for tool in data.tools: |
| 482 | + |
| 483 | + name = tool.function.name |
| 484 | + params = tool.function.parameters.get('properties', {}) |
| 485 | + required_params = tool.function.parameters.get('required', []) |
| 486 | + |
| 487 | + model_fields = {} |
| 488 | + if params: |
| 489 | + for arg_key, arg_val in params.items(): |
| 490 | + arg_name = arg_key |
| 491 | + arg_dtype = dtypes[arg_val['type']] |
| 492 | + required = arg_name in required_params |
| 493 | + model_fields["name"] = name # this need to be a string with a strict value of name |
| 494 | + model_fields["arguments"] = {} |
| 495 | + |
| 496 | + # Use Field to set whether the argument is required or not |
| 497 | + if required: |
| 498 | + model_fields["arguments"][arg_name] = (arg_dtype, ...) |
| 499 | + else: |
| 500 | + model_fields["arguments"][arg_name] = (arg_dtype, None) |
| 501 | + |
| 502 | + # Create the Pydantic model for the tool |
| 503 | + tool_response_model = create_model(name, **model_fields) |
| 504 | + tool_response_models.append(tool_response_model) |
| 505 | + |
| 506 | + print(tool_response_models) # these tool_response_model will go into the tool_call as a union of them, need to format correctly |
| 507 | + |
| 508 | + |
468 | 509 |
|
469 | 510 | def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
470 | 511 | tool_calls = json.loads(call_str)
|
|
0 commit comments