72
72
"FunctionLibraryType" ,
73
73
]
74
74
75
+ Mode = protos .DynamicRetrievalConfig .Mode
76
+
77
+ ModeOptions = Union [int , str , Mode ]
78
+
79
+ _MODE : dict [ModeOptions , Mode ] = {
80
+ Mode .MODE_UNSPECIFIED : Mode .MODE_UNSPECIFIED ,
81
+ 0 : Mode .MODE_UNSPECIFIED ,
82
+ "mode_unspecified" : Mode .MODE_UNSPECIFIED ,
83
+ "unspecified" : Mode .MODE_UNSPECIFIED ,
84
+ Mode .MODE_DYNAMIC : Mode .MODE_DYNAMIC ,
85
+ 1 : Mode .MODE_DYNAMIC ,
86
+ "mode_dynamic" : Mode .MODE_DYNAMIC ,
87
+ "dynamic" : Mode .MODE_DYNAMIC ,
88
+ }
89
+
90
+
91
+ def to_mode (x : ModeOptions ) -> Mode :
92
+ if isinstance (x , str ):
93
+ x = x .lower ()
94
+ return _MODE [x ]
95
+
75
96
76
97
def _pil_to_blob (image : PIL .Image .Image ) -> protos .Blob :
77
98
# If the image is a local file, return a file-based blob without any modification.
@@ -644,16 +665,54 @@ def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.F
644
665
return fd .to_proto ()
645
666
646
667
668
+ class DynamicRetrievalConfigDict (TypedDict ):
669
+ mode : protos .DynamicRetrievalConfig .mode
670
+ dynamic_threshold : float
671
+
672
+
673
+ DynamicRetrievalConfig = Union [protos .DynamicRetrievalConfig , DynamicRetrievalConfigDict ]
674
+
675
+
676
+ class GoogleSearchRetrievalDict (TypedDict ):
677
+ dynamic_retrieval_config : DynamicRetrievalConfig
678
+
679
+
680
+ GoogleSearchRetrievalType = Union [protos .GoogleSearchRetrieval , GoogleSearchRetrievalDict ]
681
+
682
+
683
+ def _make_google_search_retrieval (gsr : GoogleSearchRetrievalType ):
684
+ if isinstance (gsr , protos .GoogleSearchRetrieval ):
685
+ return gsr
686
+ elif isinstance (gsr , Mapping ):
687
+ drc = gsr .get ("dynamic_retrieval_config" , None )
688
+ if drc is not None and isinstance (drc , Mapping ):
689
+ mode = drc .get ("mode" , None )
690
+ if mode is not None :
691
+ mode = to_mode (mode )
692
+ gsr = gsr .copy ()
693
+ gsr ["dynamic_retrieval_config" ]["mode" ] = mode
694
+ return protos .GoogleSearchRetrieval (gsr )
695
+ else :
696
+ raise TypeError (
697
+ "Invalid input type. Expected an instance of `genai.GoogleSearchRetrieval`.\n "
698
+ f"However, received an object of type: { type (gsr )} .\n "
699
+ f"Object Value: { gsr } "
700
+ )
701
+
702
+
647
703
class Tool :
648
- """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects."""
704
+ """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects,
705
+ protos.CodeExecution object, and protos.GoogleSearchRetrieval object."""
649
706
650
707
def __init__ (
651
708
self ,
709
+ * ,
652
710
function_declarations : Iterable [FunctionDeclarationType ] | None = None ,
711
+ google_search_retrieval : GoogleSearchRetrievalType | None = None ,
653
712
code_execution : protos .CodeExecution | None = None ,
654
713
):
655
714
# The main path doesn't use this but is seems useful.
656
- if function_declarations :
715
+ if function_declarations is not None :
657
716
self ._function_declarations = [
658
717
_make_function_declaration (f ) for f in function_declarations
659
718
]
@@ -668,15 +727,25 @@ def __init__(
668
727
self ._function_declarations = []
669
728
self ._index = {}
670
729
730
+ if google_search_retrieval is not None :
731
+ self ._google_search_retrieval = _make_google_search_retrieval (google_search_retrieval )
732
+ else :
733
+ self ._google_search_retrieval = None
734
+
671
735
self ._proto = protos .Tool (
672
736
function_declarations = [_encode_fd (fd ) for fd in self ._function_declarations ],
737
+ google_search_retrieval = google_search_retrieval ,
673
738
code_execution = code_execution ,
674
739
)
675
740
676
741
@property
677
742
def function_declarations (self ) -> list [FunctionDeclaration | protos .FunctionDeclaration ]:
678
743
return self ._function_declarations
679
744
745
+ @property
746
+ def google_search_retrieval (self ) -> protos .GoogleSearchRetrieval :
747
+ return self ._google_search_retrieval
748
+
680
749
@property
681
750
def code_execution (self ) -> protos .CodeExecution :
682
751
return self ._proto .code_execution
@@ -705,7 +774,7 @@ class ToolDict(TypedDict):
705
774
706
775
707
776
ToolType = Union [
708
- Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
777
+ str , Tool , protos .Tool , ToolDict , Iterable [FunctionDeclarationType ], FunctionDeclarationType
709
778
]
710
779
711
780
@@ -717,20 +786,41 @@ def _make_tool(tool: ToolType) -> Tool:
717
786
code_execution = tool .code_execution
718
787
else :
719
788
code_execution = None
720
- return Tool (function_declarations = tool .function_declarations , code_execution = code_execution )
789
+
790
+ if "google_search_retrieval" in tool :
791
+ google_search_retrieval = tool .google_search_retrieval
792
+ else :
793
+ google_search_retrieval = None
794
+
795
+ return Tool (
796
+ function_declarations = tool .function_declarations ,
797
+ google_search_retrieval = google_search_retrieval ,
798
+ code_execution = code_execution ,
799
+ )
721
800
elif isinstance (tool , dict ):
722
- if "function_declarations" in tool or "code_execution" in tool :
801
+ if (
802
+ "function_declarations" in tool
803
+ or "google_search_retrieval" in tool
804
+ or "code_execution" in tool
805
+ ):
723
806
return Tool (** tool )
724
807
else :
725
808
fd = tool
726
809
return Tool (function_declarations = [protos .FunctionDeclaration (** fd )])
727
810
elif isinstance (tool , str ):
728
811
if tool .lower () == "code_execution" :
729
812
return Tool (code_execution = protos .CodeExecution ())
813
+ # Check to see if one of the mode enums matches
814
+ elif tool .lower () == "google_search_retrieval" :
815
+ return Tool (google_search_retrieval = protos .GoogleSearchRetrieval ())
730
816
else :
731
- raise ValueError ("The only string that can be passed as a tool is 'code_execution'." )
817
+ raise ValueError (
818
+ "The only string that can be passed as a tool is 'code_execution', or one of the specified values for the `mode` parameter for google_search_retrieval."
819
+ )
732
820
elif isinstance (tool , protos .CodeExecution ):
733
821
return Tool (code_execution = tool )
822
+ elif isinstance (tool , protos .GoogleSearchRetrieval ):
823
+ return Tool (google_search_retrieval = tool )
734
824
elif isinstance (tool , Iterable ):
735
825
return Tool (function_declarations = tool )
736
826
else :
@@ -786,7 +876,7 @@ def to_proto(self):
786
876
787
877
def _make_tools (tools : ToolsType ) -> list [Tool ]:
788
878
if isinstance (tools , str ):
789
- if tools .lower () == "code_execution" :
879
+ if tools .lower () == "code_execution" or tools . lower () == "google_search_retrieval" :
790
880
return [_make_tool (tools )]
791
881
else :
792
882
raise ValueError ("The only string that can be passed as a tool is 'code_execution'." )
0 commit comments