1
+ # Copyright 2024 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
1
16
from __future__ import annotations
2
17
3
18
from collections .abc import Iterable , Mapping , Sequence
@@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:
300
315
return contents
301
316
302
317
303
- def _generate_schema (
318
+ def _schema_for_class (cls : TypedDict ) -> dict [str , Any ]:
319
+ schema = _build_schema ("dummy" , {"dummy" : (cls , pydantic .Field ())})
320
+ return schema ["properties" ]["dummy" ]
321
+
322
+
323
+ def _schema_for_function (
304
324
f : Callable [..., Any ],
305
325
* ,
306
326
descriptions : Mapping [str , str ] | None = None ,
@@ -323,52 +343,36 @@ def _generate_schema(
323
343
"""
324
344
if descriptions is None :
325
345
descriptions = {}
326
- if required is None :
327
- required = []
328
346
defaults = dict (inspect .signature (f ).parameters )
329
- fields_dict = {
330
- name : (
331
- # 1. We infer the argument type here: use Any rather than None so
332
- # it will not try to auto-infer the type based on the default value.
333
- (param .annotation if param .annotation != inspect .Parameter .empty else Any ),
334
- pydantic .Field (
335
- # 2. We do not support default values for now.
336
- # default=(
337
- # param.default if param.default != inspect.Parameter.empty
338
- # else None
339
- # ),
340
- # 3. We support user-provided descriptions.
341
- description = descriptions .get (name , None ),
342
- ),
343
- )
344
- for name , param in defaults .items ()
345
- # We do not support *args or **kwargs
346
- if param .kind
347
- in (
347
+
348
+ fields_dict = {}
349
+ for name , param in defaults .items ():
350
+ if param .kind in (
348
351
inspect .Parameter .POSITIONAL_OR_KEYWORD ,
349
352
inspect .Parameter .KEYWORD_ONLY ,
350
353
inspect .Parameter .POSITIONAL_ONLY ,
351
- )
352
- }
353
- parameters = pydantic .create_model (f .__name__ , ** fields_dict ).schema ()
354
- # Postprocessing
355
- # 4. Suppress unnecessary title generation:
356
- # * https://github.com/pydantic/pydantic/issues/1051
357
- # * http://cl/586221780
358
- parameters .pop ("title" , None )
359
- for name , function_arg in parameters .get ("properties" , {}).items ():
360
- function_arg .pop ("title" , None )
361
- annotation = defaults [name ].annotation
362
- # 5. Nullable fields:
363
- # * https://github.com/pydantic/pydantic/issues/1270
364
- # * https://stackoverflow.com/a/58841311
365
- # * https://github.com/pydantic/pydantic/discussions/4872
366
- if typing .get_origin (annotation ) is typing .Union and type (None ) in typing .get_args (
367
- annotation
368
354
):
369
- function_arg ["nullable" ] = True
355
+ # We do not support default values for now.
356
+ # default=(
357
+ # param.default if param.default != inspect.Parameter.empty
358
+ # else None
359
+ # ),
360
+ field = pydantic .Field (
361
+ # We support user-provided descriptions.
362
+ description = descriptions .get (name , None )
363
+ )
364
+
365
+ # 1. We infer the argument type here: use Any rather than None so
366
+ # it will not try to auto-infer the type based on the default value.
367
+ if param .annotation != inspect .Parameter .empty :
368
+ fields_dict [name ] = param .annotation , field
369
+ else :
370
+ fields_dict [name ] = Any , field
371
+
372
+ parameters = _build_schema (f .__name__ , fields_dict )
373
+
370
374
# 6. Annotate required fields.
371
- if required :
375
+ if required is not None :
372
376
# We use the user-provided "required" fields if specified.
373
377
parameters ["required" ] = required
374
378
else :
@@ -387,9 +391,112 @@ def _generate_schema(
387
391
)
388
392
]
389
393
schema = dict (name = f .__name__ , description = f .__doc__ , parameters = parameters )
394
+
390
395
return schema
391
396
392
397
398
+ def _build_schema (fname , fields_dict ):
399
+ parameters = pydantic .create_model (fname , ** fields_dict ).schema ()
400
+ defs = parameters .pop ("$defs" , {})
401
+ # flatten the defs
402
+ for name , value in defs .items ():
403
+ unpack_defs (value , defs )
404
+ unpack_defs (parameters , defs )
405
+
406
+ # 5. Nullable fields:
407
+ # * https://github.com/pydantic/pydantic/issues/1270
408
+ # * https://stackoverflow.com/a/58841311
409
+ # * https://github.com/pydantic/pydantic/discussions/4872
410
+ convert_to_nullable (parameters )
411
+ add_object_type (parameters )
412
+ # Postprocessing
413
+ # 4. Suppress unnecessary title generation:
414
+ # * https://github.com/pydantic/pydantic/issues/1051
415
+ # * http://cl/586221780
416
+ strip_titles (parameters )
417
+ return parameters
418
+
419
+
420
+ def unpack_defs (schema , defs ):
421
+ properties = schema ["properties" ]
422
+ for name , value in properties .items ():
423
+ ref_key = value .get ("$ref" , None )
424
+ if ref_key is not None :
425
+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
426
+ unpack_defs (ref , defs )
427
+ properties [name ] = ref
428
+ continue
429
+
430
+ anyof = value .get ("anyOf" , None )
431
+ if anyof is not None :
432
+ for i , atype in enumerate (anyof ):
433
+ ref_key = atype .get ("$ref" , None )
434
+ if ref_key is not None :
435
+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
436
+ unpack_defs (ref , defs )
437
+ anyof [i ] = ref
438
+ continue
439
+
440
+ items = value .get ("items" , None )
441
+ if items is not None :
442
+ ref_key = items .get ("$ref" , None )
443
+ if ref_key is not None :
444
+ ref = defs [ref_key .split ("defs/" )[- 1 ]]
445
+ unpack_defs (ref , defs )
446
+ value ["items" ] = ref
447
+ continue
448
+
449
+
450
+ def strip_titles (schema ):
451
+ title = schema .pop ("title" , None )
452
+
453
+ properties = schema .get ("properties" , None )
454
+ if properties is not None :
455
+ for name , value in properties .items ():
456
+ strip_titles (value )
457
+
458
+ items = schema .get ("items" , None )
459
+ if items is not None :
460
+ strip_titles (items )
461
+
462
+
463
+ def add_object_type (schema ):
464
+ properties = schema .get ("properties" , None )
465
+ if properties is not None :
466
+ schema .pop ("required" , None )
467
+ schema ["type" ] = "object"
468
+ for name , value in properties .items ():
469
+ add_object_type (value )
470
+
471
+ items = schema .get ("items" , None )
472
+ if items is not None :
473
+ add_object_type (items )
474
+
475
+
476
+ def convert_to_nullable (schema ):
477
+ anyof = schema .pop ("anyOf" , None )
478
+ if anyof is not None :
479
+ if len (anyof ) != 2 :
480
+ raise ValueError ("Type Unions are not supported (except for Optional)" )
481
+ a , b = anyof
482
+ if a == {"type" : "null" }:
483
+ schema .update (b )
484
+ elif b == {"type" : "null" }:
485
+ schema .update (a )
486
+ else :
487
+ raise ValueError ("Type Unions are not supported (except for Optional)" )
488
+ schema ["nullable" ] = True
489
+
490
+ properties = schema .get ("properties" , None )
491
+ if properties is not None :
492
+ for name , value in properties .items ():
493
+ convert_to_nullable (value )
494
+
495
+ items = schema .get ("items" , None )
496
+ if items is not None :
497
+ convert_to_nullable (items )
498
+
499
+
393
500
def _rename_schema_fields (schema ):
394
501
if schema is None :
395
502
return schema
@@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N
460
567
if descriptions is None :
461
568
descriptions = {}
462
569
463
- schema = _generate_schema (function , descriptions = descriptions )
570
+ schema = _schema_for_function (function , descriptions = descriptions )
464
571
465
572
return CallableFunctionDeclaration (** schema , function = function )
466
573
0 commit comments