Skip to content

Commit cfc936e

Browse files
Stroke out functional approach for CachedContent CURD ops
1 parent afd066d commit cfc936e

File tree

5 files changed

+246
-274
lines changed

5 files changed

+246
-274
lines changed

google/generativeai/caching.py

+228-40
Original file line numberDiff line numberDiff line change
@@ -14,63 +14,251 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17-
from typing import Optional, Iterable
18-
19-
import google.ai.generativelanguage as glm
17+
import dataclasses
18+
import datetime
19+
from typing import Any, Iterable, Optional
2020

21+
from google.generativeai.types.model_types import idecode_time
2122
from google.generativeai.types import caching_types
2223
from google.generativeai.types import content_types
24+
from google.generativeai.utils import flatten_update_paths
2325
from google.generativeai.client import get_default_cache_client
2426

27+
from google.protobuf import field_mask_pb2
28+
import google.ai.generativelanguage as glm
29+
30+
31+
@dataclasses.dataclass
32+
class CachedContent:
33+
"""Cached content resource."""
34+
35+
name: str
36+
model: str
37+
create_time: datetime.datetime
38+
update_time: datetime.datetime
39+
expire_time: datetime.datetime
40+
41+
# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
42+
# Adding basic support for now.
43+
def __enter__(self):
44+
return self
45+
46+
def __exit__(self, exc_type, exc_value, exc_tb):
47+
self.delete()
48+
49+
def _to_dict(self) -> glm.CachedContent:
50+
proto_paths = {
51+
"name": self.name,
52+
"model": self.model,
53+
}
54+
return glm.CachedContent(**proto_paths)
55+
56+
def _apply_update(self, path, value):
57+
parts = path.split(".")
58+
for part in parts[:-1]:
59+
self = getattr(self, part)
60+
if parts[-1] == "ttl":
61+
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
62+
parts[-1] = "expire_time"
63+
setattr(self, parts[-1], value)
64+
65+
@classmethod
66+
def _decode_cached_content(cls, cached_content: glm.CachedContent) -> CachedContent:
67+
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
68+
# is returning these, hence setting including_default_value_fields to False
69+
cached_content = type(cached_content).to_dict(
70+
cached_content, including_default_value_fields=False
71+
)
72+
73+
idecode_time(cached_content, "create_time")
74+
idecode_time(cached_content, "update_time")
75+
# always decode `expire_time` as Timestamp is returned
76+
# regardless of what was sent on input
77+
idecode_time(cached_content, "expire_time")
78+
return cls(**cached_content)
79+
80+
@staticmethod
81+
def _prepare_create_request(
82+
name: str,
83+
model: str,
84+
system_instruction: Optional[content_types.ContentType] = None,
85+
contents: Optional[content_types.ContentsType] = None,
86+
tools: Optional[content_types.FunctionLibraryType] = None,
87+
tool_config: Optional[content_types.ToolConfigType] = None,
88+
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
89+
) -> glm.CreateCachedContentRequest:
90+
"""Prepares a CreateCachedContentRequest."""
91+
if "cachedContents/" not in name:
92+
name = "cachedContents/" + name
93+
94+
if "/" not in model:
95+
model = "models/" + model
96+
97+
if system_instruction:
98+
system_instruction = content_types.to_content(system_instruction)
99+
100+
tools_lib = content_types.to_function_library(tools)
101+
if tools_lib:
102+
tools_lib = tools_lib.to_proto()
103+
104+
if tool_config:
105+
tool_config = content_types.to_tool_config(tool_config)
106+
107+
if contents:
108+
contents = content_types.to_contents(contents)
109+
110+
if ttl:
111+
ttl = caching_types.to_ttl(ttl)
112+
113+
cached_content = glm.CachedContent(
114+
name=name,
115+
model=model,
116+
system_instruction=system_instruction,
117+
contents=contents,
118+
tools=tools_lib,
119+
tool_config=tool_config,
120+
ttl=ttl,
121+
)
122+
123+
return glm.CreateCachedContentRequest(cached_content=cached_content)
124+
125+
@classmethod
126+
def create(
127+
cls,
128+
name: str,
129+
model: str,
130+
system_instruction: Optional[content_types.ContentType] = None,
131+
contents: Optional[content_types.ContentsType] = None,
132+
tools: Optional[content_types.FunctionLibraryType] = None,
133+
tool_config: Optional[content_types.ToolConfigType] = None,
134+
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
135+
client: glm.CacheServiceClient | None = None,
136+
) -> CachedContent:
137+
"""Creates CachedContent resource.
138+
139+
Args:
140+
name: The resource name referring to the cached content.
141+
Format: cachedContents/{id}.
142+
model: The name of the `Model` to use for cached content
143+
Format: models/{model}. Cached content resource can be only
144+
used with model it was created for.
145+
system_instruction: Developer set system instruction.
146+
contents: Contents to cache.
147+
tools: A list of `Tools` the model may use to generate response.
148+
tool_config: Config to apply to all tools.
149+
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
150+
151+
Returns:
152+
`CachedContent` resource with specified name.
153+
"""
154+
if client is None:
155+
client = get_default_cache_client()
156+
157+
request = cls._prepare_create_request(
158+
name=name,
159+
model=model,
160+
system_instruction=system_instruction,
161+
contents=contents,
162+
tools=tools,
163+
tool_config=tool_config,
164+
ttl=ttl,
165+
)
166+
167+
response = client.create_cached_content(request)
168+
return cls._decode_cached_content(response)
169+
170+
@classmethod
171+
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
172+
"""Fetches required `CachedContent` resource.
173+
174+
Args:
175+
name: name: The resource name referring to the cached content.
25176
26-
# alias for `caching_types.CachedContent`.
27-
CachedContent = caching_types.CachedContent
177+
Returns:
178+
`CachedContent` resource with specified name.
179+
"""
180+
if client is None:
181+
client = get_default_cache_client()
28182

183+
if "cachedContents/" not in name:
184+
name = "cachedContents/" + name
29185

30-
def get_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
31-
"""Fetches required `CachedContent` resource.
186+
request = glm.GetCachedContentRequest(name=name)
187+
response = client.get_cached_content(request)
188+
return cls._decode_cached_content(response)
189+
190+
@classmethod
191+
def list(
192+
cls,
193+
page_size: Optional[int] = 1,
194+
client: glm.CacheServiceClient | None = None
195+
) -> Iterable[CachedContent]:
196+
"""Lists `CachedContent` objects associated with the project.
32197
33-
Args:
34-
name: name: The resource name referring to the cached content.
198+
Args:
199+
page_size: The maximum number of permissions to return (per page).
200+
The service may return fewer `CachedContent` objects.
35201
36-
Returns:
37-
`CachedContent` resource with specified name.
38-
"""
39-
return CachedContent.get(name=name, client=client)
202+
Returns:
203+
A paginated list of `CachedContent` objects.
204+
"""
205+
if client is None:
206+
client = get_default_cache_client()
40207

208+
request = glm.ListCachedContentsRequest(page_size=page_size)
209+
for cached_content in client.list_cached_contents(request):
210+
yield cls._decode_cached_content(cached_content)
41211

42-
def delete_cached_content(name: str, client: glm.CacheServiceClient | None = None) -> None:
43-
"""Deletes `CachedContent` resource.
212+
def delete(self, client: glm.CachedServiceClient | None = None) -> None:
213+
"""Deletes `CachedContent` resource.
44214
45-
Args:
46-
name: The resource name referring to the cached content.
47-
Format: cachedContents/{id}.
48-
"""
49-
if client is None:
50-
client = get_default_cache_client()
215+
Args:
216+
name: The resource name referring to the cached content.
217+
Format: cachedContents/{id}.
218+
"""
219+
if client is None:
220+
client = get_default_cache_client()
51221

52-
if "cachedContents/" not in name:
53-
name = "cachedContents/" + name
222+
request = glm.DeleteCachedContentRequest(name=self.name)
223+
client.delete_cached_content(request)
224+
return
54225

55-
request = glm.DeleteCachedContentRequest(name=name)
56-
client.delete_cached_content(request)
57-
return
226+
def update(
227+
self,
228+
updates: dict[str, Any],
229+
client: glm.CacheServiceClient | None = None,
230+
) -> CachedContent:
231+
"""Updates requested `CachedContent` resource.
58232
233+
Args:
234+
updates: The list of fields to update.
235+
Currently only `ttl/expire_time` is supported as an update path.
59236
60-
def list_cached_contents(
61-
page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
62-
) -> Iterable[CachedContent]:
63-
"""Lists `CachedContent` objects associated with the project.
237+
Returns:
238+
`CachedContent` object with specified updates.
239+
"""
240+
if client is None:
241+
client = get_default_cache_client()
64242

65-
Args:
66-
page_size: The maximum number of permissions to return (per page). The service may return fewer `CachedContent` objects.
243+
updates = flatten_update_paths(updates)
244+
for update_path in updates:
245+
if update_path == "ttl":
246+
updates = updates.copy()
247+
update_path_val = updates.get(update_path)
248+
updates[update_path] = caching_types.to_ttl(update_path_val)
249+
else:
250+
raise ValueError(
251+
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
252+
)
253+
field_mask = field_mask_pb2.FieldMask()
67254

68-
Returns:
69-
A paginated list of `CachedContent` objects.
70-
"""
71-
if client is None:
72-
client = get_default_cache_client()
255+
for path in updates.keys():
256+
field_mask.paths.append(path)
257+
for path, value in updates.items():
258+
self._apply_update(path, value)
73259

74-
request = glm.ListCachedContentsRequest(page_size=page_size)
75-
for cached_content in client.list_cached_contents(request):
76-
yield caching_types.decode_cached_content(cached_content)
260+
request = glm.UpdateCachedContentRequest(
261+
cached_content=self._to_dict(), update_mask=field_mask
262+
)
263+
client.update_cached_content(request)
264+
return self

google/generativeai/generative_models.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from google.generativeai.types import generation_types
2020
from google.generativeai.types import helper_types
2121
from google.generativeai.types import safety_types
22-
from google.generativeai.types import caching_types
23-
2422

2523
class GenerativeModel:
2624
"""
@@ -198,15 +196,15 @@ def from_cached_content(
198196
@classmethod
199197
def from_cached_content(
200198
cls,
201-
cached_content: caching_types.CachedContent,
199+
cached_content: caching.CachedContent,
202200
generation_config: generation_types.GenerationConfigType | None = None,
203201
safety_settings: safety_types.SafetySettingOptions | None = None,
204202
) -> GenerativeModel: ...
205203

206204
@classmethod
207205
def from_cached_content(
208206
cls,
209-
cached_content: str | caching_types.CachedContent,
207+
cached_content: str | caching.CachedContent,
210208
generation_config: generation_types.GenerationConfigType | None = None,
211209
safety_settings: safety_types.SafetySettingOptions | None = None,
212210
) -> GenerativeModel:
@@ -219,7 +217,7 @@ def from_cached_content(
219217
`GenerativeModel` object with `cached_content` as its context.
220218
"""
221219
if isinstance(cached_content, str):
222-
cached_content = caching.get_cached_content(name=cached_content)
220+
cached_content = caching.CachedContent.get(name=cached_content)
223221

224222
# call __new__ with the cached_content to set the model's context. This is done to avoid
225223
# the exposing `cached_content` as a public attribute.

0 commit comments

Comments
 (0)