diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..230c604acb 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, cast import anyio import httpx @@ -332,6 +332,8 @@ async def _establish_session( async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" + capabilities = cast(types.InitializeResult, session.initialize_result).capabilities + # Create a reverse index so we can find all prompts, resources, and # tools belonging to this session. Used for removing components from # the session group via self.disconnect_from_server. @@ -345,35 +347,38 @@ async def _aggregate_components(self, server_info: types.Implementation, session tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") + if capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..188201074b 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -1,4 +1,5 @@ import contextlib +import logging from unittest import mock import httpx @@ -125,6 +126,83 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio +async def test_client_session_group_connect_with_session_respects_negotiated_capabilities( + caplog: pytest.LogCaptureFixture, +): + from mcp import Client + from mcp.server import Server, ServerRequestContext + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="ping", + description="Ping", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(type="text", text="pong")]) + + server = Server( + "tools-only-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + + group = ClientSessionGroup() + + with caplog.at_level(logging.WARNING): + async with Client(server) as client: + assert client.initialize_result.capabilities.prompts is None + assert client.initialize_result.capabilities.resources is None + + client.session.list_prompts = mock.AsyncMock(side_effect=AssertionError("list_prompts() was called")) + client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called")) + + await group.connect_with_session(client.initialize_result.server_info, client.session) + await group.call_tool("ping") + + assert not caplog.records + + +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_tools_and_resources( + caplog: pytest.LogCaptureFixture, +): + from mcp import Client + from mcp.server import Server, ServerRequestContext + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult(prompts=[types.Prompt(name="hello", description="Hello", arguments=[])]) + + server = Server( + "prompts-only-server", + on_list_prompts=handle_list_prompts, + ) + + group = ClientSessionGroup() + + with caplog.at_level(logging.WARNING): + async with Client(server) as client: + assert client.initialize_result.capabilities.tools is None + assert client.initialize_result.capabilities.resources is None + + client.session.list_tools = mock.AsyncMock(side_effect=AssertionError("list_tools() was called")) + client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called")) + + await group.connect_with_session(client.initialize_result.server_info, client.session) + + assert not caplog.records + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook."""