Skip to content

Explicit Caching patch #377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5824d8e
Squashed commit of the following:
mayureshagashe2105 Jun 5, 2024
3d8ff35
Remove auto cache deletion
mayureshagashe2105 Jun 5, 2024
4c5a31b
Rename _to_dict --> _get_update_fields
mayureshagashe2105 Jun 5, 2024
b6d440f
Fix tests
mayureshagashe2105 Jun 5, 2024
e53dbab
Set 'CachedContent' as a public property
mayureshagashe2105 Jun 5, 2024
f53a7c6
blacken
mayureshagashe2105 Jun 5, 2024
02cba55
set 'role=user' when content is passed as a str (#4)
mayureshagashe2105 Jun 5, 2024
4c495ef
Handle ttl and expire_time separately
mayureshagashe2105 Jun 6, 2024
0f5f8eb
Remove name param
mayureshagashe2105 Jun 6, 2024
cef3fc7
Update caching_types.py
MarkDaoust Jun 6, 2024
f03a765
Update caching.py
MarkDaoust Jun 6, 2024
42b1e35
Update docstrs and error messages
mayureshagashe2105 Jun 7, 2024
e4648a7
Update model name to gemini-1.5-pro for caching tests
mayureshagashe2105 Jun 7, 2024
f2b495f
Merge branch 'magashe-caching-patch-1' of https://github.com/mayuresh…
mayureshagashe2105 Jun 7, 2024
f715ecb
Remove dafault ttl assignment
mayureshagashe2105 Jun 7, 2024
a576166
blacken
mayureshagashe2105 Jun 7, 2024
5e9b14b
Remove client arg
mayureshagashe2105 Jun 10, 2024
6ccee3e
Add 'usage_metadata' param to CachedContent class
mayureshagashe2105 Jun 11, 2024
3de6909
Add 'display_name' to CachedContent class
mayureshagashe2105 Jun 11, 2024
7fccb32
update generativelanguage version, fix tests
MarkDaoust Jun 11, 2024
2fabe67
format
MarkDaoust Jun 11, 2024
7d14bb1
fewer automatic 'role' insertions
MarkDaoust Jun 11, 2024
3982b48
cleanup
MarkDaoust Jun 11, 2024
940834a
Wrap the proto
MarkDaoust Jun 12, 2024
4a5229e
Apply suggestions from code review
MarkDaoust Jun 12, 2024
c039644
fix
MarkDaoust Jun 12, 2024
fc767a1
format
MarkDaoust Jun 12, 2024
75cc224
cleanup
MarkDaoust Jun 12, 2024
19f0384
update version
MarkDaoust Jun 12, 2024
d438860
fix
MarkDaoust Jun 12, 2024
aa12c3d
typing
MarkDaoust Jun 12, 2024
cc54a87
Merge branch 'main' into magashe-caching-patch-1
MarkDaoust Jun 13, 2024
a6f4355
Simplify update method
mayureshagashe2105 Jun 13, 2024
1c77da4
Add repr to CachedContent
mayureshagashe2105 Jun 13, 2024
0bac36e
cleanup
mayureshagashe2105 Jun 13, 2024
25f4d10
blacken
mayureshagashe2105 Jun 13, 2024
9b48863
Apply suggestions from code review
mayureshagashe2105 Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 160 additions & 106 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,90 +14,130 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
import datetime
from typing import Any, Iterable, Optional
import textwrap
from typing import Iterable, Optional

from google.generativeai import protos
from google.generativeai.types.model_types import idecode_time
from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai.utils import flatten_update_paths
from google.generativeai.client import get_default_cache_client

from google.protobuf import field_mask_pb2
import google.ai.generativelanguage as glm

_USER_ROLE = "user"
_MODEL_ROLE = "model"


@dataclasses.dataclass
class CachedContent:
"""Cached content resource."""

name: str
model: str
create_time: datetime.datetime
update_time: datetime.datetime
expire_time: datetime.datetime
def __init__(self, name):
"""Fetches a `CachedContent` resource.

# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
# Adding basic support for now.
def __enter__(self):
return self
Identical to `CachedContent.get`.

def __exit__(self, exc_type, exc_value, exc_tb):
self.delete()

def _to_dict(self) -> protos.CachedContent:
proto_paths = {
"name": self.name,
"model": self.model,
}
return protos.CachedContent(**proto_paths)

def _apply_update(self, path, value):
parts = path.split(".")
for part in parts[:-1]:
self = getattr(self, part)
if parts[-1] == "ttl":
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
parts[-1] = "expire_time"
setattr(self, parts[-1], value)
Args:
name: The resource name referring to the cached content.
"""
client = get_default_cache_client()

@classmethod
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
# is returning these, hence setting including_default_value_fields to False
cached_content = type(cached_content).to_dict(
cached_content, including_default_value_fields=False
if "cachedContents/" not in name:
name = "cachedContents/" + name

request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
self._proto = response

@property
def name(self) -> str:
return self._proto.name

@property
def model(self) -> str:
return self._proto.model

@property
def display_name(self) -> str:
return self._proto.display_name

@property
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
return self._proto.usage_metadata

@property
def create_time(self) -> datetime.datetime:
return self._proto.create_time

@property
def update_time(self) -> datetime.datetime:
return self._proto.update_time

@property
def expire_time(self) -> datetime.datetime:
return self._proto.expire_time

def __str__(self):
return textwrap.dedent(
f"""\
CachedContent(
name='{self.name}',
model='{self.model}',
display_name='{self.display_name}',
usage_metadata={'{'}
'total_token_count': {self.usage_metadata.total_token_count},
{'}'},
create_time={self.create_time},
update_time={self.update_time},
expire_time={self.expire_time}
)"""
)

idecode_time(cached_content, "create_time")
idecode_time(cached_content, "update_time")
# always decode `expire_time` as Timestamp is returned
# regardless of what was sent on input
idecode_time(cached_content, "expire_time")
return cls(**cached_content)
__repr__ = __str__

@classmethod
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
"""Creates an instance of CachedContent form an object, without calling `get`."""
self = cls.__new__(cls)
self._proto = protos.CachedContent()
self._update(obj)
return self

def _update(self, updates):
"""Updates this instance inplace, does not call the API's `update` method"""
if isinstance(updates, CachedContent):
updates = updates._proto

if not isinstance(updates, dict):
updates = type(updates).to_dict(updates, including_default_value_fields=False)

for key, value in updates.items():
setattr(self._proto, key, value)

@staticmethod
def _prepare_create_request(
model: str,
name: str | None = None,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> protos.CreateCachedContentRequest:
"""Prepares a CreateCachedContentRequest."""
if name is not None:
if not caching_types.valid_cached_content_name(name):
raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name))

name = "cachedContents/" + name
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)

if "/" not in model:
model = "models/" + model

if display_name and len(display_name) > 128:
raise ValueError("`display_name` must be no more than 128 unicode characters.")

if system_instruction:
system_instruction = content_types.to_content(system_instruction)

Expand All @@ -110,18 +150,21 @@ def _prepare_create_request(

if contents:
contents = content_types.to_contents(contents)
if not contents[-1].role:
contents[-1].role = _USER_ROLE

if ttl:
ttl = caching_types.to_ttl(ttl)
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)

cached_content = protos.CachedContent(
name=name,
model=model,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools_lib,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)

return protos.CreateCachedContentRequest(cached_content=cached_content)
Expand All @@ -130,48 +173,55 @@ def _prepare_create_request(
def create(
cls,
model: str,
name: str | None = None,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
client: glm.CacheServiceClient | None = None,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> CachedContent:
"""Creates `CachedContent` resource.

Args:
model: The name of the `model` to use for cached content creation.
Any `CachedContent` resource can be only used with the
`model` it was created for.
name: The resource name referring to the cached content.
display_name: The user-generated meaningful display name
of the cached content. `display_name` must be no
more than 128 unicode characters.
system_instruction: Developer set system instruction.
contents: Contents to cache.
tools: A list of `Tools` the model may use to generate response.
tool_config: Config to apply to all tools.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.

Returns:
`CachedContent` resource with specified name.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = cls._prepare_create_request(
model=model,
name=name,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)

response = client.create_cached_content(request)
return cls._decode_cached_content(response)
result = CachedContent._from_obj(response)
return result

@classmethod
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
def get(cls, name: str) -> CachedContent:
"""Fetches required `CachedContent` resource.

Args:
Expand All @@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC
Returns:
`CachedContent` resource with specified `name`.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

if "cachedContents/" not in name:
name = "cachedContents/" + name

request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
return cls._decode_cached_content(response)
result = CachedContent._from_obj(response)
return result

@classmethod
def list(
cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
) -> Iterable[CachedContent]:
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
"""Lists `CachedContent` objects associated with the project.

Args:
Expand All @@ -203,58 +251,64 @@ def list(
Returns:
A paginated list of `CachedContent` objects.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = protos.ListCachedContentsRequest(page_size=page_size)
for cached_content in client.list_cached_contents(request):
yield cls._decode_cached_content(cached_content)
cached_content = CachedContent._from_obj(cached_content)
yield cached_content

def delete(self, client: glm.CachedServiceClient | None = None) -> None:
def delete(self) -> None:
"""Deletes `CachedContent` resource."""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = protos.DeleteCachedContentRequest(name=self.name)
client.delete_cached_content(request)
return

def update(
self,
updates: dict[str, Any],
client: glm.CacheServiceClient | None = None,
) -> CachedContent:
*,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> None:
"""Updates requested `CachedContent` resource.

Args:
updates: The list of fields to update. Currently only
`ttl/expire_time` is supported as an update path.

Returns:
`CachedContent` object with specified updates.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.
"""
if client is None:
client = get_default_cache_client()

updates = flatten_update_paths(updates)
for update_path in updates:
if update_path == "ttl":
updates = updates.copy()
update_path_val = updates.get(update_path)
updates[update_path] = caching_types.to_ttl(update_path_val)
else:
raise ValueError(
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
)
field_mask = field_mask_pb2.FieldMask()
client = get_default_cache_client()

for path in updates.keys():
field_mask.paths.append(path)
for path, value in updates.items():
self._apply_update(path, value)
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)

request = protos.UpdateCachedContentRequest(
cached_content=self._to_dict(), update_mask=field_mask
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)

updates = protos.CachedContent(
name=self.name,
ttl=ttl,
expire_time=expire_time,
)
client.update_cached_content(request)
return self

field_mask = field_mask_pb2.FieldMask()

if ttl:
field_mask.paths.append("ttl")
elif expire_time:
field_mask.paths.append("expire_time")
else:
raise ValueError(
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
)

request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
updated_cc = client.update_cached_content(request)
self._update(updated_cc)

return
Loading
Loading