diff --git a/pyproject.toml b/pyproject.toml index dc64177b1..35012a146 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "googleapis-common-protos>=1.70.0", "culsans>=0.11.0 ; python_full_version < '3.13'", "packaging>=24.0", + "typing-extensions>=4.0.0", ] classifiers = [ diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index ff9ca3ce5..5be3445aa 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -82,9 +82,7 @@ async def _dispatch_notification( ) -> bool: url = push_info.url try: - headers = None - if push_info.token: - headers = {'X-A2A-Notification-Token': push_info.token} + headers = self._build_headers(push_info) response = await self._client.post( url, @@ -103,3 +101,23 @@ async def _dispatch_notification( ) return False return True + + @staticmethod + def _authorization_header( + push_info: TaskPushNotificationConfig, + ) -> str | None: + auth = push_info.authentication + if not auth.scheme or not auth.credentials: + return None + return f'{auth.scheme} {auth.credentials}' + + def _build_headers( + self, push_info: TaskPushNotificationConfig + ) -> dict[str, str] | None: + headers: dict[str, str] = {} + if push_info.token: + headers['X-A2A-Notification-Token'] = push_info.token + authorization = self._authorization_header(push_info) + if authorization: + headers['Authorization'] = authorization + return headers or None diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index 990f6c7f5..5006f7d78 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -8,6 +8,7 @@ BasePushNotificationSender, ) from a2a.types.a2a_pb2 import ( + AuthenticationInfo, StreamResponse, Task, TaskArtifactUpdateEvent, @@ -34,8 +35,11 @@ def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, + authentication: AuthenticationInfo | None = None, ) -> TaskPushNotificationConfig: - return TaskPushNotificationConfig(id=config_id, url=url, token=token) + return TaskPushNotificationConfig( + id=config_id, url=url, token=token, authentication=authentication + ) class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): @@ -101,6 +105,61 @@ async def test_send_notification_with_token_success(self) -> None: ) mock_response.raise_for_status.assert_called_once() + async def test_send_notification_with_auth_header(self) -> None: + task_id = 'task_send_auth' + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( + url='http://notify.me/here', + token='unique_token', + authentication=AuthenticationInfo( + scheme='Bearer', credentials='token_or_jwt' + ), + ) + self.mock_config_store.get_info_for_dispatch.return_value = [config] + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + self.mock_httpx_client.post.return_value = mock_response + + await self.sender.send_notification(task_id, task_data) + + self.mock_config_store.get_info_for_dispatch.assert_awaited_once_with( + task_data.id + ) + self.mock_httpx_client.post.assert_awaited_once_with( + config.url, + json=MessageToDict(StreamResponse(task=task_data)), + headers={ + 'X-A2A-Notification-Token': 'unique_token', + 'Authorization': 'Bearer token_or_jwt', + }, + ) + mock_response.raise_for_status.assert_called_once() + + def test_authorization_header_requires_scheme_and_credentials(self) -> None: + config = _create_sample_push_config() + self.assertIsNone(self.sender._authorization_header(config)) + + config = _create_sample_push_config( + authentication=AuthenticationInfo(credentials='token_or_jwt') + ) + self.assertIsNone(self.sender._authorization_header(config)) + + config = _create_sample_push_config( + authentication=AuthenticationInfo(scheme='Bearer') + ) + self.assertIsNone(self.sender._authorization_header(config)) + + config = _create_sample_push_config( + authentication=AuthenticationInfo( + scheme='Basic', credentials='token_or_jwt' + ) + ) + self.assertEqual( + self.sender._authorization_header(config), + 'Basic token_or_jwt', + ) + async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' task_data = _create_sample_task(task_id=task_id) diff --git a/uv.lock b/uv.lock index daba3ed6e..fff2dc5fc 100644 --- a/uv.lock +++ b/uv.lock @@ -21,6 +21,7 @@ dependencies = [ { name = "packaging" }, { name = "protobuf" }, { name = "pydantic" }, + { name = "typing-extensions" }, ] [package.optional-dependencies] @@ -148,6 +149,7 @@ requires-dist = [ { name = "starlette", marker = "extra == 'all'" }, { name = "starlette", marker = "extra == 'fastapi'" }, { name = "starlette", marker = "extra == 'http-server'" }, + { name = "typing-extensions", specifier = ">=4.0.0" }, ] provides-extras = ["all", "db-cli", "encryption", "fastapi", "grpc", "http-server", "mysql", "postgresql", "signing", "sql", "sqlite", "telemetry"]