Skip to content

Commit e09e7f2

Browse files
authored
Improve schema support (#309)
* handle nested schemas Change-Id: I22476536eb12027eb6b3a6dfcfa95cf61d2f4c0c * Improve support for nested schemas Change-Id: I51f761d87ab62465c50881301714aa5c38e7056d * Improve support for nested schemas Change-Id: I4739d8c46b0815134d55fbff4413544cb71a39fe * Improve support for nested schemas Change-Id: If97e7265954db092cfba54b0f61c1606d4b9b1d2 * Improve support for nested schemas Change-Id: I426db26133356eed885f7702ff2c465631adc418 * format Change-Id: Id722f2a02b0115dfbdaafe5b9a9f56ad4c6737b1 * more tests that will need to pass Change-Id: I3595531b4c974a3bee0291abec470e625722dfb2 * work on nested schema. Change-Id: Ia05084dd6e59009f6fca590c5a7e42b537964a51 * format Change-Id: I98cb8da98b0bb9aae7adcf073cd648b152410552 * service fails if 'required' is used in nested objects Change-Id: Iade8b6f91b2d26a29c90890a4b67678927f73a44 * format Change-Id: Id6f123168f12657eb2c01f36aff848d717244554 * Add support for types in "response_schema" Change-Id: Id7a17d5fba055020bc9bd94d98bd585ed19171df * add missing import Change-Id: Iacbcb1acbd468347ffb2b873258a1d0737c947d7 * update generativelanguage version Change-Id: I106cdf98a950ae6bf92dcf58c98064c09f5da5f4 * add tests Change-Id: I1de22340f48ed2d6ae54423419a33965a7bc3a67
1 parent a89469f commit e09e7f2

File tree

5 files changed

+332
-81
lines changed

5 files changed

+332
-81
lines changed

google/generativeai/types/content_types.py

Lines changed: 149 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
116
from __future__ import annotations
217

318
from collections.abc import Iterable, Mapping, Sequence
@@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:
300315
return contents
301316

302317

303-
def _generate_schema(
318+
def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
319+
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
320+
return schema["properties"]["dummy"]
321+
322+
323+
def _schema_for_function(
304324
f: Callable[..., Any],
305325
*,
306326
descriptions: Mapping[str, str] | None = None,
@@ -323,52 +343,36 @@ def _generate_schema(
323343
"""
324344
if descriptions is None:
325345
descriptions = {}
326-
if required is None:
327-
required = []
328346
defaults = dict(inspect.signature(f).parameters)
329-
fields_dict = {
330-
name: (
331-
# 1. We infer the argument type here: use Any rather than None so
332-
# it will not try to auto-infer the type based on the default value.
333-
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
334-
pydantic.Field(
335-
# 2. We do not support default values for now.
336-
# default=(
337-
# param.default if param.default != inspect.Parameter.empty
338-
# else None
339-
# ),
340-
# 3. We support user-provided descriptions.
341-
description=descriptions.get(name, None),
342-
),
343-
)
344-
for name, param in defaults.items()
345-
# We do not support *args or **kwargs
346-
if param.kind
347-
in (
347+
348+
fields_dict = {}
349+
for name, param in defaults.items():
350+
if param.kind in (
348351
inspect.Parameter.POSITIONAL_OR_KEYWORD,
349352
inspect.Parameter.KEYWORD_ONLY,
350353
inspect.Parameter.POSITIONAL_ONLY,
351-
)
352-
}
353-
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
354-
# Postprocessing
355-
# 4. Suppress unnecessary title generation:
356-
# * https://github.com/pydantic/pydantic/issues/1051
357-
# * http://cl/586221780
358-
parameters.pop("title", None)
359-
for name, function_arg in parameters.get("properties", {}).items():
360-
function_arg.pop("title", None)
361-
annotation = defaults[name].annotation
362-
# 5. Nullable fields:
363-
# * https://github.com/pydantic/pydantic/issues/1270
364-
# * https://stackoverflow.com/a/58841311
365-
# * https://github.com/pydantic/pydantic/discussions/4872
366-
if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args(
367-
annotation
368354
):
369-
function_arg["nullable"] = True
355+
# We do not support default values for now.
356+
# default=(
357+
# param.default if param.default != inspect.Parameter.empty
358+
# else None
359+
# ),
360+
field = pydantic.Field(
361+
# We support user-provided descriptions.
362+
description=descriptions.get(name, None)
363+
)
364+
365+
# 1. We infer the argument type here: use Any rather than None so
366+
# it will not try to auto-infer the type based on the default value.
367+
if param.annotation != inspect.Parameter.empty:
368+
fields_dict[name] = param.annotation, field
369+
else:
370+
fields_dict[name] = Any, field
371+
372+
parameters = _build_schema(f.__name__, fields_dict)
373+
370374
# 6. Annotate required fields.
371-
if required:
375+
if required is not None:
372376
# We use the user-provided "required" fields if specified.
373377
parameters["required"] = required
374378
else:
@@ -387,9 +391,112 @@ def _generate_schema(
387391
)
388392
]
389393
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
394+
390395
return schema
391396

392397

398+
def _build_schema(fname, fields_dict):
399+
parameters = pydantic.create_model(fname, **fields_dict).schema()
400+
defs = parameters.pop("$defs", {})
401+
# flatten the defs
402+
for name, value in defs.items():
403+
unpack_defs(value, defs)
404+
unpack_defs(parameters, defs)
405+
406+
# 5. Nullable fields:
407+
# * https://github.com/pydantic/pydantic/issues/1270
408+
# * https://stackoverflow.com/a/58841311
409+
# * https://github.com/pydantic/pydantic/discussions/4872
410+
convert_to_nullable(parameters)
411+
add_object_type(parameters)
412+
# Postprocessing
413+
# 4. Suppress unnecessary title generation:
414+
# * https://github.com/pydantic/pydantic/issues/1051
415+
# * http://cl/586221780
416+
strip_titles(parameters)
417+
return parameters
418+
419+
420+
def unpack_defs(schema, defs):
421+
properties = schema["properties"]
422+
for name, value in properties.items():
423+
ref_key = value.get("$ref", None)
424+
if ref_key is not None:
425+
ref = defs[ref_key.split("defs/")[-1]]
426+
unpack_defs(ref, defs)
427+
properties[name] = ref
428+
continue
429+
430+
anyof = value.get("anyOf", None)
431+
if anyof is not None:
432+
for i, atype in enumerate(anyof):
433+
ref_key = atype.get("$ref", None)
434+
if ref_key is not None:
435+
ref = defs[ref_key.split("defs/")[-1]]
436+
unpack_defs(ref, defs)
437+
anyof[i] = ref
438+
continue
439+
440+
items = value.get("items", None)
441+
if items is not None:
442+
ref_key = items.get("$ref", None)
443+
if ref_key is not None:
444+
ref = defs[ref_key.split("defs/")[-1]]
445+
unpack_defs(ref, defs)
446+
value["items"] = ref
447+
continue
448+
449+
450+
def strip_titles(schema):
451+
title = schema.pop("title", None)
452+
453+
properties = schema.get("properties", None)
454+
if properties is not None:
455+
for name, value in properties.items():
456+
strip_titles(value)
457+
458+
items = schema.get("items", None)
459+
if items is not None:
460+
strip_titles(items)
461+
462+
463+
def add_object_type(schema):
464+
properties = schema.get("properties", None)
465+
if properties is not None:
466+
schema.pop("required", None)
467+
schema["type"] = "object"
468+
for name, value in properties.items():
469+
add_object_type(value)
470+
471+
items = schema.get("items", None)
472+
if items is not None:
473+
add_object_type(items)
474+
475+
476+
def convert_to_nullable(schema):
477+
anyof = schema.pop("anyOf", None)
478+
if anyof is not None:
479+
if len(anyof) != 2:
480+
raise ValueError("Type Unions are not supported (except for Optional)")
481+
a, b = anyof
482+
if a == {"type": "null"}:
483+
schema.update(b)
484+
elif b == {"type": "null"}:
485+
schema.update(a)
486+
else:
487+
raise ValueError("Type Unions are not supported (except for Optional)")
488+
schema["nullable"] = True
489+
490+
properties = schema.get("properties", None)
491+
if properties is not None:
492+
for name, value in properties.items():
493+
convert_to_nullable(value)
494+
495+
items = schema.get("items", None)
496+
if items is not None:
497+
convert_to_nullable(items)
498+
499+
393500
def _rename_schema_fields(schema):
394501
if schema is None:
395502
return schema
@@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N
460567
if descriptions is None:
461568
descriptions = {}
462569

463-
schema = _generate_schema(function, descriptions=descriptions)
570+
schema = _schema_for_function(function, descriptions=descriptions)
464571

465572
return CallableFunctionDeclaration(**schema, function=function)
466573

google/generativeai/types/generation_types.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@
2525
import textwrap
2626
from typing import Union, Any
2727
from typing_extensions import TypedDict
28+
import types
2829

2930
import google.protobuf.json_format
3031
import google.api_core.exceptions
3132

3233
from google.ai import generativelanguage as glm
3334
from google.generativeai import string_utils
35+
from google.generativeai.types import content_types
3436
from google.generativeai.responder import _rename_schema_fields
3537

3638
__all__ = [
@@ -174,8 +176,20 @@ def _normalize_schema(generation_config):
174176
response_schema = generation_config.get("response_schema", None)
175177
if response_schema is None:
176178
return
179+
177180
if isinstance(response_schema, glm.Schema):
178181
return
182+
183+
if isinstance(response_schema, type):
184+
response_schema = content_types._schema_for_class(response_schema)
185+
elif isinstance(response_schema, types.GenericAlias):
186+
if not str(response_schema).startswith("list["):
187+
raise ValueError(
188+
f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, "
189+
"`typing_extensions.TypedDict`, `dataclass`, or `list[...]`"
190+
)
191+
response_schema = content_types._schema_for_class(response_schema)
192+
179193
response_schema = _rename_schema_fields(response_schema)
180194
generation_config["response_schema"] = glm.Schema(response_schema)
181195

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_version():
4242
release_status = "Development Status :: 5 - Production/Stable"
4343

4444
dependencies = [
45-
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz",
45+
"google-ai-generativelanguage==0.6.3",
4646
"google-api-core",
4747
"google-api-python-client",
4848
"google-auth>=2.15.0", # 2.15 adds API key auth support

0 commit comments

Comments
 (0)