Skip to content

Commit 05877f7

Browse files
authored
Allow empty contents with count_tokens (#342)
Change-Id: Ic20e2f88427d2e4fbc97847cf5c2df1f80a9a5a1
1 parent 88f7ab3 commit 05877f7

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

Diff for: google/generativeai/generative_models.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def _prepare_request(
129129
tool_config: content_types.ToolConfigType | None,
130130
) -> glm.GenerateContentRequest:
131131
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
132-
if not contents:
133-
raise TypeError("contents must not be empty")
134-
135132
tools_lib = self._get_tools_lib(tools)
136133
if tools_lib is not None:
137134
tools_lib = tools_lib.to_proto()
@@ -235,6 +232,9 @@ def generate_content(
235232
tools: `glm.Tools` more info coming soon.
236233
request_options: Options for the request.
237234
"""
235+
if not contents:
236+
raise TypeError("contents must not be empty")
237+
238238
request = self._prepare_request(
239239
contents=contents,
240240
generation_config=generation_config,
@@ -282,6 +282,9 @@ async def generate_content_async(
282282
request_options: helper_types.RequestOptionsType | None = None,
283283
) -> generation_types.AsyncGenerateContentResponse:
284284
"""The async version of `GenerativeModel.generate_content`."""
285+
if not contents:
286+
raise TypeError("contents must not be empty")
287+
285288
request = self._prepare_request(
286289
contents=contents,
287290
generation_config=generation_config,

Diff for: tests/test_generative_models.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()
2222

2323

24+
def noop(x: int):
25+
return x
26+
27+
2428
def simple_part(text: str) -> glm.Content:
2529
return glm.Content({"parts": [{"text": text}]})
2630

@@ -725,18 +729,33 @@ def test_system_instruction(self, instruction, expected_instr):
725729
self.assertEqual(req.system_instruction, expected_instr)
726730

727731
@parameterized.named_parameters(
728-
["basic", "Hello"],
729-
["list", ["Hello"]],
732+
["basic", {"contents": "Hello"}],
733+
["list", {"contents": ["Hello"]}],
730734
[
731735
"list2",
732-
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
736+
{
737+
"contents": [
738+
{"text": "Hello"},
739+
{"inline_data": {"data": b"PNG!", "mime_type": "image/png"}},
740+
]
741+
},
733742
],
734-
["contents", [{"role": "user", "parts": ["hello"]}]],
743+
[
744+
"contents",
745+
{"contents": [{"role": "user", "parts": ["hello"]}]},
746+
],
747+
["empty", {}],
748+
[
749+
"system_instruction",
750+
{"system_instruction": ["You are a cat"]},
751+
],
752+
["tools", {"tools": [noop]}],
735753
)
736-
def test_count_tokens_smoke(self, contents):
754+
def test_count_tokens_smoke(self, kwargs):
755+
si = kwargs.pop("system_instruction", None)
737756
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
738-
model = generative_models.GenerativeModel("gemini-pro-vision")
739-
response = model.count_tokens(contents)
757+
model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si)
758+
response = model.count_tokens(**kwargs)
740759
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})
741760

742761
@parameterized.named_parameters(

0 commit comments

Comments
 (0)