From 5412d24c457768a731aa907a39f93ede35b3827d Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Wed, 5 Jun 2024 05:51:09 +0530 Subject: [PATCH 1/7] Explicit Caching (#355) * *Inital prototype for explicit caching *Add basic CURD support for caching *Remove INPUT_ONLY marked fields from CachedContent dataclass *Rename files 'cached_content*' -> 'caching*' *Update 'Create' method for explicit instantination of 'CachedContent' *Add a factory method to instatinate model with `CachedContent` as its context *blacken *Add tests Change-Id: I694545243efda467d6fd599beded0dc6679b727d * rename get_cached_content to get * Stroke out functional approach for CachedContent CURD ops * blacken * Improve tests * fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 * fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 * Validate name checks for CachedContent creation Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652 * Add tests Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821 * mark name as OPTIONAL for CachedContent creation If not provided, the name will be randomly generated Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e * Add type-annotations to __new__ to fix pytype checks Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd * Add 'cached_content' to GenerativeModel's repr Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80 * blacken Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a * Fix types Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0 * Fix docstrings Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf * Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 * Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 * Refactor for genai.protos module Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1 * use preview build Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601 --------- Co-authored-by: Mark Daoust --- google/generativeai/caching.py | 260 +++++++++++++++++++++ google/generativeai/client.py | 4 + google/generativeai/generative_models.py | 71 +++++- google/generativeai/types/caching_types.py | 53 +++++ setup.py | 2 +- tests/test_caching.py | 246 +++++++++++++++++++ tests/test_generative_models.py | 79 ++++++- 7 files changed, 712 insertions(+), 3 deletions(-) create mode 100644 google/generativeai/caching.py create mode 100644 google/generativeai/types/caching_types.py create mode 100644 tests/test_caching.py diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py new file mode 100644 index 000000000..a28a50256 --- /dev/null +++ b/google/generativeai/caching.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import datetime +from typing import Any, Iterable, Optional + +from google.generativeai import protos +from google.generativeai.types.model_types import idecode_time +from google.generativeai.types import caching_types +from google.generativeai.types import content_types +from google.generativeai.utils import flatten_update_paths +from google.generativeai.client import get_default_cache_client + +from google.protobuf import field_mask_pb2 +import google.ai.generativelanguage as glm + + +@dataclasses.dataclass +class CachedContent: + """Cached content resource.""" + + name: str + model: str + create_time: datetime.datetime + update_time: datetime.datetime + expire_time: datetime.datetime + + # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). + # Adding basic support for now. + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + self.delete() + + def _to_dict(self) -> protos.CachedContent: + proto_paths = { + "name": self.name, + "model": self.model, + } + return protos.CachedContent(**proto_paths) + + def _apply_update(self, path, value): + parts = path.split(".") + for part in parts[:-1]: + self = getattr(self, part) + if parts[-1] == "ttl": + value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) + parts[-1] = "expire_time" + setattr(self, parts[-1], value) + + @classmethod + def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent: + # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build + # is returning these, hence setting including_default_value_fields to False + cached_content = type(cached_content).to_dict( + cached_content, including_default_value_fields=False + ) + + idecode_time(cached_content, "create_time") + idecode_time(cached_content, "update_time") + # always decode `expire_time` as Timestamp is returned + # regardless of what was sent on input + idecode_time(cached_content, "expire_time") + return cls(**cached_content) + + @staticmethod + def _prepare_create_request( + model: str, + name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + ) -> protos.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if name is not None: + if not caching_types.valid_cached_content_name(name): + raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) + + name = "cachedContents/" + name + + if "/" not in model: + model = "models/" + model + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + + if ttl: + ttl = caching_types.to_ttl(ttl) + + cached_content = protos.CachedContent( + name=name, + model=model, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + ) + + return protos.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + model: str, + name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Creates `CachedContent` resource. + + Args: + model: The name of the `model` to use for cached content creation. + Any `CachedContent` resource can be only used with the + `model` it was created for. + name: The resource name referring to the cached content. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + + Returns: + `CachedContent` resource with specified name. + """ + if client is None: + client = get_default_cache_client() + + request = cls._prepare_create_request( + model=model, + name=name, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + ) + + response = client.create_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: The resource name referring to the cached content. + + Returns: + `CachedContent` resource with specified `name`. + """ + if client is None: + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + return cls._decode_cached_content(response) + + @classmethod + def list( + cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None + ) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. + + Args: + page_size: The maximum number of permissions to return (per page). + The service may return fewer `CachedContent` objects. + + Returns: + A paginated list of `CachedContent` objects. + """ + if client is None: + client = get_default_cache_client() + + request = protos.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + yield cls._decode_cached_content(cached_content) + + def delete(self, client: glm.CachedServiceClient | None = None) -> None: + """Deletes `CachedContent` resource.""" + if client is None: + client = get_default_cache_client() + + request = protos.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return + + def update( + self, + updates: dict[str, Any], + client: glm.CacheServiceClient | None = None, + ) -> CachedContent: + """Updates requested `CachedContent` resource. + + Args: + updates: The list of fields to update. Currently only + `ttl/expire_time` is supported as an update path. + + Returns: + `CachedContent` object with specified updates. + """ + if client is None: + client = get_default_cache_client() + + updates = flatten_update_paths(updates) + for update_path in updates: + if update_path == "ttl": + updates = updates.copy() + update_path_val = updates.get(update_path) + updates[update_path] = caching_types.to_ttl(update_path_val) + else: + raise ValueError( + f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." + ) + field_mask = field_mask_pb2.FieldMask() + + for path in updates.keys(): + field_mask.paths.append(path) + for path, value in updates.items(): + self._apply_update(path, value) + + request = protos.UpdateCachedContentRequest( + cached_content=self._to_dict(), update_mask=field_mask + ) + client.update_cached_content(request) + return self diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 40c2bdcaf..7012ecc7c 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -315,6 +315,10 @@ def configure( _client_manager.configure() +def get_default_cache_client() -> glm.CacheServiceClient: + return _client_manager.get_default_client("cache") + + def get_default_discuss_client() -> glm.DiscussServiceClient: return _client_manager.get_default_client("discuss") diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 7d69ae8f9..10744a948 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import textwrap -from typing import Any +from typing import Any, Union, overload import reprlib # pylint: disable=bad-continuation, line-too-long @@ -13,6 +13,8 @@ import google.api_core.exceptions from google.generativeai import protos from google.generativeai import client + +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types @@ -94,6 +96,15 @@ def __init__( self._client = None self._async_client = None + def __new__(cls, *args, **kwargs) -> GenerativeModel: + self = super().__new__(cls) + + if cached_instance := kwargs.pop("cached_content", None): + setattr(self, "_cached_content", cached_instance.name) + setattr(cls, "cached_content", property(fget=lambda self: self._cached_content)) + + return self + @property def model_name(self): return self._model_name @@ -112,6 +123,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, + cached_content={getattr(self, "cached_content", None)} )""" ) @@ -127,6 +139,13 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: """Creates a `protos.GenerateContentRequest` from raw inputs.""" + if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]): + raise ValueError( + "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." + ) + + cached_content = getattr(self, "cached_content", None) + tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -155,6 +174,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, + cached_content=cached_content, ) def _get_tools_lib( @@ -165,6 +185,55 @@ def _get_tools_lib( else: return content_types.to_function_library(tools) + @overload + @classmethod + def from_cached_content( + cls, + cached_content: str, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: caching.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @classmethod + def from_cached_content( + cls, + cached_content: str | caching.CachedContent, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: + """Creates a model with `cached_content` as model's context. + + Args: + cached_content: context for the model. + + Returns: + `GenerativeModel` object with `cached_content` as its context. + """ + if isinstance(cached_content, str): + cached_content = caching.CachedContent.get(name=cached_content) + + # call __new__ with the cached_content to set the model's context. This is done to avoid + # the exposing `cached_content` as a public attribute. + self = cls.__new__(cls, cached_content=cached_content) + + # call __init__ to set the model's `generation_config`, `safety_settings`. + # `model_name` will be the name of the model for which the `cached_content` was created. + self.__init__( + model_name=cached_content.model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + return self + def generate_content( self, contents: content_types.ContentsType, diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py new file mode 100644 index 000000000..8d55b70b2 --- /dev/null +++ b/google/generativeai/types/caching_types.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +from typing import Optional, Union +from typing_extensions import TypedDict +import re + +__all__ = ["TTL"] + + +_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$" +NAME_ERROR_MESSAGE = ( + "The `name` must consist of alphanumeric characters (or `-` or `.`). Received: `{name}`" +) + + +def valid_cached_content_name(name: str) -> bool: + return re.match(_VALID_CACHED_CONTENT_NAME, name) is not None + + +class TTL(TypedDict): + seconds: int + + +ExpirationTypes = Union[TTL, int, datetime.timedelta] + + +def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: + if isinstance(expiration, datetime.timedelta): + return {"seconds": int(expiration.total_seconds())} + elif isinstance(expiration, dict): + return expiration + elif isinstance(expiration, int): + return {"seconds": expiration} + else: + raise TypeError( + f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", + expiration, + ) diff --git a/setup.py b/setup.py index 6f9545e4f..0575dcd28 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.4", + "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 000000000..47692325b --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import unittest + +from google.generativeai import caching +from google.generativeai import protos + +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["cache"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def create_cached_content( + request: protos.CreateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def get_cached_content( + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def list_cached_contents( + request: protos.ListCachedContentsRequest, + **kwargs, + ) -> protos.ListCachedContentsResponse: + self.observed_requests.append(request) + return [ + protos.CachedContent( + name="cachedContents/test-cached-content-1", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + protos.CachedContent( + name="cachedContents/test-cached-content-2", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + ] + + @add_client_method + def update_cached_content( + request: protos.UpdateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T03:01:01.123456Z", + ) + + @add_client_method + def delete_cached_content( + request: protos.DeleteCachedContentRequest, + **kwargs, + ) -> None: + self.observed_requests.append(request) + + def test_create_cached_content(self): + + def add(a: int, b: int) -> int: + return a + b + + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + tools=[add], + tool_config={"function_calling_config": "ANY"}, + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + @parameterized.named_parameters( + [ + dict( + testcase_name="ttl-is-int-seconds", + ttl=7200, + ), + dict( + testcase_name="ttl-is-timedelta", + ttl=datetime.timedelta(hours=2), + ), + dict( + testcase_name="ttl-is-dict", + ttl={"seconds": 7200}, + ), + dict( + testcase_name="ttl-is-none-default-to-1-hr", + ttl=None, + ), + ] + ) + def test_expiration_types_for_create_cached_content(self, ttl): + cc = caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["cache this please for 2 hours"], + ttl=ttl, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + @parameterized.named_parameters( + [ + dict( + testcase_name="upper_case", + name="Test-cached-content", + ), + dict( + testcase_name="special_characters_except_dot_and_hyphen", + name="test-cac*@/hed-conte#nt", + ), + dict( + testcase_name="empty_name", + name="", + ), + dict( + testcase_name="blank_spaces", + name="test cached content", + ), + ] + ) + def test_create_cached_content_with_invalid_name_format(self, name): + with self.assertRaises(ValueError): + _ = caching.CachedContent.create( + name=name, + model="models/gemini-1.0-pro-001", + ) + + def test_get_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + + def test_list_cached_contents(self): + ccs = list(caching.CachedContent.list(page_size=2)) + self.assertIsInstance(self.observed_requests[-1], protos.ListCachedContentsRequest) + self.assertLen(ccs, 2) + self.assertIsInstance(ccs[0], caching.CachedContent) + self.assertIsInstance(ccs[1], caching.CachedContent) + + def test_update_cached_content_invalid_update_paths(self): + update_masks = dict( + name="change", + model="models/gemini-1.5-pro-001", + system_instruction="Always add 10 to the result.", + contents=["add this Content"], + ) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + with self.assertRaises(ValueError): + cc.update(updates=update_masks) + + def test_update_cached_content_valid_update_paths(self): + update_masks = dict( + ttl=datetime.timedelta(hours=2), + ) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc = cc.update(updates=update_masks) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_delete_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + def test_auto_delete_cached_content_with_context_manager(self): + with caching.CachedContent.create( + name="test-cached-content", + model="models/gemini-1.0-pro-001", + contents=["Add 5 and 6"], + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) as cc: + ... # some logic + + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 0ece77e94..73789346d 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -1,6 +1,7 @@ import collections from collections.abc import Iterable import copy +import datetime import pathlib from typing import Any import textwrap @@ -10,11 +11,11 @@ from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types - import PIL.Image HERE = pathlib.Path(__file__).parent @@ -77,6 +78,20 @@ def count_tokens( response = self.responses["count_tokens"].pop(0) return response + def get_cached_content( + self, + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + class CUJTests(parameterized.TestCase): """Tests are in order with the design doc.""" @@ -96,6 +111,7 @@ def responses(self): def setUp(self): self.client = MockGenerativeServiceClient(self) client_lib._client_manager.clients["generative"] = self.client + client_lib._client_manager.clients["cache"] = self.client def test_hello(self): # Generate text from text prompt @@ -317,6 +333,56 @@ def test_stream_prompt_feedback_not_blocked(self): text = "".join(chunk.text for chunk in response) self.assertEqual(text, "first second") + @parameterized.named_parameters( + [ + dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), + dict( + testcase_name="test_cached_content_as_CachedContent_object", + cached_content=caching.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.0-pro-001", + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ), + ), + ], + ) + def test_model_with_cached_content_as_context(self, cached_content): + model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content) + cc_name = model.cached_content # pytype: disable=attribute-error + model_name = model.model_name + self.assertEqual(cc_name, "cachedContents/test-cached-content") + self.assertEqual(model_name, "models/gemini-1.0-pro-001") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_content_generation_with_model_having_context(self): + self.responses["generate_content"] = [simple_response("world!")] + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + response = model.generate_content("Hello") + + self.assertEqual(response.text, "world!") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_fail_content_generation_with_model_having_context(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + + def add(a: int, b: int) -> int: + return a + b + + with self.assertRaises(ValueError): + model.generate_content("Hello", tools=[add]) + def test_chat(self): # Multi turn chat model = generative_models.GenerativeModel("gemini-pro") @@ -1140,6 +1206,7 @@ def test_repr_for_multi_turn_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" @@ -1168,6 +1235,7 @@ def test_repr_for_incomplete_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1212,6 +1280,7 @@ def test_repr_for_broken_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1223,6 +1292,14 @@ def test_repr_for_system_instruction(self): result = repr(model) self.assertIn("system_instruction='Be excellent.'", result) + def test_repr_for_model_created_from_cahced_content(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + result = repr(model) + self.assertIn("cached_content=cachedContents/test-cached-content", result) + self.assertIn("model_name='models/gemini-1.0-pro-001'", result) + def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} From b9d5bc00b67aa5bcbf20850d5105fe486910bf68 Mon Sep 17 00:00:00 2001 From: Ryan Wilson Date: Tue, 11 Jun 2024 11:43:22 -0400 Subject: [PATCH 2/7] Delete .github/ISSUE_TEMPLATE directory (#389) This will now use the issue templates at the base of the google-gemini organization instead. --- .github/ISSUE_TEMPLATE/bug_report.yml | 23 ---------------------- .github/ISSUE_TEMPLATE/feature_request.yml | 23 ---------------------- 2 files changed, 46 deletions(-) delete mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml delete mode 100644 .github/ISSUE_TEMPLATE/feature_request.yml diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 16c34f37a..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Bug report -description: Use this template to report bugs -labels: ["type:bug", "component:python sdk"] -body: - - type: markdown - attributes: - value: > - **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [Discourse forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. - - type: textarea - id: description - attributes: - label: > - Description of the bug: - - type: textarea - id: behavior - attributes: - label: > - Actual vs expected behavior: - - type: textarea - id: info - attributes: - label: > - Any other information you'd like to share? diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml deleted file mode 100644 index 91a380e87..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Feature request -description: Use this template to suggest a new feature -labels: ["type:feature request", "component:python sdk"] -body: - - type: markdown - attributes: - value: > - **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [Discourse forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. - - type: textarea - id: description - attributes: - label: > - Description of the feature request: - - type: textarea - id: behavior - attributes: - label: > - What problem are you trying to solve with this feature? - - type: textarea - id: info - attributes: - label: > - Any other information you'd like to share? From 8713cfda26e37f440356378538e448859b083cc2 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Wed, 12 Jun 2024 03:26:41 +0530 Subject: [PATCH 3/7] Update gapic lib to use 0.6.5 and fix tests (#390) Change-Id: Idd9b450daf8b0b2b09a07127aaa37f97e8b8cfbf --- setup.py | 2 +- tests/test_generative_models.py | 5 ++++- tests/test_generative_models_async.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 0575dcd28..89af61515 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py.tar.gz", + "google-ai-generativelanguage==0.6.5", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 73789346d..c4d46ffec 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -842,7 +842,10 @@ def test_count_tokens_smoke(self, kwargs): self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) response = model.count_tokens(**kwargs) - self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) + self.assertEqual( + type(response).to_dict(response, including_default_value_fields=False), + {"total_tokens": 7}, + ) @parameterized.named_parameters( [ diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 03055ffb3..3dcf49ae4 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -216,7 +216,10 @@ async def test_count_tokens_smoke(self, contents): self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) - self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) + self.assertEqual( + type(response).to_dict(response, including_default_value_fields=False), + {"total_tokens": 7}, + ) async def test_stream_generate_content_called_with_request_options(self): self.client.stream_generate_content = unittest.mock.AsyncMock() From 7313e21f1306e327d32193842966c0ba7c381bc8 Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 13 Jun 2024 15:05:40 +0530 Subject: [PATCH 4/7] Remove duplicate test (#387) * Cover methods like _handle_afc Change-Id: I0f45cb8566b681f5aaeda1500e74d95eaeab10ef * Remove duplicate code match test Change-Id: Ia9b7c3dcb8dd4d6bf163303b483225aa0b00e0e9 --- tests/test_async_code_match.py | 2 +- tests/test_generative_models.py | 50 --------------------------------- 2 files changed, 1 insertion(+), 51 deletions(-) diff --git a/tests/test_async_code_match.py b/tests/test_async_code_match.py index 008d251f2..0ec4550d4 100644 --- a/tests/test_async_code_match.py +++ b/tests/test_async_code_match.py @@ -87,7 +87,7 @@ def test_code_match_for_async_methods(self): for node in ast.walk(source_nodes): if isinstance( node, (ast.FunctionDef, ast.AsyncFunctionDef) - ) and not node.name.startswith("_"): + ) and not node.name.startswith("__"): name = node.name[:-6] if node.name.endswith("_async") else node.name if name in EXEMPT_FUNCTIONS or self._inspect_decorator_exemption(node, fpath): continue diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index c4d46ffec..18daf6707 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -847,56 +847,6 @@ def test_count_tokens_smoke(self, kwargs): {"total_tokens": 7}, ) - @parameterized.named_parameters( - [ - "GenerateContentResponse", - generation_types.GenerateContentResponse, - generation_types.AsyncGenerateContentResponse, - ], - [ - "GenerativeModel.generate_response", - generative_models.GenerativeModel.generate_content, - generative_models.GenerativeModel.generate_content_async, - ], - [ - "GenerativeModel.count_tokens", - generative_models.GenerativeModel.count_tokens, - generative_models.GenerativeModel.count_tokens_async, - ], - [ - "ChatSession.send_message", - generative_models.ChatSession.send_message, - generative_models.ChatSession.send_message_async, - ], - [ - "ChatSession._handle_afc", - generative_models.ChatSession._handle_afc, - generative_models.ChatSession._handle_afc_async, - ], - ) - def test_async_code_match(self, obj, aobj): - import inspect - import re - - source = inspect.getsource(obj) - asource = inspect.getsource(aobj) - - source = re.sub('""".*"""', "", source, flags=re.DOTALL) - asource = re.sub('""".*"""', "", asource, flags=re.DOTALL) - - asource = ( - asource.replace("anext", "next") - .replace("aiter", "iter") - .replace("_async", "") - .replace("async ", "") - .replace("await ", "") - .replace("Async", "") - .replace("ASYNC_", "") - ) - - asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource, f"error in {obj=}") - def test_repr_for_unary_non_streamed_response(self): model = generative_models.GenerativeModel(model_name="gemini-pro") self.responses["generate_content"].append(simple_response("world!")) From dbd5498f749b838b0218fa077d8433cf9bc9c966 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 13 Jun 2024 02:36:15 -0700 Subject: [PATCH 5/7] remove references to pro-vision (#388) Change-Id: I5409ada8470dfda8354beba615ad906778ea13f6 --- tests/test_generative_models.py | 28 +++++++++++++-------------- tests/test_generative_models_async.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 18daf6707..4b9501334 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -434,7 +434,7 @@ def test_chat_streaming_basic(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -457,7 +457,7 @@ def test_chat_incomplete_streaming_errors(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -481,7 +481,7 @@ def test_edit_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello") @@ -507,7 +507,7 @@ def test_replace_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() chat.send_message("hello1") chat.send_message("hello2") @@ -529,7 +529,7 @@ def test_copy_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat1 = model.start_chat() chat1.send_message("hello1") @@ -574,7 +574,7 @@ def no_throw(): no_throw(), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() # Send a message, the response is okay.. @@ -617,7 +617,7 @@ def test_chat_prompt_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() with self.assertRaises(generation_types.BlockedPromptException): @@ -635,7 +635,7 @@ def test_chat_candidate_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() with self.assertRaises(generation_types.StopCandidateException): @@ -657,7 +657,7 @@ def test_chat_streaming_unexpected_stop(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello", stream=True) @@ -681,7 +681,7 @@ def test_tools(self): dict(name="datetime", description="Returns the current UTC date and time.") ] ) - model = generative_models.GenerativeModel("gemini-pro-vision", tools=tools) + model = generative_models.GenerativeModel("gemini-1.5-flash", tools=tools) self.responses["generate_content"] = [ simple_response("a"), @@ -840,7 +840,7 @@ def test_system_instruction(self, instruction, expected_instr): def test_count_tokens_smoke(self, kwargs): si = kwargs.pop("system_instruction", None) self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) + model = generative_models.GenerativeModel("gemini-1.5-flash", system_instruction=si) response = model.count_tokens(**kwargs) self.assertEqual( type(response).to_dict(response, including_default_value_fields=False), @@ -1018,7 +1018,7 @@ def no_throw(): no_throw(), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() # Send a message, the response is okay.. @@ -1077,7 +1077,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello", stream=True) @@ -1257,7 +1257,7 @@ def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options) self.assertEqual(request_options, self.observed_kwargs[0]) diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 3dcf49ae4..dd9bc3b62 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -214,7 +214,7 @@ async def test_tool_config(self, tool_config, expected_tool_config): ) async def test_count_tokens_smoke(self, contents): self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") response = await model.count_tokens_async(contents) self.assertEqual( type(response).to_dict(response, including_default_value_fields=False), @@ -256,7 +256,7 @@ async def test_count_tokens_called_with_request_options(self): request = unittest.mock.ANY request_options = {"timeout": 120} - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") response = await model.count_tokens_async( contents=["Hello?"], request_options=request_options ) From 23b81d76afd4ebf5fc83fcb9035ff0f51494f9bd Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Thu, 13 Jun 2024 19:38:16 +0530 Subject: [PATCH 6/7] Explicit Caching patch (#377) * Squashed commit of the following: commit acb3806754b7a71a746fc487c596d67a9aee23f0 Author: Mayuresh Agashe Date: Wed Jun 5 00:51:30 2024 +0000 fix update method Change-Id: I433c25b2d80cdf6e483b59f61ff29bb8d2dc6595 commit fb9995c08dd0f39473efbf66fbb21e48b62c28c0 Merge: 4627fe1 7b9758f Author: Mark Daoust Date: Tue Jun 4 09:55:38 2024 -0700 Merge branch 'main' into caching Change-Id: I2bade6b0099f12dd37a24fe26cfda1981c58fbc0 commit 4627fe1b411dcb1b5e3c7c1d882ce18b8eac73f7 Author: Mark Daoust Date: Tue Jun 4 09:54:31 2024 -0700 use preview build Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601 commit 8e86ef19f9b9fce9d384e12ff364c4e8bdb0265f Author: Mayuresh Agashe Date: Thu May 30 16:18:22 2024 +0000 Refactor for genai.protos module Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1 commit 82d3c5a877e799b357ee39df6e63e8e5ca3807a4 Merge: bf6551a f08c789 Author: Mayuresh Agashe Date: Thu May 30 15:57:27 2024 +0000 Merge branch 'main' of https://github.com/mayureshagashe2105/generative-ai-python into caching Change-Id: Id2b259fe4b2c91653bf5e4d5e883f556366d8676 commit bf6551ac133c50be294788357fb52a318d4d5d4d Author: Mayuresh Agashe Date: Mon May 27 11:26:03 2024 +0000 Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 commit 67472d32bcd1dbbb62972e1ad626efdee30cf0c1 Author: Mayuresh Agashe Date: Mon May 27 11:26:03 2024 +0000 Fix types Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075 commit a1c8c725540ebe1b3ea486ad1b45ee6836b40ca6 Author: Mayuresh Agashe Date: Mon May 27 11:15:15 2024 +0000 Fix docstrings Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf commit f48cedc391982f2442dde08e553303298c61f49c Author: Mayuresh Agashe Date: Mon May 27 11:13:44 2024 +0000 Fix types Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0 commit 645ceab6d2bd10524edf0edd43f780e4c93c410b Author: Mayuresh Agashe Date: Mon May 27 05:54:26 2024 +0000 blacken Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a commit 17372e3f118d1126ac32e918aac25975d8f455c4 Author: Mayuresh Agashe Date: Mon May 27 05:54:06 2024 +0000 Add 'cached_content' to GenerativeModel's repr Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80 commit d1fd7496ea09612b6d8df64bd374603589fb62fb Author: Mayuresh Agashe Date: Mon May 27 05:04:43 2024 +0000 Add type-annotations to __new__ to fix pytype checks Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd commit f37df8cc5e3dc5f81603ec013746059ce1abc717 Author: Mayuresh Agashe Date: Sun May 26 06:51:54 2024 +0000 mark name as OPTIONAL for CachedContent creation If not provided, the name will be randomly generated Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e commit 59663c88d6fc3958544fe877d3c71962c15bd865 Author: Mayuresh Agashe Date: Fri May 24 10:22:08 2024 +0000 Add tests Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821 commit e1d8c7ac2785add8b27e4fee8bd7835a98156de7 Author: Mayuresh Agashe Date: Fri May 24 10:21:42 2024 +0000 Validate name checks for CachedContent creation Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652 commit 2cde1a21ea15c42eceb6778add040eb6d3a69b95 Author: Mayuresh Agashe Date: Thu May 23 18:09:14 2024 +0000 fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 commit d862dae543645d13e5cf31512b8306c03dcb3fc1 Author: Mayuresh Agashe Date: Thu May 23 18:09:14 2024 +0000 fix tests Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7 commit d35cc7194a905d2776abcc719eafac3f4c91d512 Author: Mayuresh Agashe Date: Thu May 23 23:12:38 2024 +0530 Improve tests commit e65d16e5a8c72683780631a037769c1e00dc6b7d Author: Mayuresh Agashe Date: Thu May 23 23:12:05 2024 +0530 blacken commit cfc936e164d5bc8fcedf6ae2894fa0369f75f762 Author: Mayuresh Agashe Date: Thu May 23 23:10:16 2024 +0530 Stroke out functional approach for CachedContent CURD ops commit afd066d181b9ffd053731904b7f172898fd07cce Merge: 6fafe6b 0dca4ce Author: Mayuresh Agashe Date: Wed May 22 23:10:20 2024 +0530 Merge branch 'main' into caching commit 6fafe6b329647586ffd073fd22588290fecc28db Author: Mayuresh Agashe Date: Wed May 22 10:49:35 2024 +0530 rename get_cached_content to get commit a4ac7a5bfe1d09fbbc3650b51e094516ad8c30ad Merge: f13228d f987fde Author: Mayuresh Agashe Date: Tue May 21 23:32:41 2024 +0530 Merge branch 'main' into caching commit f13228dc01728e410d5ca6916176049a04490218 Author: Mayuresh Agashe Date: Fri Apr 26 16:54:09 2024 +0000 *Inital prototype for explicit caching *Add basic CURD support for caching *Remove INPUT_ONLY marked fields from CachedContent dataclass *Rename files 'cached_content*' -> 'caching*' *Update 'Create' method for explicit instantination of 'CachedContent' *Add a factory method to instatinate model with `CachedContent` as its context *blacken *Add tests Change-Id: I694545243efda467d6fd599beded0dc6679b727d Change-Id: I7b14d94f729953294780815f4c496888bb2ad46f * Remove auto cache deletion Change-Id: I4658e1c57f967faeb3945dffef0181a456d65370 * Rename _to_dict --> _get_update_fields Change-Id: I3c92c65e8e5b215e98c1ac0eea6db033166dec78 * Fix tests Change-Id: Id36d7606e13d15caf6870f29a108944c7f36eaeb * Set 'CachedContent' as a public property Remove __new__ construct Change-Id: Ie4f5527270be90730341b6c3b67de71b9b6e9c5c * blacken Change-Id: I12498213a7fc2b257827ab0df87c6913e04cad25 * set 'role=user' when content is passed as a str (#4) 'to_content' method assigns a default 'role=user' to all the contents passed as a string Change-Id: I748514a7839b7f1d36150b879c3d1464ca9e11ba * Handle ttl and expire_time separately Change-Id: If9c6f04fe8d419828e3efd2249f0698bca4d5bdc * Remove name param Change-Id: I40fe7c8fafdb014fb9c7e74956452aca9a666641 * Update caching_types.py * Update caching.py * Update docstrs and error messages Change-Id: I111a1218a7d9783d494b84f0a11cb3b76c7ad9da * Update model name to gemini-1.5-pro for caching tests Change-Id: Ibb1f75c409afaac124ef70232be71e3a882f6015 * Remove dafault ttl assignment Let the API set the dafault Change-Id: Id8d125a085ed27229ddb78d5812ed5b5ad39227b * blacken Change-Id: I1d7fe0ec422589e237502b0eda687cf81ef21a21 * Remove client arg Change-Id: I17f05a90a1514f404dd3527c0db1ce6147d2c47a * Add 'usage_metadata' param to CachedContent class Change-Id: Ic527c157bc2cd114948b73a8f1832c21dd61b52e * Add 'display_name' to CachedContent class Change-Id: Id0a9be9d1bfdb94dc9d5c4fc7af9dee89e5365a4 * update generativelanguage version, fix tests Change-Id: I0acc57853ab7dde863bbbe4b30ae3957e6ec3d11 * format Change-Id: Ib2e9a16aaa989021d3498f3e59f9983560919159 * fewer automatic 'role' insertions Change-Id: I0752741532a451f8720fa5e110e68f0b4e66cc4b * cleanup Change-Id: I151a809f6d079b8e4b0ed30d1153a638c98cacfd * Wrap the proto Change-Id: I14b4c54652fb51b867fb43d4b3e9091e6eaccd4e * Apply suggestions from code review Co-authored-by: Mayuresh Agashe * fix Change-Id: I381029fc8fc13c39e432b39084fc8feba305514e * format Change-Id: I8e0b44aebc102d3b2afb27a422c4d70d6c99d5d2 * cleanup Change-Id: I024733b53cede5bfdf957ce7e56d6ad01fd4b2bf * update version Change-Id: Ic95dffb3e945e31adc0d98787942d27289512b8a * fix Change-Id: I6ffdabbddf0e803606b3638521ebfeb6796d2e4b * typing Change-Id: I629d4d111f0e640f4f4bf602ea33f70fdc9ca3e4 * Simplify update method Accept kwargs instead of dict of updates and construct protos using kwargs Change-Id: I7858d585b1aa6b965134e2fb90adff737172af92 * Add repr to CachedContent Change-Id: Id4ec78ebf9d6e96f22f6bf37fc4509268fa552f4 * cleanup Change-Id: I684b46f881735bceb3f9e09d8573721ddb29f98a * blacken Change-Id: I773e7a5b8a222c8b4435470cdc2b53be425d95e4 * Apply suggestions from code review Change-Id: I2a12b9689001bbc41c460db5a9f0e87c77d4caf6 --------- Co-authored-by: Mark Daoust --- google/generativeai/caching.py | 266 +++++++++++++-------- google/generativeai/generative_models.py | 57 +++-- google/generativeai/types/caching_types.py | 72 ++++-- google/generativeai/version.py | 2 +- tests/test_caching.py | 122 ++++++---- tests/test_generative_models.py | 30 ++- 6 files changed, 335 insertions(+), 214 deletions(-) diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index a28a50256..9acad4726 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -14,90 +14,130 @@ # limitations under the License. from __future__ import annotations -import dataclasses import datetime -from typing import Any, Iterable, Optional +import textwrap +from typing import Iterable, Optional from google.generativeai import protos -from google.generativeai.types.model_types import idecode_time from google.generativeai.types import caching_types from google.generativeai.types import content_types -from google.generativeai.utils import flatten_update_paths from google.generativeai.client import get_default_cache_client from google.protobuf import field_mask_pb2 -import google.ai.generativelanguage as glm + +_USER_ROLE = "user" +_MODEL_ROLE = "model" -@dataclasses.dataclass class CachedContent: """Cached content resource.""" - name: str - model: str - create_time: datetime.datetime - update_time: datetime.datetime - expire_time: datetime.datetime + def __init__(self, name): + """Fetches a `CachedContent` resource. - # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). - # Adding basic support for now. - def __enter__(self): - return self + Identical to `CachedContent.get`. - def __exit__(self, exc_type, exc_value, exc_tb): - self.delete() - - def _to_dict(self) -> protos.CachedContent: - proto_paths = { - "name": self.name, - "model": self.model, - } - return protos.CachedContent(**proto_paths) - - def _apply_update(self, path, value): - parts = path.split(".") - for part in parts[:-1]: - self = getattr(self, part) - if parts[-1] == "ttl": - value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) - parts[-1] = "expire_time" - setattr(self, parts[-1], value) + Args: + name: The resource name referring to the cached content. + """ + client = get_default_cache_client() - @classmethod - def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent: - # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build - # is returning these, hence setting including_default_value_fields to False - cached_content = type(cached_content).to_dict( - cached_content, including_default_value_fields=False + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + self._proto = response + + @property + def name(self) -> str: + return self._proto.name + + @property + def model(self) -> str: + return self._proto.model + + @property + def display_name(self) -> str: + return self._proto.display_name + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self._proto.usage_metadata + + @property + def create_time(self) -> datetime.datetime: + return self._proto.create_time + + @property + def update_time(self) -> datetime.datetime: + return self._proto.update_time + + @property + def expire_time(self) -> datetime.datetime: + return self._proto.expire_time + + def __str__(self): + return textwrap.dedent( + f"""\ + CachedContent( + name='{self.name}', + model='{self.model}', + display_name='{self.display_name}', + usage_metadata={'{'} + 'total_token_count': {self.usage_metadata.total_token_count}, + {'}'}, + create_time={self.create_time}, + update_time={self.update_time}, + expire_time={self.expire_time} + )""" ) - idecode_time(cached_content, "create_time") - idecode_time(cached_content, "update_time") - # always decode `expire_time` as Timestamp is returned - # regardless of what was sent on input - idecode_time(cached_content, "expire_time") - return cls(**cached_content) + __repr__ = __str__ + + @classmethod + def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent: + """Creates an instance of CachedContent form an object, without calling `get`.""" + self = cls.__new__(cls) + self._proto = protos.CachedContent() + self._update(obj) + return self + + def _update(self, updates): + """Updates this instance inplace, does not call the API's `update` method""" + if isinstance(updates, CachedContent): + updates = updates._proto + + if not isinstance(updates, dict): + updates = type(updates).to_dict(updates, including_default_value_fields=False) + + for key, value in updates.items(): + setattr(self._proto, key, value) @staticmethod def _prepare_create_request( model: str, - name: str | None = None, + *, + display_name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, ) -> protos.CreateCachedContentRequest: """Prepares a CreateCachedContentRequest.""" - if name is not None: - if not caching_types.valid_cached_content_name(name): - raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) - - name = "cachedContents/" + name + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) if "/" not in model: model = "models/" + model + if display_name and len(display_name) > 128: + raise ValueError("`display_name` must be no more than 128 unicode characters.") + if system_instruction: system_instruction = content_types.to_content(system_instruction) @@ -110,18 +150,21 @@ def _prepare_create_request( if contents: contents = content_types.to_contents(contents) + if not contents[-1].role: + contents[-1].role = _USER_ROLE - if ttl: - ttl = caching_types.to_ttl(ttl) + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) cached_content = protos.CachedContent( - name=name, model=model, + display_name=display_name, system_instruction=system_instruction, contents=contents, tools=tools_lib, tool_config=tool_config, ttl=ttl, + expire_time=expire_time, ) return protos.CreateCachedContentRequest(cached_content=cached_content) @@ -130,13 +173,14 @@ def _prepare_create_request( def create( cls, model: str, - name: str | None = None, + *, + display_name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), - client: glm.CacheServiceClient | None = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, ) -> CachedContent: """Creates `CachedContent` resource. @@ -144,34 +188,40 @@ def create( model: The name of the `model` to use for cached content creation. Any `CachedContent` resource can be only used with the `model` it was created for. - name: The resource name referring to the cached content. + display_name: The user-generated meaningful display name + of the cached content. `display_name` must be no + more than 128 unicode characters. system_instruction: Developer set system instruction. contents: Contents to cache. tools: A list of `Tools` the model may use to generate response. tool_config: Config to apply to all tools. ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. Returns: `CachedContent` resource with specified name. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = cls._prepare_create_request( model=model, - name=name, + display_name=display_name, system_instruction=system_instruction, contents=contents, tools=tools, tool_config=tool_config, ttl=ttl, + expire_time=expire_time, ) response = client.create_cached_content(request) - return cls._decode_cached_content(response) + result = CachedContent._from_obj(response) + return result @classmethod - def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + def get(cls, name: str) -> CachedContent: """Fetches required `CachedContent` resource. Args: @@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC Returns: `CachedContent` resource with specified `name`. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() if "cachedContents/" not in name: name = "cachedContents/" + name request = protos.GetCachedContentRequest(name=name) response = client.get_cached_content(request) - return cls._decode_cached_content(response) + result = CachedContent._from_obj(response) + return result @classmethod - def list( - cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None - ) -> Iterable[CachedContent]: + def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]: """Lists `CachedContent` objects associated with the project. Args: @@ -203,17 +251,16 @@ def list( Returns: A paginated list of `CachedContent` objects. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = protos.ListCachedContentsRequest(page_size=page_size) for cached_content in client.list_cached_contents(request): - yield cls._decode_cached_content(cached_content) + cached_content = CachedContent._from_obj(cached_content) + yield cached_content - def delete(self, client: glm.CachedServiceClient | None = None) -> None: + def delete(self) -> None: """Deletes `CachedContent` resource.""" - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = protos.DeleteCachedContentRequest(name=self.name) client.delete_cached_content(request) @@ -221,40 +268,47 @@ def delete(self, client: glm.CachedServiceClient | None = None) -> None: def update( self, - updates: dict[str, Any], - client: glm.CacheServiceClient | None = None, - ) -> CachedContent: + *, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> None: """Updates requested `CachedContent` resource. Args: - updates: The list of fields to update. Currently only - `ttl/expire_time` is supported as an update path. - - Returns: - `CachedContent` object with specified updates. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. """ - if client is None: - client = get_default_cache_client() - - updates = flatten_update_paths(updates) - for update_path in updates: - if update_path == "ttl": - updates = updates.copy() - update_path_val = updates.get(update_path) - updates[update_path] = caching_types.to_ttl(update_path_val) - else: - raise ValueError( - f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." - ) - field_mask = field_mask_pb2.FieldMask() + client = get_default_cache_client() - for path in updates.keys(): - field_mask.paths.append(path) - for path, value in updates.items(): - self._apply_update(path, value) + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) - request = protos.UpdateCachedContentRequest( - cached_content=self._to_dict(), update_mask=field_mask + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + updates = protos.CachedContent( + name=self.name, + ttl=ttl, + expire_time=expire_time, ) - client.update_cached_content(request) - return self + + field_mask = field_mask_pb2.FieldMask() + + if ttl: + field_mask.paths.append("ttl") + elif expire_time: + field_mask.paths.append("expire_time") + else: + raise ValueError( + f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`." + ) + + request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask) + updated_cc = client.update_cached_content(request) + self._update(updated_cc) + + return diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 10744a948..e3387a64f 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -20,6 +20,9 @@ from google.generativeai.types import helper_types from google.generativeai.types import safety_types +_USER_ROLE = "user" +_MODEL_ROLE = "model" + class GenerativeModel: """ @@ -96,14 +99,9 @@ def __init__( self._client = None self._async_client = None - def __new__(cls, *args, **kwargs) -> GenerativeModel: - self = super().__new__(cls) - - if cached_instance := kwargs.pop("cached_content", None): - setattr(self, "_cached_content", cached_instance.name) - setattr(cls, "cached_content", property(fget=lambda self: self._cached_content)) - - return self + @property + def cached_content(self) -> str: + return getattr(self, "_cached_content", None) @property def model_name(self): @@ -123,7 +121,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, - cached_content={getattr(self, "cached_content", None)} + cached_content={self.cached_content} )""" ) @@ -139,13 +137,11 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: """Creates a `protos.GenerateContentRequest` from raw inputs.""" - if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]): + if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]): raise ValueError( "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." ) - cached_content = getattr(self, "cached_content", None) - tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -174,7 +170,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, - cached_content=cached_content, + cached_content=self.cached_content, ) def _get_tools_lib( @@ -190,6 +186,7 @@ def _get_tools_lib( def from_cached_content( cls, cached_content: str, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: ... @@ -199,6 +196,7 @@ def from_cached_content( def from_cached_content( cls, cached_content: caching.CachedContent, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: ... @@ -207,6 +205,7 @@ def from_cached_content( def from_cached_content( cls, cached_content: str | caching.CachedContent, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: @@ -214,6 +213,8 @@ def from_cached_content( Args: cached_content: context for the model. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. Returns: `GenerativeModel` object with `cached_content` as its context. @@ -221,17 +222,16 @@ def from_cached_content( if isinstance(cached_content, str): cached_content = caching.CachedContent.get(name=cached_content) - # call __new__ with the cached_content to set the model's context. This is done to avoid - # the exposing `cached_content` as a public attribute. - self = cls.__new__(cls, cached_content=cached_content) - # call __init__ to set the model's `generation_config`, `safety_settings`. # `model_name` will be the name of the model for which the `cached_content` was created. - self.__init__( + self = cls( model_name=cached_content.model, generation_config=generation_config, safety_settings=safety_settings, ) + + # set the model's context. + setattr(self, "_cached_content", cached_content.name) return self def generate_content( @@ -309,6 +309,10 @@ def generate_content( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._client is None: self._client = client.get_default_generative_client() @@ -359,6 +363,10 @@ async def generate_content_async( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._async_client is None: self._async_client = client.get_default_generative_async_client() @@ -489,9 +497,6 @@ class ChatSession: history: A chat history to initialize the object with. """ - _USER_ROLE = "user" - _MODEL_ROLE = "model" - def __init__( self, model: GenerativeModel, @@ -559,7 +564,7 @@ def send_message( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -646,7 +651,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -688,7 +693,7 @@ async def send_message_async( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -753,7 +758,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -820,7 +825,7 @@ def history(self) -> list[protos.Content]: sent = self._last_sent received = last.candidates[0].content if not received.role: - received.role = self._MODEL_ROLE + received.role = _MODEL_ROLE self._history.extend([sent, received]) self._last_sent = None diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index 8d55b70b2..4f1a6b8be 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -15,39 +15,69 @@ from __future__ import annotations import datetime -from typing import Optional, Union +from typing import Union from typing_extensions import TypedDict -import re -__all__ = ["TTL"] +__all__ = [ + "ExpireTime", + "TTL", + "TTLTypes", + "ExpireTimeTypes", +] -_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$" -NAME_ERROR_MESSAGE = ( - "The `name` must consist of alphanumeric characters (or `-` or `.`). Received: `{name}`" -) +class TTL(TypedDict): + # Represents datetime.datetime.now() + desired ttl + seconds: int + nanos: int -def valid_cached_content_name(name: str) -> bool: - return re.match(_VALID_CACHED_CONTENT_NAME, name) is not None +class ExpireTime(TypedDict): + # Represents seconds of UTC time since Unix epoch + seconds: int + nanos: int -class TTL(TypedDict): - seconds: int +TTLTypes = Union[TTL, int, datetime.timedelta] +ExpireTimeTypes = Union[ExpireTime, int, datetime.datetime] -ExpirationTypes = Union[TTL, int, datetime.timedelta] +def to_optional_ttl(ttl: TTLTypes | None) -> TTL | None: + if ttl is None: + return None + elif isinstance(ttl, datetime.timedelta): + return { + "seconds": int(ttl.total_seconds()), + "nanos": int(ttl.microseconds * 1000), + } + elif isinstance(ttl, dict): + return ttl + elif isinstance(ttl, int): + return {"seconds": ttl, "nanos": 0} + else: + raise TypeError( + f"Could not convert input to `ttl` \n'" f" type: {type(ttl)}\n", + ttl, + ) -def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: - if isinstance(expiration, datetime.timedelta): - return {"seconds": int(expiration.total_seconds())} - elif isinstance(expiration, dict): - return expiration - elif isinstance(expiration, int): - return {"seconds": expiration} +def to_optional_expire_time(expire_time: ExpireTimeTypes | None) -> ExpireTime | None: + if expire_time is None: + return expire_time + elif isinstance(expire_time, datetime.datetime): + timestamp = expire_time.timestamp() + seconds = int(timestamp) + nanos = int((seconds % 1) * 1000) + return { + "seconds": seconds, + "nanos": nanos, + } + elif isinstance(expire_time, dict): + return expire_time + elif isinstance(expire_time, int): + return {"seconds": expire_time, "nanos": 0} else: raise TypeError( - f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", - expiration, + f"Could not convert input to `expire_time` \n'" f" type: {type(expire_time)}\n", + expire_time, ) diff --git a/google/generativeai/version.py b/google/generativeai/version.py index 8018b67ac..69a8b817e 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/tests/test_caching.py b/tests/test_caching.py index 47692325b..1d1b2608c 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime +import textwrap import unittest from google.generativeai import caching @@ -44,7 +45,9 @@ def create_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -58,7 +61,9 @@ def get_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -73,14 +78,18 @@ def list_cached_contents( return [ protos.CachedContent( name="cachedContents/test-cached-content-1", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", ), protos.CachedContent( name="cachedContents/test-cached-content-2", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -95,7 +104,9 @@ def update_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T03:01:01.123456Z", @@ -114,8 +125,7 @@ def add(a: int, b: int) -> int: return a + b cc = caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", contents=["Add 5 and 6"], tools=[add], tool_config={"function_calling_config": "ANY"}, @@ -125,7 +135,7 @@ def add(a: int, b: int) -> int: self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") - self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + self.assertEqual(cc.model, "models/gemini-1.5-pro") @parameterized.named_parameters( [ @@ -147,10 +157,9 @@ def add(a: int, b: int) -> int: ), ] ) - def test_expiration_types_for_create_cached_content(self, ttl): + def test_ttl_types_for_create_cached_content(self, ttl): cc = caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", contents=["cache this please for 2 hours"], ttl=ttl, ) @@ -160,28 +169,39 @@ def test_expiration_types_for_create_cached_content(self, ttl): @parameterized.named_parameters( [ dict( - testcase_name="upper_case", - name="Test-cached-content", + testcase_name="expire_time-is-int-seconds", + expire_time=1717653421, ), dict( - testcase_name="special_characters_except_dot_and_hyphen", - name="test-cac*@/hed-conte#nt", + testcase_name="expire_time-is-datetime", + expire_time=datetime.datetime.now(), ), dict( - testcase_name="empty_name", - name="", + testcase_name="expire_time-is-dict", + expire_time={"seconds": 1717653421}, ), dict( - testcase_name="blank_spaces", - name="test cached content", + testcase_name="expire_time-is-none-default-to-1-hr", + expire_time=None, ), ] ) - def test_create_cached_content_with_invalid_name_format(self, name): + def test_expire_time_types_for_create_cached_content(self, expire_time): + cc = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + expire_time=expire_time, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_mutual_exclusivity_for_ttl_and_expire_time_in_create_cached_content(self): with self.assertRaises(ValueError): _ = caching.CachedContent.create( - name=name, - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + ttl=datetime.timedelta(hours=2), + expire_time=datetime.datetime.now(), ) def test_get_cached_content(self): @@ -189,7 +209,7 @@ def test_get_cached_content(self): self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") - self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + self.assertEqual(cc.model, "models/gemini-1.5-pro") def test_list_cached_contents(self): ccs = list(caching.CachedContent.list(page_size=2)) @@ -198,25 +218,27 @@ def test_list_cached_contents(self): self.assertIsInstance(ccs[0], caching.CachedContent) self.assertIsInstance(ccs[1], caching.CachedContent) - def test_update_cached_content_invalid_update_paths(self): - update_masks = dict( - name="change", - model="models/gemini-1.5-pro-001", - system_instruction="Always add 10 to the result.", - contents=["add this Content"], - ) + def test_update_cached_content_ttl_and_expire_time_are_mutually_exclusive(self): + ttl = datetime.timedelta(hours=2) + expire_time = datetime.datetime.now() cc = caching.CachedContent.get(name="cachedContents/test-cached-content") with self.assertRaises(ValueError): - cc.update(updates=update_masks) + cc.update(ttl=ttl, expire_time=expire_time) - def test_update_cached_content_valid_update_paths(self): - update_masks = dict( - ttl=datetime.timedelta(hours=2), - ) + @parameterized.named_parameters( + [ + dict(testcase_name="ttl", ttl=datetime.timedelta(hours=2)), + dict( + testcase_name="expire_time", + expire_time=datetime.datetime(2024, 6, 5, 12, 12, 12, 23), + ), + ] + ) + def test_update_cached_content_valid_update_paths(self, ttl=None, expire_time=None): cc = caching.CachedContent.get(name="cachedContents/test-cached-content") - cc = cc.update(updates=update_masks) + cc.update(ttl=ttl, expire_time=expire_time) self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) @@ -229,17 +251,23 @@ def test_delete_cached_content(self): cc.delete() self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) - def test_auto_delete_cached_content_with_context_manager(self): - with caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", - contents=["Add 5 and 6"], - system_instruction="Always add 10 to the result.", - ttl=datetime.timedelta(minutes=30), - ) as cc: - ... # some logic - - self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + def test_repr_cached_content(self): + expexted_repr = textwrap.dedent( + """\ + CachedContent( + name='cachedContents/test-cached-content', + model='models/gemini-1.5-pro', + display_name='Cached content for test', + usage_metadata={ + 'total_token_count': 1, + }, + create_time=2000-01-01 01:01:01.123456+00:00, + update_time=2000-01-01 01:01:01.123456+00:00, + expire_time=2000-01-01 01:01:01.123456+00:00 + )""" + ) + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertEqual(repr(cc), expexted_repr) if __name__ == "__main__": diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4b9501334..cccea9d48 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -3,9 +3,7 @@ import copy import datetime import pathlib -from typing import Any import textwrap -import unittest.mock from absl.testing import absltest from absl.testing import parameterized from google.generativeai import protos @@ -86,7 +84,9 @@ def get_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -338,12 +338,16 @@ def test_stream_prompt_feedback_not_blocked(self): dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), dict( testcase_name="test_cached_content_as_CachedContent_object", - cached_content=caching.CachedContent( - name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", - create_time=datetime.datetime.now(), - update_time=datetime.datetime.now(), - expire_time=datetime.datetime.now(), + cached_content=caching.CachedContent._from_obj( + dict( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ) ), ), ], @@ -353,7 +357,7 @@ def test_model_with_cached_content_as_context(self, cached_content): cc_name = model.cached_content # pytype: disable=attribute-error model_name = model.model_name self.assertEqual(cc_name, "cachedContents/test-cached-content") - self.assertEqual(model_name, "models/gemini-1.0-pro-001") + self.assertEqual(model_name, "models/gemini-1.5-pro") self.assertEqual( model.cached_content, # pytype: disable=attribute-error "cachedContents/test-cached-content", @@ -801,9 +805,9 @@ def test_tool_config(self, tool_config, expected_tool_config): [ "part_dict", {"parts": [{"text": "talk like a pirate"}]}, - simple_part("talk like a pirate"), + protos.Content(parts=[{"text": "talk like a pirate"}]), ], - ["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])], + ["part_list", ["talk like", "a pirate"], iter_part(["talk like", "a pirate"])], ) def test_system_instruction(self, instruction, expected_instr): self.responses["generate_content"] = [simple_response("echo echo")] @@ -1251,7 +1255,7 @@ def test_repr_for_model_created_from_cahced_content(self): ) result = repr(model) self.assertIn("cached_content=cachedContents/test-cached-content", result) - self.assertIn("model_name='models/gemini-1.0-pro-001'", result) + self.assertIn("model_name='models/gemini-1.5-pro'", result) def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) From 419a7cef81188a2d715b3853113a3213516d6a4a Mon Sep 17 00:00:00 2001 From: Mayuresh Agashe Date: Fri, 14 Jun 2024 11:23:02 +0530 Subject: [PATCH 7/7] Accept partial file names (#386) Adds prefix files/ if not present Change-Id: Iac5c4d0934620f2462cf15ae519474b0ce7908da --- google/generativeai/files.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 4028d37f7..0e8a8ed27 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -81,16 +81,20 @@ def list_files(page_size=100) -> Iterable[file_types.File]: yield file_types.File(proto) -def get_file(name) -> file_types.File: +def get_file(name: str) -> file_types.File: """Calls the API to retrieve a specified file using a supported file service.""" + if "/" not in name: + name = f"files/{name}" client = get_default_file_client() return file_types.File(client.get_file(name=name)) -def delete_file(name): +def delete_file(name: str | file_types.File | protos.File): """Calls the API to permanently delete a specified file using a supported file service.""" if isinstance(name, (file_types.File, protos.File)): name = name.name + elif "/" not in name: + name = f"files/{name}" request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request)