Skip to content

Commit d5103eb

Browse files
Search grounding (#558)
* Updated tests and current progress on adding search grounding. * Update google/generativeai/types/content_types.py Co-authored-by: Mark Daoust <[email protected]> * Update tests/test_content.py Co-authored-by: Mark Daoust <[email protected]> * Update search grounding * update content_types * Update and add aditional test cases * update test case on empty_dictionary_with_dynamic_retrieval_config * Update test cases and _make_search_grounding * fix tests Change-Id: Ib9e19d78861da180f713e09ec93d366d5d7b5762 * Remove print statement * Fix tuned model tests Change-Id: I5ace9222954be7d903ebbdabab9efc663fa79174 * Fix tests Change-Id: Ifa610965c5d6c38123080a7e16416ac325418285 * format Change-Id: Iab48a9400d53f3cbdc5ca49c73df4f6a186a867b * fix typing Change-Id: If892b20ca29d1afb82c48ae1a49bef58e0421bab * Format Change-Id: I51a51150879adb3d4b6b00323e0d8eaf4c0b2515 --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent 6c8dad1 commit d5103eb

File tree

2 files changed

+169
-13
lines changed

2 files changed

+169
-13
lines changed

google/generativeai/types/content_types.py

+97-7
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,27 @@
7272
"FunctionLibraryType",
7373
]
7474

75+
Mode = protos.DynamicRetrievalConfig.Mode
76+
77+
ModeOptions = Union[int, str, Mode]
78+
79+
_MODE: dict[ModeOptions, Mode] = {
80+
Mode.MODE_UNSPECIFIED: Mode.MODE_UNSPECIFIED,
81+
0: Mode.MODE_UNSPECIFIED,
82+
"mode_unspecified": Mode.MODE_UNSPECIFIED,
83+
"unspecified": Mode.MODE_UNSPECIFIED,
84+
Mode.MODE_DYNAMIC: Mode.MODE_DYNAMIC,
85+
1: Mode.MODE_DYNAMIC,
86+
"mode_dynamic": Mode.MODE_DYNAMIC,
87+
"dynamic": Mode.MODE_DYNAMIC,
88+
}
89+
90+
91+
def to_mode(x: ModeOptions) -> Mode:
92+
if isinstance(x, str):
93+
x = x.lower()
94+
return _MODE[x]
95+
7596

7697
def _pil_to_blob(image: PIL.Image.Image) -> protos.Blob:
7798
# If the image is a local file, return a file-based blob without any modification.
@@ -644,16 +665,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
644665
return fd.to_proto()
645666

646667

668+
class DynamicRetrievalConfigDict(TypedDict):
669+
mode: protos.DynamicRetrievalConfig.mode
670+
dynamic_threshold: float
671+
672+
673+
DynamicRetrievalConfig = Union[protos.DynamicRetrievalConfig, DynamicRetrievalConfigDict]
674+
675+
676+
class GoogleSearchRetrievalDict(TypedDict):
677+
dynamic_retrieval_config: DynamicRetrievalConfig
678+
679+
680+
GoogleSearchRetrievalType = Union[protos.GoogleSearchRetrieval, GoogleSearchRetrievalDict]
681+
682+
683+
def _make_google_search_retrieval(gsr: GoogleSearchRetrievalType):
684+
if isinstance(gsr, protos.GoogleSearchRetrieval):
685+
return gsr
686+
elif isinstance(gsr, Mapping):
687+
drc = gsr.get("dynamic_retrieval_config", None)
688+
if drc is not None and isinstance(drc, Mapping):
689+
mode = drc.get("mode", None)
690+
if mode is not None:
691+
mode = to_mode(mode)
692+
gsr = gsr.copy()
693+
gsr["dynamic_retrieval_config"]["mode"] = mode
694+
return protos.GoogleSearchRetrieval(gsr)
695+
else:
696+
raise TypeError(
697+
"Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n"
698+
f"However, received an object of type: {type(gsr)}.\n"
699+
f"Object Value: {gsr}"
700+
)
701+
702+
647703
class Tool:
648-
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
704+
"""A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
705+
protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
649706

650707
def __init__(
651708
self,
709+
*,
652710
function_declarations: Iterable[FunctionDeclarationType] | None = None,
711+
google_search_retrieval: GoogleSearchRetrievalType | None = None,
653712
code_execution: protos.CodeExecution | None = None,
654713
):
655714
# The main path doesn't use this but is seems useful.
656-
if function_declarations:
715+
if function_declarations is not None:
657716
self._function_declarations = [
658717
_make_function_declaration(f) for f in function_declarations
659718
]
@@ -668,15 +727,25 @@ def __init__(
668727
self._function_declarations = []
669728
self._index = {}
670729

730+
if google_search_retrieval is not None:
731+
self._google_search_retrieval = _make_google_search_retrieval(google_search_retrieval)
732+
else:
733+
self._google_search_retrieval = None
734+
671735
self._proto = protos.Tool(
672736
function_declarations=[_encode_fd(fd) for fd in self._function_declarations],
737+
google_search_retrieval=google_search_retrieval,
673738
code_execution=code_execution,
674739
)
675740

676741
@property
677742
def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]:
678743
return self._function_declarations
679744

745+
@property
746+
def google_search_retrieval(self) -> protos.GoogleSearchRetrieval:
747+
return self._google_search_retrieval
748+
680749
@property
681750
def code_execution(self) -> protos.CodeExecution:
682751
return self._proto.code_execution
@@ -705,7 +774,7 @@ class ToolDict(TypedDict):
705774

706775

707776
ToolType = Union[
708-
Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
777+
str, Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType
709778
]
710779

711780

@@ -717,20 +786,41 @@ def _make_tool(tool: ToolType) -> Tool:
717786
code_execution = tool.code_execution
718787
else:
719788
code_execution = None
720-
return Tool(function_declarations=tool.function_declarations, code_execution=code_execution)
789+
790+
if "google_search_retrieval" in tool:
791+
google_search_retrieval = tool.google_search_retrieval
792+
else:
793+
google_search_retrieval = None
794+
795+
return Tool(
796+
function_declarations=tool.function_declarations,
797+
google_search_retrieval=google_search_retrieval,
798+
code_execution=code_execution,
799+
)
721800
elif isinstance(tool, dict):
722-
if "function_declarations" in tool or "code_execution" in tool:
801+
if (
802+
"function_declarations" in tool
803+
or "google_search_retrieval" in tool
804+
or "code_execution" in tool
805+
):
723806
return Tool(**tool)
724807
else:
725808
fd = tool
726809
return Tool(function_declarations=[protos.FunctionDeclaration(**fd)])
727810
elif isinstance(tool, str):
728811
if tool.lower() == "code_execution":
729812
return Tool(code_execution=protos.CodeExecution())
813+
# Check to see if one of the mode enums matches
814+
elif tool.lower() == "google_search_retrieval":
815+
return Tool(google_search_retrieval=protos.GoogleSearchRetrieval())
730816
else:
731-
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")
817+
raise ValueError(
818+
"The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
819+
)
732820
elif isinstance(tool, protos.CodeExecution):
733821
return Tool(code_execution=tool)
822+
elif isinstance(tool, protos.GoogleSearchRetrieval):
823+
return Tool(google_search_retrieval=tool)
734824
elif isinstance(tool, Iterable):
735825
return Tool(function_declarations=tool)
736826
else:
@@ -786,7 +876,7 @@ def to_proto(self):
786876

787877
def _make_tools(tools: ToolsType) -> list[Tool]:
788878
if isinstance(tools, str):
789-
if tools.lower() == "code_execution":
879+
if tools.lower() == "code_execution" or tools.lower() == "google_search_retrieval":
790880
return [_make_tool(tools)]
791881
else:
792882
raise ValueError("The only string that can be passed as a tool is 'code_execution'.")

tests/test_content.py

+72-6
Original file line numberDiff line numberDiff line change
@@ -435,12 +435,78 @@ def no_args():
435435
["empty_dictionary_list", [{"code_execution": {}}]],
436436
)
437437
def test_code_execution(self, tools):
438-
if isinstance(tools, Iterable):
439-
t = content_types._make_tools(tools)
440-
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
441-
else:
442-
t = content_types._make_tool(tools) # Pass code execution into tools
443-
self.assertIsInstance(t.code_execution, protos.CodeExecution)
438+
t = content_types._make_tools(tools)
439+
self.assertIsInstance(t[0].code_execution, protos.CodeExecution)
440+
441+
@parameterized.named_parameters(
442+
["string", "google_search_retrieval"],
443+
["empty_dictionary", {"google_search_retrieval": {}}],
444+
[
445+
"empty_dictionary_with_dynamic_retrieval_config",
446+
{"google_search_retrieval": {"dynamic_retrieval_config": {}}},
447+
],
448+
[
449+
"dictionary_with_mode_integer",
450+
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": 0}}},
451+
],
452+
[
453+
"dictionary_with_mode_string",
454+
{"google_search_retrieval": {"dynamic_retrieval_config": {"mode": "DYNAMIC"}}},
455+
],
456+
[
457+
"dictionary_with_dynamic_retrieval_config",
458+
{
459+
"google_search_retrieval": {
460+
"dynamic_retrieval_config": {"mode": "unspecified", "dynamic_threshold": 0.5}
461+
}
462+
},
463+
],
464+
[
465+
"proto_object",
466+
protos.GoogleSearchRetrieval(
467+
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
468+
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
469+
)
470+
),
471+
],
472+
[
473+
"proto_passed_in",
474+
protos.Tool(
475+
google_search_retrieval=protos.GoogleSearchRetrieval(
476+
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
477+
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
478+
)
479+
)
480+
),
481+
],
482+
[
483+
"proto_object_list",
484+
[
485+
protos.GoogleSearchRetrieval(
486+
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
487+
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
488+
)
489+
)
490+
],
491+
],
492+
[
493+
"proto_passed_in_list",
494+
[
495+
protos.Tool(
496+
google_search_retrieval=protos.GoogleSearchRetrieval(
497+
dynamic_retrieval_config=protos.DynamicRetrievalConfig(
498+
mode="MODE_UNSPECIFIED", dynamic_threshold=0.5
499+
)
500+
)
501+
)
502+
],
503+
],
504+
)
505+
def test_search_grounding(self, tools):
506+
if self._testMethodName == "test_search_grounding_empty_dictionary":
507+
pass
508+
t = content_types._make_tools(tools)
509+
self.assertIsInstance(t[0].google_search_retrieval, protos.GoogleSearchRetrieval)
444510

445511
def test_two_fun_is_one_tool(self):
446512
def a():

0 commit comments

Comments
 (0)