diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index 6652a4ee1a..91737e688a 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -37,6 +37,13 @@ from google.auth.transport import requests as requests_auth if TYPE_CHECKING: + try: + from google.genai import types + + types = types + except (ImportError, AttributeError): + types = Any + try: from google.adk.events.event import Event @@ -1795,6 +1802,263 @@ async def async_search_memory(self, *, user_id: str, query: str): query=query, ) + async def async_save_artifact( + self, + *, + user_id: str, + filename: str, + artifact: Union["types.Part", Dict[str, Any], str], + session_id: Optional[str] = None, + custom_metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Saves an artifact to the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + artifact (Union[types.Part, Dict[str, Any], str]): + Required. The artifact to save. + session_id (Optional[str]): + Optional. The ID of the session. + custom_metadata (Optional[Dict[str, Any]]): + Optional. Custom metadata to associate with the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + int: The revision ID. + """ + if isinstance(artifact, str): + try: + from google.genai import types + except ImportError: + raise ImportError( + "The `google-genai` package is required to use AdkApp. " + "Please install it with `pip install google-genai`." + ) + artifact = types.Part(text=artifact) + + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").save_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + artifact=artifact, + session_id=session_id, + custom_metadata=custom_metadata, + **kwargs, + ) + + async def async_load_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets an artifact from the artifact service storage. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[types.Part]: The artifact or None if not found. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").load_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + + async def async_list_artifact_keys( + self, + *, + user_id: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all the artifact filenames within a session. + + Args: + user_id (str): + Required. The ID of the user. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[str]: A list of artifact filenames. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + + async def async_delete_artifact( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Deletes an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + await self._tmpl_attrs.get("artifact_service").delete_artifact( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[int]: A list of all available versions of the artifact. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_list_artifact_versions( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + **kwargs, + ): + """Lists all versions and their metadata for a specific artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + list[ArtifactVersion]: A list of ArtifactVersion objects. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + **kwargs, + ) + + async def async_get_artifact_version( + self, + *, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + **kwargs, + ): + """Gets the metadata for a specific version of an artifact. + + Args: + user_id (str): + Required. The ID of the user. + filename (str): + Required. The filename of the artifact. + session_id (Optional[str]): + Optional. The ID of the session. + version (Optional[int]): + Optional. The version number of the artifact. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + artifact service. + + Returns: + Optional[ArtifactVersion]: An ArtifactVersion object or None. + """ + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").get_artifact_version( + app_name=self._app_name(), + user_id=user_id, + filename=filename, + session_id=session_id, + version=version, + **kwargs, + ) + def register_operations(self) -> Dict[str, List[str]]: """Registers the operations of the ADK application.""" return { @@ -1811,6 +2075,13 @@ def register_operations(self) -> Dict[str, List[str]]: "async_delete_session", "async_add_session_to_memory", "async_search_memory", + "async_save_artifact", + "async_load_artifact", + "async_list_artifact_keys", + "async_delete_artifact", + "async_list_versions", + "async_list_artifact_versions", + "async_get_artifact_version", ], "stream": ["stream_query"], "async_stream": [ diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py index e882987903..59f5f85368 100644 --- a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py +++ b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py @@ -406,12 +406,26 @@ def test_adk_version(self): adk_template.AdkApp(agent=_TEST_AGENT) def setup_method(self): + for key in [ + "GOOGLE_CLOUD_PROJECT", + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", + "GOOGLE_CLOUD_LOCATION", + "GOOGLE_GENAI_USE_VERTEXAI", + ]: + os.environ.pop(key, None) importlib.reload(initializer) importlib.reload(agentplatform) agentplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) + for key in [ + "GOOGLE_CLOUD_PROJECT", + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", + "GOOGLE_CLOUD_LOCATION", + "GOOGLE_GENAI_USE_VERTEXAI", + ]: + os.environ.pop(key, None) def test_initialization(self): app = adk_template.AdkApp(agent=_TEST_AGENT) @@ -825,6 +839,160 @@ async def test_async_add_session_to_memory( assert len(response.memories) >= 1 @pytest.mark.asyncio + async def test_async_artifact_management(self, get_project_id_mock: mock.Mock): + app = adk_template.AdkApp(agent=_TEST_AGENT) + session = await app.async_create_session(user_id=_TEST_USER_ID) + session_id = session["id"] + + part = types.Part(text="test artifact content") + version = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + artifact=part, + session_id=session_id, + ) + assert version == 0 + + # Test string conversion + version_str = await app.async_save_artifact( + user_id=_TEST_USER_ID, + filename="test2.txt", + artifact="raw string content", + session_id=session_id, + ) + assert version_str == 0 + + loaded_str = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="test2.txt", + session_id=session_id, + ) + assert loaded_str.text == "raw string content" + + loaded = await app.async_load_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert loaded.text == "test artifact content" + + keys = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert set(keys) == {"test.txt", "test2.txt"} + + versions = await app.async_list_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert versions == [0] + + art_versions = await app.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + assert len(art_versions) == 1 + assert art_versions[0].version == 0 + + art_ver = await app.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + version=0, + ) + assert art_ver.version == 0 + + await app.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="test.txt", + session_id=session_id, + ) + keys_after = await app.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id=session_id, + ) + assert set(keys_after) == {"test2.txt"} + + @pytest.mark.asyncio + async def test_async_artifact_management_lazy_init( + self, get_project_id_mock: mock.Mock + ): + part = types.Part(text="test lazy content") + + # 1. Save Artifact lazy init + app1 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app1._tmpl_attrs.get("artifact_service") is None + version = await app1.async_save_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + artifact=part, + session_id="lazy_session", + ) + assert version == 0 + assert app1._tmpl_attrs.get("artifact_service") is not None + + # 2. Load Artifact lazy init + app2 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app2._tmpl_attrs.get("artifact_service") is None + await app2.async_load_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app2._tmpl_attrs.get("artifact_service") is not None + + # 3. List keys lazy init + app3 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app3._tmpl_attrs.get("artifact_service") is None + await app3.async_list_artifact_keys( + user_id=_TEST_USER_ID, + session_id="lazy_session", + ) + assert app3._tmpl_attrs.get("artifact_service") is not None + + # 4. Delete lazy init + app4 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app4._tmpl_attrs.get("artifact_service") is None + await app4.async_delete_artifact( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app4._tmpl_attrs.get("artifact_service") is not None + + # 5. List versions lazy init + app5 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app5._tmpl_attrs.get("artifact_service") is None + await app5.async_list_versions( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app5._tmpl_attrs.get("artifact_service") is not None + + # 6. List artifact versions lazy init + app6 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app6._tmpl_attrs.get("artifact_service") is None + await app6.async_list_artifact_versions( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + ) + assert app6._tmpl_attrs.get("artifact_service") is not None + + # 7. Get version lazy init + app7 = adk_template.AdkApp(agent=_TEST_AGENT) + assert app7._tmpl_attrs.get("artifact_service") is None + await app7.async_get_artifact_version( + user_id=_TEST_USER_ID, + filename="lazy.txt", + session_id="lazy_session", + version=0, + ) + assert app7._tmpl_attrs.get("artifact_service") is not None async def test_async_add_session_to_memory_dict( self, get_project_id_mock: mock.Mock,