2
2
3
3
import os
4
4
import contextlib
5
+ import inspect
5
6
import dataclasses
6
7
import pathlib
7
- import types
8
8
from typing import Any , cast
9
9
from collections .abc import Sequence
10
10
import httplib2
30
30
__version__ = "0.0.0"
31
31
32
32
USER_AGENT = "genai-py"
33
+
34
+ #### Caution! ####
35
+ # - It would make sense for the discovery URL to respect the client_options.endpoint setting.
36
+ # - That would make testing Files on the staging server possible.
37
+ # - We tried fixing this once, but broke colab in the process because their endpoint didn't forward the discovery
38
+ # requests. https://github.com/google-gemini/generative-ai-python/pull/333
39
+ # - Kaggle would have a similar problem (b/362278209).
40
+ # - I think their proxy would forward the discovery traffic.
41
+ # - But they don't need to intercept the files-service at all, and uploads of large files could overload them.
42
+ # - Do the scotty uploads go to the same domain?
43
+ # - If you do route the discovery call to kaggle, be sure to attach the default_metadata (they need it).
44
+ # - One solution to all this would be if configure could take overrides per service.
45
+ # - set client_options.endpoint, but use a different endpoint for file service? It's not clear how best to do that
46
+ # through the file service.
47
+ ##################
33
48
GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest"
34
49
35
50
@@ -50,7 +65,7 @@ def __init__(self, *args, **kwargs):
50
65
self ._discovery_api = None
51
66
super ().__init__ (* args , ** kwargs )
52
67
53
- def _setup_discovery_api (self ):
68
+ def _setup_discovery_api (self , metadata : dict | Sequence [ tuple [ str , str ]] = () ):
54
69
api_key = self ._client_options .api_key
55
70
if api_key is None :
56
71
raise ValueError (
@@ -61,6 +76,7 @@ def _setup_discovery_api(self):
61
76
http = httplib2 .Http (),
62
77
postproc = lambda resp , content : (resp , content ),
63
78
uri = f"{ GENAI_API_DISCOVERY_URL } ?version=v1beta&key={ api_key } " ,
79
+ headers = dict (metadata ),
64
80
)
65
81
response , content = request .execute ()
66
82
request .http .close ()
@@ -78,9 +94,10 @@ def create_file(
78
94
name : str | None = None ,
79
95
display_name : str | None = None ,
80
96
resumable : bool = True ,
97
+ metadata : Sequence [tuple [str , str ]] = (),
81
98
) -> protos .File :
82
99
if self ._discovery_api is None :
83
- self ._setup_discovery_api ()
100
+ self ._setup_discovery_api (metadata )
84
101
85
102
file = {}
86
103
if name is not None :
@@ -92,6 +109,8 @@ def create_file(
92
109
filename = path , mimetype = mime_type , resumable = resumable
93
110
)
94
111
request = self ._discovery_api .media ().upload (body = {"file" : file }, media_body = media )
112
+ for key , value in metadata :
113
+ request .headers [key ] = value
95
114
result = request .execute ()
96
115
97
116
return self .get_file ({"name" : result ["file" ]["name" ]})
@@ -108,9 +127,6 @@ async def create_file(self, *args, **kwargs):
108
127
class _ClientManager :
109
128
client_config : dict [str , Any ] = dataclasses .field (default_factory = dict )
110
129
default_metadata : Sequence [tuple [str , str ]] = ()
111
-
112
- discuss_client : glm .DiscussServiceClient | None = None
113
- discuss_async_client : glm .DiscussServiceAsyncClient | None = None
114
130
clients : dict [str , Any ] = dataclasses .field (default_factory = dict )
115
131
116
132
def configure (
@@ -119,7 +135,7 @@ def configure(
119
135
api_key : str | None = None ,
120
136
credentials : ga_credentials .Credentials | dict | None = None ,
121
137
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
122
- # See ` _transport_registry` in `DiscussServiceClientMeta` .
138
+ # See _transport_registry in the google.ai.generativelanguage package .
123
139
# Since the transport classes align with the client classes it wouldn't make
124
140
# sense to accept a `Transport` object here even though the client classes can.
125
141
# We could accept a dict since all the `Transport` classes take the same args,
@@ -229,16 +245,14 @@ def make_client(self, name):
229
245
def keep (name , f ):
230
246
if name .startswith ("_" ):
231
247
return False
232
- elif name == "create_file" :
233
- return False
234
- elif not isinstance (f , types .FunctionType ):
235
- return False
236
- elif isinstance (f , classmethod ):
248
+
249
+ if not callable (f ):
237
250
return False
238
- elif isinstance (f , staticmethod ):
251
+
252
+ if "metadata" not in inspect .signature (f ).parameters .keys ():
239
253
return False
240
- else :
241
- return True
254
+
255
+ return True
242
256
243
257
def add_default_metadata_wrapper (f ):
244
258
def call (* args , metadata = (), ** kwargs ):
@@ -247,7 +261,7 @@ def call(*args, metadata=(), **kwargs):
247
261
248
262
return call
249
263
250
- for name , value in cls . __dict__ . items ( ):
264
+ for name , value in inspect . getmembers ( cls ):
251
265
if not keep (name , value ):
252
266
continue
253
267
f = getattr (client , name )
@@ -281,7 +295,6 @@ def configure(
281
295
api_key : str | None = None ,
282
296
credentials : ga_credentials .Credentials | dict | None = None ,
283
297
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
284
- # See `_transport_registry` in `DiscussServiceClientMeta`.
285
298
# Since the transport classes align with the client classes it wouldn't make
286
299
# sense to accept a `Transport` object here even though the client classes can.
287
300
# We could accept a dict since all the `Transport` classes take the same args,
@@ -326,14 +339,6 @@ def get_default_cache_client() -> glm.CacheServiceClient:
326
339
return _client_manager .get_default_client ("cache" )
327
340
328
341
329
- def get_default_discuss_client () -> glm .DiscussServiceClient :
330
- return _client_manager .get_default_client ("discuss" )
331
-
332
-
333
- def get_default_discuss_async_client () -> glm .DiscussServiceAsyncClient :
334
- return _client_manager .get_default_client ("discuss_async" )
335
-
336
-
337
342
def get_default_file_client () -> glm .FilesServiceClient :
338
343
return _client_manager .get_default_client ("file" )
339
344
@@ -350,10 +355,6 @@ def get_default_generative_async_client() -> glm.GenerativeServiceAsyncClient:
350
355
return _client_manager .get_default_client ("generative_async" )
351
356
352
357
353
- def get_default_text_client () -> glm .TextServiceClient :
354
- return _client_manager .get_default_client ("text" )
355
-
356
-
357
358
def get_default_operations_client () -> operations_v1 .OperationsClient :
358
359
return _client_manager .get_default_client ("operations" )
359
360
0 commit comments