diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 16c34f37a..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Bug report -description: Use this template to report bugs -labels: ["type:bug", "component:python sdk"] -body: - - type: markdown - attributes: - value: > - **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [Discourse forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. - - type: textarea - id: description - attributes: - label: > - Description of the bug: - - type: textarea - id: behavior - attributes: - label: > - Actual vs expected behavior: - - type: textarea - id: info - attributes: - label: > - Any other information you'd like to share? diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml deleted file mode 100644 index 91a380e87..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Feature request -description: Use this template to suggest a new feature -labels: ["type:feature request", "component:python sdk"] -body: - - type: markdown - attributes: - value: > - **Note:** If this is a support question (e.g. _How do I do XYZ?_), please visit the [Discourse forum](https://discuss.ai.google.dev/). This is a great place to interact with developers, and to learn, share, and support each other. - - type: textarea - id: description - attributes: - label: > - Description of the feature request: - - type: textarea - id: behavior - attributes: - label: > - What problem are you trying to solve with this feature? - - type: textarea - id: info - attributes: - label: > - Any other information you'd like to share? diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py new file mode 100644 index 000000000..9acad4726 --- /dev/null +++ b/google/generativeai/caching.py @@ -0,0 +1,314 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +import textwrap +from typing import Iterable, Optional + +from google.generativeai import protos +from google.generativeai.types import caching_types +from google.generativeai.types import content_types +from google.generativeai.client import get_default_cache_client + +from google.protobuf import field_mask_pb2 + +_USER_ROLE = "user" +_MODEL_ROLE = "model" + + +class CachedContent: + """Cached content resource.""" + + def __init__(self, name): + """Fetches a `CachedContent` resource. + + Identical to `CachedContent.get`. + + Args: + name: The resource name referring to the cached content. + """ + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + self._proto = response + + @property + def name(self) -> str: + return self._proto.name + + @property + def model(self) -> str: + return self._proto.model + + @property + def display_name(self) -> str: + return self._proto.display_name + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self._proto.usage_metadata + + @property + def create_time(self) -> datetime.datetime: + return self._proto.create_time + + @property + def update_time(self) -> datetime.datetime: + return self._proto.update_time + + @property + def expire_time(self) -> datetime.datetime: + return self._proto.expire_time + + def __str__(self): + return textwrap.dedent( + f"""\ + CachedContent( + name='{self.name}', + model='{self.model}', + display_name='{self.display_name}', + usage_metadata={'{'} + 'total_token_count': {self.usage_metadata.total_token_count}, + {'}'}, + create_time={self.create_time}, + update_time={self.update_time}, + expire_time={self.expire_time} + )""" + ) + + __repr__ = __str__ + + @classmethod + def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent: + """Creates an instance of CachedContent form an object, without calling `get`.""" + self = cls.__new__(cls) + self._proto = protos.CachedContent() + self._update(obj) + return self + + def _update(self, updates): + """Updates this instance inplace, does not call the API's `update` method""" + if isinstance(updates, CachedContent): + updates = updates._proto + + if not isinstance(updates, dict): + updates = type(updates).to_dict(updates, including_default_value_fields=False) + + for key, value in updates.items(): + setattr(self._proto, key, value) + + @staticmethod + def _prepare_create_request( + model: str, + *, + display_name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> protos.CreateCachedContentRequest: + """Prepares a CreateCachedContentRequest.""" + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) + + if "/" not in model: + model = "models/" + model + + if display_name and len(display_name) > 128: + raise ValueError("`display_name` must be no more than 128 unicode characters.") + + if system_instruction: + system_instruction = content_types.to_content(system_instruction) + + tools_lib = content_types.to_function_library(tools) + if tools_lib: + tools_lib = tools_lib.to_proto() + + if tool_config: + tool_config = content_types.to_tool_config(tool_config) + + if contents: + contents = content_types.to_contents(contents) + if not contents[-1].role: + contents[-1].role = _USER_ROLE + + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + cached_content = protos.CachedContent( + model=model, + display_name=display_name, + system_instruction=system_instruction, + contents=contents, + tools=tools_lib, + tool_config=tool_config, + ttl=ttl, + expire_time=expire_time, + ) + + return protos.CreateCachedContentRequest(cached_content=cached_content) + + @classmethod + def create( + cls, + model: str, + *, + display_name: str | None = None, + system_instruction: Optional[content_types.ContentType] = None, + contents: Optional[content_types.ContentsType] = None, + tools: Optional[content_types.FunctionLibraryType] = None, + tool_config: Optional[content_types.ToolConfigType] = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> CachedContent: + """Creates `CachedContent` resource. + + Args: + model: The name of the `model` to use for cached content creation. + Any `CachedContent` resource can be only used with the + `model` it was created for. + display_name: The user-generated meaningful display name + of the cached content. `display_name` must be no + more than 128 unicode characters. + system_instruction: Developer set system instruction. + contents: Contents to cache. + tools: A list of `Tools` the model may use to generate response. + tool_config: Config to apply to all tools. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. + + Returns: + `CachedContent` resource with specified name. + """ + client = get_default_cache_client() + + request = cls._prepare_create_request( + model=model, + display_name=display_name, + system_instruction=system_instruction, + contents=contents, + tools=tools, + tool_config=tool_config, + ttl=ttl, + expire_time=expire_time, + ) + + response = client.create_cached_content(request) + result = CachedContent._from_obj(response) + return result + + @classmethod + def get(cls, name: str) -> CachedContent: + """Fetches required `CachedContent` resource. + + Args: + name: The resource name referring to the cached content. + + Returns: + `CachedContent` resource with specified `name`. + """ + client = get_default_cache_client() + + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + result = CachedContent._from_obj(response) + return result + + @classmethod + def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]: + """Lists `CachedContent` objects associated with the project. + + Args: + page_size: The maximum number of permissions to return (per page). + The service may return fewer `CachedContent` objects. + + Returns: + A paginated list of `CachedContent` objects. + """ + client = get_default_cache_client() + + request = protos.ListCachedContentsRequest(page_size=page_size) + for cached_content in client.list_cached_contents(request): + cached_content = CachedContent._from_obj(cached_content) + yield cached_content + + def delete(self) -> None: + """Deletes `CachedContent` resource.""" + client = get_default_cache_client() + + request = protos.DeleteCachedContentRequest(name=self.name) + client.delete_cached_content(request) + return + + def update( + self, + *, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> None: + """Updates requested `CachedContent` resource. + + Args: + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. + """ + client = get_default_cache_client() + + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) + + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + updates = protos.CachedContent( + name=self.name, + ttl=ttl, + expire_time=expire_time, + ) + + field_mask = field_mask_pb2.FieldMask() + + if ttl: + field_mask.paths.append("ttl") + elif expire_time: + field_mask.paths.append("expire_time") + else: + raise ValueError( + f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`." + ) + + request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask) + updated_cc = client.update_cached_content(request) + self._update(updated_cc) + + return diff --git a/google/generativeai/client.py b/google/generativeai/client.py index 40c2bdcaf..7012ecc7c 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -315,6 +315,10 @@ def configure( _client_manager.configure() +def get_default_cache_client() -> glm.CacheServiceClient: + return _client_manager.get_default_client("cache") + + def get_default_discuss_client() -> glm.DiscussServiceClient: return _client_manager.get_default_client("discuss") diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 4028d37f7..0e8a8ed27 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -81,16 +81,20 @@ def list_files(page_size=100) -> Iterable[file_types.File]: yield file_types.File(proto) -def get_file(name) -> file_types.File: +def get_file(name: str) -> file_types.File: """Calls the API to retrieve a specified file using a supported file service.""" + if "/" not in name: + name = f"files/{name}" client = get_default_file_client() return file_types.File(client.get_file(name=name)) -def delete_file(name): +def delete_file(name: str | file_types.File | protos.File): """Calls the API to permanently delete a specified file using a supported file service.""" if isinstance(name, (file_types.File, protos.File)): name = name.name + elif "/" not in name: + name = f"files/{name}" request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 7d69ae8f9..e3387a64f 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -4,7 +4,7 @@ from collections.abc import Iterable import textwrap -from typing import Any +from typing import Any, Union, overload import reprlib # pylint: disable=bad-continuation, line-too-long @@ -13,11 +13,16 @@ import google.api_core.exceptions from google.generativeai import protos from google.generativeai import client + +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types from google.generativeai.types import safety_types +_USER_ROLE = "user" +_MODEL_ROLE = "model" + class GenerativeModel: """ @@ -94,6 +99,10 @@ def __init__( self._client = None self._async_client = None + @property + def cached_content(self) -> str: + return getattr(self, "_cached_content", None) + @property def model_name(self): return self._model_name @@ -112,6 +121,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, + cached_content={self.cached_content} )""" ) @@ -127,6 +137,11 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: """Creates a `protos.GenerateContentRequest` from raw inputs.""" + if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]): + raise ValueError( + "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." + ) + tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -155,6 +170,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, + cached_content=self.cached_content, ) def _get_tools_lib( @@ -165,6 +181,59 @@ def _get_tools_lib( else: return content_types.to_function_library(tools) + @overload + @classmethod + def from_cached_content( + cls, + cached_content: str, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @overload + @classmethod + def from_cached_content( + cls, + cached_content: caching.CachedContent, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: ... + + @classmethod + def from_cached_content( + cls, + cached_content: str | caching.CachedContent, + *, + generation_config: generation_types.GenerationConfigType | None = None, + safety_settings: safety_types.SafetySettingOptions | None = None, + ) -> GenerativeModel: + """Creates a model with `cached_content` as model's context. + + Args: + cached_content: context for the model. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. + + Returns: + `GenerativeModel` object with `cached_content` as its context. + """ + if isinstance(cached_content, str): + cached_content = caching.CachedContent.get(name=cached_content) + + # call __init__ to set the model's `generation_config`, `safety_settings`. + # `model_name` will be the name of the model for which the `cached_content` was created. + self = cls( + model_name=cached_content.model, + generation_config=generation_config, + safety_settings=safety_settings, + ) + + # set the model's context. + setattr(self, "_cached_content", cached_content.name) + return self + def generate_content( self, contents: content_types.ContentsType, @@ -240,6 +309,10 @@ def generate_content( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._client is None: self._client = client.get_default_generative_client() @@ -290,6 +363,10 @@ async def generate_content_async( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._async_client is None: self._async_client = client.get_default_generative_async_client() @@ -420,9 +497,6 @@ class ChatSession: history: A chat history to initialize the object with. """ - _USER_ROLE = "user" - _MODEL_ROLE = "model" - def __init__( self, model: GenerativeModel, @@ -490,7 +564,7 @@ def send_message( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -577,7 +651,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -619,7 +693,7 @@ async def send_message_async( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -684,7 +758,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -751,7 +825,7 @@ def history(self) -> list[protos.Content]: sent = self._last_sent received = last.candidates[0].content if not received.role: - received.role = self._MODEL_ROLE + received.role = _MODEL_ROLE self._history.extend([sent, received]) self._last_sent = None diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py new file mode 100644 index 000000000..4f1a6b8be --- /dev/null +++ b/google/generativeai/types/caching_types.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +from typing import Union +from typing_extensions import TypedDict + +__all__ = [ + "ExpireTime", + "TTL", + "TTLTypes", + "ExpireTimeTypes", +] + + +class TTL(TypedDict): + # Represents datetime.datetime.now() + desired ttl + seconds: int + nanos: int + + +class ExpireTime(TypedDict): + # Represents seconds of UTC time since Unix epoch + seconds: int + nanos: int + + +TTLTypes = Union[TTL, int, datetime.timedelta] +ExpireTimeTypes = Union[ExpireTime, int, datetime.datetime] + + +def to_optional_ttl(ttl: TTLTypes | None) -> TTL | None: + if ttl is None: + return None + elif isinstance(ttl, datetime.timedelta): + return { + "seconds": int(ttl.total_seconds()), + "nanos": int(ttl.microseconds * 1000), + } + elif isinstance(ttl, dict): + return ttl + elif isinstance(ttl, int): + return {"seconds": ttl, "nanos": 0} + else: + raise TypeError( + f"Could not convert input to `ttl` \n'" f" type: {type(ttl)}\n", + ttl, + ) + + +def to_optional_expire_time(expire_time: ExpireTimeTypes | None) -> ExpireTime | None: + if expire_time is None: + return expire_time + elif isinstance(expire_time, datetime.datetime): + timestamp = expire_time.timestamp() + seconds = int(timestamp) + nanos = int((seconds % 1) * 1000) + return { + "seconds": seconds, + "nanos": nanos, + } + elif isinstance(expire_time, dict): + return expire_time + elif isinstance(expire_time, int): + return {"seconds": expire_time, "nanos": 0} + else: + raise TypeError( + f"Could not convert input to `expire_time` \n'" f" type: {type(expire_time)}\n", + expire_time, + ) diff --git a/google/generativeai/version.py b/google/generativeai/version.py index 8018b67ac..69a8b817e 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/setup.py b/setup.py index 6f9545e4f..89af61515 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage==0.6.4", + "google-ai-generativelanguage==0.6.5", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_async_code_match.py b/tests/test_async_code_match.py index 008d251f2..0ec4550d4 100644 --- a/tests/test_async_code_match.py +++ b/tests/test_async_code_match.py @@ -87,7 +87,7 @@ def test_code_match_for_async_methods(self): for node in ast.walk(source_nodes): if isinstance( node, (ast.FunctionDef, ast.AsyncFunctionDef) - ) and not node.name.startswith("_"): + ) and not node.name.startswith("__"): name = node.name[:-6] if node.name.endswith("_async") else node.name if name in EXEMPT_FUNCTIONS or self._inspect_decorator_exemption(node, fpath): continue diff --git a/tests/test_caching.py b/tests/test_caching.py new file mode 100644 index 000000000..1d1b2608c --- /dev/null +++ b/tests/test_caching.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import textwrap +import unittest + +from google.generativeai import caching +from google.generativeai import protos + +from google.generativeai import client +from absl.testing import absltest +from absl.testing import parameterized + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = unittest.mock.MagicMock() + + client._client_manager.clients["cache"] = self.client + + self.observed_requests = [] + + def add_client_method(f): + name = f.__name__ + setattr(self.client, name, f) + return f + + @add_client_method + def create_cached_content( + request: protos.CreateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def get_cached_content( + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + + @add_client_method + def list_cached_contents( + request: protos.ListCachedContentsRequest, + **kwargs, + ) -> protos.ListCachedContentsResponse: + self.observed_requests.append(request) + return [ + protos.CachedContent( + name="cachedContents/test-cached-content-1", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + protos.CachedContent( + name="cachedContents/test-cached-content-2", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ), + ] + + @add_client_method + def update_cached_content( + request: protos.UpdateCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T03:01:01.123456Z", + ) + + @add_client_method + def delete_cached_content( + request: protos.DeleteCachedContentRequest, + **kwargs, + ) -> None: + self.observed_requests.append(request) + + def test_create_cached_content(self): + + def add(a: int, b: int) -> int: + return a + b + + cc = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["Add 5 and 6"], + tools=[add], + tool_config={"function_calling_config": "ANY"}, + system_instruction="Always add 10 to the result.", + ttl=datetime.timedelta(minutes=30), + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.5-pro") + + @parameterized.named_parameters( + [ + dict( + testcase_name="ttl-is-int-seconds", + ttl=7200, + ), + dict( + testcase_name="ttl-is-timedelta", + ttl=datetime.timedelta(hours=2), + ), + dict( + testcase_name="ttl-is-dict", + ttl={"seconds": 7200}, + ), + dict( + testcase_name="ttl-is-none-default-to-1-hr", + ttl=None, + ), + ] + ) + def test_ttl_types_for_create_cached_content(self, ttl): + cc = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + ttl=ttl, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + @parameterized.named_parameters( + [ + dict( + testcase_name="expire_time-is-int-seconds", + expire_time=1717653421, + ), + dict( + testcase_name="expire_time-is-datetime", + expire_time=datetime.datetime.now(), + ), + dict( + testcase_name="expire_time-is-dict", + expire_time={"seconds": 1717653421}, + ), + dict( + testcase_name="expire_time-is-none-default-to-1-hr", + expire_time=None, + ), + ] + ) + def test_expire_time_types_for_create_cached_content(self, expire_time): + cc = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + expire_time=expire_time, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_mutual_exclusivity_for_ttl_and_expire_time_in_create_cached_content(self): + with self.assertRaises(ValueError): + _ = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + ttl=datetime.timedelta(hours=2), + expire_time=datetime.datetime.now(), + ) + + def test_get_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + self.assertEqual(cc.name, "cachedContents/test-cached-content") + self.assertEqual(cc.model, "models/gemini-1.5-pro") + + def test_list_cached_contents(self): + ccs = list(caching.CachedContent.list(page_size=2)) + self.assertIsInstance(self.observed_requests[-1], protos.ListCachedContentsRequest) + self.assertLen(ccs, 2) + self.assertIsInstance(ccs[0], caching.CachedContent) + self.assertIsInstance(ccs[1], caching.CachedContent) + + def test_update_cached_content_ttl_and_expire_time_are_mutually_exclusive(self): + ttl = datetime.timedelta(hours=2) + expire_time = datetime.datetime.now() + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + with self.assertRaises(ValueError): + cc.update(ttl=ttl, expire_time=expire_time) + + @parameterized.named_parameters( + [ + dict(testcase_name="ttl", ttl=datetime.timedelta(hours=2)), + dict( + testcase_name="expire_time", + expire_time=datetime.datetime(2024, 6, 5, 12, 12, 12, 23), + ), + ] + ) + def test_update_cached_content_valid_update_paths(self, ttl=None, expire_time=None): + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.update(ttl=ttl, expire_time=expire_time) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_delete_cached_content(self): + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + cc.delete() + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + + def test_repr_cached_content(self): + expexted_repr = textwrap.dedent( + """\ + CachedContent( + name='cachedContents/test-cached-content', + model='models/gemini-1.5-pro', + display_name='Cached content for test', + usage_metadata={ + 'total_token_count': 1, + }, + create_time=2000-01-01 01:01:01.123456+00:00, + update_time=2000-01-01 01:01:01.123456+00:00, + expire_time=2000-01-01 01:01:01.123456+00:00 + )""" + ) + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertEqual(repr(cc), expexted_repr) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 0ece77e94..cccea9d48 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -1,20 +1,19 @@ import collections from collections.abc import Iterable import copy +import datetime import pathlib -from typing import Any import textwrap -import unittest.mock from absl.testing import absltest from absl.testing import parameterized from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models +from google.generativeai import caching from google.generativeai.types import content_types from google.generativeai.types import generation_types from google.generativeai.types import helper_types - import PIL.Image HERE = pathlib.Path(__file__).parent @@ -77,6 +76,22 @@ def count_tokens( response = self.responses["count_tokens"].pop(0) return response + def get_cached_content( + self, + request: protos.GetCachedContentRequest, + **kwargs, + ) -> protos.CachedContent: + self.observed_requests.append(request) + return protos.CachedContent( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time="2000-01-01T01:01:01.123456Z", + update_time="2000-01-01T01:01:01.123456Z", + expire_time="2000-01-01T01:01:01.123456Z", + ) + class CUJTests(parameterized.TestCase): """Tests are in order with the design doc.""" @@ -96,6 +111,7 @@ def responses(self): def setUp(self): self.client = MockGenerativeServiceClient(self) client_lib._client_manager.clients["generative"] = self.client + client_lib._client_manager.clients["cache"] = self.client def test_hello(self): # Generate text from text prompt @@ -317,6 +333,60 @@ def test_stream_prompt_feedback_not_blocked(self): text = "".join(chunk.text for chunk in response) self.assertEqual(text, "first second") + @parameterized.named_parameters( + [ + dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), + dict( + testcase_name="test_cached_content_as_CachedContent_object", + cached_content=caching.CachedContent._from_obj( + dict( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ) + ), + ), + ], + ) + def test_model_with_cached_content_as_context(self, cached_content): + model = generative_models.GenerativeModel.from_cached_content(cached_content=cached_content) + cc_name = model.cached_content # pytype: disable=attribute-error + model_name = model.model_name + self.assertEqual(cc_name, "cachedContents/test-cached-content") + self.assertEqual(model_name, "models/gemini-1.5-pro") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_content_generation_with_model_having_context(self): + self.responses["generate_content"] = [simple_response("world!")] + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + response = model.generate_content("Hello") + + self.assertEqual(response.text, "world!") + self.assertEqual( + model.cached_content, # pytype: disable=attribute-error + "cachedContents/test-cached-content", + ) + + def test_fail_content_generation_with_model_having_context(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + + def add(a: int, b: int) -> int: + return a + b + + with self.assertRaises(ValueError): + model.generate_content("Hello", tools=[add]) + def test_chat(self): # Multi turn chat model = generative_models.GenerativeModel("gemini-pro") @@ -368,7 +438,7 @@ def test_chat_streaming_basic(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -391,7 +461,7 @@ def test_chat_incomplete_streaming_errors(self): iter([simple_response("x"), simple_response("y"), simple_response("z")]), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("letters?", stream=True) @@ -415,7 +485,7 @@ def test_edit_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello") @@ -441,7 +511,7 @@ def test_replace_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() chat.send_message("hello1") chat.send_message("hello2") @@ -463,7 +533,7 @@ def test_copy_history(self): simple_response("third"), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat1 = model.start_chat() chat1.send_message("hello1") @@ -508,7 +578,7 @@ def no_throw(): no_throw(), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() # Send a message, the response is okay.. @@ -551,7 +621,7 @@ def test_chat_prompt_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() with self.assertRaises(generation_types.BlockedPromptException): @@ -569,7 +639,7 @@ def test_chat_candidate_blocked(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() with self.assertRaises(generation_types.StopCandidateException): @@ -591,7 +661,7 @@ def test_chat_streaming_unexpected_stop(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello", stream=True) @@ -615,7 +685,7 @@ def test_tools(self): dict(name="datetime", description="Returns the current UTC date and time.") ] ) - model = generative_models.GenerativeModel("gemini-pro-vision", tools=tools) + model = generative_models.GenerativeModel("gemini-1.5-flash", tools=tools) self.responses["generate_content"] = [ simple_response("a"), @@ -735,9 +805,9 @@ def test_tool_config(self, tool_config, expected_tool_config): [ "part_dict", {"parts": [{"text": "talk like a pirate"}]}, - simple_part("talk like a pirate"), + protos.Content(parts=[{"text": "talk like a pirate"}]), ], - ["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])], + ["part_list", ["talk like", "a pirate"], iter_part(["talk like", "a pirate"])], ) def test_system_instruction(self, instruction, expected_instr): self.responses["generate_content"] = [simple_response("echo echo")] @@ -774,60 +844,13 @@ def test_system_instruction(self, instruction, expected_instr): def test_count_tokens_smoke(self, kwargs): si = kwargs.pop("system_instruction", None) self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) + model = generative_models.GenerativeModel("gemini-1.5-flash", system_instruction=si) response = model.count_tokens(**kwargs) - self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) - - @parameterized.named_parameters( - [ - "GenerateContentResponse", - generation_types.GenerateContentResponse, - generation_types.AsyncGenerateContentResponse, - ], - [ - "GenerativeModel.generate_response", - generative_models.GenerativeModel.generate_content, - generative_models.GenerativeModel.generate_content_async, - ], - [ - "GenerativeModel.count_tokens", - generative_models.GenerativeModel.count_tokens, - generative_models.GenerativeModel.count_tokens_async, - ], - [ - "ChatSession.send_message", - generative_models.ChatSession.send_message, - generative_models.ChatSession.send_message_async, - ], - [ - "ChatSession._handle_afc", - generative_models.ChatSession._handle_afc, - generative_models.ChatSession._handle_afc_async, - ], - ) - def test_async_code_match(self, obj, aobj): - import inspect - import re - - source = inspect.getsource(obj) - asource = inspect.getsource(aobj) - - source = re.sub('""".*"""', "", source, flags=re.DOTALL) - asource = re.sub('""".*"""', "", asource, flags=re.DOTALL) - - asource = ( - asource.replace("anext", "next") - .replace("aiter", "iter") - .replace("_async", "") - .replace("async ", "") - .replace("await ", "") - .replace("Async", "") - .replace("ASYNC_", "") + self.assertEqual( + type(response).to_dict(response, including_default_value_fields=False), + {"total_tokens": 7}, ) - asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource, f"error in {obj=}") - def test_repr_for_unary_non_streamed_response(self): model = generative_models.GenerativeModel(model_name="gemini-pro") self.responses["generate_content"].append(simple_response("world!")) @@ -999,7 +1022,7 @@ def no_throw(): no_throw(), ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() # Send a message, the response is okay.. @@ -1058,7 +1081,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): ) ] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") chat = model.start_chat() response = chat.send_message("hello", stream=True) @@ -1140,6 +1163,7 @@ def test_repr_for_multi_turn_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" @@ -1168,6 +1192,7 @@ def test_repr_for_incomplete_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1212,6 +1237,7 @@ def test_repr_for_broken_streaming_chat(self): safety_settings={}, tools=None, system_instruction=None, + cached_content=None ), history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" @@ -1223,11 +1249,19 @@ def test_repr_for_system_instruction(self): result = repr(model) self.assertIn("system_instruction='Be excellent.'", result) + def test_repr_for_model_created_from_cahced_content(self): + model = generative_models.GenerativeModel.from_cached_content( + cached_content="test-cached-content" + ) + result = repr(model) + self.assertIn("cached_content=cachedContents/test-cached-content", result) + self.assertIn("model_name='models/gemini-1.5-pro'", result) + def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options) self.assertEqual(request_options, self.observed_kwargs[0]) diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 03055ffb3..dd9bc3b62 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -214,9 +214,12 @@ async def test_tool_config(self, tool_config, expected_tool_config): ) async def test_count_tokens_smoke(self, contents): self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") response = await model.count_tokens_async(contents) - self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) + self.assertEqual( + type(response).to_dict(response, including_default_value_fields=False), + {"total_tokens": 7}, + ) async def test_stream_generate_content_called_with_request_options(self): self.client.stream_generate_content = unittest.mock.AsyncMock() @@ -253,7 +256,7 @@ async def test_count_tokens_called_with_request_options(self): request = unittest.mock.ANY request_options = {"timeout": 120} - model = generative_models.GenerativeModel("gemini-pro-vision") + model = generative_models.GenerativeModel("gemini-1.5-flash") response = await model.count_tokens_async( contents=["Hello?"], request_options=request_options )