|
21 | 21 | TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()
|
22 | 22 |
|
23 | 23 |
|
| 24 | +def noop(x: int): |
| 25 | + return x |
| 26 | + |
| 27 | + |
24 | 28 | def simple_part(text: str) -> glm.Content:
|
25 | 29 | return glm.Content({"parts": [{"text": text}]})
|
26 | 30 |
|
@@ -725,18 +729,33 @@ def test_system_instruction(self, instruction, expected_instr):
|
725 | 729 | self.assertEqual(req.system_instruction, expected_instr)
|
726 | 730 |
|
727 | 731 | @parameterized.named_parameters(
|
728 |
| - ["basic", "Hello"], |
729 |
| - ["list", ["Hello"]], |
| 732 | + ["basic", {"contents": "Hello"}], |
| 733 | + ["list", {"contents": ["Hello"]}], |
730 | 734 | [
|
731 | 735 | "list2",
|
732 |
| - [{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}], |
| 736 | + { |
| 737 | + "contents": [ |
| 738 | + {"text": "Hello"}, |
| 739 | + {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}, |
| 740 | + ] |
| 741 | + }, |
733 | 742 | ],
|
734 |
| - ["contents", [{"role": "user", "parts": ["hello"]}]], |
| 743 | + [ |
| 744 | + "contents", |
| 745 | + {"contents": [{"role": "user", "parts": ["hello"]}]}, |
| 746 | + ], |
| 747 | + ["empty", {}], |
| 748 | + [ |
| 749 | + "system_instruction", |
| 750 | + {"system_instruction": ["You are a cat"]}, |
| 751 | + ], |
| 752 | + ["tools", {"tools": [noop]}], |
735 | 753 | )
|
736 |
| - def test_count_tokens_smoke(self, contents): |
| 754 | + def test_count_tokens_smoke(self, kwargs): |
| 755 | + si = kwargs.pop("system_instruction", None) |
737 | 756 | self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
|
738 |
| - model = generative_models.GenerativeModel("gemini-pro-vision") |
739 |
| - response = model.count_tokens(contents) |
| 757 | + model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) |
| 758 | + response = model.count_tokens(**kwargs) |
740 | 759 | self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})
|
741 | 760 |
|
742 | 761 | @parameterized.named_parameters(
|
|
0 commit comments