Skip to content

Commit c3c90f8

Browse files
Add ability to configure timeouts (#204)
* Add ability to configure timeouts * Add timeout to docstrings * Replace timeout default w/ None * Add missing timeout arg to mocked methed * Make async/sync generate_content have same formatting * Add missing timeout arg to embedding mocks * Add request_options param * Test generate answer called w/ req opts * Test embed_content called w/ req opts * Test count_tokens called w/ req opts * Test async generate_content called w/ req opts * Test embed/generate text called w/ req opts * Fixup mocked methods w/ missing req opts * Test batch_embed_text called w/ req opts * Test count_text_tokens called w/ req opts * Test create_corpus called w/ req opts * Test corpus called w/ req opts * Unpack req opts, hande None case * Handle exc instead of assert raises * Handle failing tests * Test batch_embed_contents called w/ req opts * Test get_model called w/ req opts * Test get_tuned_model called w/ req opts * Test list_models called w/ req opts * Test list_tuned_models called w/ req opts * Test update_tuned_model called w/ req opts * Test delete_tuned_model called w/ req opts * Test create_tuned_model called w/ req opts * Add missing kwarg to mocked client method * Pass req opts to update tuned model client method * Test update_corpus called w/ req opts * Test list_corpora called w/ req opts * Test query_corpus called w/ req opts * Test delete_corpus called w/ req opts * Rm unneeded try catch on delete corpus * Test create_document called w/ req opts * Import missing Any * Test embed_content called w/ req opts * Test sync count_tokens called w/ req opts * Test create_corpus called w/ req opts * Test async embed_content called w/ req opts * Test generate_message called w/ req opts * Test get_corpus called w/ req opts * Test update_corpus called w/ req opts * Test get_document called w/ req opts * Test update_document called w/ req opts * Test list_documents called w/ req opts * Test delete_document called w/ req opts * Test query_document called w/ req opts * Test create_chunk called w/ req opts * Test batch_create_chunks called w/ req opts * Test get_chunk called w/ req opts * Test list_chunks called w/ req opts * Test update_chunk called w/ req opts * Test batch_update_chunks called w/ req opts * Test delete_chunk called w/ req opts * Test batch_delete_chunks called w/ req opts * Add req opts to docstrings * Add couple more docstrs 4 req opts * Remove stray timeout. --------- Co-authored-by: Mark Daoust <[email protected]>
1 parent bcec1b5 commit c3c90f8

21 files changed

+1197
-129
lines changed

Diff for: google/generativeai/answer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def generate_answer(
165165
safety_settings: safety_types.SafetySettingOptions | None = None,
166166
temperature: float | None = None,
167167
client: glm.GenerativeServiceClient | None = None,
168+
request_options: dict[str, Any] | None = None,
168169
):
169170
"""
170171
Calls the API and returns a `types.Answer` containing the answer.
@@ -177,10 +178,14 @@ def generate_answer(
177178
answer_style: Style in which the grounded answer should be returned.
178179
safety_settings: Safety settings for generated output. Defaults to None.
179180
client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead.
181+
request_options: Options for the request.
180182
181183
Returns:
182184
A `types.Answer` containing the model's text answer response.
183185
"""
186+
if request_options is None:
187+
request_options = {}
188+
184189
if client is None:
185190
client = get_default_generative_client()
186191

@@ -193,6 +198,6 @@ def generate_answer(
193198
answer_style=answer_style,
194199
)
195200

196-
response = client.generate_answer(request)
201+
response = client.generate_answer(request, **request_options)
197202

198203
return response

Diff for: google/generativeai/discuss.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import sys
1919
import textwrap
2020

21-
from typing import Iterable, List, Optional, Union
21+
from typing import Any, Iterable, List, Optional, Union
2222

2323
import google.ai.generativelanguage as glm
2424

@@ -316,6 +316,7 @@ def chat(
316316
top_k: float | None = None,
317317
prompt: discuss_types.MessagePromptOptions | None = None,
318318
client: glm.DiscussServiceClient | None = None,
319+
request_options: dict[str, Any] | None = None,
319320
) -> discuss_types.ChatResponse:
320321
"""Calls the API and returns a `types.ChatResponse` containing the response.
321322
@@ -382,6 +383,7 @@ def chat(
382383
setting `context`/`examples`/`messages`, but not both.
383384
client: If you're not relying on the default client, you pass a
384385
`glm.DiscussServiceClient` instead.
386+
request_options: Options for the request.
385387
386388
Returns:
387389
A `types.ChatResponse` containing the model's reply.
@@ -398,7 +400,7 @@ def chat(
398400
prompt=prompt,
399401
)
400402

401-
return _generate_response(client=client, request=request)
403+
return _generate_response(client=client, request=request, request_options=request_options)
402404

403405

404406
@string_utils.set_doc(chat.__doc__)
@@ -414,6 +416,7 @@ async def chat_async(
414416
top_k: float | None = None,
415417
prompt: discuss_types.MessagePromptOptions | None = None,
416418
client: glm.DiscussServiceAsyncClient | None = None,
419+
request_options: dict[str, Any] | None = None,
417420
) -> discuss_types.ChatResponse:
418421
request = _make_generate_message_request(
419422
model=model,
@@ -427,7 +430,9 @@ async def chat_async(
427430
prompt=prompt,
428431
)
429432

430-
return await _generate_response_async(client=client, request=request)
433+
return await _generate_response_async(
434+
client=client, request=request, request_options=request_options
435+
)
431436

432437

433438
if (sys.version_info.major, sys.version_info.minor) >= (3, 10):
@@ -461,7 +466,11 @@ def last(self, message: discuss_types.MessageOptions):
461466
self.messages[-1] = message
462467

463468
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
464-
def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResponse:
469+
def reply(
470+
self,
471+
message: discuss_types.MessageOptions,
472+
request_options: dict[str, Any] | None = None,
473+
) -> discuss_types.ChatResponse:
465474
if isinstance(self._client, glm.DiscussServiceAsyncClient):
466475
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
467476
if self.last is None:
@@ -477,7 +486,9 @@ def reply(self, message: discuss_types.MessageOptions) -> discuss_types.ChatResp
477486
request["messages"] = list(request["messages"])
478487
request["messages"].append(_make_message(message))
479488
request = _make_generate_message_request(**request)
480-
return _generate_response(request=request, client=self._client)
489+
return _generate_response(
490+
request=request, client=self._client, request_options=request_options
491+
)
481492

482493
@string_utils.set_doc(discuss_types.ChatResponse.reply.__doc__)
483494
async def reply_async(
@@ -526,23 +537,31 @@ def _build_chat_response(
526537
def _generate_response(
527538
request: glm.GenerateMessageRequest,
528539
client: glm.DiscussServiceClient | None = None,
540+
request_options: dict[str, Any] | None = None,
529541
) -> ChatResponse:
542+
if request_options is None:
543+
request_options = {}
544+
530545
if client is None:
531546
client = get_default_discuss_client()
532547

533-
response = client.generate_message(request)
548+
response = client.generate_message(request, **request_options)
534549

535550
return _build_chat_response(request, response, client)
536551

537552

538553
async def _generate_response_async(
539554
request: glm.GenerateMessageRequest,
540555
client: glm.DiscussServiceAsyncClient | None = None,
556+
request_options: dict[str, Any] | None = None,
541557
) -> ChatResponse:
558+
if request_options is None:
559+
request_options = {}
560+
542561
if client is None:
543562
client = get_default_discuss_async_client()
544563

545-
response = await client.generate_message(request)
564+
response = await client.generate_message(request, **request_options)
546565

547566
return _build_chat_response(request, response, client)
548567

@@ -555,13 +574,17 @@ def count_message_tokens(
555574
messages: discuss_types.MessagesOptions | None = None,
556575
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
557576
client: glm.DiscussServiceAsyncClient | None = None,
577+
request_options: dict[str, Any] | None = None,
558578
) -> discuss_types.TokenCount:
559579
model = model_types.make_model_name(model)
560580
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
561581

582+
if request_options is None:
583+
request_options = {}
584+
562585
if client is None:
563586
client = get_default_discuss_client()
564587

565-
result = client.count_message_tokens(model=model, prompt=prompt)
588+
result = client.count_message_tokens(model=model, prompt=prompt, **request_options)
566589

567590
return type(result).to_dict(result)

Diff for: google/generativeai/embedding.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import dataclasses
1818
from collections.abc import Iterable, Sequence, Mapping
1919
import itertools
20-
from typing import Iterable, overload, TypeVar, Union, Mapping
20+
from typing import Any, Iterable, overload, TypeVar, Union, Mapping
2121

2222
import google.ai.generativelanguage as glm
2323

@@ -95,6 +95,7 @@ def embed_content(
9595
task_type: EmbeddingTaskTypeOptions | None = None,
9696
title: str | None = None,
9797
client: glm.GenerativeServiceClient | None = None,
98+
request_options: dict[str, Any] | None = None,
9899
) -> text_types.EmbeddingDict: ...
99100

100101

@@ -105,6 +106,7 @@ def embed_content(
105106
task_type: EmbeddingTaskTypeOptions | None = None,
106107
title: str | None = None,
107108
client: glm.GenerativeServiceClient | None = None,
109+
request_options: dict[str, Any] | None = None,
108110
) -> text_types.BatchEmbeddingDict: ...
109111

110112

@@ -114,6 +116,7 @@ def embed_content(
114116
task_type: EmbeddingTaskTypeOptions | None = None,
115117
title: str | None = None,
116118
client: glm.GenerativeServiceClient = None,
119+
request_options: dict[str, Any] | None = None,
117120
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
118121
"""Calls the API to create embeddings for content passed in.
119122
@@ -132,13 +135,18 @@ def embed_content(
132135
title:
133136
An optional title for the text. Only applicable when task_type is
134137
`RETRIEVAL_DOCUMENT`.
138+
request_options:
139+
Options for the request.
135140
136141
Return:
137142
Dictionary containing the embedding (list of float values) for the
138143
input content.
139144
"""
140145
model = model_types.make_model_name(model)
141146

147+
if request_options is None:
148+
request_options = {}
149+
142150
if client is None:
143151
client = get_default_generative_client()
144152

@@ -160,15 +168,21 @@ def embed_content(
160168
)
161169
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
162170
embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch)
163-
embedding_response = client.batch_embed_contents(embedding_request)
171+
embedding_response = client.batch_embed_contents(
172+
embedding_request,
173+
**request_options,
174+
)
164175
embedding_dict = type(embedding_response).to_dict(embedding_response)
165176
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
166177
return result
167178
else:
168179
embedding_request = glm.EmbedContentRequest(
169180
model=model, content=content_types.to_content(content), task_type=task_type, title=title
170181
)
171-
embedding_response = client.embed_content(embedding_request)
182+
embedding_response = client.embed_content(
183+
embedding_request,
184+
**request_options,
185+
)
172186
embedding_dict = type(embedding_response).to_dict(embedding_response)
173187
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
174188
return embedding_dict
@@ -181,6 +195,7 @@ async def embed_content_async(
181195
task_type: EmbeddingTaskTypeOptions | None = None,
182196
title: str | None = None,
183197
client: glm.GenerativeServiceAsyncClient | None = None,
198+
request_options: dict[str, Any] | None = None,
184199
) -> text_types.EmbeddingDict: ...
185200

186201

@@ -191,6 +206,7 @@ async def embed_content_async(
191206
task_type: EmbeddingTaskTypeOptions | None = None,
192207
title: str | None = None,
193208
client: glm.GenerativeServiceAsyncClient | None = None,
209+
request_options: dict[str, Any] | None = None,
194210
) -> text_types.BatchEmbeddingDict: ...
195211

196212

@@ -200,10 +216,14 @@ async def embed_content_async(
200216
task_type: EmbeddingTaskTypeOptions | None = None,
201217
title: str | None = None,
202218
client: glm.GenerativeServiceAsyncClient = None,
219+
request_options: dict[str, Any] | None = None,
203220
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
204221
"""The async version of `genai.embed_content`."""
205222
model = model_types.make_model_name(model)
206223

224+
if request_options is None:
225+
request_options = {}
226+
207227
if client is None:
208228
client = get_default_generative_async_client()
209229

@@ -225,15 +245,21 @@ async def embed_content_async(
225245
)
226246
for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE):
227247
embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch)
228-
embedding_response = await client.batch_embed_contents(embedding_request)
248+
embedding_response = await client.batch_embed_contents(
249+
embedding_request,
250+
**request_options,
251+
)
229252
embedding_dict = type(embedding_response).to_dict(embedding_response)
230253
result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"])
231254
return result
232255
else:
233256
embedding_request = glm.EmbedContentRequest(
234257
model=model, content=content_types.to_content(content), task_type=task_type, title=title
235258
)
236-
embedding_response = await client.embed_content(embedding_request)
259+
embedding_response = await client.embed_content(
260+
embedding_request,
261+
**request_options,
262+
)
237263
embedding_dict = type(embedding_response).to_dict(embedding_response)
238264
embedding_dict["embedding"] = embedding_dict["embedding"]["values"]
239265
return embedding_dict

0 commit comments

Comments
 (0)