From 4120cc51cc1fe694390f072f1dec50eddbdff085 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Tue, 26 May 2026 22:12:35 +0200 Subject: [PATCH 1/2] fix(client): respect negotiated capabilities in ClientSessionGroup --- src/mcp/client/session_group.py | 55 ++++++++++++++++-------------- tests/client/test_session_group.py | 45 ++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 25 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 9610212642..d65a0b9150 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -332,6 +332,8 @@ async def _establish_session( async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" + capabilities = session.initialize_result.capabilities if session.initialize_result else None + # Create a reverse index so we can find all prompts, resources, and # tools belonging to this session. Used for removing components from # the session group via self.disconnect_from_server. @@ -345,35 +347,38 @@ async def _aggregate_components(self, server_info: types.Implementation, session tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch prompts: {err}") + if capabilities is None or capabilities.prompts is not None: + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch resources: {err}") + if capabilities is None or capabilities.resources is not None: + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except MCPError as err: # pragma: no cover - logging.warning(f"Could not fetch tools: {err}") + if capabilities is None or capabilities.tools is not None: + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except MCPError as err: # pragma: no cover + logging.warning(f"Could not fetch tools: {err}") # Clean up exit stack for session if we couldn't retrieve anything # from the server. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 6a58b39f39..5e344259d9 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -1,4 +1,5 @@ import contextlib +import logging from unittest import mock import httpx @@ -125,6 +126,50 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli mock_session.list_prompts.assert_awaited_once() +@pytest.mark.anyio +async def test_client_session_group_connect_with_session_respects_negotiated_capabilities( + caplog: pytest.LogCaptureFixture, +): + from mcp import Client + from mcp.server import Server, ServerRequestContext + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool( + name="ping", + description="Ping", + input_schema={"type": "object", "properties": {}}, + ) + ] + ) + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + return types.CallToolResult(content=[types.TextContent(type="text", text="pong")]) + + server = Server( + "tools-only-server", + on_list_tools=handle_list_tools, + on_call_tool=handle_call_tool, + ) + + group = ClientSessionGroup() + + with caplog.at_level(logging.WARNING): + async with Client(server) as client: + assert client.initialize_result.capabilities.prompts is None + assert client.initialize_result.capabilities.resources is None + + client.session.list_prompts = mock.AsyncMock(side_effect=AssertionError("list_prompts() was called")) + client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called")) + + await group.connect_with_session(client.initialize_result.server_info, client.session) + + assert not caplog.records + + @pytest.mark.anyio async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack): """Test connecting with a component name hook.""" From 35579f326c446e191fc49270f960b176b8e66f47 Mon Sep 17 00:00:00 2001 From: Jianke LIN Date: Tue, 26 May 2026 22:48:04 +0200 Subject: [PATCH 2/2] test(client): cover unadvertised capability branches --- src/mcp/client/session_group.py | 10 ++++----- tests/client/test_session_group.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index d65a0b9150..230c604acb 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,7 +11,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, cast import anyio import httpx @@ -332,7 +332,7 @@ async def _establish_session( async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: """Aggregates prompts, resources, and tools from a given session.""" - capabilities = session.initialize_result.capabilities if session.initialize_result else None + capabilities = cast(types.InitializeResult, session.initialize_result).capabilities # Create a reverse index so we can find all prompts, resources, and # tools belonging to this session. Used for removing components from @@ -347,7 +347,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session tool_to_session_temp: dict[str, mcp.ClientSession] = {} # Query the server for its prompts and aggregate to list. - if capabilities is None or capabilities.prompts is not None: + if capabilities.prompts is not None: try: prompts = (await session.list_prompts()).prompts for prompt in prompts: @@ -358,7 +358,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session logging.warning(f"Could not fetch prompts: {err}") # Query the server for its resources and aggregate to list. - if capabilities is None or capabilities.resources is not None: + if capabilities.resources is not None: try: resources = (await session.list_resources()).resources for resource in resources: @@ -369,7 +369,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session logging.warning(f"Could not fetch resources: {err}") # Query the server for its tools and aggregate to list. - if capabilities is None or capabilities.tools is not None: + if capabilities.tools is not None: try: tools = (await session.list_tools()).tools for tool in tools: diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 5e344259d9..188201074b 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -166,6 +166,39 @@ async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequ client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called")) await group.connect_with_session(client.initialize_result.server_info, client.session) + await group.call_tool("ping") + + assert not caplog.records + + +@pytest.mark.anyio +async def test_client_session_group_skips_unadvertised_tools_and_resources( + caplog: pytest.LogCaptureFixture, +): + from mcp import Client + from mcp.server import Server, ServerRequestContext + + async def handle_list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + return types.ListPromptsResult(prompts=[types.Prompt(name="hello", description="Hello", arguments=[])]) + + server = Server( + "prompts-only-server", + on_list_prompts=handle_list_prompts, + ) + + group = ClientSessionGroup() + + with caplog.at_level(logging.WARNING): + async with Client(server) as client: + assert client.initialize_result.capabilities.tools is None + assert client.initialize_result.capabilities.resources is None + + client.session.list_tools = mock.AsyncMock(side_effect=AssertionError("list_tools() was called")) + client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called")) + + await group.connect_with_session(client.initialize_result.server_info, client.session) assert not caplog.records