From b04d7e024054a490d9cbfcb24b581125baaa67fd Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 15:45:25 +0000 Subject: [PATCH 01/34] test: add interaction-model e2e suite with requirements manifest New tests/interaction/ suite asserting client<->server round trips through the public API only. Tests are organised around a requirements manifest (_requirements.py) mapping each test to the spec or SDK behaviour it exercises, with known divergences from the spec recorded on the requirement; test_coverage.py enforces that every non-deferred requirement is exercised by at least one test. Covers tools, prompts, resources, and ping against the low-level Server, plus MCPServer tool-call behaviours. Removes two 'pragma: no cover' comments on the ping send/answer paths now that they are covered. --- pyproject.toml | 3 + src/mcp/client/session.py | 2 +- src/mcp/server/session.py | 2 +- tests/interaction/__init__.py | 0 tests/interaction/_requirements.py | 213 +++++++++++++++ tests/interaction/lowlevel/__init__.py | 0 tests/interaction/lowlevel/test_ping.py | 51 ++++ tests/interaction/lowlevel/test_prompts.py | 136 ++++++++++ tests/interaction/lowlevel/test_resources.py | 135 ++++++++++ tests/interaction/lowlevel/test_tools.py | 266 +++++++++++++++++++ tests/interaction/mcpserver/__init__.py | 0 tests/interaction/mcpserver/test_tools.py | 84 ++++++ tests/interaction/test_coverage.py | 47 ++++ 13 files changed, 937 insertions(+), 2 deletions(-) create mode 100644 tests/interaction/__init__.py create mode 100644 tests/interaction/_requirements.py create mode 100644 tests/interaction/lowlevel/__init__.py create mode 100644 tests/interaction/lowlevel/test_ping.py create mode 100644 tests/interaction/lowlevel/test_prompts.py create mode 100644 tests/interaction/lowlevel/test_resources.py create mode 100644 tests/interaction/lowlevel/test_tools.py create mode 100644 tests/interaction/mcpserver/__init__.py create mode 100644 tests/interaction/mcpserver/test_tools.py create mode 100644 tests/interaction/test_coverage.py diff --git a/pyproject.toml b/pyproject.toml index d88869da1c..b98e64a487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,6 +193,9 @@ strict-no-cover = { git = "https://github.com/pydantic/strict-no-cover" } [tool.pytest.ini_options] log_cli = true xfail_strict = true +markers = [ + "requirement(id): links a test to the entry in tests/interaction/_requirements.py it exercises", +] addopts = """ --color=yes --capture=fd diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a77..b26b47870f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -449,7 +449,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques client_response = ClientResponse.validate_python(response) await responder.respond(client_response) - case types.PingRequest(): # pragma: no cover + case types.PingRequest(): with responder: return await responder.respond(types.EmptyResult()) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 20b640527a..e775cb8954 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -447,7 +447,7 @@ async def elicit_url( metadata=ServerMessageMetadata(related_request_id=related_request_id), ) - async def send_ping(self) -> types.EmptyResult: # pragma: no cover + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( types.PingRequest(), diff --git a/tests/interaction/__init__.py b/tests/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py new file mode 100644 index 0000000000..5d3437ed70 --- /dev/null +++ b/tests/interaction/_requirements.py @@ -0,0 +1,213 @@ +"""Requirements manifest for the interaction-model test suite. + +Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` and +`test_coverage.py` enforces that every non-deferred requirement is exercised by at least one test. + +Sources: + spec URL -- externally mandated by the MCP specification (deep link to the section) + `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated + `issue:#n` -- regression lock-in for a previously fixed bug + +The `behavior` sentence describes what the suite *asserts* -- which is always the SDK's current +behaviour. Where that differs from what `source` mandates, the gap is recorded in `divergence` +and the tests still pin current behaviour: this suite is the parity bar for the receive-path +rewrite, so a test that fails today proves nothing about equivalence. +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TypeVar + +import pytest + +SPEC_REVISION = "2025-11-25" +SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" + +_TestFn = TypeVar("_TestFn", bound=Callable[..., object]) + + +@dataclass(frozen=True, kw_only=True) +class Divergence: + """A documented gap between the SDK behaviour this suite pins and what `source` mandates.""" + + note: str + issue: str | None = None + + +@dataclass(frozen=True, kw_only=True) +class Requirement: + """A single testable behaviour and the provenance of why it must hold.""" + + source: str + behavior: str + divergence: Divergence | None = None + deferred: str | None = None + + +REQUIREMENTS: dict[str, Requirement] = { + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="An unhandled exception in a request handler is returned to the caller as a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec reserves -32603 Internal error for this; the low-level Server returns code 0 " + "(not a defined JSON-RPC code) and leaks str(exc) as the error message." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Ping + # ═══════════════════════════════════════════════════════════════════════════ + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools + # ═══════════════════════════════════════════════════════════════════════════ + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:optional-fields": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", + ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), + "tools:call:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#audio-content", + behavior="A tool result can carry audio content: base64 data with a mimeType.", + ), + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", + ), + "tools:call:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", + behavior="A tool result can carry an embedded resource with full text or blob contents.", + ), + "tools:call:content:multiple": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#calling-tools", + behavior="A tool result can carry multiple content blocks of different types; order is preserved.", + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." + ), + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Resources + # ═══════════════════════════════════════════════════════════════════════════ + "resources:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#listing-resources", + behavior=( + "resources/list returns the registered resources with uri, name, and the optional descriptive " + "fields supplied by the server." + ), + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), + "resources:read:binary": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns binary contents base64-encoded in blob.", + ), + "resources:read:not-found": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for an unknown URI returns a JSON-RPC error; the spec reserves -32002 for it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:get:arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:get:multi-message": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", + ), + "prompts:get:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get for an unknown prompt name returns a JSON-RPC error.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer behavioural guarantees (not spec-mandated) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:tools:handler-exception": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), + ), + "mcpserver:tools:unknown-name": Requirement( + source="sdk", + behavior="Calling a tool name that was never registered returns a tool result with isError true.", + divergence=Divergence( + note=( + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error instead. The low-level path follows the " + "spec example (see tools:call:unknown-name)." + ), + ), + ), +} + + +def requirement(requirement_id: str) -> Callable[[_TestFn], _TestFn]: + """Mark a test as exercising a requirement from :data:`REQUIREMENTS`. + + Applies the `requirement` pytest marker and records the coverage link checked by + `test_coverage.py`. Unknown IDs fail at import time so a typo surfaces as a collection + error on the offending test, not as a missing-coverage report later. + """ + if requirement_id not in REQUIREMENTS: + raise KeyError(f"Unknown requirement id {requirement_id!r}: add it to REQUIREMENTS in {__name__}") + + def apply(test_fn: _TestFn) -> _TestFn: + covered_by(requirement_id).append(f"{test_fn.__module__}.{test_fn.__qualname__}") + return pytest.mark.requirement(requirement_id)(test_fn) + + return apply + + +_COVERAGE: dict[str, list[str]] = {} + + +def covered_by(requirement_id: str) -> list[str]: + """Return the (mutable) list of test names recorded as exercising `requirement_id`.""" + return _COVERAGE.setdefault(requirement_id, []) diff --git a/tests/interaction/lowlevel/__init__.py b/tests/interaction/lowlevel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py new file mode 100644 index 0000000000..48dc2717de --- /dev/null +++ b/tests/interaction/lowlevel/test_ping.py @@ -0,0 +1,51 @@ +"""Ping interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("ping:client-to-server") +async def test_client_ping_returns_empty_result() -> None: + """A client ping is answered with an empty result, even by a server with no handlers.""" + server = Server("silent") + + async with Client(server) as client: + result = await client.send_ping() + + assert result == snapshot(EmptyResult()) + + +@requirement("ping:server-to-client") +async def test_server_ping_returns_empty_result() -> None: + """A server-initiated ping sent while a request is in flight is answered by the client. + + The tool returns the type of the ping response, proving the round trip completed inside + the handler before the tool result was produced. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ping_back", description="Ping the client.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ping_back" + pong = await ctx.session.send_ping() + return CallToolResult(content=[TextContent(text=type(pong).__name__)]) + + server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("ping_back", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py new file mode 100644 index 0000000000..64ca0ce055 --- /dev/null +++ b/tests/interaction/lowlevel/test_prompts.py @@ -0,0 +1,136 @@ +"""Prompt interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("prompts:list:basic") +async def test_list_prompts_returns_registered_prompts() -> None: + """The prompts returned by the handler reach the client with their argument declarations intact.""" + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + return ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + ), + Prompt(name="daily_standup"), + ] + ) + + server = Server("prompter", on_list_prompts=list_prompts) + + async with Client(server) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", description="The code to review.", required=True), + PromptArgument(name="style_guide", description="Optional style guide to apply."), + ], + ), + Prompt(name="daily_standup"), + ] + ) + ) + + +@requirement("prompts:get:arguments") +async def test_get_prompt_substitutes_arguments() -> None: + """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "greet" + assert params.arguments is not None + return GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text=f"Hello, {params.arguments['name']}!"))], + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with Client(server) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Hello, Ada!"))], + ) + ) + + +@requirement("prompts:get:multi-message") +async def test_get_prompt_multiple_messages_preserve_roles_and_order() -> None: + """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "geography_quiz" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with Client(server) as client: + result = await client.get_prompt("geography_quiz") + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=TextContent(text="What is the capital of France?")), + PromptMessage(role="assistant", content=TextContent(text="The capital of France is Paris.")), + PromptMessage(role="user", content=TextContent(text="And of Italy?")), + ] + ) + ) + + +@requirement("prompts:get:unknown-name") +async def test_get_prompt_unknown_name_is_protocol_error() -> None: + """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. + + The error's code and message chosen by the handler reach the client verbatim. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.name}") + + server = Server("prompter", on_get_prompt=get_prompt) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_PARAMS, message="Unknown prompt: nope")) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py new file mode 100644 index 0000000000..1d66e6722a --- /dev/null +++ b/tests/interaction/lowlevel/test_resources.py @@ -0,0 +1,135 @@ +"""Resource interactions against the low-level Server, driven through the public Client API.""" + +import base64 + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + Annotations, + BlobResourceContents, + ErrorData, + ListResourcesResult, + ReadResourceResult, + Resource, + TextResourceContents, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("resources:list:basic") +async def test_list_resources_returns_registered_resources() -> None: + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact.""" + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + return ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + ), + ] + ) + + server = Server("library", on_list_resources=list_resources) + + async with Client(server) as client: + result = await client.list_resources() + + assert result == snapshot( + ListResourcesResult( + resources=[ + Resource(uri="memo://minimal", name="minimal"), + Resource( + uri="file:///project/README.md", + name="readme", + title="Project README", + description="The project's front page.", + mime_type="text/markdown", + size=1024, + annotations=Annotations(audience=["user", "assistant"], priority=0.8), + ), + ] + ) + ) + + +@requirement("resources:read:text") +async def test_read_resource_text() -> None: + """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[TextResourceContents(uri=params.uri, mime_type="text/plain", text="Hello, world!")] + ) + + server = Server("library", on_read_resource=read_resource) + + async with Client(server) as client: + result = await client.read_resource("file:///greeting.txt") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="file:///greeting.txt", mime_type="text/plain", text="Hello, world!")] + ) + ) + + +@requirement("resources:read:binary") +async def test_read_resource_binary() -> None: + """Reading a binary resource returns its contents base64-encoded in the blob field.""" + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + return ReadResourceResult( + contents=[ + BlobResourceContents( + uri=params.uri, + mime_type="image/png", + blob=base64.b64encode(b"\x89PNG").decode(), + ) + ] + ) + + server = Server("library", on_read_resource=read_resource) + + async with Client(server) as client: + result = await client.read_resource("file:///pixel.png") + + assert result == snapshot( + ReadResourceResult( + contents=[BlobResourceContents(uri="file:///pixel.png", mime_type="image/png", blob="iVBORw==")] + ) + ) + + +@requirement("resources:read:not-found") +async def test_read_resource_unknown_uri_is_protocol_error() -> None: + """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches + the client verbatim. + """ + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + raise MCPError(code=-32002, message=f"Resource not found: {params.uri}") + + server = Server("library", on_read_resource=read_resource) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("file:///missing.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py new file mode 100644 index 0000000000..0dc899ef9c --- /dev/null +++ b/tests/interaction/lowlevel/test_tools.py @@ -0,0 +1,266 @@ +"""Tool interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + INVALID_PARAMS, + AudioContent, + CallToolResult, + EmbeddedResource, + ErrorData, + Icon, + ImageContent, + ListToolsResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, + ToolAnnotations, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="add", description="Add two integers.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) + + +@requirement("tools:call:is-error") +async def test_call_tool_execution_error_is_returned_as_result() -> None: + """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. + + Tool execution errors are part of the result so the caller (typically a model) can see + them; only protocol-level failures become JSON-RPC errors. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "flux" + return CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + + server = Server("errors", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="the flux capacitor is offline")], is_error=True) + ) + + +@requirement("tools:call:unknown-name") +async def test_call_tool_unknown_tool_is_protocol_error() -> None: + """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. + + The error's code, message, and data chosen by the handler reach the client verbatim. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + raise MCPError(code=INVALID_PARAMS, message=f"Unknown tool: {params.name}", data={"requested": params.name}) + + server = Server("errors", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("nope", {}) + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Unknown tool: nope", data={"requested": "nope"}) + ) + + +@requirement("protocol:error:internal-error") +async def test_call_tool_uncaught_exception_becomes_error_response() -> None: + """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. + + The low-level server reports it with code 0 and the exception text as the message; see the + divergence note on the requirement. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "explode" + raise ValueError("boom") + + server = Server("errors", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("explode", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="boom")) + + +@requirement("tools:list:basic") +async def test_list_tools_returns_registered_tools() -> None: + """The tools advertised by the server's list handler arrive at the client unchanged.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + + server = Server("calculator", on_list_tools=list_tools) + + async with Client(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="add", + description="Add two integers.", + input_schema={ + "type": "object", + "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, + "required": ["a", "b"], + }, + ), + Tool(name="reset", description="Reset the calculator.", input_schema={"type": "object"}), + ] + ) + ) + + +@requirement("tools:list:optional-fields") +async def test_list_tools_optional_fields_round_trip() -> None: + """Every optional Tool field the server supplies reaches the client unchanged.""" + + tool = Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[tool]) + + server = Server("annotated", on_list_tools=list_tools) + + async with Client(server) as client: + result = await client.list_tools() + + assert result == snapshot( + ListToolsResult( + tools=[ + Tool( + name="annotated", + title="Annotated tool", + description="A tool carrying every optional field.", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"answer": {"type": "integer"}}}, + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + annotations=ToolAnnotations(title="Display title", read_only_hint=True, idempotent_hint=True), + _meta={"example.com/source": "interaction-suite"}, + ) + ] + ) + ) + + +@requirement("tools:call:content:multiple") +@requirement("tools:call:content:image") +@requirement("tools:call:content:audio") +@requirement("tools:call:content:resource-link") +@requirement("tools:call:content:embedded-resource") +async def test_call_tool_multiple_content_block_types() -> None: + """A tool result can mix every content block type; all of them arrive in order. + + The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the + snapshot pins the exact bytes the client receives. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="render", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "render" + return CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + + server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("render", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="all five content block types"), + ImageContent(data="aW1n", mime_type="image/png"), + AudioContent(data="YXVk", mime_type="audio/wav"), + ResourceLink(name="report", uri="resource://reports/1", description="The full report"), + EmbeddedResource( + resource=TextResourceContents(uri="resource://reports/1", mime_type="text/plain", text="contents") + ), + ] + ) + ) + + +@requirement("tools:call:structured-content") +async def test_call_tool_structured_content() -> None: + """A tool result carrying structured content alongside content delivers both to the client.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="sum", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "sum" + return CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5}) + + server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("sum", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) diff --git a/tests/interaction/mcpserver/__init__.py b/tests/interaction/mcpserver/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py new file mode 100644 index 0000000000..ff383d7726 --- /dev/null +++ b/tests/interaction/mcpserver/test_tools.py @@ -0,0 +1,84 @@ +"""Tool interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver.exceptions import ToolError +from mcp.types import CallToolResult, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool function; its return value comes back as text content. + + MCPServer also derives an output schema from the return annotation and attaches the + matching structuredContent to the result. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + return str(a + b) + + async with Client(mcp) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) + + +@requirement("mcpserver:tools:handler-exception") +async def test_call_tool_function_exception_becomes_error_result() -> None: + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def explode() -> str: + raise ValueError("boom") + + async with Client(mcp) as client: + result = await client.call_tool("explode", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool explode: boom")], is_error=True) + ) + + +@requirement("mcpserver:tools:handler-exception") +async def test_call_tool_tool_error_becomes_error_result() -> None: + """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + mcp = MCPServer("errors") + + @mcp.tool() + def flux() -> str: + raise ToolError("flux capacitor offline") + + async with Client(mcp) as client: + result = await client.call_tool("flux", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool flux: flux capacitor offline")], is_error=True) + ) + + +@requirement("mcpserver:tools:unknown-name") +async def test_call_tool_unknown_name_returns_error_result() -> None: + """Calling a tool name that was never registered is reported as an is_error result. + + The spec classifies unknown tools as a protocol error; see the divergence note on the + requirement. + """ + mcp = MCPServer("errors") + + @mcp.tool() + def add() -> None: + """A registered tool; the test calls a different name.""" + + async with Client(mcp) as client: + result = await client.call_tool("nope", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py new file mode 100644 index 0000000000..929bb103ed --- /dev/null +++ b/tests/interaction/test_coverage.py @@ -0,0 +1,47 @@ +"""Enforces the contract between the requirements manifest and the test suite. + +Every non-deferred entry in :data:`REQUIREMENTS` must be exercised by at least one test, and every +`@requirement(...)` mark must reference a manifest entry. Test modules are imported directly +(rather than relying on pytest collection) so the check holds even when only this file is run. +""" + +import importlib +from pathlib import Path + +import pytest + +from tests.interaction._requirements import REQUIREMENTS, covered_by, requirement + +_SUITE_ROOT = Path(__file__).parent + + +def _import_all_test_modules() -> None: + """Import every test module in the suite so their `@requirement` decorators register.""" + for path in sorted(_SUITE_ROOT.rglob("test_*.py")): + relative = path.relative_to(_SUITE_ROOT).with_suffix("") + importlib.import_module(f"{__package__}.{'.'.join(relative.parts)}") + + +def test_every_requirement_is_exercised() -> None: + """Each non-deferred requirement is covered by at least one test (deferred ones by none).""" + _import_all_test_modules() + + uncovered = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is None and not covered_by(requirement_id) + ] + assert not uncovered, f"Requirements with no test and no deferred reason: {uncovered}" + + stale_deferrals = [ + requirement_id + for requirement_id, spec in sorted(REQUIREMENTS.items()) + if spec.deferred is not None and covered_by(requirement_id) + ] + assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" + + +def test_unknown_requirement_id_is_rejected() -> None: + """Marking a test with an ID that is not in the manifest fails at decoration time.""" + with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): + requirement("tools:call:does-not-exist") From 571066285a89cd944781cf491eab3cfd4466dc3e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 16:06:22 +0000 Subject: [PATCH 02/34] test: add lifecycle, completion, logging, and MCPServer feature interaction tests Extends the interaction suite with the initialize handshake (server identity, instructions, capability derivation, client identity and capabilities as seen by the server), completion round trips, logging notifications, and the MCPServer resource/prompt/structured-output behaviours. Records two more divergences on the requirements manifest: MCPServer reports unknown resources and prompts with error code 0 rather than the codes the spec documents. Removes the 'pragma: no cover' from the method-not-found fallback now that it is covered. --- src/mcp/server/lowlevel/server.py | 2 +- tests/interaction/_requirements.py | 138 ++++++++++++ tests/interaction/lowlevel/test_completion.py | 106 ++++++++++ tests/interaction/lowlevel/test_initialize.py | 199 ++++++++++++++++++ tests/interaction/lowlevel/test_logging.py | 112 ++++++++++ tests/interaction/mcpserver/test_prompts.py | 90 ++++++++ tests/interaction/mcpserver/test_resources.py | 129 ++++++++++++ tests/interaction/mcpserver/test_tools.py | 102 +++++++++ 8 files changed, 877 insertions(+), 1 deletion(-) create mode 100644 tests/interaction/lowlevel/test_completion.py create mode 100644 tests/interaction/lowlevel/test_initialize.py create mode 100644 tests/interaction/lowlevel/test_logging.py create mode 100644 tests/interaction/mcpserver/test_prompts.py create mode 100644 tests/interaction/mcpserver/test_resources.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace45..419e06f770 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -513,7 +513,7 @@ async def _handle_request( if raise_exceptions: # pragma: no cover raise err response = types.ErrorData(code=0, message=str(err)) - else: # pragma: no cover + else: response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") if isinstance(response, types.ErrorData) and span is not None: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 5d3437ed70..d73c74c5fa 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -60,6 +60,42 @@ class Requirement: ), ), # ═══════════════════════════════════════════════════════════════════════════ + # Lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "lifecycle:initialize:server-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The initialize result identifies the server: name and version, plus title when declared.", + ), + "lifecycle:initialize:instructions": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Server-declared instructions are returned in the initialize result, and omitted when the " + "server declares none." + ), + ), + "lifecycle:initialize:capabilities:from-handlers": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The server advertises a capability for each feature area it has a registered handler for, " + "and omits the capability for areas it does not." + ), + ), + "lifecycle:initialize:capabilities:minimal": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with no feature handlers advertises no feature capabilities.", + ), + "lifecycle:initialize:client-info": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="The client's name, version, and title are visible to server handlers after initialization.", + ), + "lifecycle:initialize:client-capabilities": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "The client capabilities visible to the server reflect which client callbacks are configured " + "(sampling, elicitation, roots)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ # Ping # ═══════════════════════════════════════════════════════════════════════════ "ping:client-to-server": Requirement( @@ -123,6 +159,53 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", ), + "tools:call:invalid-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content), not a protocol error." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:complete:prompt-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:complete:resource-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:complete:context": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#context", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:message:notification": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data, in the order the server sent them." + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Resources # ═══════════════════════════════════════════════════════════════════════════ @@ -167,6 +250,61 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # MCPServer behavioural guarantees (not spec-mandated) # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:tools:output-schema:model": Requirement( + source="sdk", + behavior=( + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." + ), + ), + "mcpserver:tools:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." + ), + ), + "mcpserver:resources:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), + ), + "mcpserver:resources:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resources:unknown-uri": Requirement( + source="sdk", + behavior="resources/read for a URI matching no registered resource returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + "mcpserver:prompts:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompts:unknown-name": Requirement( + source="sdk", + behavior="prompts/get for a name that was never registered returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), "mcpserver:tools:handler-exception": Requirement( source="sdk", behavior=( diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py new file mode 100644 index 0000000000..91fd20a5a0 --- /dev/null +++ b/tests/interaction/lowlevel/test_completion.py @@ -0,0 +1,106 @@ +"""Completion interactions against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + METHOD_NOT_FOUND, + CompleteResult, + Completion, + ErrorData, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("completion:complete:prompt-ref") +async def test_complete_prompt_argument() -> None: + """Completing a prompt argument delivers the ref, argument name, and current value to the handler. + + The returned values are filtered by the argument's value, proving the value reached the handler. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + assert params.ref.name == "code_review" + assert params.argument.name == "language" + candidates = ["python", "pytorch", "ruby"] + matches = [candidate for candidate in candidates if candidate.startswith(params.argument.value)] + return CompleteResult(completion=Completion(values=matches, total=len(matches), has_more=False)) + + server = Server("completer", on_completion=completion) + + async with Client(server) as client: + result = await client.complete( + PromptReference(name="code_review"), argument={"name": "language", "value": "py"} + ) + + assert result == snapshot( + CompleteResult(completion=Completion(values=["python", "pytorch"], total=2, has_more=False)) + ) + + +@requirement("completion:complete:resource-ref") +async def test_complete_resource_template_variable() -> None: + """Completing a URI template variable delivers the template URI and variable name to the handler.""" + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, ResourceTemplateReference) + assert params.ref.uri == "github://repos/{owner}/{repo}" + assert params.argument.name == "owner" + return CompleteResult(completion=Completion(values=[f"{params.argument.value}contextprotocol"])) + + server = Server("completer", on_completion=completion) + + async with Client(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "owner", "value": "model"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) + + +@requirement("completion:complete:context") +async def test_complete_receives_context_arguments() -> None: + """Previously-resolved arguments passed as completion context reach the handler. + + The returned value is derived from the context, proving it arrived. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert params.argument.name == "repo" + assert params.context is not None + assert params.context.arguments is not None + return CompleteResult(completion=Completion(values=[f"{params.context.arguments['owner']}/python-sdk"])) + + server = Server("completer", on_completion=completion) + + async with Client(server) as client: + result = await client.complete( + ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) + + +@requirement("completion:complete:not-supported") +async def test_complete_without_handler_is_method_not_found() -> None: + """A server with no completion handler advertises no completions capability and rejects the request.""" + server = Server("incomplete") + + async with Client(server) as client: + assert client.initialize_result.capabilities.completions is None + + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="anything"), argument={"name": "topic", "value": ""}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py new file mode 100644 index 0000000000..029104f0d9 --- /dev/null +++ b/tests/interaction/lowlevel/test_initialize.py @@ -0,0 +1,199 @@ +"""Initialization handshake against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + CompletionsCapability, + Icon, + Implementation, + LoggingCapability, + PromptsCapability, + ResourcesCapability, + ServerCapabilities, + TextContent, + ToolsCapability, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("lifecycle:initialize:server-info") +async def test_initialize_returns_server_info() -> None: + """Every identity field the server declares is returned to the client in server_info.""" + server = Server( + "greeter", + version="1.2.3", + title="Greeter", + description="Greets people.", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + + async with Client(server) as client: + server_info = client.initialize_result.server_info + + assert server_info == snapshot( + Implementation( + name="greeter", + title="Greeter", + description="Greets people.", + version="1.2.3", + website_url="https://example.com/greeter", + icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], + ) + ) + + +@requirement("lifecycle:initialize:instructions") +async def test_initialize_returns_instructions() -> None: + """Instructions are returned when the server declares them and omitted when it does not.""" + async with Client(Server("guided", instructions="Call the add tool.")) as client: + assert client.initialize_result.instructions == snapshot("Call the add tool.") + + async with Client(Server("unguided")) as client: + assert client.initialize_result.instructions is None + + +@requirement("lifecycle:initialize:capabilities:from-handlers") +async def test_initialize_capabilities_reflect_registered_handlers() -> None: + """Each feature area with a registered handler is advertised as a capability. + + The in-memory transport connects with default initialization options, so the + list_changed flags are always False regardless of the server's notification behaviour. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> types.EmptyResult: + """Registered only so the subscribe sub-capability is advertised; never called.""" + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered only so the prompts capability is advertised; never called.""" + raise NotImplementedError + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> types.EmptyResult: + """Registered only so the logging capability is advertised; never called.""" + raise NotImplementedError + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + server = Server( + "full", + on_list_tools=list_tools, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + on_list_prompts=list_prompts, + on_set_logging_level=set_logging_level, + on_completion=completion, + ) + + async with Client(server) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot( + ServerCapabilities( + experimental={}, + logging=LoggingCapability(), + prompts=PromptsCapability(list_changed=False), + resources=ResourcesCapability(subscribe=True, list_changed=False), + tools=ToolsCapability(list_changed=False), + completions=CompletionsCapability(), + ) + ) + + +@requirement("lifecycle:initialize:capabilities:minimal") +async def test_initialize_minimal_server_advertises_no_capabilities() -> None: + """A server with no feature handlers advertises no feature capabilities.""" + async with Client(Server("bare")) as client: + capabilities = client.initialize_result.capabilities + + assert capabilities == snapshot(ServerCapabilities(experimental={})) + + +@requirement("lifecycle:initialize:client-info") +async def test_initialize_server_sees_client_info() -> None: + """The client identity supplied to Client is visible to server handlers after initialization.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="whoami", description="Report the caller.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + assert ctx.session.client_params is not None + client_info = ctx.session.client_params.client_info + return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + client = Client(server, client_info=Implementation(name="acme-agent", version="9.9.9")) + + async with client: + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) + + +@requirement("lifecycle:initialize:client-capabilities") +async def test_initialize_server_sees_client_capabilities() -> None: + """The client capabilities visible to the server reflect which callbacks the client configured.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="abilities", description="Report capabilities.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "abilities" + assert ctx.session.client_params is not None + capabilities = ctx.session.client_params.capabilities + declared = [ + name + for name, value in ( + ("sampling", capabilities.sampling), + ("elicitation", capabilities.elicitation), + ("roots", capabilities.roots), + ) + if value is not None + ] + return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) + + async def list_roots(context: object) -> types.ListRootsResult: + """Registered only so the client declares the roots capability; never called.""" + raise NotImplementedError + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) + + async with Client(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("abilities", {}) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots")])) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py new file mode 100644 index 0000000000..600724259f --- /dev/null +++ b/tests/interaction/lowlevel/test_logging.py @@ -0,0 +1,112 @@ +"""Logging interactions against the low-level Server, driven through the public Client API. + +Notification ordering: the in-memory transport delivers every server-to-client message on one +ordered stream, and the client's receive loop dispatches each incoming message to completion +before reading the next one. Together these guarantee that every notification the server sends +before its response reaches the client callback before the originating request returns, so tests +collect notifications into a plain list and assert after the request completes -- no events, no +waiting. This does not generalise to transports that split messages across streams (the +streamable HTTP standalone GET stream); tests over those transports must synchronise explicitly. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ALL_LEVELS: tuple[types.LoggingLevel, ...] = ( + "debug", + "info", + "notice", + "warning", + "error", + "critical", + "alert", + "emergency", +) + + +@requirement("logging:set-level") +async def test_set_logging_level_reaches_handler() -> None: + """The level requested by the client is delivered to the server's handler verbatim.""" + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + assert params.level == "warning" + return EmptyResult() + + server = Server("logger", on_set_logging_level=set_logging_level) + + async with Client(server) as client: + result = await client.set_logging_level("warning") + + assert result == snapshot(EmptyResult()) + + +@requirement("logging:message:notification") +async def test_log_messages_reach_logging_callback_in_order() -> None: + """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. + + The two messages pin the full notification shape: severity, optional logger name, and both + string and structured data payloads. + """ + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="chatty", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "chatty" + await ctx.session.send_log_message(level="info", data="starting up", logger="app.lifecycle") + await ctx.session.send_log_message(level="error", data={"code": 502, "retryable": True}) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, logging_callback=collect) as client: + result = await client.call_tool("chatty", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="info", logger="app.lifecycle", data="starting up"), + LoggingMessageNotificationParams(level="error", data={"code": 502, "retryable": True}), + ] + ) + + +@requirement("logging:message:all-levels") +async def test_log_messages_at_every_severity_level() -> None: + """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="siren", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "siren" + for level in ALL_LEVELS: + await ctx.session.send_log_message(level=level, data=f"a {level} message") + return CallToolResult(content=[TextContent(text="logged")]) + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, logging_callback=collect) as client: + await client.call_tool("siren", {}) + + assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py new file mode 100644 index 0000000000..27b44773a6 --- /dev/null +++ b/tests/interaction/mcpserver/test_prompts.py @@ -0,0 +1,90 @@ +"""Prompt interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + GetPromptResult, + ListPromptsResult, + Prompt, + PromptArgument, + PromptMessage, + TextContent, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:prompts:decorated") +async def test_list_prompts_derives_arguments_from_signature() -> None: + """A decorated prompt is listed with arguments derived from the function signature. + + Parameters without a default are required; the description comes from the docstring. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def code_review(code: str, style_guide: str = "pep8") -> str: + """Review a piece of code.""" + raise NotImplementedError # registered for listing only; never rendered + + async with Client(mcp) as client: + result = await client.list_prompts() + + assert result == snapshot( + ListPromptsResult( + prompts=[ + Prompt( + name="code_review", + description="Review a piece of code.", + arguments=[ + PromptArgument(name="code", required=True), + PromptArgument(name="style_guide", required=False), + ], + ) + ] + ) + ) + + +@requirement("mcpserver:prompts:decorated") +async def test_get_prompt_renders_function_return() -> None: + """The decorated function's string return value is rendered as a single user message.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A personalised greeting.""" + return f"Say hello to {name}." + + async with Client(mcp) as client: + result = await client.get_prompt("greet", {"name": "Ada"}) + + assert result == snapshot( + GetPromptResult( + description="A personalised greeting.", + messages=[PromptMessage(role="user", content=TextContent(text="Say hello to Ada."))], + ) + ) + + +@requirement("mcpserver:prompts:unknown-name") +async def test_get_unknown_prompt_is_error() -> None: + """Getting a prompt name that was never registered fails with a JSON-RPC error.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; the test requests a different name.""" + raise NotImplementedError + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("nope") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py new file mode 100644 index 0000000000..801e60663a --- /dev/null +++ b/tests/interaction/mcpserver/test_resources.py @@ -0,0 +1,129 @@ +"""Resource interactions against MCPServer, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError +from mcp.client.client import Client +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + ErrorData, + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:resources:static") +async def test_read_static_resource() -> None: + """A function registered for a fixed URI is served at that URI with its return value as text.""" + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + async with Client(mcp) as client: + result = await client.read_resource("config://app") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="theme = dark")] + ) + ) + + +@requirement("mcpserver:resources:static") +async def test_list_static_and_templated_resources() -> None: + """Statically-registered resources appear in resources/list; templated ones only in templates/list. + + The name and description are derived from the function name and docstring; the MIME type + defaults to text/plain. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + raise NotImplementedError # registered for listing only; never read + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + raise NotImplementedError # registered for listing only; never read + + async with Client(mcp) as client: + resources = await client.list_resources() + templates = await client.list_resource_templates() + + assert resources == snapshot( + ListResourcesResult( + resources=[ + Resource( + name="app_config", + uri="config://app", + description="The application configuration.", + mime_type="text/plain", + ) + ] + ) + ) + assert templates == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate( + name="user_profile", + uri_template="users://{user_id}/profile", + description="A user's profile.", + mime_type="text/plain", + ) + ] + ) + ) + + +@requirement("mcpserver:resources:template") +async def test_read_templated_resource() -> None: + """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" + mcp = MCPServer("library") + + @mcp.resource("users://{user_id}/profile") + def user_profile(user_id: str) -> str: + """A user's profile.""" + return f"profile for {user_id}" + + async with Client(mcp) as client: + result = await client.read_resource("users://42/profile") + + assert result == snapshot( + ReadResourceResult( + contents=[TextResourceContents(uri="users://42/profile", mime_type="text/plain", text="profile for 42")] + ) + ) + + +@requirement("mcpserver:resources:unknown-uri") +async def test_read_unknown_uri_is_error() -> None: + """Reading a URI that matches no registered resource fails with a JSON-RPC error. + + The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """A registered resource; the test reads a different URI.""" + raise NotImplementedError + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("config://missing") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index ff383d7726..30def4870a 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -2,6 +2,7 @@ import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from mcp.client.client import Client from mcp.server.mcpserver import MCPServer @@ -82,3 +83,104 @@ def add() -> None: result = await client.call_tool("nope", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) + + +@requirement("mcpserver:tools:output-schema:model") +async def test_call_tool_model_return_becomes_structured_content() -> None: + """A tool returning a pydantic model advertises the model's schema as the tool's output schema + and returns the model's fields as structured content alongside a serialised text block. + """ + mcp = MCPServer("weather") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def get_weather() -> Weather: + return Weather(temperature=22.5, conditions="sunny") + + async with Client(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("get_weather", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": { + "temperature": {"title": "Temperature", "type": "number"}, + "conditions": {"title": "Conditions", "type": "string"}, + }, + "required": ["temperature", "conditions"], + "title": "Weather", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="""\ +{ + "temperature": 22.5, + "conditions": "sunny" +}\ +""" + ) + ], + structured_content={"temperature": 22.5, "conditions": "sunny"}, + ) + ) + + +@requirement("mcpserver:tools:output-schema:wrapped") +async def test_call_tool_list_return_is_wrapped_in_result_key() -> None: + """A tool returning a list wraps the value under a "result" key in both the generated output + schema and the structured content. + """ + mcp = MCPServer("primes") + + @mcp.tool() + def primes() -> list[int]: + return [2, 3, 5] + + async with Client(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("primes", {}) + + assert listed.tools[0].output_schema == snapshot( + { + "properties": {"result": {"items": {"type": "integer"}, "title": "Result", "type": "array"}}, + "required": ["result"], + "title": "primesOutput", + "type": "object", + } + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="2"), TextContent(text="3"), TextContent(text="5")], + structured_content={"result": [2, 3, 5]}, + ) + ) + + +@requirement("tools:call:invalid-arguments") +async def test_call_tool_invalid_arguments_become_error_result() -> None: + """Arguments that fail validation against the tool's signature are reported as an is_error + result describing the failure, not as a protocol error. + + The description is raw pydantic output (version-dependent and leaking the internal argument + model name), so only the stable prefix is asserted rather than the full text. + """ + mcp = MCPServer("adder") + + @mcp.tool() + def add(a: int, b: int) -> str: + """Validation rejects the arguments before the function is ever called.""" + raise NotImplementedError + + async with Client(mcp) as client: + result = await client.call_tool("add", {"b": 3}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + assert result.content[0].text.startswith("Error executing tool add: 1 validation error") From 521699705edef6733ea87f55632bf9a0d9a952a2 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 16:58:27 +0000 Subject: [PATCH 03/34] test: add server-initiated request and notification interaction tests Covers the server-to-client half of the interaction model: sampling, form-mode elicitation, roots, progress in both directions, list_changed notifications, and request cancellation, all against the low-level Server through the public Client API. Records a further divergence: the server answers cancelled requests with an error response where the spec says no response should be sent. Removes five more 'pragma: no cover' comments on paths these tests now cover (server list_changed senders, the client roots send path, and the default elicitation callback). --- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 4 +- src/mcp/server/session.py | 4 +- tests/interaction/_requirements.py | 152 +++++++++ .../interaction/lowlevel/test_cancellation.py | 138 +++++++++ .../interaction/lowlevel/test_elicitation.py | 161 ++++++++++ tests/interaction/lowlevel/test_initialize.py | 3 +- .../interaction/lowlevel/test_list_changed.py | 109 +++++++ tests/interaction/lowlevel/test_progress.py | 127 ++++++++ tests/interaction/lowlevel/test_roots.py | 130 ++++++++ tests/interaction/lowlevel/test_sampling.py | 291 ++++++++++++++++++ 11 files changed, 1115 insertions(+), 6 deletions(-) create mode 100644 tests/interaction/lowlevel/test_cancellation.py create mode 100644 tests/interaction/lowlevel/test_elicitation.py create mode 100644 tests/interaction/lowlevel/test_list_changed.py create mode 100644 tests/interaction/lowlevel/test_progress.py create mode 100644 tests/interaction/lowlevel/test_roots.py create mode 100644 tests/interaction/lowlevel/test_sampling.py diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 34d6a360fa..b33fea4052 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -305,4 +305,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta async def send_roots_list_changed(self) -> None: """Send a notification that the roots list has changed.""" # TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support. - await self.session.send_roots_list_changed() # pragma: no cover + await self.session.send_roots_list_changed() diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b26b47870f..cf92696682 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -74,7 +74,7 @@ async def _default_elicitation_callback( context: RequestContext[ClientSession], params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( # pragma: no cover + return types.ErrorData( code=types.INVALID_REQUEST, message="Elicitation not supported", ) @@ -408,7 +408,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None return result - async def send_roots_list_changed(self) -> None: # pragma: no cover + async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index e775cb8954..b577f278a7 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -479,11 +479,11 @@ async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) - async def send_tool_list_changed(self) -> None: # pragma: no cover + async def send_tool_list_changed(self) -> None: """Send a tool list changed notification.""" await self.send_notification(types.ToolListChangedNotification()) - async def send_prompt_list_changed(self) -> None: # pragma: no cover + async def send_prompt_list_changed(self) -> None: """Send a prompt list changed notification.""" await self.send_notification(types.PromptListChangedNotification()) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d73c74c5fa..1091ade9a9 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -96,6 +96,61 @@ class Requirement: ), ), # ═══════════════════════════════════════════════════════════════════════════ + # Cancellation + # ═══════════════════════════════════════════════════════════════════════════ + "cancellation:in-flight": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification for an in-flight request stops the server-side handler, and the " + "caller's pending request fails with an error response." + ), + divergence=Divergence( + note=( + "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " + "request; the server sends an error response (code 0, 'Request cancelled'), which is what " + "unblocks the SDK client's pending call." + ), + ), + ), + "cancellation:server-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), + "cancellation:unknown-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A cancellation notification referencing an unknown or already-completed request is ignored without error." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Progress + # ═══════════════════════════════════════════════════════════════════════════ + "progress:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a handler during a request are delivered to the caller's " + "progress callback, in order, with their progress, total, and message." + ), + ), + "progress:token-propagation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Supplying a progress callback attaches a progress token to the outgoing request, which the " + "server-side handler can observe in its request metadata." + ), + ), + "progress:no-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Without a progress callback no token is attached, and a handler that reports progress anyway " + "sends nothing." + ), + ), + "progress:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior="A progress notification sent by the client is delivered to the server's progress handler.", + ), + # ═══════════════════════════════════════════════════════════════════════════ # Ping # ═══════════════════════════════════════════════════════════════════════════ "ping:client-to-server": Requirement( @@ -229,6 +284,21 @@ class Requirement: behavior="resources/read for an unknown URI returns a JSON-RPC error; the spec reserves -32002 for it.", ), # ═══════════════════════════════════════════════════════════════════════════ + # Notifications: list_changed (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "notifications:tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior="A tools/list_changed notification sent by the server reaches the client's message handler.", + ), + "notifications:resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior="A resources/list_changed notification sent by the server reaches the client's message handler.", + ), + "notifications:prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior="A prompts/list_changed notification sent by the server reaches the client's message handler.", + ), + # ═══════════════════════════════════════════════════════════════════════════ # Prompts # ═══════════════════════════════════════════════════════════════════════════ "prompts:list:basic": Requirement( @@ -248,6 +318,88 @@ class Requirement: behavior="prompts/get for an unknown prompt name returns a JSON-RPC error.", ), # ═══════════════════════════════════════════════════════════════════════════ + # Sampling (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "sampling:create-message:round-trip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A sampling/createMessage request from a server handler is answered by the client's sampling " + "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." + ), + ), + "sampling:create-message:params": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "The sampling parameters supplied by the server (messages, maxTokens, systemPrompt, " + "modelPreferences, temperature, stopSequences) reach the client callback intact." + ), + ), + "sampling:create-message:image-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#message-content", + behavior="Sampling messages can carry image content: base64 data with a mimeType.", + ), + "sampling:create-message:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", + behavior="A sampling callback that returns an error is surfaced to the requesting handler as an MCPError.", + ), + "sampling:create-message:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Elicitation (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "elicitation:form:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation", + behavior=( + "A form-mode elicitation answered with action 'accept' returns the user's content to the " + "requesting handler, validated against the requested schema." + ), + ), + "elicitation:form:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", + ), + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability fails with " + "an error rather than hanging or being silently dropped." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list:round-trip": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#capabilities", + behavior=( + "A roots/list request to a client that did not declare the roots capability fails with an " + "error rather than hanging or being silently dropped." + ), + ), + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + # ═══════════════════════════════════════════════════════════════════════════ # MCPServer behavioural guarantees (not spec-mandated) # ═══════════════════════════════════════════════════════════════════════════ "mcpserver:tools:output-schema:model": Requirement( diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py new file mode 100644 index 0000000000..30821c1294 --- /dev/null +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -0,0 +1,138 @@ +"""Cancellation interactions against the low-level Server, driven through the public Client API. + +There is no client-side cancellation API: cancelling means sending a CancelledNotification +carrying the request id, which only the server-side handler can observe (`ctx.request_id`), so +these tests capture the id from inside the blocked handler before cancelling. The handler blocks +on an Event rather than a sleep, and every wait is bounded by `anyio.fail_after`. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("cancellation:in-flight") +async def test_cancellation_stops_in_flight_handler() -> None: + """Cancelling an in-flight request interrupts its handler and fails the pending call. + + The server answers the cancelled request with an error response (the spec says it should + not respond at all; see the divergence note on the requirement), so the caller's pending + request raises rather than hanging. + """ + started = anyio.Event() + handler_cancelled = anyio.Event() + request_ids: list[types.RequestId] = [] + errors: list[ErrorData] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + try: + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + errors.append(exc_info.value.error) + + task_group.start_soon(call_and_capture_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification( + params=types.CancelledNotificationParams(request_id=request_ids[0], reason="user aborted") + ) + ) + + await handler_cancelled.wait() + + assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) + + +@requirement("cancellation:server-survives") +async def test_session_serves_requests_after_cancellation() -> None: + """A request cancelled mid-flight does not poison the session: the next request succeeds.""" + started = anyio.Event() + request_ids: list[types.RequestId] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + assert ctx.request_id is not None + request_ids.append(ctx.request_id) + started.set() + await anyio.Event().wait() # blocks until cancelled + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_swallow_cancellation_error() -> None: + with pytest.raises(MCPError): + await client.call_tool("block", {}) + + task_group.start_soon(call_and_swallow_cancellation_error) + await started.wait() + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=request_ids[0])) + ) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("cancellation:unknown-request") +async def test_cancellation_for_unknown_request_is_ignored() -> None: + """A cancellation referencing a request id that is not in flight is ignored without error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="unbothered")]) + + server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + await client.session.send_notification( + types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) + ) + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py new file mode 100644 index 0000000000..6017580d86 --- /dev/null +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -0,0 +1,161 @@ +"""Form-mode elicitation against the low-level Server, driven through the public Client API.""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, ElicitRequestFormParams, ElicitResult, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +REQUESTED_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], +} + + +@requirement("elicitation:form:accept") +async def test_elicit_form_accepted_content_returns_to_handler() -> None: + """An accepted form elicitation returns the user's content to the requesting handler. + + The tool reports the action as text and the received content as structured content, proving + the client's answer made it back into the tool's own result. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form("Choose a username.", REQUESTED_SCHEMA) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) + + async with Client(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Choose a username.", + requested_schema={ + "type": "object", + "properties": { + "username": {"type": "string"}, + "newsletter": {"type": "boolean"}, + }, + "required": ["username"], + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept")], + structured_content={"username": "ada", "newsletter": True}, + ) + ) + + +@requirement("elicitation:form:decline") +async def test_elicit_form_decline_returns_no_content() -> None: + """A declined form elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with Client(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:form:cancel") +async def test_elicit_form_cancel_returns_no_content() -> None: + """A cancelled form elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="confirm", description="Ask for confirmation.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "confirm" + answer = await ctx.session.elicit_form("Proceed?", {"type": "object", "properties": {}}) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("confirmer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with Client(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("confirm", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:form:not-supported") +async def test_elicit_form_without_callback_is_error() -> None: + """Eliciting from a client that configured no elicitation callback fails with an error. + + The client's default callback answers with an Invalid request error, which the server-side + elicit call raises as an MCPError; the tool reports the code and message it caught. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="ask", description="Ask the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask" + try: + await ctx.session.elicit_form("Anyone there?", {"type": "object", "properties": {}}) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # elicit_form cannot succeed without a client callback + + server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("ask", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 029104f0d9..6ade1de9da 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -4,6 +4,7 @@ from inline_snapshot import snapshot from mcp import types +from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( @@ -184,7 +185,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ] return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) - async def list_roots(context: object) -> types.ListRootsResult: + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: """Registered only so the client declares the roots capability; never called.""" raise NotImplementedError diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py new file mode 100644 index 0000000000..de37d0e4eb --- /dev/null +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -0,0 +1,109 @@ +"""List-changed notifications from the low-level Server, driven through the public Client API. + +The notifications are emitted from inside a tool call, so the ordering guarantee described in +test_logging.py applies: they reach the client's message handler before the tool call returns, +and the tests assert on a plain collected list with no synchronisation. The collector records +every message the handler receives, so the assertions also prove nothing else was delivered. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.session import RequestResponder +from mcp.types import ( + CallToolResult, + ClientResult, + PromptListChangedNotification, + ResourceListChangedNotification, + ServerNotification, + ServerRequest, + TextContent, + ToolListChangedNotification, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" + + +@requirement("notifications:tools:list-changed") +async def test_tool_list_changed_notification() -> None: + """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="install", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "install" + await ctx.session.send_tool_list_changed() + return CallToolResult(content=[TextContent(text="installed")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, message_handler=collect) as client: + await client.call_tool("install", {}) + + assert received == snapshot([ToolListChangedNotification()]) + + +@requirement("notifications:resources:list-changed") +async def test_resource_list_changed_notification() -> None: + """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="mount", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "mount" + await ctx.session.send_resource_list_changed() + return CallToolResult(content=[TextContent(text="mounted")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, message_handler=collect) as client: + await client.call_tool("mount", {}) + + assert received == snapshot([ResourceListChangedNotification()]) + + +@requirement("notifications:prompts:list-changed") +async def test_prompt_list_changed_notification() -> None: + """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="learn", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "learn" + await ctx.session.send_prompt_list_changed() + return CallToolResult(content=[TextContent(text="learned")]) + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, message_handler=collect) as client: + await client.call_tool("learn", {}) + + assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py new file mode 100644 index 0000000000..229f8edf6f --- /dev/null +++ b/tests/interaction/lowlevel/test_progress.py @@ -0,0 +1,127 @@ +"""Progress interactions against the low-level Server, driven through the public Client API. + +Server-to-client progress emitted during a request follows the same ordering guarantee as +logging notifications (see test_logging.py): everything the server sends before its response is +dispatched to the progress callback before the request returns, so no synchronisation is needed. +The client-to-server direction is a standalone notification with no response to await, so that +test waits on an event set by the server's handler. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, ProgressNotificationParams, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("progress:server-to-client") +async def test_progress_during_tool_call_reaches_callback_in_order() -> None: + """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" + received: list[tuple[float, float | None, str | None]] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="download", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "download" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 1.0, total=3.0, message="first chunk") + await ctx.session.send_progress_notification(token, 2.0, total=3.0, message="second chunk") + await ctx.session.send_progress_notification(token, 3.0, total=3.0, message="done") + return CallToolResult(content=[TextContent(text="downloaded")]) + + server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("download", {}, progress_callback=collect) + + assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) + assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) + + +@requirement("progress:token-propagation") +async def test_progress_token_visible_to_handler() -> None: + """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def ignore(progress: float, total: float | None, message: str | None) -> None: + """A progress callback that is never invoked; the tool only inspects the token.""" + raise NotImplementedError + + async with Client(server) as client: + result = await client.call_tool("inspect", {}, progress_callback=ignore) + + # The token is the request id of the tools/call request itself (initialize is request 0). + assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + + +@requirement("progress:no-token") +async def test_no_progress_callback_means_no_token() -> None: + """Without a progress callback the request carries no progress token. + + The low-level API has no way to report request-scoped progress without a token, so a handler + that sees no token has nothing to send progress against. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="inspect", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "inspect" + assert ctx.meta is not None + return CallToolResult(content=[TextContent(text=str(ctx.meta.get("progress_token")))]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("inspect", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) + + +@requirement("progress:client-to-server") +async def test_client_progress_notification_reaches_server_handler() -> None: + """A progress notification sent by the client is delivered to the server's progress handler.""" + received: list[ProgressNotificationParams] = [] + delivered = anyio.Event() + + async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationParams) -> None: + received.append(params) + delivered.set() + + server = Server("observer", on_progress=on_progress) + + async with Client(server) as client: + await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot( + [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] + ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py new file mode 100644 index 0000000000..c87a00735d --- /dev/null +++ b/tests/interaction/lowlevel/test_roots.py @@ -0,0 +1,130 @@ +"""Roots interactions against the low-level Server, driven through the public Client API.""" + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import FileUrl + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, ListRootsResult, Root, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("roots:list:round-trip") +async def test_list_roots_round_trip() -> None: + """A roots/list request from a tool handler is answered by the client's roots callback. + + The tool reports the URIs and names it received, proving the client's roots reached the server. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + result = await ctx.session.list_roots() + lines = [f"{root.uri} name={root.name}" for root in result.roots] + return CallToolResult(content=[TextContent(text="\n".join(lines))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult( + roots=[ + Root(uri=FileUrl("file:///home/alice/project"), name="project"), + Root(uri=FileUrl("file:///home/alice/scratch")), + ] + ) + + async with Client(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="file:///home/alice/project name=project\nfile:///home/alice/scratch name=None")] + ) + ) + + +@requirement("roots:list:empty") +async def test_list_roots_empty() -> None: + """A client with no roots to offer answers roots/list with an empty list, not an error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="count_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "count_roots" + result = await ctx.session.list_roots() + return CallToolResult(content=[TextContent(text=str(len(result.roots)))]) + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + return ListRootsResult(roots=[]) + + async with Client(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("count_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) + + +@requirement("roots:list:not-supported") +async def test_list_roots_without_callback_is_error() -> None: + """A roots/list request to a client with no roots callback fails with an error the handler can observe. + + The client's default callback answers with INVALID_REQUEST rather than leaving the server hanging. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # list_roots cannot succeed without a client callback + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) + + +@requirement("roots:list-changed") +async def test_roots_list_changed_reaches_server_handler() -> None: + """A roots/list_changed notification from the client is delivered to the server's handler. + + Unlike a request, a notification has no response to await: the handler sets an event and the + test waits on it, which is the only synchronisation point proving delivery. + """ + delivered = anyio.Event() + received: list[types.NotificationParams | None] = [] + + async def roots_list_changed(ctx: ServerRequestContext, params: types.NotificationParams | None) -> None: + received.append(params) + delivered.set() + + server = Server("rooted", on_roots_list_changed=roots_list_changed) + + async with Client(server) as client: + await client.send_roots_list_changed() + with anyio.fail_after(5): + await delivered.wait() + + assert received == snapshot([None]) diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py new file mode 100644 index 0000000000..d109a32764 --- /dev/null +++ b/tests/interaction/lowlevel/test_sampling.py @@ -0,0 +1,291 @@ +"""Sampling interactions against the low-level Server, driven through the public Client API. + +Each test nests a sampling/createMessage request inside a tool call: the tool handler calls +ctx.session.create_message(), the client's sampling callback answers it, and the handler +round-trips what it received back to the test through its tool result. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + ErrorData, + ImageContent, + ModelHint, + ModelPreferences, + SamplingMessage, + TextContent, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("sampling:create-message:round-trip") +async def test_create_message_round_trip() -> None: + """A handler's sampling request is answered by the client callback, and the callback's result + (role, content, model, stop reason) is returned to the handler. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=f"{result.model}/{result.stop_reason}: {result.content.text}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult( + role="assistant", + content=TextContent(text="Hello to you too."), + model="mock-llm-1", + stop_reason="endTurn", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:params") +async def test_create_message_params_reach_callback() -> None: + """Every sampling parameter the handler supplies arrives at the client callback unchanged.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + max_tokens=50, + system_prompt="You are terse.", + temperature=0.7, + stop_sequences=["\n\n", "END"], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], + model_preferences=ModelPreferences( + hints=[ModelHint(name="claude"), ModelHint(name="gpt")], + cost_priority=0.2, + speed_priority=0.3, + intelligence_priority=0.9, + ), + system_prompt="You are terse.", + temperature=0.7, + max_tokens=50, + stop_sequences=["\n\n", "END"], + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_request_with_image_content_reaches_callback() -> None: + """A sampling request message carrying image content arrives at the client callback intact. + + This is the server-to-client direction: the server includes an image in the conversation it + asks the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="describe_image", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "describe_image" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + image = params.messages[0].content + assert isinstance(image, ImageContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"described {image.mime_type} ({image.data})"), + model="mock-vision-1", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("describe_image", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:image-content") +async def test_create_message_result_with_image_content_returns_to_handler() -> None: + """A sampling result whose content is an image is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is an image rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="draw", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "draw" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Draw a cat."))], + max_tokens=100, + ) + image = result.content + assert isinstance(image, ImageContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {image.mime_type} {image.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=ImageContent(data="Y2F0", mime_type="image/png"), + model="mock-vision-1", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("draw", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) + + +@requirement("sampling:create-message:client-error") +async def test_create_message_callback_error() -> None: + """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. + + The error here is the spec's own example for a user rejecting a sampling request (code -1); + the callback's code and message reach the handler verbatim, whatever they are. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: + return ErrorData(code=-1, message="User rejected sampling request") + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) + + +@requirement("sampling:create-message:not-supported") +async def test_create_message_without_callback_is_error() -> None: + """A sampling request to a client with no sampling callback fails with the SDK's default error.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=100, + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # create_message cannot succeed without a client callback + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) From d4a35585b9971a505ed91f9dc07039ecb7b6b7a7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 17:44:43 +0000 Subject: [PATCH 04/34] test: add URL elicitation, subscriptions, pagination, timeouts, and meta interaction tests Covers URL-mode elicitation (including the elicitation/complete lifecycle and the -32042 rejection flow), resource subscriptions and update notifications, cursor pagination across all four list methods, request and session read timeouts, _meta round trips, and the MCPServer Context convenience methods. Removes the 'pragma: no cover' from the resource-updated send path now that it is covered. --- src/mcp/server/session.py | 2 +- tests/interaction/_helpers.py | 17 ++ tests/interaction/_requirements.py | 142 ++++++++++++ .../interaction/lowlevel/test_elicitation.py | 216 +++++++++++++++++- .../interaction/lowlevel/test_list_changed.py | 8 +- tests/interaction/lowlevel/test_meta.py | 63 +++++ tests/interaction/lowlevel/test_pagination.py | 173 ++++++++++++++ tests/interaction/lowlevel/test_resources.py | 113 +++++++++ tests/interaction/lowlevel/test_timeouts.py | 112 +++++++++ tests/interaction/mcpserver/test_context.py | 165 +++++++++++++ 10 files changed, 1000 insertions(+), 11 deletions(-) create mode 100644 tests/interaction/_helpers.py create mode 100644 tests/interaction/lowlevel/test_meta.py create mode 100644 tests/interaction/lowlevel/test_pagination.py create mode 100644 tests/interaction/lowlevel/test_timeouts.py create mode 100644 tests/interaction/mcpserver/test_context.py diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index b577f278a7..fc2f97a9cb 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -223,7 +223,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py new file mode 100644 index 0000000000..3fe0f35324 --- /dev/null +++ b/tests/interaction/_helpers.py @@ -0,0 +1,17 @@ +"""Shared type aliases for the interaction suite. + +Keep this module small: it exists only for types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback. Server fixtures and assertion +helpers belong in the test that uses them. +""" + +from mcp.shared.session import RequestResponder +from mcp.types import ClientResult, ServerNotification, ServerRequest + +# TODO: this union is the parameter type of every client message handler (MessageHandlerFnT), +# but the SDK does not export a name for it -- writing a correctly-typed handler requires +# importing RequestResponder from mcp.shared.session and assembling the union by hand. It +# should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is +# for the request callbacks), at which point this module can be deleted. +IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception +"""Everything a client message handler can receive.""" diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 1091ade9a9..c815a902b9 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -151,6 +151,71 @@ class Requirement: behavior="A progress notification sent by the client is delivered to the server's progress handler.", ), # ═══════════════════════════════════════════════════════════════════════════ + # Timeouts + # ═══════════════════════════════════════════════════════════════════════════ + "timeouts:per-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior=( + "A request that exceeds its read timeout fails with a request-timeout error instead of " + "waiting forever for the response." + ), + divergence=Divergence( + note=( + "The spec says the requester SHOULD issue a cancellation notification for the timed-out " + "request; the client only raises locally and sends nothing, so the server keeps running " + "the handler." + ), + ), + ), + "timeouts:session-survives": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="The session continues to serve new requests after an earlier request timed out.", + ), + "timeouts:session-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A session-level read timeout applies to every request that does not override it.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:cursor-round-trip": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "The nextCursor returned by a list handler reaches the client, and the cursor the client " + "sends back on the next call reaches the handler as an opaque string." + ), + ), + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:resources": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", + ), + "pagination:resource-templates": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", + ), + "pagination:prompts": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Request metadata + # ═══════════════════════════════════════════════════════════════════════════ + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), + # ═══════════════════════════════════════════════════════════════════════════ # Ping # ═══════════════════════════════════════════════════════════════════════════ "ping:client-to-server": Requirement( @@ -283,6 +348,29 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/resources#error-handling", behavior="resources/read for an unknown URI returns a JSON-RPC error; the spec reserves -32002 for it.", ), + "resources:templates:list": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior=( + "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." + ), + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:unsubscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." + ), + ), + "resources:updated-notification": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior=( + "A resources/updated notification sent by the server reaches the client carrying the URI of " + "the changed resource." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Notifications: list_changed (server → client) # ═══════════════════════════════════════════════════════════════════════════ @@ -367,6 +455,36 @@ class Requirement: source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", ), + "elicitation:url:accept": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notification", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), "elicitation:form:not-supported": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", behavior=( @@ -457,6 +575,30 @@ class Requirement: ), ), ), + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), "mcpserver:tools:handler-exception": Requirement( source="sdk", behavior=( diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 6017580d86..f2f7b54d01 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -1,13 +1,23 @@ -"""Form-mode elicitation against the low-level Server, driven through the public Client API.""" +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API.""" import pytest from inline_snapshot import snapshot -from mcp import MCPError, types +from mcp import MCPError, UrlElicitationRequiredError, types from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, ElicitRequestFormParams, ElicitResult, TextContent +from mcp.types import ( + CallToolResult, + ElicitCompleteNotification, + ElicitCompleteNotificationParams, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ErrorData, + TextContent, +) +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -159,3 +169,203 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("ask", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) + + +@requirement("elicitation:url:accept") +async def test_elicit_url_delivers_url_and_returns_accept_without_content() -> None: + """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it + returns the action with no content. + + Accept means the user agreed to visit the URL, not that the out-of-band interaction finished, + so there is never form content to return. + """ + received: list[types.ElicitRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept") + + async with Client(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert received == snapshot( + [ + ElicitRequestURLParams( + _meta={}, + message="Authorize access to your calendar.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="accept content=None")])) + + +@requirement("elicitation:url:decline") +async def test_elicit_url_decline_returns_no_content() -> None: + """A declined URL elicitation returns the decline action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="decline") + + async with Client(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) + + +@requirement("elicitation:url:cancel") +async def test_elicit_url_cancel_returns_no_content() -> None: + """A cancelled URL elicitation returns the cancel action to the handler with no content.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="authorize", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "authorize" + answer = await ctx.session.elicit_url( + "Authorize access to your calendar.", "https://example.com/oauth/authorize", "auth-001" + ) + return CallToolResult(content=[TextContent(text=f"{answer.action} content={answer.content}")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="cancel") + + async with Client(server, elicitation_callback=answer_url) as client: + result = await client.call_tool("authorize", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) + + +@requirement("elicitation:complete-notification") +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client() -> None: + """After a URL elicitation finishes, the server announces it with a notification carrying the same id. + + The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user + agrees to visit the URL, the out-of-band interaction finishes, and the server emits + elicitation/complete so the client can correlate the completion with the elicitation it + accepted earlier. Both messages arrive before the tool call returns, so a plain collected + list needs no synchronisation. + """ + elicitation_id = "auth-001" + elicited_ids: list[str] = [] + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="link_account", description="Link an account.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "link_account" + answer = await ctx.session.elicit_url( + "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id + ) + assert answer.action == "accept" + await ctx.session.send_elicit_complete(elicitation_id) + return CallToolResult(content=[TextContent(text="linked")]) + + server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestURLParams) + elicited_ids.append(params.elicitation_id) + return ElicitResult(action="accept") + + async with Client(server, message_handler=collect, elicitation_callback=answer_url) as client: + await client.call_tool("link_account", {}) + + # The completion notification refers to the same elicitation the client accepted. + assert elicited_ids == [elicitation_id] + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="auth-001"))] + ) + + +@requirement("elicitation:url:required-error") +async def test_url_elicitation_required_error_carries_pending_elicitations() -> None: + """A request that cannot proceed until a URL interaction completes is rejected with error -32042. + + This is the non-interactive alternative to elicit_url: instead of asking and waiting, the + handler rejects the whole request and lists the required URL elicitations in the error data. + The client is expected to present those URLs, wait for the matching elicitation/complete + notifications, and retry the original request. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + server = Server("authorizer", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index de37d0e4eb..9bbdf7ee75 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -12,24 +12,18 @@ from mcp import types from mcp.client.client import Client from mcp.server import Server, ServerRequestContext -from mcp.shared.session import RequestResponder from mcp.types import ( CallToolResult, - ClientResult, PromptListChangedNotification, ResourceListChangedNotification, - ServerNotification, - ServerRequest, TextContent, ToolListChangedNotification, ) +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio -IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception -"""Everything a client message handler can receive.""" - @requirement("notifications:tools:list-changed") async def test_tool_list_changed_notification() -> None: diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py new file mode 100644 index 0000000000..a63acbfa5c --- /dev/null +++ b/tests/interaction/lowlevel/test_meta.py @@ -0,0 +1,63 @@ +"""Request and result _meta round trips against the low-level Server, through the public Client API. + +Meta is opaque pass-through data, so these tests assert identity against the value that was sent +rather than snapshotting a literal: the expected value and the sent value are the same variable, +which also proves the SDK injected nothing alongside it. +""" + +import pytest + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("meta:request-to-handler") +async def test_request_meta_reaches_handler() -> None: + """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" + request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} + observed_metas: list[dict[str, object]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="traced", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "traced" + assert ctx.meta is not None + observed_metas.append(dict(ctx.meta)) + return CallToolResult(content=[TextContent(text="traced")]) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + await client.call_tool("traced", {}, meta=request_meta) + + assert observed_metas == [dict(request_meta)] + + +@requirement("meta:result-to-client") +async def test_result_meta_reaches_client() -> None: + """The _meta object a handler attaches to its result is delivered to the client unchanged.""" + result_meta = {"example.com/cost": 3} + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="metered", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "metered" + return CallToolResult(content=[TextContent(text="done")], _meta=result_meta) + + server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("metered", {}) + + assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py new file mode 100644 index 0000000000..0c585d7896 --- /dev/null +++ b/tests/interaction/lowlevel/test_pagination.py @@ -0,0 +1,173 @@ +"""Cursor pagination of the list operations against the low-level Server. + +The cursor is an opaque string chosen by the server: the suite only asserts that whatever the +handler returns as next_cursor comes back verbatim on the client's next call, not any particular +pagination scheme. +""" + +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import ( + ListPromptsResult, + ListResourcesResult, + ListResourceTemplatesResult, + ListToolsResult, + Prompt, + Resource, + ResourceTemplate, + Tool, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("pagination:cursor-round-trip") +async def test_next_cursor_round_trips_through_the_client() -> None: + """The next_cursor a list handler returns reaches the client, and the cursor the client sends + back on the following call reaches the handler verbatim. + """ + seen_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None # the client always sends params, even without a cursor + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListToolsResult( + tools=[Tool(name="alpha", input_schema={"type": "object"})], + next_cursor="page-2", + ) + return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) + + server = Server("paginated", on_list_tools=list_tools) + + async with Client(server) as client: + first_page = await client.list_tools() + second_page = await client.list_tools(cursor="page-2") + + assert first_page == snapshot( + ListToolsResult(tools=[Tool(name="alpha", input_schema={"type": "object"})], next_cursor="page-2") + ) + assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) + assert seen_cursors == snapshot([None, "page-2"]) + + +@requirement("pagination:exhaustion") +async def test_paginating_until_next_cursor_is_absent_yields_every_page() -> None: + """Following next_cursor until it is absent visits every page exactly once, in order.""" + pages: dict[str | None, tuple[str, str | None]] = { + None: ("alpha", "page-2"), + "page-2": ("beta", "page-3"), + "page-3": ("gamma", None), + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + tool_name, next_cursor = pages[params.cursor] + return ListToolsResult(tools=[Tool(name=tool_name, input_schema={"type": "object"})], next_cursor=next_cursor) + + server = Server("paginated", on_list_tools=list_tools) + + collected: list[str] = [] + cursor: str | None = None + requests_made = 0 + async with Client(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + requests_made += 1 + assert requests_made <= len(pages), "the server kept returning next_cursor past the last page" + collected.extend(tool.name for tool in result.tools) + if result.next_cursor is None: + break + cursor = result.next_cursor + + assert collected == snapshot(["alpha", "beta", "gamma"]) + assert requests_made == len(pages) + + +@requirement("pagination:resources") +async def test_resources_list_supports_cursor_pagination() -> None: + """resources/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor="page-2") + return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) + + server = Server("paginated", on_list_resources=list_resources) + + async with Client(server) as client: + first_page = await client.list_resources() + second_page = await client.list_resources(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [resource.name for resource in first_page.resources] == ["first"] + assert first_page.next_cursor == "page-2" + assert [resource.name for resource in second_page.resources] == ["second"] + assert second_page.next_cursor is None + + +@requirement("pagination:resource-templates") +async def test_resource_templates_list_supports_cursor_pagination() -> None: + """resources/templates/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], + next_cursor="page-2", + ) + return ListResourceTemplatesResult( + resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] + ) + + server = Server("paginated", on_list_resource_templates=list_resource_templates) + + async with Client(server) as client: + first_page = await client.list_resource_templates() + second_page = await client.list_resource_templates(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [template.name for template in first_page.resource_templates] == ["first"] + assert first_page.next_cursor == "page-2" + assert [template.name for template in second_page.resource_templates] == ["second"] + assert second_page.next_cursor is None + + +@requirement("pagination:prompts") +async def test_prompts_list_supports_cursor_pagination() -> None: + """prompts/list round-trips the cursor like every other list operation.""" + seen_cursors: list[str | None] = [] + + async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: + assert params is not None + seen_cursors.append(params.cursor) + if params.cursor is None: + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor="page-2") + return ListPromptsResult(prompts=[Prompt(name="second")]) + + server = Server("paginated", on_list_prompts=list_prompts) + + async with Client(server) as client: + first_page = await client.list_prompts() + second_page = await client.list_prompts(cursor="page-2") + + assert seen_cursors == snapshot([None, "page-2"]) + assert [prompt.name for prompt in first_page.prompts] == ["first"] + assert first_page.next_cursor == "page-2" + assert [prompt.name for prompt in second_page.prompts] == ["second"] + assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 1d66e6722a..96b42d25a2 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -11,12 +11,20 @@ from mcp.types import ( Annotations, BlobResourceContents, + CallToolResult, + EmptyResult, ErrorData, ListResourcesResult, + ListResourceTemplatesResult, ReadResourceResult, Resource, + ResourceTemplate, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, + TextContent, TextResourceContents, ) +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -133,3 +141,108 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq await client.read_resource("file:///missing.txt") assert exc_info.value.error == snapshot(ErrorData(code=-32002, message="Resource not found: file:///missing.txt")) + + +@requirement("resources:templates:list") +async def test_list_resource_templates_returns_registered_templates() -> None: + """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" + + async def list_resource_templates( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourceTemplatesResult: + return ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + ), + ] + ) + + server = Server("library", on_list_resource_templates=list_resource_templates) + + async with Client(server) as client: + result = await client.list_resource_templates() + + assert result == snapshot( + ListResourceTemplatesResult( + resource_templates=[ + ResourceTemplate(uri_template="users://{user_id}", name="user"), + ResourceTemplate( + uri_template="logs://{service}/{date}", + name="service_logs", + title="Service logs", + description="One day of logs for one service.", + mime_type="text/plain", + ), + ] + ) + ) + + +@requirement("resources:subscribe") +async def test_subscribe_resource_delivers_uri_to_handler() -> None: + """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_subscribe_resource=subscribe_resource) + + async with Client(server) as client: + result = await client.subscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:unsubscribe") +async def test_unsubscribe_resource_delivers_uri_to_handler() -> None: + """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" + + async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: + assert params.uri == "file:///watched.txt" + return EmptyResult() + + server = Server("library", on_unsubscribe_resource=unsubscribe_resource) + + async with Client(server) as client: + result = await client.unsubscribe_resource("file:///watched.txt") + + assert result == snapshot(EmptyResult()) + + +@requirement("resources:updated-notification") +async def test_resource_updated_notification_reaches_client() -> None: + """A resources/updated notification sent during a tool call reaches the client with the resource URI. + + The collector records every message the handler receives, so the assertion also proves nothing + else was delivered. + """ + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="touch", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "touch" + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="touched")]) + + server = Server("library", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server, message_handler=collect) as client: + await client.call_tool("touch", {}) + + assert received == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py new file mode 100644 index 0000000000..ebd8cbde19 --- /dev/null +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -0,0 +1,112 @@ +"""Request timeouts against the low-level Server, driven through the public Client API. + +The handler blocks on an event that is never set, so the awaited response can never arrive and +any positive timeout fires deterministically on the next event-loop pass. The timeout is therefore +set to an effectively-zero duration: the tests add no wall-clock time to the suite. (Zero itself +cannot be used: a falsy read_timeout_seconds is silently treated as "no timeout".) +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, types +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("timeouts:per-request") +async def test_request_timeout_fails_the_pending_call() -> None: + """A request whose response does not arrive within its read timeout fails with a timeout error. + + No cancellation is sent to the server (see the divergence note on the requirement): the handler + starts and is still running after the caller has already given up. The test waits for the + handler to have started only after the timeout has fired, so the timeout itself races nothing. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + # The request was already on the wire: the handler still runs even though the caller gave up. + with anyio.fail_after(5): + await handler_started.wait() + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + ) + ) + + +@requirement("timeouts:session-survives") +async def test_session_serves_requests_after_timeout() -> None: + """A timed-out request does not poison the session: the next request succeeds.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[ + types.Tool(name="block", input_schema={"type": "object"}), + types.Tool(name="echo", input_schema={"type": "object"}), + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + if params.name == "echo": + return CallToolResult(content=[TextContent(text="still alive")]) + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with pytest.raises(MCPError): + await client.call_tool("block", {}, read_timeout_seconds=0.000001) + + result = await client.call_tool("echo", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) + + +@requirement("timeouts:session-default") +async def test_session_level_timeout_applies_to_every_request() -> None: + """A read timeout configured on the client applies to requests that do not set their own. + + The session default also governs the initialize handshake, so this is the one test in the + suite that needs a real (50ms) timeout: it must be long enough for the in-process handshake + to complete and is then waited out in full by the blocked tool call. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + await anyio.Event().wait() # blocks until the session is torn down + raise NotImplementedError # unreachable + + server = Server("blocker", on_call_tool=call_tool) + + async with Client(server, read_timeout_seconds=0.05) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("block", {}) + + assert exc_info.value.error == snapshot( + ErrorData( + code=REQUEST_TIMEOUT, + message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + ) + ) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py new file mode 100644 index 0000000000..e1c678fe8c --- /dev/null +++ b/tests/interaction/mcpserver/test_context.py @@ -0,0 +1,165 @@ +"""The Context convenience methods MCPServer injects into tool functions, observed from the client.""" + +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client import ClientRequestContext +from mcp.client.client import Client +from mcp.server.elicitation import AcceptedElicitation +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + ElicitRequestFormParams, + ElicitRequestParams, + ElicitResult, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:context:logging") +async def test_context_logging_helpers_send_log_notifications() -> None: + """Each Context logging helper sends a log message notification at the matching severity. + + All four notifications reach the client's logging callback before the tool call returns; none + of them carry a logger name unless one is passed explicitly. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("chatty") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + await ctx.debug("d") + await ctx.info("i") + await ctx.warning("w") + await ctx.error("e") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with Client(mcp, logging_callback=collect) as client: + result = await client.call_tool("narrate", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="d"), + LoggingMessageNotificationParams(level="info", data="i"), + LoggingMessageNotificationParams(level="warning", data="w"), + LoggingMessageNotificationParams(level="error", data="e"), + ] + ) + + +@requirement("mcpserver:context:progress") +async def test_context_report_progress_sends_progress_notifications() -> None: + """Context.report_progress sends progress notifications correlated to the calling request. + + The caller's progress callback receives each report, in order, before the tool call returns. + """ + received: list[tuple[float, float | None, str | None]] = [] + mcp = MCPServer("worker") + + @mcp.tool() + async def crunch(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.report_progress(2, 3, "halfway there") + return "crunched" + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with Client(mcp) as client: + result = await client.call_tool("crunch", {}, progress_callback=on_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="crunched")], structured_content={"result": "crunched"}) + ) + assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) + + +@requirement("mcpserver:context:elicit") +async def test_context_elicit_returns_typed_result() -> None: + """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. + + The client sees the JSON schema generated from the model; the accepted content is validated + back into the model and handed to the tool as result.data. + """ + received: list[ElicitRequestParams] = [] + mcp = MCPServer("travel") + + class TravelPreferences(BaseModel): + destination: str + window_seat: bool + + @mcp.tool() + async def book_flight(ctx: Context) -> str: + answer = await ctx.elicit("Where to?", TravelPreferences) + assert isinstance(answer, AcceptedElicitation) + return f"{answer.action}: {answer.data.destination} window={answer.data.window_seat}" + + async def answer_form(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) + + async with Client(mcp, elicitation_callback=answer_form) as client: + result = await client.call_tool("book_flight", {}) + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta={}, + message="Where to?", + requested_schema={ + "properties": { + "destination": {"title": "Destination", "type": "string"}, + "window_seat": {"title": "Window Seat", "type": "boolean"}, + }, + "required": ["destination", "window_seat"], + "title": "TravelPreferences", + "type": "object", + }, + ) + ] + ) + assert result == snapshot( + CallToolResult( + content=[TextContent(text="accept: Lisbon window=True")], + structured_content={"result": "accept: Lisbon window=True"}, + ) + ) + + +@requirement("mcpserver:context:read-resource") +async def test_context_read_resource_reads_registered_resource() -> None: + """Context.read_resource lets a tool read a resource registered on the same server. + + The tool reports the MIME type and content it read, proving the resource function ran and its + return value came back through the context. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def app_config() -> str: + """The application configuration.""" + return "theme = dark" + + @mcp.tool() + async def show_config(ctx: Context) -> str: + contents = list(await ctx.read_resource("config://app")) + return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) + + async with Client(mcp) as client: + result = await client.call_tool("show_config", {}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="text/plain: 'theme = dark'")], + structured_content={"result": "text/plain: 'theme = dark'"}, + ) + ) From d6c9b63c8c44bb815cca39b816decbe1137e31f4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 18:18:07 +0000 Subject: [PATCH 05/34] test: add lifecycle edge cases, concurrency, and behaviour-gap interaction tests Adds ClientSession-level tests for pre-initialization request rejection and protocol version negotiation, a proof that concurrent tool calls are dispatched simultaneously and answered independently, and tests pinning three behaviour gaps: tool-set mutations send no list_changed notification, logging/setLevel is not supported by MCPServer and no level filtering exists, and tool-enabled sampling is rejected because the high-level client cannot declare the sampling.tools capability. --- tests/interaction/_requirements.py | 65 +++++++++++++++ tests/interaction/lowlevel/test_initialize.py | 83 ++++++++++++++++++- tests/interaction/lowlevel/test_sampling.py | 41 +++++++++ tests/interaction/lowlevel/test_tools.py | 53 ++++++++++++ tests/interaction/mcpserver/test_context.py | 39 +++++++++ tests/interaction/mcpserver/test_tools.py | 54 +++++++++++- 6 files changed, 330 insertions(+), 5 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index c815a902b9..139f1faa2c 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -95,6 +95,17 @@ class Requirement: "(sampling, elicitation, roots)." ), ), + "lifecycle:initialize:protocol-version": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "The server echoes a requested protocol version it supports, and answers an unsupported " + "requested version with its own latest supported version rather than an error." + ), + ), + "lifecycle:requests-before-initialized": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior="A request sent before the initialization handshake completes is rejected with an error.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Cancellation # ═══════════════════════════════════════════════════════════════════════════ @@ -279,6 +290,13 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", ), + "tools:call:concurrent": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " + "receives the response to its own request." + ), + ), "tools:call:invalid-arguments": Requirement( source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior=( @@ -326,6 +344,21 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", ), + "logging:set-level:filtering": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior=( + "MCPServer registers no logging/setLevel handler (the request is rejected with method-not-found) " + "and log messages are delivered at every severity regardless of any requested level." + ), + divergence=Divergence( + note=( + "The spec says servers SHOULD only send log messages at or above the level the client " + "configured via logging/setLevel. Neither MCPServer (which rejects the request outright) " + "nor the low-level Server (which leaves the handler entirely to the author) implements " + "any filtering." + ), + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Resources # ═══════════════════════════════════════════════════════════════════════════ @@ -426,6 +459,25 @@ class Requirement: source=f"{SPEC_BASE_URL}/client/sampling#message-content", behavior="Sampling messages can carry image content: base64 data with a mimeType.", ), + "sampling:create-message:tools:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire, with an Invalid params error." + ), + ), + "sampling:create-message:tools:round-trip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#sampling-with-tools", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not expressible through the public API: Client does not expose ClientSession's " + "sampling_capabilities parameter, so a client can never declare sampling.tools and the " + "server-side validator rejects every tool-enabled request before it is sent." + ), + ), "sampling:create-message:client-error": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#error-handling", behavior="A sampling callback that returns an error is surfaced to the requesting handler as an MCPError.", @@ -599,6 +651,19 @@ class Requirement: source="sdk", behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", ), + "mcpserver:tools:list-changed-on-mutation": Requirement( + source="sdk", + behavior=( + "Adding or removing a tool on a running server changes what tools/list returns but sends no " + "notification to connected clients." + ), + divergence=Divergence( + note=( + "The spec provides notifications/tools/list_changed for exactly this case; MCPServer never " + "sends it, so a connected client cannot learn that the tool set changed without polling." + ), + ), + ), "mcpserver:tools:handler-exception": Requirement( source="sdk", behavior=( diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 6ade1de9da..be0b0ac2ef 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -1,17 +1,33 @@ -"""Initialization handshake against the low-level Server, driven through the public Client API.""" +"""Initialization handshake against the low-level Server, driven through the public Client API. +The last two tests drive a bare ClientSession over an InMemoryTransport instead: Client always +performs the full handshake with the latest protocol version, so skipping initialization or +requesting a different version can only be expressed one level down. +""" + +import anyio import pytest from inline_snapshot import snapshot -from mcp import types -from mcp.client import ClientRequestContext +from mcp import MCPError, types +from mcp.client import ClientRequestContext, ClientSession +from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( + INVALID_PARAMS, CallToolResult, + ClientCapabilities, CompletionsCapability, + EmptyResult, + ErrorData, Icon, Implementation, + InitializeRequest, + InitializeRequestParams, + InitializeResult, + ListToolsRequest, + ListToolsResult, LoggingCapability, PromptsCapability, ResourcesCapability, @@ -198,3 +214,64 @@ async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: async with Client(server, list_roots_callback=list_roots) as client: result = await client.call_tool("abilities", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="roots")])) + + +@requirement("lifecycle:requests-before-initialized") +async def test_request_before_initialization_is_rejected() -> None: + """A feature request sent before the handshake completes is rejected; ping is exempt. + + Client always initializes on entry, so this drives a bare ClientSession that never sends + initialize. The server's stated reason for the rejection never reaches the client: the error + is reported as a generic invalid-params failure. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + """Registered so the request is routed to a real handler; never reached.""" + raise NotImplementedError + + server = Server("strict", on_list_tools=list_tools) + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) + + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() + + assert exc_info.value.error == snapshot( + ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + assert pong == snapshot(EmptyResult()) + + +@requirement("lifecycle:initialize:protocol-version") +async def test_initialize_negotiates_protocol_version() -> None: + """The server echoes a supported requested version and answers an unsupported one with its latest. + + Client always requests the latest version, so each half hand-builds an InitializeRequest on a + bare ClientSession to control the requested version. + """ + server = Server("negotiator") + + def initialize_request(protocol_version: str) -> InitializeRequest: + return InitializeRequest( + params=InitializeRequestParams( + protocol_version=protocol_version, + capabilities=ClientCapabilities(), + client_info=Implementation(name="time-traveller", version="0.0.1"), + ) + ) + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + assert result.protocol_version == snapshot("2025-03-26") + + async with InMemoryTransport(server) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + assert result.protocol_version == snapshot("2025-11-25") diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index d109a32764..7a0a396be9 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -289,3 +289,44 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("ask_model", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) + + +@requirement("sampling:create-message:tools:not-supported") +async def test_create_message_with_tools_is_rejected_for_unsupporting_client() -> None: + """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. + + The client supports plain sampling but cannot declare the tools sub-capability (Client does not + expose it), so the server-side validator rejects the request before anything reaches the wire. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="What is the weather?"))], + max_tokens=100, + tools=[types.Tool(name="get_weather", input_schema={"type": "object"})], + ) + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the validator rejects every tool-enabled request + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the plain sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 0dc899ef9c..071180ddfd 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -1,5 +1,6 @@ """Tool interactions against the low-level Server, driven through the public Client API.""" +import anyio import pytest from inline_snapshot import snapshot @@ -264,3 +265,55 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("sum", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) + + +@requirement("tools:call:concurrent") +async def test_concurrent_tool_calls_complete_independently() -> None: + """Two tool calls in flight at once run concurrently and each caller gets its own answer. + + Both handlers are held on a shared event after signalling that they have started, and the test + only releases them once both signals have arrived -- a server that processed requests + sequentially would never start the second handler and the test would time out instead. + """ + started: list[str] = [] + started_events = {"first": anyio.Event(), "second": anyio.Event()} + release = anyio.Event() + results: dict[str, CallToolResult] = {} + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + tag = params.arguments["tag"] + assert isinstance(tag, str) + started.append(tag) + started_events[tag].set() + await release.wait() + return CallToolResult(content=[TextContent(text=tag)]) + + server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("echo", {"tag": tag}) + + task_group.start_soon(call_and_record, "first") + task_group.start_soon(call_and_record, "second") + + # Both handlers are running at the same time before either is allowed to finish. + await started_events["first"].wait() + await started_events["second"].wait() + release.set() + + assert sorted(started) == ["first", "second"] + assert results == snapshot( + { + "first": CallToolResult(content=[TextContent(text="first")]), + "second": CallToolResult(content=[TextContent(text="second")]), + } + ) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index e1c678fe8c..c6218fc58e 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -4,15 +4,18 @@ from inline_snapshot import snapshot from pydantic import BaseModel +from mcp import MCPError from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.server.elicitation import AcceptedElicitation from mcp.server.mcpserver import Context, MCPServer from mcp.types import ( + METHOD_NOT_FOUND, CallToolResult, ElicitRequestFormParams, ElicitRequestParams, ElicitResult, + ErrorData, LoggingMessageNotificationParams, TextContent, ) @@ -163,3 +166,39 @@ async def show_config(ctx: Context) -> str: structured_content={"result": "text/plain: 'theme = dark'"}, ) ) + + +@requirement("logging:set-level:filtering") +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered() -> None: + """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. + + The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, + and every message a tool emits is delivered regardless of level. The spec says the server + should only send messages at or above the configured level; with no way to configure one, + everything is sent. + """ + received: list[LoggingMessageNotificationParams] = [] + mcp = MCPServer("unfilterable") + + @mcp.tool() + async def chatter(ctx: Context) -> str: + await ctx.debug("noise") + await ctx.error("signal") + return "done" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + async with Client(mcp, logging_callback=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.set_logging_level("error") + + await client.call_tool("chatter", {}) + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + assert received == snapshot( + [ + LoggingMessageNotificationParams(level="debug", data="noise"), + LoggingMessageNotificationParams(level="error", data="signal"), + ] + ) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index 30def4870a..767750884e 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -5,9 +5,15 @@ from pydantic import BaseModel from mcp.client.client import Client -from mcp.server.mcpserver import MCPServer +from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError -from mcp.types import CallToolResult, TextContent +from mcp.types import ( + CallToolResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -184,3 +190,47 @@ def add(a: int, b: int) -> str: assert result.is_error is True assert isinstance(result.content[0], TextContent) assert result.content[0].text.startswith("Error executing tool add: 1 validation error") + + +@requirement("mcpserver:tools:list-changed-on-mutation") +async def test_adding_and_removing_tools_does_not_notify_connected_clients() -> None: + """Mutating the tool set on a running server changes tools/list but sends no notification. + + add_tool and remove_tool only update the registry: a connected client that listed the tools + before the mutation has no way to learn it should list them again. The spec provides + notifications/tools/list_changed for exactly this; MCPServer never sends it. The tool emits + one log message as a sentinel so the test proves notifications do reach the collector -- the + log message arrives, a list_changed does not. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("mutable") + + def extra() -> str: + """A tool registered at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + def doomed() -> str: + """A tool removed at runtime; never called.""" + raise NotImplementedError + + @mcp.tool() + async def grow(ctx: Context) -> str: + mcp.add_tool(extra, name="extra") + mcp.remove_tool("doomed") + await ctx.info("tool set changed") + return "mutated" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with Client(mcp, message_handler=collect) as client: + before = await client.list_tools() + await client.call_tool("grow", {}) + after = await client.list_tools() + + assert [tool.name for tool in before.tools] == ["doomed", "grow"] + assert [tool.name for tool in after.tools] == ["grow", "extra"] + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="tool set changed"))] + ) From a358aa439de6996325b5d16f3c83e658970b6fb4 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 18:27:30 +0000 Subject: [PATCH 06/34] test: add wire-level invariant tests via a recording transport A RecordingTransport wrapper tees every message crossing the client's transport boundary so the suite can assert properties that are invisible to API callers: request ids are unique and never null, notifications are never answered, and exactly one initialized notification is sent between the initialize response and the first feature request. --- tests/interaction/_helpers.py | 100 +++++++++++++++- tests/interaction/_requirements.py | 21 ++++ tests/interaction/lowlevel/test_timeouts.py | 13 ++- tests/interaction/lowlevel/test_wire.py | 122 ++++++++++++++++++++ tests/interaction/mcpserver/test_tools.py | 6 +- 5 files changed, 248 insertions(+), 14 deletions(-) create mode 100644 tests/interaction/lowlevel/test_wire.py diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 3fe0f35324..25833b0ca5 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -1,10 +1,18 @@ -"""Shared type aliases for the interaction suite. +"""Shared helpers for the interaction suite. -Keep this module small: it exists only for types that every test would otherwise have to -assemble from the SDK's internals to annotate a client callback. Server fixtures and assertion -helpers belong in the test that uses them. +Keep this module small: it exists only for (a) types that every test would otherwise have to +assemble from the SDK's internals to annotate a client callback, and (b) the recording transport +used by the wire-level tests. Server fixtures and assertion helpers belong in the test that uses +them. """ +from types import TracebackType + +import anyio +from typing_extensions import Self + +from mcp.client._transport import ReadStream, Transport, TransportStreams, WriteStream +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.types import ClientResult, ServerNotification, ServerRequest @@ -12,6 +20,88 @@ # but the SDK does not export a name for it -- writing a correctly-typed handler requires # importing RequestResponder from mcp.shared.session and assembling the union by hand. It # should be a named, exported alias next to MessageHandlerFnT (like ClientRequestContext is -# for the request callbacks), at which point this module can be deleted. +# for the request callbacks), at which point this alias can be deleted. IncomingMessage = RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception """Everything a client message handler can receive.""" + + +class _RecordingReadStream: + """Delegates to a read stream, appending every received message to a log.""" + + def __init__(self, inner: ReadStream[SessionMessage | Exception], log: list[SessionMessage | Exception]) -> None: + self._inner = inner + self._log = log + + async def receive(self) -> SessionMessage | Exception: + item = await self._inner.receive() + self._log.append(item) + return item + + async def aclose(self) -> None: + await self._inner.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> SessionMessage | Exception: + try: + return await self.receive() + except anyio.EndOfStream: + raise StopAsyncIteration from None + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class _RecordingWriteStream: + """Delegates to a write stream, appending every sent message to a log.""" + + def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage]) -> None: + self._inner = inner + self._log = log + + async def send(self, item: SessionMessage, /) -> None: + self._log.append(item) + await self._inner.send(item) + + async def aclose(self) -> None: + await self._inner.aclose() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + await self.aclose() + return None + + +class RecordingTransport: + """Wraps a Transport and records every message crossing the client's transport boundary. + + `sent` holds everything the client wrote towards the server; `received` holds everything the + server delivered to the client. The recording sits at the transport seam -- the exact payloads + a real transport would serialise -- and never touches the session, so wire-level assertions + written against it survive changes to the receive path. + """ + + def __init__(self, inner: Transport) -> None: + self.inner = inner + self.sent: list[SessionMessage] = [] + self.received: list[SessionMessage | Exception] = [] + + async def __aenter__(self) -> TransportStreams: + read_stream, write_stream = await self.inner.__aenter__() + return _RecordingReadStream(read_stream, self.received), _RecordingWriteStream(write_stream, self.sent) + + async def __aexit__( + self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None + ) -> bool | None: + return await self.inner.__aexit__(exc_type, exc_val, exc_tb) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 139f1faa2c..0cde7c9363 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -49,6 +49,20 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Protocol primitives # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null integer id; ids are never reused " + "within the session." + ), + ), + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", + behavior=( + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." + ), + ), "protocol:error:internal-error": Requirement( source=f"{SPEC_BASE_URL}/basic#responses", behavior="An unhandled exception in a request handler is returned to the caller as a JSON-RPC error.", @@ -106,6 +120,13 @@ class Requirement: source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", behavior="A request sent before the initialization handshake completes is rejected with an error.", ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "The client sends exactly one initialized notification, after the initialize response and " + "before its first feature request." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Cancellation # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index ebd8cbde19..4e7c64fba2 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -86,12 +86,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("timeouts:session-default") async def test_session_level_timeout_applies_to_every_request() -> None: - """A read timeout configured on the client applies to requests that do not set their own. - - The session default also governs the initialize handshake, so this is the one test in the - suite that needs a real (50ms) timeout: it must be long enough for the in-process handshake - to complete and is then waited out in full by the blocked tool call. - """ + """A read timeout configured on the client applies to requests that do not set their own.""" async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" @@ -100,6 +95,12 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_call_tool=call_tool) + # The one real wall-clock wait in the suite, and it cannot be made effectively zero like the + # per-request timeouts: a session-level timeout also governs the initialize handshake, so the + # value must be long enough for the in-process handshake to complete before the blocked tool + # call waits it out in full. 50ms buys a ~50x safety margin over the handshake's actual + # latency; lowering it only erodes the margin against CI scheduler jitter without saving + # anything perceptible. async with Client(server, read_timeout_seconds=0.05) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py new file mode 100644 index 0000000000..4ba86a9404 --- /dev/null +++ b/tests/interaction/lowlevel/test_wire.py @@ -0,0 +1,122 @@ +"""Wire-level invariants observed at the client's transport boundary. + +These behaviours are invisible to API callers -- they are properties of the raw JSON-RPC frames. +The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing +the transport seam into a list without touching the session, so the assertions hold for whatever +the session implementation sends rather than for what its API returns. +""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import types +from mcp.client._memory import InMemoryTransport +from mcp.client.client import Client +from mcp.server import Server, ServerRequestContext +from mcp.shared.message import SessionMessage +from mcp.types import CallToolResult, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, TextContent +from tests.interaction._helpers import RecordingTransport, _RecordingReadStream +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _echo_server() -> Server: + """A server with one echo tool, used by every test in this module.""" + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="echo", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + return CallToolResult(content=[TextContent(text="ok")]) + + return Server("wire", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("protocol:request-id:unique") +async def test_request_ids_are_unique_and_never_null() -> None: + """Every request the client sends carries a distinct, non-null id. + + The id sequence is pinned: sequential integers from zero, in send order, including the + schema-cache refresh the client performs after the first successful tool call. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + await client.call_tool("echo", {}) + await client.call_tool("echo", {}) + await client.send_ping() + + sent = [message.message for message in recording.sent] + request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + assert all(request_id is not None for request_id in request_ids) + assert len(request_ids) == len(set(request_ids)) + # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a + # schema-cache refresh here because the explicit tools/list already populated the cache. + assert request_ids == snapshot([0, 1, 2, 3, 4]) + + +@requirement("protocol:notifications:no-response") +async def test_notifications_are_never_answered() -> None: + """A notification produces no response: everything the server sends back answers a request. + + The client sends two notifications (initialized and roots/list_changed) and several requests; + the messages received from the server must be exactly one response per request, each carrying + the id of the request it answers, and nothing else. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.send_roots_list_changed() + await client.send_ping() + + sent = [message.message for message in recording.sent] + sent_request_ids = [message.id for message in sent if isinstance(message, JSONRPCRequest)] + sent_notifications = [message for message in sent if isinstance(message, JSONRPCNotification)] + received = [message.message for message in recording.received if isinstance(message, SessionMessage)] + received_responses = [message for message in received if isinstance(message, JSONRPCResponse)] + + assert len(sent_notifications) == 2 # notifications/initialized and notifications/roots/list_changed + assert len(received_responses) == len(received) # nothing the server sent was anything but a response + assert [message.id for message in received_responses] == sent_request_ids + + +async def test_recording_read_stream_ends_iteration_when_the_sender_closes() -> None: + """The recording wrapper preserves the end-of-stream behaviour of the stream it wraps. + + This exercises the helper itself rather than an interaction-model behaviour: a transport whose + far end closes must end the client's receive loop cleanly, and the wrapper must not swallow or + mistranslate that. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1) + log: list[SessionMessage | Exception] = [] + async with send_stream, _RecordingReadStream(receive_stream, log) as wrapped: + await send_stream.aclose() + items = [item async for item in wrapped] + assert items == [] + assert log == [] + + +@requirement("lifecycle:initialized-notification") +async def test_exactly_one_initialized_notification_is_sent_after_the_handshake() -> None: + """The client sends initialized exactly once, between the initialize response and its first request. + + The full method sequence the client puts on the wire is pinned in send order. + """ + recording = RecordingTransport(InMemoryTransport(_echo_server())) + + async with Client(recording) as client: + await client.list_tools() + + sent_methods = [ + message.message.method + for message in recording.sent + if isinstance(message.message, JSONRPCRequest | JSONRPCNotification) + ] + assert sent_methods.count("notifications/initialized") == 1 + assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index 767750884e..1724360d5e 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -173,9 +173,6 @@ def primes() -> list[int]: async def test_call_tool_invalid_arguments_become_error_result() -> None: """Arguments that fail validation against the tool's signature are reported as an is_error result describing the failure, not as a protocol error. - - The description is raw pydantic output (version-dependent and leaking the internal argument - model name), so only the stable prefix is asserted rather than the full text. """ mcp = MCPServer("adder") @@ -187,6 +184,9 @@ def add(a: int, b: int) -> str: async with Client(mcp) as client: result = await client.call_tool("add", {"b": 3}) + # The description is raw pydantic output -- it embeds a pydantic-version-specific + # errors.pydantic.dev URL and the internal `addArguments` model name -- so only the stable + # prefix is asserted; a full snapshot would break on every pydantic upgrade. assert result.is_error is True assert isinstance(result.content[0], TextContent) assert result.content[0].text.startswith("Error executing tool add: 1 validation error") From d739975cf3936fb92c590b5eb6a9ba4f1589ad6f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 18:37:36 +0000 Subject: [PATCH 07/34] test: add in-process streamable HTTP transport smoke tests Drives the streamable HTTP Starlette app through httpx's ASGI transport so the full HTTP framing layer (session ids, SSE and JSON response encoding, stateful and stateless modes) runs with no sockets, threads, or subprocesses. Covers the handshake, tool calls and errors, mid-call notifications, the stateless rejection of server-initiated requests, and the routing of unrelated server messages to the standalone stream. Removes the 'pragma: no cover' comments these tests now cover (the session-manager accessors, the no-session-id validation path, and the related-request-id routing branch). The session-manager accessor's unreachable error guard keeps its pragma, moved onto the raise statement itself so the now-executed condition above it is measured normally. --- src/mcp/server/lowlevel/server.py | 6 +- src/mcp/server/mcpserver/server.py | 2 +- src/mcp/server/streamable_http.py | 4 +- tests/interaction/_requirements.py | 61 +++++ tests/interaction/transports/__init__.py | 0 .../transports/test_streamable_http.py | 224 ++++++++++++++++++ 6 files changed, 291 insertions(+), 6 deletions(-) create mode 100644 tests/interaction/transports/__init__.py create mode 100644 tests/interaction/transports/test_streamable_http.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 419e06f770..d1a15120af 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -349,12 +349,12 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - if self._session_manager is None: # pragma: no cover - raise RuntimeError( + if self._session_manager is None: + raise RuntimeError( # pragma: no cover "Session manager can only be accessed after calling streamable_http_app(). " "The session manager is created lazily to avoid unnecessary initialization." ) - return self._session_manager # pragma: no cover + return self._session_manager async def run( self, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index b3471163b7..ec2365810e 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -244,7 +244,7 @@ def session_manager(self) -> StreamableHTTPSessionManager: Raises: RuntimeError: If called before streamable_http_app() has been called. """ - return self._lowlevel_server.session_manager # pragma: no cover + return self._lowlevel_server.session_manager @overload def run(self, transport: Literal["stdio"] = ...) -> None: ... diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f14201857c..fbe3bd9676 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -818,7 +818,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # If we're not using session IDs, return True return True @@ -993,7 +993,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 0cde7c9363..fe70fd828c 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -591,6 +591,67 @@ class Requirement: behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", ), # ═══════════════════════════════════════════════════════════════════════════ + # Transports + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream; a client that never opened one does not receive it." + ), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + deferred=( + "The in-process ASGI client buffers each response in full, which deadlocks on a " + "server-to-client request nested inside a still-open call. Covered over a real socket by " + "tests/shared/test_streamable_http.py." + ), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="The interaction round trip works over a stdio subprocess.", + deferred=( + "Requires a real subprocess. Process lifecycle is covered by tests/client/test_stdio.py and " + "end-to-end stdio coverage belongs to the cross-SDK conformance suite." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ # MCPServer behavioural guarantees (not spec-mandated) # ═══════════════════════════════════════════════════════════════════════════ "mcpserver:tools:output-schema:model": Requirement( diff --git a/tests/interaction/transports/__init__.py b/tests/interaction/transports/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py new file mode 100644 index 0000000000..4e5dd306c5 --- /dev/null +++ b/tests/interaction/transports/test_streamable_http.py @@ -0,0 +1,224 @@ +"""Smoke tests for the interaction model over the streamable HTTP transport, entirely in process. + +The Starlette app a real deployment would hand to uvicorn is driven through httpx's ASGI +transport instead: the full HTTP framing layer runs (session ids, SSE and JSON response +encoding, stateful and stateless session management) with no sockets, threads, or subprocesses, +so these tests are as deterministic as the in-memory ones. + +The ASGI client buffers each response in full before the client sees any of it. Request, +response, and notification flows are unaffected -- notifications are written to the request's +SSE stream before the response and arrive in order -- but a server-initiated request nested +inside a still-open call would deadlock, so that scenario is deferred to the real-socket +transport tests (see the `transport:streamable-http:server-to-client` requirement). +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import BaseModel + +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server.mcpserver import Context, MCPServer +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + CallToolResult, + LoggingMessageNotification, + LoggingMessageNotificationParams, + TextContent, +) +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _smoke_server() -> MCPServer: + """A server exercising one example of each message shape the smoke tests need.""" + mcp = MCPServer("smoke", instructions="Talk to the smoke server.") + + @mcp.tool() + def echo(text: str) -> str: + """Echo the text back.""" + return text + + @mcp.tool() + def fail() -> str: + """Always fails.""" + raise ValueError("deliberately broken") + + @mcp.tool() + async def narrate(ctx: Context) -> str: + """Send a log message and a progress update, then return.""" + await ctx.info("starting") + await ctx.report_progress(1, 2) + await ctx.info("finishing") + return "narrated" + + class Confirmation(BaseModel): + confirmed: bool + + @mcp.tool() + async def ask(ctx: Context) -> str: + """Attempt a server-initiated elicitation.""" + await ctx.elicit("Proceed?", Confirmation) + raise NotImplementedError # only called in stateless mode, where the elicit cannot succeed + + @mcp.tool() + async def announce(ctx: Context) -> str: + """Send one notification related to this request and one that is not.""" + await ctx.info("about to announce") + await ctx.session.send_resource_updated("file:///watched.txt") + return "announced" + + return mcp + + +@asynccontextmanager +async def _connected( + mcp: MCPServer, *, stateless_http: bool = False, json_response: bool = False +) -> AsyncIterator[Client]: + """Yield a Client connected to the server through the in-process streamable HTTP stack.""" + # DNS-rebinding protection validates Host/Origin headers against a real network attack that + # cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation + # branch (deliberately uncovered in src) into coverage. + app = mcp.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + async with mcp.session_manager.run(): + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) + async with Client(transport) as client: + yield client + + +@requirement("transport:streamable-http:stateful") +async def test_initialize_and_call_a_tool_over_streamable_http() -> None: + """The handshake and a tool round trip work through the stateful SSE framing.""" + async with _connected(_smoke_server()) as client: + assert client.initialize_result.server_info.name == "smoke" + assert client.initialize_result.instructions == "Talk to the smoke server." + result = await client.call_tool("echo", {"text": "over http"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="over http")], structured_content={"result": "over http"}) + ) + + +@requirement("transport:streamable-http:stateful") +async def test_tool_errors_round_trip_over_streamable_http() -> None: + """A tool execution error crosses the HTTP framing as an is_error result, not a transport failure.""" + async with _connected(_smoke_server()) as client: + result = await client.call_tool("fail", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="Error executing tool fail: deliberately broken")], is_error=True) + ) + + +@requirement("transport:streamable-http:json-response") +async def test_tool_call_over_streamable_http_with_json_responses() -> None: + """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" + async with _connected(_smoke_server(), json_response=True) as client: + assert client.initialize_result.server_info.name == "smoke" + result = await client.call_tool("echo", {"text": "as json"}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="as json")], structured_content={"result": "as json"}) + ) + + +@requirement("transport:streamable-http:stateless") +async def test_tool_calls_over_stateless_streamable_http() -> None: + """Consecutive requests each succeed against a stateless server with no session to share.""" + async with _connected(_smoke_server(), stateless_http=True) as client: + first = await client.call_tool("echo", {"text": "first"}) + second = await client.call_tool("echo", {"text": "second"}) + + assert first == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + assert second == snapshot( + CallToolResult(content=[TextContent(text="second")], structured_content={"result": "second"}) + ) + + +@requirement("transport:streamable-http:notifications") +async def test_notifications_during_a_tool_call_arrive_before_the_response() -> None: + """Log and progress notifications emitted mid-call are delivered on the call's SSE stream in order.""" + logs: list[LoggingMessageNotificationParams] = [] + progress_updates: list[tuple[float, float | None, str | None]] = [] + + async def collect_log(params: LoggingMessageNotificationParams) -> None: + logs.append(params) + + async def collect_progress(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) + + server = _smoke_server() + app = server.streamable_http_app( + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) + async with Client(transport, logging_callback=collect_log) as client: + result = await client.call_tool("narrate", {}, progress_callback=collect_progress) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="narrated")], structured_content={"result": "narrated"}) + ) + assert [params.data for params in logs] == snapshot(["starting", "finishing"]) + assert progress_updates == snapshot([(1.0, 2.0, None)]) + + +@requirement("transport:streamable-http:stateless-restrictions") +async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: + """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + async with _connected(_smoke_server(), stateless_http=True) as client: + result = await client.call_tool("ask", {}) + + assert result.is_error is True + assert isinstance(result.content[0], TextContent) + # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error + # path; pin the stable prefix rather than the full exception prose. + assert result.content[0].text.startswith("Error executing tool ask:") + + +@requirement("transport:streamable-http:unrelated-messages") +async def test_unrelated_server_messages_are_not_delivered_without_a_listening_stream() -> None: + """A server message with no related request goes to the standalone GET stream, not the call's stream. + + The client never opens the standalone stream, so the resource-updated notification is silently + dropped. The log notification sent by the same tool IS related to the call and does arrive, + proving the collector works and making the absence of the unrelated one meaningful. This is + the transport behaviour that makes `related_request_id` matter. + """ + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + server = _smoke_server() + app = server.streamable_http_app( + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) + async with Client(transport, message_handler=collect) as client: + result = await client.call_tool("announce", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) + ) + # Only the related log notification arrives; the resource-updated notification went to the + # standalone stream nobody is reading. + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] + ) From 2f0da6e223b881a69682d1b9123077f91e2d3b3c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Sat, 23 May 2026 19:34:59 +0000 Subject: [PATCH 08/34] test: document the interaction suite's conventions and manifest workflow --- tests/interaction/README.md | 168 ++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tests/interaction/README.md diff --git a/tests/interaction/README.md b/tests/interaction/README.md new file mode 100644 index 0000000000..4f7e3dc1f3 --- /dev/null +++ b/tests/interaction/README.md @@ -0,0 +1,168 @@ +# Interaction-model test suite + +This suite enumerates the MCP interaction model as end-to-end tests: one test per piece of +functionality, asserting the full client↔server round trip through the public API. It exists to +pin the SDK's observable behaviour — every request type, every notification direction, every +error plane — so that internal rewrites of the send/receive path can be proven equivalent by +running the suite before and after. + +```bash +uv run --frozen pytest tests/interaction/ +``` + +The whole suite is in-memory and event-driven; it runs in about a second. + +## Ground rules + +- **Public API only.** Tests drive a `Client` connected to a `Server` or `MCPServer`. Nothing + reaches into session internals, so the suite keeps working when those internals change. + `ClientSession` is used directly only for behaviours `Client` cannot express (skipping + initialization, requesting a non-default protocol version). +- **Pin current behaviour.** Every test passes against the current `main`, including behaviours + that diverge from the specification. A failing or xfailed test proves nothing about whether a + rewrite preserved behaviour; a passing test that pins the wrong output exactly does. Known + divergences are recorded as data on the requirement (see below), not worked around in the test. +- **Spec-mandated assertions, not implementation quirks.** Error *codes* are asserted against + the constants in `mcp.types`; error *message strings* are pinned only where they are the + SDK's own deliberate output. +- **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that + could hang is bounded by `anyio.fail_after(5)`. The streamable HTTP tests drive the Starlette + app in-process through `httpx.ASGITransport` — no sockets, threads, or subprocesses anywhere. + +## Layout + +```text +tests/interaction/ + _requirements.py the requirements manifest (see below) + _helpers.py shared type aliases + the wire-recording transport + test_coverage.py enforces the manifest ↔ test contract + lowlevel/ one file per feature area, against the low-level Server + mcpserver/ the same feature areas in MCPServer's natural idiom + transports/ a smoke subset over the streamable HTTP framing +``` + +The two server APIs produce genuinely different wire output for the same conceptual feature +(`MCPServer` generates schemas, converts exceptions to `isError` results, attaches structured +content), so they get parallel directories with mirrored file names rather than one parametrized +test body — each directory pins its flavour's true output exactly. + +## The requirements manifest + +`_requirements.py` maps every behaviour the suite covers to the reason it must hold: + +```python +"tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content.", +), +``` + +- **`source`** is a deep link into the MCP specification for externally mandated behaviour, + the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or + `"issue:#n"` for a regression lock. +- **`behavior`** describes what the suite *asserts* — which is always the SDK's current + behaviour, never an aspiration. +- **`divergence`** records the gap when current behaviour differs from what `source` mandates, + with an issue link once one exists. The test still pins current behaviour. +- **`deferred`** marks a behaviour that is deliberately not covered, with the reason. + +Tests link themselves to the manifest with a decorator: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: ... +``` + +`test_coverage.py` enforces the contract in both directions: every non-deferred requirement must +be exercised by at least one test, every deferred requirement by none, and an unknown ID fails at +import time. A behaviour without a manifest entry cannot be silently half-tested, and a manifest +entry without a test cannot be silently aspirational. + +### The divergence lifecycle + +1. A test reveals that the SDK does not do what the spec says. The test pins what the SDK + *actually does* and a `Divergence(note=..., issue=...)` goes on the requirement. +2. When the behaviour is eventually fixed, the pinned test fails. Whoever makes the change finds + the divergence note explaining that the old behaviour was a known gap, re-pins the test to the + spec-correct output, and deletes the `Divergence`. +3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. + +This is also the triage key for any rewrite: a test that fails on the new code path either has a +divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or +it does not (the rewrite broke something that was correct — fix the rewrite). + +### When a new spec revision is released + +1. Update `SPEC_REVISION` and walk the new revision's changelog. +2. For each changed interaction, find its requirements (the IDs use the wire method strings the + changelog speaks in), re-audit the tests against the new text, and update `source` links and + assertions where behaviour legitimately changed. +3. New interactions get new requirements and new tests; removed interactions get their + requirements deleted along with their tests. +4. A behaviour that is correct under both revisions needs no change beyond the `source` link. + +## Writing a test + +The shortest complete example of the conventions: + +```python +@requirement("tools:call:content:text") +async def test_call_tool_returns_text_content() -> None: + """Arguments reach the tool handler; its content comes back as the call result.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "add" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["a"] + params.arguments["b"]))]) + + server = Server("adder", on_call_tool=call_tool) + + async with Client(server) as client: + result = await client.call_tool("add", {"a": 2, "b": 3}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) +``` + +- **The server is defined inside the test** (or in a small fixture at the top of the file when + several tests genuinely share it). The whole observable behaviour fits on one screen. +- **Test names are behaviour sentences** — they state the observable outcome, not the feature + being poked. Docstrings add the one or two sentences of context a reviewer needs, including + whether the assertion is spec-mandated, SDK-defined, or a known divergence. +- **Handlers assert their dispatch identity first** (`assert params.name == "add"`), proving the + request that arrived is the request the test sent. +- **The result proves the round trip.** Server-side observations travel back to the test through + the protocol itself (a tool returns what it saw) or through a closure-captured list; the test + asserts after the call returns. +- **Order within a test**: server handlers → server construction → client callbacks → connect → + act → assert. The test reads in the order the conversation happens. +- A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body + so it cannot silently become load-bearing. + +### Choosing an assertion + +| The property under test is… | Assert with | +|---|---| +| the result of a transformation (arguments → output, exception → error result) | `result == snapshot(...)` of the full object, so any field the implementation adds or drops fails the test | +| pass-through of an opaque value (`_meta`, cursors) | identity against the same variable that was sent — a snapshot of a pass-through value only matches the input because a human checked two literals correspond | +| an error | `pytest.raises(MCPError)` and a snapshot of `exc.value.error` when the message is the SDK's own; a plain `==` on `.code` against the `mcp.types` constant when it is not | +| third-party output embedded in a result (validation messages) | the stable prefix only — never pin text that changes with a dependency upgrade | + +### Notifications and concurrency + +The client's receive loop dispatches each incoming message to completion before reading the next, +and the in-memory transport delivers everything on one ordered stream. Together these guarantee +that every notification a server handler emits before its response reaches the client callback +before the originating request returns — so tests collect notifications into a plain list and +assert after the call, with no synchronisation. The exceptions: + +- a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in + the receiving handler and awaited under `anyio.fail_after(5)`; +- the ordering guarantee does not survive transports that split messages across streams (the + streamable HTTP standalone GET stream) — see `transports/test_streamable_http.py`. + +### Coverage + +CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the +build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# pragma`, +`# type: ignore`, or `# noqa` comments; restructure instead. From cce06b2f68a89e1787834d16f12645873742ecf7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 09:44:11 +0000 Subject: [PATCH 09/34] test: correct spec anchors and record further divergences in the requirements manifest Fixes the spec deep links that pointed at non-existent anchors, records the divergences for the client's default not-supported answers (the spec names -32601 for roots and -32602 for an undeclared elicitation mode; the default callbacks answer -32600), and adds a logging:capability requirement noting that MCPServer emits log message notifications without declaring the logging capability. Also tightens behaviour sentences and docstrings to match what the tests assert, and adds a test pinning that Context.report_progress is a silent no-op when the caller supplied no progress token, removing the pragma on that path. --- src/mcp/server/mcpserver/context.py | 2 +- tests/interaction/_requirements.py | 56 ++++++++++++++----- .../interaction/lowlevel/test_elicitation.py | 3 +- tests/interaction/lowlevel/test_roots.py | 3 +- tests/interaction/lowlevel/test_wire.py | 3 +- tests/interaction/mcpserver/test_context.py | 40 ++++++++++++- 6 files changed, 87 insertions(+), 20 deletions(-) diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index e87388eee9..d4344daa92 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -94,7 +94,7 @@ async def report_progress(self, progress: float, total: float | None = None, mes """ progress_token = self.request_context.meta.get("progress_token") if self.request_context.meta else None - if progress_token is None: # pragma: no cover + if progress_token is None: return await self.request_context.session.send_progress_notification( diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index fe70fd828c..2e83fb006b 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -240,11 +240,11 @@ class Requirement: # Request metadata # ═══════════════════════════════════════════════════════════════════════════ "meta:request-to-handler": Requirement( - source=f"{SPEC_BASE_URL}/basic#meta", + source=f"{SPEC_BASE_URL}/basic#_meta", behavior="The _meta object the client attaches to a request is visible to the server handler.", ), "meta:result-to-client": Requirement( - source=f"{SPEC_BASE_URL}/basic#meta", + source=f"{SPEC_BASE_URL}/basic#_meta", behavior="The _meta object a handler attaches to its result is delivered to the client.", ), # ═══════════════════════════════════════════════════════════════════════════ @@ -337,7 +337,7 @@ class Requirement: behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", ), "completion:complete:context": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/completion#context", + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", ), "completion:complete:not-supported": Requirement( @@ -351,7 +351,7 @@ class Requirement: # Logging # ═══════════════════════════════════════════════════════════════════════════ "logging:set-level": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", ), "logging:message:notification": Requirement( @@ -365,8 +365,22 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", ), + "logging:capability": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "MCPServer tools emit log message notifications through the Context helpers while the server's " + "advertised capabilities omit logging." + ), + divergence=Divergence( + note=( + "The spec says servers that emit log message notifications MUST declare the logging " + "capability; MCPServer registers no setLevel handler, so capability derivation leaves " + "logging unset even though the Context helpers send the notifications." + ), + ), + ), "logging:set-level:filtering": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", behavior=( "MCPServer registers no logging/setLevel handler (the request is rejected with method-not-found) " "and log messages are delivered at every severity regardless of any requested level." @@ -477,7 +491,7 @@ class Requirement: ), ), "sampling:create-message:image-content": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#message-content", + source=f"{SPEC_BASE_URL}/client/sampling#image-content", behavior="Sampling messages can carry image content: base64 data with a mimeType.", ), "sampling:create-message:tools:not-supported": Requirement( @@ -506,18 +520,19 @@ class Requirement: "sampling:create-message:not-supported": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#capabilities", behavior=( - "A sampling request to a client that did not declare the sampling capability fails with an " - "error rather than hanging or being silently dropped." + "A sampling request to a client that did not declare the sampling capability fails with the " + "client's default-callback error (-32600 Invalid request) rather than hanging or being " + "silently dropped; the spec names no error code for this case." ), ), # ═══════════════════════════════════════════════════════════════════════════ # Elicitation (server → client) # ═══════════════════════════════════════════════════════════════════════════ "elicitation:form:accept": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation", + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", behavior=( "A form-mode elicitation answered with action 'accept' returns the user's content to the " - "requesting handler, validated against the requested schema." + "requesting handler." ), ), "elicitation:form:decline": Requirement( @@ -529,7 +544,7 @@ class Requirement: behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", ), "elicitation:url:accept": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation", + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", behavior=( "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " "response carries no content (accept means the user agreed to visit the URL, not that the " @@ -545,7 +560,7 @@ class Requirement: behavior="A URL-mode elicitation answered with cancel returns the action with no content.", ), "elicitation:complete-notification": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#completion-notification", + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", behavior=( "An elicitation/complete notification sent by the server after an out-of-band elicitation " "finishes reaches the client carrying the elicitationId." @@ -559,11 +574,18 @@ class Requirement: ), ), "elicitation:form:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", behavior=( "An elicitation request to a client that did not declare the elicitation capability fails with " "an error rather than hanging or being silently dropped." ), + divergence=Divergence( + note=( + "The spec says a request for an elicitation mode the client has not declared MUST be " + "answered with -32602 Invalid params; the client's default callback answers with -32600 " + "Invalid request." + ), + ), ), # ═══════════════════════════════════════════════════════════════════════════ # Roots (server → client) @@ -580,11 +602,17 @@ class Requirement: behavior="An empty roots list is a valid response and reaches the handler as such.", ), "roots:list:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#capabilities", + source=f"{SPEC_BASE_URL}/client/roots#error-handling", behavior=( "A roots/list request to a client that did not declare the roots capability fails with an " "error rather than hanging or being silently dropped." ), + divergence=Divergence( + note=( + "The spec says a client that does not support roots SHOULD answer with -32601 Method not " + "found; the client's default callback answers with -32600 Invalid request." + ), + ), ), "roots:list-changed": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index f2f7b54d01..056173ac2b 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -145,7 +145,8 @@ async def test_elicit_form_without_callback_is_error() -> None: """Eliciting from a client that configured no elicitation callback fails with an error. The client's default callback answers with an Invalid request error, which the server-side - elicit call raises as an MCPError; the tool reports the code and message it caught. + elicit call raises as an MCPError; the tool reports the code and message it caught. The spec + requires -32602 for an undeclared mode (see the divergence note on the requirement). """ async def list_tools( diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index c87a00735d..b98e7ff315 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -82,7 +82,8 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: async def test_list_roots_without_callback_is_error() -> None: """A roots/list request to a client with no roots callback fails with an error the handler can observe. - The client's default callback answers with INVALID_REQUEST rather than leaving the server hanging. + The client's default callback answers with INVALID_REQUEST rather than leaving the server + hanging; the spec names -32601 for this case (see the divergence note on the requirement). """ async def list_tools( diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 4ba86a9404..f7e55ecaf3 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -41,8 +41,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_request_ids_are_unique_and_never_null() -> None: """Every request the client sends carries a distinct, non-null id. - The id sequence is pinned: sequential integers from zero, in send order, including the - schema-cache refresh the client performs after the first successful tool call. + The id sequence is pinned: sequential integers from zero, in send order. """ recording = RecordingTransport(InMemoryTransport(_echo_server())) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index c6218fc58e..d24fd62511 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -16,20 +16,24 @@ ElicitRequestParams, ElicitResult, ErrorData, + LoggingMessageNotification, LoggingMessageNotificationParams, TextContent, ) +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("mcpserver:context:logging") +@requirement("logging:capability") async def test_context_logging_helpers_send_log_notifications() -> None: """Each Context logging helper sends a log message notification at the matching severity. All four notifications reach the client's logging callback before the tool call returns; none - of them carry a logger name unless one is passed explicitly. + of them carry a logger name unless one is passed explicitly. The server emits these without + advertising the logging capability (see the divergence note on logging:capability). """ received: list[LoggingMessageNotificationParams] = [] mcp = MCPServer("chatty") @@ -47,6 +51,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async with Client(mcp, logging_callback=collect) as client: result = await client.call_tool("narrate", {}) + advertised_logging = client.initialize_result.capabilities.logging assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) assert received == snapshot( @@ -57,6 +62,8 @@ async def collect(params: LoggingMessageNotificationParams) -> None: LoggingMessageNotificationParams(level="error", data="e"), ] ) + # The spec requires servers that emit log notifications to declare the logging capability. + assert advertised_logging is None @requirement("mcpserver:context:progress") @@ -86,6 +93,37 @@ async def on_progress(progress: float, total: float | None, message: str | None) assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) +@requirement("progress:no-token") +async def test_report_progress_without_a_progress_token_sends_nothing() -> None: + """When the caller supplied no progress callback, Context.report_progress is a silent no-op. + + The tool also emits one log message as a sentinel: the message handler receives only that, + proving the notification pipeline works and no progress notification was sent for the + token-less request. + """ + received: list[IncomingMessage] = [] + mcp = MCPServer("quiet") + + @mcp.tool() + async def mill(ctx: Context) -> str: + await ctx.report_progress(1, 3) + await ctx.info("milling done") + return "milled" + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with Client(mcp, message_handler=collect) as client: + result = await client.call_tool("mill", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="milled")], structured_content={"result": "milled"}) + ) + assert received == snapshot( + [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="milling done"))] + ) + + @requirement("mcpserver:context:elicit") async def test_context_elicit_returns_typed_result() -> None: """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. From 7709b98816366b0ac2f34b7c7e13e6aa7f8d9933 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 12:50:41 +0000 Subject: [PATCH 10/34] test: add output schema, sampling constraint, roots error, and version-rejection interaction tests --- tests/interaction/_requirements.py | 78 ++++++++++++++++++- tests/interaction/lowlevel/test_initialize.py | 60 +++++++++++++- tests/interaction/lowlevel/test_prompts.py | 3 + tests/interaction/lowlevel/test_resources.py | 5 ++ tests/interaction/lowlevel/test_roots.py | 33 +++++++- tests/interaction/lowlevel/test_sampling.py | 56 +++++++++++++ tests/interaction/lowlevel/test_tools.py | 36 +++++++++ tests/interaction/mcpserver/test_prompts.py | 22 ++++++ 8 files changed, 287 insertions(+), 6 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 2e83fb006b..c170d772bf 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -116,6 +116,13 @@ class Requirement: "requested version with its own latest supported version rather than an error." ), ), + "lifecycle:initialize:protocol-version:client-rejects": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "A client that receives an initialize response carrying a protocol version it does not " + "support fails initialization with an error rather than proceeding with the session." + ), + ), "lifecycle:requests-before-initialized": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", behavior="A request sent before the initialization handshake completes is rejected with an error.", @@ -154,6 +161,19 @@ class Requirement: "A cancellation notification referencing an unknown or already-completed request is ignored without error." ), ), + "cancellation:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " + "cancels it, and the client stops processing the cancelled request." + ), + deferred=( + "Not expressible through the public API: abandoning a server-side send_request emits no " + "cancellation notification (the same sender-side gap recorded on timeouts:per-request), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Progress # ═══════════════════════════════════════════════════════════════════════════ @@ -325,6 +345,13 @@ class Requirement: "with the validation failure described in content), not a protocol error." ), ), + "tools:call:output-schema-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-result", + behavior=( + "A tool result whose structuredContent does not conform to the tool's declared outputSchema " + "is rejected by the client: the call raises instead of returning the invalid result." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Completion # ═══════════════════════════════════════════════════════════════════════════ @@ -473,6 +500,17 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/prompts#error-handling", behavior="prompts/get for an unknown prompt name returns a JSON-RPC error.", ), + "prompts:get:missing-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get with a required argument missing returns a JSON-RPC error.", + divergence=Divergence( + note=( + "The spec says missing required arguments are answered with -32602 Invalid params; " + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Sampling (server → client) # ═══════════════════════════════════════════════════════════════════════════ @@ -487,7 +525,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", behavior=( "The sampling parameters supplied by the server (messages, maxTokens, systemPrompt, " - "modelPreferences, temperature, stopSequences) reach the client callback intact." + "modelPreferences, temperature, stopSequences, includeContext) reach the client callback intact." ), ), "sampling:create-message:image-content": Requirement( @@ -501,6 +539,13 @@ class Requirement: "by the server before anything reaches the wire, with an Invalid params error." ), ), + "sampling:create-message:tools:message-constraints": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#message-content-constraints", + behavior=( + "A sampling request whose messages violate the tool_use/tool_result pairing rules is rejected " + "by the server-side validator before anything reaches the wire." + ), + ), "sampling:create-message:tools:round-trip": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#sampling-with-tools", behavior=( @@ -587,6 +632,17 @@ class Requirement: ), ), ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not expressible through the public API: a Client with an elicitation callback always declares " + "both the form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), # ═══════════════════════════════════════════════════════════════════════════ # Roots (server → client) # ═══════════════════════════════════════════════════════════════════════════ @@ -614,6 +670,10 @@ class Requirement: ), ), ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), "roots:list-changed": Requirement( source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", @@ -671,6 +731,22 @@ class Requirement: "tests/shared/test_streamable_http.py." ), ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + deferred=( + "Replay requires dropping and re-establishing the SSE connection, which the in-process ASGI " + "client cannot express. Covered over a real socket by tests/shared/test_streamable_http.py." + ), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="Requests with a disallowed Origin or Host header are rejected before reaching the session.", + deferred=( + "The in-process fixture disables DNS-rebinding protection because no network attack surface " + "exists in-process. Covered by tests/server/test_streamable_http_security.py." + ), + ), "transport:stdio": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", behavior="The interaction round trip works over a stdio subprocess.", diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index be0b0ac2ef..c9debb06be 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -1,8 +1,10 @@ """Initialization handshake against the low-level Server, driven through the public Client API. -The last two tests drive a bare ClientSession over an InMemoryTransport instead: Client always +The later tests drive a bare ClientSession over an InMemoryTransport instead: Client always performs the full handshake with the latest protocol version, so skipping initialization or -requesting a different version can only be expressed one level down. +requesting a different version can only be expressed one level down. The final test goes one step +further and plays the server's side of the wire by hand, because no real Server can be made to +answer initialize with an unsupported protocol version. """ import anyio @@ -14,6 +16,8 @@ from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage from mcp.types import ( INVALID_PARAMS, CallToolResult, @@ -26,6 +30,8 @@ InitializeRequest, InitializeRequestParams, InitializeResult, + JSONRPCRequest, + JSONRPCResponse, ListToolsRequest, ListToolsResult, LoggingCapability, @@ -195,10 +201,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara for name, value in ( ("sampling", capabilities.sampling), ("elicitation", capabilities.elicitation), - ("roots", capabilities.roots), ) if value is not None ] + if capabilities.roots is not None: + declared.append(f"roots(list_changed={capabilities.roots.list_changed})") return CallToolResult(content=[TextContent(text=",".join(declared) or "none")]) async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: @@ -213,7 +220,7 @@ async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: async with Client(server, list_roots_callback=list_roots) as client: result = await client.call_tool("abilities", {}) - assert result == snapshot(CallToolResult(content=[TextContent(text="roots")])) + assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) @requirement("lifecycle:requests-before-initialized") @@ -275,3 +282,48 @@ def initialize_request(protocol_version: str) -> InitializeRequest: with anyio.fail_after(5): result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) assert result.protocol_version == snapshot("2025-11-25") + + +@requirement("lifecycle:initialize:protocol-version:client-rejects") +async def test_unsupported_server_protocol_version_fails_initialization() -> None: + """An initialize response carrying a protocol version the client does not support fails initialization. + + A real Server only ever answers with a version it supports, so this test alone plays the + server's side of the wire by hand: it reads the initialize request off the raw stream and + answers it with a hand-built result. Reserve this pattern for behaviour no real server can + be made to produce. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() + + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index 64ca0ce055..dd27b0f659 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -10,6 +10,7 @@ INVALID_PARAMS, ErrorData, GetPromptResult, + Icon, ListPromptsResult, Prompt, PromptArgument, @@ -35,6 +36,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest PromptArgument(name="code", description="The code to review.", required=True), PromptArgument(name="style_guide", description="Optional style guide to apply."), ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], ), Prompt(name="daily_standup"), ] @@ -55,6 +57,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest PromptArgument(name="code", description="The code to review.", required=True), PromptArgument(name="style_guide", description="Optional style guide to apply."), ], + icons=[Icon(src="https://example.com/review.png", mime_type="image/png", sizes=["48x48"])], ), Prompt(name="daily_standup"), ] diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 96b42d25a2..44d69209f4 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -14,6 +14,7 @@ CallToolResult, EmptyResult, ErrorData, + Icon, ListResourcesResult, ListResourceTemplatesResult, ReadResourceResult, @@ -48,6 +49,7 @@ async def list_resources( mime_type="text/markdown", size=1024, annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], ), ] ) @@ -69,6 +71,7 @@ async def list_resources( mime_type="text/markdown", size=1024, annotations=Annotations(audience=["user", "assistant"], priority=0.8), + icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], ), ] ) @@ -159,6 +162,7 @@ async def list_resource_templates( title="Service logs", description="One day of logs for one service.", mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], ), ] ) @@ -178,6 +182,7 @@ async def list_resource_templates( title="Service logs", description="One day of logs for one service.", mime_type="text/plain", + icons=[Icon(src="https://example.com/logs.png", mime_type="image/png", sizes=["48x48"])], ), ] ) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index b98e7ff315..221be372d5 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -9,7 +9,7 @@ from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, ListRootsResult, Root, TextContent +from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -107,6 +107,37 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) +@requirement("roots:list:client-error") +async def test_list_roots_callback_error_surfaces_to_the_handler() -> None: + """A roots callback that answers with an error fails the roots/list request with that exact error. + + The callback's code and message reach the requesting handler verbatim as an MCPError. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="show_roots", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "show_roots" + try: + await ctx.session.list_roots() + except MCPError as exc: + return CallToolResult(content=[TextContent(text=f"{exc.error.code}: {exc.error.message}")]) + raise NotImplementedError # the callback always answers with an error + + server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) + + async def list_roots(context: ClientRequestContext) -> ErrorData: + return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") + + async with Client(server, list_roots_callback=list_roots) as client: + result = await client.call_tool("show_roots", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) + + @requirement("roots:list-changed") async def test_roots_list_changed_reaches_server_handler() -> None: """A roots/list_changed notification from the client is delivered to the server's handler. diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 7a0a396be9..b2b268d9b7 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -22,6 +22,7 @@ ModelPreferences, SamplingMessage, TextContent, + ToolResultContent, ) from tests.interaction._requirements import requirement @@ -93,6 +94,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara messages=[SamplingMessage(role="user", content=TextContent(text="Pick a model."))], max_tokens=50, system_prompt="You are terse.", + include_context="thisServer", temperature=0.7, stop_sequences=["\n\n", "END"], model_preferences=ModelPreferences( @@ -129,6 +131,7 @@ async def sampling_callback( intelligence_priority=0.9, ), system_prompt="You are terse.", + include_context="thisServer", temperature=0.7, max_tokens=50, stop_sequences=["\n\n", "END"], @@ -330,3 +333,56 @@ async def sampling_callback( assert result == snapshot( CallToolResult(content=[TextContent(text="-32602: Client does not support sampling tools capability")]) ) + + +@requirement("sampling:create-message:tools:message-constraints") +async def test_create_message_with_unbalanced_tool_messages_is_rejected() -> None: + """A sampling request whose messages mix tool results with other content never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, even + when no tools are passed, so the client callback is never invoked and the handler observes the + ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="summarise_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "summarise_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + ToolResultContent(tool_use_id="call-1", content=[TextContent(text="42")]), + TextContent(text="Also, a comment alongside the result."), + ], + ) + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("summarise_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent(text="ValueError: The last message must contain only tool_result content if any is present") + ] + ) + ) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 071180ddfd..81664e8a51 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -317,3 +317,39 @@ async def call_and_record(tag: str) -> None: "second": CallToolResult(content=[TextContent(text="second")]), } ) + + +@requirement("tools:call:output-schema-validation") +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client() -> None: + """A result whose structured content does not conform to the tool's declared output schema never + reaches the caller: the client validates it against the schema cached from tools/list and raises. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")], structured_content={"temperature": "warm"}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with Client(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. + assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 27b44773a6..3f865b077a 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -88,3 +88,25 @@ def greet(name: str) -> str: await client.get_prompt("nope") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) + + +@requirement("prompts:get:missing-arguments") +async def test_get_prompt_with_a_missing_required_argument_is_an_error() -> None: + """Getting a prompt without one of its required arguments fails with a JSON-RPC error. + + The missing argument is detected before the prompt function is called, but the spec's -32602 + Invalid params is reported as error code 0 with the bare exception text (see the divergence + note on the requirement). + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet(name: str) -> str: + """A registered prompt; validation rejects the call before the function runs.""" + raise NotImplementedError + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("greet") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) From bdfded0ce99ce177c1fa936fa5b1c53899939189 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 13:55:00 +0000 Subject: [PATCH 11/34] test: align requirement IDs, add transport applicability, and enforce two-way test coverage --- tests/interaction/_requirements.py | 368 ++++++++++++++---- .../interaction/lowlevel/test_cancellation.py | 6 +- tests/interaction/lowlevel/test_completion.py | 7 +- .../interaction/lowlevel/test_elicitation.py | 13 +- tests/interaction/lowlevel/test_initialize.py | 9 +- .../interaction/lowlevel/test_list_changed.py | 6 +- tests/interaction/lowlevel/test_logging.py | 3 +- tests/interaction/lowlevel/test_pagination.py | 9 +- tests/interaction/lowlevel/test_progress.py | 9 +- tests/interaction/lowlevel/test_prompts.py | 2 +- tests/interaction/lowlevel/test_resources.py | 4 +- tests/interaction/lowlevel/test_roots.py | 2 +- tests/interaction/lowlevel/test_sampling.py | 13 +- tests/interaction/lowlevel/test_timeouts.py | 7 +- tests/interaction/lowlevel/test_tools.py | 6 +- tests/interaction/mcpserver/test_context.py | 7 +- tests/interaction/mcpserver/test_prompts.py | 8 +- tests/interaction/mcpserver/test_resources.py | 9 +- tests/interaction/mcpserver/test_tools.py | 15 +- tests/interaction/test_coverage.py | 47 ++- 20 files changed, 412 insertions(+), 138 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index c170d772bf..fb502d96c3 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1,8 +1,10 @@ """Requirements manifest for the interaction-model test suite. Every user-facing behaviour the SDK must satisfy, keyed by a stable `:[:]` -ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` and -`test_coverage.py` enforces that every non-deferred requirement is exercised by at least one test. +ID. Each entry owns the tests that exercise it: tests declare `@requirement("")` (a test that +proves several behaviours stacks several decorators) and `test_coverage.py` enforces the contract +in both directions: every non-deferred requirement has at least one test, and every test carries +at least one requirement. Sources: spec URL -- externally mandated by the MCP specification (deep link to the section) @@ -13,19 +15,31 @@ behaviour. Where that differs from what `source` mandates, the gap is recorded in `divergence` and the tests still pin current behaviour: this suite is the parity bar for the receive-path rewrite, so a test that fails today proves nothing about equivalence. + +`transports` records which transports a behaviour applies to (or is observable on); None means +the behaviour is transport-independent. + +The ID vocabulary and entry granularity are aligned with the TypeScript SDK's end-to-end +requirements suite, so coverage and recorded divergences can be compared across the two SDKs +entry by entry; IDs that exist in only one SDK reflect genuinely different API surface. """ +import re from collections.abc import Callable from dataclasses import dataclass -from typing import TypeVar +from typing import Literal, TypeVar import pytest SPEC_REVISION = "2025-11-25" SPEC_BASE_URL = f"https://modelcontextprotocol.io/specification/{SPEC_REVISION}" +Transport = Literal["in-memory", "stdio", "streamable-http", "sse"] + _TestFn = TypeVar("_TestFn", bound=Callable[..., object]) +_SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") + @dataclass(frozen=True, kw_only=True) class Divergence: @@ -41,9 +55,14 @@ class Requirement: source: str behavior: str + transports: tuple[Transport, ...] | None = None divergence: Divergence | None = None deferred: str | None = None + def __post_init__(self) -> None: + if not _SOURCE_PATTERN.fullmatch(self.source): + raise ValueError(f"source must be a specification URL, 'sdk', or 'issue:#n', got {self.source!r}") + REQUIREMENTS: dict[str, Requirement] = { # ═══════════════════════════════════════════════════════════════════════════ @@ -73,6 +92,10 @@ class Requirement: ), ), ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Lifecycle # ═══════════════════════════════════════════════════════════════════════════ @@ -109,14 +132,18 @@ class Requirement: "(sampling, elicitation, roots)." ), ), - "lifecycle:initialize:protocol-version": Requirement( + "lifecycle:version:match": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior="The server echoes a requested protocol version it supports in the initialize result.", + ), + "lifecycle:version:server-fallback-latest": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", behavior=( - "The server echoes a requested protocol version it supports, and answers an unsupported " - "requested version with its own latest supported version rather than an error." + "An initialize request carrying a protocol version the server does not support is answered " + "with the server's latest supported version rather than an error." ), ), - "lifecycle:initialize:protocol-version:client-rejects": Requirement( + "lifecycle:version:reject-unsupported": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", behavior=( "A client that receives an initialize response carrying a protocol version it does not " @@ -137,7 +164,7 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Cancellation # ═══════════════════════════════════════════════════════════════════════════ - "cancellation:in-flight": Requirement( + "protocol:cancel:in-flight": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( "A cancellation notification for an in-flight request stops the server-side handler, and the " @@ -151,17 +178,17 @@ class Requirement: ), ), ), - "cancellation:server-survives": Requirement( + "protocol:cancel:server-survives": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior="The session continues to serve new requests after an earlier request was cancelled.", ), - "cancellation:unknown-request": Requirement( + "protocol:cancel:unknown-id-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( "A cancellation notification referencing an unknown or already-completed request is ignored without error." ), ), - "cancellation:server-to-client": Requirement( + "protocol:cancel:server-to-client": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " @@ -169,48 +196,53 @@ class Requirement: ), deferred=( "Not expressible through the public API: abandoning a server-side send_request emits no " - "cancellation notification (the same sender-side gap recorded on timeouts:per-request), and " - "the client could not act on one anyway because client callbacks run inline in the receive " - "loop, so a cancellation would not even be read until the callback had already finished." + "cancellation notification (the same sender-side gap recorded on " + "protocol:timeout:sends-cancellation), and the client could not act on one anyway because " + "client callbacks run inline in the receive loop, so a cancellation would not even be read " + "until the callback had already finished." ), ), # ═══════════════════════════════════════════════════════════════════════════ # Progress # ═══════════════════════════════════════════════════════════════════════════ - "progress:server-to-client": Requirement( + "protocol:progress:callback": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( "Progress notifications emitted by a handler during a request are delivered to the caller's " "progress callback, in order, with their progress, total, and message." ), ), - "progress:token-propagation": Requirement( + "protocol:progress:token-injected": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( "Supplying a progress callback attaches a progress token to the outgoing request, which the " "server-side handler can observe in its request metadata." ), ), - "progress:no-token": Requirement( + "protocol:progress:no-token": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( "Without a progress callback no token is attached, and a handler that reports progress anyway " "sends nothing." ), ), - "progress:client-to-server": Requirement( + "protocol:progress:client-to-server": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior="A progress notification sent by the client is delivered to the server's progress handler.", ), # ═══════════════════════════════════════════════════════════════════════════ # Timeouts # ═══════════════════════════════════════════════════════════════════════════ - "timeouts:per-request": Requirement( + "protocol:timeout:basic": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", behavior=( "A request that exceeds its read timeout fails with a request-timeout error instead of " "waiting forever for the response." ), + ), + "protocol:timeout:sends-cancellation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A request that times out fails the caller; the server handler is not cancelled and keeps running.", divergence=Divergence( note=( "The spec says the requester SHOULD issue a cancellation notification for the timed-out " @@ -219,18 +251,18 @@ class Requirement: ), ), ), - "timeouts:session-survives": Requirement( + "protocol:timeout:session-survives": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", behavior="The session continues to serve new requests after an earlier request timed out.", ), - "timeouts:session-default": Requirement( + "protocol:timeout:session-default": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", behavior="A session-level read timeout applies to every request that does not override it.", ), # ═══════════════════════════════════════════════════════════════════════════ # Pagination # ═══════════════════════════════════════════════════════════════════════════ - "pagination:cursor-round-trip": Requirement( + "tools:list:pagination": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", behavior=( "The nextCursor returned by a list handler reaches the client, and the cursor the client " @@ -244,15 +276,15 @@ class Requirement: "nextCursor ends the sequence." ), ), - "pagination:resources": Requirement( + "resources:list:pagination": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", behavior="resources/list supports cursor pagination.", ), - "pagination:resource-templates": Requirement( + "resources:templates:pagination": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", behavior="resources/templates/list supports cursor pagination.", ), - "pagination:prompts": Requirement( + "prompts:list:pagination": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", behavior="prompts/list supports cursor pagination.", ), @@ -281,11 +313,15 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Tools # ═══════════════════════════════════════════════════════════════════════════ + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), "tools:list:basic": Requirement( source=f"{SPEC_BASE_URL}/server/tools#listing-tools", behavior="tools/list returns the registered tools with name, description, and inputSchema.", ), - "tools:list:optional-fields": Requirement( + "tools:list:metadata": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool", behavior=( "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " @@ -312,7 +348,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", behavior="A tool result can carry an embedded resource with full text or blob contents.", ), - "tools:call:content:multiple": Requirement( + "tools:call:content:mixed": Requirement( source=f"{SPEC_BASE_URL}/server/tools#calling-tools", behavior="A tool result can carry multiple content blocks of different types; order is preserved.", ), @@ -320,6 +356,10 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/tools#structured-content", behavior="A tool result can carry structuredContent alongside content; the client receives both.", ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), "tools:call:is-error": Requirement( source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior=( @@ -338,14 +378,35 @@ class Requirement: "receives the response to its own request." ), ), - "tools:call:invalid-arguments": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#error-handling", + "tools:call:elicitation-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", behavior=( - "Arguments that fail the tool's input validation produce a tool execution error (isError true " - "with the validation failure described in content), not a protocol error." + "A tool handler that issues an elicitation receives the client's result and can embed it in " + "the tool call result." + ), + ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:logging-mid-execution": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "Log notifications emitted by a tool handler during execution reach the client's logging " + "callback before the tool result returns." ), ), - "tools:call:output-schema-validation": Requirement( + "tools:call:progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "Progress notifications emitted by a tool handler reach the caller's progress callback before " + "the tool result returns." + ), + ), + "client:output-schema:validate": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool-result", behavior=( "A tool result whose structuredContent does not conform to the tool's declared outputSchema " @@ -355,15 +416,19 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Completion # ═══════════════════════════════════════════════════════════════════════════ - "completion:complete:prompt-ref": Requirement( + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:prompt-arg": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", ), - "completion:complete:resource-ref": Requirement( + "completion:resource-template-arg": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", ), - "completion:complete:context": Requirement( + "completion:context-arguments": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", ), @@ -381,7 +446,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", ), - "logging:message:notification": Requirement( + "logging:message:fields": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", behavior=( "A log message sent by a server handler is delivered to the client's logging callback with its " @@ -392,7 +457,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", ), - "logging:capability": Requirement( + "logging:capability:declared": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", behavior=( "MCPServer tools emit log message notifications through the Context helpers while the server's " @@ -406,7 +471,7 @@ class Requirement: ), ), ), - "logging:set-level:filtering": Requirement( + "logging:message:filtered": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", behavior=( "MCPServer registers no logging/setLevel handler (the request is rejected with method-not-found) " @@ -424,6 +489,13 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Resources # ═══════════════════════════════════════════════════════════════════════════ + "resources:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior=( + "A server with resource handlers advertises the resources capability, including the subscribe " + "sub-flag when a subscribe handler is registered." + ), + ), "resources:list:basic": Requirement( source=f"{SPEC_BASE_URL}/server/resources#listing-resources", behavior=( @@ -435,11 +507,11 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/resources#reading-resources", behavior="resources/read returns text contents carrying uri, mimeType, and the text.", ), - "resources:read:binary": Requirement( + "resources:read:blob": Requirement( source=f"{SPEC_BASE_URL}/server/resources#reading-resources", behavior="resources/read returns binary contents base64-encoded in blob.", ), - "resources:read:not-found": Requirement( + "resources:read:unknown-uri": Requirement( source=f"{SPEC_BASE_URL}/server/resources#error-handling", behavior="resources/read for an unknown URI returns a JSON-RPC error; the spec reserves -32002 for it.", ), @@ -449,6 +521,10 @@ class Requirement: "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." ), ), + "resources:read:template-vars": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#resource-templates", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), "resources:subscribe": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", @@ -459,6 +535,15 @@ class Requirement: "resources/unsubscribe delivers the URI to the server's unsubscribe handler and returns an empty result." ), ), + "resources:unsubscribe:stops-updates": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", + deferred=( + "The SDK keeps no subscription state -- emitting updated notifications is entirely handler " + "code -- so there is no SDK behaviour to pin beyond the unsubscribe request reaching the " + "handler (covered by resources:unsubscribe)." + ), + ), "resources:updated-notification": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", behavior=( @@ -469,26 +554,30 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Notifications: list_changed (server → client) # ═══════════════════════════════════════════════════════════════════════════ - "notifications:tools:list-changed": Requirement( + "tools:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", behavior="A tools/list_changed notification sent by the server reaches the client's message handler.", ), - "notifications:resources:list-changed": Requirement( + "resources:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", behavior="A resources/list_changed notification sent by the server reaches the client's message handler.", ), - "notifications:prompts:list-changed": Requirement( + "prompts:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", behavior="A prompts/list_changed notification sent by the server reaches the client's message handler.", ), # ═══════════════════════════════════════════════════════════════════════════ # Prompts # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", + ), "prompts:list:basic": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", ), - "prompts:get:arguments": Requirement( + "prompts:get:with-args": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", ), @@ -500,7 +589,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/server/prompts#error-handling", behavior="prompts/get for an unknown prompt name returns a JSON-RPC error.", ), - "prompts:get:missing-arguments": Requirement( + "prompts:get:missing-required-args": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", behavior="prompts/get with a required argument missing returns a JSON-RPC error.", divergence=Divergence( @@ -514,39 +603,47 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Sampling (server → client) # ═══════════════════════════════════════════════════════════════════════════ - "sampling:create-message:round-trip": Requirement( + "sampling:create:basic": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", behavior=( "A sampling/createMessage request from a server handler is answered by the client's sampling " "callback, and the callback's result (role, content, model, stopReason) is returned to the handler." ), ), - "sampling:create-message:params": Requirement( + "sampling:create:include-context": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The includeContext value supplied by the server reaches the client callback intact.", + ), + "sampling:create:model-preferences": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", behavior=( - "The sampling parameters supplied by the server (messages, maxTokens, systemPrompt, " - "modelPreferences, temperature, stopSequences, includeContext) reach the client callback intact." + "The model preferences supplied by the server (hints and the cost, speed, and intelligence " + "priorities) reach the client callback intact." ), ), + "sampling:create:system-prompt": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior="The system prompt supplied by the server reaches the client callback intact.", + ), "sampling:create-message:image-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#image-content", behavior="Sampling messages can carry image content: base64 data with a mimeType.", ), - "sampling:create-message:tools:not-supported": Requirement( + "sampling:tools:server-gated-by-capability": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#capabilities", behavior=( "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " "by the server before anything reaches the wire, with an Invalid params error." ), ), - "sampling:create-message:tools:message-constraints": Requirement( + "sampling:tool-result:no-mixed-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#message-content-constraints", behavior=( "A sampling request whose messages violate the tool_use/tool_result pairing rules is rejected " "by the server-side validator before anything reaches the wire." ), ), - "sampling:create-message:tools:round-trip": Requirement( + "sampling:create:tools": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#sampling-with-tools", behavior=( "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " @@ -558,7 +655,7 @@ class Requirement: "server-side validator rejects every tool-enabled request before it is sent." ), ), - "sampling:create-message:client-error": Requirement( + "sampling:error:user-rejected": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#error-handling", behavior="A sampling callback that returns an error is surfaced to the requesting handler as an MCPError.", ), @@ -573,22 +670,36 @@ class Requirement: # ═══════════════════════════════════════════════════════════════════════════ # Elicitation (server → client) # ═══════════════════════════════════════════════════════════════════════════ - "elicitation:form:accept": Requirement( + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior=( + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." + ), + ), + "elicitation:form:action:accept": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", behavior=( "A form-mode elicitation answered with action 'accept' returns the user's content to the " "requesting handler." ), ), - "elicitation:form:decline": Requirement( + "elicitation:form:action:decline": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", ), - "elicitation:form:cancel": Requirement( + "elicitation:form:action:cancel": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", ), - "elicitation:url:accept": Requirement( + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:action:accept-no-content": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", behavior=( "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " @@ -604,7 +715,7 @@ class Requirement: source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior="A URL-mode elicitation answered with cancel returns the action with no content.", ), - "elicitation:complete-notification": Requirement( + "elicitation:url:complete-notification": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", behavior=( "An elicitation/complete notification sent by the server after an out-of-band elicitation " @@ -643,10 +754,15 @@ class Requirement: "both the form and url sub-capabilities, so a form-only client cannot be constructed." ), ), + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + behavior="A client that declares the defaults capability receives requested schemas with defaults applied.", + deferred="The SDK does not implement the defaults sub-capability on either side.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Roots (server → client) # ═══════════════════════════════════════════════════════════════════════════ - "roots:list:round-trip": Requirement( + "roots:list:basic": Requirement( source=f"{SPEC_BASE_URL}/client/roots#listing-roots", behavior=( "A roots/list request from a server handler is answered by the client's roots callback, and " @@ -687,10 +803,12 @@ class Requirement: "The interaction round trip (initialize, tool calls, tool errors) works through the " "streamable HTTP framing in its default stateful SSE-response mode." ), + transports=("streamable-http",), ), "transport:streamable-http:json-response": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), ), "transport:streamable-http:stateless": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", @@ -698,6 +816,7 @@ class Requirement: "The interaction round trip works in stateless mode, where every request is served by a " "fresh transport with no session id." ), + transports=("streamable-http",), ), "transport:streamable-http:notifications": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", @@ -705,6 +824,7 @@ class Requirement: "Notifications emitted during a request are delivered on that request's SSE stream and reach " "the client's callbacks, in order, before the response." ), + transports=("streamable-http",), ), "transport:streamable-http:stateless-restrictions": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", @@ -712,6 +832,7 @@ class Requirement: "A handler that attempts a server-initiated request in stateless mode fails with an error " "result, because there is no session to call back through." ), + transports=("streamable-http",), ), "transport:streamable-http:unrelated-messages": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", @@ -719,12 +840,14 @@ class Requirement: "A server-to-client message that is not related to an in-flight request is routed to the " "standalone GET stream; a client that never opened one does not receive it." ), + transports=("streamable-http",), ), "transport:streamable-http:server-to-client": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior=( "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." ), + transports=("streamable-http",), deferred=( "The in-process ASGI client buffers each response in full, which deadlocks on a " "server-to-client request nested inside a still-open call. Covered over a real socket by " @@ -734,6 +857,7 @@ class Requirement: "transport:streamable-http:resumability": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), deferred=( "Replay requires dropping and re-establishing the SSE connection, which the in-process ASGI " "client cannot express. Covered over a real socket by tests/shared/test_streamable_http.py." @@ -742,51 +866,150 @@ class Requirement: "transport:streamable-http:origin-validation": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior="Requests with a disallowed Origin or Host header are rejected before reaching the session.", + transports=("streamable-http",), deferred=( "The in-process fixture disables DNS-rebinding protection because no network attack surface " "exists in-process. Covered by tests/server/test_streamable_http_security.py." ), ), + "transport:streamable-http:session-management": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "The server issues a session id on initialize, validates it on subsequent requests, isolates " + "sessions, and tears the session down on DELETE." + ), + transports=("streamable-http",), + deferred=( + "Covered at the wire level by tests/shared/test_streamable_http.py and " + "tests/server/test_streamable_http_manager.py; this suite drives sessions only through the " + "client API." + ), + ), + "transport:streamable-http:wire-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The server validates Accept and Content-Type headers, the protocol-version header, and " + "malformed JSON bodies, answering with the documented HTTP status codes." + ), + transports=("streamable-http",), + deferred=( + "Raw-HTTP request/response validation is covered by tests/shared/test_streamable_http.py; " + "this suite only sends well-formed traffic through the client." + ), + ), + "transport:streamable-http:client-reconnect": Requirement( + source="sdk", + behavior=( + "The HTTP client transport reconnects dropped SSE streams, honours the server-provided retry " + "interval, and resumes from the last event id." + ), + transports=("streamable-http",), + deferred=( + "Reconnection and resumption behaviour needs a droppable connection; covered by " + "tests/shared/test_streamable_http.py over a real socket." + ), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + deferred=( + "The legacy SSE transport is covered by tests/shared/test_sse.py; in-process coverage in this " + "suite arrives with the transport fixture work." + ), + ), "transport:stdio": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", behavior="The interaction round trip works over a stdio subprocess.", + transports=("stdio",), deferred=( "Requires a real subprocess. Process lifecycle is covered by tests/client/test_stdio.py and " "end-to-end stdio coverage belongs to the cross-SDK conformance suite." ), ), # ═══════════════════════════════════════════════════════════════════════════ - # MCPServer behavioural guarantees (not spec-mandated) + # Authorization # ═══════════════════════════════════════════════════════════════════════════ - "mcpserver:tools:output-schema:model": Requirement( + "auth:client-oauth": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization", + behavior=( + "The client performs the OAuth 2.1 authorization flow (metadata discovery, PKCE, dynamic " + "client registration, token refresh, resource parameter) when a server requires authorization." + ), + transports=("streamable-http",), + deferred=( + "Authorization is out of scope for this suite. Client-side flow coverage lives in " + "tests/client/test_auth.py, tests/client/auth/, and tests/shared/test_auth_utils.py." + ), + ), + "auth:server-enforcement": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization", + behavior=( + "A server protecting its endpoints rejects missing, invalid, expired, or under-scoped tokens " + "with 401/403 and serves protected-resource metadata." + ), + transports=("streamable-http",), + deferred=( + "Authorization is out of scope for this suite. Server-side enforcement coverage lives in " + "tests/server/auth/ and tests/shared/test_auth.py." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:experimental": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks", + behavior=( + "Task-augmented requests (tasks/create, tasks/get, tasks/list, tasks/cancel, task-status " + "notifications and task-scoped side-channel requests) run the documented task lifecycle." + ), + deferred=( + "Tasks are experimental and under active spec revision; the suite excludes them. Python task " + "behaviour is covered by tests/experimental/tasks/." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer behaviours + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior=( + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content), not a protocol error." + ), + ), + "mcpserver:tool:output-schema:model": Requirement( source="sdk", behavior=( "A tool returning a typed model advertises a matching generated outputSchema and returns the " "model's fields as structuredContent alongside a serialised text block." ), ), - "mcpserver:tools:output-schema:wrapped": Requirement( + "mcpserver:tool:output-schema:wrapped": Requirement( source="sdk", behavior=( "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " "structuredContent, with a matching generated outputSchema." ), ), - "mcpserver:resources:static": Requirement( + "mcpserver:resource:static": Requirement( source="sdk", behavior=( "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " "served by resources/read at that URI." ), ), - "mcpserver:resources:template": Requirement( + "mcpserver:resource:template": Requirement( source="sdk", behavior=( "A function registered with a URI template is listed by resources/templates/list and matched " "by resources/read, receiving the parameters extracted from the requested URI." ), ), - "mcpserver:resources:unknown-uri": Requirement( + "mcpserver:resource:unknown-uri": Requirement( source="sdk", behavior="resources/read for a URI matching no registered resource returns a JSON-RPC error.", divergence=Divergence( @@ -796,14 +1019,14 @@ class Requirement: ), ), ), - "mcpserver:prompts:decorated": Requirement( + "mcpserver:prompt:decorated": Requirement( source="sdk", behavior=( "A function registered with @mcp.prompt() is listed with arguments derived from its signature " "and rendered into prompt messages by prompts/get." ), ), - "mcpserver:prompts:unknown-name": Requirement( + "mcpserver:prompt:unknown-name": Requirement( source="sdk", behavior="prompts/get for a name that was never registered returns a JSON-RPC error.", divergence=Divergence( @@ -837,12 +1060,9 @@ class Requirement: source="sdk", behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", ), - "mcpserver:tools:list-changed-on-mutation": Requirement( + "mcpserver:register:post-connect": Requirement( source="sdk", - behavior=( - "Adding or removing a tool on a running server changes what tools/list returns but sends no " - "notification to connected clients." - ), + behavior=("A tool added or removed after the client connected is reflected in subsequent tools/list results."), divergence=Divergence( note=( "The spec provides notifications/tools/list_changed for exactly this case; MCPServer never " @@ -850,14 +1070,14 @@ class Requirement: ), ), ), - "mcpserver:tools:handler-exception": Requirement( + "mcpserver:tool:handler-throws": Requirement( source="sdk", behavior=( "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." ), ), - "mcpserver:tools:unknown-name": Requirement( + "mcpserver:tool:unknown-name": Requirement( source="sdk", behavior="Calling a tool name that was never registered returns a tool result with isError true.", divergence=Divergence( diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 30821c1294..591c66efa2 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -19,7 +19,7 @@ pytestmark = pytest.mark.anyio -@requirement("cancellation:in-flight") +@requirement("protocol:cancel:in-flight") async def test_cancellation_stops_in_flight_handler() -> None: """Cancelling an in-flight request interrupts its handler and fails the pending call. @@ -68,7 +68,7 @@ async def call_and_capture_error() -> None: assert errors == snapshot([ErrorData(code=0, message="Request cancelled")]) -@requirement("cancellation:server-survives") +@requirement("protocol:cancel:server-survives") async def test_session_serves_requests_after_cancellation() -> None: """A request cancelled mid-flight does not poison the session: the next request succeeds.""" started = anyio.Event() @@ -114,7 +114,7 @@ async def call_and_swallow_cancellation_error() -> None: assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) -@requirement("cancellation:unknown-request") +@requirement("protocol:cancel:unknown-id-ignored") async def test_cancellation_for_unknown_request_is_ignored() -> None: """A cancellation referencing a request id that is not in flight is ignored without error.""" diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index 91fd20a5a0..f5deaa89f6 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -19,7 +19,7 @@ pytestmark = pytest.mark.anyio -@requirement("completion:complete:prompt-ref") +@requirement("completion:prompt-arg") async def test_complete_prompt_argument() -> None: """Completing a prompt argument delivers the ref, argument name, and current value to the handler. @@ -46,7 +46,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar ) -@requirement("completion:complete:resource-ref") +@requirement("completion:resource-template-arg") async def test_complete_resource_template_variable() -> None: """Completing a URI template variable delivers the template URI and variable name to the handler.""" @@ -67,7 +67,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol"]))) -@requirement("completion:complete:context") +@requirement("completion:context-arguments") async def test_complete_receives_context_arguments() -> None: """Previously-resolved arguments passed as completion context reach the handler. @@ -93,6 +93,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("completion:complete:not-supported") +@requirement("protocol:error:method-not-found") async def test_complete_without_handler_is_method_not_found() -> None: """A server with no completion handler advertises no completions capability and rejects the request.""" server = Server("incomplete") diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 056173ac2b..d46728be2e 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -32,7 +32,9 @@ } -@requirement("elicitation:form:accept") +@requirement("elicitation:form:action:accept") +@requirement("elicitation:form:basic") +@requirement("tools:call:elicitation-roundtrip") async def test_elicit_form_accepted_content_returns_to_handler() -> None: """An accepted form elicitation returns the user's content to the requesting handler. @@ -86,7 +88,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest ) -@requirement("elicitation:form:decline") +@requirement("elicitation:form:action:decline") async def test_elicit_form_decline_returns_no_content() -> None: """A declined form elicitation returns the decline action to the handler with no content.""" @@ -113,7 +115,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) -@requirement("elicitation:form:cancel") +@requirement("elicitation:form:action:cancel") async def test_elicit_form_cancel_returns_no_content() -> None: """A cancelled form elicitation returns the cancel action to the handler with no content.""" @@ -172,7 +174,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) -@requirement("elicitation:url:accept") +@requirement("elicitation:url:action:accept-no-content") +@requirement("elicitation:url:basic") async def test_elicit_url_delivers_url_and_returns_accept_without_content() -> None: """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it returns the action with no content. @@ -276,7 +279,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) -@requirement("elicitation:complete-notification") +@requirement("elicitation:url:complete-notification") async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client() -> None: """After a URL elicitation finishes, the server announces it with a notification carrying the same id. diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index c9debb06be..074a8c12c2 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -84,6 +84,10 @@ async def test_initialize_returns_instructions() -> None: @requirement("lifecycle:initialize:capabilities:from-handlers") +@requirement("tools:capability:declared") +@requirement("resources:capability:declared") +@requirement("prompts:capability:declared") +@requirement("completion:capability:declared") async def test_initialize_capabilities_reflect_registered_handlers() -> None: """Each feature area with a registered handler is advertised as a capability. @@ -253,7 +257,8 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa assert pong == snapshot(EmptyResult()) -@requirement("lifecycle:initialize:protocol-version") +@requirement("lifecycle:version:match") +@requirement("lifecycle:version:server-fallback-latest") async def test_initialize_negotiates_protocol_version() -> None: """The server echoes a supported requested version and answers an unsupported one with its latest. @@ -284,7 +289,7 @@ def initialize_request(protocol_version: str) -> InitializeRequest: assert result.protocol_version == snapshot("2025-11-25") -@requirement("lifecycle:initialize:protocol-version:client-rejects") +@requirement("lifecycle:version:reject-unsupported") async def test_unsupported_server_protocol_version_fails_initialization() -> None: """An initialize response carrying a protocol version the client does not support fails initialization. diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index 9bbdf7ee75..e06c6f33f6 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -25,7 +25,7 @@ pytestmark = pytest.mark.anyio -@requirement("notifications:tools:list-changed") +@requirement("tools:list-changed") async def test_tool_list_changed_notification() -> None: """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] @@ -51,7 +51,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert received == snapshot([ToolListChangedNotification()]) -@requirement("notifications:resources:list-changed") +@requirement("resources:list-changed") async def test_resource_list_changed_notification() -> None: """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] @@ -77,7 +77,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert received == snapshot([ResourceListChangedNotification()]) -@requirement("notifications:prompts:list-changed") +@requirement("prompts:list-changed") async def test_prompt_list_changed_notification() -> None: """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index 600724259f..9f9110a3cf 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -48,7 +48,8 @@ async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelReq assert result == snapshot(EmptyResult()) -@requirement("logging:message:notification") +@requirement("logging:message:fields") +@requirement("tools:call:logging-mid-execution") async def test_log_messages_reach_logging_callback_in_order() -> None: """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 0c585d7896..3304450a6f 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -26,7 +26,7 @@ pytestmark = pytest.mark.anyio -@requirement("pagination:cursor-round-trip") +@requirement("tools:list:pagination") async def test_next_cursor_round_trips_through_the_client() -> None: """The next_cursor a list handler returns reaches the client, and the cursor the client sends back on the following call reaches the handler verbatim. @@ -57,6 +57,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("pagination:exhaustion") +@requirement("tools:list:pagination") async def test_paginating_until_next_cursor_is_absent_yields_every_page() -> None: """Following next_cursor until it is absent visits every page exactly once, in order.""" pages: dict[str | None, tuple[str, str | None]] = { @@ -89,7 +90,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa assert requests_made == len(pages) -@requirement("pagination:resources") +@requirement("resources:list:pagination") async def test_resources_list_supports_cursor_pagination() -> None: """resources/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] @@ -116,7 +117,7 @@ async def list_resources( assert second_page.next_cursor is None -@requirement("pagination:resource-templates") +@requirement("resources:templates:pagination") async def test_resource_templates_list_supports_cursor_pagination() -> None: """resources/templates/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] @@ -148,7 +149,7 @@ async def list_resource_templates( assert second_page.next_cursor is None -@requirement("pagination:prompts") +@requirement("prompts:list:pagination") async def test_prompts_list_supports_cursor_pagination() -> None: """prompts/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 229f8edf6f..f39737a27f 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -20,7 +20,8 @@ pytestmark = pytest.mark.anyio -@requirement("progress:server-to-client") +@requirement("protocol:progress:callback") +@requirement("tools:call:progress") async def test_progress_during_tool_call_reaches_callback_in_order() -> None: """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" received: list[tuple[float, float | None, str | None]] = [] @@ -52,7 +53,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert received == snapshot([(1.0, 3.0, "first chunk"), (2.0, 3.0, "second chunk"), (3.0, 3.0, "done")]) -@requirement("progress:token-propagation") +@requirement("protocol:progress:token-injected") async def test_progress_token_visible_to_handler() -> None: """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" @@ -79,7 +80,7 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) -@requirement("progress:no-token") +@requirement("protocol:progress:no-token") async def test_no_progress_callback_means_no_token() -> None: """Without a progress callback the request carries no progress token. @@ -105,7 +106,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) -@requirement("progress:client-to-server") +@requirement("protocol:progress:client-to-server") async def test_client_progress_notification_reaches_server_handler() -> None: """A progress notification sent by the client is delivered to the server's progress handler.""" received: list[ProgressNotificationParams] = [] diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index dd27b0f659..52ef3a85d4 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -65,7 +65,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest ) -@requirement("prompts:get:arguments") +@requirement("prompts:get:with-args") async def test_get_prompt_substitutes_arguments() -> None: """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 44d69209f4..5b02797020 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -99,7 +99,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq ) -@requirement("resources:read:binary") +@requirement("resources:read:blob") async def test_read_resource_binary() -> None: """Reading a binary resource returns its contents base64-encoded in the blob field.""" @@ -126,7 +126,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq ) -@requirement("resources:read:not-found") +@requirement("resources:read:unknown-uri") async def test_read_resource_unknown_uri_is_protocol_error() -> None: """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 221be372d5..94cd1b9303 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -15,7 +15,7 @@ pytestmark = pytest.mark.anyio -@requirement("roots:list:round-trip") +@requirement("roots:list:basic") async def test_list_roots_round_trip() -> None: """A roots/list request from a tool handler is answered by the client's roots callback. diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index b2b268d9b7..6903f86abb 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -29,7 +29,8 @@ pytestmark = pytest.mark.anyio -@requirement("sampling:create-message:round-trip") +@requirement("sampling:create:basic") +@requirement("tools:call:sampling-roundtrip") async def test_create_message_round_trip() -> None: """A handler's sampling request is answered by the client callback, and the callback's result (role, content, model, stop reason) is returned to the handler. @@ -78,7 +79,9 @@ async def sampling_callback( ) -@requirement("sampling:create-message:params") +@requirement("sampling:create:include-context") +@requirement("sampling:create:model-preferences") +@requirement("sampling:create:system-prompt") async def test_create_message_params_reach_callback() -> None: """Every sampling parameter the handler supplies arrives at the client callback unchanged.""" received: list[CreateMessageRequestParams] = [] @@ -231,7 +234,7 @@ async def sampling_callback( assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) -@requirement("sampling:create-message:client-error") +@requirement("sampling:error:user-rejected") async def test_create_message_callback_error() -> None: """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. @@ -294,7 +297,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) -@requirement("sampling:create-message:tools:not-supported") +@requirement("sampling:tools:server-gated-by-capability") async def test_create_message_with_tools_is_rejected_for_unsupporting_client() -> None: """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. @@ -335,7 +338,7 @@ async def sampling_callback( ) -@requirement("sampling:create-message:tools:message-constraints") +@requirement("sampling:tool-result:no-mixed-content") async def test_create_message_with_unbalanced_tool_messages_is_rejected() -> None: """A sampling request whose messages mix tool results with other content never leaves the server. diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index 4e7c64fba2..a9c83d641d 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -19,7 +19,8 @@ pytestmark = pytest.mark.anyio -@requirement("timeouts:per-request") +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") async def test_request_timeout_fails_the_pending_call() -> None: """A request whose response does not arrive within its read timeout fails with a timeout error. @@ -53,7 +54,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) -@requirement("timeouts:session-survives") +@requirement("protocol:timeout:session-survives") async def test_session_serves_requests_after_timeout() -> None: """A timed-out request does not poison the session: the next request succeeds.""" @@ -84,7 +85,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="still alive")])) -@requirement("timeouts:session-default") +@requirement("protocol:timeout:session-default") async def test_session_level_timeout_applies_to_every_request() -> None: """A read timeout configured on the client applies to requests that do not set their own.""" diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 81664e8a51..a2ee65109c 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -158,7 +158,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ) -@requirement("tools:list:optional-fields") +@requirement("tools:list:metadata") async def test_list_tools_optional_fields_round_trip() -> None: """Every optional Tool field the server supplies reaches the client unchanged.""" @@ -199,7 +199,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ) -@requirement("tools:call:content:multiple") +@requirement("tools:call:content:mixed") @requirement("tools:call:content:image") @requirement("tools:call:content:audio") @requirement("tools:call:content:resource-link") @@ -319,7 +319,7 @@ async def call_and_record(tag: str) -> None: ) -@requirement("tools:call:output-schema-validation") +@requirement("client:output-schema:validate") async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client() -> None: """A result whose structured content does not conform to the tool's declared output schema never reaches the caller: the client validates it against the schema cached from tools/list and raises. diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index d24fd62511..9ccbd8fdd8 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -27,7 +27,7 @@ @requirement("mcpserver:context:logging") -@requirement("logging:capability") +@requirement("logging:capability:declared") async def test_context_logging_helpers_send_log_notifications() -> None: """Each Context logging helper sends a log message notification at the matching severity. @@ -93,7 +93,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) -@requirement("progress:no-token") +@requirement("protocol:progress:no-token") async def test_report_progress_without_a_progress_token_sends_nothing() -> None: """When the caller supplied no progress callback, Context.report_progress is a silent no-op. @@ -125,6 +125,7 @@ async def collect(message: IncomingMessage) -> None: @requirement("mcpserver:context:elicit") +@requirement("tools:call:elicitation-roundtrip") async def test_context_elicit_returns_typed_result() -> None: """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. @@ -206,7 +207,7 @@ async def show_config(ctx: Context) -> str: ) -@requirement("logging:set-level:filtering") +@requirement("logging:message:filtered") async def test_set_logging_level_is_rejected_and_messages_are_never_filtered() -> None: """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 3f865b077a..62c7c33558 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -20,7 +20,7 @@ pytestmark = pytest.mark.anyio -@requirement("mcpserver:prompts:decorated") +@requirement("mcpserver:prompt:decorated") async def test_list_prompts_derives_arguments_from_signature() -> None: """A decorated prompt is listed with arguments derived from the function signature. @@ -52,7 +52,7 @@ def code_review(code: str, style_guide: str = "pep8") -> str: ) -@requirement("mcpserver:prompts:decorated") +@requirement("mcpserver:prompt:decorated") async def test_get_prompt_renders_function_return() -> None: """The decorated function's string return value is rendered as a single user message.""" mcp = MCPServer("prompter") @@ -73,7 +73,7 @@ def greet(name: str) -> str: ) -@requirement("mcpserver:prompts:unknown-name") +@requirement("mcpserver:prompt:unknown-name") async def test_get_unknown_prompt_is_error() -> None: """Getting a prompt name that was never registered fails with a JSON-RPC error.""" mcp = MCPServer("prompter") @@ -90,7 +90,7 @@ def greet(name: str) -> str: assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown prompt: nope")) -@requirement("prompts:get:missing-arguments") +@requirement("prompts:get:missing-required-args") async def test_get_prompt_with_a_missing_required_argument_is_an_error() -> None: """Getting a prompt without one of its required arguments fails with a JSON-RPC error. diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index 801e60663a..4ad9ed356b 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -20,7 +20,7 @@ pytestmark = pytest.mark.anyio -@requirement("mcpserver:resources:static") +@requirement("mcpserver:resource:static") async def test_read_static_resource() -> None: """A function registered for a fixed URI is served at that URI with its return value as text.""" mcp = MCPServer("library") @@ -40,7 +40,7 @@ def app_config() -> str: ) -@requirement("mcpserver:resources:static") +@requirement("mcpserver:resource:static") async def test_list_static_and_templated_resources() -> None: """Statically-registered resources appear in resources/list; templated ones only in templates/list. @@ -89,7 +89,8 @@ def user_profile(user_id: str) -> str: ) -@requirement("mcpserver:resources:template") +@requirement("mcpserver:resource:template") +@requirement("resources:read:template-vars") async def test_read_templated_resource() -> None: """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" mcp = MCPServer("library") @@ -109,7 +110,7 @@ def user_profile(user_id: str) -> str: ) -@requirement("mcpserver:resources:unknown-uri") +@requirement("mcpserver:resource:unknown-uri") async def test_read_unknown_uri_is_error() -> None: """Reading a URI that matches no registered resource fails with a JSON-RPC error. diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index 1724360d5e..bd63fd5e61 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -38,7 +38,7 @@ def add(a: int, b: int) -> str: assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) -@requirement("mcpserver:tools:handler-exception") +@requirement("mcpserver:tool:handler-throws") async def test_call_tool_function_exception_becomes_error_result() -> None: """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" mcp = MCPServer("errors") @@ -55,7 +55,7 @@ def explode() -> str: ) -@requirement("mcpserver:tools:handler-exception") +@requirement("mcpserver:tool:handler-throws") async def test_call_tool_tool_error_becomes_error_result() -> None: """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" mcp = MCPServer("errors") @@ -72,7 +72,7 @@ def flux() -> str: ) -@requirement("mcpserver:tools:unknown-name") +@requirement("mcpserver:tool:unknown-name") async def test_call_tool_unknown_name_returns_error_result() -> None: """Calling a tool name that was never registered is reported as an is_error result. @@ -91,7 +91,8 @@ def add() -> None: assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) -@requirement("mcpserver:tools:output-schema:model") +@requirement("mcpserver:tool:output-schema:model") +@requirement("tools:call:structured-content:text-mirror") async def test_call_tool_model_return_becomes_structured_content() -> None: """A tool returning a pydantic model advertises the model's schema as the tool's output schema and returns the model's fields as structured content alongside a serialised text block. @@ -138,7 +139,7 @@ def get_weather() -> Weather: ) -@requirement("mcpserver:tools:output-schema:wrapped") +@requirement("mcpserver:tool:output-schema:wrapped") async def test_call_tool_list_return_is_wrapped_in_result_key() -> None: """A tool returning a list wraps the value under a "result" key in both the generated output schema and the structured content. @@ -169,7 +170,7 @@ def primes() -> list[int]: ) -@requirement("tools:call:invalid-arguments") +@requirement("mcpserver:tool:input-validation") async def test_call_tool_invalid_arguments_become_error_result() -> None: """Arguments that fail validation against the tool's signature are reported as an is_error result describing the failure, not as a protocol error. @@ -192,7 +193,7 @@ def add(a: int, b: int) -> str: assert result.content[0].text.startswith("Error executing tool add: 1 validation error") -@requirement("mcpserver:tools:list-changed-on-mutation") +@requirement("mcpserver:register:post-connect") async def test_adding_and_removing_tools_does_not_notify_connected_clients() -> None: """Mutating the tool set on a running server changes tools/list but sends no notification. diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 929bb103ed..6ef499e8ab 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -1,25 +1,37 @@ """Enforces the contract between the requirements manifest and the test suite. -Every non-deferred entry in :data:`REQUIREMENTS` must be exercised by at least one test, and every -`@requirement(...)` mark must reference a manifest entry. Test modules are imported directly +The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be +exercised by at least one test, and every test in the suite must carry at least one +`@requirement(...)` mark referencing a manifest entry. Test modules are imported directly (rather than relying on pytest collection) so the check holds even when only this file is run. """ import importlib from pathlib import Path +from types import ModuleType import pytest -from tests.interaction._requirements import REQUIREMENTS, covered_by, requirement +from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement _SUITE_ROOT = Path(__file__).parent +# Tests that exercise the suite's own helpers rather than an interaction-model behaviour. +# Anything listed here is exempt from the every-test-has-a-requirement check. +_HARNESS_SELF_TESTS = { + "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", +} -def _import_all_test_modules() -> None: - """Import every test module in the suite so their `@requirement` decorators register.""" + +def _import_all_test_modules() -> list[ModuleType]: + """Import every other test module in the suite so their `@requirement` decorators register.""" + modules: list[ModuleType] = [] for path in sorted(_SUITE_ROOT.rglob("test_*.py")): relative = path.relative_to(_SUITE_ROOT).with_suffix("") - importlib.import_module(f"{__package__}.{'.'.join(relative.parts)}") + name = f"{__package__}.{'.'.join(relative.parts)}" + if name != __name__: + modules.append(importlib.import_module(name)) + return modules def test_every_requirement_is_exercised() -> None: @@ -41,7 +53,30 @@ def test_every_requirement_is_exercised() -> None: assert not stale_deferrals, f"Deferred requirements that now have tests (remove deferred): {stale_deferrals}" +def test_every_test_exercises_a_requirement() -> None: + """Each test in the suite carries at least one `@requirement` mark (harness self-tests excepted).""" + all_tests = { + f"{module.__name__}.{name}" + for module in _import_all_test_modules() + for name in vars(module) + if name.startswith("test_") + } + linked_tests = {test_name for requirement_id in REQUIREMENTS for test_name in covered_by(requirement_id)} + + unlinked = sorted(all_tests - linked_tests - _HARNESS_SELF_TESTS) + assert not unlinked, f"Tests with no @requirement mark: {unlinked}" + + stale_exemptions = sorted(_HARNESS_SELF_TESTS - all_tests) + assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" + + def test_unknown_requirement_id_is_rejected() -> None: """Marking a test with an ID that is not in the manifest fails at decoration time.""" with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): requirement("tools:call:does-not-exist") + + +def test_invalid_requirement_source_is_rejected() -> None: + """A requirement whose source is not a spec URL, 'sdk', or an issue reference fails at construction.""" + with pytest.raises(ValueError, match="source must be a specification URL"): + Requirement(source="https://example.com/not-the-spec", behavior="Never constructed.") From d07f01f378dbaab89df6cc3cbf47fd0781a1dba1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 14:58:24 +0000 Subject: [PATCH 12/34] test: track the full requirements surface in the interaction manifest --- tests/interaction/README.md | 14 +- tests/interaction/_requirements.py | 2582 +++++++++++++---- .../interaction/lowlevel/test_cancellation.py | 1 + tests/interaction/lowlevel/test_completion.py | 1 + tests/interaction/lowlevel/test_initialize.py | 1 + tests/interaction/lowlevel/test_ping.py | 2 + tests/interaction/test_coverage.py | 20 +- 7 files changed, 2082 insertions(+), 539 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 4f7e3dc1f3..487908ff0e 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -60,11 +60,15 @@ test body — each directory pins its flavour's true output exactly. - **`source`** is a deep link into the MCP specification for externally mandated behaviour, the literal string `"sdk"` for behaviour the SDK chose where the spec is silent, or `"issue:#n"` for a regression lock. -- **`behavior`** describes what the suite *asserts* — which is always the SDK's current - behaviour, never an aspiration. -- **`divergence`** records the gap when current behaviour differs from what `source` mandates, - with an issue link once one exists. The test still pins current behaviour. -- **`deferred`** marks a behaviour that is deliberately not covered, with the reason. +- **`behavior`** describes the *required* behaviour — what the specification (or the SDK's own + contract) says should happen. Tests always pin the SDK's current behaviour; where that falls + short of `behavior`, the gap is recorded as data rather than hidden in the test. +- **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. +- **`deferred`** marks a behaviour that is tracked but not yet covered by a test in this suite. + The reason names the covering tests elsewhere in the repo, starts with "Not implemented in the + SDK" for genuine feature gaps, or starts with "Not yet covered here" for tests that are planned. +- **`transports`** names the transports a behaviour applies to; omitted means transport-independent. +- **`issue`** carries the tracking link for a recorded gap once one is filed. Tests link themselves to the manifest with a decorator: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index fb502d96c3..9249a51cf4 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -11,10 +11,16 @@ `sdk` -- a behavioural guarantee the SDK chose; not spec-mandated `issue:#n` -- regression lock-in for a previously fixed bug -The `behavior` sentence describes what the suite *asserts* -- which is always the SDK's current -behaviour. Where that differs from what `source` mandates, the gap is recorded in `divergence` -and the tests still pin current behaviour: this suite is the parity bar for the receive-path -rewrite, so a test that fails today proves nothing about equivalence. +The `behavior` sentence describes the REQUIRED behaviour -- what the specification (or the SDK's +own contract) says should happen. Tests always pin the SDK's current behaviour. Where current +behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose +tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered +by a test in this suite. `issue` carries the tracking link for a recorded gap once one is filed. + +`deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this +repo the reason names the covering test path; where the SDK does not implement the behaviour at +all the reason starts with "Not implemented in the SDK"; and where an interaction-level test is +planned but not yet written the reason starts with "Not yet covered here". `transports` records which transports a behaviour applies to (or is observable on); None means the behaviour is transport-independent. @@ -40,6 +46,11 @@ _SOURCE_PATTERN = re.compile(r"https://modelcontextprotocol\.io/specification/.+|sdk|issue:#\d+") +_TASKS_DEFERRAL = ( + "Tasks are experimental and the spec is being substantially revised; python task behaviour is " + "covered by tests/experimental/tasks/ until the next spec revision settles." +) + @dataclass(frozen=True, kw_only=True) class Divergence: @@ -58,6 +69,7 @@ class Requirement: transports: tuple[Transport, ...] | None = None divergence: Divergence | None = None deferred: str | None = None + issue: str | None = None def __post_init__(self) -> None: if not _SOURCE_PATTERN.fullmatch(self.source): @@ -66,49 +78,42 @@ def __post_init__(self) -> None: REQUIREMENTS: dict[str, Requirement] = { # ═══════════════════════════════════════════════════════════════════════════ - # Protocol primitives + # Lifecycle & version negotiation # ═══════════════════════════════════════════════════════════════════════════ - "protocol:request-id:unique": Requirement( - source=f"{SPEC_BASE_URL}/basic#requests", + "lifecycle:capability:client-not-declared": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", behavior=( - "Every request sent on a session carries a unique, non-null integer id; ids are never reused " - "within the session." + "The client rejects sending notifications or registering handlers for capabilities it did not declare." + ), + deferred=( + "Not implemented in the SDK: the client does not check its own declared capabilities before " + "sending notifications or serving callbacks." ), ), - "protocol:notifications:no-response": Requirement( - source=f"{SPEC_BASE_URL}/basic#notifications", + "lifecycle:capability:server-not-advertised": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#operation", behavior=( - "Notifications are never answered: every message the server delivers is either the response " - "to a request the client sent or a notification carrying no id." + "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." ), - ), - "protocol:error:internal-error": Requirement( - source=f"{SPEC_BASE_URL}/basic#responses", - behavior="An unhandled exception in a request handler is returned to the caller as a JSON-RPC error.", - divergence=Divergence( - note=( - "The spec reserves -32603 Internal error for this; the low-level Server returns code 0 " - "(not a defined JSON-RPC code) and leaks str(exc) as the error message." - ), + deferred=( + "Not implemented in the SDK: the client sends any request regardless of the server's " + "advertised capabilities and surfaces whatever the server answers." ), ), - "protocol:error:method-not-found": Requirement( - source=f"{SPEC_BASE_URL}/basic#responses", - behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + "lifecycle:initialize:basic": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Connecting sends initialize with the protocol version, client capabilities, and client " + "info; the server responds with its own and the connection is established." + ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Lifecycle - # ═══════════════════════════════════════════════════════════════════════════ "lifecycle:initialize:server-info": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", behavior="The initialize result identifies the server: name and version, plus title when declared.", ), "lifecycle:initialize:instructions": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", - behavior=( - "Server-declared instructions are returned in the initialize result, and omitted when the " - "server declares none." - ), + behavior="A server may include an instructions string in the initialize result; the client exposes it.", ), "lifecycle:initialize:capabilities:from-handlers": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", @@ -132,15 +137,67 @@ def __post_init__(self) -> None: "(sampling, elicitation, roots)." ), ), + "lifecycle:initialized-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "After successful initialization, the client sends exactly one initialized notification, " + "before any non-ping request." + ), + ), + "lifecycle:ping": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="ping in either direction returns an empty result.", + ), + "ping:client-to-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A client-initiated ping receives an empty result from the server.", + ), + "ping:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", + behavior="A server-initiated ping receives an empty result from the client.", + ), + "lifecycle:requests-before-initialized": Requirement( + source="sdk", + behavior=( + "A request other than ping sent before the initialization handshake completes is rejected with an error." + ), + ), + "lifecycle:pre-initialization-ordering": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + behavior=( + "Before initialization completes, the client sends no requests other than pings, and the " + "server sends no requests other than pings and logging." + ), + deferred=( + "Not yet covered here: the sender-side restraint (especially the server half — no sampling, " + "elicitation, or roots requests before the initialized notification) has no test yet." + ), + ), + "lifecycle:version:downgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", + behavior=( + "When the server returns an older supported protocol version, the client downgrades to it " + "and the connection succeeds at that version." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here: observing the negotiated version requires the MCP-Protocol-Version " + "request header, which only exists on the HTTP transport; planned with the transport " + "conformance work." + ), + ), "lifecycle:version:match": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", - behavior="The server echoes a requested protocol version it supports in the initialize result.", + behavior=( + "When the server supports the requested protocol version it echoes that version in the " + "initialize result, and the connection proceeds at that version." + ), ), "lifecycle:version:server-fallback-latest": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", behavior=( "An initialize request carrying a protocol version the server does not support is answered " - "with the server's latest supported version rather than an error." + "with another version the server supports — the latest one — rather than an error." ), ), "lifecycle:version:reject-unsupported": Requirement( @@ -150,25 +207,44 @@ def __post_init__(self) -> None: "support fails initialization with an error rather than proceeding with the session." ), ), - "lifecycle:requests-before-initialized": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", - behavior="A request sent before the initialization handshake completes is rejected with an error.", + # ═══════════════════════════════════════════════════════════════════════════ + # Protocol primitives: cancellation, timeout, progress, errors, _meta + # ═══════════════════════════════════════════════════════════════════════════ + "protocol:request-id:unique": Requirement( + source=f"{SPEC_BASE_URL}/basic#requests", + behavior=( + "Every request sent on a session carries a unique, non-null string or integer id; ids are " + "never reused within the session." + ), ), - "lifecycle:initialized-notification": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#initialization", + "protocol:notifications:no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic#notifications", behavior=( - "The client sends exactly one initialized notification, after the initialize response and " - "before its first feature request." + "Notifications are never answered: every message the server delivers is either the response " + "to a request the client sent or a notification carrying no id." ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Cancellation - # ═══════════════════════════════════════════════════════════════════════════ + "protocol:cancel:abort-signal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#cancellation-flow", + behavior=( + "Cancelling an in-flight request through the client API sends notifications/cancelled with " + "the request id and fails the local call." + ), + deferred=( + "Not implemented in the SDK: there is no public client-side API to cancel an in-flight " + "request; cancellation requires hand-constructing the notification (which is how " + "protocol:cancel:in-flight exercises the receiving side)." + ), + ), + "protocol:cancel:handler-abort-propagates": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior="On the receiving side, a cancellation notification stops the running request handler.", + ), "protocol:cancel:in-flight": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( "A cancellation notification for an in-flight request stops the server-side handler, and the " - "caller's pending request fails with an error response." + "receiver does not send a response for the cancelled request." ), divergence=Divergence( note=( @@ -178,16 +254,29 @@ def __post_init__(self) -> None: ), ), ), - "protocol:cancel:server-survives": Requirement( + "protocol:cancel:initialize-not-cancellable": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", - behavior="The session continues to serve new requests after an earlier request was cancelled.", + behavior="The client never sends notifications/cancelled for the initialize request.", + deferred=( + "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " + "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." + ), ), - "protocol:cancel:unknown-id-ignored": Requirement( + "protocol:cancel:late-response-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( - "A cancellation notification referencing an unknown or already-completed request is ignored without error." + "A response that arrives after the sender issued notifications/cancelled is ignored; the " + "request stays failed and no error is raised." + ), + deferred=( + "Not yet covered here: needs the scripted-peer wire pattern to deliver a response after a " + "cancellation; today the receive loop logs an unknown-request-id error for such responses." ), ), + "protocol:cancel:server-survives": Requirement( + source="sdk", + behavior="The session continues to serve new requests after an earlier request was cancelled.", + ), "protocol:cancel:server-to-client": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior=( @@ -195,16 +284,76 @@ def __post_init__(self) -> None: "cancels it, and the client stops processing the cancelled request." ), deferred=( - "Not expressible through the public API: abandoning a server-side send_request emits no " - "cancellation notification (the same sender-side gap recorded on " - "protocol:timeout:sends-cancellation), and the client could not act on one anyway because " - "client callbacks run inline in the receive loop, so a cancellation would not even be read " - "until the callback had already finished." + "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " + "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " + "the client could not act on one anyway because client callbacks run inline in the receive " + "loop, so a cancellation would not even be read until the callback had already finished." ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Progress - # ═══════════════════════════════════════════════════════════════════════════ + "protocol:cancel:unknown-id-ignored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", + behavior=( + "The receiver silently ignores a cancellation notification referencing an unknown or " + "already-completed request id; no error response is sent and no exception is raised." + ), + ), + "protocol:cancel:sender-targeting": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", + behavior=( + "Cancellation notifications reference only requests that were previously issued in the same " + "direction and are believed to still be in flight." + ), + deferred=( + "Not yet covered here: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin yet." + ), + ), + "protocol:error:connection-closed": Requirement( + source="sdk", + behavior="Closing the transport fails all in-flight requests with a connection-closed error.", + deferred=( + "Not yet covered here: planned gap test (close the transport while a request is in flight and " + "pin the error the caller receives)." + ), + ), + "protocol:error:internal-error": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior=( + "An unhandled exception in a request handler is returned to the caller as JSON-RPC error " + "-32603 Internal error." + ), + divergence=Divergence( + note=( + "The low-level Server returns code 0 (not a defined JSON-RPC code) instead of -32603 and " + "leaks str(exc) as the error message." + ), + ), + ), + "protocol:error:invalid-params": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", + deferred=( + "Not yet covered here: the typed client API cannot send malformed params; needs a request " + "driven one level below it (planned gap test)." + ), + ), + "protocol:error:method-not-found": Requirement( + source=f"{SPEC_BASE_URL}/basic#responses", + behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", + ), + "protocol:meta:related-task": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="Messages may carry related-task _meta associating them with a task.", + deferred=_TASKS_DEFERRAL, + ), + "meta:request-to-handler": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object the client attaches to a request is visible to the server handler.", + ), + "meta:result-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic#_meta", + behavior="The _meta object a handler attaches to its result is delivered to the client.", + ), "protocol:progress:callback": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( @@ -219,6 +368,32 @@ def __post_init__(self) -> None: "server-side handler can observe in its request metadata." ), ), + "protocol:progress:token-unique": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), + deferred=( + "Not yet covered here: planned gap test (two concurrent requests with progress callbacks, " + "asserting their tokens differ and each callback only sees its own notifications)." + ), + ), + "protocol:progress:monotonic": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", + behavior=( + "The progress value increases with each notification for a given token, even when the total is unknown." + ), + deferred=( + "Not implemented in the SDK: progress values are not validated anywhere; a handler can emit " + "non-increasing values and they are forwarded as-is." + ), + ), + "protocol:progress:stops-after-completion": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", + behavior="Progress notifications for a token stop once the associated request completes.", + deferred=( + "Not yet covered here: needs a test that a handler reporting progress after its request " + "completed produces no further notifications for the caller." + ), + ), "protocol:progress:no-token": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( @@ -230,9 +405,6 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior="A progress notification sent by the client is delivered to the server's progress handler.", ), - # ═══════════════════════════════════════════════════════════════════════════ - # Timeouts - # ═══════════════════════════════════════════════════════════════════════════ "protocol:timeout:basic": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", behavior=( @@ -240,14 +412,31 @@ def __post_init__(self) -> None: "waiting forever for the response." ), ), + "protocol:timeout:max-total": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + deferred=( + "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " + "read timeout exists." + ), + ), + "protocol:timeout:reset-on-progress": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", + behavior="When configured to do so, each progress notification resets the request's read timeout.", + deferred=( + "Not implemented in the SDK: progress notifications do not reset the request read timeout and " + "no option exists to enable that." + ), + ), "protocol:timeout:sends-cancellation": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", - behavior="A request that times out fails the caller; the server handler is not cancelled and keeps running.", + behavior=( + "When a request times out, the sender issues notifications/cancelled for that request before " + "failing the local call." + ), divergence=Divergence( note=( - "The spec says the requester SHOULD issue a cancellation notification for the timed-out " - "request; the client only raises locally and sends nothing, so the server keeps running " - "the handler." + "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." ), ), ), @@ -260,136 +449,51 @@ def __post_init__(self) -> None: behavior="A session-level read timeout applies to every request that does not override it.", ), # ═══════════════════════════════════════════════════════════════════════════ - # Pagination - # ═══════════════════════════════════════════════════════════════════════════ - "tools:list:pagination": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", - behavior=( - "The nextCursor returned by a list handler reaches the client, and the cursor the client " - "sends back on the next call reaches the handler as an opaque string." - ), - ), - "pagination:exhaustion": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", - behavior=( - "Following nextCursor until it is absent yields every page exactly once; a result without " - "nextCursor ends the sequence." - ), - ), - "resources:list:pagination": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", - behavior="resources/list supports cursor pagination.", - ), - "resources:templates:pagination": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", - behavior="resources/templates/list supports cursor pagination.", - ), - "prompts:list:pagination": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", - behavior="prompts/list supports cursor pagination.", - ), - # ═══════════════════════════════════════════════════════════════════════════ - # Request metadata - # ═══════════════════════════════════════════════════════════════════════════ - "meta:request-to-handler": Requirement( - source=f"{SPEC_BASE_URL}/basic#_meta", - behavior="The _meta object the client attaches to a request is visible to the server handler.", - ), - "meta:result-to-client": Requirement( - source=f"{SPEC_BASE_URL}/basic#_meta", - behavior="The _meta object a handler attaches to its result is delivered to the client.", - ), - # ═══════════════════════════════════════════════════════════════════════════ - # Ping - # ═══════════════════════════════════════════════════════════════════════════ - "ping:client-to-server": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", - behavior="A client-initiated ping receives an empty result from the server.", - ), - "ping:server-to-client": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/ping#behavior-requirements", - behavior="A server-initiated ping receives an empty result from the client.", - ), - # ═══════════════════════════════════════════════════════════════════════════ # Tools # ═══════════════════════════════════════════════════════════════════════════ - "tools:capability:declared": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", - behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", - ), - "tools:list:basic": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#listing-tools", - behavior="tools/list returns the registered tools with name, description, and inputSchema.", - ), - "tools:list:metadata": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#tool", - behavior=( - "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " - "are delivered to the client unchanged." - ), - ), - "tools:call:content:text": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#text-content", - behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", - ), - "tools:call:content:image": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#image-content", - behavior="A tool result can carry image content: base64 data with a mimeType.", - ), "tools:call:content:audio": Requirement( source=f"{SPEC_BASE_URL}/server/tools#audio-content", behavior="A tool result can carry audio content: base64 data with a mimeType.", ), - "tools:call:content:resource-link": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#resource-links", - behavior="A tool result can carry a resource_link content block referencing a resource by URI.", - ), "tools:call:content:embedded-resource": Requirement( source=f"{SPEC_BASE_URL}/server/tools#embedded-resources", behavior="A tool result can carry an embedded resource with full text or blob contents.", ), + "tools:call:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#image-content", + behavior="A tool result can carry image content: base64 data with a mimeType.", + ), "tools:call:content:mixed": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#calling-tools", + source=f"{SPEC_BASE_URL}/server/tools#tool-result", behavior="A tool result can carry multiple content blocks of different types; order is preserved.", ), - "tools:call:structured-content": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#structured-content", - behavior="A tool result can carry structuredContent alongside content; the client receives both.", + "tools:call:content:resource-link": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior="A tool result can carry a resource_link content block referencing a resource by URI.", ), - "tools:call:structured-content:text-mirror": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#structured-content", - behavior="A tool returning structured content also returns the serialized JSON as a text content block.", - ), - "tools:call:is-error": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#error-handling", - behavior=( - "A tool execution failure is returned as a result with isError true and the failure described " - "in content, not as a JSON-RPC error." - ), - ), - "tools:call:unknown-name": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#error-handling", - behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + "tools:call:content:text": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#text-content", + behavior="tools/call delivers arguments to the tool handler and returns its text content to the caller.", ), "tools:call:concurrent": Requirement( - source=f"{SPEC_BASE_URL}/basic#requests", + source="sdk", behavior=( "Multiple tool calls in flight on one session are dispatched concurrently, and each caller " "receives the response to its own request." ), ), "tools:call:elicitation-roundtrip": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + source=f"{SPEC_BASE_URL}/client/elicitation#user-interaction-model", behavior=( "A tool handler that issues an elicitation receives the client's result and can embed it in " "the tool call result." ), ), - "tools:call:sampling-roundtrip": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + "tools:call:is-error": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior=( - "A tool handler that issues a sampling request receives the client's completion and can embed " - "it in the tool call result." + "A tool execution failure is returned as a result with isError true and the failure described " + "in content, not as a JSON-RPC error." ), ), "tools:call:logging-mid-execution": Requirement( @@ -406,96 +510,256 @@ def __post_init__(self) -> None: "the tool result returns." ), ), + "tools:call:sampling-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + behavior=( + "A tool handler that issues a sampling request receives the client's completion and can embed " + "it in the tool call result." + ), + ), + "tools:call:structured-content": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool result can carry structuredContent alongside content; the client receives both.", + ), + "tools:call:structured-content:text-mirror": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#structured-content", + behavior="A tool returning structured content also returns the serialized JSON as a text content block.", + ), + "tools:call:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name the server does not recognise returns a JSON-RPC error.", + ), + "tools:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#capabilities", + behavior="A server with a list_tools handler advertises the tools capability in its initialize result.", + ), + "tools:input-schema:json-schema-2020-12": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " + "is discoverable and callable." + ), + deferred=( + "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " + "interaction-level passthrough test is planned with the gap batch." + ), + ), + "tools:input-schema:preserve-additional-properties": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema additionalProperties as registered.", + deferred=( + "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " + "interaction-level passthrough test is planned with the gap batch." + ), + ), + "tools:input-schema:preserve-defs": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves inputSchema $defs as registered.", + deferred=( + "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " + "interaction-level passthrough test is planned with the gap batch." + ), + ), + "tools:input-schema:preserve-schema-dialect": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", + deferred=( + "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " + "interaction-level passthrough test is planned with the gap batch." + ), + ), + "tools:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", + behavior=( + "When the tool set changes, a server that declared the tools listChanged capability sends " + "notifications/tools/list_changed and it reaches the client's handler." + ), + ), + "tools:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#listing-tools", + behavior="tools/list returns the registered tools with name, description, and inputSchema.", + ), + "tools:list:metadata": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool", + behavior=( + "Optional Tool fields supplied by the server (title, annotations, outputSchema, icons, _meta) " + "are delivered to the client unchanged." + ), + ), + "tools:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "tools/list supports cursor pagination: the nextCursor returned by a list handler round-trips " + "back to the handler as an opaque cursor until the listing is exhausted." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tools: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "client:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="The client skips structured-content validation when the tool result has isError true.", + deferred=( + "Not yet covered here: planned gap test (an isError result with mismatching structuredContent " + "is returned to the caller rather than rejected)." + ), + ), "client:output-schema:validate": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#tool-result", + source=f"{SPEC_BASE_URL}/server/tools#output-schema", behavior=( "A tool result whose structuredContent does not conform to the tool's declared outputSchema " "is rejected by the client: the call raises instead of returning the invalid result." ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Completion - # ═══════════════════════════════════════════════════════════════════════════ - "completion:capability:declared": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", - behavior="A server with a completion handler advertises the completions capability in its initialize result.", + "mcpserver:output-schema:missing-structured": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior="A tool with an output schema whose function returns no structured content produces a server error.", + deferred="Not yet covered here: planned gap test (output schema declared but no structured content returned).", ), - "completion:prompt-arg": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", - behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + "mcpserver:output-schema:server-validate": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#output-schema", + behavior=( + "MCPServer validates structured content against the tool's output schema before returning; a " + "mismatch produces a server error." + ), + deferred="Not yet covered here: planned gap test (server-side output schema validation failure).", ), - "completion:resource-template-arg": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", - behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + "mcpserver:output-schema:skip-on-error": Requirement( + source="sdk", + behavior="Server-side output schema validation is skipped when the tool returns an isError result.", + deferred="Not yet covered here: planned gap test (isError results bypass server-side schema validation).", ), - "completion:context-arguments": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", - behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + "mcpserver:tool:duplicate-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#tool-names", + behavior="Registering a tool with a name already in use is rejected at registration time.", + deferred="Not yet covered here: planned gap test (duplicate tool registration).", ), - "completion:complete:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + "mcpserver:tool:extra": Requirement( + source="sdk", behavior=( - "A server with no completion handler does not advertise the completions capability and rejects " - "completion/complete with METHOD_NOT_FOUND." + "Tool functions can access request metadata (request id, client params, session, lifespan " + "state) through the Context parameter." ), + deferred="Not yet covered here: planned gap test (Context request-metadata access from inside a tool).", ), - # ═══════════════════════════════════════════════════════════════════════════ - # Logging - # ═══════════════════════════════════════════════════════════════════════════ - "logging:set-level": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", - behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + "mcpserver:tool:handler-throws": Requirement( + source="sdk", + behavior=( + "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " + "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + ), ), - "logging:message:fields": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + "mcpserver:tool:input-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", behavior=( - "A log message sent by a server handler is delivered to the client's logging callback with its " - "severity level, logger name, and data, in the order the server sent them." + "Arguments that fail the tool's input validation produce a tool execution error (isError true " + "with the validation failure described in content) without invoking the function." ), ), - "logging:message:all-levels": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", - behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + "mcpserver:tool:naming-validation": Requirement( + source="sdk", + behavior="Tool names that violate the spec's naming rules are rejected at registration time.", + deferred="Not yet covered here: tool-name validation at registration has not been pinned yet.", ), - "logging:capability:declared": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + "mcpserver:tool:output-schema:model": Requirement( + source="sdk", behavior=( - "MCPServer tools emit log message notifications through the Context helpers while the server's " - "advertised capabilities omit logging." + "A tool returning a typed model advertises a matching generated outputSchema and returns the " + "model's fields as structuredContent alongside a serialised text block." ), - divergence=Divergence( - note=( - "The spec says servers that emit log message notifications MUST declare the logging " - "capability; MCPServer registers no setLevel handler, so capability derivation leaves " - "logging unset even though the Context helpers send the notifications." - ), + ), + "mcpserver:tool:output-schema:wrapped": Requirement( + source="sdk", + behavior=( + "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " + "structuredContent, with a matching generated outputSchema." ), ), - "logging:message:filtered": Requirement( - source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + "mcpserver:tool:schema-variants": Requirement( + source="sdk", behavior=( - "MCPServer registers no logging/setLevel handler (the request is rejected with method-not-found) " - "and log messages are delivered at every severity regardless of any requested level." + "Tool input schemas generated from complex parameter types (unions, nested models, " + "constrained types) validate and coerce arguments before the function runs." + ), + deferred=( + "Not yet covered here: planned gap test (complex parameter types validated and coerced before " + "the function runs)." ), + ), + "mcpserver:tool:unknown-name": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#error-handling", + behavior="tools/call for a name that was never registered returns a JSON-RPC error.", divergence=Divergence( note=( - "The spec says servers SHOULD only send log messages at or above the level the client " - "configured via logging/setLevel. Neither MCPServer (which rejects the request outright) " - "nor the low-level Server (which leaves the handler entirely to the author) implements " - "any filtering." + "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " + "params); MCPServer reports a tool execution error (isError true) instead. The low-level " + "path follows the spec example (see tools:call:unknown-name)." ), ), ), + "mcpserver:tool:url-elicitation-error": Requirement( + source="sdk", + behavior=( + "A tool function that raises the URL-elicitation-required error surfaces to the caller as " + "error -32042 with the elicitation parameters intact." + ), + deferred=( + "Not yet covered here: the low-level equivalent is pinned by elicitation:url:required-error; " + "the MCPServer-decorated path is a planned gap test." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # MCPServer: Context helpers (SDK) + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:context:logging": Requirement( + source="sdk", + behavior=( + "The Context logging helpers (debug/info/warning/error) send log message notifications at the " + "corresponding severity." + ), + ), + "mcpserver:context:progress": Requirement( + source="sdk", + behavior=( + "Context.report_progress sends a progress notification against the requesting client's progress token." + ), + ), + "mcpserver:context:elicit": Requirement( + source="sdk", + behavior=( + "Context.elicit sends a form elicitation built from a typed schema and returns a typed " + "accepted/declined/cancelled result." + ), + ), + "mcpserver:context:read-resource": Requirement( + source="sdk", + behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Resources # ═══════════════════════════════════════════════════════════════════════════ + "resources:annotations": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#annotations", + behavior=( + "Resource annotations (audience, priority, lastModified) supplied by the server round-trip to " + "the client in list and read results." + ), + deferred="Not yet covered here: planned gap test (annotations passthrough on list and read results).", + ), "resources:capability:declared": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + source=f"{SPEC_BASE_URL}/server/resources#capabilities", behavior=( "A server with resource handlers advertises the resources capability, including the subscribe " "sub-flag when a subscribe handler is registered." ), ), + "resources:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", + behavior=( + "When the resource set changes, a server that declared the resources listChanged capability " + "sends notifications/resources/list_changed and it reaches the client's handler." + ), + ), "resources:list:basic": Requirement( source=f"{SPEC_BASE_URL}/server/resources#listing-resources", behavior=( @@ -503,17 +767,48 @@ def __post_init__(self) -> None: "fields supplied by the server." ), ), - "resources:read:text": Requirement( - source=f"{SPEC_BASE_URL}/server/resources#reading-resources", - behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + "resources:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/list supports cursor pagination.", ), "resources:read:blob": Requirement( source=f"{SPEC_BASE_URL}/server/resources#reading-resources", behavior="resources/read returns binary contents base64-encoded in blob.", ), + "resources:read:template-vars": Requirement( + source="sdk", + behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", + ), + "resources:read:text": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#reading-resources", + behavior="resources/read returns text contents carrying uri, mimeType, and the text.", + ), "resources:read:unknown-uri": Requirement( source=f"{SPEC_BASE_URL}/server/resources#error-handling", - behavior="resources/read for an unknown URI returns a JSON-RPC error; the spec reserves -32002 for it.", + behavior="resources/read for an unknown URI returns JSON-RPC error -32002 (resource not found).", + ), + "resources:subscribe": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + ), + "resources:subscribe:capability-required": Requirement( + source="sdk", + behavior=( + "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." + ), + deferred=( + "Not yet covered here: planned gap test (subscribe rejected with METHOD_NOT_FOUND when no " + "subscribe handler is registered)." + ), + ), + "resources:subscribe:updated": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#subscriptions", + behavior="After resources/subscribe, changes to that resource send notifications/resources/updated.", + deferred=( + "Not implemented in the SDK: the server keeps no subscription state linking subscribe to " + "updated notifications; emitting updates is entirely handler code. The two halves are pinned " + "separately by resources:subscribe and resources:updated-notification." + ), ), "resources:templates:list": Requirement( source=f"{SPEC_BASE_URL}/server/resources#resource-templates", @@ -521,13 +816,9 @@ def __post_init__(self) -> None: "resources/templates/list returns the registered templates with their uriTemplate and descriptive fields." ), ), - "resources:read:template-vars": Requirement( - source=f"{SPEC_BASE_URL}/server/resources#resource-templates", - behavior="Variables extracted from a templated resource URI reach the resource function as typed arguments.", - ), - "resources:subscribe": Requirement( - source=f"{SPEC_BASE_URL}/server/resources#subscriptions", - behavior="resources/subscribe delivers the URI to the server's subscribe handler and returns an empty result.", + "resources:templates:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="resources/templates/list supports cursor pagination.", ), "resources:unsubscribe": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", @@ -552,57 +843,243 @@ def __post_init__(self) -> None: ), ), # ═══════════════════════════════════════════════════════════════════════════ - # Notifications: list_changed (server → client) + # Resources: SDK guarantees # ═══════════════════════════════════════════════════════════════════════════ - "tools:list-changed": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", - behavior="A tools/list_changed notification sent by the server reaches the client's message handler.", + "mcpserver:resource:duplicate-name": Requirement( + source="sdk", + behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", + deferred="Not yet covered here: planned gap test (duplicate resource registration).", ), - "resources:list-changed": Requirement( - source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", - behavior="A resources/list_changed notification sent by the server reaches the client's message handler.", + "mcpserver:resource:read-throws-surfaced": Requirement( + source="sdk", + behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", + deferred="Not yet covered here: planned gap test (resource function raising during read).", ), - "prompts:list-changed": Requirement( - source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", - behavior="A prompts/list_changed notification sent by the server reaches the client's message handler.", + "mcpserver:resource:static": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " + "served by resources/read at that URI." + ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Prompts - # ═══════════════════════════════════════════════════════════════════════════ - "prompts:capability:declared": Requirement( - source=f"{SPEC_BASE_URL}/basic/lifecycle#capability-negotiation", + "mcpserver:resource:template": Requirement( + source="sdk", + behavior=( + "A function registered with a URI template is listed by resources/templates/list and matched " + "by resources/read, receiving the parameters extracted from the requested URI." + ), + ), + "mcpserver:resource:unknown-uri": Requirement( + source=f"{SPEC_BASE_URL}/server/resources#error-handling", + behavior="resources/read for a URI matching no registered resource returns JSON-RPC error -32002.", + divergence=Divergence( + note=( + "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " + "the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts + # ═══════════════════════════════════════════════════════════════════════════ + "prompts:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#capabilities", behavior="A server with a list_prompts handler advertises the prompts capability in its initialize result.", ), - "prompts:list:basic": Requirement( - source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", - behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + "prompts:get:content:audio": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#audio-content", + behavior="Prompt messages may contain audio content with base64 data and a mimeType.", + deferred="Not yet covered here: planned gap test (audio content in prompt messages).", ), - "prompts:get:with-args": Requirement( - source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", - behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + "prompts:get:content:embedded-resource": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", + behavior="Prompt messages may contain embedded resource content.", + deferred="Not yet covered here: planned gap test (embedded resources in prompt messages).", + ), + "prompts:get:content:image": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#image-content", + behavior="Prompt messages may contain image content.", + deferred="Not yet covered here: planned gap test (image content in prompt messages).", + ), + "prompts:get:missing-required-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#error-handling", + behavior="prompts/get omitting a required argument returns JSON-RPC error -32602 (Invalid params).", + divergence=Divergence( + note=( + "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " + "which the low-level server converts to error code 0 with the exception text as the message." + ), + ), ), "prompts:get:multi-message": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", behavior="A prompt can return multiple messages mixing user and assistant roles; order is preserved.", ), + "prompts:get:no-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get with no arguments returns the prompt's messages.", + deferred="Not yet covered here: planned gap test (argument-free prompt fetched without arguments).", + ), "prompts:get:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", - behavior="prompts/get for an unknown prompt name returns a JSON-RPC error.", + behavior="prompts/get for an unknown prompt name returns JSON-RPC error -32602 (Invalid params).", ), - "prompts:get:missing-required-args": Requirement( + "prompts:get:with-args": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", + behavior="prompts/get delivers the supplied arguments to the prompt handler and returns its messages.", + ), + "prompts:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", + behavior=( + "When the prompt set changes, a server that declared the prompts listChanged capability sends " + "notifications/prompts/list_changed and it reaches the client's handler." + ), + ), + "prompts:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#listing-prompts", + behavior="prompts/list returns the registered prompts with name, description, and argument declarations.", + ), + "prompts:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#operations-supporting-pagination", + behavior="prompts/list supports cursor pagination.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Prompts: SDK guarantees + # ═══════════════════════════════════════════════════════════════════════════ + "mcpserver:prompt:args-validation": Requirement( + source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", + behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", + deferred="Not yet covered here: planned gap test (argument validation on decorated prompts).", + ), + "mcpserver:prompt:decorated": Requirement( + source="sdk", + behavior=( + "A function registered with @mcp.prompt() is listed with arguments derived from its signature " + "and rendered into prompt messages by prompts/get." + ), + ), + "mcpserver:prompt:duplicate-name": Requirement( + source="sdk", + behavior="Registering a duplicate prompt name is rejected at registration time.", + deferred="Not yet covered here: planned gap test (duplicate prompt registration).", + ), + "mcpserver:prompt:optional-args": Requirement( + source="sdk", + behavior="A prompt with optional arguments can be fetched without supplying them.", + deferred="Not yet covered here: planned gap test (optional prompt arguments omitted).", + ), + "mcpserver:prompt:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", - behavior="prompts/get with a required argument missing returns a JSON-RPC error.", + behavior="prompts/get for a name that was never registered returns JSON-RPC error -32602 (Invalid params).", divergence=Divergence( note=( - "The spec says missing required arguments are answered with -32602 Invalid params; " - "MCPServer's prompt renderer raises a plain ValueError before the prompt function runs, " - "which the low-level server converts to error code 0 with the exception text as the message." + "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " + "ValueError, which the low-level server converts to error code 0." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Completion + # ═══════════════════════════════════════════════════════════════════════════ + "completion:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior="A server with a completion handler advertises the completions capability in its initialize result.", + ), + "completion:complete:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#capabilities", + behavior=( + "A server with no completion handler does not advertise the completions capability and rejects " + "completion/complete with METHOD_NOT_FOUND." + ), + ), + "completion:context-arguments": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#requesting-completions", + behavior="Previously-resolved argument values supplied in context.arguments reach the completion handler.", + ), + "completion:error:invalid-ref": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#error-handling", + behavior=( + "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " + "JSON-RPC error -32602 (Invalid params)." + ), + deferred="Not yet covered here: planned gap test (completion against an unknown ref).", + ), + "completion:prompt-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/prompt returns suggested values for the named prompt argument.", + ), + "completion:resource-template-arg": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", + behavior="completion/complete with a ref/resource returns suggested values for a URI template variable.", + ), + "completion:result-shape": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/completion#completion-results", + behavior="The completion result carries values (at most 100), an optional total, and an optional hasMore flag.", + ), + "mcpserver:completion:capability-auto": Requirement( + source="sdk", + behavior=( + "MCPServer advertises the completions capability when at least one completion source is " + "registered, and omits it otherwise." + ), + deferred="Not yet covered here: planned gap test (automatic completions capability derivation).", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Logging + # ═══════════════════════════════════════════════════════════════════════════ + "logging:capability:declared": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#capabilities", + behavior=( + "A server that emits log message notifications declares the logging capability in its initialize result." + ), + divergence=Divergence( + note=( + "MCPServer registers no setLevel handler, so capability derivation leaves logging unset " + "even though the Context helpers send log message notifications." + ), + ), + ), + "logging:message:all-levels": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-levels", + behavior="All eight RFC 5424 severity levels are deliverable as log message notifications.", + ), + "logging:message:fields": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", + behavior=( + "A log message sent by a server handler is delivered to the client's logging callback with its " + "severity level, logger name, and data, in the order the server sent them." + ), + ), + "logging:message:filtered": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="After logging/setLevel, log messages below the configured level are not sent.", + divergence=Divergence( + note=( + "Neither MCPServer (which rejects logging/setLevel with method-not-found) nor the " + "low-level Server (which leaves the handler entirely to the author) implements any " + "filtering; messages are delivered at every severity regardless of the requested level." ), ), ), + "logging:set-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#setting-log-level", + behavior="logging/setLevel delivers the requested level to the server's handler and returns an empty result.", + ), + "logging:set-level:invalid-level": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", + behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", + deferred="Not yet covered here: planned gap test (invalid level value on setLevel).", + ), # ═══════════════════════════════════════════════════════════════════════════ # Sampling (server → client) # ═══════════════════════════════════════════════════════════════════════════ + "sampling:capability:declare": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "A client that handles sampling requests advertises the sampling capability in its initialize request." + ), + deferred="Not yet covered here: planned gap test (positive sampling capability declaration).", + ), "sampling:create:basic": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", behavior=( @@ -611,11 +1088,22 @@ def __post_init__(self) -> None: ), ), "sampling:create:include-context": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", behavior="The includeContext value supplied by the server reaches the client callback intact.", ), + "sampling:context:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + behavior=( + "The server does not use includeContext values thisServer or allServers unless the client " + "declared the sampling.context capability." + ), + deferred=( + "Not implemented in the SDK: include_context is forwarded regardless of the client's declared " + "sampling.context capability (unlike tools, which are gated by the server-side validator)." + ), + ), "sampling:create:model-preferences": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", + source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", behavior=( "The model preferences supplied by the server (hints and the cost, speed, and intelligence " "priorities) reach the client callback intact." @@ -625,468 +1113,1496 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", behavior="The system prompt supplied by the server reaches the client callback intact.", ), + "sampling:create:tools": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", + behavior=( + "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " + "with a toolUse stop reason returns to the requesting handler." + ), + deferred=( + "Not implemented in the SDK: Client does not expose ClientSession's sampling_capabilities " + "parameter, so a client can never declare sampling.tools and the server-side validator " + "rejects every tool-enabled request before it is sent." + ), + ), + "sampling:create-message:audio-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#audio-content", + behavior="Sampling messages can carry audio content: base64 data with a mimeType.", + deferred="Not yet covered here: planned gap test (audio content in sampling messages, both directions).", + ), "sampling:create-message:image-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#image-content", behavior="Sampling messages can carry image content: base64 data with a mimeType.", ), - "sampling:tools:server-gated-by-capability": Requirement( + "sampling:create-message:not-supported": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#capabilities", behavior=( - "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " - "by the server before anything reaches the wire, with an Invalid params error." + "A sampling request to a client that did not declare the sampling capability fails with an " + "error rather than hanging or being silently dropped; the spec names no error code for this case." ), ), - "sampling:tool-result:no-mixed-content": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#message-content-constraints", + "sampling:error:user-rejected": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#error-handling", behavior=( - "A sampling request whose messages violate the tool_use/tool_result pairing rules is rejected " - "by the server-side validator before anything reaches the wire." + "A sampling request the user rejects is answered with a JSON-RPC error (the spec's code for " + "this case is -1, 'User rejected sampling request'), surfaced to the requesting handler as an MCPError." ), ), - "sampling:create:tools": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#sampling-with-tools", + "sampling:message:content-cardinality": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling", + behavior="A sampling message's content may be a single block or an array of blocks.", + deferred="Not yet covered here: planned gap test (list-valued sampling message content).", + ), + "sampling:result:no-tools-single-content": Requirement( + source="sdk", behavior=( - "A sampling request carrying tools and toolChoice reaches the client, and a tool_use response " - "with a toolUse stop reason returns to the requesting handler." + "When the request carries no tools, a sampling callback result whose content is an array is " + "rejected by the client." + ), + deferred="Not yet covered here: planned gap test (array content rejected for tool-free sampling).", + ), + "sampling:result:with-tools-array-content": Requirement( + source="sdk", + behavior=( + "When the request includes tools, the client accepts a callback result whose content is an " + "array including tool_use blocks." ), deferred=( - "Not expressible through the public API: Client does not expose ClientSession's " - "sampling_capabilities parameter, so a client can never declare sampling.tools and the " - "server-side validator rejects every tool-enabled request before it is sent." + "Not implemented in the SDK: requires declaring sampling.tools, which the high-level client " + "cannot do (see sampling:create:tools)." ), ), - "sampling:error:user-rejected": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#error-handling", - behavior="A sampling callback that returns an error is surfaced to the requesting handler as an MCPError.", + "sampling:tool-result:no-mixed-content": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-result-messages", + behavior=( + "A user sampling message that carries tool_result content contains only tool_result blocks; " + "mixing tool_result with text, image, or audio content is rejected as invalid." + ), ), - "sampling:create-message:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/sampling#capabilities", + "sampling:tool-use:result-balance": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", + behavior=( + "Every assistant tool_use block in a sampling request must be matched by a tool_result with " + "the same id in the following user message; an unmatched tool_use is rejected with Invalid params." + ), + deferred="Not yet covered here: planned gap test (unmatched tool_use rejected by the validator).", + ), + "sampling:tools:server-gated-by-capability": Requirement( + source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", behavior=( - "A sampling request to a client that did not declare the sampling capability fails with the " - "client's default-callback error (-32600 Invalid request) rather than hanging or being " - "silently dropped; the spec names no error code for this case." + "A tool-enabled sampling request to a client that did not declare sampling.tools is rejected " + "by the server before anything reaches the wire (the SDK surfaces this as an Invalid params error)." ), ), # ═══════════════════════════════════════════════════════════════════════════ # Elicitation (server → client) # ═══════════════════════════════════════════════════════════════════════════ - "elicitation:form:basic": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + "elicitation:capability:empty-is-form": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior="A client advertising an empty elicitation capability accepts form-mode elicitation requests.", + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares explicit " + "form and url sub-capabilities, so an empty elicitation capability cannot be produced through " + "the public API." + ), + ), + "elicitation:capability:mode-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", behavior=( - "A form-mode elicitation delivers the message and requested schema to the client callback " - "exactly as the server sent them." + "The client answers elicitation requests for a mode it did not advertise with JSON-RPC error " + "-32602 (Invalid params)." + ), + deferred=( + "Not implemented in the SDK: a client cannot be configured form-only or url-only, so the " + "per-mode mismatch error cannot arise (see elicitation:url:not-supported)." + ), + ), + "elicitation:capability:server-respects-mode": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + behavior=( + "The server refuses to send an elicitation request with a mode the connected client did not " + "declare in its capabilities." + ), + deferred=( + "Not implemented in the SDK: the server does not check the client's declared elicitation " + "modes before sending elicitation/create." ), ), "elicitation:form:action:accept": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior=( "A form-mode elicitation answered with action 'accept' returns the user's content to the " "requesting handler." ), ), - "elicitation:form:action:decline": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", - behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", - ), "elicitation:form:action:cancel": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", behavior="A form-mode elicitation answered with action 'cancel' returns no content to the handler.", ), - "elicitation:url:basic": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + "elicitation:form:action:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A form-mode elicitation answered with action 'decline' returns no content to the handler.", + ), + "elicitation:form:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", behavior=( - "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " - "the server sent them." + "A form-mode elicitation delivers the message and requested schema to the client callback " + "exactly as the server sent them." ), ), - "elicitation:url:action:accept-no-content": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + "elicitation:form:defaults": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", behavior=( - "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " - "response carries no content (accept means the user agreed to visit the URL, not that the " - "interaction completed)." + "Optional default values declared in a form-mode requested schema are pre-populated into the " + "form presented to the user." + ), + deferred=( + "Not implemented in the SDK: there is no form-rendering layer that could pre-populate " + "defaults; client callbacks receive the requested schema as-is." ), ), - "elicitation:url:decline": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", - behavior="A URL-mode elicitation answered with decline returns the action with no content.", + "elicitation:form:mode-omitted-default": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", + behavior="An elicitation request with no mode field is treated as form mode by the client.", + deferred="Not yet covered here: planned gap test (mode-less elicitation request handled as form mode).", ), - "elicitation:url:cancel": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", - behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + "elicitation:form:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "An elicitation request to a client that did not declare the elicitation capability is " + "answered with -32602 Invalid params." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32602.", + ), ), - "elicitation:url:complete-notification": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + "elicitation:form:schema:enum-variants": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", behavior=( - "An elicitation/complete notification sent by the server after an out-of-band elicitation " - "finishes reaches the client carrying the elicitationId." + "Requested-schema enum fields (including titled and multi-select variants) reach the client " + "callback as sent." ), + deferred="Not yet covered here: planned gap test (enum variants in the requested schema).", ), - "elicitation:url:required-error": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + "elicitation:form:schema:primitives": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", + behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", + deferred="Not yet covered here: planned gap test (full primitive-type coverage in the requested schema).", + ), + "elicitation:form:schema:restricted-subset": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", behavior=( - "A handler that cannot proceed without a URL elicitation rejects the request with error " - "-32042, carrying the pending elicitations in the error data." + "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " + "structures and arrays of objects are not used." + ), + deferred=( + "Not implemented in the SDK: nothing restricts or validates the requested-schema shape on the " + "sending side; hand-built lowlevel elicitation requests pass through unchecked." + ), + ), + "elicitation:form:response-validation": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", + behavior=( + "Accepted form-mode content is validated against the requested schema: the client validates " + "the response before sending and the server validates the content it receives." + ), + deferred=("Not implemented in the SDK: accepted elicitation content passes through unvalidated on both sides."), + ), + "elicitation:url:action:accept-no-content": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior=( + "A URL-mode elicitation delivers the message, URL, and elicitationId to the client; an accept " + "response carries no content (accept means the user agreed to visit the URL, not that the " + "interaction completed)." + ), + ), + "elicitation:url:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-mode-elicitation-requests", + behavior=( + "A url-mode elicitation delivers the elicitation id and URL to the client callback exactly as " + "the server sent them." + ), + ), + "elicitation:url:cancel": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with cancel returns the action with no content.", + ), + "elicitation:url:complete-notification": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "An elicitation/complete notification sent by the server after an out-of-band elicitation " + "finishes reaches the client carrying the elicitationId." + ), + ), + "elicitation:url:complete-unknown-ignored": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#completion-notifications-for-url-mode-elicitation", + behavior=( + "The client ignores an elicitation/complete notification referencing an unknown or " + "already-completed elicitationId without error." + ), + deferred="Not yet covered here: planned gap test (unknown elicitationId in a complete notification).", + ), + "elicitation:url:decline": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", + behavior="A URL-mode elicitation answered with decline returns the action with no content.", + ), + "elicitation:url:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + behavior=( + "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " + "Invalid params error." + ), + deferred=( + "Not implemented in the SDK: a Client with an elicitation callback always declares both the " + "form and url sub-capabilities, so a form-only client cannot be constructed." + ), + ), + "elicitation:url:required-error": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", + behavior=( + "A handler that cannot proceed without a URL elicitation rejects the request with error " + "-32042, carrying the pending elicitations in the error data." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Roots (server → client) + # ═══════════════════════════════════════════════════════════════════════════ + "roots:list-changed": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + ), + "roots:list-changed:client-emits": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", + behavior=( + "A client that declared roots.listChanged sends notifications/roots/list_changed when its set " + "of roots changes." + ), + deferred=( + "Not implemented in the SDK: the client keeps no managed roots store, so nothing fires " + "automatically when the configured roots change; emission is an explicit " + "send_roots_list_changed() call (pinned by roots:list-changed)." + ), + ), + "roots:list:basic": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior=( + "A roots/list request from a server handler is answered by the client's roots callback, and " + "the returned roots (uri, name) reach the handler." + ), + ), + "roots:list:client-error": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + ), + "roots:list:empty": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + behavior="An empty roots list is a valid response and reaches the handler as such.", + ), + "roots:list:not-supported": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#error-handling", + behavior=( + "A roots/list request to a client that did not declare the roots capability is answered with " + "-32601 Method not found." + ), + divergence=Divergence( + note="The client's default callback answers with -32600 Invalid request instead of -32601.", + ), + ), + "roots:uri:file-scheme": Requirement( + source=f"{SPEC_BASE_URL}/client/roots#root", + behavior="Every root returned by the client identifies itself with a file:// URI.", + deferred=( + "Not yet covered here: planned gap test (the SDK's Root type enforces the file:// scheme; pin " + "it end-to-end through roots/list)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # list_changed & dynamic registration + # ═══════════════════════════════════════════════════════════════════════════ + "client:list-changed:auto-refresh": Requirement( + source="sdk", + behavior=( + "A client configured to react to list_changed notifications automatically re-fetches the " + "corresponding list and delivers the fresh result to its callback." + ), + deferred=( + "Not implemented in the SDK: the client has no list-changed auto-refresh mechanism; " + "notifications are only delivered to the message handler." + ), + ), + "client:list-changed:capability-gated": Requirement( + source="sdk", + behavior=( + "The client does not activate list-changed handling for a kind the server did not advertise " + "with listChanged true." + ), + deferred="Not implemented in the SDK: no client-side list-changed handling exists to gate.", + ), + "client:list-changed:signal-only": Requirement( + source="sdk", + behavior="A client configured for signal-only list-changed handling is notified without auto-refreshing.", + deferred="Not implemented in the SDK: no client-side list-changed handling exists.", + ), + "mcpserver:list-changed:debounce": Requirement( + source="sdk", + behavior=( + "Bursts of registration changes on MCPServer are debounced into one list_changed notification per kind." + ), + deferred=( + "Not implemented in the SDK: MCPServer does not send list_changed notifications on " + "registration changes at all (see mcpserver:register:post-connect), so there is nothing to " + "debounce." + ), + ), + "mcpserver:register:post-connect": Requirement( + source="sdk", + behavior=( + "A tool, resource, or prompt registered or removed after the client connected appears in (or " + "disappears from) the corresponding list results, and the change is announced with a " + "list_changed notification." + ), + divergence=Divergence( + note=( + "MCPServer never sends list_changed notifications on registration changes, so a connected " + "client cannot learn that the set changed without polling." + ), + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Pagination + # ═══════════════════════════════════════════════════════════════════════════ + "pagination:exhaustion": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#response-format", + behavior=( + "Following nextCursor until it is absent yields every page exactly once; a result without " + "nextCursor ends the sequence." + ), + ), + "pagination:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", + behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", + deferred="Not yet covered here: planned gap test (invalid pagination cursor rejected).", + ), + "pagination:client:cursor-handling": Requirement( + source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", + behavior=( + "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " + "and does not assume a fixed page size." + ), + deferred=( + "Not yet covered here: planned gap test (the client passes a server-issued cursor back " + "byte-for-byte and follows pages of varying sizes)." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Tasks (experimental) + # ═══════════════════════════════════════════════════════════════════════════ + "tasks:auth:context-isolation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-isolation-and-access-control", + behavior=( + "When an authorization context is available, task operations are scoped to the context that " + "created the task: other contexts cannot get it, retrieve its result, cancel it, or see it in " + "tasks/list." + ), + transports=("streamable-http",), + deferred=_TASKS_DEFERRAL, + ), + "tasks:bidirectional": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#definitions", + behavior="Task APIs are bidirectional: the server may create, get, list, and cancel tasks on the client.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:no-handler-abort": Requirement( + source="sdk", + behavior=( + "tasks/cancel marks the task cancelled without aborting the originating request handler " + "(the spec says receivers SHOULD attempt to stop execution)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:remains-cancelled": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior=( + "After tasks/cancel, the task remains cancelled even if the underlying handler subsequently " + "completes or fails." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:terminal-rejected": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a task already in a terminal state returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:cancel:working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-cancellation", + behavior="tasks/cancel on a working task transitions it to cancelled and returns the updated task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:ttl-honored": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#ttl-and-resource-management", + behavior=( + "tasks/get responses include the actual ttl applied by the receiver (or null for unlimited); " + "the create-task result carries the same value." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:create:via-tool-call": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#creating-tasks", + behavior="A task-augmented tools/call returns a create-task result instead of the tool result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:get": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#getting-tasks", + behavior="tasks/get returns the task's current status, ttl, timestamps, and status message.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:initial-working": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-lifecycle", + behavior="A newly created task has status 'working'.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:lifecycle:input-required": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "While a task awaits a side-channel client response its status is input_required; once the " + "response arrives the task leaves input_required (typically returning to working)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:invalid-cursor": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/list with an invalid cursor returns Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:list:pagination": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#listing-tasks", + behavior="tasks/list returns created tasks and supports cursor pagination.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:no-capability:ignore-task-param": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-support-and-handling", + behavior=( + "A receiver that did not declare task capability for a request type processes the request " + "normally and returns the ordinary result, ignoring the task augmentation." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:progress:after-create": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-progress-notifications", + behavior=( + "After the create-task result, progress notifications keyed to the original progress token " + "continue to reach the caller until the task is terminal." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:request-cancel:no-task-cancel": Requirement( + source="sdk", + behavior="A cancellation notification for the originating request does not auto-cancel the created task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:failed": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-execution-errors", + behavior="tasks/result for a failed task returns the failure result (isError true).", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:related-task-meta": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", + behavior="The tasks/result response carries related-task _meta naming the requested task.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:result:terminal": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#result-retrieval", + behavior="tasks/result for a completed task returns the stored result of the original request type.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drain-fifo": Requirement( + source="sdk", + behavior="tasks/result drains queued related-task messages in FIFO order before returning the final result.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:drop-on-cancel": Requirement( + source="sdk", + behavior="When a task is cancelled before tasks/result, queued related-task messages are dropped.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:elicitation": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "An elicitation issued mid-task is delivered through the tasks/result side-channel, and the " + "client's response routes back to the handler." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:queue": Requirement( + source="sdk", + behavior=( + "Server-to-client requests with related-task metadata sent while no tasks/result is open are queued." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:sampling": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#input-required-status", + behavior=( + "A sampling request issued mid-task is delivered through the tasks/result side-channel, and " + "the client's response routes back to the task." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:side-channel:stream": Requirement( + source="sdk", + behavior=( + "Calling tasks/result while the task is working streams related-task messages as they are " + "produced, then returns the result." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:status-notification": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#task-status-notification", + behavior="Task status notifications deliver status updates carrying the full task fields.", + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:forbidden-with-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=( + "A task-augmented tools/call on a tool that does not support tasks returns Method not found (-32601)." + ), + deferred=_TASKS_DEFERRAL, + ), + "tasks:tool-level:required-no-task-32601": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#tool-level-negotiation", + behavior=("A plain tools/call on a tool that requires task augmentation returns Method not found (-32601)."), + deferred=_TASKS_DEFERRAL, + ), + "tasks:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/utilities/tasks#protocol-errors", + behavior="tasks/get, tasks/result, and tasks/cancel for an unknown task id return Invalid params (-32602).", + deferred=_TASKS_DEFERRAL, + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Transports (in-suite coverage) + # ═══════════════════════════════════════════════════════════════════════════ + "transport:streamable-http:stateful": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip (initialize, tool calls, tool errors) works through the " + "streamable HTTP framing in its default stateful SSE-response mode." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:json-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + transports=("streamable-http",), + ), + "transport:streamable-http:stateless": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "The interaction round trip works in stateless mode, where every request is served by a " + "fresh transport with no session id." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:notifications": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "Notifications emitted during a request are delivered on that request's SSE stream and reach " + "the client's callbacks, in order, before the response." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:stateless-restrictions": Requirement( + source="sdk", + behavior=( + "A handler that attempts a server-initiated request in stateless mode fails with an error " + "result, because there is no session to call back through." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:unrelated-messages": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-to-client message that is not related to an in-flight request is routed to the " + "standalone GET stream; a client that never opened one does not receive it." + ), + transports=("streamable-http",), + ), + "transport:streamable-http:server-to-client": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior=( + "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + ), + transports=("streamable-http",), + deferred=( + "The in-process ASGI client buffers each response in full, which deadlocks on a " + "server-to-client request nested inside a still-open call. Covered over a real socket by " + "tests/shared/test_streamable_http.py." + ), + ), + "transport:streamable-http:resumability": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + transports=("streamable-http",), + deferred=( + "Replay requires dropping and re-establishing the SSE connection, which the in-process ASGI " + "client cannot express. Covered over a real socket by tests/shared/test_streamable_http.py." + ), + ), + "transport:streamable-http:origin-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", + transports=("streamable-http",), + deferred=( + "Not yet covered here: the in-process fixture leaves the SDK's opt-in protection disabled (see " + "hosting:http:dns-rebinding); existing coverage in tests/server/test_streamable_http_security.py." + ), + ), + "transport:sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " + "requests, with server messages delivered on the SSE stream." + ), + transports=("sse",), + deferred=( + "The legacy SSE transport is covered by tests/shared/test_sse.py; in-process coverage in this " + "suite arrives with the transport fixture work." + ), + ), + "transport:stdio": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="The interaction round trip works over a stdio subprocess.", + transports=("stdio",), + deferred=( + "Not yet covered here: a single composed end-to-end stdio test is planned; process lifecycle " + "details are covered by tests/client/test_stdio.py." + ), + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: session lifecycle + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:session:cors-expose": Requirement( + source="sdk", + behavior="CORS configuration exposes the Mcp-Session-Id header so browser clients can read it.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: CORS configuration is left to the hosting ASGI application.", + ), + "hosting:session:create": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "An initialize POST without a session id creates a session and returns Mcp-Session-Id in the " + "response headers." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/shared/test_streamable_http.py and " + "tests/server/test_streamable_http_manager.py." + ), + ), + "hosting:session:delete": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="DELETE with a valid Mcp-Session-Id terminates the session and removes its transport.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:session:id-charset": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:session:isolation": Requirement( + source="sdk", + behavior="Each session gets its own server instance; closing one session does not affect others.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/server/test_streamable_http_manager.py.", + ), + "hosting:session:missing-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:session:reinitialize": Requirement( + source="sdk", + behavior="A second initialize on an already-initialized session transport is rejected.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:session:reuse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:session:unknown-id": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:stateless:concurrent-clients": Requirement( + source="sdk", + behavior="Multiple independent clients can connect to a stateless server concurrently.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:stateless:no-reuse": Requirement( + source="sdk", + behavior="A stateless per-request transport cannot be reused for a second request.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:stateless:no-session-id": Requirement( + source="sdk", + behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: auth + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:auth:as-router": Requirement( + source="sdk", + behavior=( + "The authorization-server routes expose the authorize, token, and registration endpoints " + "(and revocation when supported)." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:aud-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", + behavior="The resource server validates that the token audience matches its resource identifier.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "hosting:auth:authinfo-propagates": Requirement( + source="sdk", + behavior="A valid token's auth info is exposed to request handlers.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:expired-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="An expired token returns 401 invalid_token.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:invalid-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", + behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:metadata-endpoints": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The MCP server publishes protected-resource metadata at its well-known endpoint, and the " + "authorization server (which the SDK can also host) publishes authorization-server metadata " + "at its own." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:missing-401": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior=( + "A request without an Authorization header is rejected with 401; the WWW-Authenticate header " + "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " + "planned with the auth tests in this suite." + ), + ), + "hosting:auth:prm:authorization-servers-field": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", + behavior=( + "The protected-resource metadata document includes an authorization_servers array with at least one entry." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "hosting:auth:scope-403": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", + behavior=( + "A token lacking a required scope returns 403 with WWW-Authenticate carrying " + "insufficient_scope, the required scope, and resource_metadata." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: resumability + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:resume:bad-event-id": Requirement( + source="sdk", + behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:resume:buffered-replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:resume:close-stream": Requirement( + source="sdk", + behavior="Handlers can close an SSE stream cleanly when an event store is configured.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: handlers have no API to close SSE streams.", + ), + "hosting:resume:event-ids": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="With an event store configured, every SSE event carries an id field.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:resume:priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A server-initiated SSE stream begins with a priming event carrying an event ID and an empty " + "data field; a server that closes the connection before terminating the stream sends an SSE " + "retry field first." + ), + transports=("streamable-http",), + deferred="Not yet covered here: whether the python server emits priming events has not been pinned.", + ), + "hosting:resume:replay": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:resume:stream-scoped": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", + behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Hosting: HTTP semantics + # ═══════════════════════════════════════════════════════════════════════════ + "hosting:http:accept-406": Requirement( + source="sdk", + behavior="A request whose Accept header does not allow the response representation returns 406.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:batch": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body is a single JSON-RPC message; batched arrays are rejected for protocol revisions " + "that forbid them." + ), + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:content-type-415": Requirement( + source="sdk", + behavior="A POST with a Content-Type other than application/json returns 415.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:disconnect-not-cancel": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A client connection drop during an in-flight request does not cancel the server-side " + "handler; the request continues and its result remains retrievable." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:http:dns-rebinding": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#security-warning", + behavior=( + "The Origin header is validated on every incoming connection; a request with an invalid " + "Origin is rejected with 403 Forbidden." + ), + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/server/test_streamable_http_security.py. " + "The SDK's protection is opt-in and disabled by default (no TransportSecuritySettings means " + "no Origin validation), and it also checks Host — the off-by-default gap is one to record as " + "a divergence when the transport conformance tests land." + ), + ), + "hosting:http:json-response-mode": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", + transports=("streamable-http",), + deferred=( + "Not yet covered here; existing coverage in tests/shared/test_streamable_http.py and the " + "json-response tests in this suite's transports directory." + ), + ), + "hosting:http:method-405": Requirement( + source="sdk", + behavior="An unsupported HTTP method on the MCP endpoint returns 405.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:no-broadcast": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", + behavior=( + "When multiple SSE streams are open for a session, each server-originated message is sent on " + "exactly one stream, never duplicated." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:http:notifications-202": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A POST containing only notifications or responses returns 202 with no body.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:onerror": Requirement( + source="sdk", + behavior="Transport-level rejections are reported through an error callback on the server transport.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: the server transport has no error callback; rejections are logged.", + ), + "hosting:http:parse-error-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A POST body that is not valid JSON or not a valid JSON-RPC message is rejected with HTTP 400; " + "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." + ), + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:protocol-version-400": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:protocol-version-default": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", + behavior=( + "When no MCP-Protocol-Version header is received and the version cannot be determined another " + "way, the server assumes protocol version 2025-03-26." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:http:response-same-connection": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior=( + "A response is delivered on the SSE stream opened by the POST that carried its request (or " + "that stream's resumed continuation), not on an unrelated stream." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "elicitation:form:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", + "hosting:http:second-sse-rejected": Requirement( + source="sdk", + behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:http:sse-close-after-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "hosting:http:standalone-sse": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="GET opens a standalone SSE stream that receives server-initiated messages.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:http:standalone-sse-no-response": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", behavior=( - "An elicitation request to a client that did not declare the elicitation capability fails with " - "an error rather than hanging or being silently dropped." - ), - divergence=Divergence( - note=( - "The spec says a request for an elicitation mode the client has not declared MUST be " - "answered with -32602 Invalid params; the client's default callback answers with -32600 " - "Invalid request." - ), + "The standalone GET SSE stream carries server requests and notifications but never a JSON-RPC " + "response, except when resuming a prior request stream." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "elicitation:url:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#capabilities", + # ═══════════════════════════════════════════════════════════════════════════ + # Client transport: streamable HTTP + # ═══════════════════════════════════════════════════════════════════════════ + "client-transport:http:404-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior=( - "A URL-mode elicitation to a client that declared only form-mode support is rejected with an " - "Invalid params error." + "A 404 in response to a request carrying a session ID makes the client start a new session " + "with a fresh InitializeRequest and no session ID attached." ), + transports=("streamable-http",), deferred=( - "Not expressible through the public API: a Client with an elicitation callback always declares " - "both the form and url sub-capabilities, so a form-only client cannot be constructed." + "Not implemented in the SDK: the client surfaces the 404 as an error to the caller instead of " + "re-initializing a new session." ), ), - "elicitation:form:defaults": Requirement( - source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-elicitation-requests", - behavior="A client that declares the defaults capability receives requested schemas with defaults applied.", - deferred="The SDK does not implement the defaults sub-capability on either side.", + "client-transport:http:accept-header-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - # ═══════════════════════════════════════════════════════════════════════════ - # Roots (server → client) - # ═══════════════════════════════════════════════════════════════════════════ - "roots:list:basic": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#listing-roots", + "client-transport:http:accept-header-post": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior=( - "A roots/list request from a server handler is answered by the client's roots callback, and " - "the returned roots (uri, name) reach the handler." + "Every client POST to the MCP endpoint includes an Accept header listing both application/json " + "and text/event-stream." ), + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "roots:list:empty": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#listing-roots", - behavior="An empty roots list is a valid response and reaches the handler as such.", + "client-transport:http:concurrent-streams": Requirement( + source="sdk", + behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "roots:list:not-supported": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#error-handling", + "client-transport:http:custom-client": Requirement( + source="sdk", behavior=( - "A roots/list request to a client that did not declare the roots capability fails with an " - "error rather than hanging or being silently dropped." - ), - divergence=Divergence( - note=( - "The spec says a client that does not support roots SHOULD answer with -32601 Method not " - "found; the client's default callback answers with -32600 Invalid request." - ), + "A caller-supplied HTTP client (and its event hooks and headers) is used for all MCP traffic, " + "including auth flows." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "roots:list:client-error": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#error-handling", - behavior="A roots callback that answers with an error surfaces to the requesting handler as an MCPError.", + "client-transport:http:custom-headers": Requirement( + source="sdk", + behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "roots:list-changed": Requirement( - source=f"{SPEC_BASE_URL}/client/roots#root-list-changes", - behavior="A roots/list_changed notification sent by the client is delivered to the server's handler.", + "client-transport:http:json-response-parsed": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - # ═══════════════════════════════════════════════════════════════════════════ - # Transports - # ═══════════════════════════════════════════════════════════════════════════ - "transport:streamable-http:stateful": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:no-reconnect-after-close": Requirement( + source="sdk", + behavior="After the transport is closed, no further reconnection attempts are scheduled.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "client-transport:http:no-reconnect-after-response": Requirement( + source="sdk", + behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "client-transport:http:protocol-version-header": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", behavior=( - "The interaction round trip (initialize, tool calls, tool errors) works through the " - "streamable HTTP framing in its default stateful SSE-response mode." + "After initialization, the client sends the negotiated MCP-Protocol-Version header on every " + "subsequent HTTP request." ), transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "transport:streamable-http:json-response": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", - behavior="The interaction round trip works when the server answers with plain JSON instead of SSE.", + "client-transport:http:protocol-version-stored": Requirement( + source="sdk", + behavior="The client transport exposes the negotiated protocol version once initialization completes.", transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "transport:streamable-http:stateless": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:reconnect-get": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior=( - "The interaction round trip works in stateless mode, where every request is served by a " - "fresh transport with no session id." + "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." ), transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "transport:streamable-http:notifications": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:reconnect-post-priming": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior=( - "Notifications emitted during a request are delivered on that request's SSE stream and reach " - "the client's callbacks, in order, before the response." + "A POST-initiated SSE stream that errors before delivering its response is reconnected only " + "if a priming event (an event carrying an ID) was received on it." ), transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "transport:streamable-http:stateless-restrictions": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:reconnect-retry-value": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", + behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "client-transport:http:resume-stream-api": Requirement( + source="sdk", behavior=( - "A handler that attempts a server-initiated request in stateless mode fails with an error " - "result, because there is no session to call back through." + "The client can capture a resumption token, reconnect with the same session id, and receive " + "the notifications it missed." ), transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "transport:streamable-http:unrelated-messages": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:session-stored": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior=( - "A server-to-client message that is not related to an in-flight request is routed to the " - "standalone GET stream; a client that never opened one does not receive it." + "The Mcp-Session-Id returned by initialize is stored by the client transport and sent on " + "every subsequent request." ), transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "transport:streamable-http:server-to-client": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-transport:http:sse-405-tolerated": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", + behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "client-transport:http:terminate-405-ok": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + # ═══════════════════════════════════════════════════════════════════════════ + # Client auth + # ═══════════════════════════════════════════════════════════════════════════ + "client-auth:401-after-auth-throws": Requirement( + source="sdk", behavior=( - "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." + "If the server still returns 401 after a successful authorization, the client fails instead of looping." ), transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:401-triggers-flow": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", + behavior="A 401 on a request triggers the OAuth authorization flow once.", + transports=("streamable-http",), deferred=( - "The in-process ASGI client buffers each response in full, which deadlocks on a " - "server-to-client request nested inside a still-open call. Covered over a real socket by " - "tests/shared/test_streamable_http.py." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "transport:streamable-http:resumability": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", - behavior="A client that reconnects with Last-Event-ID receives the events it missed.", + "client-auth:403-scope-upgrade": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", + behavior=( + "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:as-metadata-discovery:priority-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client discovers authorization-server metadata by trying, in order, the OAuth " + "path-inserted, OIDC path-inserted, and OIDC path-appended well-known URLs (with the " + "root-path forms when the issuer URL has no path)." + ), transports=("streamable-http",), deferred=( - "Replay requires dropping and re-establishing the SSE connection, which the in-process ASGI " - "client cannot express. Covered over a real socket by tests/shared/test_streamable_http.py." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "transport:streamable-http:origin-validation": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", - behavior="Requests with a disallowed Origin or Host header are rejected before reaching the session.", + "client-auth:bearer-header:every-request": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", + behavior=( + "Once authorized, the client sends the bearer token in the Authorization header on every HTTP " + "request to the MCP server, never in the query string." + ), transports=("streamable-http",), deferred=( - "The in-process fixture disables DNS-rebinding protection because no network attack surface " - "exists in-process. Covered by tests/server/test_streamable_http_security.py." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "transport:streamable-http:session-management": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#session-management", + "client-auth:cimd": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", + behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: client-ID metadata documents are not supported.", + ), + "client-auth:client-credentials": Requirement( + source="sdk", behavior=( - "The server issues a session id on initialize, validates it on subsequent requests, isolates " - "sessions, and tears the session down on DELETE." + "A client-credentials provider obtains a token without user interaction and the resulting " + "bearer token authorizes subsequent requests." ), transports=("streamable-http",), deferred=( - "Covered at the wire level by tests/shared/test_streamable_http.py and " - "tests/server/test_streamable_http_manager.py; this suite drives sessions only through the " - "client API." + "Not yet covered here; existing coverage in tests/client/auth/; interaction-level coverage " + "planned with the auth tests in this suite." ), ), - "transport:streamable-http:wire-validation": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", + "client-auth:dcr": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", behavior=( - "The server validates Accept and Content-Type headers, the protocol-version header, and " - "malformed JSON bodies, answering with the documented HTTP status codes." + "The client performs dynamic client registration against the authorization server when no " + "client_id is preconfigured." ), transports=("streamable-http",), deferred=( - "Raw-HTTP request/response validation is covered by tests/shared/test_streamable_http.py; " - "this suite only sends well-formed traffic through the client." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "transport:streamable-http:client-reconnect": Requirement( + "client-auth:invalid-client-clears-all": Requirement( source="sdk", behavior=( - "The HTTP client transport reconnects dropped SSE streams, honours the server-provided retry " - "interval, and resumes from the last event id." + "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." ), transports=("streamable-http",), - deferred=( - "Reconnection and resumption behaviour needs a droppable connection; covered by " - "tests/shared/test_streamable_http.py over a real socket." + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:invalid-grant-clears-tokens": Requirement( + source="sdk", + behavior="An invalid-grant error during authorization invalidates only the stored tokens.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:pkce:refuse-if-unsupported": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The client refuses to proceed when the authorization server's metadata does not include " + "code_challenge_methods_supported, since PKCE support cannot be verified." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - "transport:sse": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports", + "client-auth:pkce:s256": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", behavior=( - "A client connected over the legacy HTTP+SSE transport completes the handshake and round-trips " - "requests, with server messages delivered on the SSE stream." + "The authorization request includes a PKCE S256 code challenge and the token request includes " + "the matching verifier." ), - transports=("sse",), + transports=("streamable-http",), deferred=( - "The legacy SSE transport is covered by tests/shared/test_sse.py; in-process coverage in this " - "suite arrives with the transport fixture work." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "transport:stdio": Requirement( - source=f"{SPEC_BASE_URL}/basic/transports#stdio", - behavior="The interaction round trip works over a stdio subprocess.", - transports=("stdio",), + "client-auth:pre-registration": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", + behavior=( + "A client with statically preconfigured credentials skips dynamic registration and uses them directly." + ), + transports=("streamable-http",), deferred=( - "Requires a real subprocess. Process lifecycle is covered by tests/client/test_stdio.py and " - "end-to-end stdio coverage belongs to the cross-SDK conformance suite." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # Authorization - # ═══════════════════════════════════════════════════════════════════════════ - "auth:client-oauth": Requirement( - source=f"{SPEC_BASE_URL}/basic/authorization", + "client-auth:private-key-jwt": Requirement( + source="sdk", + behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", + transports=("streamable-http",), + deferred="Not implemented in the SDK: JWT-assertion client authentication is not supported.", + ), + "client-auth:prm-discovery:fallback-order": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", behavior=( - "The client performs the OAuth 2.1 authorization flow (metadata discovery, PKCE, dynamic " - "client registration, token refresh, resource parameter) when a server requires authorization." + "The client uses resource_metadata from WWW-Authenticate when present, then falls back to the " + "well-known protected-resource locations in the documented order." ), transports=("streamable-http",), deferred=( - "Authorization is out of scope for this suite. Client-side flow coverage lives in " - "tests/client/test_auth.py, tests/client/auth/, and tests/shared/test_auth_utils.py." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - "auth:server-enforcement": Requirement( - source=f"{SPEC_BASE_URL}/basic/authorization", + "client-auth:prm-resource-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", behavior=( - "A server protecting its endpoints rejects missing, invalid, expired, or under-scoped tokens " - "with 401/403 and serves protected-resource metadata." + "The client refuses to proceed when the protected-resource metadata's resource field does not " + "match the server URL it is connecting to." ), transports=("streamable-http",), - deferred=( - "Authorization is out of scope for this suite. Server-side enforcement coverage lives in " - "tests/server/auth/ and tests/shared/test_auth.py." - ), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - # ═══════════════════════════════════════════════════════════════════════════ - # Tasks (experimental) - # ═══════════════════════════════════════════════════════════════════════════ - "tasks:experimental": Requirement( - source=f"{SPEC_BASE_URL}/basic/utilities/tasks", + "client-auth:resource-parameter": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", behavior=( - "Task-augmented requests (tasks/create, tasks/get, tasks/list, tasks/cancel, task-status " - "notifications and task-scoped side-channel requests) run the documented task lifecycle." + "The client includes the canonical server URI as the resource parameter in both the " + "authorization request and the token request." ), + transports=("streamable-http",), deferred=( - "Tasks are experimental and under active spec revision; the suite excludes them. Python task " - "behaviour is covered by tests/experimental/tasks/." + "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " + "coverage planned with the auth tests in this suite." ), ), - # ═══════════════════════════════════════════════════════════════════════════ - # MCPServer behaviours - # ═══════════════════════════════════════════════════════════════════════════ - "mcpserver:tool:input-validation": Requirement( - source=f"{SPEC_BASE_URL}/server/tools#error-handling", + "client-auth:scope-selection:priority": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", behavior=( - "Arguments that fail the tool's input validation produce a tool execution error (isError true " - "with the validation failure described in content), not a protocol error." + "The client selects the requested scope from WWW-Authenticate when present, then from the " + "protected-resource metadata, and otherwise omits scope." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - "mcpserver:tool:output-schema:model": Requirement( - source="sdk", + "client-auth:state:verify": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", behavior=( - "A tool returning a typed model advertises a matching generated outputSchema and returns the " - "model's fields as structuredContent alongside a serialised text block." + "A state parameter is included in the authorization URL, and authorization results with a " + "missing or mismatched state are discarded." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - "mcpserver:tool:output-schema:wrapped": Requirement( + "client-auth:token-endpoint-auth-method": Requirement( source="sdk", + behavior="The client authenticates to the token endpoint using the auth method established at registration.", + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:token-provenance": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", behavior=( - "A tool returning a non-object type (primitive or list) wraps the value as {'result': ...} in " - "structuredContent, with a matching generated outputSchema." + "The client sends the MCP server only tokens issued by that server's authorization server, " + "never tokens obtained elsewhere." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - "mcpserver:resource:static": Requirement( - source="sdk", + # ═══════════════════════════════════════════════════════════════════════════ + # stdio transport + # ═══════════════════════════════════════════════════════════════════════════ + "transport:stdio:clean-shutdown": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", + behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", + transports=("stdio",), + deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", + ), + "transport:stdio:stream-purity": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", behavior=( - "A function registered with @mcp.resource() for a fixed URI is listed by resources/list and " - "served by resources/read at that URI." + "Nothing that is not a valid MCP message is written to the server's stdout, and nothing that " + "is not a valid MCP message is written to its stdin." ), + transports=("stdio",), + deferred="Not yet covered here: planned with the stdio end-to-end test.", ), - "mcpserver:resource:template": Requirement( - source="sdk", + "transport:stdio:no-embedded-newlines": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#stdio", + behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", + transports=("stdio",), + deferred="Not yet covered here: planned with the stdio end-to-end test.", + ), + "transport:stdio:shutdown-escalation": Requirement( + source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", behavior=( - "A function registered with a URI template is listed by resources/templates/list and matched " - "by resources/read, receiving the parameters extracted from the requested URI." + "If the server process does not exit after stdin is closed, the client transport terminates " + "it (and kills it if still alive) after a grace period." ), + transports=("stdio",), + deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", ), - "mcpserver:resource:unknown-uri": Requirement( + "transport:stdio:stderr-passthrough": Requirement( source="sdk", - behavior="resources/read for a URI matching no registered resource returns a JSON-RPC error.", - divergence=Divergence( - note=( - "The spec reserves -32002 for resource-not-found; MCPServer raises ResourceError, which " - "the low-level server converts to error code 0." - ), - ), + behavior="Server stderr is available to the client and is not consumed by the transport.", + transports=("stdio",), + deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", ), - "mcpserver:prompt:decorated": Requirement( - source="sdk", + # ═══════════════════════════════════════════════════════════════════════════ + # Composite end-to-end flows + # ═══════════════════════════════════════════════════════════════════════════ + "flow:compat:dual-transport-server": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", behavior=( - "A function registered with @mcp.prompt() is listed with arguments derived from its signature " - "and rendered into prompt messages by prompts/get." + "A single server instance can serve streamable HTTP and the legacy SSE transport " + "concurrently; clients on either transport can call the same tools." ), + transports=("streamable-http", "sse"), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "mcpserver:prompt:unknown-name": Requirement( - source="sdk", - behavior="prompts/get for a name that was never registered returns a JSON-RPC error.", - divergence=Divergence( - note=( - "The spec's example uses -32602 Invalid params for unknown prompts; MCPServer raises " - "ValueError, which the low-level server converts to error code 0." - ), + "flow:compat:streamable-then-sse-fallback": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "When a streamable HTTP initialize fails with 400, 404, or 405, falling back to the legacy " + "SSE client transport against the same server connects successfully." ), + transports=("streamable-http", "sse"), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "mcpserver:context:logging": Requirement( + "flow:elicitation:multi-step-form": Requirement( source="sdk", behavior=( - "The Context logging helpers (debug/info/warning/error) send log message notifications at the " - "corresponding severity." + "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " + "and a decline or cancel at any step short-circuits to a final result." ), + deferred="Not yet covered here: planned gap test (multi-step elicitation flow).", ), - "mcpserver:context:progress": Requirement( + "flow:elicitation:url-at-session-init": Requirement( source="sdk", behavior=( - "Context.report_progress sends a progress notification against the requesting client's progress token." + "The server can issue a URL-mode elicitation over the standalone GET stream immediately after " + "session initialization, before any client request." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "mcpserver:context:elicit": Requirement( - source="sdk", + "flow:elicitation:url-required-then-retry": Requirement( + source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", behavior=( - "Context.elicit sends a form elicitation built from a typed schema and returns a typed " - "accepted/declined/cancelled result." + "A tool call rejected with the URL-elicitation-required error can be retried successfully " + "after the client completes the URL flow and the server announces completion." ), + deferred="Not yet covered here: planned gap test (full URL-elicitation-required retry flow).", ), - "mcpserver:context:read-resource": Requirement( + "flow:multi-client:stateful-isolation": Requirement( source="sdk", - behavior="Context.read_resource reads a resource registered on the same server from inside a tool.", + behavior=( + "Independent clients connected to one stateful server each receive a distinct session and " + "only the notifications produced by their own requests." + ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", ), - "mcpserver:register:post-connect": Requirement( - source="sdk", - behavior=("A tool added or removed after the client connected is reflected in subsequent tools/list results."), - divergence=Divergence( - note=( - "The spec provides notifications/tools/list_changed for exactly this case; MCPServer never " - "sends it, so a connected client cannot learn that the tool set changed without polling." - ), + "flow:oauth:authorization-code-roundtrip": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "Connecting to a protected server walks the authorization-code flow end to end: the first " + "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." ), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), - "mcpserver:tool:handler-throws": Requirement( - source="sdk", + "flow:resume:tool-call-resumption-token": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior=( - "An exception raised by a tool function (ToolError or otherwise) is caught and returned as a " - "tool result with isError true and the failure text in content; it does not become a JSON-RPC error." + "A tool call interrupted mid-stream can be resumed with the captured resumption token, " + "delivering only the remaining notifications and the final result." ), + transports=("streamable-http",), + deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), - "mcpserver:tool:unknown-name": Requirement( - source="sdk", - behavior="Calling a tool name that was never registered returns a tool result with isError true.", - divergence=Divergence( - note=( - "The spec classifies unknown tools as a protocol error (its example uses -32602 Invalid " - "params); MCPServer reports a tool execution error instead. The low-level path follows the " - "spec example (see tools:call:unknown-name)." - ), + "flow:session:terminate-then-reconnect": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), + transports=("streamable-http",), + deferred="Not yet covered here: planned with the transport conformance work.", + ), + "flow:tool-result:resource-link-follow": Requirement( + source=f"{SPEC_BASE_URL}/server/tools#resource-links", + behavior=( + "A resource_link returned by a tool call can be followed with resources/read on the linked " + "URI to retrieve the referenced contents." ), + deferred="Not yet covered here: planned gap test (follow a resource link returned by a tool).", ), } diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 591c66efa2..bbf984fc4a 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -20,6 +20,7 @@ @requirement("protocol:cancel:in-flight") +@requirement("protocol:cancel:handler-abort-propagates") async def test_cancellation_stops_in_flight_handler() -> None: """Cancelling an in-flight request interrupts its handler and fails the pending call. diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index f5deaa89f6..ea2529169f 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -20,6 +20,7 @@ @requirement("completion:prompt-arg") +@requirement("completion:result-shape") async def test_complete_prompt_argument() -> None: """Completing a prompt argument delivers the ref, argument name, and current value to the handler. diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 074a8c12c2..16b943f960 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -46,6 +46,7 @@ pytestmark = pytest.mark.anyio +@requirement("lifecycle:initialize:basic") @requirement("lifecycle:initialize:server-info") async def test_initialize_returns_server_info() -> None: """Every identity field the server declares is returned to the client in server_info.""" diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py index 48dc2717de..6a82601d48 100644 --- a/tests/interaction/lowlevel/test_ping.py +++ b/tests/interaction/lowlevel/test_ping.py @@ -12,6 +12,7 @@ pytestmark = pytest.mark.anyio +@requirement("lifecycle:ping") @requirement("ping:client-to-server") async def test_client_ping_returns_empty_result() -> None: """A client ping is answered with an empty result, even by a server with no handlers.""" @@ -23,6 +24,7 @@ async def test_client_ping_returns_empty_result() -> None: assert result == snapshot(EmptyResult()) +@requirement("lifecycle:ping") @requirement("ping:server-to-client") async def test_server_ping_returns_empty_result() -> None: """A server-initiated ping sent while a request is in flight is answered by the client. diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 6ef499e8ab..aa576aaf84 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -2,11 +2,13 @@ The contract runs in both directions: every non-deferred entry in :data:`REQUIREMENTS` must be exercised by at least one test, and every test in the suite must carry at least one -`@requirement(...)` mark referencing a manifest entry. Test modules are imported directly +`@requirement(...)` mark referencing a manifest entry. Deferral reasons that point at coverage +elsewhere in the repo must point at paths that exist. Test modules are imported directly (rather than relying on pytest collection) so the check holds even when only this file is run. """ import importlib +import re from pathlib import Path from types import ModuleType @@ -15,6 +17,10 @@ from tests.interaction._requirements import REQUIREMENTS, Requirement, covered_by, requirement _SUITE_ROOT = Path(__file__).parent +_REPO_ROOT = _SUITE_ROOT.parent.parent + +# Repo paths cited inside deferral reasons ("Covered by tests/... "). +_CITED_PATH = re.compile(r"(?:tests|src)/[\w./-]*\w") # Tests that exercise the suite's own helpers rather than an interaction-model behaviour. # Anything listed here is exempt from the every-test-has-a-requirement check. @@ -70,6 +76,18 @@ def test_every_test_exercises_a_requirement() -> None: assert not stale_exemptions, f"Harness self-test exemptions that no longer exist: {stale_exemptions}" +def test_deferral_reasons_cite_existing_paths() -> None: + """Every repo path named in a deferral reason exists, so coverage pointers cannot rot.""" + missing = sorted( + f"{requirement_id}: {cited}" + for requirement_id, spec in REQUIREMENTS.items() + if spec.deferred is not None + for cited in _CITED_PATH.findall(spec.deferred) + if not (_REPO_ROOT / cited).exists() + ) + assert not missing, f"Deferral reasons citing paths that do not exist: {missing}" + + def test_unknown_requirement_id_is_rejected() -> None: """Marking a test with an ID that is not in the manifest fails at decoration time.""" with pytest.raises(KeyError, match="Unknown requirement id 'tools:call:does-not-exist'"): From c1eab9d9b634a90756966c0984002090ed49777b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 15:48:11 +0000 Subject: [PATCH 13/34] test: add an in-process streaming ASGI transport and cover server-initiated requests over streamable HTTP --- tests/interaction/README.md | 4 +- tests/interaction/_requirements.py | 8 +- tests/interaction/test_coverage.py | 3 + tests/interaction/transports/_bridge.py | 154 ++++++++++++++++++ tests/interaction/transports/test_bridge.py | 71 ++++++++ .../transports/test_streamable_http.py | 105 ++++++++---- 6 files changed, 310 insertions(+), 35 deletions(-) create mode 100644 tests/interaction/transports/_bridge.py create mode 100644 tests/interaction/transports/test_bridge.py diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 487908ff0e..50bde98e41 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -27,7 +27,9 @@ The whole suite is in-memory and event-driven; it runs in about a second. SDK's own deliberate output. - **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that could hang is bounded by `anyio.fail_after(5)`. The streamable HTTP tests drive the Starlette - app in-process through `httpx.ASGITransport` — no sockets, threads, or subprocesses anywhere. + app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which + delivers each response chunk as the server produces it — full duplex, but still no sockets, + threads, or subprocesses anywhere. ## Layout diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 9249a51cf4..cdc97d4074 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1722,7 +1722,8 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior=( "A server-to-client message that is not related to an in-flight request is routed to the " - "standalone GET stream; a client that never opened one does not receive it." + "standalone GET stream and delivered to the client listening on it, not to any request's " + "own stream." ), transports=("streamable-http",), ), @@ -1732,11 +1733,6 @@ def __post_init__(self) -> None: "A server-initiated request nested inside an in-flight call round-trips over stateful streamable HTTP." ), transports=("streamable-http",), - deferred=( - "The in-process ASGI client buffers each response in full, which deadlocks on a " - "server-to-client request nested inside a still-open call. Covered over a real socket by " - "tests/shared/test_streamable_http.py." - ), ), "transport:streamable-http:resumability": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index aa576aaf84..5a4f003101 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -26,6 +26,9 @@ # Anything listed here is exempt from the every-test-has-a-requirement check. _HARNESS_SELF_TESTS = { "tests.interaction.lowlevel.test_wire.test_recording_read_stream_ends_iteration_when_the_sender_closes", + "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", + "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", + "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", } diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py new file mode 100644 index 0000000000..254f1e00c1 --- /dev/null +++ b/tests/interaction/transports/_bridge.py @@ -0,0 +1,154 @@ +"""An in-process, full-duplex HTTP transport for driving ASGI applications from httpx. + +`httpx.ASGITransport` runs the application to completion and only then hands the buffered +response to the caller, so a server that streams its response — the streamable HTTP transport's +SSE responses — can never converse with the client mid-request: a server-initiated request +nested inside a still-open call deadlocks. `StreamingASGITransport` removes that limitation by +running the application as a background task and forwarding every `http.response.body` chunk to +the client the moment it is sent. Everything happens on the one event loop: no sockets, no +threads, no sleeps, no extra dependencies. + +The behavioural contract, pinned by `test_bridge.py`: + +- The request body is buffered before the application is invoked (MCP requests are small JSON + documents); the response streams chunk by chunk. +- Closing the response — or the whole client — delivers `http.disconnect` to the application, + exactly as a real server sees when its peer goes away. +- An exception the application raises before sending `http.response.start` fails the originating + request with that same exception. After the response has started, a failure is visible to the + client only through the response itself (status code, truncated body) — the same signal a real + server over a real socket would give. + +The transport owns an anyio task group for the application tasks; it is opened and closed by +`httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite +always does). +""" + +import math +from collections.abc import AsyncIterator +from types import TracebackType + +import anyio +import anyio.abc +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream +from starlette.types import ASGIApp, Message, Scope + + +class _StreamingResponseBody(httpx.AsyncByteStream): + """A response body that yields chunks as the application produces them. + + Closing it tells the application the client has gone away (`http.disconnect`), mirroring a + peer that drops the connection mid-response. + """ + + def __init__(self, chunks: MemoryObjectReceiveStream[bytes], client_disconnected: anyio.Event) -> None: + self._chunks = chunks + self._client_disconnected = client_disconnected + + async def __aiter__(self) -> AsyncIterator[bytes]: + async for chunk in self._chunks: + yield chunk + + async def aclose(self) -> None: + self._client_disconnected.set() + await self._chunks.aclose() + + +class StreamingASGITransport(httpx.AsyncBaseTransport): + """Drive an ASGI application in-process, streaming each response as it is produced.""" + + _task_group: anyio.abc.TaskGroup + + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __aenter__(self) -> "StreamingASGITransport": + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + # Any application task still running at this point is serving a client that no longer + # exists; cancel rather than wait so harness teardown can never hang. + self._task_group.cancel_scope.cancel() + await self._task_group.__aexit__(exc_type, exc_value, traceback) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + assert isinstance(request.stream, httpx.AsyncByteStream) + request_body = b"".join([chunk async for chunk in request.stream]) + + scope: Scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?", maxsplit=1)[0], + "query_string": request.url.query, + "root_path": "", + "headers": [(name.lower(), value) for name, value in request.headers.raw], + "server": (request.url.host, request.url.port), + "client": ("127.0.0.1", 1234), + } + + request_delivered = False + client_disconnected = anyio.Event() + response_started = anyio.Event() + response_status = 0 + response_headers: list[tuple[bytes, bytes]] = [] + application_error: Exception | None = None + chunk_writer, chunk_reader = anyio.create_memory_object_stream[bytes](math.inf) + + async def receive_request() -> Message: + nonlocal request_delivered + if not request_delivered: + request_delivered = True + return {"type": "http.request", "body": request_body, "more_body": False} + await client_disconnected.wait() + return {"type": "http.disconnect"} + + async def send_response(message: Message) -> None: + nonlocal response_status, response_headers + if message["type"] == "http.response.start": + response_status = message["status"] + response_headers = list(message.get("headers", [])) + response_started.set() + return + assert message["type"] == "http.response.body" + body: bytes = message.get("body", b"") + if body: + await chunk_writer.send(body) + if not message.get("more_body", False): + await chunk_writer.aclose() + + async def run_application() -> None: + nonlocal application_error + try: + await self._app(scope, receive_request, send_response) + except Exception as exc: # The bridge is the application's outermost boundary: a crash + # must fail the originating request (or show up in the already-started response), + # never tear down the task group shared with every other in-flight request. + application_error = exc + finally: + response_started.set() + await chunk_writer.aclose() + + self._task_group.start_soon(run_application) + await response_started.wait() + if application_error is not None: + # No response will be built, so close the reader the response body would have owned. + await chunk_reader.aclose() + raise application_error + return httpx.Response( + status_code=response_status, + headers=response_headers, + stream=_StreamingResponseBody(chunk_reader, client_disconnected), + request=request, + ) diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py new file mode 100644 index 0000000000..13389f8533 --- /dev/null +++ b/tests/interaction/transports/test_bridge.py @@ -0,0 +1,71 @@ +"""Contract tests for the suite's streaming ASGI bridge. + +These pin what `StreamingASGITransport` itself guarantees — chunk-by-chunk delivery, disconnect +propagation, and failure handling — against minimal hand-written ASGI applications, so the MCP +transport tests built on top of it never have to wonder what the harness provides. They are +harness self-tests, not interaction-model tests, and are exempted from the requirement-coverage +contract in `test_coverage.py`. +""" + +import anyio +import httpx +import pytest +from starlette.types import Message, Receive, Scope, Send + +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def test_response_chunks_arrive_as_the_application_sends_them() -> None: + """Each body chunk is delivered as sent, empty chunks are skipped, and the stream ends with the application.""" + + async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")]}) + await send({"type": "http.response.body", "body": b"first", "more_body": True}) + await send({"type": "http.response.body", "body": b"", "more_body": True}) + await send({"type": "http.response.body", "body": b"second", "more_body": False}) + + async with httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http: + async with http.stream("GET", "/chunks") as response: + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] + + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain" + assert chunks == [b"first", b"second"] + + +async def test_closing_the_response_delivers_a_disconnect_to_the_application() -> None: + """A client that closes the response early is seen by the application as an http.disconnect.""" + seen_after_request: list[Message] = [] + disconnect_seen = anyio.Event() + + async def waiting_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + assert (await receive())["type"] == "http.request" + await send({"type": "http.response.start", "status": 200, "headers": []}) + seen_after_request.append(await receive()) + disconnect_seen.set() + + async with httpx.AsyncClient(transport=StreamingASGITransport(waiting_app), base_url="http://bridge") as http: + async with http.stream("GET", "/wait") as response: + assert response.status_code == 200 + # Leaving the stream block closes the response while the application is still mid-response. + with anyio.fail_after(5): + await disconnect_seen.wait() + + assert seen_after_request == [{"type": "http.disconnect"}] + + +async def test_an_application_failure_before_the_response_starts_fails_the_request() -> None: + """An exception raised before http.response.start reaches the caller as that same exception.""" + + async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: + raise RuntimeError("the demo application is broken") + + async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: + with pytest.raises(RuntimeError, match="the demo application is broken"): + await http.get("/broken") diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index 4e5dd306c5..d2639266a5 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -1,37 +1,42 @@ -"""Smoke tests for the interaction model over the streamable HTTP transport, entirely in process. - -The Starlette app a real deployment would hand to uvicorn is driven through httpx's ASGI -transport instead: the full HTTP framing layer runs (session ids, SSE and JSON response -encoding, stateful and stateless session management) with no sockets, threads, or subprocesses, -so these tests are as deterministic as the in-memory ones. - -The ASGI client buffers each response in full before the client sees any of it. Request, -response, and notification flows are unaffected -- notifications are written to the request's -SSE stream before the response and arrive in order -- but a server-initiated request nested -inside a still-open call would deadlock, so that scenario is deferred to the real-socket -transport tests (see the `transport:streamable-http:server-to-client` requirement). +"""Tests for the interaction model over the streamable HTTP transport, entirely in process. + +The Starlette app a real deployment would hand to uvicorn is driven through the suite's +streaming ASGI bridge instead: the full HTTP framing layer runs (session ids, SSE and JSON +response encoding, stateful and stateless session management) with no sockets, threads, or +subprocesses, so these tests are as deterministic as the in-memory ones. Because the bridge +streams each response as the server produces it, full-duplex behaviour works too: a +server-initiated request nested inside a still-open call round-trips while that call's SSE +response remains open. """ from collections.abc import AsyncIterator from contextlib import asynccontextmanager +import anyio import httpx import pytest from inline_snapshot import snapshot from pydantic import BaseModel +from mcp.client import ClientRequestContext from mcp.client.client import Client from mcp.client.streamable_http import streamable_http_client +from mcp.server.elicitation import AcceptedElicitation from mcp.server.mcpserver import Context, MCPServer from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( CallToolResult, + ElicitRequestParams, + ElicitResult, LoggingMessageNotification, LoggingMessageNotificationParams, + ResourceUpdatedNotification, + ResourceUpdatedNotificationParams, TextContent, ) from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport pytestmark = pytest.mark.anyio @@ -63,9 +68,11 @@ class Confirmation(BaseModel): @mcp.tool() async def ask(ctx: Context) -> str: - """Attempt a server-initiated elicitation.""" - await ctx.elicit("Proceed?", Confirmation) - raise NotImplementedError # only called in stateless mode, where the elicit cannot succeed + """Elicit a confirmation from the client and report the outcome.""" + answer = await ctx.elicit("Proceed?", Confirmation) + # In stateless mode the elicit raises before this point: there is no session to call back through. + assert isinstance(answer, AcceptedElicitation) + return f"confirmed={answer.data.confirmed}" @mcp.tool() async def announce(ctx: Context) -> str: @@ -91,7 +98,7 @@ async def _connected( transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) async with mcp.session_manager.run(): - async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) async with Client(transport) as client: yield client @@ -165,7 +172,7 @@ async def collect_progress(progress: float, total: float | None, message: str | transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) async with server.session_manager.run(): - async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) async with Client(transport, logging_callback=collect_log) as client: result = await client.call_tool("narrate", {}, progress_callback=collect_progress) @@ -191,34 +198,76 @@ async def test_stateless_streamable_http_rejects_server_initiated_requests() -> @requirement("transport:streamable-http:unrelated-messages") -async def test_unrelated_server_messages_are_not_delivered_without_a_listening_stream() -> None: - """A server message with no related request goes to the standalone GET stream, not the call's stream. - - The client never opens the standalone stream, so the resource-updated notification is silently - dropped. The log notification sent by the same tool IS related to the call and does arrive, - proving the collector works and making the absence of the unrelated one meaningful. This is - the transport behaviour that makes `related_request_id` matter. +async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: + """A server message with no related request reaches the client through the standalone GET stream. + + The log notification is related to the tool call and travels on that call's own SSE stream; + the resource-updated notification is not related to any request, so the only way it can reach + the client is the standalone stream the client opens after initialization. Delivery order + across the two streams is not guaranteed, so the unrelated message is awaited rather than + assumed to beat the tool result. """ received: list[IncomingMessage] = [] + resource_update_seen = anyio.Event() async def collect(message: IncomingMessage) -> None: received.append(message) + if isinstance(message, ResourceUpdatedNotification): + resource_update_seen.set() server = _smoke_server() app = server.streamable_http_app( transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) ) async with server.session_manager.run(): - async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://127.0.0.1:8000") as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) async with Client(transport, message_handler=collect) as client: result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() assert result == snapshot( CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) ) - # Only the related log notification arrives; the resource-updated notification went to the - # standalone stream nobody is reading. - assert received == snapshot( + # The related log notification rides the call's stream; the unrelated resource-updated + # notification rides the standalone stream. Both arrive, nothing else does. + assert [message for message in received if isinstance(message, LoggingMessageNotification)] == snapshot( [LoggingMessageNotification(params=LoggingMessageNotificationParams(level="info", data="about to announce"))] ) + assert [message for message in received if isinstance(message, ResourceUpdatedNotification)] == snapshot( + [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] + ) + assert len(received) == 2 + + +@requirement("transport:streamable-http:server-to-client") +async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: + """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. + + The elicitation request travels on the still-open SSE response of the tool call that triggered + it, and the client's answer arrives as a separate POST -- the full-duplex exchange the + streamable HTTP transport exists to provide. + """ + asked: list[ElicitRequestParams] = [] + + async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult: + asked.append(params) + return ElicitResult(action="accept", content={"confirmed": True}) + + server = _smoke_server() + app = server.streamable_http_app( + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: + transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) + async with Client(transport, elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) + ) + assert [params.message for params in asked] == snapshot(["Proceed?"]) From d64f5259df771089ccc17074d002e133a1a687df Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 17:14:55 +0000 Subject: [PATCH 14/34] test: run the interaction suite over both in-memory and streamable HTTP transports --- src/mcp/server/streamable_http.py | 4 +- tests/interaction/README.md | 16 +- tests/interaction/_connect.py | 117 ++++++++++++++ tests/interaction/conftest.py | 22 +++ .../interaction/lowlevel/test_cancellation.py | 14 +- tests/interaction/lowlevel/test_completion.py | 18 +-- .../interaction/lowlevel/test_elicitation.py | 38 ++--- tests/interaction/lowlevel/test_initialize.py | 32 ++-- .../interaction/lowlevel/test_list_changed.py | 14 +- tests/interaction/lowlevel/test_logging.py | 14 +- tests/interaction/lowlevel/test_meta.py | 10 +- tests/interaction/lowlevel/test_pagination.py | 22 +-- tests/interaction/lowlevel/test_ping.py | 10 +- tests/interaction/lowlevel/test_progress.py | 18 +-- tests/interaction/lowlevel/test_prompts.py | 18 +-- tests/interaction/lowlevel/test_resources.py | 34 ++-- tests/interaction/lowlevel/test_roots.py | 22 +-- tests/interaction/lowlevel/test_sampling.py | 34 ++-- tests/interaction/lowlevel/test_tools.py | 42 ++--- tests/interaction/mcpserver/test_context.py | 26 +-- tests/interaction/mcpserver/test_prompts.py | 18 +-- tests/interaction/mcpserver/test_resources.py | 18 +-- tests/interaction/mcpserver/test_tools.py | 34 ++-- .../transports/test_streamable_http.py | 151 +++--------------- 24 files changed, 395 insertions(+), 351 deletions(-) create mode 100644 tests/interaction/_connect.py create mode 100644 tests/interaction/conftest.py diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index fbe3bd9676..a4cb5af03a 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # pragma: lax no cover # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -635,7 +635,7 @@ async def sse_writer(): # pragma: lax no cover finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 50bde98e41..df8f331596 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -37,10 +37,12 @@ The whole suite is in-memory and event-driven; it runs in about a second. tests/interaction/ _requirements.py the requirements manifest (see below) _helpers.py shared type aliases + the wire-recording transport + _connect.py the transport-parametrized connection factories + conftest.py the connect fixture (the transport matrix) test_coverage.py enforces the manifest ↔ test contract lowlevel/ one file per feature area, against the low-level Server mcpserver/ the same feature areas in MCPServer's natural idiom - transports/ a smoke subset over the streamable HTTP framing + transports/ behaviour specific to one transport (modes, streams, framing) ``` The two server APIs produce genuinely different wire output for the same conceptual feature @@ -48,6 +50,18 @@ The two server APIs produce genuinely different wire output for the same concept content), so they get parallel directories with mirrored file names rather than one parametrized test body — each directory pins its flavour's true output exactly. +### The transport matrix + +Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` +directly, and therefore run once per transport: over the in-memory transport and over the +server's real streamable HTTP app driven in process through the streaming bridge. A test connects +the same way in either case — `async with connect(server, ...) as client:` — and asserts the same +output, because the transport is not supposed to change observable behaviour. Tests that are tied +to one transport do not use the fixture: the wire-recording tests (their seam is the in-memory +stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout tests (the timeout +machinery is transport-independent and must not race transport latency), and everything under +`transports/`, which pins behaviour only observable on that transport. + ## The requirements manifest `_requirements.py` maps every behaviour the suite covers to the reason it must hold: diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py new file mode 100644 index 0000000000..a091e18d9a --- /dev/null +++ b/tests/interaction/_connect.py @@ -0,0 +1,117 @@ +"""Transport-parametrized connection factories for the interaction suite. + +The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body +runs over the in-memory transport and over streamable HTTP without naming either: the factory is a +drop-in replacement for constructing `Client(server, ...)` and yields the connected client. The +streamable HTTP factory drives the server's real Starlette app through the in-process streaming +bridge, so the full HTTP framing layer (session ids, SSE encoding, session management) runs with +no sockets, threads, or subprocesses. +""" + +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Protocol + +import httpx + +from mcp.client.client import Client +from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.mcpserver import MCPServer +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import Implementation +from tests.interaction.transports._bridge import StreamingASGITransport + +# The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. +_BASE_URL = "http://127.0.0.1:8000" + + +class Connect(Protocol): + """Connect a Client to a server over the transport selected by the `connect` fixture. + + Accepts the same keyword arguments as `Client` and yields the connected client. + """ + + def __call__( + self, + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, + ) -> AbstractAsyncContextManager[Client]: ... + + +@asynccontextmanager +async def connect_in_memory( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server over the in-memory transport.""" + async with Client( + server, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +@asynccontextmanager +async def connect_over_streamable_http( + server: Server | MCPServer, + *, + stateless_http: bool = False, + json_response: bool = False, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's streamable HTTP app, entirely in process. + + With the defaults this is the matrix leg (stateful sessions, SSE responses); the + transport-specific tests pass `stateless_http` or `json_response` to select the other + server modes. + """ + # DNS-rebinding protection validates Host/Origin headers against a real network attack that + # cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation + # branch (deliberately uncovered in src) into coverage. + app = server.streamable_http_app( + stateless_http=stateless_http, + json_response=json_response, + transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http_client: + transport = streamable_http_client(f"{_BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py new file mode 100644 index 0000000000..f8960bd13b --- /dev/null +++ b/tests/interaction/conftest.py @@ -0,0 +1,22 @@ +"""Shared fixtures for the interaction suite.""" + +import pytest + +from tests.interaction._connect import Connect, connect_in_memory, connect_over_streamable_http + +_FACTORIES: dict[str, Connect] = { + "in-memory": connect_in_memory, + "streamable-http": connect_over_streamable_http, +} + + +@pytest.fixture(params=sorted(_FACTORIES)) +def connect(request: pytest.FixtureRequest) -> Connect: + """The transport-parametrized connection factory: a test using it runs once per transport. + + Tests that are tied to one transport (the wire-recording tests, the bare-ClientSession tests, + the transport-specific tests under transports/) do not use this fixture and connect directly. + """ + transport_name = request.param + assert isinstance(transport_name, str) + return _FACTORIES[transport_name] diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index bbf984fc4a..eb07ef9404 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -11,9 +11,9 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import CallToolResult, ErrorData, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -21,7 +21,7 @@ @requirement("protocol:cancel:in-flight") @requirement("protocol:cancel:handler-abort-propagates") -async def test_cancellation_stops_in_flight_handler() -> None: +async def test_cancellation_stops_in_flight_handler(connect: Connect) -> None: """Cancelling an in-flight request interrupts its handler and fails the pending call. The server answers the cancelled request with an error response (the spec says it should @@ -47,7 +47,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: @@ -70,7 +70,7 @@ async def call_and_capture_error() -> None: @requirement("protocol:cancel:server-survives") -async def test_session_serves_requests_after_cancellation() -> None: +async def test_session_serves_requests_after_cancellation(connect: Connect) -> None: """A request cancelled mid-flight does not poison the session: the next request succeeds.""" started = anyio.Event() request_ids: list[types.RequestId] = [] @@ -96,7 +96,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: @@ -116,7 +116,7 @@ async def call_and_swallow_cancellation_error() -> None: @requirement("protocol:cancel:unknown-id-ignored") -async def test_cancellation_for_unknown_request_is_ignored() -> None: +async def test_cancellation_for_unknown_request_is_ignored(connect: Connect) -> None: """A cancellation referencing a request id that is not in flight is ignored without error.""" async def list_tools( @@ -130,7 +130,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("calm", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: await client.session.send_notification( types.CancelledNotification(params=types.CancelledNotificationParams(request_id=9999)) ) diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index ea2529169f..e036d48c3c 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -4,7 +4,6 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( METHOD_NOT_FOUND, @@ -14,6 +13,7 @@ PromptReference, ResourceTemplateReference, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -21,7 +21,7 @@ @requirement("completion:prompt-arg") @requirement("completion:result-shape") -async def test_complete_prompt_argument() -> None: +async def test_complete_prompt_argument(connect: Connect) -> None: """Completing a prompt argument delivers the ref, argument name, and current value to the handler. The returned values are filtered by the argument's value, proving the value reached the handler. @@ -37,7 +37,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar server = Server("completer", on_completion=completion) - async with Client(server) as client: + async with connect(server) as client: result = await client.complete( PromptReference(name="code_review"), argument={"name": "language", "value": "py"} ) @@ -48,7 +48,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("completion:resource-template-arg") -async def test_complete_resource_template_variable() -> None: +async def test_complete_resource_template_variable(connect: Connect) -> None: """Completing a URI template variable delivers the template URI and variable name to the handler.""" async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: @@ -59,7 +59,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar server = Server("completer", on_completion=completion) - async with Client(server) as client: + async with connect(server) as client: result = await client.complete( ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), argument={"name": "owner", "value": "model"}, @@ -69,7 +69,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("completion:context-arguments") -async def test_complete_receives_context_arguments() -> None: +async def test_complete_receives_context_arguments(connect: Connect) -> None: """Previously-resolved arguments passed as completion context reach the handler. The returned value is derived from the context, proving it arrived. @@ -83,7 +83,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar server = Server("completer", on_completion=completion) - async with Client(server) as client: + async with connect(server) as client: result = await client.complete( ResourceTemplateReference(uri="github://repos/{owner}/{repo}"), argument={"name": "repo", "value": ""}, @@ -95,11 +95,11 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("completion:complete:not-supported") @requirement("protocol:error:method-not-found") -async def test_complete_without_handler_is_method_not_found() -> None: +async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: """A server with no completion handler advertises no completions capability and rejects the request.""" server = Server("incomplete") - async with Client(server) as client: + async with connect(server) as client: assert client.initialize_result.capabilities.completions is None with pytest.raises(MCPError) as exc_info: diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index d46728be2e..d27613dd36 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -5,7 +5,6 @@ from mcp import MCPError, UrlElicitationRequiredError, types from mcp.client import ClientRequestContext -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( CallToolResult, @@ -17,6 +16,7 @@ ErrorData, TextContent, ) +from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -35,7 +35,7 @@ @requirement("elicitation:form:action:accept") @requirement("elicitation:form:basic") @requirement("tools:call:elicitation-roundtrip") -async def test_elicit_form_accepted_content_returns_to_handler() -> None: +async def test_elicit_form_accepted_content_returns_to_handler(connect: Connect) -> None: """An accepted form elicitation returns the user's content to the requesting handler. The tool reports the action as text and the received content as structured content, proving @@ -61,7 +61,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest received.append(params) return ElicitResult(action="accept", content={"username": "ada", "newsletter": True}) - async with Client(server, elicitation_callback=answer_form) as client: + async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("signup", {}) assert received == snapshot( @@ -89,7 +89,7 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest @requirement("elicitation:form:action:decline") -async def test_elicit_form_decline_returns_no_content() -> None: +async def test_elicit_form_decline_returns_no_content(connect: Connect) -> None: """A declined form elicitation returns the decline action to the handler with no content.""" async def list_tools( @@ -109,14 +109,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") - async with Client(server, elicitation_callback=answer_form) as client: + async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("confirm", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) @requirement("elicitation:form:action:cancel") -async def test_elicit_form_cancel_returns_no_content() -> None: +async def test_elicit_form_cancel_returns_no_content(connect: Connect) -> None: """A cancelled form elicitation returns the cancel action to the handler with no content.""" async def list_tools( @@ -136,14 +136,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: return ElicitResult(action="cancel") - async with Client(server, elicitation_callback=answer_form) as client: + async with connect(server, elicitation_callback=answer_form) as client: result = await client.call_tool("confirm", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) @requirement("elicitation:form:not-supported") -async def test_elicit_form_without_callback_is_error() -> None: +async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: """Eliciting from a client that configured no elicitation callback fails with an error. The client's default callback answers with an Invalid request error, which the server-side @@ -168,7 +168,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("asker", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("ask", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Elicitation not supported")])) @@ -176,7 +176,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("elicitation:url:action:accept-no-content") @requirement("elicitation:url:basic") -async def test_elicit_url_delivers_url_and_returns_accept_without_content() -> None: +async def test_elicit_url_delivers_url_and_returns_accept_without_content(connect: Connect) -> None: """A URL elicitation delivers the message, URL, and elicitation id to the client; accepting it returns the action with no content. @@ -205,7 +205,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP received.append(params) return ElicitResult(action="accept") - async with Client(server, elicitation_callback=answer_url) as client: + async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) assert received == snapshot( @@ -222,7 +222,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP @requirement("elicitation:url:decline") -async def test_elicit_url_decline_returns_no_content() -> None: +async def test_elicit_url_decline_returns_no_content(connect: Connect) -> None: """A declined URL elicitation returns the decline action to the handler with no content.""" async def list_tools( @@ -244,14 +244,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: return ElicitResult(action="decline") - async with Client(server, elicitation_callback=answer_url) as client: + async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="decline content=None")])) @requirement("elicitation:url:cancel") -async def test_elicit_url_cancel_returns_no_content() -> None: +async def test_elicit_url_cancel_returns_no_content(connect: Connect) -> None: """A cancelled URL elicitation returns the cancel action to the handler with no content.""" async def list_tools( @@ -273,14 +273,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def answer_url(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: return ElicitResult(action="cancel") - async with Client(server, elicitation_callback=answer_url) as client: + async with connect(server, elicitation_callback=answer_url) as client: result = await client.call_tool("authorize", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="cancel content=None")])) @requirement("elicitation:url:complete-notification") -async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client() -> None: +async def test_elicitation_complete_notification_carries_the_elicited_id_back_to_the_client(connect: Connect) -> None: """After a URL elicitation finishes, the server announces it with a notification carrying the same id. The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user @@ -319,7 +319,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP elicited_ids.append(params.elicitation_id) return ElicitResult(action="accept") - async with Client(server, message_handler=collect, elicitation_callback=answer_url) as client: + async with connect(server, message_handler=collect, elicitation_callback=answer_url) as client: await client.call_tool("link_account", {}) # The completion notification refers to the same elicitation the client accepted. @@ -330,7 +330,7 @@ async def answer_url(context: ClientRequestContext, params: types.ElicitRequestP @requirement("elicitation:url:required-error") -async def test_url_elicitation_required_error_carries_pending_elicitations() -> None: +async def test_url_elicitation_required_error_carries_pending_elicitations(connect: Connect) -> None: """A request that cannot proceed until a URL interaction completes is rejected with error -32042. This is the non-interactive alternative to elicit_url: instead of asking and waiting, the @@ -353,7 +353,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("authorizer", on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("read_files", {}) diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 16b943f960..32da2f3338 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -14,7 +14,6 @@ from mcp import MCPError, types from mcp.client import ClientRequestContext, ClientSession from mcp.client._memory import InMemoryTransport -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage @@ -41,6 +40,7 @@ TextContent, ToolsCapability, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -48,7 +48,7 @@ @requirement("lifecycle:initialize:basic") @requirement("lifecycle:initialize:server-info") -async def test_initialize_returns_server_info() -> None: +async def test_initialize_returns_server_info(connect: Connect) -> None: """Every identity field the server declares is returned to the client in server_info.""" server = Server( "greeter", @@ -59,7 +59,7 @@ async def test_initialize_returns_server_info() -> None: icons=[Icon(src="https://example.com/icon.png", mime_type="image/png", sizes=["48x48"])], ) - async with Client(server) as client: + async with connect(server) as client: server_info = client.initialize_result.server_info assert server_info == snapshot( @@ -75,12 +75,12 @@ async def test_initialize_returns_server_info() -> None: @requirement("lifecycle:initialize:instructions") -async def test_initialize_returns_instructions() -> None: +async def test_initialize_returns_instructions(connect: Connect) -> None: """Instructions are returned when the server declares them and omitted when it does not.""" - async with Client(Server("guided", instructions="Call the add tool.")) as client: + async with connect(Server("guided", instructions="Call the add tool.")) as client: assert client.initialize_result.instructions == snapshot("Call the add tool.") - async with Client(Server("unguided")) as client: + async with connect(Server("unguided")) as client: assert client.initialize_result.instructions is None @@ -89,7 +89,7 @@ async def test_initialize_returns_instructions() -> None: @requirement("resources:capability:declared") @requirement("prompts:capability:declared") @requirement("completion:capability:declared") -async def test_initialize_capabilities_reflect_registered_handlers() -> None: +async def test_initialize_capabilities_reflect_registered_handlers(connect: Connect) -> None: """Each feature area with a registered handler is advertised as a capability. The in-memory transport connects with default initialization options, so the @@ -136,7 +136,7 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar on_completion=completion, ) - async with Client(server) as client: + async with connect(server) as client: capabilities = client.initialize_result.capabilities assert capabilities == snapshot( @@ -152,16 +152,16 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar @requirement("lifecycle:initialize:capabilities:minimal") -async def test_initialize_minimal_server_advertises_no_capabilities() -> None: +async def test_initialize_minimal_server_advertises_no_capabilities(connect: Connect) -> None: """A server with no feature handlers advertises no feature capabilities.""" - async with Client(Server("bare")) as client: + async with connect(Server("bare")) as client: capabilities = client.initialize_result.capabilities assert capabilities == snapshot(ServerCapabilities(experimental={})) @requirement("lifecycle:initialize:client-info") -async def test_initialize_server_sees_client_info() -> None: +async def test_initialize_server_sees_client_info(connect: Connect) -> None: """The client identity supplied to Client is visible to server handlers after initialization.""" async def list_tools( @@ -178,16 +178,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara return CallToolResult(content=[TextContent(text=f"{client_info.name} {client_info.version}")]) server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) - client = Client(server, client_info=Implementation(name="acme-agent", version="9.9.9")) - - async with client: + async with connect(server, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: result = await client.call_tool("whoami", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="acme-agent 9.9.9")])) @requirement("lifecycle:initialize:client-capabilities") -async def test_initialize_server_sees_client_capabilities() -> None: +async def test_initialize_server_sees_client_capabilities(connect: Connect) -> None: """The client capabilities visible to the server reflect which callbacks the client configured.""" async def list_tools( @@ -219,11 +217,11 @@ async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("abilities", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="none")])) - async with Client(server, list_roots_callback=list_roots) as client: + async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("abilities", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="roots(list_changed=True)")])) diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index e06c6f33f6..eb20db207b 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -10,7 +10,6 @@ from inline_snapshot import snapshot from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( CallToolResult, @@ -19,6 +18,7 @@ TextContent, ToolListChangedNotification, ) +from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -26,7 +26,7 @@ @requirement("tools:list-changed") -async def test_tool_list_changed_notification() -> None: +async def test_tool_list_changed_notification(connect: Connect) -> None: """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] @@ -45,14 +45,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, message_handler=collect) as client: + async with connect(server, message_handler=collect) as client: await client.call_tool("install", {}) assert received == snapshot([ToolListChangedNotification()]) @requirement("resources:list-changed") -async def test_resource_list_changed_notification() -> None: +async def test_resource_list_changed_notification(connect: Connect) -> None: """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] @@ -71,14 +71,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, message_handler=collect) as client: + async with connect(server, message_handler=collect) as client: await client.call_tool("mount", {}) assert received == snapshot([ResourceListChangedNotification()]) @requirement("prompts:list-changed") -async def test_prompt_list_changed_notification() -> None: +async def test_prompt_list_changed_notification(connect: Connect) -> None: """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] @@ -97,7 +97,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, message_handler=collect) as client: + async with connect(server, message_handler=collect) as client: await client.call_tool("learn", {}) assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index 9f9110a3cf..792334ecd2 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -13,9 +13,9 @@ from inline_snapshot import snapshot from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import CallToolResult, EmptyResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -33,7 +33,7 @@ @requirement("logging:set-level") -async def test_set_logging_level_reaches_handler() -> None: +async def test_set_logging_level_reaches_handler(connect: Connect) -> None: """The level requested by the client is delivered to the server's handler verbatim.""" async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: @@ -42,7 +42,7 @@ async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelReq server = Server("logger", on_set_logging_level=set_logging_level) - async with Client(server) as client: + async with connect(server) as client: result = await client.set_logging_level("warning") assert result == snapshot(EmptyResult()) @@ -50,7 +50,7 @@ async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelReq @requirement("logging:message:fields") @requirement("tools:call:logging-mid-execution") -async def test_log_messages_reach_logging_callback_in_order() -> None: +async def test_log_messages_reach_logging_callback_in_order(connect: Connect) -> None: """Log messages sent during a tool call arrive at the logging callback, in order, before the call returns. The two messages pin the full notification shape: severity, optional logger name, and both @@ -74,7 +74,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, logging_callback=collect) as client: + async with connect(server, logging_callback=collect) as client: result = await client.call_tool("chatty", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="done")])) @@ -87,7 +87,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("logging:message:all-levels") -async def test_log_messages_at_every_severity_level() -> None: +async def test_log_messages_at_every_severity_level(connect: Connect) -> None: """Each of the eight RFC 5424 severity levels is deliverable as a log message notification.""" received: list[LoggingMessageNotificationParams] = [] @@ -107,7 +107,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, logging_callback=collect) as client: + async with connect(server, logging_callback=collect) as client: await client.call_tool("siren", {}) assert [params.level for params in received] == list(ALL_LEVELS) diff --git a/tests/interaction/lowlevel/test_meta.py b/tests/interaction/lowlevel/test_meta.py index a63acbfa5c..a9e4f994d8 100644 --- a/tests/interaction/lowlevel/test_meta.py +++ b/tests/interaction/lowlevel/test_meta.py @@ -8,16 +8,16 @@ import pytest from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import CallToolResult, RequestParamsMeta, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("meta:request-to-handler") -async def test_request_meta_reaches_handler() -> None: +async def test_request_meta_reaches_handler(connect: Connect) -> None: """The _meta object the client attaches to a request arrives at the tool handler unchanged.""" request_meta: RequestParamsMeta = {"example.com/trace": "abc-123"} observed_metas: list[dict[str, object]] = [] @@ -35,14 +35,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: await client.call_tool("traced", {}, meta=request_meta) assert observed_metas == [dict(request_meta)] @requirement("meta:result-to-client") -async def test_result_meta_reaches_client() -> None: +async def test_result_meta_reaches_client(connect: Connect) -> None: """The _meta object a handler attaches to its result is delivered to the client unchanged.""" result_meta = {"example.com/cost": 3} @@ -57,7 +57,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("observability", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("metered", {}) assert result == CallToolResult(content=[TextContent(text="done")], _meta=result_meta) diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 3304450a6f..1b6ac3e66a 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -9,7 +9,6 @@ from inline_snapshot import snapshot from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( ListPromptsResult, @@ -21,13 +20,14 @@ ResourceTemplate, Tool, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("tools:list:pagination") -async def test_next_cursor_round_trips_through_the_client() -> None: +async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> None: """The next_cursor a list handler returns reaches the client, and the cursor the client sends back on the following call reaches the handler verbatim. """ @@ -45,7 +45,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa server = Server("paginated", on_list_tools=list_tools) - async with Client(server) as client: + async with connect(server) as client: first_page = await client.list_tools() second_page = await client.list_tools(cursor="page-2") @@ -58,7 +58,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("pagination:exhaustion") @requirement("tools:list:pagination") -async def test_paginating_until_next_cursor_is_absent_yields_every_page() -> None: +async def test_paginating_until_next_cursor_is_absent_yields_every_page(connect: Connect) -> None: """Following next_cursor until it is absent visits every page exactly once, in order.""" pages: dict[str | None, tuple[str, str | None]] = { None: ("alpha", "page-2"), @@ -76,7 +76,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa collected: list[str] = [] cursor: str | None = None requests_made = 0 - async with Client(server) as client: + async with connect(server) as client: while True: result = await client.list_tools(cursor=cursor) requests_made += 1 @@ -91,7 +91,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("resources:list:pagination") -async def test_resources_list_supports_cursor_pagination() -> None: +async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: """resources/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] @@ -106,7 +106,7 @@ async def list_resources( server = Server("paginated", on_list_resources=list_resources) - async with Client(server) as client: + async with connect(server) as client: first_page = await client.list_resources() second_page = await client.list_resources(cursor="page-2") @@ -118,7 +118,7 @@ async def list_resources( @requirement("resources:templates:pagination") -async def test_resource_templates_list_supports_cursor_pagination() -> None: +async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: """resources/templates/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] @@ -138,7 +138,7 @@ async def list_resource_templates( server = Server("paginated", on_list_resource_templates=list_resource_templates) - async with Client(server) as client: + async with connect(server) as client: first_page = await client.list_resource_templates() second_page = await client.list_resource_templates(cursor="page-2") @@ -150,7 +150,7 @@ async def list_resource_templates( @requirement("prompts:list:pagination") -async def test_prompts_list_supports_cursor_pagination() -> None: +async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: """prompts/list round-trips the cursor like every other list operation.""" seen_cursors: list[str | None] = [] @@ -163,7 +163,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest server = Server("paginated", on_list_prompts=list_prompts) - async with Client(server) as client: + async with connect(server) as client: first_page = await client.list_prompts() second_page = await client.list_prompts(cursor="page-2") diff --git a/tests/interaction/lowlevel/test_ping.py b/tests/interaction/lowlevel/test_ping.py index 6a82601d48..797e20dc35 100644 --- a/tests/interaction/lowlevel/test_ping.py +++ b/tests/interaction/lowlevel/test_ping.py @@ -4,9 +4,9 @@ from inline_snapshot import snapshot from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import CallToolResult, EmptyResult, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -14,11 +14,11 @@ @requirement("lifecycle:ping") @requirement("ping:client-to-server") -async def test_client_ping_returns_empty_result() -> None: +async def test_client_ping_returns_empty_result(connect: Connect) -> None: """A client ping is answered with an empty result, even by a server with no handlers.""" server = Server("silent") - async with Client(server) as client: + async with connect(server) as client: result = await client.send_ping() assert result == snapshot(EmptyResult()) @@ -26,7 +26,7 @@ async def test_client_ping_returns_empty_result() -> None: @requirement("lifecycle:ping") @requirement("ping:server-to-client") -async def test_server_ping_returns_empty_result() -> None: +async def test_server_ping_returns_empty_result(connect: Connect) -> None: """A server-initiated ping sent while a request is in flight is answered by the client. The tool returns the type of the ping response, proving the round trip completed inside @@ -47,7 +47,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("pinger", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("ping_back", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="EmptyResult")])) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index f39737a27f..56eae40d7d 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -12,9 +12,9 @@ from inline_snapshot import snapshot from mcp import types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import CallToolResult, ProgressNotificationParams, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -22,7 +22,7 @@ @requirement("protocol:progress:callback") @requirement("tools:call:progress") -async def test_progress_during_tool_call_reaches_callback_in_order() -> None: +async def test_progress_during_tool_call_reaches_callback_in_order(connect: Connect) -> None: """Progress notifications emitted by a tool handler reach the caller's progress callback in order.""" received: list[tuple[float, float | None, str | None]] = [] @@ -46,7 +46,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("download", {}, progress_callback=collect) assert result == snapshot(CallToolResult(content=[TextContent(text="downloaded")])) @@ -54,7 +54,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:progress:token-injected") -async def test_progress_token_visible_to_handler() -> None: +async def test_progress_token_visible_to_handler(connect: Connect) -> None: """Supplying a progress callback attaches a progress token that the handler can read from the request meta.""" async def list_tools( @@ -73,7 +73,7 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N """A progress callback that is never invoked; the tool only inspects the token.""" raise NotImplementedError - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("inspect", {}, progress_callback=ignore) # The token is the request id of the tools/call request itself (initialize is request 0). @@ -81,7 +81,7 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N @requirement("protocol:progress:no-token") -async def test_no_progress_callback_means_no_token() -> None: +async def test_no_progress_callback_means_no_token(connect: Connect) -> None: """Without a progress callback the request carries no progress token. The low-level API has no way to report request-scoped progress without a token, so a handler @@ -100,14 +100,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("inspect", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="None")])) @requirement("protocol:progress:client-to-server") -async def test_client_progress_notification_reaches_server_handler() -> None: +async def test_client_progress_notification_reaches_server_handler(connect: Connect) -> None: """A progress notification sent by the client is delivered to the server's progress handler.""" received: list[ProgressNotificationParams] = [] delivered = anyio.Event() @@ -118,7 +118,7 @@ async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationPar server = Server("observer", on_progress=on_progress) - async with Client(server) as client: + async with connect(server) as client: await client.send_progress_notification("upload-1", 0.5, total=1.0, message="halfway") with anyio.fail_after(5): await delivered.wait() diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index 52ef3a85d4..b09f765755 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -4,7 +4,6 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( INVALID_PARAMS, @@ -17,13 +16,14 @@ PromptMessage, TextContent, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("prompts:list:basic") -async def test_list_prompts_returns_registered_prompts() -> None: +async def test_list_prompts_returns_registered_prompts(connect: Connect) -> None: """The prompts returned by the handler reach the client with their argument declarations intact.""" async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: @@ -44,7 +44,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest server = Server("prompter", on_list_prompts=list_prompts) - async with Client(server) as client: + async with connect(server) as client: result = await client.list_prompts() assert result == snapshot( @@ -66,7 +66,7 @@ async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequest @requirement("prompts:get:with-args") -async def test_get_prompt_substitutes_arguments() -> None: +async def test_get_prompt_substitutes_arguments(connect: Connect) -> None: """Arguments supplied by the client reach the prompt handler; the templated message comes back.""" async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: @@ -79,7 +79,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa server = Server("prompter", on_get_prompt=get_prompt) - async with Client(server) as client: + async with connect(server) as client: result = await client.get_prompt("greet", {"name": "Ada"}) assert result == snapshot( @@ -91,7 +91,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa @requirement("prompts:get:multi-message") -async def test_get_prompt_multiple_messages_preserve_roles_and_order() -> None: +async def test_get_prompt_multiple_messages_preserve_roles_and_order(connect: Connect) -> None: """A prompt returning a user/assistant conversation reaches the client with roles and order intact.""" async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: @@ -106,7 +106,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa server = Server("prompter", on_get_prompt=get_prompt) - async with Client(server) as client: + async with connect(server) as client: result = await client.get_prompt("geography_quiz") assert result == snapshot( @@ -121,7 +121,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa @requirement("prompts:get:unknown-name") -async def test_get_prompt_unknown_name_is_protocol_error() -> None: +async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. The error's code and message chosen by the handler reach the client verbatim. @@ -132,7 +132,7 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa server = Server("prompter", on_get_prompt=get_prompt) - async with Client(server) as client: + async with connect(server) as client: with pytest.raises(MCPError) as exc_info: await client.get_prompt("nope") diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 5b02797020..1d29a62e07 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -6,7 +6,6 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( Annotations, @@ -25,6 +24,7 @@ TextContent, TextResourceContents, ) +from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -32,7 +32,7 @@ @requirement("resources:list:basic") -async def test_list_resources_returns_registered_resources() -> None: +async def test_list_resources_returns_registered_resources(connect: Connect) -> None: """Listed resources reach the client with their URIs, names, and optional descriptive fields intact.""" async def list_resources( @@ -56,7 +56,7 @@ async def list_resources( server = Server("library", on_list_resources=list_resources) - async with Client(server) as client: + async with connect(server) as client: result = await client.list_resources() assert result == snapshot( @@ -79,7 +79,7 @@ async def list_resources( @requirement("resources:read:text") -async def test_read_resource_text() -> None: +async def test_read_resource_text(connect: Connect) -> None: """Reading a text resource returns its contents with the URI, MIME type, and text supplied by the handler.""" async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: @@ -89,7 +89,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq server = Server("library", on_read_resource=read_resource) - async with Client(server) as client: + async with connect(server) as client: result = await client.read_resource("file:///greeting.txt") assert result == snapshot( @@ -100,7 +100,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:read:blob") -async def test_read_resource_binary() -> None: +async def test_read_resource_binary(connect: Connect) -> None: """Reading a binary resource returns its contents base64-encoded in the blob field.""" async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: @@ -116,7 +116,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq server = Server("library", on_read_resource=read_resource) - async with Client(server) as client: + async with connect(server) as client: result = await client.read_resource("file:///pixel.png") assert result == snapshot( @@ -127,7 +127,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:read:unknown-uri") -async def test_read_resource_unknown_uri_is_protocol_error() -> None: +async def test_read_resource_unknown_uri_is_protocol_error(connect: Connect) -> None: """A handler that rejects an unrecognised URI with MCPError produces a JSON-RPC error. The spec reserves -32002 for resource-not-found; the code is the handler's choice and reaches @@ -139,7 +139,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq server = Server("library", on_read_resource=read_resource) - async with Client(server) as client: + async with connect(server) as client: with pytest.raises(MCPError) as exc_info: await client.read_resource("file:///missing.txt") @@ -147,7 +147,7 @@ async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceReq @requirement("resources:templates:list") -async def test_list_resource_templates_returns_registered_templates() -> None: +async def test_list_resource_templates_returns_registered_templates(connect: Connect) -> None: """Listed resource templates reach the client with their URI templates and descriptive fields intact.""" async def list_resource_templates( @@ -169,7 +169,7 @@ async def list_resource_templates( server = Server("library", on_list_resource_templates=list_resource_templates) - async with Client(server) as client: + async with connect(server) as client: result = await client.list_resource_templates() assert result == snapshot( @@ -190,7 +190,7 @@ async def list_resource_templates( @requirement("resources:subscribe") -async def test_subscribe_resource_delivers_uri_to_handler() -> None: +async def test_subscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Subscribing to a resource delivers the URI to the server's subscribe handler and returns an empty result.""" async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: @@ -199,14 +199,14 @@ async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeR server = Server("library", on_subscribe_resource=subscribe_resource) - async with Client(server) as client: + async with connect(server) as client: result = await client.subscribe_resource("file:///watched.txt") assert result == snapshot(EmptyResult()) @requirement("resources:unsubscribe") -async def test_unsubscribe_resource_delivers_uri_to_handler() -> None: +async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" async def unsubscribe_resource(ctx: ServerRequestContext, params: types.UnsubscribeRequestParams) -> EmptyResult: @@ -215,14 +215,14 @@ async def unsubscribe_resource(ctx: ServerRequestContext, params: types.Unsubscr server = Server("library", on_unsubscribe_resource=unsubscribe_resource) - async with Client(server) as client: + async with connect(server) as client: result = await client.unsubscribe_resource("file:///watched.txt") assert result == snapshot(EmptyResult()) @requirement("resources:updated-notification") -async def test_resource_updated_notification_reaches_client() -> None: +async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: """A resources/updated notification sent during a tool call reaches the client with the resource URI. The collector records every message the handler receives, so the assertion also proves nothing @@ -245,7 +245,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("library", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server, message_handler=collect) as client: + async with connect(server, message_handler=collect) as client: await client.call_tool("touch", {}) assert received == snapshot( diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 94cd1b9303..577b99819c 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -7,16 +7,16 @@ from mcp import MCPError, types from mcp.client import ClientRequestContext -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import INTERNAL_ERROR, CallToolResult, ErrorData, ListRootsResult, Root, TextContent +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("roots:list:basic") -async def test_list_roots_round_trip() -> None: +async def test_list_roots_round_trip(connect: Connect) -> None: """A roots/list request from a tool handler is answered by the client's roots callback. The tool reports the URIs and names it received, proving the client's roots reached the server. @@ -43,7 +43,7 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: ] ) - async with Client(server, list_roots_callback=list_roots) as client: + async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("show_roots", {}) assert result == snapshot( @@ -54,7 +54,7 @@ async def list_roots(context: ClientRequestContext) -> ListRootsResult: @requirement("roots:list:empty") -async def test_list_roots_empty() -> None: +async def test_list_roots_empty(connect: Connect) -> None: """A client with no roots to offer answers roots/list with an empty list, not an error.""" async def list_tools( @@ -72,14 +72,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def list_roots(context: ClientRequestContext) -> ListRootsResult: return ListRootsResult(roots=[]) - async with Client(server, list_roots_callback=list_roots) as client: + async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("count_roots", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="0")])) @requirement("roots:list:not-supported") -async def test_list_roots_without_callback_is_error() -> None: +async def test_list_roots_without_callback_is_error(connect: Connect) -> None: """A roots/list request to a client with no roots callback fails with an error the handler can observe. The client's default callback answers with INVALID_REQUEST rather than leaving the server @@ -101,14 +101,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("rooted", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("show_roots", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: List roots not supported")])) @requirement("roots:list:client-error") -async def test_list_roots_callback_error_surfaces_to_the_handler() -> None: +async def test_list_roots_callback_error_surfaces_to_the_handler(connect: Connect) -> None: """A roots callback that answers with an error fails the roots/list request with that exact error. The callback's code and message reach the requesting handler verbatim as an MCPError. @@ -132,14 +132,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def list_roots(context: ClientRequestContext) -> ErrorData: return ErrorData(code=INTERNAL_ERROR, message="roots provider crashed") - async with Client(server, list_roots_callback=list_roots) as client: + async with connect(server, list_roots_callback=list_roots) as client: result = await client.call_tool("show_roots", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32603: roots provider crashed")])) @requirement("roots:list-changed") -async def test_roots_list_changed_reaches_server_handler() -> None: +async def test_roots_list_changed_reaches_server_handler(connect: Connect) -> None: """A roots/list_changed notification from the client is delivered to the server's handler. Unlike a request, a notification has no response to await: the handler sets an event and the @@ -154,7 +154,7 @@ async def roots_list_changed(ctx: ServerRequestContext, params: types.Notificati server = Server("rooted", on_roots_list_changed=roots_list_changed) - async with Client(server) as client: + async with connect(server) as client: await client.send_roots_list_changed() with anyio.fail_after(5): await delivered.wait() diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 6903f86abb..85eb8c3455 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -10,7 +10,6 @@ from mcp import MCPError, types from mcp.client import ClientRequestContext -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( CallToolResult, @@ -24,6 +23,7 @@ TextContent, ToolResultContent, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -31,7 +31,7 @@ @requirement("sampling:create:basic") @requirement("tools:call:sampling-roundtrip") -async def test_create_message_round_trip() -> None: +async def test_create_message_round_trip(connect: Connect) -> None: """A handler's sampling request is answered by the client callback, and the callback's result (role, content, model, stop reason) is returned to the handler. """ @@ -64,7 +64,7 @@ async def sampling_callback( stop_reason="endTurn", ) - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="mock-llm-1/endTurn: Hello to you too.")])) @@ -82,7 +82,7 @@ async def sampling_callback( @requirement("sampling:create:include-context") @requirement("sampling:create:model-preferences") @requirement("sampling:create:system-prompt") -async def test_create_message_params_reach_callback() -> None: +async def test_create_message_params_reach_callback(connect: Connect) -> None: """Every sampling parameter the handler supplies arrives at the client callback unchanged.""" received: list[CreateMessageRequestParams] = [] @@ -118,7 +118,7 @@ async def sampling_callback( received.append(params) return CreateMessageResult(role="assistant", content=TextContent(text="ok"), model="mock-llm-1") - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) @@ -144,7 +144,7 @@ async def sampling_callback( @requirement("sampling:create-message:image-content") -async def test_create_message_request_with_image_content_reaches_callback() -> None: +async def test_create_message_request_with_image_content_reaches_callback(connect: Connect) -> None: """A sampling request message carrying image content arrives at the client callback intact. This is the server-to-client direction: the server includes an image in the conversation it @@ -180,7 +180,7 @@ async def sampling_callback( model="mock-vision-1", ) - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("describe_image", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="described image/png (aW1n)")])) @@ -196,7 +196,7 @@ async def sampling_callback( @requirement("sampling:create-message:image-content") -async def test_create_message_result_with_image_content_returns_to_handler() -> None: +async def test_create_message_result_with_image_content_returns_to_handler(connect: Connect) -> None: """A sampling result whose content is an image is returned to the requesting handler intact. This is the client-to-server direction: the model's response is an image rather than text. @@ -228,14 +228,14 @@ async def sampling_callback( model="mock-vision-1", ) - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("draw", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="mock-vision-1: image/png Y2F0")])) @requirement("sampling:error:user-rejected") -async def test_create_message_callback_error() -> None: +async def test_create_message_callback_error(connect: Connect) -> None: """A sampling callback that answers with an error surfaces to the requesting handler as an MCPError. The error here is the spec's own example for a user rejecting a sampling request (code -1); @@ -263,14 +263,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> ErrorData: return ErrorData(code=-1, message="User rejected sampling request") - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-1: User rejected sampling request")])) @requirement("sampling:create-message:not-supported") -async def test_create_message_without_callback_is_error() -> None: +async def test_create_message_without_callback_is_error(connect: Connect) -> None: """A sampling request to a client with no sampling callback fails with the SDK's default error.""" async def list_tools( @@ -291,14 +291,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("ask_model", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="-32600: Sampling not supported")])) @requirement("sampling:tools:server-gated-by-capability") -async def test_create_message_with_tools_is_rejected_for_unsupporting_client() -> None: +async def test_create_message_with_tools_is_rejected_for_unsupporting_client(connect: Connect) -> None: """A tool-enabled sampling request to a client that has not declared sampling.tools never leaves the server. The client supports plain sampling but cannot declare the tools sub-capability (Client does not @@ -330,7 +330,7 @@ async def sampling_callback( """Declares the plain sampling capability; never invoked because the request is rejected first.""" raise NotImplementedError - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("ask_model", {}) assert result == snapshot( @@ -339,7 +339,7 @@ async def sampling_callback( @requirement("sampling:tool-result:no-mixed-content") -async def test_create_message_with_unbalanced_tool_messages_is_rejected() -> None: +async def test_create_message_with_unbalanced_tool_messages_is_rejected(connect: Connect) -> None: """A sampling request whose messages mix tool results with other content never leaves the server. The message-structure validation runs inside create_message before the request is sent, even @@ -379,7 +379,7 @@ async def sampling_callback( """Declares the sampling capability; never invoked because the request is rejected first.""" raise NotImplementedError - async with Client(server, sampling_callback=sampling_callback) as client: + async with connect(server, sampling_callback=sampling_callback) as client: result = await client.call_tool("summarise_tools", {}) assert result == snapshot( diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index a2ee65109c..49b04db2fa 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -5,7 +5,6 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client.client import Client from mcp.server import Server, ServerRequestContext from mcp.types import ( INVALID_PARAMS, @@ -22,13 +21,14 @@ Tool, ToolAnnotations, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("tools:call:content:text") -async def test_call_tool_returns_text_content() -> None: +async def test_call_tool_returns_text_content(connect: Connect) -> None: """Arguments reach the tool handler; its content comes back as the call result.""" async def list_tools( @@ -45,14 +45,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("adder", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) assert result == snapshot(CallToolResult(content=[TextContent(text="5")])) @requirement("tools:call:is-error") -async def test_call_tool_execution_error_is_returned_as_result() -> None: +async def test_call_tool_execution_error_is_returned_as_result(connect: Connect) -> None: """A tool reporting its own failure with is_error=True reaches the client as a result, not an exception. Tool execution errors are part of the result so the caller (typically a model) can see @@ -65,7 +65,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("errors", on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("flux", {}) assert result == snapshot( @@ -74,7 +74,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("tools:call:unknown-name") -async def test_call_tool_unknown_tool_is_protocol_error() -> None: +async def test_call_tool_unknown_tool_is_protocol_error(connect: Connect) -> None: """A handler that rejects an unrecognised tool name with MCPError produces a JSON-RPC error. The error's code, message, and data chosen by the handler reach the client verbatim. @@ -85,7 +85,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("errors", on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("nope", {}) @@ -95,7 +95,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:error:internal-error") -async def test_call_tool_uncaught_exception_becomes_error_response() -> None: +async def test_call_tool_uncaught_exception_becomes_error_response(connect: Connect) -> None: """An uncaught exception in the tool handler surfaces to the client as a JSON-RPC error. The low-level server reports it with code 0 and the exception text as the message; see the @@ -108,7 +108,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("errors", on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with pytest.raises(MCPError) as exc_info: await client.call_tool("explode", {}) @@ -116,7 +116,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("tools:list:basic") -async def test_list_tools_returns_registered_tools() -> None: +async def test_list_tools_returns_registered_tools(connect: Connect) -> None: """The tools advertised by the server's list handler arrive at the client unchanged.""" async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: @@ -137,7 +137,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa server = Server("calculator", on_list_tools=list_tools) - async with Client(server) as client: + async with connect(server) as client: result = await client.list_tools() assert result == snapshot( @@ -159,7 +159,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("tools:list:metadata") -async def test_list_tools_optional_fields_round_trip() -> None: +async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: """Every optional Tool field the server supplies reaches the client unchanged.""" tool = Tool( @@ -178,7 +178,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa server = Server("annotated", on_list_tools=list_tools) - async with Client(server) as client: + async with connect(server) as client: result = await client.list_tools() assert result == snapshot( @@ -204,7 +204,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("tools:call:content:audio") @requirement("tools:call:content:resource-link") @requirement("tools:call:content:embedded-resource") -async def test_call_tool_multiple_content_block_types() -> None: +async def test_call_tool_multiple_content_block_types(connect: Connect) -> None: """A tool result can mix every content block type; all of them arrive in order. The payloads are tiny fixed base64 strings ("aW1n" is b"img", "YXVk" is b"aud") so the @@ -230,7 +230,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("renderer", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("render", {}) assert result == snapshot( @@ -249,7 +249,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("tools:call:structured-content") -async def test_call_tool_structured_content() -> None: +async def test_call_tool_structured_content(connect: Connect) -> None: """A tool result carrying structured content alongside content delivers both to the client.""" async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: @@ -261,14 +261,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("calculator", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: result = await client.call_tool("sum", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="the sum is 5")], structured_content={"sum": 5})) @requirement("tools:call:concurrent") -async def test_concurrent_tool_calls_complete_independently() -> None: +async def test_concurrent_tool_calls_complete_independently(connect: Connect) -> None: """Two tool calls in flight at once run concurrently and each caller gets its own answer. Both handlers are held on a shared event after signalling that they have started, and the test @@ -295,7 +295,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: with anyio.fail_after(5): async with anyio.create_task_group() as task_group: @@ -320,7 +320,7 @@ async def call_and_record(tag: str) -> None: @requirement("client:output-schema:validate") -async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client() -> None: +async def test_call_tool_structured_content_violating_output_schema_is_rejected_by_the_client(connect: Connect) -> None: """A result whose structured content does not conform to the tool's declared output schema never reaches the caller: the client validates it against the schema cached from tools/list and raises. """ @@ -346,7 +346,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) - async with Client(server) as client: + async with connect(server) as client: await client.list_tools() with pytest.raises(RuntimeError) as exc_info: await client.call_tool("forecast", {}) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index 9ccbd8fdd8..e7ae4b94d9 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -6,7 +6,6 @@ from mcp import MCPError from mcp.client import ClientRequestContext -from mcp.client.client import Client from mcp.server.elicitation import AcceptedElicitation from mcp.server.mcpserver import Context, MCPServer from mcp.types import ( @@ -20,6 +19,7 @@ LoggingMessageNotificationParams, TextContent, ) +from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -28,7 +28,7 @@ @requirement("mcpserver:context:logging") @requirement("logging:capability:declared") -async def test_context_logging_helpers_send_log_notifications() -> None: +async def test_context_logging_helpers_send_log_notifications(connect: Connect) -> None: """Each Context logging helper sends a log message notification at the matching severity. All four notifications reach the client's logging callback before the tool call returns; none @@ -49,7 +49,7 @@ async def narrate(ctx: Context) -> str: async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) - async with Client(mcp, logging_callback=collect) as client: + async with connect(mcp, logging_callback=collect) as client: result = await client.call_tool("narrate", {}) advertised_logging = client.initialize_result.capabilities.logging @@ -67,7 +67,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: @requirement("mcpserver:context:progress") -async def test_context_report_progress_sends_progress_notifications() -> None: +async def test_context_report_progress_sends_progress_notifications(connect: Connect) -> None: """Context.report_progress sends progress notifications correlated to the calling request. The caller's progress callback receives each report, in order, before the tool call returns. @@ -84,7 +84,7 @@ async def crunch(ctx: Context) -> str: async def on_progress(progress: float, total: float | None, message: str | None) -> None: received.append((progress, total, message)) - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("crunch", {}, progress_callback=on_progress) assert result == snapshot( @@ -94,7 +94,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) @requirement("protocol:progress:no-token") -async def test_report_progress_without_a_progress_token_sends_nothing() -> None: +async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: """When the caller supplied no progress callback, Context.report_progress is a silent no-op. The tool also emits one log message as a sentinel: the message handler receives only that, @@ -113,7 +113,7 @@ async def mill(ctx: Context) -> str: async def collect(message: IncomingMessage) -> None: received.append(message) - async with Client(mcp, message_handler=collect) as client: + async with connect(mcp, message_handler=collect) as client: result = await client.call_tool("mill", {}) assert result == snapshot( @@ -126,7 +126,7 @@ async def collect(message: IncomingMessage) -> None: @requirement("mcpserver:context:elicit") @requirement("tools:call:elicitation-roundtrip") -async def test_context_elicit_returns_typed_result() -> None: +async def test_context_elicit_returns_typed_result(connect: Connect) -> None: """Context.elicit sends a form elicitation built from a pydantic schema and returns a typed result. The client sees the JSON schema generated from the model; the accepted content is validated @@ -149,7 +149,7 @@ async def answer_form(context: ClientRequestContext, params: ElicitRequestParams received.append(params) return ElicitResult(action="accept", content={"destination": "Lisbon", "window_seat": True}) - async with Client(mcp, elicitation_callback=answer_form) as client: + async with connect(mcp, elicitation_callback=answer_form) as client: result = await client.call_tool("book_flight", {}) assert received == snapshot( @@ -178,7 +178,7 @@ async def answer_form(context: ClientRequestContext, params: ElicitRequestParams @requirement("mcpserver:context:read-resource") -async def test_context_read_resource_reads_registered_resource() -> None: +async def test_context_read_resource_reads_registered_resource(connect: Connect) -> None: """Context.read_resource lets a tool read a resource registered on the same server. The tool reports the MIME type and content it read, proving the resource function ran and its @@ -196,7 +196,7 @@ async def show_config(ctx: Context) -> str: contents = list(await ctx.read_resource("config://app")) return "\n".join(f"{item.mime_type}: {item.content!r}" for item in contents) - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("show_config", {}) assert result == snapshot( @@ -208,7 +208,7 @@ async def show_config(ctx: Context) -> str: @requirement("logging:message:filtered") -async def test_set_logging_level_is_rejected_and_messages_are_never_filtered() -> None: +async def test_set_logging_level_is_rejected_and_messages_are_never_filtered(connect: Connect) -> None: """MCPServer does not support logging/setLevel, so log messages are never filtered by severity. The request is rejected with METHOD_NOT_FOUND because MCPServer registers no handler for it, @@ -228,7 +228,7 @@ async def chatter(ctx: Context) -> str: async def collect(params: LoggingMessageNotificationParams) -> None: received.append(params) - async with Client(mcp, logging_callback=collect) as client: + async with connect(mcp, logging_callback=collect) as client: with pytest.raises(MCPError) as exc_info: await client.set_logging_level("error") diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 62c7c33558..e4cb03d8f5 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -4,7 +4,6 @@ from inline_snapshot import snapshot from mcp import MCPError -from mcp.client.client import Client from mcp.server.mcpserver import MCPServer from mcp.types import ( ErrorData, @@ -15,13 +14,14 @@ PromptMessage, TextContent, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("mcpserver:prompt:decorated") -async def test_list_prompts_derives_arguments_from_signature() -> None: +async def test_list_prompts_derives_arguments_from_signature(connect: Connect) -> None: """A decorated prompt is listed with arguments derived from the function signature. Parameters without a default are required; the description comes from the docstring. @@ -33,7 +33,7 @@ def code_review(code: str, style_guide: str = "pep8") -> str: """Review a piece of code.""" raise NotImplementedError # registered for listing only; never rendered - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.list_prompts() assert result == snapshot( @@ -53,7 +53,7 @@ def code_review(code: str, style_guide: str = "pep8") -> str: @requirement("mcpserver:prompt:decorated") -async def test_get_prompt_renders_function_return() -> None: +async def test_get_prompt_renders_function_return(connect: Connect) -> None: """The decorated function's string return value is rendered as a single user message.""" mcp = MCPServer("prompter") @@ -62,7 +62,7 @@ def greet(name: str) -> str: """A personalised greeting.""" return f"Say hello to {name}." - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.get_prompt("greet", {"name": "Ada"}) assert result == snapshot( @@ -74,7 +74,7 @@ def greet(name: str) -> str: @requirement("mcpserver:prompt:unknown-name") -async def test_get_unknown_prompt_is_error() -> None: +async def test_get_unknown_prompt_is_error(connect: Connect) -> None: """Getting a prompt name that was never registered fails with a JSON-RPC error.""" mcp = MCPServer("prompter") @@ -83,7 +83,7 @@ def greet(name: str) -> str: """A registered prompt; the test requests a different name.""" raise NotImplementedError - async with Client(mcp) as client: + async with connect(mcp) as client: with pytest.raises(MCPError) as exc_info: await client.get_prompt("nope") @@ -91,7 +91,7 @@ def greet(name: str) -> str: @requirement("prompts:get:missing-required-args") -async def test_get_prompt_with_a_missing_required_argument_is_an_error() -> None: +async def test_get_prompt_with_a_missing_required_argument_is_an_error(connect: Connect) -> None: """Getting a prompt without one of its required arguments fails with a JSON-RPC error. The missing argument is detected before the prompt function is called, but the spec's -32602 @@ -105,7 +105,7 @@ def greet(name: str) -> str: """A registered prompt; validation rejects the call before the function runs.""" raise NotImplementedError - async with Client(mcp) as client: + async with connect(mcp) as client: with pytest.raises(MCPError) as exc_info: await client.get_prompt("greet") diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index 4ad9ed356b..8960eb2be2 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -4,7 +4,6 @@ from inline_snapshot import snapshot from mcp import MCPError -from mcp.client.client import Client from mcp.server.mcpserver import MCPServer from mcp.types import ( ErrorData, @@ -15,13 +14,14 @@ ResourceTemplate, TextResourceContents, ) +from tests.interaction._connect import Connect from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @requirement("mcpserver:resource:static") -async def test_read_static_resource() -> None: +async def test_read_static_resource(connect: Connect) -> None: """A function registered for a fixed URI is served at that URI with its return value as text.""" mcp = MCPServer("library") @@ -30,7 +30,7 @@ def app_config() -> str: """The application configuration.""" return "theme = dark" - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.read_resource("config://app") assert result == snapshot( @@ -41,7 +41,7 @@ def app_config() -> str: @requirement("mcpserver:resource:static") -async def test_list_static_and_templated_resources() -> None: +async def test_list_static_and_templated_resources(connect: Connect) -> None: """Statically-registered resources appear in resources/list; templated ones only in templates/list. The name and description are derived from the function name and docstring; the MIME type @@ -59,7 +59,7 @@ def user_profile(user_id: str) -> str: """A user's profile.""" raise NotImplementedError # registered for listing only; never read - async with Client(mcp) as client: + async with connect(mcp) as client: resources = await client.list_resources() templates = await client.list_resource_templates() @@ -91,7 +91,7 @@ def user_profile(user_id: str) -> str: @requirement("mcpserver:resource:template") @requirement("resources:read:template-vars") -async def test_read_templated_resource() -> None: +async def test_read_templated_resource(connect: Connect) -> None: """Reading a URI that matches a registered template invokes the function with the extracted parameters.""" mcp = MCPServer("library") @@ -100,7 +100,7 @@ def user_profile(user_id: str) -> str: """A user's profile.""" return f"profile for {user_id}" - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.read_resource("users://42/profile") assert result == snapshot( @@ -111,7 +111,7 @@ def user_profile(user_id: str) -> str: @requirement("mcpserver:resource:unknown-uri") -async def test_read_unknown_uri_is_error() -> None: +async def test_read_unknown_uri_is_error(connect: Connect) -> None: """Reading a URI that matches no registered resource fails with a JSON-RPC error. The spec reserves -32002 for resource-not-found; see the divergence note on the requirement. @@ -123,7 +123,7 @@ def app_config() -> str: """A registered resource; the test reads a different URI.""" raise NotImplementedError - async with Client(mcp) as client: + async with connect(mcp) as client: with pytest.raises(MCPError) as exc_info: await client.read_resource("config://missing") diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index bd63fd5e61..ac6fd59650 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -4,7 +4,6 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from mcp.client.client import Client from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError from mcp.types import ( @@ -13,6 +12,7 @@ LoggingMessageNotificationParams, TextContent, ) +from tests.interaction._connect import Connect from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement @@ -20,7 +20,7 @@ @requirement("tools:call:content:text") -async def test_call_tool_returns_text_content() -> None: +async def test_call_tool_returns_text_content(connect: Connect) -> None: """Arguments reach the tool function; its return value comes back as text content. MCPServer also derives an output schema from the return annotation and attaches the @@ -32,14 +32,14 @@ async def test_call_tool_returns_text_content() -> None: def add(a: int, b: int) -> str: return str(a + b) - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("add", {"a": 2, "b": 3}) assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) @requirement("mcpserver:tool:handler-throws") -async def test_call_tool_function_exception_becomes_error_result() -> None: +async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" mcp = MCPServer("errors") @@ -47,7 +47,7 @@ async def test_call_tool_function_exception_becomes_error_result() -> None: def explode() -> str: raise ValueError("boom") - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("explode", {}) assert result == snapshot( @@ -56,7 +56,7 @@ def explode() -> str: @requirement("mcpserver:tool:handler-throws") -async def test_call_tool_tool_error_becomes_error_result() -> None: +async def test_call_tool_tool_error_becomes_error_result(connect: Connect) -> None: """A ToolError raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" mcp = MCPServer("errors") @@ -64,7 +64,7 @@ async def test_call_tool_tool_error_becomes_error_result() -> None: def flux() -> str: raise ToolError("flux capacitor offline") - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("flux", {}) assert result == snapshot( @@ -73,7 +73,7 @@ def flux() -> str: @requirement("mcpserver:tool:unknown-name") -async def test_call_tool_unknown_name_returns_error_result() -> None: +async def test_call_tool_unknown_name_returns_error_result(connect: Connect) -> None: """Calling a tool name that was never registered is reported as an is_error result. The spec classifies unknown tools as a protocol error; see the divergence note on the @@ -85,7 +85,7 @@ async def test_call_tool_unknown_name_returns_error_result() -> None: def add() -> None: """A registered tool; the test calls a different name.""" - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("nope", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="Unknown tool: nope")], is_error=True)) @@ -93,7 +93,7 @@ def add() -> None: @requirement("mcpserver:tool:output-schema:model") @requirement("tools:call:structured-content:text-mirror") -async def test_call_tool_model_return_becomes_structured_content() -> None: +async def test_call_tool_model_return_becomes_structured_content(connect: Connect) -> None: """A tool returning a pydantic model advertises the model's schema as the tool's output schema and returns the model's fields as structured content alongside a serialised text block. """ @@ -107,7 +107,7 @@ class Weather(BaseModel): def get_weather() -> Weather: return Weather(temperature=22.5, conditions="sunny") - async with Client(mcp) as client: + async with connect(mcp) as client: listed = await client.list_tools() result = await client.call_tool("get_weather", {}) @@ -140,7 +140,7 @@ def get_weather() -> Weather: @requirement("mcpserver:tool:output-schema:wrapped") -async def test_call_tool_list_return_is_wrapped_in_result_key() -> None: +async def test_call_tool_list_return_is_wrapped_in_result_key(connect: Connect) -> None: """A tool returning a list wraps the value under a "result" key in both the generated output schema and the structured content. """ @@ -150,7 +150,7 @@ async def test_call_tool_list_return_is_wrapped_in_result_key() -> None: def primes() -> list[int]: return [2, 3, 5] - async with Client(mcp) as client: + async with connect(mcp) as client: listed = await client.list_tools() result = await client.call_tool("primes", {}) @@ -171,7 +171,7 @@ def primes() -> list[int]: @requirement("mcpserver:tool:input-validation") -async def test_call_tool_invalid_arguments_become_error_result() -> None: +async def test_call_tool_invalid_arguments_become_error_result(connect: Connect) -> None: """Arguments that fail validation against the tool's signature are reported as an is_error result describing the failure, not as a protocol error. """ @@ -182,7 +182,7 @@ def add(a: int, b: int) -> str: """Validation rejects the arguments before the function is ever called.""" raise NotImplementedError - async with Client(mcp) as client: + async with connect(mcp) as client: result = await client.call_tool("add", {"b": 3}) # The description is raw pydantic output -- it embeds a pydantic-version-specific @@ -194,7 +194,7 @@ def add(a: int, b: int) -> str: @requirement("mcpserver:register:post-connect") -async def test_adding_and_removing_tools_does_not_notify_connected_clients() -> None: +async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: """Mutating the tool set on a running server changes tools/list but sends no notification. add_tool and remove_tool only update the registry: a connected client that listed the tools @@ -225,7 +225,7 @@ async def grow(ctx: Context) -> str: async def collect(message: IncomingMessage) -> None: received.append(message) - async with Client(mcp, message_handler=collect) as client: + async with connect(mcp, message_handler=collect) as client: before = await client.list_tools() await client.call_tool("grow", {}) after = await client.list_tools() diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index d2639266a5..f20fa44f05 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -1,29 +1,20 @@ -"""Tests for the interaction model over the streamable HTTP transport, entirely in process. - -The Starlette app a real deployment would hand to uvicorn is driven through the suite's -streaming ASGI bridge instead: the full HTTP framing layer runs (session ids, SSE and JSON -response encoding, stateful and stateless session management) with no sockets, threads, or -subprocesses, so these tests are as deterministic as the in-memory ones. Because the bridge -streams each response as the server produces it, full-duplex behaviour works too: a -server-initiated request nested inside a still-open call round-trips while that call's SSE -response remains open. -""" +"""Behaviour specific to the streamable HTTP transport, exercised entirely in process. -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file only pins what cannot be observed in memory: the +server's stateless and JSON-response modes, the standalone GET stream, and the full-duplex +server-initiated exchange on a still-open call. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge — no sockets, threads, or subprocesses. +""" import anyio -import httpx import pytest from inline_snapshot import snapshot from pydantic import BaseModel from mcp.client import ClientRequestContext -from mcp.client.client import Client -from mcp.client.streamable_http import streamable_http_client from mcp.server.elicitation import AcceptedElicitation from mcp.server.mcpserver import Context, MCPServer -from mcp.server.transport_security import TransportSecuritySettings from mcp.types import ( CallToolResult, ElicitRequestParams, @@ -34,15 +25,15 @@ ResourceUpdatedNotificationParams, TextContent, ) +from tests.interaction._connect import connect_over_streamable_http from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement -from tests.interaction.transports._bridge import StreamingASGITransport pytestmark = pytest.mark.anyio def _smoke_server() -> MCPServer: - """A server exercising one example of each message shape the smoke tests need.""" + """A server exercising each message shape the transport-specific tests need.""" mcp = MCPServer("smoke", instructions="Talk to the smoke server.") @mcp.tool() @@ -50,19 +41,6 @@ def echo(text: str) -> str: """Echo the text back.""" return text - @mcp.tool() - def fail() -> str: - """Always fails.""" - raise ValueError("deliberately broken") - - @mcp.tool() - async def narrate(ctx: Context) -> str: - """Send a log message and a progress update, then return.""" - await ctx.info("starting") - await ctx.report_progress(1, 2) - await ctx.info("finishing") - return "narrated" - class Confirmation(BaseModel): confirmed: bool @@ -84,54 +62,10 @@ async def announce(ctx: Context) -> str: return mcp -@asynccontextmanager -async def _connected( - mcp: MCPServer, *, stateless_http: bool = False, json_response: bool = False -) -> AsyncIterator[Client]: - """Yield a Client connected to the server through the in-process streamable HTTP stack.""" - # DNS-rebinding protection validates Host/Origin headers against a real network attack that - # cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation - # branch (deliberately uncovered in src) into coverage. - app = mcp.streamable_http_app( - stateless_http=stateless_http, - json_response=json_response, - transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False), - ) - async with mcp.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: - transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) - async with Client(transport) as client: - yield client - - -@requirement("transport:streamable-http:stateful") -async def test_initialize_and_call_a_tool_over_streamable_http() -> None: - """The handshake and a tool round trip work through the stateful SSE framing.""" - async with _connected(_smoke_server()) as client: - assert client.initialize_result.server_info.name == "smoke" - assert client.initialize_result.instructions == "Talk to the smoke server." - result = await client.call_tool("echo", {"text": "over http"}) - - assert result == snapshot( - CallToolResult(content=[TextContent(text="over http")], structured_content={"result": "over http"}) - ) - - -@requirement("transport:streamable-http:stateful") -async def test_tool_errors_round_trip_over_streamable_http() -> None: - """A tool execution error crosses the HTTP framing as an is_error result, not a transport failure.""" - async with _connected(_smoke_server()) as client: - result = await client.call_tool("fail", {}) - - assert result == snapshot( - CallToolResult(content=[TextContent(text="Error executing tool fail: deliberately broken")], is_error=True) - ) - - @requirement("transport:streamable-http:json-response") async def test_tool_call_over_streamable_http_with_json_responses() -> None: """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" - async with _connected(_smoke_server(), json_response=True) as client: + async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: assert client.initialize_result.server_info.name == "smoke" result = await client.call_tool("echo", {"text": "as json"}) @@ -143,7 +77,7 @@ async def test_tool_call_over_streamable_http_with_json_responses() -> None: @requirement("transport:streamable-http:stateless") async def test_tool_calls_over_stateless_streamable_http() -> None: """Consecutive requests each succeed against a stateless server with no session to share.""" - async with _connected(_smoke_server(), stateless_http=True) as client: + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: first = await client.call_tool("echo", {"text": "first"}) second = await client.call_tool("echo", {"text": "second"}) @@ -155,39 +89,10 @@ async def test_tool_calls_over_stateless_streamable_http() -> None: ) -@requirement("transport:streamable-http:notifications") -async def test_notifications_during_a_tool_call_arrive_before_the_response() -> None: - """Log and progress notifications emitted mid-call are delivered on the call's SSE stream in order.""" - logs: list[LoggingMessageNotificationParams] = [] - progress_updates: list[tuple[float, float | None, str | None]] = [] - - async def collect_log(params: LoggingMessageNotificationParams) -> None: - logs.append(params) - - async def collect_progress(progress: float, total: float | None, message: str | None) -> None: - progress_updates.append((progress, total, message)) - - server = _smoke_server() - app = server.streamable_http_app( - transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) - ) - async with server.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: - transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) - async with Client(transport, logging_callback=collect_log) as client: - result = await client.call_tool("narrate", {}, progress_callback=collect_progress) - - assert result == snapshot( - CallToolResult(content=[TextContent(text="narrated")], structured_content={"result": "narrated"}) - ) - assert [params.data for params in logs] == snapshot(["starting", "finishing"]) - assert progress_updates == snapshot([(1.0, 2.0, None)]) - - @requirement("transport:streamable-http:stateless-restrictions") async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: """A handler that tries to call back to the client in stateless mode fails: there is no session.""" - async with _connected(_smoke_server(), stateless_http=True) as client: + async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: result = await client.call_tool("ask", {}) assert result.is_error is True @@ -197,6 +102,7 @@ async def test_stateless_streamable_http_rejects_server_initiated_requests() -> assert result.content[0].text.startswith("Error executing tool ask:") +@requirement("transport:streamable-http:notifications") @requirement("transport:streamable-http:unrelated-messages") async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: """A server message with no related request reaches the client through the standalone GET stream. @@ -215,17 +121,10 @@ async def collect(message: IncomingMessage) -> None: if isinstance(message, ResourceUpdatedNotification): resource_update_seen.set() - server = _smoke_server() - app = server.streamable_http_app( - transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) - ) - async with server.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: - transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) - async with Client(transport, message_handler=collect) as client: - result = await client.call_tool("announce", {}) - with anyio.fail_after(5): - await resource_update_seen.wait() + async with connect_over_streamable_http(_smoke_server(), message_handler=collect) as client: + result = await client.call_tool("announce", {}) + with anyio.fail_after(5): + await resource_update_seen.wait() assert result == snapshot( CallToolResult(content=[TextContent(text="announced")], structured_content={"result": "announced"}) @@ -241,6 +140,7 @@ async def collect(message: IncomingMessage) -> None: assert len(received) == 2 +@requirement("transport:streamable-http:stateful") @requirement("transport:streamable-http:server-to-client") async def test_server_initiated_elicitation_round_trips_during_a_tool_call() -> None: """An elicitation issued mid-call reaches the client and its answer reaches the handler over stateful HTTP. @@ -255,17 +155,10 @@ async def answer(context: ClientRequestContext, params: ElicitRequestParams) -> asked.append(params) return ElicitResult(action="accept", content={"confirmed": True}) - server = _smoke_server() - app = server.streamable_http_app( - transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) - ) - async with server.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url="http://127.0.0.1:8000") as http: - transport = streamable_http_client("http://127.0.0.1:8000/mcp", http_client=http) - async with Client(transport, elicitation_callback=answer) as client: - # Bounded because a harness regression here historically meant deadlock, not failure. - with anyio.fail_after(5): - result = await client.call_tool("ask", {}) + async with connect_over_streamable_http(_smoke_server(), elicitation_callback=answer) as client: + # Bounded because a harness regression here historically meant deadlock, not failure. + with anyio.fail_after(5): + result = await client.call_tool("ask", {}) assert result == snapshot( CallToolResult(content=[TextContent(text="confirmed=True")], structured_content={"result": "confirmed=True"}) From 8353a9bdc260af2742c17fb46ba8c0ebbedae4ff Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 17:46:53 +0000 Subject: [PATCH 15/34] test: run the interaction suite over the legacy SSE transport in-process --- src/mcp/server/sse.py | 12 +-- tests/interaction/_connect.py | 98 +++++++++++++++++++-- tests/interaction/_requirements.py | 19 +++- tests/interaction/conftest.py | 3 +- tests/interaction/test_coverage.py | 1 + tests/interaction/transports/_bridge.py | 22 +++-- tests/interaction/transports/test_bridge.py | 21 +++++ tests/interaction/transports/test_sse.py | 98 +++++++++++++++++++++ 8 files changed, 253 insertions(+), 21 deletions(-) create mode 100644 tests/interaction/transports/test_sse.py diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 48192ff612..3e5261896b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -116,15 +116,15 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover - if scope["type"] != "http": + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": # pragma: no cover logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") # Validate request headers for DNS rebinding protection request = Request(scope, receive) error_response = await self._security.validate_request(request, is_post=False) - if error_response: + if error_response: # pragma: no cover await error_response(scope, receive, send) raise ValueError("Request validation failed") @@ -190,13 +190,13 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): logger.debug("Yielding read and write streams") yield (read_stream, write_stream) - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: logger.debug("Handling POST message") request = Request(scope, receive) # Validate request headers for DNS rebinding protection error_response = await self._security.validate_request(request, is_post=True) - if error_response: + if error_response: # pragma: no cover return await error_response(scope, receive, send) session_id_param = request.query_params.get("session_id") @@ -225,7 +225,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) try: message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") - except ValidationError as err: + except ValidationError as err: # pragma: no cover logger.exception("Failed to parse message") response = Response("Could not parse message", status_code=400) await response(scope, receive, send) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index a091e18d9a..b553477f63 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -1,24 +1,31 @@ """Transport-parametrized connection factories for the interaction suite. The `connect` fixture (see conftest.py) hands tests one of these factories so the same test body -runs over the in-memory transport and over streamable HTTP without naming either: the factory is a -drop-in replacement for constructing `Client(server, ...)` and yields the connected client. The -streamable HTTP factory drives the server's real Starlette app through the in-process streaming -bridge, so the full HTTP framing layer (session ids, SSE encoding, session management) runs with -no sockets, threads, or subprocesses. +runs over each transport without naming any of them: the factory is a drop-in replacement for +constructing `Client(server, ...)` and yields the connected client. The HTTP factories drive the +server's real Starlette app through the in-process streaming bridge, so the full transport layer +(session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. """ +import gc +import warnings from collections.abc import AsyncIterator from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Protocol import httpx +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route from mcp.client.client import Client from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT +from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.mcpserver import MCPServer +from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings from mcp.types import Implementation from tests.interaction.transports._bridge import StreamingASGITransport @@ -115,3 +122,84 @@ async def connect_over_streamable_http( elicitation_callback=elicitation_callback, ) as client: yield client + + +def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: + """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. + + `MCPServer.sse_app()` exists but does not expose the underlying `SseServerTransport`, which + the SSE-specific tests need; building the app explicitly here gives both server flavours the + same routing while keeping that handle. + """ + sse = SseServerTransport( + "/messages/", security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse(request.scope, request.receive, request._send) as (read, write): + await lowlevel.run(read, write, lowlevel.create_initialization_options()) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ], + ) + return app, sse + + +@asynccontextmanager +async def connect_over_sse( + server: Server | MCPServer, + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: Implementation | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Yield a Client connected to the server's legacy SSE transport, entirely in process.""" + app, _ = build_sse_app(server) + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + # The SSE server transport's connect_sse runs the entire MCP session inside the GET + # request and only releases its streams after that request observes a disconnect, so the + # bridge must let the application drain rather than cancelling at close. + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=_BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client(f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + try: + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client + finally: + # SseServerTransport.connect_sse hands its internal SSE-chunk receive stream to + # sse_starlette's EventSourceResponse, which never closes it when its task group is + # cancelled on disconnect (see notes/findings.md). Collect the orphan here so its + # ResourceWarning fires deterministically inside this fixture instead of at an + # arbitrary later GC. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect() diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index cdc97d4074..68171992b0 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1759,10 +1759,23 @@ def __post_init__(self) -> None: "requests, with server messages delivered on the SSE stream." ), transports=("sse",), - deferred=( - "The legacy SSE transport is covered by tests/shared/test_sse.py; in-process coverage in this " - "suite arrives with the transport fixture work." + ), + "transport:sse:endpoint-event": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", + behavior=( + "Opening the SSE stream delivers an `endpoint` event naming the message-POST URL and a fresh " + "session identifier; the server registers the session before the event is sent and releases it " + "when the stream disconnects." + ), + transports=("sse",), + ), + "transport:sse:post:session-routing": Requirement( + source="sdk", + behavior=( + "A POST to the SSE message endpoint that names no session id, a malformed session id, or an " + "unknown session id is rejected (400/400/404) instead of being forwarded." ), + transports=("sse",), ), "transport:stdio": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", diff --git a/tests/interaction/conftest.py b/tests/interaction/conftest.py index f8960bd13b..c2ace45077 100644 --- a/tests/interaction/conftest.py +++ b/tests/interaction/conftest.py @@ -2,11 +2,12 @@ import pytest -from tests.interaction._connect import Connect, connect_in_memory, connect_over_streamable_http +from tests.interaction._connect import Connect, connect_in_memory, connect_over_sse, connect_over_streamable_http _FACTORIES: dict[str, Connect] = { "in-memory": connect_in_memory, "streamable-http": connect_over_streamable_http, + "sse": connect_over_sse, } diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 5a4f003101..47b1b95e71 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -29,6 +29,7 @@ "tests.interaction.transports.test_bridge.test_response_chunks_arrive_as_the_application_sends_them", "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", + "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", } diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py index 254f1e00c1..6d0bfd62d4 100644 --- a/tests/interaction/transports/_bridge.py +++ b/tests/interaction/transports/_bridge.py @@ -21,7 +21,8 @@ The transport owns an anyio task group for the application tasks; it is opened and closed by `httpx.AsyncClient`'s own context manager, so use the client as a context manager (the suite -always does). +always does). Closing the transport cancels every running application task by default; set +`cancel_on_close=False` to wait for the application's own disconnect handling instead. """ import math @@ -56,12 +57,19 @@ async def aclose(self) -> None: class StreamingASGITransport(httpx.AsyncBaseTransport): - """Drive an ASGI application in-process, streaming each response as it is produced.""" + """Drive an ASGI application in-process, streaming each response as it is produced. + + With `cancel_on_close` (the default), closing the transport cancels every application task + still running so harness teardown can never hang. Setting it to False makes the transport wait + for the application's own disconnect handling to complete instead, which is the path the legacy + SSE server transport relies on for resource cleanup. + """ _task_group: anyio.abc.TaskGroup - def __init__(self, app: ASGIApp) -> None: + def __init__(self, app: ASGIApp, *, cancel_on_close: bool = True) -> None: self._app = app + self._cancel_on_close = cancel_on_close async def __aenter__(self) -> "StreamingASGITransport": self._task_group = anyio.create_task_group() @@ -74,9 +82,11 @@ async def __aexit__( exc_value: BaseException | None = None, traceback: TracebackType | None = None, ) -> None: - # Any application task still running at this point is serving a client that no longer - # exists; cancel rather than wait so harness teardown can never hang. - self._task_group.cancel_scope.cancel() + # httpx closes every streamed response before closing the transport, so by now each + # application task has been delivered `http.disconnect`. Either cancel immediately, or wait + # for the application's own disconnect handling to unwind. + if self._cancel_on_close: + self._task_group.cancel_scope.cancel() await self._task_group.__aexit__(exc_type, exc_value, traceback) async def handle_async_request(self, request: httpx.Request) -> httpx.Response: diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py index 13389f8533..71be14ced0 100644 --- a/tests/interaction/transports/test_bridge.py +++ b/tests/interaction/transports/test_bridge.py @@ -69,3 +69,24 @@ async def broken_app(scope: Scope, receive: Receive, send: Send) -> None: async with httpx.AsyncClient(transport=StreamingASGITransport(broken_app), base_url="http://bridge") as http: with pytest.raises(RuntimeError, match="the demo application is broken"): await http.get("/broken") + + +async def test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect() -> None: + """With cancel_on_close=False, an application that runs cleanup after seeing http.disconnect + completes that cleanup before the transport finishes closing.""" + cleanup_ran = anyio.Event() + + async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + await receive() + await send({"type": "http.response.start", "status": 200, "headers": []}) + assert (await receive())["type"] == "http.disconnect" + cleanup_ran.set() + + transport = StreamingASGITransport(lingering_app, cancel_on_close=False) + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: + with anyio.fail_after(5): + async with http.stream("GET", "/linger") as response: + assert response.status_code == 200 + assert not cleanup_ran.is_set() + assert cleanup_ran.is_set() diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py new file mode 100644 index 0000000000..1d5434c160 --- /dev/null +++ b/tests/interaction/transports/test_sse.py @@ -0,0 +1,98 @@ +"""Behaviour specific to the legacy HTTP+SSE transport, exercised entirely in process. + +Transport-agnostic behaviour is covered by the `connect`-fixture matrix, which runs the rest of +the suite over this transport as well; this file pins only what is observable on the SSE wiring +itself: the GET-then-POST connection lifecycle, the endpoint event, and how the message endpoint +rejects requests it cannot route to a session. Every test drives the server's real Starlette app +through the suite's streaming ASGI bridge. +""" + +import gc +import warnings +from uuid import UUID, uuid4 + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.sse import sse_client +from mcp.server import Server +from mcp.types import EmptyResult +from tests.interaction._connect import build_sse_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + +_BASE_URL = "http://127.0.0.1:8000" + + +@requirement("transport:sse") +@requirement("transport:sse:endpoint-event") +async def test_endpoint_event_names_the_message_endpoint_with_a_fresh_session_id() -> None: + """Connecting opens a GET stream whose first event names the POST endpoint and a fresh + session id; messages POSTed there are answered on that stream, and disconnecting releases the + server's session entry.""" + app, sse = build_sse_app(Server("legacy")) + captured_session_id: list[str] = [] + + def httpx_client_factory( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=StreamingASGITransport(app, cancel_on_close=False), + base_url=_BASE_URL, + headers=headers, + timeout=timeout, + auth=auth, + ) + + transport = sse_client( + f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + ) + with anyio.fail_after(5): + async with Client(transport) as client: + assert len(captured_session_id) == 1 + assert UUID(hex=captured_session_id[0]) in sse._read_stream_writers + assert await client.send_ping() == snapshot(EmptyResult()) + + assert sse._read_stream_writers == {} + # See connect_over_sse: collect the one stream sse_starlette never closes on disconnect. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ResourceWarning) + gc.collect() + + +@requirement("transport:sse:post:session-routing") +async def test_post_without_a_session_id_is_rejected() -> None: + """A POST to the message endpoint with no session_id query parameter is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) + assert (response.status_code, response.text) == snapshot((400, "session_id is required")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_with_a_malformed_session_id_is_rejected() -> None: + """A POST whose session_id query parameter is not a UUID is answered 400.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((400, "Invalid session ID")) + + +@requirement("transport:sse:post:session-routing") +async def test_post_for_an_unknown_session_is_rejected() -> None: + """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" + app, _ = build_sse_app(Server("legacy")) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + response = await http.post( + "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} + ) + assert (response.status_code, response.text) == snapshot((404, "Could not find session")) From 584e0989085d08df6691e55e443386f878b19dd8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Tue, 26 May 2026 18:02:45 +0000 Subject: [PATCH 16/34] test: add an SDK-client to SDK-server stdio end-to-end interaction test --- tests/interaction/_requirements.py | 10 +-- tests/interaction/transports/_stdio_server.py | 56 ++++++++++++++ tests/interaction/transports/test_stdio.py | 76 +++++++++++++++++++ 3 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 tests/interaction/transports/_stdio_server.py create mode 100644 tests/interaction/transports/test_stdio.py diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 68171992b0..842fe7a199 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1779,12 +1779,11 @@ def __post_init__(self) -> None: ), "transport:stdio": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", - behavior="The interaction round trip works over a stdio subprocess.", - transports=("stdio",), - deferred=( - "Not yet covered here: a single composed end-to-end stdio test is planned; process lifecycle " - "details are covered by tests/client/test_stdio.py." + behavior=( + "A Client connected to a real SDK Server over stdio initializes, calls a tool with arguments, " + "and receives notifications and results over the child process's stdin/stdout." ), + transports=("stdio",), ), # ═══════════════════════════════════════════════════════════════════════════ # Hosting: session lifecycle @@ -2494,7 +2493,6 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/lifecycle#shutdown", behavior="Closing the client transport closes the child process's stdin and the server exits cleanly.", transports=("stdio",), - deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", ), "transport:stdio:stream-purity": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py new file mode 100644 index 0000000000..fbe7e614f7 --- /dev/null +++ b/tests/interaction/transports/_stdio_server.py @@ -0,0 +1,56 @@ +"""A real low-level Server over the stdio transport, for the suite's one subprocess test. + +Runnable as `python -m tests.interaction.transports._stdio_server` from the repo root; the test +launches it that way via `stdio_client`. Kept separate from the test module so the server lives in +its own importable file (subprocess coverage applies) while the test file follows the suite's +test-only-functions convention. +""" + +import sys + +import anyio + +from mcp.server import Server, ServerRequestContext +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="echo", + input_schema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"]}, + ) + ] + ) + + +async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + text = params.arguments["text"] + await ctx.session.send_log_message(level="info", data=f"echoing {text}", logger="echo") + return CallToolResult(content=[TextContent(text=text)]) + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool) + + +async def main() -> None: + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + # Reached only when the run loop exits because stdin closed; if the process were terminated + # the test's stderr capture would not see this line. + print("stdio-echo: clean exit", file=sys.stderr, flush=True) + + +if __name__ == "__main__": + anyio.run(main) diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py new file mode 100644 index 0000000000..e70a68225f --- /dev/null +++ b/tests/interaction/transports/test_stdio.py @@ -0,0 +1,76 @@ +"""The suite's one stdio end-to-end test: a real SDK Server in a subprocess, driven by Client. + +Everything else in the suite runs in a single process; this test exists to prove the same +client↔server round trip works over the stdio transport's real boundary (a child process whose +stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in +`_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. + +stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test +would be slow, and the matrix already proves transport-agnosticism over three in-process +transports. Process-lifecycle edge cases (escalation to terminate/kill, stderr handling, parse +errors) are covered by `tests/client/test_stdio.py` and stay deferred here. +""" + +import os +import sys +import tempfile +from pathlib import Path + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp.client.client import Client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._requirements import requirement +from tests.interaction.transports import _stdio_server + +pytestmark = pytest.mark.anyio + +_REPO_ROOT = Path(__file__).parents[3] + + +@requirement("transport:stdio") +@requirement("transport:stdio:clean-shutdown") +async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: + """A Client connected over stdio initializes, calls a tool with arguments, receives the + server's log notification before the call returns, and the server exits when the transport + closes its stdin.""" + received: list[LoggingMessageNotificationParams] = [] + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params) + + with tempfile.TemporaryFile(mode="w+") as errlog: + transport = stdio_client( + StdioServerParameters( + command=sys.executable, + args=["-m", _stdio_server.__name__], + cwd=str(_REPO_ROOT), + # stdio_client deliberately filters the inherited environment to a safe minimum, + # which drops the variables coverage.py's subprocess support uses; pass them through + # so the server module is measured. Empty when not running under coverage. + env={key: value for key, value in os.environ.items() if key.startswith("COVERAGE_")}, + ), + errlog=errlog, + ) + + with anyio.fail_after(10): + async with Client(transport, logging_callback=collect) as client: + assert client.initialize_result.server_info.name == "stdio-echo" + result = await client.call_tool("echo", {"text": "across\nprocesses"}) + + errlog.seek(0) + captured_stderr = errlog.read() + + assert result == snapshot(CallToolResult(content=[TextContent(text="across\nprocesses")])) + # stdio carries one ordered server→client stream, so the same notification-before-response + # guarantee holds here as for the in-memory transport. + assert received == snapshot( + [LoggingMessageNotificationParams(level="info", logger="echo", data="echoing across\nprocesses")] + ) + # The server writes this line only after its run loop returns, which happens when stdin closes: + # seeing it proves the process exited on its own rather than via the transport's terminate + # escalation, without a timing-based assertion. + assert captured_stderr == snapshot("stdio-echo: clean exit\n") From 538136a5c6bb1ce46e1995b9fa1c70a3cd0ec987 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 08:39:04 +0000 Subject: [PATCH 17/34] test: add streamable HTTP hosting, resumability, and client transport conformance tests --- src/mcp/client/streamable_http.py | 12 +- src/mcp/server/mcpserver/context.py | 2 +- src/mcp/server/streamable_http.py | 94 +++--- src/mcp/server/streamable_http_manager.py | 2 +- src/mcp/server/transport_security.py | 34 +- tests/interaction/README.md | 6 + tests/interaction/_connect.py | 177 ++++++++++- tests/interaction/_requirements.py | 122 +++---- tests/interaction/transports/_event_store.py | 55 ++++ .../transports/test_client_transport_http.py | 211 +++++++++++++ .../transports/test_hosting_http.py | 297 ++++++++++++++++++ .../transports/test_hosting_resume.py | 287 +++++++++++++++++ .../transports/test_hosting_session.py | 203 ++++++++++++ tests/interaction/transports/test_sse.py | 14 +- .../transports/test_streamable_http.py | 3 + 15 files changed, 1356 insertions(+), 163 deletions(-) create mode 100644 tests/interaction/transports/_event_store.py create mode 100644 tests/interaction/transports/test_client_transport_http.py create mode 100644 tests/interaction/transports/test_hosting_http.py create mode 100644 tests/interaction/transports/test_hosting_resume.py create mode 100644 tests/interaction/transports/test_hosting_session.py diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c6338..a6b4e6cfa0 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -210,7 +210,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: # Stream ended normally (server closed) - reset attempt counter attempt = 0 - except Exception: # pragma: lax no cover + except Exception: logger.debug("GET stream error", exc_info=True) attempt += 1 @@ -492,17 +492,17 @@ async def handle_request_async(): async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover - return + if not self.session_id: + return # pragma: no cover try: headers = self._prepare_headers() response = await client.delete(self.url, headers=headers) - if response.status_code == 405: # pragma: lax no cover + if response.status_code == 405: logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): # pragma: lax no cover - logger.warning(f"Session termination failed: {response.status_code}") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") # pragma: no cover except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index d4344daa92..1441649808 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -237,7 +237,7 @@ async def close_sse_stream(self) -> None: This is a no-op if not using StreamableHTTP transport with event_store. The callback is only available when event_store is configured. """ - if self._request_context and self._request_context.close_sse_stream: # pragma: no cover + if self._request_context and self._request_context.close_sse_stream: # pragma: no branch await self._request_context.close_sse_stream() async def close_standalone_sse_stream(self) -> None: diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index a4cb5af03a..c85eeeeadf 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -179,7 +179,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -198,11 +198,11 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover the disconnect. """ writer = self._sse_stream_writers.pop(request_id, None) - if writer: + if writer: # pragma: no branch writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() @@ -242,7 +242,7 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) async def close_standalone_stream_callback() -> None: # pragma: no cover @@ -293,7 +293,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -320,10 +320,10 @@ def _create_json_response( ) -> Response: """Create a JSON response from a JSONRPCMessage.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: lax no cover - response_headers.update(headers) + if headers: + response_headers.update(headers) # pragma: no cover - if self.mcp_session_id: # pragma: lax no cover + if self.mcp_session_id: response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id return Response( @@ -344,7 +344,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -374,7 +374,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: lax no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -389,7 +389,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -421,7 +421,7 @@ async def _validate_accept_header(self, request: Request, scope: Scope, send: Se has_json, has_sse = self._check_accept_headers(request) if self.is_json_response_enabled: # For JSON-only responses, only require application/json - if not has_json: # pragma: lax no cover + if not has_json: # pragma: no cover response = self._create_error_response( "Not Acceptable: Client must accept application/json", HTTPStatus.NOT_ACCEPTABLE, @@ -469,7 +469,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -495,7 +495,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -579,7 +579,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Store writer reference so close_sse_stream() can close it self._sse_stream_writers[request_id] = sse_stream_writer - async def sse_writer(): # pragma: lax no cover + async def sse_writer(): # Get the request ID from the incoming request message try: async with sse_stream_writer, request_stream_reader: @@ -595,10 +595,10 @@ async def sse_writer(): # pragma: lax no cover # If response, remove from pending streams and close if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in SSE writer") finally: logger.debug("Closing SSE writer") @@ -628,14 +628,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: lax no cover + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,9 +643,9 @@ async def sse_writer(): # pragma: lax no cover INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: + if writer: # pragma: no cover await writer.send(Exception(err)) - return + return # pragma: no cover async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -661,7 +661,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -673,7 +673,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -683,11 +683,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -707,7 +707,7 @@ async def standalone_sse_writer(): async with sse_stream_writer, standalone_stream_reader: # Process messages from the standalone stream - async for event_message in standalone_stream_reader: # pragma: lax no cover + async for event_message in standalone_stream_reader: # For the standalone stream, we handle: # - JSONRPCNotification (server sends notifications to client) # - JSONRPCRequest (server sends requests to client) @@ -716,8 +716,8 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover - logger.exception("Error in standalone SSE writer") + except Exception: + logger.exception("Error in standalone SSE writer") # pragma: no cover finally: logger.debug("Closing standalone SSE writer") await self._clean_up_memory_streams(GET_STREAM_KEY) @@ -775,7 +775,7 @@ async def terminate(self) -> None: request_stream_keys = list(self._request_streams.keys()) # Close all request streams asynchronously - for key in request_stream_keys: # pragma: lax no cover + for key in request_stream_keys: await self._clean_up_memory_streams(key) # Clear the request streams dictionary immediately @@ -793,13 +793,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -809,7 +809,7 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non ) await response(request.scope, request.receive, send) - async def _validate_request_headers(self, request: Request, send: Send) -> bool: # pragma: lax no cover + async def _validate_request_headers(self, request: Request, send: Send) -> bool: if not await self._validate_session(request, send): return False if not await self._validate_protocol_version(request, send): @@ -826,7 +826,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -851,11 +851,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -867,14 +867,14 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store if not event_store: - return + return # pragma: no cover try: headers = { @@ -883,7 +883,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -921,10 +921,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -936,13 +936,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -1009,7 +1009,7 @@ async def message_router(): # regardless of whether a client is connected # messages will be replayed on the re-connect event_id = None - if self._event_store: # pragma: lax no cover + if self._event_store: event_id = await self._event_store.store_event(request_stream_id, message) logger.debug(f"Stored {event_id} from {request_stream_id}") @@ -1020,14 +1020,14 @@ async def message_router(): except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) except anyio.ClosedResourceError: - if self._terminated: + if self._terminated: # pragma: lax no cover logger.debug("Read stream closed by client") else: logger.exception("Unexpected closure of read stream in message router") @@ -1041,8 +1041,8 @@ async def message_router(): # Yield the streams for the caller to use yield read_stream, write_stream finally: - for stream_id in list(self._request_streams.keys()): # pragma: lax no cover - await self._clean_up_memory_streams(stream_id) + for stream_id in list(self._request_streams.keys()): + await self._clean_up_memory_streams(stream_id) # pragma: no cover self._request_streams.clear() # Clean up the read and write streams diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index c25314eab6..39d434505c 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -173,7 +173,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") # Assert task group is not None for type checking diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..707d4b61dd 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,19 +40,19 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" - if not host: + if not host: # pragma: no cover logger.warning("Missing Host header in request") return False # Check exact match first - if host in self.settings.allowed_hosts: + if host in self.settings.allowed_hosts: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base host from pattern base_host = allowed[:-2] # Check if the actual host starts with base host and has a port @@ -62,19 +62,19 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests - if not origin: + if not origin: # pragma: no cover return True # Check exact match first - if origin in self.settings.allowed_origins: + if origin in self.settings.allowed_origins: # pragma: no cover return True # Check wildcard port patterns for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): + if allowed.endswith(":*"): # pragma: no branch # Extract base origin from pattern base_origin = allowed[:-2] # Check if the actual origin starts with base origin and has a port @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None diff --git a/tests/interaction/README.md b/tests/interaction/README.md index df8f331596..e1341806c6 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -62,6 +62,12 @@ stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout t machinery is transport-independent and must not race transport latency), and everything under `transports/`, which pins behaviour only observable on that transport. +A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app +**only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, +response headers, SSE event fields, which stream a message travels on. Any other behaviour is +asserted through a `Client`, connected to the mounted app via `client_via_http(http)` so several +clients can share one session manager. + ## The requirements manifest `_requirements.py` maps every behaviour the suite covers to the reason it must hold: diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index b553477f63..baca975917 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -9,11 +9,12 @@ import gc import warnings -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Protocol import httpx +from httpx_sse import ServerSentEvent, aconnect_sse from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import Response @@ -26,12 +27,29 @@ from mcp.server import Server from mcp.server.mcpserver import MCPServer from mcp.server.sse import SseServerTransport +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.types import Implementation +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + jsonrpc_message_adapter, +) from tests.interaction.transports._bridge import StreamingASGITransport # The in-process app is mounted at this origin purely so URLs are well-formed; nothing listens here. -_BASE_URL = "http://127.0.0.1:8000" +BASE_URL = "http://127.0.0.1:8000" + +# DNS-rebinding protection validates Host/Origin headers against a real network attack that cannot +# exist for an in-process ASGI app, so the in-process factories disable it; tests that exercise the +# protection itself pass explicit settings (or transport_security=None to get the localhost +# auto-enable behaviour). +NO_DNS_REBINDING_PROTECTION = TransportSecuritySettings(enable_dns_rebinding_protection=False) class Connect(Protocol): @@ -86,6 +104,8 @@ async def connect_over_streamable_http( *, stateless_http: bool = False, json_response: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, @@ -98,19 +118,19 @@ async def connect_over_streamable_http( With the defaults this is the matrix leg (stateful sessions, SSE responses); the transport-specific tests pass `stateless_http` or `json_response` to select the other - server modes. + server modes, and the resumability tests pass an `event_store` (with `retry_interval=0` so + the client's reconnection wait is a no-op). """ - # DNS-rebinding protection validates Host/Origin headers against a real network attack that - # cannot exist for an in-process ASGI app; leaving it on would also pull the origin-validation - # branch (deliberately uncovered in src) into coverage. app = server.streamable_http_app( stateless_http=stateless_http, json_response=json_response, - transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False), + event_store=event_store, + retry_interval=retry_interval, + transport_security=NO_DNS_REBINDING_PROTECTION, ) async with server.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http_client: - transport = streamable_http_client(f"{_BASE_URL}/mcp", http_client=http_client) + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) async with Client( transport, read_timeout_seconds=read_timeout_seconds, @@ -124,6 +144,139 @@ async def connect_over_streamable_http( yield client +@asynccontextmanager +async def mounted_app( + server: Server | MCPServer, + *, + stateless_http: bool = False, + event_store: EventStore | None = None, + retry_interval: int | None = None, + transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, + on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, + headers: dict[str, str] | None = None, +) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: + """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. + + Yields the httpx client (rooted at the in-process origin) and the live session manager. Tests + use this in two ways: for raw-httpx assertions (status codes, headers, SSE bytes) the test + speaks HTTP through the yielded client directly; for client-driven assertions the test wraps + that client in `client_via_http(http)`, which lets several `Client`s share the one mounted + session manager. `on_request` records every outgoing HTTP request before it leaves the + yielded client. + + DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the + localhost auto-enable behaviour) to test the protection itself. + """ + app = server.streamable_http_app( + stateless_http=stateless_http, + event_store=event_store, + retry_interval=retry_interval, + transport_security=transport_security, + ) + event_hooks = {"request": [on_request]} if on_request is not None else None + async with server.session_manager.run(): + async with httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers + ) as http_client: + yield http_client, server.session_manager + + +@asynccontextmanager +async def client_via_http( + http_client: httpx.AsyncClient, + *, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncIterator[Client]: + """Connect a `Client` over an already-mounted streamable HTTP app. + + Use with `mounted_app(...)` so several `Client`s share the one session manager, or so a + client-driven assertion can sit alongside raw-httpx assertions in the same test. The + underlying `httpx.AsyncClient` is left open when the `Client` exits. + """ + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client( + transport, + logging_callback=logging_callback, + message_handler=message_handler, + elicitation_callback=elicitation_callback, + ) as client: + yield client + + +def parse_sse_messages(events: Iterable[ServerSentEvent]) -> list[JSONRPCMessage]: + """Decode SSE events into JSON-RPC messages, skipping priming events that carry no data.""" + return [jsonrpc_message_adapter.validate_json(event.data) for event in events if event.data] + + +async def post_jsonrpc( + http: httpx.AsyncClient, body: dict[str, object], *, session_id: str | None = None +) -> tuple[httpx.Response, list[JSONRPCMessage]]: + """POST a JSON-RPC body and read its SSE response stream to completion. + + Returns the HTTP response (for header/status assertions) and the parsed JSON-RPC messages + that arrived on the response's SSE stream. Only meaningful for requests the server answers + with `text/event-stream`; for error responses or 202 notification acknowledgements, use + `httpx.AsyncClient.post` directly and assert on the response. + """ + async with aconnect_sse(http, "POST", "/mcp", json=body, headers=base_headers(session_id=session_id)) as source: + events = [event async for event in source.aiter_sse()] + return source.response, parse_sse_messages(events) + + +def base_headers(*, session_id: str | None = None) -> dict[str, str]: + """Standard request headers for raw-httpx streamable-HTTP tests. + + Every well-formed request carries these (Accept covering both response representations, + Content-Type for POST bodies, MCP-Protocol-Version at the latest revision, and the session + ID once one exists), so a test that wants to assert a specific rejection only varies the one + header under test. + """ + headers = { + "accept": "application/json, text/event-stream", + "content-type": "application/json", + "mcp-protocol-version": LATEST_PROTOCOL_VERSION, + } + if session_id is not None: + headers["mcp-session-id"] = session_id + return headers + + +def initialize_body(request_id: int = 1) -> dict[str, object]: + """A wire-level initialize JSON-RPC request body, exactly as an SDK client would send it.""" + params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="raw", version="0.0.0"), + ) + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="initialize", params=params.model_dump(by_alias=True, exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) + + +async def initialize_via_http(http: httpx.AsyncClient) -> str: + """Perform the initialize handshake over a raw `httpx.AsyncClient` and return the session ID. + + Validates the SSE response and sends the `notifications/initialized` follow-up, so the server + is fully ready for subsequent feature requests when this returns. + """ + async with aconnect_sse(http, "POST", "/mcp", json=initialize_body(), headers=base_headers()) as source: + assert source.response.status_code == 200 + # An event-store-backed server opens the stream with a priming event (empty data); skip it. + events = [event async for event in source.aiter_sse() if event.data] + assert len(events) == 1 + assert JSONRPCResponse.model_validate_json(events[0].data).id == 1 + session_id = source.response.headers["mcp-session-id"] + initialized = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, + headers=base_headers(session_id=session_id), + ) + assert initialized.status_code == 202 + return session_id + + def build_sse_app(server: Server | MCPServer) -> tuple[Starlette, SseServerTransport]: """Mount a server on a Starlette app exposing the legacy SSE transport at /sse and /messages/. @@ -175,13 +328,13 @@ def httpx_client_factory( # bridge must let the application drain rather than cancelling at close. return httpx.AsyncClient( transport=StreamingASGITransport(app, cancel_on_close=False), - base_url=_BASE_URL, + base_url=BASE_URL, headers=headers, timeout=timeout, auth=auth, ) - transport = sse_client(f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory) + transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) try: async with Client( transport, diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 842fe7a199..91ab5375d9 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1738,19 +1738,11 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/transports#streamable-http", behavior="A client that reconnects with Last-Event-ID receives the events it missed.", transports=("streamable-http",), - deferred=( - "Replay requires dropping and re-establishing the SSE connection, which the in-process ASGI " - "client cannot express. Covered over a real socket by tests/shared/test_streamable_http.py." - ), ), "transport:streamable-http:origin-validation": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#security-warning", behavior="Requests with an invalid Origin header are rejected with 403 before reaching the session.", transports=("streamable-http",), - deferred=( - "Not yet covered here: the in-process fixture leaves the SDK's opt-in protection disabled (see " - "hosting:http:dns-rebinding); existing coverage in tests/server/test_streamable_http_security.py." - ), ), "transport:sse": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", @@ -1801,70 +1793,70 @@ def __post_init__(self) -> None: "response headers." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/shared/test_streamable_http.py and " - "tests/server/test_streamable_http_manager.py." - ), ), "hosting:session:delete": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="DELETE with a valid Mcp-Session-Id terminates the session and removes its transport.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:session:id-charset": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="Generated Mcp-Session-Id values contain only visible ASCII characters.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:session:isolation": Requirement( source="sdk", behavior="Each session gets its own server instance; closing one session does not affect others.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/server/test_streamable_http_manager.py.", ), "hosting:session:missing-id": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="A non-initialize POST without Mcp-Session-Id in stateful mode returns 400.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + ), + "hosting:session:post-termination-404": Requirement( + source=f"{SPEC_BASE_URL}/basic/transports#session-management", + behavior=( + "After a session is terminated, any further request carrying that session ID is answered with " + "404 Not Found." + ), + transports=("streamable-http",), ), "hosting:session:reinitialize": Requirement( source="sdk", behavior="A second initialize on an already-initialized session transport is rejected.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", + divergence=Divergence( + note=( + "The transport forwards a second initialize carrying the existing session ID to the running " + "server, which answers it as a fresh handshake; nothing rejects re-initialization." + ), + ), ), "hosting:session:reuse": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="A POST carrying a valid Mcp-Session-Id routes to that session's transport with state preserved.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:session:unknown-id": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="A POST, GET, or DELETE with an unknown Mcp-Session-Id returns 404.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:stateless:concurrent-clients": Requirement( source="sdk", behavior="Multiple independent clients can connect to a stateless server concurrently.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:stateless:no-reuse": Requirement( source="sdk", behavior="A stateless per-request transport cannot be reused for a second request.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:stateless:no-session-id": Requirement( source="sdk", behavior="In stateless mode no Mcp-Session-Id is emitted and no session validation is performed.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), # ═══════════════════════════════════════════════════════════════════════════ # Hosting: auth @@ -1963,25 +1955,28 @@ def __post_init__(self) -> None: source="sdk", behavior="A Last-Event-ID that cannot be mapped to a stream is rejected.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", + divergence=Divergence( + note=( + "The replay path returns an empty SSE stream rather than rejecting an unknown " + "Last-Event-ID; the client cannot tell an unknown ID apart from a stream with no missed " + "events." + ), + ), ), "hosting:resume:buffered-replay": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior="Notifications emitted while no client is connected are replayed in order on reconnect.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:resume:close-stream": Requirement( source="sdk", behavior="Handlers can close an SSE stream cleanly when an event store is configured.", transports=("streamable-http",), - deferred="Not implemented in the SDK: handlers have no API to close SSE streams.", ), "hosting:resume:event-ids": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior="With an event store configured, every SSE event carries an id field.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:resume:priming": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -1991,19 +1986,23 @@ def __post_init__(self) -> None: "retry field first." ), transports=("streamable-http",), - deferred="Not yet covered here: whether the python server emits priming events has not been pinned.", + divergence=Divergence( + note=( + "The retry hint is attached to the priming event itself rather than sent as a separate " + "event before the connection closes, and a priming event is only sent when an event store " + "is configured and the negotiated protocol version is at least 2025-11-25." + ), + ), ), "hosting:resume:replay": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior="GET with Last-Event-ID replays stored events for that stream after the given id.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:resume:stream-scoped": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior="Replay via Last-Event-ID returns only messages from the stream that event id belongs to.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), # ═══════════════════════════════════════════════════════════════════════════ # Hosting: HTTP semantics @@ -2012,7 +2011,6 @@ def __post_init__(self) -> None: source="sdk", behavior="A request whose Accept header does not allow the response representation returns 406.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:batch": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -2021,13 +2019,18 @@ def __post_init__(self) -> None: "that forbid them." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:content-type-415": Requirement( source="sdk", behavior="A POST with a Content-Type other than application/json returns 415.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + divergence=Divergence( + note=( + "The transport-security middleware rejects a non-JSON Content-Type with 400 'Invalid " + "Content-Type header' before the request reaches the transport, so the transport's own 415 " + "path is unreachable through any public entry point." + ), + ), ), "hosting:http:disconnect-not-cancel": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -2036,7 +2039,6 @@ def __post_init__(self) -> None: "handler; the request continues and its result remains retrievable." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:dns-rebinding": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#security-warning", @@ -2045,27 +2047,24 @@ def __post_init__(self) -> None: "Origin is rejected with 403 Forbidden." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/test_streamable_http_security.py. " - "The SDK's protection is opt-in and disabled by default (no TransportSecuritySettings means " - "no Origin validation), and it also checks Host — the off-by-default gap is one to record as " - "a divergence when the transport conformance tests land." + divergence=Divergence( + note=( + "The spec's Origin validation is an unconditional MUST; the SDK enables it only when the " + "host is a localhost address or explicit TransportSecuritySettings are passed (with no " + "settings, no Origin validation runs), and additionally validates the Host header " + "(returning 421 on mismatch), which the spec does not require." + ), ), ), "hosting:http:json-response-mode": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior="With JSON response mode enabled, POST returns application/json instead of SSE.", transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/shared/test_streamable_http.py and the " - "json-response tests in this suite's transports directory." - ), ), "hosting:http:method-405": Requirement( source="sdk", behavior="An unsupported HTTP method on the MCP endpoint returns 405.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:no-broadcast": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#multiple-connections", @@ -2074,13 +2073,11 @@ def __post_init__(self) -> None: "exactly one stream, never duplicated." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:notifications-202": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior="A POST containing only notifications or responses returns 202 with no body.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:onerror": Requirement( source="sdk", @@ -2095,13 +2092,11 @@ def __post_init__(self) -> None: "the body may carry a JSON-RPC error response (the SDK sends a Parse error body)." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:protocol-version-400": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", behavior="An invalid or unsupported MCP-Protocol-Version header returns 400 Bad Request.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:protocol-version-default": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", @@ -2110,7 +2105,6 @@ def __post_init__(self) -> None: "way, the server assumes protocol version 2025-03-26." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:response-same-connection": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -2119,25 +2113,21 @@ def __post_init__(self) -> None: "that stream's resumed continuation), not on an unrelated stream." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:second-sse-rejected": Requirement( source="sdk", behavior="A second concurrent standalone GET SSE stream on the same session is rejected.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:sse-close-after-response": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior="The server terminates a POST-initiated SSE stream after writing the JSON-RPC response.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "hosting:http:standalone-sse": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", behavior="GET opens a standalone SSE stream that receives server-initiated messages.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "hosting:http:standalone-sse-no-response": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", @@ -2146,7 +2136,6 @@ def __post_init__(self) -> None: "response, except when resuming a prior request stream." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), # ═══════════════════════════════════════════════════════════════════════════ # Client transport: streamable HTTP @@ -2167,7 +2156,6 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", behavior="The client GET to the MCP endpoint includes an Accept header listing text/event-stream.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:accept-header-post": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -2176,13 +2164,11 @@ def __post_init__(self) -> None: "and text/event-stream." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:concurrent-streams": Requirement( source="sdk", behavior="Multiple concurrent POST-initiated SSE streams each deliver their response to the right caller.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:custom-client": Requirement( source="sdk", @@ -2191,31 +2177,26 @@ def __post_init__(self) -> None: "including auth flows." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:custom-headers": Requirement( source="sdk", behavior="Caller-supplied headers are sent on every POST, GET, and DELETE to the MCP endpoint.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:json-response-parsed": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior="A Content-Type application/json response is parsed as a single JSON-RPC message.", transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:no-reconnect-after-close": Requirement( source="sdk", behavior="After the transport is closed, no further reconnection attempts are scheduled.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:no-reconnect-after-response": Requirement( source="sdk", behavior="A POST-initiated stream that already delivered its response is not reconnected when it closes.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:protocol-version-header": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#protocol-version-header", @@ -2224,13 +2205,13 @@ def __post_init__(self) -> None: "subsequent HTTP request." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:protocol-version-stored": Requirement( source="sdk", - behavior="The client transport exposes the negotiated protocol version once initialization completes.", + behavior=( + "The client transport stores the negotiated protocol version and sends it on every subsequent request." + ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:reconnect-get": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", @@ -2238,7 +2219,12 @@ def __post_init__(self) -> None: "A standalone GET SSE stream that errors is reconnected with the Last-Event-ID of the last received event." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", + deferred=( + "Not yet covered here: the standalone GET stream emits no priming event or retry hint, so " + "the client's reconnection path always sleeps the hard-coded 1 s default; a deterministic " + "in-process test would inject real-time delay or require an SDK change. The POST-stream " + "reconnection path is covered by client-transport:http:reconnect-post-priming." + ), ), "client-transport:http:reconnect-post-priming": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", @@ -2247,13 +2233,11 @@ def __post_init__(self) -> None: "if a priming event (an event carrying an ID) was received on it." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:reconnect-retry-value": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#sending-messages-to-the-server", behavior="Reconnection delay honours the server-provided SSE retry value when one was sent.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:resume-stream-api": Requirement( source="sdk", @@ -2262,7 +2246,6 @@ def __post_init__(self) -> None: "the notifications it missed." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:session-stored": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", @@ -2271,19 +2254,16 @@ def __post_init__(self) -> None: "every subsequent request." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "client-transport:http:sse-405-tolerated": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", behavior="Opening the standalone GET SSE stream tolerates a 405 response without failing the connection.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "client-transport:http:terminate-405-ok": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior="Session termination succeeds without error if the server answers 405 (termination unsupported).", transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), # ═══════════════════════════════════════════════════════════════════════════ # Client auth diff --git a/tests/interaction/transports/_event_store.py b/tests/interaction/transports/_event_store.py new file mode 100644 index 0000000000..84d1a2646a --- /dev/null +++ b/tests/interaction/transports/_event_store.py @@ -0,0 +1,55 @@ +"""A predictable event store for resumability tests. + +The SDK's `EventStore` interface lets a streamable-HTTP server stamp every SSE event with an ID +and replay missed events when a client reconnects with `Last-Event-ID`. This implementation +issues sequential integer IDs starting at "1" so tests can assert exact IDs (the example store +uses uuid4, which cannot be snapshotted) and is small enough that every line is exercised by the +resumability tests themselves. +""" + +import anyio + +from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId +from mcp.types import JSONRPCMessage + + +class SequencedEventStore(EventStore): + """Stores every event in order and replays the same-stream tail after a given ID.""" + + def __init__(self) -> None: + self._events: list[tuple[StreamId, JSONRPCMessage | None]] = [] + self._milestones: dict[int, anyio.Event] = {} + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self._events.append((stream_id, message)) + count = len(self._events) + milestone = self._milestones.pop(count, None) + if milestone is not None: + milestone.set() + return str(count) + + async def wait_until_stored(self, count: int) -> None: + """Block until at least `count` events have been stored. + + Tests use this to wait for the server's message router (which runs in another task) to + finish storing a known set of events before issuing a replay, so the replay's content is + deterministic rather than depending on task scheduling order. + """ + if len(self._events) >= count: + return + milestone = self._milestones.setdefault(count, anyio.Event()) + await milestone.wait() + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + try: + cursor = int(last_event_id) + except ValueError: + return None + if not 0 < cursor <= len(self._events): + return None + stream_id, _ = self._events[cursor - 1] + for index in range(cursor, len(self._events)): + event_stream_id, message = self._events[index] + if event_stream_id == stream_id and message is not None: + await send_callback(EventMessage(message, str(index + 1))) + return stream_id diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py new file mode 100644 index 0000000000..604f08a8f2 --- /dev/null +++ b/tests/interaction/transports/test_client_transport_http.py @@ -0,0 +1,211 @@ +"""Behaviour of the streamable-HTTP client transport itself, observed at the wire. + +These tests connect a real `Client` to a real server over the in-process bridge, recording every +HTTP request the SDK client issues, so the assertions are about what the transport sends (headers, +methods, ordering) rather than what the protocol layer on top of it returns. The recording is the +wire-level instrument; the SDK client never exposes these details. +""" + +from collections.abc import AsyncIterator + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from starlette.types import Receive, Scope, Send + +from mcp import types +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.transports._bridge import StreamingASGITransport +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _tooled_server() -> Server: + """A low-level server with one echo tool, used by every test in this file.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", description="Echo text.", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "echo" + assert params.arguments is not None + return CallToolResult(content=[TextContent(text=str(params.arguments["text"]))]) + + return Server("echoer", on_list_tools=list_tools, on_call_tool=call_tool) + + +@pytest.fixture +async def recorded() -> AsyncIterator[list[httpx.Request]]: + """Connect a `Client` over a recording HTTP client, list tools, exit, and yield every request sent. + + The HTTP client carries one caller-supplied header (`x-trace`) so its propagation can be + asserted; the recording captures the closing DELETE because it is read after the `Client` has + fully exited. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record, headers={"x-trace": "abc"}) as (http, _): + async with client_via_http(http) as client: + result = await client.list_tools() + assert [tool.name for tool in result.tools] == ["echo"] + + yield requests + + +def _after_initialize(recorded: list[httpx.Request]) -> list[httpx.Request]: + """Every recorded request after the initialize POST (which carries no session yet).""" + assert recorded[0].method == "POST" + assert "mcp-session-id" not in recorded[0].headers + return recorded[1:] + + +@requirement("client-transport:http:custom-client") +@requirement("client-transport:http:custom-headers") +async def test_the_client_uses_the_supplied_http_client_and_propagates_its_headers( + recorded: list[httpx.Request], +) -> None: + """A caller-supplied `httpx.AsyncClient` is used for every request and carries its own headers. + + The recording itself proves the supplied client is the one in use; the propagated header + proves the SDK transport does not replace the caller's client configuration. + """ + # Exact ordering past the first request is not guaranteed (the standalone GET stream is + # scheduled concurrently with later POSTs), so methods are asserted as a multiset. + assert sorted(request.method for request in recorded) == snapshot(["DELETE", "GET", "POST", "POST", "POST"]) + assert all(request.headers["x-trace"] == "abc" for request in recorded) + + +@requirement("client-transport:http:session-stored") +async def test_every_request_after_initialize_carries_the_issued_session_id(recorded: list[httpx.Request]) -> None: + """The session id from the initialize response is sent on every subsequent request.""" + session_ids = {request.headers["mcp-session-id"] for request in _after_initialize(recorded)} + assert len(session_ids) == 1 + (session_id,) = session_ids + assert session_id + + +@requirement("client-transport:http:protocol-version-stored") +@requirement("client-transport:http:protocol-version-header") +async def test_every_request_after_initialize_carries_the_negotiated_protocol_version( + recorded: list[httpx.Request], +) -> None: + """The negotiated protocol version is sent on every subsequent request (and not on initialize).""" + assert "mcp-protocol-version" not in recorded[0].headers + versions = {request.headers["mcp-protocol-version"] for request in _after_initialize(recorded)} + assert versions == snapshot({"2025-11-25"}) + + +@requirement("client-transport:http:accept-header-post") +@requirement("client-transport:http:accept-header-get") +async def test_accept_headers_cover_the_response_representations_the_transport_handles( + recorded: list[httpx.Request], +) -> None: + """POSTs accept both JSON and SSE; the standalone GET stream accepts SSE.""" + for request in recorded: + if request.method == "POST": + assert "application/json" in request.headers["accept"] + assert "text/event-stream" in request.headers["accept"] + if request.method == "GET": + assert "text/event-stream" in request.headers["accept"] + + +@requirement("client-transport:http:no-reconnect-after-close") +async def test_closing_the_client_sends_delete_and_does_not_reconnect(recorded: list[httpx.Request]) -> None: + """Client teardown sends DELETE and issues no further requests (no resumption GET).""" + assert recorded[-1].method == "DELETE" + assert all("last-event-id" not in request.headers for request in recorded) + + +@requirement("client-transport:http:concurrent-streams") +async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_own_response() -> None: + """Three tool calls issued at once each open their own POST stream and get the right answer.""" + requests: list[httpx.Request] = [] + results: dict[int, CallToolResult] = {} + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_tooled_server(), on_request=record) as (http, _): + async with client_via_http(http) as client: + + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + for n in (1, 2, 3): + tg.start_soon(call, n) + + assert results == snapshot( + { + 1: CallToolResult(content=[TextContent(text="1")]), + 2: CallToolResult(content=[TextContent(text="2")]), + 3: CallToolResult(content=[TextContent(text="3")]), + } + ) + tools_call_posts = [r for r in requests if r.method == "POST" and b'"tools/call"' in r.content] + assert len(tools_call_posts) == 3 + + +@requirement("client-transport:http:sse-405-tolerated") +@requirement("client-transport:http:terminate-405-ok") +async def test_client_tolerates_405_on_get_and_delete() -> None: + """A 405 on the standalone GET stream or the closing DELETE does not fail the connection. + + The GET-stream task swallows the failure and schedules a reconnect that the closing cancel + interrupts before it ever sleeps the full default delay; the DELETE 405 is logged and ignored. + Neither surfaces to the caller. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + + async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] in ("GET", "DELETE"): + await send({"type": "http.response.start", "status": 405, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await real_app(scope, receive, send) + + async with server.session_manager.run(): + http_client = httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) + async with http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): + async with Client(transport) as client: + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + + +@requirement("client-transport:http:no-reconnect-after-response") +async def test_a_completed_post_stream_is_not_reconnected() -> None: + """A POST stream that delivered its response closes without a resumption GET. + + With an event store the server stamps every SSE event with an ID, so the client transport has a + Last-Event-ID it could resume from -- the test proves it does not, because the response arrived + and the stream completed normally. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + server = _tooled_server() + async with mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _): + async with client_via_http(http) as client: + with anyio.fail_after(5): + result = await client.list_tools() + + assert [tool.name for tool in result.tools] == ["echo"] + resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] + assert resumption_gets == [] diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py new file mode 100644 index 0000000000..aa9beee067 --- /dev/null +++ b/tests/interaction/transports/test_hosting_http.py @@ -0,0 +1,297 @@ +"""Streamable HTTP semantics: status codes, header validation, message routing, and security. + +These tests speak HTTP directly to the server's mounted ASGI app via the in-process bridge, +asserting the wire contract -- which status code answers which condition, which stream a message +travels on -- that the SDK client never exposes. Transport-agnostic behaviour is covered by the +`connect`-fixture matrix. +""" + +import anyio +import pytest +from anyio.lowlevel import checkpoint +from httpx_sse import ServerSentEvent, aconnect_sse +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.server.transport_security import TransportSecuritySettings +from mcp.types import ( + INVALID_PARAMS, + PARSE_ERROR, + CallToolRequestParams, + CallToolResult, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ListToolsResult, + PaginatedRequestParams, + TextContent, +) +from tests.interaction._connect import ( + base_headers, + initialize_body, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A low-level server with one tool that emits a related and an unrelated notification.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + """Registered only so the tools capability is advertised; never called.""" + raise NotImplementedError + + async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + assert params.name == "narrate" + await ctx.session.send_log_message(level="info", data="related", logger=None, related_request_id=ctx.request_id) + await ctx.session.send_resource_updated("file:///watched.txt") + return CallToolResult(content=[TextContent(text="done")]) + + return Server("hosted", on_list_tools=list_tools, on_call_tool=call_tool) + + +@requirement("hosting:http:method-405") +async def test_unsupported_http_methods_return_405() -> None: + """PUT and PATCH on the MCP endpoint return 405 with an Allow header naming the supported methods.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + put = await http.put("/mcp", json={}, headers=base_headers(session_id=session_id)) + patch = await http.patch("/mcp", json={}, headers=base_headers(session_id=session_id)) + + assert (put.status_code, put.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + assert (patch.status_code, patch.headers.get("allow")) == snapshot((405, "GET, POST, DELETE")) + + +@requirement("hosting:http:accept-406") +async def test_missing_accept_media_types_return_406() -> None: + """A POST whose Accept header lacks both required types, or a GET lacking text/event-stream, returns 406.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", json=initialize_body(), headers={"accept": "text/plain", "mcp-protocol-version": "2025-11-25"} + ) + session_id = await initialize_via_http(http) + get = await http.get( + "/mcp", + headers={"accept": "application/json", "mcp-protocol-version": "2025-11-25", "mcp-session-id": session_id}, + ) + + assert (post.status_code, post.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept both application/json and text/event-stream") + ) + assert (get.status_code, get.json()["error"]["message"]) == snapshot( + (406, "Not Acceptable: Client must accept text/event-stream") + ) + + +@requirement("hosting:http:content-type-415") +async def test_non_json_content_type_is_rejected() -> None: + """A POST with a non-JSON Content-Type is rejected before reaching the transport. + + See the divergence on the requirement: the security middleware rejects with 400, so the + transport's own 415 path is unreachable through any public entry point. + """ + async with mounted_app(_server()) as (http, _): + response = await http.post( + "/mcp", content=b"", headers=base_headers() | {"content-type": "text/plain"} + ) + + assert (response.status_code, response.text) == snapshot((400, "Invalid Content-Type header")) + + +@requirement("hosting:http:parse-error-400") +@requirement("hosting:http:batch") +async def test_malformed_and_batched_bodies_return_400() -> None: + """A non-JSON body returns 400 Parse error; a JSON array of requests returns 400 Invalid params.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + not_json = await http.post( + "/mcp", + content=b"this is not json", + headers=base_headers(session_id=session_id) | {"content-type": "application/json"}, + ) + batched = await http.post( + "/mcp", + json=[ + {"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + ], + headers=base_headers(session_id=session_id), + ) + + assert not_json.status_code == 400 + assert JSONRPCError.model_validate_json(not_json.text).error.code == PARSE_ERROR + assert batched.status_code == 400 + assert JSONRPCError.model_validate_json(batched.text).error.code == INVALID_PARAMS + + +@requirement("hosting:http:protocol-version-400") +@requirement("hosting:http:protocol-version-default") +async def test_protocol_version_header_is_validated() -> None: + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + bad = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + ) + # Only Accept and the session ID -- no MCP-Protocol-Version header at all. + defaulted = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers={"accept": "application/json, text/event-stream", "mcp-session-id": session_id}, + ) + + assert bad.status_code == 400 + assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( + "Bad Request: Unsupported protocol version: 1991-01-01." + ) + # 202 proves the request was accepted under the assumed default version (2025-03-26). + assert defaulted.status_code == 202 + + +@requirement("hosting:http:notifications-202") +async def test_notification_post_returns_202_with_no_body() -> None: + """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + response = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progressToken": 0, "progress": 1}}, + headers=base_headers(session_id=session_id), + ) + + assert (response.status_code, response.content) == snapshot((202, b"")) + + +@requirement("hosting:http:second-sse-rejected") +async def test_a_second_standalone_get_stream_on_the_same_session_returns_409() -> None: + """Opening a second standalone GET SSE stream while one is already established returns 409 Conflict.""" + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as first: + assert first.response.status_code == 200 + # The standalone-stream writer registers its key as its first action, then parks + # awaiting messages; one yield to the loop lets that registration complete before the + # second GET is dispatched. + await checkpoint() + second = await http.get("/mcp", headers=base_headers(session_id=session_id)) + + assert (second.status_code, second.json()["error"]["message"]) == snapshot( + (409, "Conflict: Only one SSE stream is allowed per session") + ) + + +@requirement("hosting:http:standalone-sse") +@requirement("hosting:http:standalone-sse-no-response") +@requirement("hosting:http:response-same-connection") +@requirement("hosting:http:sse-close-after-response") +@requirement("hosting:http:no-broadcast") +async def test_messages_are_routed_to_exactly_one_stream() -> None: + """Each server message travels on exactly one SSE stream and is never broadcast. + + A streamable-HTTP session has two kinds of server-to-client SSE stream: one short-lived stream + per POST request, carrying that request's response and any notifications related to it, and one + long-lived standalone stream (opened by GET) for notifications not tied to any request. The + spec's routing rule is that the POST stream delivers the response (and its related + notifications) and then closes, the standalone stream carries only unrelated notifications and + never a JSON-RPC response, and no message appears on both. The test opens both streams, calls a + tool whose handler emits one related and one unrelated notification, and asserts each message's + routing. + """ + async with mounted_app(_server()) as (http, _): + session_id = await initialize_via_http(http) + post_events: list[ServerSentEvent] = [] + get_events: list[ServerSentEvent] = [] + + async def read_standalone_stream() -> None: + async with aconnect_sse(http, "GET", "/mcp", headers=base_headers(session_id=session_id)) as get: + assert get.response.status_code == 200 + standalone_ready.set() + async for event in get.aiter_sse(): + get_events.append(event) + seen_on_standalone.set() + + standalone_ready = anyio.Event() + seen_on_standalone = anyio.Event() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(read_standalone_stream) + await standalone_ready.wait() + + params = CallToolRequestParams(name="narrate", arguments={}) + body = JSONRPCRequest(jsonrpc="2.0", id=5, method="tools/call", params=params.model_dump()) + async with aconnect_sse( + http, + "POST", + "/mcp", + json=body.model_dump(by_alias=True, exclude_none=True), + headers=base_headers(session_id=session_id), + ) as post: + assert post.response.status_code == 200 + # The POST stream iterator ends when the server closes the stream after the response. + post_events = [event async for event in post.aiter_sse()] + + await seen_on_standalone.wait() + tg.cancel_scope.cancel() + + post_messages = parse_sse_messages(post_events) + get_messages = parse_sse_messages(get_events) + + # POST stream: the related log notification, then the response, then the iterator ends (close). + assert [type(m).__name__ for m in post_messages] == snapshot(["JSONRPCNotification", "JSONRPCResponse"]) + assert isinstance(post_messages[0], JSONRPCNotification) + assert (post_messages[0].method, post_messages[0].params) == snapshot( + ("notifications/message", {"level": "info", "data": "related"}) + ) + assert isinstance(post_messages[1], JSONRPCResponse) + assert post_messages[1].id == 5 + + # Standalone stream: only the unrelated resource-updated notification, never a response. + assert [type(m).__name__ for m in get_messages] == snapshot(["JSONRPCNotification"]) + assert isinstance(get_messages[0], JSONRPCNotification) + assert get_messages[0].method == snapshot("notifications/resources/updated") + + +@requirement("hosting:http:dns-rebinding") +@requirement("transport:streamable-http:origin-validation") +async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> None: + """A disallowed Origin returns 403 (and Host 421) with protection enabled; disabled lets both through. + + See the divergence on hosting:http:dns-rebinding: the spec's Origin validation is an + unconditional MUST, but the SDK enables it only when the host is localhost (or settings are + passed explicitly) and additionally checks the Host header (returning 421), which the spec + does not require. + """ + # transport_security=None triggers the localhost auto-enable behaviour. + async with mounted_app(Server("guarded"), transport_security=None) as (http, _): + bad_origin = await http.post( + "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) + bad_host = await http.post("/mcp", json=initialize_body(), headers=base_headers() | {"host": "evil.example"}) + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://127.0.0.1:8000"} + ) as ok: + assert ok.response.status_code == 200 + assert [event async for event in ok.aiter_sse()] + + assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) + assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + + async with mounted_app( + Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) + ) as (http, _): + async with aconnect_sse( + http, "POST", "/mcp", json=initialize_body(), headers=base_headers() | {"origin": "http://evil.example"} + ) as unguarded: + status = unguarded.response.status_code + assert [event async for event in unguarded.aiter_sse()] + + assert status == 200 diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py new file mode 100644 index 0000000000..6abeb5d8ed --- /dev/null +++ b/tests/interaction/transports/test_hosting_resume.py @@ -0,0 +1,287 @@ +"""Resumability over the streamable HTTP transport, exercised entirely in process. + +These tests configure the server with an event store, so every SSE event is stamped with an ID +and a client that loses its connection can resume by sending `Last-Event-ID`. The wire-level +tests (`mounted_app` + raw httpx) assert exactly what travels on the wire; the end-to-end test +drives the SDK client through a server-initiated stream close and proves the call still +completes. The bridge's `aclose()` delivers `http.disconnect` to the running application, so +closing a streaming response mid-read is a deterministic in-process disconnect -- no sockets, +no real time. Every server here uses `retry_interval=0` so reconnection waits are no-ops. +""" + +import json + +import anyio +import httpx +import pytest +from httpx_sse import EventSource, ServerSentEvent +from inline_snapshot import snapshot + +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, + jsonrpc_message_adapter, +) +from tests.interaction._connect import ( + base_headers, + connect_over_streamable_http, + initialize_via_http, + mounted_app, + parse_sse_messages, +) +from tests.interaction._requirements import requirement +from tests.interaction.transports._event_store import SequencedEventStore + +pytestmark = pytest.mark.anyio + + +def _counting_server() -> MCPServer: + """A server with one tool that emits related notifications and one unrelated notification.""" + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context, n: int) -> str: + """Emit n log notifications related to this call, plus one unrelated resource update.""" + for i in range(1, n + 1): + await ctx.info(f"tick {i}") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return f"counted to {n}" + + return mcp + + +def _tools_call(request_id: int, name: str, arguments: dict[str, object]) -> str: + """A serialized tools/call JSON-RPC request body.""" + return JSONRPCRequest( + jsonrpc="2.0", id=request_id, method="tools/call", params={"name": name, "arguments": arguments} + ).model_dump_json(by_alias=True, exclude_none=True) + + +async def _read_events(response: httpx.Response, count: int) -> list[ServerSentEvent]: + """Read exactly `count` SSE events from a streaming response without closing it.""" + source = EventSource(response).aiter_sse() + return [await anext(source) for _ in range(count)] + + +@requirement("hosting:resume:event-ids") +@requirement("hosting:resume:priming") +async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_event() -> None: + """A request's SSE stream opens with a priming event (id, empty data, retry) then stamps each message.""" + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) + ) as response: + assert response.status_code == 200 + events = await _read_events(response, 4) + + priming, first, second, result = events + # The priming event is the only event a client could have seen before any work happened, so it + # is the resumption anchor: it carries an ID and empty data. The SDK attaches the retry hint + # to this event (see the divergence on hosting:resume:priming). + assert (priming.id, priming.data, priming.retry) == snapshot(("3", "", 0)) + assert priming.event == snapshot("message") + # Every subsequent event carries an event-store ID; the related notifications and the response + # all ride this stream and close it after the response. + assert [event.id for event in (first, second, result)] == snapshot(["4", "5", "7"]) + assert [json.loads(event.data)["method"] for event in (first, second)] == snapshot( + ["notifications/message", "notifications/message"] + ) + assert jsonrpc_message_adapter.validate_json(result.data) == snapshot( + JSONRPCResponse( + jsonrpc="2.0", + id=1, + result={ + "content": [{"type": "text", "text": "counted to 2"}], + "structuredContent": {"result": "counted to 2"}, + "isError": False, + }, + ) + ) + + +@requirement("hosting:resume:replay") +@requirement("hosting:resume:stream-scoped") +@requirement("hosting:resume:buffered-replay") +async def test_get_with_last_event_id_replays_only_that_streams_missed_events() -> None: + """Reconnecting with Last-Event-ID returns the missed events from that one stream, in order. + + The handler also emits an unrelated notification (which the server stores under the + standalone-stream key); replay must not return it, proving replay is scoped to the stream + the given event ID belongs to. + + Steps: (1) initialize; (2) POST a tool call and read events until the first notification is + captured; (3) close the response mid-stream -- the bridge delivers `http.disconnect`, the + handler keeps running; (4) release the handler so it emits the remaining messages, which the + server buffers in the event store; (5) wait on the event store for the handler's response to + be stored, so the replay's content is independent of task scheduling; (6) GET with + `Last-Event-ID` and assert the replay is exactly the missed events from this request's stream. + """ + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + """Emit one related notification, wait for the test, then emit two more plus an unrelated one.""" + await ctx.info("tick 1") + await release.wait() + await ctx.info("tick 2") + await ctx.info("tick 3") + await ctx.session.send_resource_updated("file:///elsewhere.txt") + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + # Read the priming event and the first notification, then drop the connection. + priming, first = await _read_events(response, 2) + assert (priming.id, first.id) == snapshot(("3", "4")) + last_seen = first.id + release.set() + # The handler keeps running after the disconnect; its remaining messages are stored. + # The first wait returns immediately (the priming and first tick are already stored); + # the second blocks until the response itself is stored so the replay content is fixed. + await store.wait_until_stored(4) + await store.wait_until_stored(8) + replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: + assert replay.status_code == 200 + missed = await _read_events(replay, 3) + + decoded = parse_sse_messages(missed) + # Exactly the two remaining related notifications and the response, with their original IDs. + assert [event.id for event in missed] == snapshot(["5", "6", "8"]) + assert [type(message).__name__ for message in decoded] == snapshot( + ["JSONRPCNotification", "JSONRPCNotification", "JSONRPCResponse"] + ) + assert isinstance(decoded[2], JSONRPCResponse) + assert decoded[2].id == 1 + # The unrelated resource-updated notification was stored under the standalone-stream key, not + # this request's stream, so it must not appear in the replay. + assert all( + not (isinstance(message, JSONRPCNotification) and message.method == "notifications/resources/updated") + for message in decoded + ) + + +@requirement("hosting:resume:bad-event-id") +async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: + """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. + + See the divergence on hosting:resume:bad-event-id: this pins current behaviour. + """ + async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + for unknown in ("no-such-event", "0"): + headers = base_headers(session_id=session_id) | {"last-event-id": unknown} + async with http.stream("GET", "/mcp", headers=headers) as replay: + assert replay.status_code == 200 + assert replay.headers["content-type"].startswith("text/event-stream") + events = [event async for event in EventSource(replay).aiter_sse()] + assert events == [] + + +@requirement("hosting:http:disconnect-not-cancel") +async def test_dropping_the_connection_mid_request_does_not_cancel_the_handler() -> None: + """Closing the request's SSE connection while the handler is running leaves the handler running. + + The handler signals when it has started and when it has finished; the test drops the + connection in between and then releases the handler. If the disconnect cancelled the handler, + `finished` would never be set and the test would time out. + """ + started = anyio.Event() + release = anyio.Event() + finished = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Signal start, wait for the test, signal completion.""" + started.set() + await release.wait() + await ctx.info("released") + finished.set() + return "held" + + async with mounted_app(mcp, event_store=SequencedEventStore(), retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "hold", {}), headers=base_headers(session_id=session_id) + ) as response: + await _read_events(response, 1) + await started.wait() + assert not finished.is_set() + release.set() + await finished.wait() + + +@requirement("hosting:resume:close-stream") +@requirement("transport:streamable-http:resumability") +@requirement("client-transport:http:reconnect-post-priming") +@requirement("client-transport:http:reconnect-retry-value") +@requirement("client-transport:http:resume-stream-api") +async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: + """A server-closed request stream is reconnected by the client and the call completes. + + The handler emits one notification, closes its own SSE stream, then (once released) emits + another and returns. The client observed the priming event (so it has a Last-Event-ID and a + retry hint of 0ms), sees the stream end, reconnects via GET with Last-Event-ID, and receives + the post-close notification and the result over the replay stream. The shared events make the + test deterministic: the handler only proceeds once the test knows the first notification has + arrived (and so the client's reconnection has begun). + """ + received: list[object] = [] + before_seen = anyio.Event() + gate = anyio.Event() + done = anyio.Event() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def interrupt(ctx: Context) -> str: + """Emit, close this call's SSE stream, then emit again after the test releases the gate.""" + await ctx.info("before close") + await ctx.close_sse_stream() + await gate.wait() + await ctx.info("after close") + done.set() + return "resumed" + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + if params.data == "before close": + before_seen.set() + + result: list[CallToolResult] = [] + async with connect_over_streamable_http( + mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect + ) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call() -> None: + result.append(await client.call_tool("interrupt", {})) + + tg.start_soon(call) + await before_seen.wait() + gate.set() + await done.wait() + + assert result == snapshot( + [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] + ) + assert received == snapshot(["before close", "after close"]) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py new file mode 100644 index 0000000000..561fbf251a --- /dev/null +++ b/tests/interaction/transports/test_hosting_session.py @@ -0,0 +1,203 @@ +"""Streamable HTTP session lifecycle: creation, routing, termination, and stateless mode. + +A test here speaks raw HTTP only when its assertion is the wire contract -- which header is +issued, which status code answers which condition -- that the SDK `Client` cannot observe. +Everything else is `Client`-driven against the same mounted session manager. Transport-agnostic +behaviour is covered by the `connect`-fixture matrix. +""" + +import re + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server, ServerRequestContext +from mcp.types import JSONRPCResponse, ListToolsResult, PaginatedRequestParams, Tool +from tests.interaction._connect import ( + base_headers, + client_via_http, + initialize_body, + initialize_via_http, + mounted_app, + post_jsonrpc, +) +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +def _server() -> Server: + """A minimal low-level server with one tool, so subsequent-request routing can be observed.""" + + async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="noop", description="Does nothing.", input_schema={"type": "object"})]) + + return Server("hosted", on_list_tools=list_tools) + + +@requirement("hosting:session:create") +@requirement("hosting:session:id-charset") +async def test_initialize_issues_a_visible_ascii_session_id() -> None: + """An initialize POST without a session ID creates a session and returns a visible-ASCII Mcp-Session-Id.""" + async with mounted_app(_server()) as (http, _): + response, messages = await post_jsonrpc(http, initialize_body()) + + assert response.status_code == 200 + session_id = response.headers.get("mcp-session-id") + assert session_id is not None + # The spec requires the session ID to consist only of visible ASCII (0x21-0x7E). + assert re.fullmatch(r"[\x21-\x7E]+", session_id) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 1 + + +@requirement("hosting:session:reuse") +async def test_subsequent_requests_with_the_session_id_route_to_the_same_session() -> None: + """Requests carrying the issued Mcp-Session-Id reuse that session's transport rather than creating another.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as client: + await client.list_tools() + await client.list_tools() + # The session count is the only signal that distinguishes routing-to-existing from + # silently creating a second session: both produce a successful result. + assert len(manager._server_instances) == 1 + + +@requirement("hosting:session:unknown-id") +async def test_requests_with_an_unknown_session_id_return_404() -> None: + """POST, GET, and DELETE each carrying an unknown Mcp-Session-Id are answered 404 by the manager.""" + async with mounted_app(_server()) as (http, _): + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, + headers=base_headers(session_id="not-a-session"), + ) + get = await http.get("/mcp", headers=base_headers(session_id="not-a-session")) + delete = await http.delete("/mcp", headers=base_headers(session_id="not-a-session")) + + assert (post.status_code, post.json()) == snapshot( + (404, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Session not found"}}) + ) + assert (get.status_code, delete.status_code) == (404, 404) + + +@requirement("hosting:session:missing-id") +async def test_non_initialize_post_without_a_session_id_returns_400() -> None: + """A non-initialize POST that omits Mcp-Session-Id in stateful mode is rejected with 400.""" + async with mounted_app(_server()) as (http, _): + await initialize_via_http(http) + response = await http.post( + "/mcp", json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers() + ) + + assert (response.status_code, response.json()) == snapshot( + (400, {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "Bad Request: Missing session ID"}}) + ) + + +@requirement("hosting:session:delete") +@requirement("hosting:session:post-termination-404") +async def test_delete_terminates_the_session_and_subsequent_requests_return_404() -> None: + """DELETE with a valid Mcp-Session-Id terminates the session; further requests on that ID return 404.""" + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + + delete = await http.delete("/mcp", headers=base_headers(session_id=session_id)) + assert delete.status_code == 200 + + # The manager keeps the terminated transport registered, so the next request reaches the + # transport's own _terminated check rather than the manager's unknown-session path. + assert session_id in manager._server_instances + post = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, + headers=base_headers(session_id=session_id), + ) + + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) + ) + + +@requirement("hosting:session:isolation") +async def test_terminating_one_session_leaves_others_working() -> None: + """Terminating one session on a manager does not disturb a concurrent session on the same manager.""" + async with mounted_app(_server()) as (http, manager): + async with client_via_http(http) as survivor: + async with client_via_http(http) as terminated: + await terminated.list_tools() + assert len(manager._server_instances) == 2 + # `terminated` has exited (its DELETE has been sent); `survivor` still answers. + result = await survivor.list_tools() + + assert result.tools[0].name == "noop" + + +@requirement("hosting:session:reinitialize") +async def test_second_initialize_on_an_existing_session_is_accepted() -> None: + """A second initialize POST carrying an existing session ID is processed rather than rejected. + + See the divergence on the requirement: the entry expects a rejection, but the SDK forwards the + second initialize to the running server, which answers it as a fresh handshake. + """ + async with mounted_app(_server()) as (http, manager): + session_id = await initialize_via_http(http) + response, messages = await post_jsonrpc(http, initialize_body(request_id=2), session_id=session_id) + assert len(manager._server_instances) == 1 + + assert response.status_code == snapshot(200) + assert isinstance(messages[0], JSONRPCResponse) + assert messages[0].id == 2 + + +@requirement("hosting:stateless:no-session-id") +@requirement("hosting:stateless:no-reuse") +async def test_stateless_mode_never_issues_a_session_id() -> None: + """A stateless server issues no Mcp-Session-Id and creates no persistent transport. + + The recording proves no request the SDK client sent carried an Mcp-Session-Id (the server + cannot have issued one, or the client would echo it); the empty instance map proves the + manager kept no transport between requests. + """ + requests: list[httpx.Request] = [] + + async def record(request: httpx.Request) -> None: + requests.append(request) + + async with mounted_app(_server(), stateless_http=True, on_request=record) as (http, manager): + async with client_via_http(http) as client: + result = await client.list_tools() + assert manager._server_instances == {} + + assert result.tools[0].name == "noop" + assert all("mcp-session-id" not in request.headers for request in requests) + assert "DELETE" not in {request.method for request in requests} + + +@requirement("hosting:stateless:concurrent-clients") +async def test_stateless_mode_serves_concurrent_clients_independently() -> None: + """Two clients connected concurrently to the same stateless app each complete a round trip.""" + results: dict[str, ListToolsResult] = {} + + async with mounted_app(_server(), stateless_http=True) as (http, _): + + async def list_via(label: str) -> None: + async with client_via_http(http) as client: + results[label] = await client.list_tools() + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(list_via, "a") + tg.start_soon(list_via, "b") + + assert results["a"].tools[0].name == "noop" + assert results["b"].tools[0].name == "noop" diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index 1d5434c160..4facadec73 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -20,14 +20,12 @@ from mcp.client.sse import sse_client from mcp.server import Server from mcp.types import EmptyResult -from tests.interaction._connect import build_sse_app +from tests.interaction._connect import BASE_URL, build_sse_app from tests.interaction._requirements import requirement from tests.interaction.transports._bridge import StreamingASGITransport pytestmark = pytest.mark.anyio -_BASE_URL = "http://127.0.0.1:8000" - @requirement("transport:sse") @requirement("transport:sse:endpoint-event") @@ -45,14 +43,14 @@ def httpx_client_factory( ) -> httpx.AsyncClient: return httpx.AsyncClient( transport=StreamingASGITransport(app, cancel_on_close=False), - base_url=_BASE_URL, + base_url=BASE_URL, headers=headers, timeout=timeout, auth=auth, ) transport = sse_client( - f"{_BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append + f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory, on_session_created=captured_session_id.append ) with anyio.fail_after(5): async with Client(transport) as client: @@ -71,7 +69,7 @@ def httpx_client_factory( async def test_post_without_a_session_id_is_rejected() -> None: """A POST to the message endpoint with no session_id query parameter is answered 400.""" app, _ = build_sse_app(Server("legacy")) - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: response = await http.post("/messages/", json={"jsonrpc": "2.0", "method": "ping", "id": 1}) assert (response.status_code, response.text) == snapshot((400, "session_id is required")) @@ -80,7 +78,7 @@ async def test_post_without_a_session_id_is_rejected() -> None: async def test_post_with_a_malformed_session_id_is_rejected() -> None: """A POST whose session_id query parameter is not a UUID is answered 400.""" app, _ = build_sse_app(Server("legacy")) - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: response = await http.post( "/messages/", params={"session_id": "not-a-uuid"}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} ) @@ -91,7 +89,7 @@ async def test_post_with_a_malformed_session_id_is_rejected() -> None: async def test_post_for_an_unknown_session_is_rejected() -> None: """A POST naming a well-formed session_id that no SSE stream owns is answered 404.""" app, _ = build_sse_app(Server("legacy")) - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=_BASE_URL) as http: + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: response = await http.post( "/messages/", params={"session_id": uuid4().hex}, json={"jsonrpc": "2.0", "method": "ping", "id": 1} ) diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index f20fa44f05..72af075770 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -63,6 +63,8 @@ async def announce(ctx: Context) -> str: @requirement("transport:streamable-http:json-response") +@requirement("hosting:http:json-response-mode") +@requirement("client-transport:http:json-response-parsed") async def test_tool_call_over_streamable_http_with_json_responses() -> None: """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" async with connect_over_streamable_http(_smoke_server(), json_response=True) as client: @@ -104,6 +106,7 @@ async def test_stateless_streamable_http_rejects_server_initiated_requests() -> @requirement("transport:streamable-http:notifications") @requirement("transport:streamable-http:unrelated-messages") +@requirement("hosting:http:standalone-sse") async def test_unrelated_server_messages_arrive_on_the_standalone_stream() -> None: """A server message with no related request reaches the client through the standalone GET stream. From c13d6ae121391031cd9e1d8150d9f01e1ea31f77 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 09:57:16 +0000 Subject: [PATCH 18/34] test: cover protocol/lifecycle gap requirements and refine the divergence model --- tests/interaction/README.md | 5 + tests/interaction/_requirements.py | 121 +++++++++---- .../interaction/lowlevel/test_cancellation.py | 95 +++++++++- tests/interaction/lowlevel/test_initialize.py | 43 +++++ tests/interaction/lowlevel/test_progress.py | 163 +++++++++++++++++- tests/interaction/lowlevel/test_wire.py | 124 ++++++++++++- 6 files changed, 517 insertions(+), 34 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index e1341806c6..ba08fa564e 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -113,6 +113,11 @@ entry without a test cannot be silently aspirational. spec-correct output, and deletes the `Divergence`. 3. An empty divergence list means the SDK is spec-conformant on every behaviour the suite covers. +A requirement may carry both `divergence` and `deferred`: the divergence records that the SDK falls +short of the spec, and the deferral records why no test pins it (typically because the divergent +behaviour cannot be driven through the public API). Divergence alone implies a test pins the +divergent behaviour; divergence plus deferred means the gap is known but unpinned. + This is also the triage key for any rewrite: a test that fails on the new code path either has a divergence note (the rewrite accidentally fixed a known gap — decide whether to keep the fix) or it does not (the rewrite broke something that was correct — fix the rewrite). diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 91ab5375d9..ccb2b15eae 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -15,7 +15,9 @@ own contract) says should happen. Tests always pin the SDK's current behaviour. Where current behaviour falls short of `behavior`, the gap is recorded as data: `divergence` on entries whose tests pin the divergent behaviour, or `deferred` on entries that are tracked but not yet covered -by a test in this suite. `issue` carries the tracking link for a recorded gap once one is filed. +by a test in this suite. An entry may carry both: `divergence` records the spec-compliance gap +(issue-able) and `deferred` records why no test exists; `divergence` alone implies a test pins +the divergent behaviour. `issue` carries the tracking link for a recorded gap once one is filed. `deferred` reasons take one of three shapes: where the behaviour is exercised elsewhere in this repo the reason names the covering test path; where the SDK does not implement the behaviour at @@ -85,6 +87,12 @@ def __post_init__(self) -> None: behavior=( "The client rejects sending notifications or registering handlers for capabilities it did not declare." ), + divergence=Divergence( + note=( + "The client does not check its own declared capabilities before sending notifications or " + "serving callbacks; nothing prevents a caller from violating the spec's SHOULD." + ), + ), deferred=( "Not implemented in the SDK: the client does not check its own declared capabilities before " "sending notifications or serving callbacks." @@ -95,6 +103,12 @@ def __post_init__(self) -> None: behavior=( "The client rejects calls to methods (e.g. resources/list) for capabilities the server did not advertise." ), + divergence=Divergence( + note=( + "The client sends any request regardless of the server's advertised capabilities and " + "surfaces whatever the server answers; the spec's SHOULD is not enforced." + ), + ), deferred=( "Not implemented in the SDK: the client sends any request regardless of the server's " "advertised capabilities and surfaces whatever the server answers." @@ -168,9 +182,19 @@ def __post_init__(self) -> None: "Before initialization completes, the client sends no requests other than pings, and the " "server sends no requests other than pings and logging." ), + divergence=Divergence( + note=( + "The server's send methods (create_message / elicit_form / list_roots) do not check " + "initialization state before sending; on the client side, Client always completes the " + "handshake before any caller code runs." + ), + ), deferred=( - "Not yet covered here: the sender-side restraint (especially the server half — no sampling, " - "elicitation, or roots requests before the initialized notification) has no test yet." + "Not implemented in the SDK: neither side enforces sender-side restraint. The server's send " + "methods (create_message / elicit_form / list_roots) do not check initialization state before " + "sending, and there is no natural hook to issue a server-to-client request between the " + "initialize response and the initialized notification through the public API; on the client " + "side, Client always completes the handshake before any caller code runs." ), ), "lifecycle:version:downgrade": Requirement( @@ -179,12 +203,6 @@ def __post_init__(self) -> None: "When the server returns an older supported protocol version, the client downgrades to it " "and the connection succeeds at that version." ), - transports=("streamable-http",), - deferred=( - "Not yet covered here: observing the negotiated version requires the MCP-Protocol-Version " - "request header, which only exists on the HTTP transport; planned with the transport " - "conformance work." - ), ), "lifecycle:version:match": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#version-negotiation", @@ -268,9 +286,13 @@ def __post_init__(self) -> None: "A response that arrives after the sender issued notifications/cancelled is ignored; the " "request stays failed and no error is raised." ), - deferred=( - "Not yet covered here: needs the scripted-peer wire pattern to deliver a response after a " - "cancellation; today the receive loop logs an unknown-request-id error for such responses." + divergence=Divergence( + note=( + "A response whose id matches no in-flight request is delivered to the message handler " + "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " + "same code path; tested in its unknown-id form because that is deterministic without the " + "client-side cancellation API the SDK does not yet provide." + ), ), ), "protocol:cancel:server-survives": Requirement( @@ -283,6 +305,13 @@ def __post_init__(self) -> None: "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " "cancels it, and the client stops processing the cancelled request." ), + divergence=Divergence( + note=( + "Abandoning a server-side send_request emits no cancellation notification, and the client " + "could not act on one anyway: client callbacks run inline in the receive loop, so a " + "cancellation is not even read until the callback has finished." + ), + ), deferred=( "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " @@ -311,10 +340,6 @@ def __post_init__(self) -> None: "protocol:error:connection-closed": Requirement( source="sdk", behavior="Closing the transport fails all in-flight requests with a connection-closed error.", - deferred=( - "Not yet covered here: planned gap test (close the transport while a request is in flight and " - "pin the error the caller receives)." - ), ), "protocol:error:internal-error": Requirement( source=f"{SPEC_BASE_URL}/basic#responses", @@ -332,10 +357,6 @@ def __post_init__(self) -> None: "protocol:error:invalid-params": Requirement( source=f"{SPEC_BASE_URL}/basic#responses", behavior="A request with malformed params is answered with JSON-RPC error -32602 Invalid params.", - deferred=( - "Not yet covered here: the typed client API cannot send malformed params; needs a request " - "driven one level below it (planned gap test)." - ), ), "protocol:error:method-not-found": Requirement( source=f"{SPEC_BASE_URL}/basic#responses", @@ -371,27 +392,34 @@ def __post_init__(self) -> None: "protocol:progress:token-unique": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=("Concurrent in-flight requests that each supply a progress callback carry distinct progress tokens."), - deferred=( - "Not yet covered here: planned gap test (two concurrent requests with progress callbacks, " - "asserting their tokens differ and each callback only sees its own notifications)." - ), ), "protocol:progress:monotonic": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", behavior=( "The progress value increases with each notification for a given token, even when the total is unknown." ), - deferred=( - "Not implemented in the SDK: progress values are not validated anywhere; a handler can emit " - "non-increasing values and they are forwarded as-is." + divergence=Divergence( + note=( + "The spec MUST is not enforced: progress values are not validated on either side, so a " + "handler that emits non-increasing values has them forwarded to the callback unchanged." + ), ), ), "protocol:progress:stops-after-completion": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#behavior-requirements", behavior="Progress notifications for a token stop once the associated request completes.", - deferred=( - "Not yet covered here: needs a test that a handler reporting progress after its request " - "completed produces no further notifications for the caller." + divergence=Divergence( + note=( + "send_progress_notification does not check whether the token's request has already " + "completed; the late notification is sent and reaches the client." + ), + ), + ), + "protocol:progress:late-dropped-by-client": Requirement( + source="sdk", + behavior=( + "A progress notification that arrives after its request has completed is not delivered to the " + "original progress callback." ), ), "protocol:progress:no-token": Requirement( @@ -415,6 +443,12 @@ def __post_init__(self) -> None: "protocol:timeout:max-total": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", behavior="A maximum total timeout is enforced even when progress notifications keep arriving.", + divergence=Divergence( + note=( + "There is no maximum-total-timeout option; only the per-request read timeout exists, so the " + "spec's SHOULD that an overall maximum is always enforced cannot be satisfied." + ), + ), deferred=( "Not implemented in the SDK: there is no maximum-total-timeout option; only the per-request " "read timeout exists." @@ -1097,6 +1131,12 @@ def __post_init__(self) -> None: "The server does not use includeContext values thisServer or allServers unless the client " "declared the sampling.context capability." ), + divergence=Divergence( + note=( + "include_context is forwarded regardless of the client's declared sampling.context " + "capability; the server-side validator only checks tools/tool_choice." + ), + ), deferred=( "Not implemented in the SDK: include_context is forwarded regardless of the client's declared " "sampling.context capability (unlike tools, which are gated by the server-side validator)." @@ -1223,6 +1263,12 @@ def __post_init__(self) -> None: "The server refuses to send an elicitation request with a mode the connected client did not " "declare in its capabilities." ), + divergence=Divergence( + note=( + "The server does not check the client's declared elicitation modes before sending " + "elicitation/create; the spec's SHOULD is not enforced." + ), + ), deferred=( "Not implemented in the SDK: the server does not check the client's declared elicitation " "modes before sending elicitation/create." @@ -1295,6 +1341,12 @@ def __post_init__(self) -> None: "Form-mode requested schemas are flat objects with primitive-typed properties only; nested " "structures and arrays of objects are not used." ), + divergence=Divergence( + note=( + "Nothing restricts or validates the requested-schema shape on the sending side; a server " + "can send nested or non-primitive schemas and the SDK forwards them unchanged." + ), + ), deferred=( "Not implemented in the SDK: nothing restricts or validates the requested-schema shape on the " "sending side; hand-built lowlevel elicitation requests pass through unchecked." @@ -1306,6 +1358,9 @@ def __post_init__(self) -> None: "Accepted form-mode content is validated against the requested schema: the client validates " "the response before sending and the server validates the content it receives." ), + divergence=Divergence( + note="Accepted elicitation content passes through unvalidated on both sides.", + ), deferred=("Not implemented in the SDK: accepted elicitation content passes through unvalidated on both sides."), ), "elicitation:url:action:accept-no-content": Requirement( @@ -2147,6 +2202,12 @@ def __post_init__(self) -> None: "with a fresh InitializeRequest and no session ID attached." ), transports=("streamable-http",), + divergence=Divergence( + note=( + "The client surfaces the 404 as an error to the caller instead of re-initializing a new " + "session; the spec's MUST is not satisfied." + ), + ), deferred=( "Not implemented in the SDK: the client surfaces the 404 as an error to the caller instead of " "re-initializing a new session." diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index eb07ef9404..f39b2014cf 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -11,9 +11,25 @@ from inline_snapshot import snapshot from mcp import MCPError, types +from mcp.client import ClientSession from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, ErrorData, TextContent +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + EmptyResult, + ErrorData, + Implementation, + InitializeResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + PingRequest, + ServerCapabilities, + TextContent, +) from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -137,3 +153,80 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara result = await client.call_tool("echo", {}) assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) + + +@requirement("protocol:cancel:late-response-ignored") +async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: + """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. + + The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; + that is the same client-side code path as any response with an unknown id, and that form is + deterministic to test without depending on the cancellation API the SDK does not yet provide. + See the divergence note on the requirement. + + A real Server cannot be made to answer with a fabricated id, so the test plays the server's + side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The + other tests in this file run over the transport matrix; this one is in-memory only because the + scripted-peer mechanism is the in-memory stream pair, not because the behaviour is + transport-specific. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(scripted_server) + async with ClientSession(client_read, client_write, message_handler=message_handler) as session: + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 32da2f3338..027c80505d 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -331,3 +331,46 @@ async def scripted_server() -> None: await session.initialize() assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + + +@requirement("lifecycle:version:downgrade") +async def test_an_older_supported_protocol_version_from_the_server_is_accepted() -> None: + """An initialize response carrying an older supported protocol version completes the handshake at that version. + + A real Server answers with the version the client requested (or its own latest), so this test + plays the server's side of the wire by hand to return a fixed older version regardless of what + was requested. Reserve this pattern for behaviour no real server can be made to produce. + """ + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + initialize_result = await session.initialize() + + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 56eae40d7d..54faf85888 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -13,8 +13,11 @@ from mcp import types from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, ProgressNotificationParams, TextContent +from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT +from mcp.types import CallToolResult, ProgressNotification, ProgressNotificationParams, ProgressToken, TextContent from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -126,3 +129,161 @@ async def on_progress(ctx: ServerRequestContext, params: ProgressNotificationPar assert received == snapshot( [ProgressNotificationParams(progress_token="upload-1", progress=0.5, total=1.0, message="halfway")] ) + + +@requirement("protocol:progress:token-unique") +async def test_concurrent_requests_carry_distinct_progress_tokens(connect: Connect) -> None: + """Two concurrent requests carry distinct progress tokens, and each callback sees only its own progress. + + Without the barrier the first call could run to completion before the second starts, so only one + token would be live at a time and the demultiplexing would never be exercised. The handlers each + block until both have started and then hand control back and forth so the four progress + notifications are emitted in strict a, b, a, b order on the wire. The two handlers send different + progress values so a stream swap (token A delivered to callback B and vice versa) would fail: each + callback receiving exactly its own values proves notifications are routed by token, not by arrival + order or by chance. + """ + progress_values = {"a": (1.0, 2.0), "b": (10.0, 20.0)} + tokens: dict[str, ProgressToken] = {} + entered = {"a": anyio.Event(), "b": anyio.Event()} + # turns[n] is set to release the nth emission; each emission releases the next. + turns = [anyio.Event() for _ in range(4)] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert params.arguments is not None + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + label = params.arguments["label"] + tokens[label] = token + entered[label].set() + # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. + first, second = (0, 2) if label == "a" else (1, 3) + await turns[first].wait() + await ctx.session.send_progress_notification(token, progress_values[label][0]) + turns[first + 1].set() + await turns[second].wait() + await ctx.session.send_progress_notification(token, progress_values[label][1]) + if second + 1 < len(turns): + turns[second + 1].set() + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received_a: list[float] = [] + received_b: list[float] = [] + + async def collect_a(progress: float, total: float | None, message: str | None) -> None: + received_a.append(progress) + + async def collect_b(progress: float, total: float | None, message: str | None) -> None: + received_b.append(progress) + + async with connect(server) as client: + + async def call(label: str, collect: ProgressFnT) -> None: + await client.call_tool("report", {"label": label}, progress_callback=collect) + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: + task_group.start_soon(call, "a", collect_a) + task_group.start_soon(call, "b", collect_b) + await entered["a"].wait() + await entered["b"].wait() + turns[0].set() + + assert tokens["a"] != tokens["b"] + assert received_a == [1.0, 2.0] + assert received_b == [10.0, 20.0] + + +@requirement("protocol:progress:stops-after-completion") +@requirement("protocol:progress:late-dropped-by-client") +async def test_progress_sent_after_the_response_is_not_delivered_to_the_callback(connect: Connect) -> None: + """A progress notification sent after the response is emitted, and the client drops it from the callback. + + This single body proves both halves: the server's `send_progress_notification` happily sends for + a token whose request has already completed (the spec MUST that progress stops is not enforced; + see the divergence on `stops-after-completion`), and the client, having removed the callback when + the call returned, does not deliver the late notification to it. The message handler observes the + late notification arriving so the test knows when to assert without polling. + """ + captured: list[tuple[ServerSession, ProgressToken]] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="report", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "report" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + captured.append((ctx.session, token)) + await ctx.session.send_progress_notification(token, 0.5) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[float] = [] + late_progress_arrived = anyio.Event() + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def message_handler(message: IncomingMessage) -> None: + if isinstance(message, ProgressNotification) and message.params.progress == 1.0: + late_progress_arrived.set() + + async with connect(server, message_handler=message_handler) as client: + with anyio.fail_after(5): + await client.call_tool("report", {}, progress_callback=collect) + assert received == [0.5] + + server_session, token = captured[0] + await server_session.send_progress_notification(token, 1.0) + await late_progress_arrived.wait() + + assert received == [0.5] + + +@requirement("protocol:progress:monotonic") +async def test_non_increasing_progress_values_are_forwarded_unchanged(connect: Connect) -> None: + """A handler that emits non-increasing progress values has them forwarded to the callback unchanged. + + The spec says progress MUST increase with each notification; the SDK does not enforce that on + either side. See the divergence note on the requirement. + """ + received: list[float] = [] + + async def collect(progress: float, total: float | None, message: str | None) -> None: + received.append(progress) + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="zigzag", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "zigzag" + assert ctx.meta is not None + token = ctx.meta.get("progress_token") + assert token is not None + await ctx.session.send_progress_notification(token, 0.5) + await ctx.session.send_progress_notification(token, 0.3) + await ctx.session.send_progress_notification(token, 0.9) + return CallToolResult(content=[TextContent(text="done")]) + + server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.call_tool("zigzag", {}, progress_callback=collect) + + assert received == snapshot([0.5, 0.3, 0.9]) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index f7e55ecaf3..62a2032ac1 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -4,18 +4,36 @@ The tests wrap the in-memory transport in a RecordingTransport, which tees every message crossing the transport seam into a list without touching the session, so the assertions hold for whatever the session implementation sends rather than for what its API returns. + +The final two tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and one sends a deliberately malformed +JSON-RPC request that the typed client API cannot produce. """ import anyio import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types +from mcp.client import ClientSession from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams from mcp.shared.message import SessionMessage -from mcp.types import CallToolResult, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, TextContent +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CallToolRequest, + CallToolRequestParams, + CallToolResult, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + TextContent, +) from tests.interaction._helpers import RecordingTransport, _RecordingReadStream from tests.interaction._requirements import requirement @@ -119,3 +137,105 @@ async def test_exactly_one_initialized_notification_is_sent_after_the_handshake( ] assert sent_methods.count("notifications/initialized") == 1 assert sent_methods == snapshot(["initialize", "notifications/initialized", "tools/list"]) + + +@requirement("protocol:error:connection-closed") +async def test_closing_the_transport_fails_in_flight_requests_with_connection_closed() -> None: + """When the server-to-client stream closes, every in-flight client request fails with CONNECTION_CLOSED. + + Driven over a bare ClientSession against a real Server so the test holds the transport stream + pair directly: once the request is in flight (the server handler signals it has started) the + test closes the server's write stream, which ends the client's receive loop and triggers the + teardown that fails the pending request. + """ + handler_started = anyio.Event() + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "block" + handler_started.set() + await anyio.Event().wait() # blocks until cancelled; nothing ever sets this event + raise NotImplementedError # unreachable: the wait above never completes normally + + server = Server("blocker", on_call_tool=call_tool) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + errors: list[ErrorData] = [] + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + async with ClientSession(client_read, client_write) as session: + with anyio.fail_after(5): + await session.initialize() + + async def call_and_capture_error() -> None: + with pytest.raises(MCPError) as exc_info: + await session.send_request( + CallToolRequest(params=CallToolRequestParams(name="block")), CallToolResult + ) + errors.append(exc_info.value.error) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(call_and_capture_error) + await handler_started.wait() + await server_write.aclose() + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=CONNECTION_CLOSED, message="Connection closed")]) + + +@requirement("protocol:error:invalid-params") +async def test_malformed_request_params_are_answered_with_invalid_params() -> None: + """A request whose params fail validation is answered with -32602 Invalid params. + + The typed client API cannot construct a request with the wrong parameter types, so the test + plays the client's side of the wire by hand against a real Server: it completes the + initialization handshake at the JSON-RPC layer and then sends a tools/call whose `name` is an + integer. Reserve this pattern for behaviour the typed API cannot produce. + """ + server = Server("strict") + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params={"name": 42})) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) From 01f6a636fc70ed5ac643c2f4928326e3476223fc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 10:12:12 +0000 Subject: [PATCH 19/34] test: cover sampling, client output-schema, and mcpserver gap requirements --- src/mcp/client/session.py | 4 +- src/mcp/server/mcpserver/prompts/base.py | 2 +- tests/interaction/_requirements.py | 96 ++++-- tests/interaction/lowlevel/test_sampling.py | 297 +++++++++++++++++- tests/interaction/lowlevel/test_tools.py | 111 +++++++ .../interaction/mcpserver/test_completion.py | 38 +++ tests/interaction/mcpserver/test_prompts.py | 45 +++ tests/interaction/mcpserver/test_resources.py | 21 ++ tests/interaction/mcpserver/test_tools.py | 134 +++++++- 9 files changed, 709 insertions(+), 39 deletions(-) create mode 100644 tests/interaction/mcpserver/test_completion.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index cf92696682..86113874be 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -337,9 +337,7 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - from jsonschema import SchemaError, ValidationError, validate if result.structured_content is None: - raise RuntimeError( - f"Tool {name} has an output schema but did not return structured content" - ) # pragma: no cover + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") try: validate(result.structured_content, output_schema) except ValidationError as e: diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index e5b2af7d82..2f778eb514 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -185,5 +185,5 @@ async def render( raise ValueError(f"Could not convert prompt result to message: {msg}") return messages - except Exception as e: # pragma: no cover + except Exception as e: raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index ccb2b15eae..d64f79a93d 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -633,10 +633,6 @@ def __post_init__(self) -> None: "client:output-schema:skip-on-error": Requirement( source="sdk", behavior="The client skips structured-content validation when the tool result has isError true.", - deferred=( - "Not yet covered here: planned gap test (an isError result with mismatching structuredContent " - "is returned to the caller rather than rejected)." - ), ), "client:output-schema:validate": Requirement( source=f"{SPEC_BASE_URL}/server/tools#output-schema", @@ -645,10 +641,27 @@ def __post_init__(self) -> None: "is rejected by the client: the call raises instead of returning the invalid result." ), ), + "client:output-schema:missing-structured": Requirement( + source="sdk", + behavior="A tool that declares an output schema but returns no structuredContent fails client-side validation.", + ), + "client:output-schema:auto-list": Requirement( + source="sdk", + behavior=( + "Calling a tool whose output schema is not yet cached issues an implicit tools/list to " + "populate the cache; subsequent calls of the same tool do not." + ), + divergence=Divergence( + note=( + "Design concern rather than spec violation: the implicit request is invisible to the " + "caller, and against a server that registers only on_call_tool a successful call surfaces " + "as METHOD_NOT_FOUND from a tools/list the caller never asked for." + ), + ), + ), "mcpserver:output-schema:missing-structured": Requirement( source=f"{SPEC_BASE_URL}/server/tools#output-schema", behavior="A tool with an output schema whose function returns no structured content produces a server error.", - deferred="Not yet covered here: planned gap test (output schema declared but no structured content returned).", ), "mcpserver:output-schema:server-validate": Requirement( source=f"{SPEC_BASE_URL}/server/tools#output-schema", @@ -656,17 +669,21 @@ def __post_init__(self) -> None: "MCPServer validates structured content against the tool's output schema before returning; a " "mismatch produces a server error." ), - deferred="Not yet covered here: planned gap test (server-side output schema validation failure).", ), "mcpserver:output-schema:skip-on-error": Requirement( source="sdk", behavior="Server-side output schema validation is skipped when the tool returns an isError result.", - deferred="Not yet covered here: planned gap test (isError results bypass server-side schema validation).", ), "mcpserver:tool:duplicate-name": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool-names", behavior="Registering a tool with a name already in use is rejected at registration time.", - deferred="Not yet covered here: planned gap test (duplicate tool registration).", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; " + "warn_on_duplicate_tools defaults to True and warning is the only effect -- there is " + "no rejection mode." + ), + ), ), "mcpserver:tool:extra": Requirement( source="sdk", @@ -693,7 +710,10 @@ def __post_init__(self) -> None: "mcpserver:tool:naming-validation": Requirement( source="sdk", behavior="Tool names that violate the spec's naming rules are rejected at registration time.", - deferred="Not yet covered here: tool-name validation at registration has not been pinned yet.", + deferred=( + "Not implemented in the SDK: MCPServer accepts any string as a tool name; there is no " + "spec-naming-rules check at registration time." + ), ), "mcpserver:tool:output-schema:model": Requirement( source="sdk", @@ -737,10 +757,6 @@ def __post_init__(self) -> None: "A tool function that raises the URL-elicitation-required error surfaces to the caller as " "error -32042 with the elicitation parameters intact." ), - deferred=( - "Not yet covered here: the low-level equivalent is pinned by elicitation:url:required-error; " - "the MCPServer-decorated path is a planned gap test." - ), ), # ═══════════════════════════════════════════════════════════════════════════ # MCPServer: Context helpers (SDK) @@ -882,12 +898,20 @@ def __post_init__(self) -> None: "mcpserver:resource:duplicate-name": Requirement( source="sdk", behavior="Registering a resource or template with a duplicate identifier is rejected at registration time.", - deferred="Not yet covered here: planned gap test (duplicate resource registration).", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + deferred=( + "Not yet covered here: mechanical sibling of mcpserver:tool:duplicate-name (same " + "warn-and-ignore behaviour); planned as a small follow-on to that test." + ), ), "mcpserver:resource:read-throws-surfaced": Requirement( source="sdk", behavior="A resource function that raises is surfaced to the caller as a JSON-RPC error response.", - deferred="Not yet covered here: planned gap test (resource function raising during read).", ), "mcpserver:resource:static": Requirement( source="sdk", @@ -983,7 +1007,6 @@ def __post_init__(self) -> None: "mcpserver:prompt:args-validation": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#implementation-considerations", behavior="prompts/get arguments that fail the prompt's argument schema are rejected before the function runs.", - deferred="Not yet covered here: planned gap test (argument validation on decorated prompts).", ), "mcpserver:prompt:decorated": Requirement( source="sdk", @@ -995,12 +1018,20 @@ def __post_init__(self) -> None: "mcpserver:prompt:duplicate-name": Requirement( source="sdk", behavior="Registering a duplicate prompt name is rejected at registration time.", - deferred="Not yet covered here: planned gap test (duplicate prompt registration).", + divergence=Divergence( + note=( + "MCPServer logs a warning and keeps the first registration instead of rejecting; same " + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + ), + ), + deferred=( + "Not yet covered here: mechanical sibling of mcpserver:tool:duplicate-name (same " + "warn-and-ignore behaviour); planned as a small follow-on to that test." + ), ), "mcpserver:prompt:optional-args": Requirement( source="sdk", behavior="A prompt with optional arguments can be fetched without supplying them.", - deferred="Not yet covered here: planned gap test (optional prompt arguments omitted).", ), "mcpserver:prompt:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", @@ -1056,7 +1087,6 @@ def __post_init__(self) -> None: "MCPServer advertises the completions capability when at least one completion source is " "registered, and omits it otherwise." ), - deferred="Not yet covered here: planned gap test (automatic completions capability derivation).", ), # ═══════════════════════════════════════════════════════════════════════════ # Logging @@ -1112,7 +1142,6 @@ def __post_init__(self) -> None: behavior=( "A client that handles sampling requests advertises the sampling capability in its initialize request." ), - deferred="Not yet covered here: planned gap test (positive sampling capability declaration).", ), "sampling:create:basic": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#creating-messages", @@ -1137,10 +1166,6 @@ def __post_init__(self) -> None: "capability; the server-side validator only checks tools/tool_choice." ), ), - deferred=( - "Not implemented in the SDK: include_context is forwarded regardless of the client's declared " - "sampling.context capability (unlike tools, which are gated by the server-side validator)." - ), ), "sampling:create:model-preferences": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#model-preferences", @@ -1168,7 +1193,6 @@ def __post_init__(self) -> None: "sampling:create-message:audio-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#audio-content", behavior="Sampling messages can carry audio content: base64 data with a mimeType.", - deferred="Not yet covered here: planned gap test (audio content in sampling messages, both directions).", ), "sampling:create-message:image-content": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#image-content", @@ -1191,7 +1215,6 @@ def __post_init__(self) -> None: "sampling:message:content-cardinality": Requirement( source=f"{SPEC_BASE_URL}/client/sampling", behavior="A sampling message's content may be a single block or an array of blocks.", - deferred="Not yet covered here: planned gap test (list-valued sampling message content).", ), "sampling:result:no-tools-single-content": Requirement( source="sdk", @@ -1199,7 +1222,13 @@ def __post_init__(self) -> None: "When the request carries no tools, a sampling callback result whose content is an array is " "rejected by the client." ), - deferred="Not yet covered here: planned gap test (array content rejected for tool-free sampling).", + divergence=Divergence( + note=( + "The client does not validate the callback result against the request shape; an array-content " + "result for a tool-free request is accepted client-side and surfaces as a raw " + "pydantic.ValidationError from the server's response parsing (send_request) instead." + ), + ), ), "sampling:result:with-tools-array-content": Requirement( source="sdk", @@ -1225,7 +1254,6 @@ def __post_init__(self) -> None: "Every assistant tool_use block in a sampling request must be matched by a tool_result with " "the same id in the following user message; an unmatched tool_use is rejected with Invalid params." ), - deferred="Not yet covered here: planned gap test (unmatched tool_use rejected by the validator).", ), "sampling:tools:server-gated-by-capability": Requirement( source=f"{SPEC_BASE_URL}/client/sampling#tools-in-sampling", @@ -1433,9 +1461,10 @@ def __post_init__(self) -> None: "of roots changes." ), deferred=( - "Not implemented in the SDK: the client keeps no managed roots store, so nothing fires " - "automatically when the configured roots change; emission is an explicit " - "send_roots_list_changed() call (pinned by roots:list-changed)." + "Not implemented in the SDK: the client does not own the root set (it calls back to the host " + "via list_roots_callback), so there is no mutation it could observe to auto-emit on; the SDK " + "provides send_roots_list_changed() for the host to call when its roots change, and that " + "emission path is covered by roots:list-changed." ), ), "roots:list:basic": Requirement( @@ -1467,8 +1496,9 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/client/roots#root", behavior="Every root returned by the client identifies itself with a file:// URI.", deferred=( - "Not yet covered here: planned gap test (the SDK's Root type enforces the file:// scheme; pin " - "it end-to-end through roots/list)." + "Schema-level validation: the FileUrl type on Root.uri rejects any non-file:// scheme at " + "construction and at parse, so a non-conforming root cannot reach the wire from either side; " + "type-level coverage belongs in tests/test_types.py rather than this interaction suite." ), ), # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 85eb8c3455..53a246b2e8 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -5,6 +5,7 @@ round-trips what it received back to the test through its tool result. """ +import pydantic import pytest from inline_snapshot import snapshot @@ -12,16 +13,20 @@ from mcp.client import ClientRequestContext from mcp.server import Server, ServerRequestContext from mcp.types import ( + AudioContent, CallToolResult, CreateMessageRequestParams, CreateMessageResult, + CreateMessageResultWithTools, ErrorData, ImageContent, ModelHint, ModelPreferences, + SamplingCapability, SamplingMessage, TextContent, ToolResultContent, + ToolUseContent, ) from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -82,8 +87,14 @@ async def sampling_callback( @requirement("sampling:create:include-context") @requirement("sampling:create:model-preferences") @requirement("sampling:create:system-prompt") +@requirement("sampling:context:server-gated-by-capability") async def test_create_message_params_reach_callback(connect: Connect) -> None: - """Every sampling parameter the handler supplies arrives at the client callback unchanged.""" + """Every sampling parameter the handler supplies arrives at the client callback unchanged. + + The client has not declared the sampling.context capability (Client cannot declare it), yet + include_context="thisServer" reaches the callback regardless: the spec's SHOULD NOT is not + enforced. See the divergence note on `sampling:context:server-gated-by-capability`. + """ received: list[CreateMessageRequestParams] = [] async def list_tools( @@ -389,3 +400,287 @@ async def sampling_callback( ] ) ) + + +@requirement("sampling:capability:declare") +async def test_a_client_with_a_sampling_callback_declares_the_sampling_capability(connect: Connect) -> None: + """A client connecting with a sampling callback advertises the sampling capability to the server. + + Client cannot declare any sub-capabilities (it does not expose ClientSession's + sampling_capabilities parameter), so the snapshot pins an empty SamplingCapability. + """ + captured: list[SamplingCapability | None] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="capabilities", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "capabilities" + assert ctx.session.client_params is not None + captured.append(ctx.session.client_params.capabilities.sampling) + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("introspector", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Registered only so the sampling capability is advertised; never called.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + await client.call_tool("capabilities", {}) + + assert captured == snapshot([SamplingCapability()]) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_request_with_audio_content_reaches_callback(connect: Connect) -> None: + """A sampling request message carrying audio content arrives at the client callback intact. + + This is the server-to-client direction: the server includes audio in the conversation it asks + the client to sample from. + """ + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="transcribe", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "transcribe" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + audio = params.messages[0].content + assert isinstance(audio, AudioContent) + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"transcribed {audio.mime_type} ({audio.data})"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("transcribe", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="transcribed audio/wav (c25k)")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[SamplingMessage(role="user", content=AudioContent(data="c25k", mime_type="audio/wav"))], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:create-message:audio-content") +async def test_create_message_result_with_audio_content_returns_to_handler(connect: Connect) -> None: + """A sampling result whose content is audio is returned to the requesting handler intact. + + This is the client-to-server direction: the model's response is audio rather than text. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="speak", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "speak" + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Say hello, aloud."))], + max_tokens=100, + ) + audio = result.content + assert isinstance(audio, AudioContent) + return CallToolResult(content=[TextContent(text=f"{result.model}: {audio.mime_type} {audio.data}")]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + return CreateMessageResult( + role="assistant", + content=AudioContent(data="aGVsbG8=", mime_type="audio/wav"), + model="mock-audio-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("speak", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mock-audio-1: audio/wav aGVsbG8=")])) + + +@requirement("sampling:message:content-cardinality") +async def test_create_message_with_list_valued_message_content_reaches_callback(connect: Connect) -> None: + """A sampling message whose content is a list of blocks arrives at the client callback as a list.""" + received: list[CreateMessageRequestParams] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="caption", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "caption" + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + assert isinstance(result.content, TextContent) + return CallToolResult(content=[TextContent(text=result.content.text)]) + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + received.append(params) + content = params.messages[0].content + assert isinstance(content, list) + return CreateMessageResult( + role="assistant", content=TextContent(text=f"{len(content)} blocks"), model="mock-llm-1" + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("caption", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="2 blocks")])) + assert received == snapshot( + [ + CreateMessageRequestParams( + _meta={}, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(text="Caption this image."), + ImageContent(data="aW1n", mime_type="image/png"), + ], + ) + ], + max_tokens=100, + ) + ] + ) + + +@requirement("sampling:tool-use:result-balance") +async def test_create_message_with_mismatched_tool_use_and_result_ids_is_rejected(connect: Connect) -> None: + """A sampling request whose tool_result ids do not match the preceding tool_use ids never leaves the server. + + The message-structure validation runs inside create_message before the request is sent, so the + client callback is never invoked and the handler observes the ValueError directly. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="continue_tools", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "continue_tools" + try: + await ctx.session.create_message( + messages=[ + SamplingMessage( + role="assistant", + content=[ToolUseContent(id="call-1", name="weather", input={})], + ), + SamplingMessage( + role="user", + content=[ToolResultContent(tool_use_id="call-WRONG", content=[TextContent(text="42")])], + ), + ], + max_tokens=100, + ) + except ValueError as exc: + return CallToolResult(content=[TextContent(text=f"{type(exc).__name__}: {exc}")]) + raise NotImplementedError # the validator rejects the malformed messages before sending + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + """Declares the sampling capability; never invoked because the request is rejected first.""" + raise NotImplementedError + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("continue_tools", {}) + + assert result == snapshot( + CallToolResult( + content=[ + TextContent( + text="ValueError: ids of tool_result blocks and tool_use blocks from previous message do not match" + ) + ] + ) + ) + + +@requirement("sampling:result:no-tools-single-content") +async def test_array_content_result_for_a_tool_free_request_surfaces_as_a_validation_error(connect: Connect) -> None: + """An array-content sampling result for a tool-free request is accepted by the client and fails server-side. + + Only the exception type is asserted: the message is pydantic's, which changes across releases. + See the divergence note on the requirement: the intended behaviour is that the client rejects + the result; instead the client accepts it and the server's response parsing raises. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="ask_model", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "ask_model" + try: + await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text="Two thoughts, please."))], + max_tokens=100, + ) + except pydantic.ValidationError as exc: + return CallToolResult(content=[TextContent(text=type(exc).__name__)]) + raise NotImplementedError # the array-content result fails server-side parsing every time + + server = Server("sampler", on_list_tools=list_tools, on_call_tool=call_tool) + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResultWithTools: + return CreateMessageResultWithTools( + role="assistant", + content=[TextContent(text="First thought."), TextContent(text="Second thought.")], + model="mock-llm-1", + ) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("ask_model", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ValidationError")])) diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 49b04db2fa..7ab93d7655 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -353,3 +353,114 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara # The message embeds the jsonschema validation error, so only the SDK-authored prefix is pinned. assert str(exc_info.value).startswith("Invalid structured content returned by tool forecast") + + +@requirement("client:output-schema:skip-on-error") +async def test_is_error_result_bypasses_client_output_schema_validation(connect: Connect) -> None: + """A tool result with isError true is returned as-is even when its structured content violates the schema. + + The schema is cached up front so the client could validate, proving the bypass is specifically the + isError flag and not an empty cache. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={ + "type": "object", + "properties": {"temperature": {"type": "number"}}, + "required": ["temperature"], + }, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult( + content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True + ) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + result = await client.call_tool("forecast", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="boom")], structured_content={"temperature": "warm"}, is_error=True) + ) + + +@requirement("client:output-schema:missing-structured") +async def test_declared_output_schema_with_no_structured_content_is_rejected_by_the_client(connect: Connect) -> None: + """A tool that declared an output schema but returned no structuredContent fails the client-side check. + + The error is the SDK's own message, so the full text is snapshotted. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="warm")]) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + await client.list_tools() + with pytest.raises(RuntimeError) as exc_info: + await client.call_tool("forecast", {}) + + assert str(exc_info.value) == snapshot("Tool forecast has an output schema but did not return structured content") + + +@requirement("client:output-schema:auto-list") +async def test_call_tool_populates_the_output_schema_cache_via_an_implicit_tools_list(connect: Connect) -> None: + """Calling a tool whose schema is not cached issues exactly one implicit tools/list to populate it. + + The first call_tool of an uncached tool triggers a tools/list the caller never asked for; the + second call hits the cache and does not. This is the SDK's chosen cache strategy and the cause of + the surprising behaviour where a server with only on_call_tool sees a successful call answered + with METHOD_NOT_FOUND from a request the caller never made; see the divergence on the requirement. + """ + list_calls: list[str] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + list_calls.append("called") + return ListToolsResult( + tools=[ + Tool( + name="forecast", + input_schema={"type": "object"}, + output_schema={"type": "object", "properties": {"temperature": {"type": "number"}}}, + ) + ] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "forecast" + return CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21}) + + server = Server("weather", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + first = await client.call_tool("forecast", {}) + assert list_calls == ["called"] + second = await client.call_tool("forecast", {}) + + assert list_calls == ["called"] + assert first == snapshot(CallToolResult(content=[TextContent(text="21 C")], structured_content={"temperature": 21})) + assert second == first diff --git a/tests/interaction/mcpserver/test_completion.py b/tests/interaction/mcpserver/test_completion.py new file mode 100644 index 0000000000..7761066e94 --- /dev/null +++ b/tests/interaction/mcpserver/test_completion.py @@ -0,0 +1,38 @@ +"""Completion behaviour against MCPServer, driven through the public Client API.""" + +import pytest + +from mcp.server.mcpserver import MCPServer +from mcp.types import ( + Completion, + CompletionArgument, + CompletionContext, + CompletionsCapability, + PromptReference, + ResourceTemplateReference, +) +from tests.interaction._connect import Connect +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("mcpserver:completion:capability-auto") +async def test_completion_capability_is_advertised_only_when_a_handler_is_registered(connect: Connect) -> None: + """An MCPServer with a registered completion handler advertises the completions capability; one without does not.""" + with_handler = MCPServer("completer") + + @with_handler.completion() + async def complete( + ref: PromptReference | ResourceTemplateReference, + argument: CompletionArgument, + context: CompletionContext | None, + ) -> Completion | None: + """Registered only so the completions capability is advertised; never called.""" + raise NotImplementedError + + async with connect(with_handler) as client: + assert client.initialize_result.capabilities.completions == CompletionsCapability() + + async with connect(MCPServer("plain")) as client: + assert client.initialize_result.capabilities.completions is None diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index e4cb03d8f5..8c7a653af2 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -110,3 +110,48 @@ def greet(name: str) -> str: await client.get_prompt("greet") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Missing required arguments: {'name'}")) + + +@requirement("mcpserver:prompt:args-validation") +async def test_get_prompt_with_a_wrong_type_argument_is_rejected_before_the_function_runs(connect: Connect) -> None: + """An argument that fails the function signature's type validation is rejected before the function runs. + + The decorated function is wrapped in pydantic's validate_call, so a value that cannot be + coerced to the parameter's annotation fails before the body executes. The function body + raises NotImplementedError to prove it never ran. The error is wrapped in the SDK's stable + rendering-error prefix; the body of the message is raw pydantic output and is not asserted. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def repeat(phrase: str, count: int) -> str: + """A registered prompt; type validation rejects the call before the function runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.get_prompt("repeat", {"phrase": "hi", "count": "many"}) + + assert exc_info.value.error.code == 0 + assert exc_info.value.error.message.startswith("Error rendering prompt repeat: 1 validation error") + + +@requirement("mcpserver:prompt:optional-args") +async def test_get_prompt_with_an_optional_argument_omitted_uses_the_default(connect: Connect) -> None: + """A prompt rendered without one of its optional arguments uses that parameter's default value.""" + mcp = MCPServer("prompter") + + @mcp.prompt() + def review(code: str, style: str = "pep8") -> str: + """Review a snippet of code against a style guide.""" + return f"Review {code} per {style}." + + async with connect(mcp) as client: + result = await client.get_prompt("review", {"code": "x = 1"}) + + assert result == snapshot( + GetPromptResult( + description="Review a snippet of code against a style guide.", + messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index 8960eb2be2..d208857fe5 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -128,3 +128,24 @@ def app_config() -> str: await client.read_resource("config://missing") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Unknown resource: config://missing")) + + +@requirement("mcpserver:resource:read-throws-surfaced") +async def test_resource_function_that_raises_is_surfaced_as_a_jsonrpc_error(connect: Connect) -> None: + """An exception raised by a resource function reaches the caller as a JSON-RPC error. + + MCPServer wraps the failure in a generic error that names only the URI, so the original + exception text is not leaked to the client. The wrapped exception becomes error code 0 the + same way every other unhandled server-side exception does. + """ + mcp = MCPServer("library") + + @mcp.resource("res://boom") + def boom() -> str: + raise RuntimeError("nope") + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("res://boom") + + assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index ac6fd59650..e66538ce09 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -1,13 +1,20 @@ """Tool interactions against MCPServer, driven through the public Client API.""" +from typing import Annotated + import pytest from inline_snapshot import snapshot from pydantic import BaseModel +from mcp import MCPError from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.exceptions import ToolError +from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.types import ( + URL_ELICITATION_REQUIRED, CallToolResult, + ElicitRequestURLParams, + ErrorData, LoggingMessageNotification, LoggingMessageNotificationParams, TextContent, @@ -39,8 +46,14 @@ def add(a: int, b: int) -> str: @requirement("mcpserver:tool:handler-throws") +@requirement("mcpserver:output-schema:skip-on-error") async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: - """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error.""" + """An exception raised by a tool function is returned as an is_error result, not a JSON-RPC error. + + The function's `-> str` annotation gives the tool a derived output schema, but the error + result is built before any schema validation runs, so no validation failure is layered on + top of the original exception. + """ mcp = MCPServer("errors") @mcp.tool() @@ -193,6 +206,125 @@ def add(a: int, b: int) -> str: assert result.content[0].text.startswith("Error executing tool add: 1 validation error") +@requirement("mcpserver:output-schema:server-validate") +@requirement("mcpserver:output-schema:missing-structured") +async def test_tool_with_output_schema_returning_mismatched_structured_content_is_an_error_result( + connect: Connect, +) -> None: + """Structured content that fails the tool's own output schema is rejected on the server side. + + A tool annotated `Annotated[CallToolResult, Model]` returns a hand-built CallToolResult while + declaring `Model` as its output schema; MCPServer validates the supplied structured_content + against that schema before returning. The two cases -- a content shape that does not match, + and no structured content at all -- both fail that validation and are reported as is_error + results carrying the (raw pydantic) validation error wrapped in the SDK's stable prefix. + """ + mcp = MCPServer("forecaster") + + class Weather(BaseModel): + temperature: float + conditions: str + + @mcp.tool() + def mismatched() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")], structured_content={"nope": True}) + + @mcp.tool() + def missing() -> Annotated[CallToolResult, Weather]: + return CallToolResult(content=[TextContent(text="oops")]) + + async with connect(mcp) as client: + mismatched_result = await client.call_tool("mismatched", {}) + missing_result = await client.call_tool("missing", {}) + + # The body of each message is raw pydantic ValidationError output (model name, field paths, + # an errors.pydantic.dev URL) and changes across pydantic versions, so only the SDK's stable + # prefix is asserted. + assert mismatched_result.is_error is True + assert isinstance(mismatched_result.content[0], TextContent) + assert mismatched_result.content[0].text.startswith("Error executing tool mismatched: 2 validation errors") + + assert missing_result.is_error is True + assert isinstance(missing_result.content[0], TextContent) + assert missing_result.content[0].text.startswith("Error executing tool missing: 1 validation error") + + +@requirement("mcpserver:tool:duplicate-name") +async def test_registering_a_duplicate_tool_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second tool with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via add_tool with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("duplicates") + + @mcp.tool() + def echo() -> str: + return "first" + + def echo_second() -> str: + """Passed to add_tool with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + mcp.add_tool(echo_second, name="echo") + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("echo", {}) + + assert [tool.name for tool in listed.tools] == ["echo"] + assert result == snapshot( + CallToolResult(content=[TextContent(text="first")], structured_content={"result": "first"}) + ) + + +@requirement("mcpserver:tool:url-elicitation-error") +async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: + """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. + + MCPServer wraps every other tool exception as an is_error result; this error is special-cased + so it propagates as the JSON-RPC error the client needs in order to present the listed URL + interactions and retry the call. + """ + mcp = MCPServer("authorizer") + + @mcp.tool() + def read_files() -> str: + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorization required for your files.", + url="https://example.com/oauth/authorize", + elicitation_id="auth-001", + ) + ] + ) + + async with connect(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + assert exc_info.value.error == snapshot( + ErrorData( + code=-32042, + message="URL elicitation required", + data={ + "elicitations": [ + { + "mode": "url", + "message": "Authorization required for your files.", + "url": "https://example.com/oauth/authorize", + "elicitationId": "auth-001", + } + ] + }, + ) + ) + + @requirement("mcpserver:register:post-connect") async def test_adding_and_removing_tools_does_not_notify_connected_clients(connect: Connect) -> None: """Mutating the tool set on a running server changes tools/list but sends no notification. From 1e0d4f65efee131b4954341545daea6afc26ec89 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 10:26:28 +0000 Subject: [PATCH 20/34] test: cover server-feature, pagination, elicitation, and mcpserver gap requirements --- tests/interaction/_requirements.py | 66 +--- tests/interaction/lowlevel/test_completion.py | 23 ++ .../interaction/lowlevel/test_elicitation.py | 292 +++++++++++++++++- tests/interaction/lowlevel/test_pagination.py | 67 +++- tests/interaction/lowlevel/test_prompts.py | 70 +++++ tests/interaction/lowlevel/test_resources.py | 30 +- tests/interaction/lowlevel/test_tools.py | 46 +++ tests/interaction/lowlevel/test_wire.py | 68 +++- tests/interaction/mcpserver/test_prompts.py | 34 ++ tests/interaction/mcpserver/test_resources.py | 32 ++ 10 files changed, 661 insertions(+), 67 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d64f79a93d..1c2cd21f36 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -333,8 +333,8 @@ def __post_init__(self) -> None: "direction and are believed to still be in flight." ), deferred=( - "Not yet covered here: there is no public client-side cancel API to drive (see " - "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin yet." + "Not implemented in the SDK: there is no public client-side cancel API to drive (see " + "protocol:cancel:abort-signal), so the sender-side targeting rule has nothing to pin." ), ), "protocol:error:connection-closed": Requirement( @@ -573,34 +573,18 @@ def __post_init__(self) -> None: "A tool registered with a JSON Schema 2020-12 inputSchema (nested objects, $defs references) " "is discoverable and callable." ), - deferred=( - "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " - "interaction-level passthrough test is planned with the gap batch." - ), ), "tools:input-schema:preserve-additional-properties": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool", behavior="tools/list preserves inputSchema additionalProperties as registered.", - deferred=( - "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " - "interaction-level passthrough test is planned with the gap batch." - ), ), "tools:input-schema:preserve-defs": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool", behavior="tools/list preserves inputSchema $defs as registered.", - deferred=( - "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " - "interaction-level passthrough test is planned with the gap batch." - ), ), "tools:input-schema:preserve-schema-dialect": Requirement( source=f"{SPEC_BASE_URL}/server/tools#tool", behavior="tools/list preserves the inputSchema $schema dialect URI as registered.", - deferred=( - "Not yet covered here; existing coverage in tests/test_types.py at the type level; an " - "interaction-level passthrough test is planned with the gap batch." - ), ), "tools:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", @@ -791,10 +775,9 @@ def __post_init__(self) -> None: "resources:annotations": Requirement( source=f"{SPEC_BASE_URL}/server/resources#annotations", behavior=( - "Resource annotations (audience, priority, lastModified) supplied by the server round-trip to " - "the client in list and read results." + "Resource annotations (audience, priority) supplied by the server round-trip to the client " + "in the list result." ), - deferred="Not yet covered here: planned gap test (annotations passthrough on list and read results).", ), "resources:capability:declared": Requirement( source=f"{SPEC_BASE_URL}/server/resources#capabilities", @@ -846,10 +829,6 @@ def __post_init__(self) -> None: behavior=( "resources/subscribe to a server that did not advertise the subscribe capability is rejected with an error." ), - deferred=( - "Not yet covered here: planned gap test (subscribe rejected with METHOD_NOT_FOUND when no " - "subscribe handler is registered)." - ), ), "resources:subscribe:updated": Requirement( source=f"{SPEC_BASE_URL}/server/resources#subscriptions", @@ -901,13 +880,10 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "MCPServer logs a warning and keeps the first registration instead of rejecting; same " - "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." + "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name). " + "Templates differ: a duplicate uri_template silently replaces the first with no warning." ), ), - deferred=( - "Not yet covered here: mechanical sibling of mcpserver:tool:duplicate-name (same " - "warn-and-ignore behaviour); planned as a small follow-on to that test." - ), ), "mcpserver:resource:read-throws-surfaced": Requirement( source="sdk", @@ -947,17 +923,14 @@ def __post_init__(self) -> None: "prompts:get:content:audio": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#audio-content", behavior="Prompt messages may contain audio content with base64 data and a mimeType.", - deferred="Not yet covered here: planned gap test (audio content in prompt messages).", ), "prompts:get:content:embedded-resource": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#embedded-resources", behavior="Prompt messages may contain embedded resource content.", - deferred="Not yet covered here: planned gap test (embedded resources in prompt messages).", ), "prompts:get:content:image": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#image-content", behavior="Prompt messages may contain image content.", - deferred="Not yet covered here: planned gap test (image content in prompt messages).", ), "prompts:get:missing-required-args": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", @@ -976,7 +949,6 @@ def __post_init__(self) -> None: "prompts:get:no-args": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#getting-a-prompt", behavior="prompts/get with no arguments returns the prompt's messages.", - deferred="Not yet covered here: planned gap test (argument-free prompt fetched without arguments).", ), "prompts:get:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#error-handling", @@ -1024,10 +996,6 @@ def __post_init__(self) -> None: "warn-and-ignore behaviour as duplicate tool names (mcpserver:tool:duplicate-name)." ), ), - deferred=( - "Not yet covered here: mechanical sibling of mcpserver:tool:duplicate-name (same " - "warn-and-ignore behaviour); planned as a small follow-on to that test." - ), ), "mcpserver:prompt:optional-args": Requirement( source="sdk", @@ -1067,7 +1035,6 @@ def __post_init__(self) -> None: "completion/complete with a ref naming an unknown prompt or non-matching resource URI returns " "JSON-RPC error -32602 (Invalid params)." ), - deferred="Not yet covered here: planned gap test (completion against an unknown ref).", ), "completion:prompt-arg": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/completion#reference-types", @@ -1132,7 +1099,6 @@ def __post_init__(self) -> None: "logging:set-level:invalid-level": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/logging#error-handling", behavior="logging/setLevel with an invalid level value returns JSON-RPC error -32602 (Invalid params).", - deferred="Not yet covered here: planned gap test (invalid level value on setLevel).", ), # ═══════════════════════════════════════════════════════════════════════════ # Sampling (server → client) @@ -1294,13 +1260,9 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The server does not check the client's declared elicitation modes before sending " - "elicitation/create; the spec's SHOULD is not enforced." + "elicitation/create; the spec's MUST NOT is not enforced." ), ), - deferred=( - "Not implemented in the SDK: the server does not check the client's declared elicitation " - "modes before sending elicitation/create." - ), ), "elicitation:form:action:accept": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", @@ -1338,7 +1300,6 @@ def __post_init__(self) -> None: "elicitation:form:mode-omitted-default": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#elicitation-requests", behavior="An elicitation request with no mode field is treated as form mode by the client.", - deferred="Not yet covered here: planned gap test (mode-less elicitation request handled as form mode).", ), "elicitation:form:not-supported": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#error-handling", @@ -1356,12 +1317,10 @@ def __post_init__(self) -> None: "Requested-schema enum fields (including titled and multi-select variants) reach the client " "callback as sent." ), - deferred="Not yet covered here: planned gap test (enum variants in the requested schema).", ), "elicitation:form:schema:primitives": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", behavior="Requested-schema fields may be string (with format), number or integer, or boolean.", - deferred="Not yet covered here: planned gap test (full primitive-type coverage in the requested schema).", ), "elicitation:form:schema:restricted-subset": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#requested-schema", @@ -1375,10 +1334,6 @@ def __post_init__(self) -> None: "can send nested or non-primitive schemas and the SDK forwards them unchanged." ), ), - deferred=( - "Not implemented in the SDK: nothing restricts or validates the requested-schema shape on the " - "sending side; hand-built lowlevel elicitation requests pass through unchecked." - ), ), "elicitation:form:response-validation": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#form-mode-security", @@ -1389,7 +1344,6 @@ def __post_init__(self) -> None: divergence=Divergence( note="Accepted elicitation content passes through unvalidated on both sides.", ), - deferred=("Not implemented in the SDK: accepted elicitation content passes through unvalidated on both sides."), ), "elicitation:url:action:accept-no-content": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", @@ -1423,7 +1377,6 @@ def __post_init__(self) -> None: "The client ignores an elicitation/complete notification referencing an unknown or " "already-completed elicitationId without error." ), - deferred="Not yet covered here: planned gap test (unknown elicitationId in a complete notification).", ), "elicitation:url:decline": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#response-actions", @@ -1566,7 +1519,6 @@ def __post_init__(self) -> None: "pagination:invalid-cursor": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#error-handling", behavior="A list request with an invalid cursor returns JSON-RPC error -32602 (Invalid params).", - deferred="Not yet covered here: planned gap test (invalid pagination cursor rejected).", ), "pagination:client:cursor-handling": Requirement( source=f"{SPEC_BASE_URL}/server/utilities/pagination#implementation-guidelines", @@ -1574,10 +1526,6 @@ def __post_init__(self) -> None: "The client treats cursors as opaque tokens — it does not parse, modify, or persist them — " "and does not assume a fixed page size." ), - deferred=( - "Not yet covered here: planned gap test (the client passes a server-issued cursor back " - "byte-for-byte and follows pages of varying sizes)." - ), ), # ═══════════════════════════════════════════════════════════════════════════ # Tasks (experimental) diff --git a/tests/interaction/lowlevel/test_completion.py b/tests/interaction/lowlevel/test_completion.py index e036d48c3c..6a35404df3 100644 --- a/tests/interaction/lowlevel/test_completion.py +++ b/tests/interaction/lowlevel/test_completion.py @@ -6,6 +6,7 @@ from mcp import MCPError, types from mcp.server import Server, ServerRequestContext from mcp.types import ( + INVALID_PARAMS, METHOD_NOT_FOUND, CompleteResult, Completion, @@ -93,6 +94,28 @@ async def completion(ctx: ServerRequestContext, params: types.CompleteRequestPar assert result == snapshot(CompleteResult(completion=Completion(values=["modelcontextprotocol/python-sdk"]))) +@requirement("completion:error:invalid-ref") +async def test_completion_against_an_unknown_ref_is_rejected_with_invalid_params(connect: Connect) -> None: + """completion/complete with a ref naming an unknown prompt is answered with -32602 Invalid params. + + The lowlevel server does not validate refs itself (it has no prompt/template registry to check + against); rejecting an unknown ref is the handler's job, and this test pins the spec-recommended + way to do it. + """ + + async def completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> CompleteResult: + assert isinstance(params.ref, PromptReference) + raise MCPError(code=INVALID_PARAMS, message=f"Unknown prompt: {params.ref.name!r}") + + server = Server("completer", on_completion=completion) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.complete(PromptReference(name="ghost"), argument={"name": "x", "value": ""}) + + assert exc_info.value.error.code == INVALID_PARAMS + + @requirement("completion:complete:not-supported") @requirement("protocol:error:method-not-found") async def test_complete_without_handler_is_method_not_found(connect: Connect) -> None: diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index d27613dd36..83a77592a9 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -1,19 +1,34 @@ -"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API.""" +"""Form- and URL-mode elicitation against the low-level Server, driven through the public Client API. +The final test plays the server's side of the wire by hand to issue an elicitation request with no +mode field, because the typed server API (`elicit_form`/`elicit_url`) always serializes one. +""" + +import anyio import pytest from inline_snapshot import snapshot from mcp import MCPError, UrlElicitationRequiredError, types -from mcp.client import ClientRequestContext +from mcp.client import ClientRequestContext, ClientSession from mcp.server import Server, ServerRequestContext +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage from mcp.types import ( CallToolResult, ElicitCompleteNotification, ElicitCompleteNotificationParams, + ElicitRequestedSchema, ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, ErrorData, + Implementation, + InitializeResult, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ServerCapabilities, TextContent, ) from tests.interaction._connect import Connect @@ -143,12 +158,15 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest @requirement("elicitation:form:not-supported") +@requirement("elicitation:capability:server-respects-mode") async def test_elicit_form_without_callback_is_error(connect: Connect) -> None: """Eliciting from a client that configured no elicitation callback fails with an error. The client's default callback answers with an Invalid request error, which the server-side elicit call raises as an MCPError; the tool reports the code and message it caught. The spec - requires -32602 for an undeclared mode (see the divergence note on the requirement). + requires -32602 for an undeclared mode (see the divergence note on the requirement). The + request reaching the client also shows the server does not check the client's declared + elicitation capability before sending (see the divergence on `server-respects-mode`). """ async def list_tools( @@ -373,3 +391,271 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara }, ) ) + + +@requirement("elicitation:form:schema:primitives") +@requirement("elicitation:form:schema:enum-variants") +async def test_elicit_form_schema_with_every_primitive_and_enum_type_reaches_the_callback_as_sent( + connect: Connect, +) -> None: + """A requested schema covering every spec-listed property kind is delivered to the callback unchanged. + + One schema with one property per kind: a formatted string, an integer with bounds, a number, + a boolean, a plain enum, a oneOf-const titled enum, and a multi-select array-of-enum. The + callback observing the same schema as the handler sent proves both the primitive coverage and + the enum-variant coverage in one snapshot. + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "email": {"type": "string", "format": "email", "title": "Email", "description": "Contact address."}, + "age": {"type": "integer", "minimum": 0, "maximum": 150}, + "score": {"type": "number"}, + "subscribe": {"type": "boolean", "default": False}, + "tier": {"type": "string", "enum": ["free", "pro", "team"]}, + "region": { + "oneOf": [ + {"const": "eu", "title": "Europe"}, + {"const": "na", "title": "North America"}, + ], + }, + "channels": {"type": "array", "items": {"type": "string", "enum": ["email", "sms", "push"]}}, + }, + "required": ["email"], + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="onboard", description="Onboard the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + answer = await ctx.session.elicit_form("Tell us about yourself.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("onboarder", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={"email": "ada@example.com"}) + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("onboard", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:schema:restricted-subset") +async def test_elicit_form_with_a_nested_schema_is_forwarded_unchanged(connect: Connect) -> None: + """A requested schema with nested-object and array-of-object properties passes through unchanged. + + The spec restricts form-mode requested schemas to flat objects with primitive-typed properties; + this test pins that the SDK does not enforce that restriction on either side (see the + divergence on the requirement). + """ + schema: ElicitRequestedSchema = { + "type": "object", + "properties": { + "address": { + "type": "object", + "properties": {"street": {"type": "string"}, "city": {"type": "string"}}, + }, + "contacts": { + "type": "array", + "items": {"type": "object", "properties": {"name": {"type": "string"}}}, + }, + }, + } + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="profile", description="Collect a profile.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "profile" + answer = await ctx.session.elicit_form("Profile details.", schema) + return CallToolResult(content=[TextContent(text=answer.action)]) + + server = Server("profiler", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[types.ElicitRequestParams] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="decline") + + async with connect(server, elicitation_callback=answer_form) as client: + await client.call_tool("profile", {}) + + assert len(received) == 1 + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].requested_schema == schema + + +@requirement("elicitation:form:response-validation") +async def test_accepted_elicitation_content_that_violates_the_schema_reaches_the_handler_unchanged( + connect: Connect, +) -> None: + """Accepted form content that contradicts the requested schema is delivered to the handler unchanged. + + The schema requires a string `name`; the callback answers with a wrong-type value and an extra + field. Nothing on either side validates the response against the schema (see the divergence on + the requirement), so the handler observes exactly what the callback sent. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="signup", description="Register the user.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "signup" + answer = await ctx.session.elicit_form( + "Choose a name.", + {"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}, + ) + return CallToolResult(content=[TextContent(text=answer.action)], structured_content=answer.content) + + server = Server("registrar", on_list_tools=list_tools, on_call_tool=call_tool) + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + return ElicitResult(action="accept", content={"name": 42, "extra": "field"}) + + async with connect(server, elicitation_callback=answer_form) as client: + result = await client.call_tool("signup", {}) + + assert result == snapshot( + CallToolResult(content=[TextContent(text="accept")], structured_content={"name": 42, "extra": "field"}) + ) + + +@requirement("elicitation:url:complete-unknown-ignored") +async def test_elicitation_complete_for_an_unknown_id_is_received_without_error(connect: Connect) -> None: + """An elicitation/complete for an id the client never elicited is delivered and does not fail anything. + + No URL elicitation precedes the notification; the client neither tracks elicitation ids nor + rejects unknown ones, so the call completes normally and the message handler observes the + notification as-is. + """ + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult( + tools=[types.Tool(name="noop", description="Send a stray complete.", input_schema={"type": "object"})] + ) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "noop" + await ctx.session.send_elicit_complete("never-elicited") + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) + + received: list[IncomingMessage] = [] + + async def collect(message: IncomingMessage) -> None: + received.append(message) + + async with connect(server, message_handler=collect) as client: + result = await client.call_tool("noop", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")])) + assert received == snapshot( + [ElicitCompleteNotification(params=ElicitCompleteNotificationParams(elicitation_id="never-elicited"))] + ) + + +@requirement("elicitation:form:mode-omitted-default") +async def test_a_mode_less_elicitation_request_is_treated_as_form_mode() -> None: + """An elicitation/create request with no mode field reaches the client callback as form-mode. + + The typed server API always serializes a mode (`elicit_form` writes 'form', `elicit_url` writes + 'url'), so this test plays the server's side of the wire by hand to send a request body without + one. Reserve this pattern for behaviour the typed server API cannot produce. + """ + received: list[types.ElicitRequestParams] = [] + answered = anyio.Event() + server_received: list[JSONRPCMessage] = [] + + async def answer_form(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + received.append(params) + return ElicitResult(action="accept", content={}) + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def scripted_server() -> None: + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, + ) + ) + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(scripted_server) + async with ClientSession(client_read, client_write, elicitation_callback=answer_form) as session: + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, + ) + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 1b6ac3e66a..0c2a0b1588 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -8,9 +8,10 @@ import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types from mcp.server import Server, ServerRequestContext from mcp.types import ( + INVALID_PARAMS, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -90,6 +91,70 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa assert requests_made == len(pages) +@requirement("pagination:client:cursor-handling") +async def test_the_client_follows_opaque_cursors_through_pages_of_varying_sizes(connect: Connect) -> None: + """The client passes a server-issued cursor back byte-for-byte and follows pages of varying sizes. + + The cursors are deliberately base64-looking strings (with padding and URL-unsafe characters) to + show the client treats them as opaque tokens; the page sizes [3, 1, 2] show the loop relies only + on next_cursor, not on a fixed page size. + """ + cursor_to_page_2 = "YWxwaGE+YnJhdm8/Y2hhcmxpZQ==" + cursor_to_page_3 = "ZGVsdGE=" + pages: dict[str | None, tuple[list[str], str | None]] = { + None: (["alpha", "beta", "gamma"], cursor_to_page_2), + cursor_to_page_2: (["delta"], cursor_to_page_3), + cursor_to_page_3: (["epsilon", "zeta"], None), + } + received_cursors: list[str | None] = [] + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + received_cursors.append(params.cursor) + names, next_cursor = pages[params.cursor] + return ListToolsResult( + tools=[Tool(name=name, input_schema={"type": "object"}) for name in names], next_cursor=next_cursor + ) + + server = Server("paginated", on_list_tools=list_tools) + + page_sizes: list[int] = [] + cursor: str | None = None + async with connect(server) as client: + while True: + result = await client.list_tools(cursor=cursor) + page_sizes.append(len(result.tools)) + if result.next_cursor is None: + break + cursor = result.next_cursor + + # Identity, not a snapshot: what arrived at the handler is exactly what the handler issued. + assert received_cursors == [None, cursor_to_page_2, cursor_to_page_3] + assert page_sizes == [3, 1, 2] + + +@requirement("pagination:invalid-cursor") +async def test_an_unrecognized_pagination_cursor_is_rejected_with_invalid_params(connect: Connect) -> None: + """A list request with a cursor the server did not issue is answered with -32602 Invalid params. + + The lowlevel server does not validate cursors itself (they are opaque to it); rejecting an + unrecognized cursor is the handler's job, and this test pins the spec-recommended way to do it. + """ + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + assert params is not None + assert params.cursor == "never-issued" + raise MCPError(code=INVALID_PARAMS, message=f"Unknown cursor: {params.cursor!r}") + + server = Server("paginated", on_list_tools=list_tools) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools(cursor="never-issued") + + assert exc_info.value.error.code == INVALID_PARAMS + + @requirement("resources:list:pagination") async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: """resources/list round-trips the cursor like every other list operation.""" diff --git a/tests/interaction/lowlevel/test_prompts.py b/tests/interaction/lowlevel/test_prompts.py index b09f765755..868b82692c 100644 --- a/tests/interaction/lowlevel/test_prompts.py +++ b/tests/interaction/lowlevel/test_prompts.py @@ -7,14 +7,18 @@ from mcp.server import Server, ServerRequestContext from mcp.types import ( INVALID_PARAMS, + AudioContent, + EmbeddedResource, ErrorData, GetPromptResult, Icon, + ImageContent, ListPromptsResult, Prompt, PromptArgument, PromptMessage, TextContent, + TextResourceContents, ) from tests.interaction._connect import Connect from tests.interaction._requirements import requirement @@ -120,6 +124,72 @@ async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestPa ) +@requirement("prompts:get:no-args") +async def test_get_prompt_without_arguments_returns_the_messages(connect: Connect) -> None: + """A prompt fetched with no arguments delivers None as the handler's arguments and returns its messages.""" + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "static" + assert params.arguments is None + return GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("static") + + assert result == snapshot( + GetPromptResult(messages=[PromptMessage(role="user", content=TextContent(text="Say hello."))]) + ) + + +@requirement("prompts:get:content:image") +@requirement("prompts:get:content:audio") +@requirement("prompts:get:content:embedded-resource") +async def test_get_prompt_with_non_text_content_round_trips(connect: Connect) -> None: + """Prompt messages can carry image, audio, and embedded-resource content; all reach the client. + + A single full-result snapshot proves all three content types round-trip: each block in the result + is one of the three behaviours under test. Tiny fixed base64 payloads ("aW1n" is b"img", "YXVk" + is b"aud") so the snapshot pins the exact bytes. + """ + + async def get_prompt(ctx: ServerRequestContext, params: types.GetPromptRequestParams) -> GetPromptResult: + assert params.name == "media" + return GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + + server = Server("prompter", on_get_prompt=get_prompt) + + async with connect(server) as client: + result = await client.get_prompt("media", {}) + + assert result == snapshot( + GetPromptResult( + messages=[ + PromptMessage(role="user", content=ImageContent(data="aW1n", mime_type="image/png")), + PromptMessage(role="assistant", content=AudioContent(data="YXVk", mime_type="audio/wav")), + PromptMessage( + role="user", + content=EmbeddedResource( + resource=TextResourceContents(uri="resource://notes/1", mime_type="text/plain", text="attached") + ), + ), + ] + ) + ) + + @requirement("prompts:get:unknown-name") async def test_get_prompt_unknown_name_is_protocol_error(connect: Connect) -> None: """A handler that rejects an unrecognised prompt name with MCPError produces a JSON-RPC error. diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 1d29a62e07..b6bed63a9c 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -8,6 +8,7 @@ from mcp import MCPError, types from mcp.server import Server, ServerRequestContext from mcp.types import ( + METHOD_NOT_FOUND, Annotations, BlobResourceContents, CallToolResult, @@ -32,8 +33,12 @@ @requirement("resources:list:basic") +@requirement("resources:annotations") async def test_list_resources_returns_registered_resources(connect: Connect) -> None: - """Listed resources reach the client with their URIs, names, and optional descriptive fields intact.""" + """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. + + The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + """ async def list_resources( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -205,6 +210,29 @@ async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeR assert result == snapshot(EmptyResult()) +@requirement("resources:subscribe:capability-required") +async def test_subscribe_without_a_subscribe_handler_is_method_not_found(connect: Connect) -> None: + """Subscribing to a server that registered no subscribe handler is rejected with METHOD_NOT_FOUND. + + The rejection comes from no handler being registered, not from any capability check; see the + divergence on lifecycle:capability:server-not-advertised. + """ + + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered only so the resources capability is advertised; never called.""" + raise NotImplementedError + + server = Server("library", on_list_resources=list_resources) + + async with connect(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.subscribe_resource("file:///watched.txt") + + assert exc_info.value.error == snapshot(ErrorData(code=METHOD_NOT_FOUND, message="Method not found")) + + @requirement("resources:unsubscribe") async def test_unsubscribe_resource_delivers_uri_to_handler(connect: Connect) -> None: """Unsubscribing from a resource delivers the URI to the server's unsubscribe handler.""" diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 7ab93d7655..95bb6bd790 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -158,6 +158,52 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa ) +@requirement("tools:input-schema:json-schema-2020-12") +@requirement("tools:input-schema:preserve-additional-properties") +@requirement("tools:input-schema:preserve-defs") +@requirement("tools:input-schema:preserve-schema-dialect") +async def test_tools_list_preserves_arbitrary_input_schema_keywords(connect: Connect) -> None: + """A rich JSON Schema 2020-12 inputSchema reaches the client unchanged and the tool is callable. + + The single identity assertion below proves all four pass-through behaviours at once: the same + dict literal that was registered is the dict that arrives, so $schema, $defs, the nested object + property, and additionalProperties are each preserved by virtue of the whole schema being + preserved. The follow-up call proves the rich-schema tool is callable end to end. + """ + schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "$defs": {"positive": {"type": "integer", "exclusiveMinimum": 0}}, + "properties": { + "count": {"$ref": "#/$defs/positive"}, + "options": { + "type": "object", + "properties": {"verbose": {"type": "boolean"}}, + "additionalProperties": False, + }, + }, + "required": ["count"], + "additionalProperties": False, + } + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="typed", input_schema=schema)]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "typed" + assert params.arguments == {"count": 3, "options": {"verbose": True}} + return CallToolResult(content=[TextContent(text="ok")]) + + server = Server("typed", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server) as client: + listed = await client.list_tools() + called = await client.call_tool("typed", {"count": 3, "options": {"verbose": True}}) + + assert listed.tools[0].input_schema == schema + assert called == snapshot(CallToolResult(content=[TextContent(text="ok")])) + + @requirement("tools:list:metadata") async def test_list_tools_optional_fields_round_trip(connect: Connect) -> None: """Every optional Tool field the server supplies reaches the client unchanged.""" diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 62a2032ac1..a3453b7b2a 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -5,9 +5,9 @@ the transport seam into a list without touching the session, so the assertions hold for whatever the session implementation sends rather than for what its API returns. -The final two tests drive the wire by hand instead: one closes the server-to-client stream while a -request is in flight to pin the connection-closed teardown, and one sends a deliberately malformed -JSON-RPC request that the typed client API cannot produce. +The later tests drive the wire by hand instead: one closes the server-to-client stream while a +request is in flight to pin the connection-closed teardown, and the last two send deliberately +malformed JSON-RPC requests that the typed client API cannot produce. """ import anyio @@ -27,6 +27,7 @@ CallToolRequest, CallToolRequestParams, CallToolResult, + EmptyResult, ErrorData, JSONRPCError, JSONRPCNotification, @@ -239,3 +240,64 @@ async def test_malformed_request_params_are_answered_with_invalid_params() -> No server_task_group.cancel_scope.cancel() assert errors == snapshot([ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="")]) + + +@requirement("logging:set-level:invalid-level") +async def test_set_level_with_an_unrecognized_value_is_answered_with_invalid_params() -> None: + """logging/setLevel with a value outside the spec's level enum is answered with -32602 Invalid params. + + The typed client API cannot construct a setLevel request with an unrecognized level (pyright and + the client-side model both reject it), so the test plays the client's side of the wire by hand + against a real Server. Reserve this pattern for behaviour the typed API cannot produce. + """ + + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; never called -- params validation fails first.""" + raise NotImplementedError + + server = Server("logger", on_set_logging_level=set_logging_level) + errors: list[ErrorData] = [] + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async with anyio.create_task_group() as server_task_group: + server_task_group.start_soon(server.run, server_read, server_write, server.create_initialization_options()) + + with anyio.fail_after(5): + await client_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=0, + method="initialize", + params={ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": {"name": "raw", "version": "0.0.1"}, + }, + ) + ) + ) + init_response = await client_read.receive() + assert isinstance(init_response, SessionMessage) + assert isinstance(init_response.message, JSONRPCResponse) + await client_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) + ) + + await client_write.send( + SessionMessage( + JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "loud"}) + ) + ) + error_response = await client_read.receive() + assert isinstance(error_response, SessionMessage) + assert isinstance(error_response.message, JSONRPCError) + errors.append(error_response.message.error) + + server_task_group.cancel_scope.cancel() + + assert len(errors) == 1 + assert errors[0].code == INVALID_PARAMS diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index 8c7a653af2..ddea4d8278 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -155,3 +155,37 @@ def review(code: str, style: str = "pep8") -> str: messages=[PromptMessage(role="user", content=TextContent(text="Review x = 1 per pep8."))], ) ) + + +@requirement("mcpserver:prompt:duplicate-name") +async def test_registering_a_duplicate_prompt_name_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second prompt with an already-used name keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The + second function is registered via the decorator with an explicit name so the test does not + redefine the same function name in this scope. + """ + mcp = MCPServer("prompter") + + @mcp.prompt() + def greet() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.prompt(name="greet") + def greet_second() -> str: + """Registered with a duplicate name; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_prompts() + result = await client.get_prompt("greet") + + assert [prompt.name for prompt in listed.prompts] == ["greet"] + assert result == snapshot( + GetPromptResult( + description="The first registration; this is the one that wins.", + messages=[PromptMessage(role="user", content=TextContent(text="first"))], + ) + ) diff --git a/tests/interaction/mcpserver/test_resources.py b/tests/interaction/mcpserver/test_resources.py index d208857fe5..57b0fdc86d 100644 --- a/tests/interaction/mcpserver/test_resources.py +++ b/tests/interaction/mcpserver/test_resources.py @@ -149,3 +149,35 @@ def boom() -> str: await client.read_resource("res://boom") assert exc_info.value.error == snapshot(ErrorData(code=0, message="Error reading resource res://boom")) + + +@requirement("mcpserver:resource:duplicate-name") +async def test_registering_a_duplicate_resource_uri_warns_and_keeps_the_first(connect: Connect) -> None: + """Registering a second static resource at an already-used URI keeps the first registration. + + The intended behaviour is rejection at registration time; MCPServer instead logs a warning + and discards the second registration (see the divergence note on the requirement). The two + registrations use different function names so the test does not redefine a name in this scope; + the resource decorator keys on the URI, not the function name. + """ + mcp = MCPServer("library") + + @mcp.resource("config://app") + def config_first() -> str: + """The first registration; this is the one that wins.""" + return "first" + + @mcp.resource("config://app") + def config_second() -> str: + """Registered at a duplicate URI; the registration is discarded so this never runs.""" + raise NotImplementedError + + async with connect(mcp) as client: + listed = await client.list_resources() + result = await client.read_resource("config://app") + + assert [resource.uri for resource in listed.resources] == ["config://app"] + assert listed.resources[0].name == "config_first" + assert result == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="config://app", mime_type="text/plain", text="first")]) + ) From 4a7d563690c1c5d55cd490632ae558a7a3b5b942 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 12:06:15 +0000 Subject: [PATCH 21/34] test: cover composed flow scenarios and stdio framing requirements --- src/mcp/client/streamable_http.py | 4 +- tests/interaction/_requirements.py | 62 +++--- tests/interaction/lowlevel/test_flows.py | 193 ++++++++++++++++++ tests/interaction/mcpserver/test_context.py | 28 +++ tests/interaction/mcpserver/test_tools.py | 32 ++- .../transports/test_client_transport_http.py | 37 +++- tests/interaction/transports/test_flows.py | 127 ++++++++++++ .../transports/test_hosting_resume.py | 3 + tests/interaction/transports/test_stdio.py | 76 ++++++- 9 files changed, 524 insertions(+), 38 deletions(-) create mode 100644 tests/interaction/lowlevel/test_flows.py create mode 100644 tests/interaction/transports/test_flows.py diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index a6b4e6cfa0..aa3e50e07e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -267,8 +267,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: # pragma: no branch - if isinstance(message, JSONRPCRequest): # pragma: no branch + if response.status_code == 404: + if isinstance(message, JSONRPCRequest): error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 1c2cd21f36..b5897ee46d 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -672,10 +672,9 @@ def __post_init__(self) -> None: "mcpserver:tool:extra": Requirement( source="sdk", behavior=( - "Tool functions can access request metadata (request id, client params, session, lifespan " - "state) through the Context parameter." + "Tool functions can access request metadata (request id, client params, session) through the " + "Context parameter." ), - deferred="Not yet covered here: planned gap test (Context request-metadata access from inside a tool).", ), "mcpserver:tool:handler-throws": Requirement( source="sdk", @@ -719,10 +718,6 @@ def __post_init__(self) -> None: "Tool input schemas generated from complex parameter types (unions, nested models, " "constrained types) validate and coerce arguments before the function runs." ), - deferred=( - "Not yet covered here: planned gap test (complex parameter types validated and coerced before " - "the function runs)." - ), ), "mcpserver:tool:unknown-name": Requirement( source=f"{SPEC_BASE_URL}/server/tools#error-handling", @@ -2186,10 +2181,6 @@ def __post_init__(self) -> None: "session; the spec's MUST is not satisfied." ), ), - deferred=( - "Not implemented in the SDK: the client surfaces the 404 as an error to the caller instead of " - "re-initializing a new session." - ), ), "client-transport:http:accept-header-get": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#listening-for-messages-from-the-server", @@ -2520,13 +2511,18 @@ def __post_init__(self) -> None: "is not a valid MCP message is written to its stdin." ), transports=("stdio",), - deferred="Not yet covered here: planned with the stdio end-to-end test.", + divergence=Divergence( + note=( + "stdio_server's own writes satisfy this, but it does not redirect or guard sys.stdout: " + "handler code that calls print() writes directly to the protocol stream and corrupts the " + "framing. The spec MUST is satisfied only as long as application code behaves." + ), + ), ), "transport:stdio:no-embedded-newlines": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#stdio", behavior="Serialized JSON-RPC messages on stdio contain no embedded newlines; one message per line.", transports=("stdio",), - deferred="Not yet covered here: planned with the stdio end-to-end test.", ), "transport:stdio:shutdown-escalation": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#stdio", @@ -2535,13 +2531,17 @@ def __post_init__(self) -> None: "it (and kills it if still alive) after a grace period." ), transports=("stdio",), - deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", + deferred=( + "Not yet covered here: a server that ignores stdin close takes the full " + "PROCESS_TERMINATION_TIMEOUT (2.0 s) grace period plus up to a further 2.0 s for " + "SIGTERM/SIGKILL escalation; a robust test of that path is real-time-bound and the constant " + "is module-level (no public override). Covered by tests/client/test_stdio.py." + ), ), "transport:stdio:stderr-passthrough": Requirement( source="sdk", behavior="Server stderr is available to the client and is not consumed by the transport.", transports=("stdio",), - deferred="Not yet covered here; existing coverage in tests/client/test_stdio.py.", ), # ═══════════════════════════════════════════════════════════════════════════ # Composite end-to-end flows @@ -2553,7 +2553,6 @@ def __post_init__(self) -> None: "concurrently; clients on either transport can call the same tools." ), transports=("streamable-http", "sse"), - deferred="Not yet covered here: planned with the transport conformance work.", ), "flow:compat:streamable-then-sse-fallback": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", @@ -2562,7 +2561,18 @@ def __post_init__(self) -> None: "SSE client transport against the same server connects successfully." ), transports=("streamable-http", "sse"), - deferred="Not yet covered here: planned with the transport conformance work.", + divergence=Divergence( + note=( + "The SDK provides no automatic streamable-HTTP-to-SSE client fallback; the spec's " + "client-side SHOULD is left to the application to compose from streamable_http_client " + "and sse_client. Both halves are independently proven by the matrix." + ), + ), + deferred=( + "A demonstration test would only re-prove what the matrix already covers (an SSE-only " + "server is reachable via sse_client; an unmounted route returns 404), with the application " + "doing the fallback in between rather than the SDK." + ), ), "flow:elicitation:multi-step-form": Requirement( source="sdk", @@ -2570,7 +2580,6 @@ def __post_init__(self) -> None: "A single tool handler issues sequential elicitations; an accept on one step feeds the next, " "and a decline or cancel at any step short-circuits to a final result." ), - deferred="Not yet covered here: planned gap test (multi-step elicitation flow).", ), "flow:elicitation:url-at-session-init": Requirement( source="sdk", @@ -2579,7 +2588,13 @@ def __post_init__(self) -> None: "session initialization, before any client request." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", + deferred=( + "No public per-session post-initialization hook exists on either server flavour " + "(Server.lifespan runs at server startup, not per session; ServerSession handles the " + "initialized notification internally with no callback). Driving 'before any client " + "request' deterministically would also require knowing the standalone GET stream is " + "established, which has no synchronization signal." + ), ), "flow:elicitation:url-required-then-retry": Requirement( source=f"{SPEC_BASE_URL}/client/elicitation#url-elicitation-required-error", @@ -2587,7 +2602,6 @@ def __post_init__(self) -> None: "A tool call rejected with the URL-elicitation-required error can be retried successfully " "after the client completes the URL flow and the server announces completion." ), - deferred="Not yet covered here: planned gap test (full URL-elicitation-required retry flow).", ), "flow:multi-client:stateful-isolation": Requirement( source="sdk", @@ -2596,7 +2610,6 @@ def __post_init__(self) -> None: "only the notifications produced by their own requests." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "flow:oauth:authorization-code-roundtrip": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", @@ -2610,17 +2623,15 @@ def __post_init__(self) -> None: "flow:resume:tool-call-resumption-token": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", behavior=( - "A tool call interrupted mid-stream can be resumed with the captured resumption token, " - "delivering only the remaining notifications and the final result." + "A tool call interrupted mid-stream is transparently resumed by the client transport using " + "the last-seen event id, delivering only the remaining notifications and the final result." ), transports=("streamable-http",), - deferred="Not yet covered here; existing coverage in tests/shared/test_streamable_http.py.", ), "flow:session:terminate-then-reconnect": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", behavior=("After terminating a session, a fresh connection obtains a new session id and operations succeed."), transports=("streamable-http",), - deferred="Not yet covered here: planned with the transport conformance work.", ), "flow:tool-result:resource-link-follow": Requirement( source=f"{SPEC_BASE_URL}/server/tools#resource-links", @@ -2628,7 +2639,6 @@ def __post_init__(self) -> None: "A resource_link returned by a tool call can be followed with resources/read on the linked " "URI to retrieve the referenced contents." ), - deferred="Not yet covered here: planned gap test (follow a resource link returned by a tool).", ), } diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py new file mode 100644 index 0000000000..8ff9dd4f1d --- /dev/null +++ b/tests/interaction/lowlevel/test_flows.py @@ -0,0 +1,193 @@ +"""Composed multi-feature flows against the low-level Server, driven through the public Client API. + +Each test reads as the scenario it proves: the steps run top to bottom in the order a real client +would perform them, composing two or more feature areas (a tool call followed by a resource read; +a chain of elicitations inside one tool call; the full URL-elicitation-required retry loop). The +individual features are pinned by their own tests; these prove they compose. +""" + +from collections.abc import Awaitable, Callable + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import MCPError, UrlElicitationRequiredError, types +from mcp.client import ClientRequestContext +from mcp.server import Server, ServerRequestContext +from mcp.server.session import ServerSession +from mcp.types import ( + URL_ELICITATION_REQUIRED, + CallToolResult, + ElicitCompleteNotification, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + ListToolsResult, + ReadResourceResult, + ResourceLink, + TextContent, + TextResourceContents, + Tool, +) +from tests.interaction._connect import Connect +from tests.interaction._helpers import IncomingMessage +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + +ListToolsHandler = Callable[ + [ServerRequestContext, types.PaginatedRequestParams | None], Awaitable[types.ListToolsResult] +] + + +def _list_tools(*names: str) -> ListToolsHandler: + """A list_tools handler advertising the named tools, so call_tool's implicit list succeeds.""" + + async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name=name, input_schema={"type": "object"}) for name in names]) + + return list_tools + + +@requirement("flow:tool-result:resource-link-follow") +async def test_a_resource_link_returned_by_a_tool_can_be_followed_with_read(connect: Connect) -> None: + """A tool returns a resource_link; reading that link's URI returns the referenced contents. + + Steps: (1) call the tool, (2) extract the link from its content, (3) read_resource on the + link's URI, (4) the read result carries the linked contents. + """ + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "generate" + return CallToolResult(content=[ResourceLink(uri="file:///report.txt", name="report")]) + + async def read_resource(ctx: ServerRequestContext, params: types.ReadResourceRequestParams) -> ReadResourceResult: + assert str(params.uri) == "file:///report.txt" + return ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + + server = Server( + "linker", on_list_tools=_list_tools("generate"), on_call_tool=call_tool, on_read_resource=read_resource + ) + + async with connect(server) as client: + called = await client.call_tool("generate", {}) + link = called.content[0] + assert isinstance(link, ResourceLink) + read = await client.read_resource(link.uri) + + assert called == snapshot(CallToolResult(content=[ResourceLink(name="report", uri="file:///report.txt")])) + assert read == snapshot( + ReadResourceResult(contents=[TextResourceContents(uri="file:///report.txt", text="generated")]) + ) + + +@requirement("flow:elicitation:multi-step-form") +async def test_a_tool_handler_chains_form_elicitations_feeding_each_answer_forward(connect: Connect) -> None: + """Sequential form elicitations inside one tool call: each accepted answer feeds the next step. + + Steps: (1) call the tool, (2) the handler issues a step-one form elicitation that the client + accepts with content, (3) the handler issues a step-two elicitation whose message references + the step-one answer, (4) the client accepts step two, (5) the tool result summarises both + answers. The callback is invoked exactly twice with the expected messages and schemas. The + short-circuit on decline is the application's choice (proven separately by the per-action + elicitation tests); what this flow pins is that the chain itself works end to end. + """ + received: list[ElicitRequestFormParams] = [] + answers: list[dict[str, str | int | float | bool | list[str] | None]] = [{"name": "ada"}, {"age": 37}] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "onboard" + first = await ctx.session.elicit_form( + "Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}} + ) + assert first.action == "accept" and first.content is not None + second = await ctx.session.elicit_form( + f"Step 2: confirm age for {first.content['name']}.", + {"type": "object", "properties": {"age": {"type": "integer"}}}, + ) + assert second.action == "accept" and second.content is not None + return CallToolResult(content=[TextContent(text=f"{first.content['name']} is {second.content['age']}")]) + + server = Server("onboarder", on_list_tools=_list_tools("onboard"), on_call_tool=call_tool) + + async def answer(context: ClientRequestContext, params: types.ElicitRequestParams) -> ElicitResult: + assert isinstance(params, ElicitRequestFormParams) + received.append(params) + return ElicitResult(action="accept", content=answers[len(received) - 1]) + + async with connect(server, elicitation_callback=answer) as client: + result = await client.call_tool("onboard", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="ada is 37")])) + assert [(p.message, p.requested_schema) for p in received] == snapshot( + [ + ("Step 1: choose a username.", {"type": "object", "properties": {"name": {"type": "string"}}}), + ("Step 2: confirm age for ada.", {"type": "object", "properties": {"age": {"type": "integer"}}}), + ] + ) + + +@requirement("flow:elicitation:url-required-then-retry") +async def test_a_tool_rejected_with_url_elicitation_required_succeeds_on_retry_after_completion( + connect: Connect, +) -> None: + """The full URL-elicitation-required retry loop: -32042, completion announced, retry succeeds. + + Steps: (1) the first call is rejected with -32042 carrying the required URL elicitation in + its error data, (2) the client extracts the elicitation id from the error, (3) the server + announces completion via the elicitation/complete notification (driven via the captured + session, the same way a real out-of-band callback would reach a held session reference), + (4) the client observes the matching completion notification and retries, (5) the retry + succeeds. The handler distinguishes the two calls by a closure flag the test flips between + them; the test waits on the completion notification with an event so the retry only happens + after the announcement has arrived. + """ + elicitation_id = "auth-001" + authorised: list[bool] = [False] + captured: list[ServerSession] = [] + completed = anyio.Event() + notifications: list[ElicitCompleteNotification] = [] + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "read_files" + captured.append(ctx.session) + if not authorised[0]: + # The log line gives the message handler a non-completion notification, so the test's + # filtering branch is exercised in both directions and the wait remains specific. + await ctx.session.send_log_message(level="warning", data="authorisation required", logger="gate") + raise UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + message="Authorize file access.", + url="https://example.com/oauth/authorize", + elicitation_id=elicitation_id, + ) + ] + ) + return CallToolResult(content=[TextContent(text="contents")]) + + server = Server("gatekeeper", on_list_tools=_list_tools("read_files"), on_call_tool=call_tool) + + async def collect(message: IncomingMessage) -> None: + if isinstance(message, ElicitCompleteNotification): + notifications.append(message) + completed.set() + + async with connect(server, message_handler=collect) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("read_files", {}) + assert exc_info.value.error.code == URL_ELICITATION_REQUIRED + required = UrlElicitationRequiredError.from_error(exc_info.value.error) + assert [e.elicitation_id for e in required.elicitations] == [elicitation_id] + + # The out-of-band interaction completes; the server announces it on the same session. + await captured[0].send_elicit_complete(elicitation_id) + with anyio.fail_after(5): + await completed.wait() + assert notifications[0].params.elicitation_id == elicitation_id + + authorised[0] = True + result = await client.call_tool("read_files", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="contents")])) diff --git a/tests/interaction/mcpserver/test_context.py b/tests/interaction/mcpserver/test_context.py index e7ae4b94d9..26556fea7a 100644 --- a/tests/interaction/mcpserver/test_context.py +++ b/tests/interaction/mcpserver/test_context.py @@ -15,6 +15,7 @@ ElicitRequestParams, ElicitResult, ErrorData, + Implementation, LoggingMessageNotification, LoggingMessageNotificationParams, TextContent, @@ -93,6 +94,33 @@ async def on_progress(progress: float, total: float | None, message: str | None) assert received == snapshot([(1.0, 3.0, None), (2.0, 3.0, "halfway there")]) +@requirement("mcpserver:tool:extra") +async def test_context_exposes_request_id_and_client_info_to_a_tool(connect: Connect) -> None: + """A tool can read the per-request id and the connecting client's identity through Context. + + The request id is non-empty (its concrete value depends on transport-level sequencing, so the + test asserts the value the tool saw is the one returned, rather than pinning the literal); the + client info reflects what the caller passed to `Client`. + """ + mcp = MCPServer("introspector") + + @mcp.tool() + async def whoami(ctx: Context) -> str: + client_params = ctx.session.client_params + assert client_params is not None + return f"request {ctx.request_id} from {client_params.client_info.name} {client_params.client_info.version}" + + async with connect(mcp, client_info=Implementation(name="acme-agent", version="9.9.9")) as client: + result = await client.call_tool("whoami", {}) + + assert isinstance(result.content[0], TextContent) + text = result.content[0].text + assert text.startswith("request ") + assert text.endswith(" from acme-agent 9.9.9") + request_id = text.removeprefix("request ").removesuffix(" from acme-agent 9.9.9") + assert request_id + + @requirement("protocol:progress:no-token") async def test_report_progress_without_a_progress_token_sends_nothing(connect: Connect) -> None: """When the caller supplied no progress callback, Context.report_progress is a silent no-op. diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index e66538ce09..f8aa208d7f 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -1,10 +1,10 @@ """Tool interactions against MCPServer, driven through the public Client API.""" -from typing import Annotated +from typing import Annotated, Literal import pytest from inline_snapshot import snapshot -from pydantic import BaseModel +from pydantic import BaseModel, Field from mcp import MCPError from mcp.server.mcpserver import Context, MCPServer @@ -45,6 +45,34 @@ def add(a: int, b: int) -> str: assert result == snapshot(CallToolResult(content=[TextContent(text="5")], structured_content={"result": "5"})) +@requirement("mcpserver:tool:schema-variants") +async def test_complex_parameter_types_are_validated_and_coerced_before_the_tool_runs(connect: Connect) -> None: + """Literal, nested-model, and constrained parameters are validated and coerced from the wire arguments. + + The string "3" is coerced to `int` and the `point` dict to a `Point` instance before the function + body sees them, proving the generated input schema and validation pipeline cover non-trivial types. + """ + mcp = MCPServer("typed") + + class Point(BaseModel): + x: int + y: int + + @mcp.tool() + def place(mode: Literal["fast", "slow"], point: Point, count: Annotated[int, Field(ge=1, le=10)]) -> str: + assert isinstance(point, Point) + return f"{mode} at ({point.x}, {point.y}) x{count}" + + async with connect(mcp) as client: + result = await client.call_tool("place", {"mode": "fast", "point": {"x": "3", "y": 4}, "count": 5}) + + assert result == snapshot( + CallToolResult( + content=[TextContent(text="fast at (3, 4) x5")], structured_content={"result": "fast at (3, 4) x5"} + ) + ) + + @requirement("mcpserver:tool:handler-throws") @requirement("mcpserver:output-schema:skip-on-error") async def test_call_tool_function_exception_becomes_error_result(connect: Connect) -> None: diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py index 604f08a8f2..2d9d0c42b6 100644 --- a/tests/interaction/transports/test_client_transport_http.py +++ b/tests/interaction/transports/test_client_transport_http.py @@ -14,11 +14,11 @@ from inline_snapshot import snapshot from starlette.types import Receive, Scope, Send -from mcp import types +from mcp import MCPError, types from mcp.client.client import Client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext -from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from mcp.types import INVALID_REQUEST, CallToolResult, ErrorData, ListToolsResult, TextContent, Tool from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION, client_via_http, mounted_app from tests.interaction._requirements import requirement from tests.interaction.transports._bridge import StreamingASGITransport @@ -209,3 +209,36 @@ async def record(request: httpx.Request) -> None: assert [tool.name for tool in result.tools] == ["echo"] resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] assert resumption_gets == [] + + +@requirement("client-transport:http:404-surfaces") +async def test_a_404_mid_session_surfaces_as_a_session_terminated_error() -> None: + """A 404 in response to a request after initialization is reported to the caller as an MCP error. + + The spec says the client MUST start a new session in this situation; the SDK instead surfaces a + `Session terminated` error to the caller (see the divergence on the requirement). This test pins + that current behaviour. + """ + server = _tooled_server() + real_app = server.streamable_http_app(transport_security=NO_DNS_REBINDING_PROTECTION) + initialize_seen = anyio.Event() + + async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["method"] == "POST" and initialize_seen.is_set(): + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + if scope["type"] == "http" and scope["method"] == "POST": + initialize_seen.set() + await real_app(scope, receive, send) + + async with server.session_manager.run(): + http_client = httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) + async with http_client: + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): + async with Client(transport) as client: + with pytest.raises(MCPError) as exc_info: + await client.list_tools() + + assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py new file mode 100644 index 0000000000..6e3d787356 --- /dev/null +++ b/tests/interaction/transports/test_flows.py @@ -0,0 +1,127 @@ +"""Transport-level composed flows: multi-client isolation, reconnection, and dual-transport hosting. + +These scenarios are about how the transport layer holds together across more than one connection +or more than one transport, so they connect real `Client`s against one mounted server rather than +running over the matrix. +""" + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.client.session import LoggingFnT +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from tests.interaction._connect import client_via_http, connect_over_sse, mounted_app +from tests.interaction._requirements import requirement + +pytestmark = pytest.mark.anyio + + +@requirement("flow:multi-client:stateful-isolation") +async def test_concurrent_clients_on_one_stateful_server_receive_only_their_own_notifications() -> None: + """Two clients on one stateful manager each receive only the notifications their own request produced. + + Complements `test_terminating_one_session_leaves_others_working` (which proves session + independence under termination) with the notification-isolation dimension: a notification + emitted by one session's handler does not leak to another session's client. + """ + mcp = MCPServer("multi") + + @mcp.tool() + async def announce(label: str, ctx: Context) -> str: + """Emit one info-level log carrying the caller's label, then return it.""" + await ctx.info(label) + return label + + received_a: list[object] = [] + received_b: list[object] = [] + + async def collect_a(params: LoggingMessageNotificationParams) -> None: + received_a.append(params.data) + + async def collect_b(params: LoggingMessageNotificationParams) -> None: + received_b.append(params.data) + + async with mounted_app(mcp) as (http, _): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + + async def call(label: str, collect: LoggingFnT) -> None: + async with client_via_http(http, logging_callback=collect) as client: + await client.call_tool("announce", {"label": label}) + + tg.start_soon(call, "a", collect_a) + tg.start_soon(call, "b", collect_b) + + assert received_a == ["a"] + assert received_b == ["b"] + + +@requirement("flow:session:terminate-then-reconnect") +async def test_a_fresh_connection_after_termination_obtains_a_new_session_and_operates() -> None: + """After a client terminates, a fresh connection to the same manager gets a distinct session. + + Steps: (1) connect a client and call list_tools, (2) the client exits (its DELETE fires), + (3) connect a second client to the same mounted app, (4) the second client's call_tool + succeeds and the recorded session ids show two distinct sessions were issued. + """ + mcp = MCPServer("reconnectable") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + session_ids: list[str] = [] + + async def record(request: httpx.Request) -> None: + session_id = request.headers.get("mcp-session-id") + if session_id is not None: + session_ids.append(session_id) + + async with mounted_app(mcp, on_request=record) as (http, _): + async with client_via_http(http) as first: + first_result = await first.list_tools() + async with client_via_http(http) as second: + second_result = await second.call_tool("echo", {"text": "again"}) + + assert {tool.name for tool in first_result.tools} == {"echo"} + assert second_result == snapshot( + CallToolResult(content=[TextContent(text="again")], structured_content={"result": "again"}) + ) + distinct = set(session_ids) + assert len(distinct) == 2, f"expected two distinct session ids across the two connections, saw {distinct}" + + +@requirement("flow:compat:dual-transport-server") +async def test_one_server_serves_streamable_http_and_sse_clients_concurrently() -> None: + """One MCPServer instance serves a streamable-HTTP client and a legacy-SSE client at the same time. + + The two transports have independent connection management (the streamable-HTTP session manager + versus a per-connection SSE handler), but both dispatch into the same server's request + handlers. The test connects one client over each transport against the same instance and + proves both reach the same tool. Uses MCPServer because the low-level Server has no SSE + convenience; the entry is about hosting composition, not the low-level API. + """ + mcp = MCPServer("dual") + + @mcp.tool() + def echo(text: str) -> str: + """Return the input unchanged.""" + return text + + async with mounted_app(mcp) as (http, _): + async with connect_over_sse(mcp) as sse_client: + async with client_via_http(http) as shttp_client: + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + + assert shttp_result == snapshot( + CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) + ) + assert sse_result == snapshot( + CallToolResult(content=[TextContent(text="via sse")], structured_content={"result": "via sse"}) + ) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index 6abeb5d8ed..bb98a96e7a 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -229,11 +229,14 @@ async def hold(ctx: Context) -> str: await finished.wait() +# This test intentionally carries every resumability requirement: the close-then-resume +# scenario is indivisible, so splitting it would mean six near-identical bodies. @requirement("hosting:resume:close-stream") @requirement("transport:streamable-http:resumability") @requirement("client-transport:http:reconnect-post-priming") @requirement("client-transport:http:reconnect-retry-value") @requirement("client-transport:http:resume-stream-api") +@requirement("flow:resume:tool-call-resumption-token") async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: """A server-closed request stream is reconnected by the client and the call completes. diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index e70a68225f..2d15d61ff8 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -1,16 +1,22 @@ -"""The suite's one stdio end-to-end test: a real SDK Server in a subprocess, driven by Client. +"""The stdio transport: one subprocess end-to-end test and one in-process framing test. -Everything else in the suite runs in a single process; this test exists to prove the same +Everything else in the suite runs in a single process; the subprocess test exists to prove the same client↔server round trip works over the stdio transport's real boundary (a child process whose stdin/stdout carry one newline-delimited JSON-RPC message per line). The server lives in `_stdio_server.py` and is launched via `python -m` so subprocess coverage measurement applies. +The framing test drives `stdio_server` in-process by passing it injected text streams instead of the +real stdin/stdout, so the raw lines the transport writes can be asserted directly without a process +boundary. + stdio is deliberately not a leg of the `connect`-fixture matrix: spawning a subprocess per test would be slow, and the matrix already proves transport-agnosticism over three in-process -transports. Process-lifecycle edge cases (escalation to terminate/kill, stderr handling, parse -errors) are covered by `tests/client/test_stdio.py` and stay deferred here. +transports. Process-lifecycle edge cases (escalation to terminate/kill, parse errors) are covered by +`tests/client/test_stdio.py` and stay deferred here. """ +import io +import json import os import sys import tempfile @@ -22,7 +28,18 @@ from mcp.client.client import Client from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.types import CallToolResult, LoggingMessageNotificationParams, TextContent +from mcp.server.stdio import stdio_server +from mcp.shared.message import SessionMessage +from mcp.types import ( + CallToolResult, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + LoggingMessageNotificationParams, + TextContent, +) +from mcp.types.jsonrpc import jsonrpc_message_adapter +from tests.interaction._connect import initialize_body from tests.interaction._requirements import requirement from tests.interaction.transports import _stdio_server @@ -33,6 +50,7 @@ @requirement("transport:stdio") @requirement("transport:stdio:clean-shutdown") +@requirement("transport:stdio:stderr-passthrough") async def test_tool_call_and_notification_round_trip_over_a_stdio_subprocess() -> None: """A Client connected over stdio initializes, calls a tool with arguments, receives the server's log notification before the call returns, and the server exits when the transport @@ -72,5 +90,51 @@ async def collect(params: LoggingMessageNotificationParams) -> None: ) # The server writes this line only after its run loop returns, which happens when stdin closes: # seeing it proves the process exited on its own rather than via the transport's terminate - # escalation, without a timing-based assertion. + # escalation, without a timing-based assertion. The capture itself proves stderr passthrough: + # the transport routes the child's stderr to the caller's `errlog` without consuming it. assert captured_stderr == snapshot("stdio-echo: clean exit\n") + + +@requirement("transport:stdio:stream-purity") +@requirement("transport:stdio:no-embedded-newlines") +async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: + """Everything `stdio_server` writes is a valid JSON-RPC message on its own line, and nothing else. + + The transport's stdin/stdout parameters are public, so the test injects in-process text streams + instead of the real process handles and drives the read/write streams directly: a JSON-RPC line on + stdin is parsed and delivered, and every message sent on the write stream appears as exactly one + newline-terminated line whose payload newlines are JSON-escaped. This proves the transport's own + framing; it does not guard `sys.stdout` against handler code that prints to it directly (see the + divergence on `transport:stdio:stream-purity`). + """ + captured = io.StringIO() + sent_line = json.dumps(initialize_body(request_id=1)) + "\n" + + with anyio.fail_after(5): + async with stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ): + async with read_stream, write_stream: + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) + + output = captured.getvalue() + assert output.endswith("\n") + lines = output.removesuffix("\n").split("\n") + assert len(lines) == 2 + messages = [jsonrpc_message_adapter.validate_json(line) for line in lines] + assert [type(message).__name__ for message in messages] == snapshot(["JSONRPCResponse", "JSONRPCNotification"]) + # The newline inside the payload is JSON-escaped on the wire, not a literal newline that would + # break the one-message-per-line framing. + assert r"line\nbreak" in lines[0] + assert r"two\nlines" in lines[1] From 9fb50a1d6058c8127ba5f26fdfb0ad846ed1b606 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 14:22:28 +0000 Subject: [PATCH 22/34] test: add end-to-end OAuth authorization tests with an in-process AS/RS harness --- src/mcp/client/auth/oauth2.py | 20 +- src/mcp/server/auth/handlers/authorize.py | 2 +- src/mcp/server/auth/middleware/bearer_auth.py | 2 +- src/mcp/server/lowlevel/server.py | 8 +- src/mcp/shared/auth.py | 2 +- src/mcp/shared/session.py | 2 +- tests/interaction/_connect.py | 13 +- tests/interaction/_requirements.py | 254 +++++++--- tests/interaction/auth/__init__.py | 0 tests/interaction/auth/_harness.py | 461 ++++++++++++++++++ tests/interaction/auth/_provider.py | 187 +++++++ tests/interaction/auth/test_as_handlers.py | 300 ++++++++++++ .../interaction/auth/test_authorize_token.py | 399 +++++++++++++++ tests/interaction/auth/test_bearer.py | 189 +++++++ tests/interaction/auth/test_discovery.py | 333 +++++++++++++ tests/interaction/auth/test_flow.py | 239 +++++++++ tests/interaction/auth/test_lifecycle.py | 445 +++++++++++++++++ tests/interaction/test_coverage.py | 1 + 18 files changed, 2767 insertions(+), 90 deletions(-) create mode 100644 tests/interaction/auth/__init__.py create mode 100644 tests/interaction/auth/_harness.py create mode 100644 tests/interaction/auth/_provider.py create mode 100644 tests/interaction/auth/test_as_handlers.py create mode 100644 tests/interaction/auth/test_authorize_token.py create mode 100644 tests/interaction/auth/test_bearer.py create mode 100644 tests/interaction/auth/test_discovery.py create mode 100644 tests/interaction/auth/test_flow.py create mode 100644 tests/interaction/auth/test_lifecycle.py diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f5775..3c546fda2b 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -360,10 +360,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise OAuthFlowError("No authorization code received") # pragma: no cover + raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -452,7 +452,7 @@ async def _refresh_token(self) -> httpx.Request: return httpx.Request("POST", token_url, data=refresh_data, headers=headers) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -468,12 +468,12 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p await self.context.storage.set_tokens(token_response) return True - except ValidationError: + except ValidationError: # pragma: no cover logger.exception("Invalid refresh response") self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover + async def _initialize(self) -> None: """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -507,17 +507,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() # pragma: no cover - refresh_response = yield refresh_request # pragma: no cover + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request - if not await self._handle_refresh_response(refresh_response): # pragma: no cover + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False @@ -612,7 +612,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py index dec6713b13..5cf93cf8c2 100644 --- a/src/mcp/server/auth/handlers/authorize.py +++ b/src/mcp/server/auth/handlers/authorize.py @@ -117,7 +117,7 @@ async def error_response( pass # the error response MUST contain the state specified by the client, if any - if state is None: # pragma: no cover + if state is None: # make last-ditch effort to load state state = best_effort_extract_string("state", params) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py index 6825c00b9e..2eafdc793e 100644 --- a/src/mcp/server/auth/middleware/bearer_auth.py +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -95,7 +95,7 @@ async def _send_auth_error(self, send: Send, status_code: int, error: str, descr """Send an authentication error response with WWW-Authenticate header.""" # Build WWW-Authenticate header value www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: # pragma: no cover + if self.resource_metadata_url: www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') www_authenticate = f"Bearer {', '.join(www_auth_parts)}" diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index d1a15120af..5e4e2e6f5b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -603,7 +603,7 @@ def streamable_http_app( required_scopes: list[str] = [] # Set up auth if configured - if auth: # pragma: no cover + if auth: required_scopes = auth.required_scopes or [] # Add auth middleware if token verifier is available @@ -629,10 +629,10 @@ def streamable_http_app( ) # Set up routes with or without auth - if token_verifier: # pragma: no cover + if token_verifier: # Determine resource metadata URL resource_metadata_url = None - if auth and auth.resource_server_url: + if auth and auth.resource_server_url: # pragma: no branch # Build compliant metadata URL for WWW-Authenticate header resource_metadata_url = build_resource_metadata_url(auth.resource_server_url) @@ -652,7 +652,7 @@ def streamable_http_app( ) # Add protected resource metadata endpoint if configured as RS - if auth and auth.resource_server_url: # pragma: no cover + if auth and auth.resource_server_url: routes.extend( create_protected_resource_routes( resource_url=auth.resource_server_url, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index ebf534d792..dd93ad7e17 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -93,7 +93,7 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: for scope in requested_scopes: if scope not in allowed_scopes: # pragma: no branch raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes # pragma: no cover + return requested_scopes def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: if redirect_uri is not None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae6..9c72a23844 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -451,7 +451,7 @@ async def _handle_session_message(message: SessionMessage) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index baca975917..3dda864cd5 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -11,7 +11,7 @@ import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Protocol +from typing import Any, Protocol import httpx from httpx_sse import ServerSentEvent, aconnect_sse @@ -25,6 +25,8 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.server import Server +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier +from mcp.server.auth.settings import AuthSettings from mcp.server.mcpserver import MCPServer from mcp.server.sse import SseServerTransport from mcp.server.streamable_http import EventStore @@ -154,6 +156,9 @@ async def mounted_app( transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, on_request: Callable[[httpx.Request], Awaitable[None]] | None = None, headers: dict[str, str] | None = None, + auth: AuthSettings | None = None, + token_verifier: TokenVerifier | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, ) -> AsyncIterator[tuple[httpx.AsyncClient, StreamableHTTPSessionManager]]: """Mount the server's streamable HTTP app on the in-process bridge and yield an httpx client. @@ -167,11 +172,15 @@ async def mounted_app( DNS-rebinding protection is disabled by default; pass explicit settings (or `None` for the localhost auto-enable behaviour) to test the protection itself. """ - app = server.streamable_http_app( + lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server + app = lowlevel.streamable_http_app( stateless_http=stateless_http, event_store=event_store, retry_interval=retry_interval, transport_security=transport_security, + auth=auth, + token_verifier=token_verifier, + auth_server_provider=auth_server_provider, ) event_hooks = {"request": [on_request]} if on_request is not None else None async with server.session_manager.run(): diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index b5897ee46d..4e072ae254 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -1896,42 +1896,37 @@ def __post_init__(self) -> None: "(and revocation when supported)." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." - ), ), "hosting:auth:aud-validation": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#access-token-usage", behavior="The resource server validates that the token audience matches its resource identifier.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + divergence=Divergence( + note=( + "BearerAuthBackend never inspects AccessToken.resource; a token issued for a different " + "resource is accepted. Spec MUST." + ), + ), ), "hosting:auth:authinfo-propagates": Requirement( source="sdk", behavior="A valid token's auth info is exposed to request handlers.", transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." - ), ), "hosting:auth:expired-401": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", behavior="An expired token returns 401 invalid_token.", transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", ), ), "hosting:auth:invalid-401": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", behavior="A malformed bearer token or token-verification failure returns 401 with WWW-Authenticate.", transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." + divergence=Divergence( + note="The challenge carries no `scope` parameter; see the note on hosting:auth:missing-401.", ), ), "hosting:auth:metadata-endpoints": Requirement( @@ -1942,10 +1937,6 @@ def __post_init__(self) -> None: "at its own." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." - ), ), "hosting:auth:missing-401": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", @@ -1954,9 +1945,14 @@ def __post_init__(self) -> None: "carries resource_metadata (one of the spec's two permitted discovery mechanisms)." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/server/auth/; interaction-level coverage " - "planned with the auth tests in this suite." + divergence=Divergence( + note=( + "The SDK never emits a `scope` parameter in any WWW-Authenticate challenge — neither the " + "discovery-time 401 (#protected-resource-metadata-discovery-requirements SHOULD) nor the " + "runtime 403 (#runtime-insufficient-scope-errors SHOULD); and for the no-credentials case " + 'it emits error="invalid_token", which RFC 6750 Section 3.1 says SHOULD NOT appear when no ' + "authentication information was presented." + ), ), ), "hosting:auth:prm:authorization-servers-field": Requirement( @@ -1965,7 +1961,14 @@ def __post_init__(self) -> None: "The protected-resource metadata document includes an authorization_servers array with at least one entry." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "hosting:auth:query-token-ignored": Requirement( + source="sdk", + behavior=( + "An access token presented in the URI query string is not accepted; the request is treated as " + "unauthenticated." + ), + transports=("streamable-http",), ), "hosting:auth:scope-403": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#runtime-insufficient-scope-errors", @@ -1974,7 +1977,81 @@ def __post_init__(self) -> None: "insufficient_scope, the required scope, and resource_metadata." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + divergence=Divergence( + note=( + 'The SDK emits error="insufficient_scope" and error_description but never the `scope` ' + "parameter the spec SHOULD include; the SDK client reads `scope` from this header to drive " + "step-up (utils.py extract_scope_from_www_auth) — a resource-server/client asymmetry." + ), + ), + ), + "hosting:auth:as:authorize-requires-pkce": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled authorization endpoint rejects an authorize request that omits " + "`code_challenge` with `invalid_request`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:verifier-mismatch": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `code_verifier` " + "does not hash to the stored `code_challenge` with `invalid_grant`." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:code-single-use": Requirement( + source="sdk", + behavior=( + "An authorization code can be exchanged exactly once; a second exchange of the same code " + "is rejected with `invalid_grant`. Enforced by the provider deleting the code on first use; " + "the handler relies on `load_authorization_code` returning None." + ), + transports=("streamable-http",), + ), + "hosting:auth:as:redirect-uri-binding": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", + behavior=( + "The bundled token endpoint rejects an authorization-code exchange whose `redirect_uri` " + "differs from the one used at authorize; the bundled authorize endpoint rejects a " + "`redirect_uri` not in the client's registered list without redirecting to it." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "RFC 6749 §5.2 assigns redirect_uri mismatch at the token endpoint to invalid_grant; " + "the SDK's TokenHandler returns invalid_request (src/mcp/server/auth/handlers/token.py:157). " + "The rejection itself is the security-relevant property and is correct." + ), + ), + ), + "hosting:auth:as:redirect-uri-scheme": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#communication-security", + behavior=( + "The bundled registration endpoint accepts only redirect URIs that use HTTPS or target a loopback host." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "Not enforced: the registration handler models redirect_uris as AnyUrl with no scheme or " + "host check, so http://evil.example/callback is accepted and registered. The spec's " + "localhost-or-HTTPS rule is left to the provider implementation." + ), + ), + ), + "hosting:auth:as:token-cache-headers": Requirement( + source="sdk", + behavior=("Every token-endpoint response carries `Cache-Control: no-store` and `Pragma: no-cache`."), + transports=("streamable-http",), + ), + "hosting:auth:as:register-error-response": Requirement( + source="sdk", + behavior=( + "The bundled registration endpoint answers invalid client metadata with HTTP 400 and an " + "RFC 7591 error body." + ), + transports=("streamable-http",), ), # ═══════════════════════════════════════════════════════════════════════════ # Hosting: resumability @@ -2304,16 +2381,11 @@ def __post_init__(self) -> None: "If the server still returns 401 after a successful authorization, the client fails instead of looping." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:401-triggers-flow": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", behavior="A 401 on a request triggers the OAuth authorization flow once.", transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:403-scope-upgrade": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#step-up-authorization-flow", @@ -2321,7 +2393,6 @@ def __post_init__(self) -> None: "A 403 with WWW-Authenticate triggers a scope-upgrade authorization attempt; repeated 403s do not loop." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:as-metadata-discovery:priority-order": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", @@ -2331,10 +2402,45 @@ def __post_init__(self) -> None: "root-path forms when the issuer URL has no path)." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." + ), + "client-auth:as-metadata-discovery:issuer-validation": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-metadata-discovery", + behavior=( + "The client rejects authorization-server metadata whose issuer does not match the URL the " + "metadata was retrieved from (RFC 8414 section 3.3)." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The SDK parses authorization-server metadata without comparing issuer to the discovery " + "URL; a mismatched issuer is accepted and the flow proceeds. The SDK also does not " + "validate that the document's authorization_endpoint, token_endpoint, and " + "registration_endpoint use http(s) schemes." + ), + ), + ), + "client-auth:authorize:error-surfaces": Requirement( + source=f"{SPEC_BASE_URL}/basic/authorization#authorization-flow-steps", + behavior=( + "An OAuth error redirect from the authorize endpoint aborts the flow before any token " + "request is issued, surfacing as an error to the caller." + ), + transports=("streamable-http",), + divergence=Divergence( + note=( + "The callback contract has no error form, so the client surfaces 'No authorization code " + "received' rather than the redirect's `error`/`error_description` values." + ), + ), + ), + "client-auth:authorize:offline-access-consent": Requirement( + source="sdk", + behavior=( + "When the authorization server's metadata advertises offline_access in scopes_supported and " + "the client uses the refresh_token grant, offline_access is appended to the requested scope " + "and prompt=consent is added to the authorize request." ), + transports=("streamable-http",), ), "client-auth:bearer-header:every-request": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-requirements", @@ -2343,16 +2449,11 @@ def __post_init__(self) -> None: "request to the MCP server, never in the query string." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:cimd": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#client-id-metadata-documents", behavior="The client can use a client-ID metadata document URL as its OAuth client_id instead of registration.", transports=("streamable-http",), - deferred="Not implemented in the SDK: client-ID metadata documents are not supported.", ), "client-auth:client-credentials": Requirement( source="sdk", @@ -2361,10 +2462,14 @@ def __post_init__(self) -> None: "bearer token authorizes subsequent requests." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/auth/; interaction-level coverage " - "planned with the auth tests in this suite." + ), + "client-auth:dcr:registration-error-surfaces": Requirement( + source="sdk", + behavior=( + "A 400 from the registration endpoint surfaces to the caller as an OAuthRegistrationError " + "carrying the status and the server's RFC 7591 error body." ), + transports=("streamable-http",), ), "client-auth:dcr": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#dynamic-client-registration", @@ -2373,10 +2478,6 @@ def __post_init__(self) -> None: "client_id is preconfigured." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:invalid-client-clears-all": Requirement( source="sdk", @@ -2384,13 +2485,22 @@ def __post_init__(self) -> None: "An invalid-client or unauthorized-client error during authorization invalidates all stored credentials." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + divergence=Divergence( + note=( + "The token-response handlers do not parse the error body; an invalid_client or " + "unauthorized_client response leaves stored client_info untouched. The TypeScript SDK " + "clears it." + ), + ), + deferred=( + "Not implemented in the SDK: no token-response path inspects the error code to decide " + "whether to clear client_info." + ), ), "client-auth:invalid-grant-clears-tokens": Requirement( source="sdk", behavior="An invalid-grant error during authorization invalidates only the stored tokens.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:pkce:refuse-if-unsupported": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", @@ -2399,7 +2509,12 @@ def __post_init__(self) -> None: "code_challenge_methods_supported, since PKCE support cannot be verified." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + divergence=Divergence( + note=( + "The client never inspects code_challenge_methods_supported and proceeds with PKCE S256 " + "regardless; the spec MUST is not enforced." + ), + ), ), "client-auth:pkce:s256": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-code-protection", @@ -2408,10 +2523,6 @@ def __post_init__(self) -> None: "the matching verifier." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:pre-registration": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#preregistration", @@ -2419,16 +2530,11 @@ def __post_init__(self) -> None: "A client with statically preconfigured credentials skips dynamic registration and uses them directly." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:private-key-jwt": Requirement( source="sdk", behavior="The client can authenticate the client-credentials grant with a signed JWT assertion.", transports=("streamable-http",), - deferred="Not implemented in the SDK: JWT-assertion client authentication is not supported.", ), "client-auth:prm-discovery:fallback-order": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#protected-resource-metadata-discovery-requirements", @@ -2437,10 +2543,15 @@ def __post_init__(self) -> None: "well-known protected-resource locations in the documented order." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." + ), + "client-auth:prm-discovery:no-prm-fallback": Requirement( + source="sdk", + behavior=( + "When every protected-resource metadata probe fails, the client falls back to discovering " + "authorization-server metadata directly at the MCP server's origin (the legacy 2025-03-26 path) " + "rather than aborting." ), + transports=("streamable-http",), ), "client-auth:prm-resource-mismatch": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#authorization-server-location", @@ -2449,7 +2560,15 @@ def __post_init__(self) -> None: "match the server URL it is connecting to." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + ), + "client-auth:refresh:transparent": Requirement( + source="sdk", + behavior=( + "An access token the client considers expired is transparently refreshed before the next " + "request, using the stored refresh token; the refresh request includes the resource indicator " + "and the new token is persisted." + ), + transports=("streamable-http",), ), "client-auth:resource-parameter": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#resource-parameter-implementation", @@ -2458,10 +2577,6 @@ def __post_init__(self) -> None: "authorization request and the token request." ), transports=("streamable-http",), - deferred=( - "Not yet covered here; existing coverage in tests/client/test_auth.py; interaction-level " - "coverage planned with the auth tests in this suite." - ), ), "client-auth:scope-selection:priority": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", @@ -2470,7 +2585,6 @@ def __post_init__(self) -> None: "protected-resource metadata, and otherwise omits scope." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:state:verify": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#open-redirection", @@ -2479,13 +2593,11 @@ def __post_init__(self) -> None: "missing or mismatched state are discarded." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:token-endpoint-auth-method": Requirement( source="sdk", behavior="The client authenticates to the token endpoint using the auth method established at registration.", transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "client-auth:token-provenance": Requirement( source=f"{SPEC_BASE_URL}/basic/authorization#token-handling", @@ -2494,7 +2606,10 @@ def __post_init__(self) -> None: "never tokens obtained elsewhere." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", + deferred=( + "Untestable negative through the public API: there is no path to inject a token obtained " + "elsewhere into the auth provider's state, so the absence cannot be observed end to end." + ), ), # ═══════════════════════════════════════════════════════════════════════════ # stdio transport @@ -2618,7 +2733,6 @@ def __post_init__(self) -> None: "attempt requires authorization, the code is exchanged, and a subsequent connection succeeds." ), transports=("streamable-http",), - deferred="Not yet covered here: planned with the auth interaction tests in this suite.", ), "flow:resume:tool-call-resumption-token": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#resumability-and-redelivery", diff --git a/tests/interaction/auth/__init__.py b/tests/interaction/auth/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py new file mode 100644 index 0000000000..8ee8263c6f --- /dev/null +++ b/tests/interaction/auth/_harness.py @@ -0,0 +1,461 @@ +"""In-process harness for the auth interaction tests. + +Co-hosts the SDK's authorization-server routes, protected-resource metadata route, and the +bearer-gated MCP endpoint on one Starlette app via `Server.streamable_http_app(auth=..., +token_verifier=..., auth_server_provider=...)`, drives that app through the streaming bridge +on a single `httpx.AsyncClient` carrying `auth=OAuthClientProvider(...)`, and completes the +authorize redirect headlessly by GETing the URL through the same bridge and parsing the code +from the 302 `Location`. The whole authorization-code flow runs in one event loop with no +sockets, no threads, and no real time. +""" + +import json +from collections.abc import AsyncIterator, Callable, Mapping, Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import parse_qs, parse_qsl, urlsplit + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.client.auth import OAuthClientProvider +from mcp.client.client import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server +from mcp.server.auth.provider import AccessToken, ProviderTokenVerifier +from mcp.server.auth.settings import AuthSettings, ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from tests.interaction._connect import BASE_URL, NO_DNS_REBINDING_PROTECTION +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +REDIRECT_URI = f"{BASE_URL}/oauth/callback" + +AppShim = Callable[[ASGIApp], ASGIApp] + + +@dataclass +class RecordedRequest: + """A snapshot of an `httpx.Request` at the moment it was sent. + + The auth flow re-yields the same `httpx.Request` object after mutating its headers in + place for the retry, so tests that need to assert on the first attempt's headers must + capture a copy rather than a live reference. `record_requests` produces these. + """ + + method: str + url: httpx.URL + headers: dict[str, str] + content: bytes + + @property + def path(self) -> str: + return self.url.path + + +def record_requests() -> tuple[list[RecordedRequest], Callable[[httpx.Request], None]]: + """Build an `on_request` callback that snapshots each request, and the list it appends to.""" + recorded: list[RecordedRequest] = [] + + def on_request(request: httpx.Request) -> None: + recorded.append( + RecordedRequest( + method=request.method, + url=request.url, + headers=dict(request.headers), + content=bytes(request.content), + ) + ) + + return recorded, on_request + + +def metadata_body(model: BaseModel, **extra: object) -> bytes: + """Serialize a metadata model to a JSON body for `shimmed_app(serve=...)`. + + `extra` keys are merged into the serialized object so a test can inject fields the model + does not declare (e.g. an unknown extension field, to prove the client's parser tolerates + unrecognized members per RFC 8414/9728 §3.2). The model itself would silently drop such + fields at construction, so they have to be added after serialization. + """ + document = model.model_dump(by_alias=True, mode="json", exclude_none=True) + document.update(extra) + return json.dumps(document).encode() + + +class StaticTokenVerifier: + """A `TokenVerifier` backed by a fixed token→`AccessToken` mapping. + + Any token string not in the mapping verifies to `None`, which the bearer middleware treats + as an unrecognized token. Tests seed the mapping with the exact token shapes (valid, expired, + wrong scope, wrong audience) they need so the resource-server gate's behaviour is asserted in + isolation from the authorization-server provider. + """ + + def __init__(self, tokens: Mapping[str, AccessToken]) -> None: + self._tokens = dict(tokens) + + async def verify_token(self, token: str) -> AccessToken | None: + return self._tokens.get(token) + + +class InMemoryTokenStorage: + """A `TokenStorage` that holds tokens and client info as instance attributes. + + Tests pre-seed `client_info` (via the constructor or by assignment) to drive the + pre-registered path, and read both attributes after the flow to assert what the SDK + persisted. + """ + + def __init__(self, *, client_info: OAuthClientInformationFull | None = None) -> None: + self.tokens: OAuthToken | None = None + self.client_info: OAuthClientInformationFull | None = client_info + + async def get_tokens(self) -> OAuthToken | None: + return self.tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self.tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self.client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self.client_info = client_info + + +class HeadlessOAuth: + """Completes the authorize step in-process by following the redirect through the bridge. + + `redirect_handler` GETs the authorize URL on the bound client (with `auth=None` so the + request does not re-enter the locked auth flow), parses `code` and `state` from the 302 + `Location`, and stashes them; `callback_handler` returns the stashed pair. Tests inspect + `authorize_url` to assert what the SDK put on the authorize request. + + `state_override`: when set, `callback_handler` returns this value as the state instead of + the one parsed from the redirect, so tests can drive the state-mismatch path. + """ + + def __init__(self, *, state_override: str | None = None) -> None: + self.authorize_url: str | None = None + self.authorize_urls: list[str] = [] + self.error: str | None = None + self._state_override = state_override + self._http: httpx.AsyncClient | None = None + self._code: str = "" + self._state: str | None = None + + def bind(self, http_client: httpx.AsyncClient) -> None: + self._http = http_client + + async def redirect_handler(self, authorization_url: str) -> None: + assert self._http is not None + self.authorize_url = authorization_url + self.authorize_urls.append(authorization_url) + # auth=None is load-bearing: without it the GET re-enters OAuthClientProvider.async_auth_flow + # through its context lock and the flow deadlocks. + response = await self._http.get(authorization_url, follow_redirects=False, auth=None) + assert response.status_code == 302, f"authorize endpoint returned {response.status_code}: {response.text}" + params = parse_qs(urlsplit(response.headers["location"]).query) + self._code = params.get("code", [""])[0] + self._state = params.get("state", [None])[0] + self.error = params.get("error", [None])[0] + + async def callback_handler(self) -> tuple[str, str | None]: + return self._code, self._state_override if self._state_override is not None else self._state + + +def auth_settings( + *, required_scopes: Sequence[str] = ("mcp",), valid_scopes: Sequence[str] | None = None +) -> AuthSettings: + """Build `AuthSettings` for the co-hosted authorization + resource server. + + The issuer and resource URLs use the suite's loopback origin, which `validate_issuer_url` + accepts in lieu of HTTPS. Dynamic client registration is enabled. `valid_scopes` defaults + to `required_scopes` so a client requesting exactly those passes registration scope + validation; tests pass a wider set when they need the protected-resource metadata's + `scopes_supported` (which mirrors `required_scopes`) to differ from what the client may + register or when AS metadata should advertise additional scopes such as `offline_access`. + """ + required = list(required_scopes) + valid = list(valid_scopes) if valid_scopes is not None else required + return AuthSettings( + issuer_url=AnyHttpUrl(BASE_URL), + resource_server_url=AnyHttpUrl(f"{BASE_URL}/mcp"), + required_scopes=required, + client_registration_options=ClientRegistrationOptions( + enabled=True, valid_scopes=valid, default_scopes=required + ), + revocation_options=RevocationOptions(enabled=False), + ) + + +def oauth_client_metadata() -> OAuthClientMetadata: + """Build the client's registration metadata. + + `scope` is left unset so the SDK's scope-selection strategy chooses one from the server's + metadata before registration. + """ + return OAuthClientMetadata( + client_name="interaction-suite", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + ) + + +def shimmed_app( + app: ASGIApp, + *, + not_found: frozenset[str] = frozenset(), + serve: Mapping[str, bytes | tuple[int, bytes]] | None = None, +) -> ASGIApp: + """Wrap an ASGI app so specific paths return canned responses before reaching the real app. + + Paths in `serve` return the given body as `application/json` (status 200, or the supplied + status when the value is a `(status, body)` pair); paths in `not_found` return 404; + everything else reaches the wrapped app unchanged. Used by the discovery tests to make a + well-known endpoint 404 or return alternate metadata while keeping the real authorization + and MCP endpoints behind it. + """ + overrides: dict[str, tuple[int, bytes]] = { + path: value if isinstance(value, tuple) else (200, value) for path, value in (serve or {}).items() + } + + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + path = scope["path"] + if path in overrides: + status, body = overrides[path] + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + return + if path in not_found: + await send({"type": "http.response.start", "status": 404, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + +def shim( + *, not_found: frozenset[str] = frozenset(), serve: Mapping[str, bytes | tuple[int, bytes]] | None = None +) -> AppShim: + """Build an `app_shim` for `connect_with_oauth` that applies `shimmed_app` with these overrides.""" + return lambda app: shimmed_app(app, not_found=not_found, serve=serve) + + +@dataclass +class _FirstChallenge: + """ASGI shim that answers the first request to a path with 401 + a given WWW-Authenticate. + + Subsequent requests pass through to the wrapped app. Used to make the initial 401 carry + parameters (such as `scope=`) that the SDK's own bearer middleware cannot be configured + to emit, so client behaviour driven by those parameters is reachable end to end. Reserve + this pattern for behaviour the real server cannot be made to produce. + """ + + app: ASGIApp + path: str + www_authenticate: str + _seen: set[str] = field(default_factory=set[str]) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == self.path and self.path not in self._seen: + self._seen.add(self.path) + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"www-authenticate", self.www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await self.app(scope, receive, send) + + +def first_challenge_shim(www_authenticate: str, *, path: str = "/mcp") -> Callable[[ASGIApp], ASGIApp]: + """Build an `app_shim` that 401s the first request to `path` with the given header value.""" + return lambda app: _FirstChallenge(app, path, www_authenticate) + + +def step_up_shim(www_authenticate: str, *, on_nth_authenticated_post: int = 2) -> AppShim: + """Build an `app_shim` that 403s the Nth authenticated POST to `/mcp` with the given challenge. + + Subsequent requests pass through. Used to drive the client's `insufficient_scope` step-up + handling: the SDK's bearer middleware never emits `scope=` in its 403 challenge (see the + divergence on `hosting:auth:scope-403`), so the test supplies the 403 itself. Reserve this + pattern for behaviour the real server cannot be made to produce. + + The default `on_nth_authenticated_post=2` targets the `notifications/initialized` POST: the + first authenticated POST is the auth flow's retry of the original initialize request (yielded + after the 401 branch, where the generator ends without inspecting the response), so a 403 + there would not reach the step-up handler. + """ + seen = 0 + fired = False + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + nonlocal seen, fired + if ( + not fired + and scope["type"] == "http" + and scope["path"] == "/mcp" + and scope["method"] == "POST" + and any(name == b"authorization" for name, _ in scope["headers"]) + ): + seen += 1 + if seen < on_nth_authenticated_post: + await app(scope, receive, send) + return + fired = True + await send( + { + "type": "http.response.start", + "status": 403, + "headers": [(b"www-authenticate", www_authenticate.encode())], + } + ) + await send({"type": "http.response.body", "body": b""}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +def m2m_token_shim(provider: InMemoryAuthorizationServerProvider, *, scopes: list[str]) -> AppShim: + """Build an `app_shim` that handles `grant_type=client_credentials` at `/token`. + + The SDK server's `TokenHandler` only routes `authorization_code` and `refresh_token`, so a + `client_credentials` request would fail discriminator validation. This shim mints a token via + `provider.mint_access_token` so the M2M client providers can complete e2e against the real + bearer middleware. The shim is harness; the SDK-under-test is the client provider, whose + outbound `/token` body the test asserts. The shim does not authenticate the client (no + credential check) because the test asserts the credentials on the recorded request, not on + the server's acceptance. + """ + + def factory(app: ASGIApp) -> ASGIApp: + async def wrapped(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http" and scope["path"] == "/token" and scope["method"] == "POST": + # The streaming bridge buffers the request body and delivers it in a single + # http.request event, so one receive is sufficient. + message = await receive() + assert not message.get("more_body", False) + form = dict(parse_qsl(message.get("body", b"").decode())) + assert form.get("grant_type") == "client_credentials", ( + f"m2m_token_shim only handles client_credentials; got {form.get('grant_type')!r}" + ) + access = provider.mint_access_token(client_id="m2m", scopes=scopes, resource=form.get("resource")) + token = OAuthToken(access_token=access, token_type="Bearer", expires_in=3600, scope=" ".join(scopes)) + response_body = token.model_dump_json(exclude_none=True).encode() + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(response_body)).encode()), + (b"cache-control", b"no-store"), + ], + } + ) + await send({"type": "http.response.body", "body": response_body}) + return + await app(scope, receive, send) + + return wrapped + + return factory + + +@asynccontextmanager +async def connect_with_oauth( + server: Server, + *, + provider: InMemoryAuthorizationServerProvider, + settings: AuthSettings | None = None, + storage: InMemoryTokenStorage | None = None, + client_metadata: OAuthClientMetadata | None = None, + client_metadata_url: str | None = None, + headless: HeadlessOAuth | None = None, + auth: httpx.Auth | None = None, + verify_tokens: bool = True, + app_shim: Callable[[ASGIApp], ASGIApp] | None = None, + on_request: Callable[[httpx.Request], None] | None = None, +) -> AsyncIterator[tuple[Client, HeadlessOAuth]]: + """Connect a `Client` to a server's bearer-gated streamable-HTTP app, completing OAuth in process. + + Yields the connected `Client` and the `HeadlessOAuth` whose `authorize_url` records what the + SDK put on the authorize request. `on_request` records every HTTP request the underlying + `httpx.AsyncClient` issues, including those yielded from inside the auth flow. + + `headless`: supply a pre-configured `HeadlessOAuth` to override the callback behaviour + (state mismatch, error redirects). `verify_tokens=False` mounts the MCP endpoint without + the bearer middleware so a flow driven by a shimmed 401 completes regardless of the granted + scopes. `app_shim` wraps the built Starlette app before it reaches the bridge transport, + for tests that need to intercept or rewrite specific server responses. + + `auth`: supply a pre-built `httpx.Auth` (such as `ClientCredentialsOAuthProvider`) to use + instead of constructing the default `OAuthClientProvider`; in that case `storage`, + `client_metadata`, `client_metadata_url`, and `headless` are unused (the yielded + `HeadlessOAuth` is never invoked and its `authorize_url` stays None). + """ + settings = settings if settings is not None else auth_settings() + storage = storage if storage is not None else InMemoryTokenStorage() + client_metadata = client_metadata if client_metadata is not None else oauth_client_metadata() + headless = headless if headless is not None else HeadlessOAuth() + + oauth = ( + auth + if auth is not None + else OAuthClientProvider( + server_url=f"{BASE_URL}/mcp", + client_metadata=client_metadata, + storage=storage, + redirect_handler=headless.redirect_handler, + callback_handler=headless.callback_handler, + client_metadata_url=client_metadata_url, + ) + ) + + app: ASGIApp = server.streamable_http_app( + auth=settings, + token_verifier=ProviderTokenVerifier(provider) if verify_tokens else None, + auth_server_provider=provider, + transport_security=NO_DNS_REBINDING_PROTECTION, + ) + if app_shim is not None: + app = app_shim(app) + + event_hooks: dict[str, list[Callable[..., Any]]] | None = None + if on_request is not None: + record = on_request + + async def hook(request: httpx.Request) -> None: + record(request) + + event_hooks = {"request": [hook]} + + async with server.session_manager.run(): + async with httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) as http_client: + headless.bind(http_client) + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + async with Client(transport) as client: + yield client, headless diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py new file mode 100644 index 0000000000..34b434e4a9 --- /dev/null +++ b/tests/interaction/auth/_provider.py @@ -0,0 +1,187 @@ +"""An in-memory implementation of the SDK's OAuth authorization-server provider protocol. + +The provider holds clients, authorization codes, refresh tokens and access tokens in plain +instance dicts so tests can inspect them; tokens are minted from `secrets.token_hex` so the +values are unique without being predictable. The behaviour mirrors what the SDK's authorization +handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` +issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction +suite drives are implemented; methods the tests do not yet reach raise `NotImplementedError` +and are filled in by the chunk that first exercises them. +""" + +import secrets +import time + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthAuthorizationServerProvider, + RefreshToken, + TokenError, + construct_redirect_uri, +) +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + +_TOKEN_LIFETIME_SECONDS = 3600 + + +class InMemoryAuthorizationServerProvider( + OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken] +): + """An OAuth authorization-server provider backed by in-memory dicts. + + Holds registered clients, issued codes, refresh tokens and access tokens as instance state + so tests can both drive the SDK's authorization handlers and inspect what was issued. + + Knobs: + `default_scopes`: scopes granted when an authorize request supplies none. + `deny_authorize`: every authorize request returns an `error=access_denied` redirect. + `issue_expired_first`: the first issued token's `expires_in` is in the past so the + client immediately considers it expired and refreshes; the server-side + `AccessToken.expires_at` stays in the future so the bearer middleware accepts it + on the retry that completes the connect. + `fail_next_refresh`: the next refresh-token exchange raises `invalid_grant` once. + `reject_all_tokens`: `load_access_token` returns None for every token, so the bearer + middleware 401s every authenticated request. + """ + + def __init__( + self, + *, + default_scopes: list[str] | None = None, + deny_authorize: bool = False, + issue_expired_first: bool = False, + fail_next_refresh: bool = False, + reject_all_tokens: bool = False, + ) -> None: + self._default_scopes = list(default_scopes) if default_scopes is not None else ["mcp"] + self._issuer = "http://127.0.0.1:8000" + self._deny_authorize = deny_authorize + self._issue_expired_first = issue_expired_first + self._fail_next_refresh = fail_next_refresh + self._reject_all_tokens = reject_all_tokens + self._tokens_issued = 0 + self.clients: dict[str, OAuthClientInformationFull] = {} + self.codes: dict[str, AuthorizationCode] = {} + self.refresh_tokens: dict[str, RefreshToken] = {} + self.access_tokens: dict[str, AccessToken] = {} + + def _next_expires_in(self) -> int: + self._tokens_issued += 1 + if self._issue_expired_first and self._tokens_issued == 1: + return -_TOKEN_LIFETIME_SECONDS + return _TOKEN_LIFETIME_SECONDS + + def mint_access_token(self, *, client_id: str, scopes: list[str], resource: str | None = None) -> str: + """Mint and store an access token, returning its value. + + Used by the auth-code and refresh exchanges and by the M2M `/token` shim. The + server-side `expires_at` is always in the future regardless of `issue_expired_first`, + which only affects what the client is told. + """ + access = f"access_{secrets.token_hex(16)}" + self.access_tokens[access] = AccessToken( + token=access, + client_id=client_id, + scopes=scopes, + expires_at=int(time.time()) + _TOKEN_LIFETIME_SECONDS, + resource=resource, + ) + return access + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + assert client_info.client_id is not None + self.clients[client_info.client_id] = client_info + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """Mint an authorization code immediately and return the redirect carrying it. + + A real provider would interpose user consent here; the test provider grants + unconditionally so the headless redirect handler can complete the flow in-process. + When `deny_authorize` is set, returns an `error=access_denied` redirect instead. + """ + assert client.client_id is not None + if self._deny_authorize: + return construct_redirect_uri( + str(params.redirect_uri), error="access_denied", error_description="user denied", state=params.state + ) + code = AuthorizationCode( + code=f"code_{secrets.token_hex(16)}", + client_id=client.client_id, + scopes=params.scopes or self._default_scopes, + expires_at=time.time() + 300, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + redirect_uri_provided_explicitly=params.redirect_uri_provided_explicitly, + resource=params.resource, + ) + self.codes[code.code] = code + # `iss` is RFC 9207's authorization-response issuer identifier — an extra parameter many + # real authorization servers send. Including it on every success redirect proves the + # client tolerates unrecognized callback parameters (RFC 6749 §4.1.2 MUST) by virtue of + # every flow test passing unchanged. + return construct_redirect_uri(str(params.redirect_uri), code=code.code, state=params.state, iss=self._issuer) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + """Mint an access token and a refresh token for a valid authorization code, then consume the code.""" + assert client.client_id is not None + access = self.mint_access_token( + client_id=client.client_id, scopes=authorization_code.scopes, resource=authorization_code.resource + ) + refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[refresh] = RefreshToken( + token=refresh, + client_id=client.client_id, + scopes=authorization_code.scopes, + ) + del self.codes[authorization_code.code] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(authorization_code.scopes), + refresh_token=refresh, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + if self._reject_all_tokens: + return None + return self.access_tokens.get(token) + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None: + return self.refresh_tokens.get(refresh_token) + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: list[str] + ) -> OAuthToken: + """Mint a new access token and rotate the refresh token, consuming the old one.""" + assert client.client_id is not None + if self._fail_next_refresh: + self._fail_next_refresh = False + raise TokenError(error="invalid_grant", error_description="refresh denied by harness") + access = self.mint_access_token(client_id=client.client_id, scopes=scopes) + new_refresh = f"refresh_{secrets.token_hex(16)}" + self.refresh_tokens[new_refresh] = RefreshToken(token=new_refresh, client_id=client.client_id, scopes=scopes) + del self.refresh_tokens[refresh_token.token] + return OAuthToken( + access_token=access, + token_type="Bearer", + expires_in=self._next_expires_in(), + scope=" ".join(scopes), + refresh_token=new_refresh, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Implemented when the bearer/lifecycle tests first exercise revocation.""" + raise NotImplementedError diff --git a/tests/interaction/auth/test_as_handlers.py b/tests/interaction/auth/test_as_handlers.py new file mode 100644 index 0000000000..5cb4e92d86 --- /dev/null +++ b/tests/interaction/auth/test_as_handlers.py @@ -0,0 +1,300 @@ +"""Error-plane behaviour of the SDK's bundled OAuth authorization-server handlers. + +The end-to-end OAuth tests prove the handlers' happy paths; these tests drive the same +mounted authorization server directly with raw httpx so the assertions are the HTTP +semantics (status, redirect target, error body, headers) the OAuth RFCs mandate. Almost +every behaviour here is enforced by the SDK's own handlers; where the pinned output +deviates from the RFC, the manifest entry carries the divergence. +""" + +import base64 +import hashlib +import secrets +from collections.abc import AsyncIterator +from urllib.parse import parse_qs, urlsplit + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import ProviderTokenVerifier +from mcp.shared.auth import OAuthClientInformationFull +from tests.interaction._connect import mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import REDIRECT_URI, auth_settings, oauth_client_metadata +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +async def as_app() -> AsyncIterator[tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider]]: + """Co-host the SDK's authorization-server routes and yield a raw httpx client against them.""" + provider = InMemoryAuthorizationServerProvider() + settings = auth_settings() + async with mounted_app( + Server("guarded"), + auth=settings, + token_verifier=ProviderTokenVerifier(provider), + auth_server_provider=provider, + ) as (http, _): + yield http, provider + + +def _pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair the same way the SDK client does.""" + verifier = secrets.token_urlsafe(48)[:64] + challenge = base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") + return verifier, challenge + + +async def _register_client(http: httpx.AsyncClient) -> OAuthClientInformationFull: + """Dynamically register a client and return its full credentials.""" + response = await http.post("/register", content=oauth_client_metadata().model_dump_json()) + assert response.status_code == 201 + return OAuthClientInformationFull.model_validate_json(response.content) + + +async def _mint_code(http: httpx.AsyncClient) -> tuple[OAuthClientInformationFull, str, str]: + """Register a client, complete a valid authorize step, and return (client_info, code, verifier).""" + client_info = await _register_client(http) + assert client_info.client_id is not None + verifier, challenge = _pkce_pair() + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "code_challenge": challenge, + "code_challenge_method": "S256", + "state": "s", + }, + follow_redirects=False, + ) + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + code = parse_qs(redirect.query)["code"][0] + return client_info, code, verifier + + +def _token_form(client_info: OAuthClientInformationFull, **overrides: str) -> dict[str, str]: + """Build the form body for an authorization-code token request, with the defaults a real client would send.""" + assert client_info.client_id is not None + assert client_info.client_secret is not None + form = { + "grant_type": "authorization_code", + "client_id": client_info.client_id, + "client_secret": client_info.client_secret, + "redirect_uri": REDIRECT_URI, + } + form.update(overrides) + return form + + +@requirement("hosting:auth:as:authorize-requires-pkce") +async def test_authorize_without_a_code_challenge_is_rejected_with_invalid_request( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request omitting `code_challenge` is redirected back with `error=invalid_request`. + + PKCE is mandatory: the bundled authorize handler models `code_challenge` as a required field, so + a code without a stored challenge can never be issued. That makes the PKCE-downgrade attack (a + token request carrying a verifier for a code minted without a challenge) structurally impossible + through these handlers, so no separate downgrade-guard test is needed. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": REDIRECT_URI, + "state": "abc", + }, + follow_redirects=False, + ) + + assert response.status_code == 302 + redirect = urlsplit(response.headers["location"]) + assert f"{redirect.scheme}://{redirect.netloc}{redirect.path}" == REDIRECT_URI + params = parse_qs(redirect.query) + assert params["error"] == ["invalid_request"] + assert params["state"] == ["abc"] + assert "code_challenge" in params["error_description"][0] + + +@requirement("hosting:auth:as:verifier-mismatch") +async def test_a_mismatched_code_verifier_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `code_verifier` does not hash to the stored challenge is rejected.""" + http, _ = as_app + client_info, code, _ = await _mint_code(http) + + response = await http.post("/token", data=_token_form(client_info, code=code, code_verifier="0" * 64)) + + assert response.status_code == 400 + assert response.json() == snapshot({"error": "invalid_grant", "error_description": "incorrect code_verifier"}) + + +@requirement("hosting:auth:as:code-single-use") +async def test_reusing_an_authorization_code_is_rejected_with_invalid_grant( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorization code can be exchanged exactly once; a second exchange is `invalid_grant`. + + The handler does not track used codes itself: it returns `invalid_grant` whenever the provider's + `load_authorization_code` returns None, and the in-memory provider deletes the code on first + exchange. The test proves the combination enforces single-use; a provider that did not consume + codes would not get this guarantee from the handler. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + first = await http.post("/token", data=form) + assert first.status_code == 200 + assert first.json()["token_type"] == "Bearer" + + second = await http.post("/token", data=form) + assert second.status_code == 400 + assert second.json() == snapshot( + {"error": "invalid_grant", "error_description": "authorization code does not exist"} + ) + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_a_redirect_uri_differing_from_authorize_is_rejected_at_the_token_endpoint( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A token exchange whose `redirect_uri` differs from the one used at authorize is rejected. + + This is the security-critical half of redirect-URI binding: a code intercepted via redirect + substitution cannot be redeemed because the attacker cannot reproduce the original authorize + redirect URI at the token endpoint. RFC 6749 §5.2 specifies `invalid_grant` for this case; + the SDK returns `invalid_request` (see the divergence on the requirement). The rejection + itself is the security property and is correct. + """ + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + + response = await http.post( + "/token", + data=_token_form(client_info, code=code, code_verifier=verifier, redirect_uri=f"{REDIRECT_URI}/different"), + ) + + assert response.status_code == 400 + assert response.json() == snapshot( + { + "error": "invalid_request", + "error_description": "redirect_uri did not match the one used when creating auth code", + } + ) + + +@requirement("hosting:auth:as:token-cache-headers") +async def test_token_responses_carry_cache_control_no_store( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Every token-endpoint response (success and error) carries `Cache-Control: no-store`.""" + http, _ = as_app + client_info, code, verifier = await _mint_code(http) + form = _token_form(client_info, code=code, code_verifier=verifier) + + success = await http.post("/token", data=form) + assert success.status_code == 200 + assert success.headers["cache-control"] == "no-store" + assert success.headers["pragma"] == "no-cache" + + failure = await http.post("/token", data=form) + assert failure.status_code == 400 + assert failure.headers["cache-control"] == "no-store" + assert failure.headers["pragma"] == "no-cache" + + +@requirement("hosting:auth:as:register-error-response") +async def test_registration_with_invalid_metadata_is_rejected_with_400( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """Invalid client metadata at the registration endpoint returns 400 with an RFC 7591 error body.""" + http, _ = as_app + + malformed = await http.post("/register", json={"redirect_uris": ["not-a-url"]}) + assert malformed.status_code == 400 + assert malformed.json()["error"] == "invalid_client_metadata" + + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + + no_auth_code = await http.post("/register", json=body | {"grant_types": ["refresh_token"]}) + assert no_auth_code.status_code == 400 + assert no_auth_code.json() == snapshot( + {"error": "invalid_client_metadata", "error_description": "grant_types must include 'authorization_code'"} + ) + + bad_scope = await http.post("/register", json=body | {"scope": "forbidden"}) + assert bad_scope.status_code == 400 + body = bad_scope.json() + assert body["error"] == "invalid_client_metadata" + # The description embeds a set difference whose ordering is not stable, so assert the prefix. + assert body["error_description"].startswith("Requested scopes are not valid: ") + + +@requirement("hosting:auth:as:redirect-uri-binding") +async def test_authorize_with_an_unregistered_redirect_uri_is_rejected_directly( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """An authorize request naming an unregistered `redirect_uri` returns 400 without redirecting to it. + + The security property is that the authorization server never redirects to an unvalidated URI: + the response is a direct JSON error to the user agent, not a 302 to the attacker's host. + """ + http, _ = as_app + client_info = await _register_client(http) + assert client_info.client_id is not None + _, challenge = _pkce_pair() + + response = await http.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info.client_id, + "redirect_uri": "http://127.0.0.1:8000/evil", + "code_challenge": challenge, + "code_challenge_method": "S256", + }, + follow_redirects=False, + ) + + assert response.status_code == 400 + assert "location" not in response.headers + body = response.json() + assert body["error"] == "invalid_request" + assert "not registered" in body["error_description"] + + +@requirement("hosting:auth:as:redirect-uri-scheme") +async def test_a_non_loopback_http_redirect_uri_is_accepted_at_registration( + as_app: tuple[httpx.AsyncClient, InMemoryAuthorizationServerProvider], +) -> None: + """A registration carrying a non-HTTPS, non-loopback redirect URI is accepted. + + The spec requires every redirect URI to be either HTTPS or a loopback host; the bundled + registration handler does not enforce this and registers `http://evil.example/callback` + successfully. See the divergence on the requirement. + """ + http, provider = as_app + body = oauth_client_metadata().model_dump(mode="json", exclude_none=True) + body["redirect_uris"] = ["http://evil.example/callback"] + + response = await http.post("/register", json=body) + + assert response.status_code == 201 + info = OAuthClientInformationFull.model_validate_json(response.content) + assert [str(u) for u in (info.redirect_uris or [])] == ["http://evil.example/callback"] + assert info.client_id in provider.clients diff --git a/tests/interaction/auth/test_authorize_token.py b/tests/interaction/auth/test_authorize_token.py new file mode 100644 index 0000000000..cb8524c097 --- /dev/null +++ b/tests/interaction/auth/test_authorize_token.py @@ -0,0 +1,399 @@ +"""Authorization-request, token-request, and PKCE wire-level invariants of the SDK's OAuth client. + +Every test connects a real `Client` end to end via `connect_with_oauth`; the assertions are on +the parsed authorize URL and the recorded `/token` form body, because those wire shapes are what +the spec mandates and `Client` cannot observe them. The recording uses `record_requests`, which +snapshots each request at send time so the auth flow's in-place header mutation on retry never +affects what was captured for the first attempt. + +Tests #1/#2/#4/#5 share one `recorded_oauth_flow` fixture (one connect, several disjoint +assertions on its recording); the others connect fresh because each needs a different harness +configuration. +""" + +import base64 +import hashlib +import json +import re +from collections.abc import AsyncIterator +from dataclasses import dataclass +from urllib.parse import parse_qsl, quote, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + HeadlessOAuth, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + first_challenge_shim, + record_requests, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict (one value per key).""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + """Filter recorded requests by method and exact path.""" + return [r for r in recorded if r.method == method and r.path == path] + + +@dataclass +class RecordedFlow: + """One completed OAuth connect: every recorded request, plus the parsed authorize URL params.""" + + requests: list[RecordedRequest] + authorize_url: str + + @property + def authorize(self) -> dict[str, str]: + return authorize_params(self.authorize_url) + + @property + def token_request(self) -> RecordedRequest: + token_posts = find(self.requests, "POST", "/token") + assert len(token_posts) == 1 + return token_posts[0] + + +@pytest.fixture +async def recorded_oauth_flow() -> AsyncIterator[RecordedFlow]: + """Run one full OAuth connect with default configuration and yield its recorded wire traffic. + + `valid_scopes` includes `offline_access` so the AS metadata advertises it and the SDK's + SEP-2207 auto-append (and the resulting `prompt=consent`) is exercised; `required_scopes` + stays at `["mcp"]` so the issued token still passes the bearer middleware. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "offline_access"]) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, settings=settings, on_request=on_request) as ( + client, + headless, + ): + await client.list_tools() + + assert headless.authorize_url is not None + yield RecordedFlow(requests=recorded, authorize_url=headless.authorize_url) + + +@requirement("client-auth:pkce:s256") +@requirement("client-auth:resource-parameter") +@requirement("client-auth:authorize:offline-access-consent") +async def test_the_authorize_url_carries_s256_pkce_and_the_resource_indicator( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every spec-mandated parameter appears on the authorize URL with the right value. + + The full key set is snapshotted so a parameter added or dropped fails the test. The + `code_challenge` length bound is the RFC 7636 §4.2 grammar; an S256 challenge is in + practice always 43 characters, so the upper bound is never approached. + """ + params = recorded_oauth_flow.authorize + + assert sorted(params) == snapshot( + [ + "client_id", + "code_challenge", + "code_challenge_method", + "prompt", + "redirect_uri", + "resource", + "response_type", + "scope", + "state", + ] + ) + assert params["response_type"] == "code" + assert params["code_challenge_method"] == "S256" + assert 43 <= len(params["code_challenge"]) <= 128 + # The exact resource value depends on canonical-URI normalisation (a spec ambiguity); pin + # the stable prefix so the test does not lock in a trailing-slash decision. + assert params["resource"].startswith(BASE_URL) + assert params["state"] != "" + + assert params["scope"].split(" ") == snapshot(["mcp", "offline_access"]) + assert params["prompt"] == "consent" + + +@requirement("client-auth:pkce:s256") +async def test_the_code_verifier_on_the_token_request_hashes_to_the_code_challenge( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The PKCE verifier sent on /token is the S256 pre-image of the challenge sent on /authorize. + + The verifier is also checked against RFC 7636 §4.1's length and `unreserved` charset. + """ + challenge = recorded_oauth_flow.authorize["code_challenge"] + verifier = form_body(recorded_oauth_flow.token_request)["code_verifier"] + + assert re.fullmatch(r"[A-Za-z0-9._~-]{43,128}", verifier) + assert base64.urlsafe_b64encode(hashlib.sha256(verifier.encode()).digest()).decode().rstrip("=") == challenge + + +@requirement("client-auth:state:verify") +async def test_a_mismatched_state_on_the_callback_aborts_the_flow() -> None: + """A callback whose state does not match the value sent on /authorize raises and stops the flow. + + The auth flow runs inside the streamable-HTTP client's task group, so the `OAuthFlowError` + reaches the test wrapped in nested single-element exception groups; `pytest.RaisesGroup` + asserts the leaf type and the SDK-authored message prefix (the full message embeds two + random tokens). + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth(state_override="wrong-state") + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^State parameter mismatch:"), flatten_subgroups=True + ): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, headless=headless).__aenter__() + + +@requirement("client-auth:resource-parameter") +async def test_the_authorization_code_token_request_carries_grant_type_code_redirect_and_resource( + recorded_oauth_flow: RecordedFlow, +) -> None: + """The /token form body has exactly the auth-code grant fields, with redirect_uri and resource matching /authorize. + + `client_secret` is present because the SDK's dynamic-registration handler issues a secret + and the client defaults to `client_secret_post`. + """ + token_req = recorded_oauth_flow.token_request + body = form_body(token_req) + + assert sorted(body) == snapshot( + ["client_id", "client_secret", "code", "code_verifier", "grant_type", "redirect_uri", "resource"] + ) + assert body["grant_type"] == "authorization_code" + assert body["code"] != "" + assert body["redirect_uri"] == recorded_oauth_flow.authorize["redirect_uri"] + assert body["resource"] == recorded_oauth_flow.authorize["resource"] + assert token_req.headers["content-type"] == "application/x-www-form-urlencoded" + + +@requirement("client-auth:bearer-header:every-request") +async def test_every_mcp_request_after_auth_carries_the_bearer_header_and_never_a_query_token( + recorded_oauth_flow: RecordedFlow, +) -> None: + """Every MCP request after the flow has `Authorization: Bearer ...` and never `?access_token=`. + + The first /mcp POST is the unauthenticated trigger and is asserted to carry no Authorization + header; that assertion is only meaningful because the recording snapshots requests at send + time (the SDK mutates the same request object in place for the retry). + """ + mcp_posts = find(recorded_oauth_flow.requests, "POST", "/mcp") + assert len(mcp_posts) >= 3 + + assert "authorization" not in mcp_posts[0].headers + for r in mcp_posts[1:]: + assert r.headers["authorization"].startswith("Bearer ") + assert r.headers["authorization"] != "Bearer " + assert "access_token" not in dict(r.url.params) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_a_client_with_a_secret_authenticates_the_token_request_with_http_basic() -> None: + """A `client_secret_basic` client sends URL-encoded credentials in HTTP Basic, not the body. + + Credentials are URL-encoded before base64 per RFC 6749 §2.3.1; the secret contains `/` so + the encoding is observable. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="cid", + client_secret="s/cret", + token_endpoint_auth_method="client_secret_basic", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage = InMemoryTokenStorage(client_info=client_info) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + [token_req] = find(recorded, "POST", "/token") + + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == f"{quote('cid', safe='')}:{quote('s/cret', safe='')}" + assert "client_secret" not in form_body(token_req) + + +@requirement("client-auth:token-endpoint-auth-method") +async def test_the_registered_auth_method_is_used_regardless_of_as_metadata_advertised_methods() -> None: + """The token-endpoint auth method comes from the registered client info, not from AS metadata. + + The shim serves AS metadata advertising only `client_secret_basic`; the client dynamically + registers and the SDK's registration handler issues `client_secret_post`. The client uses + `client_secret_post` (secret in the body, no Basic header) because the SDK reads the + registered `token_endpoint_auth_method`, not `token_endpoint_auth_methods_supported`. Other + SDKs (TypeScript, Go) do consult the AS metadata; this test pins where the python SDK's + selection point lives. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve), on_request=on_request + ) as (client, _): + await client.list_tools() + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content).get("token_endpoint_auth_method") is None + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert "client_secret" in body + assert body["client_secret"] != "" + assert "authorization" not in token_req.headers + + +@requirement("client-auth:scope-selection:priority") +async def test_scope_is_selected_from_the_www_authenticate_challenge_over_prm_metadata() -> None: + """When the 401 challenge carries `scope=`, that value is requested instead of the PRM scopes. + + The SDK's bearer middleware never emits `scope=` in WWW-Authenticate (see the divergence + on `hosting:auth:scope-403`), so the test supplies the first 401 itself via + `first_challenge_shim` and disables token verification so the post-auth retry succeeds + regardless of the granted scope. PRM advertises `["from-prm"]` (it mirrors + `required_scopes`); the challenge says `from-header`; the authorize URL must carry + `from-header`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(default_scopes=["from-header"]) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["from-prm"], valid_scopes=["from-header", "from-prm"]) + challenge = f'Bearer scope="from-header", resource_metadata="{BASE_URL}{PRM_PATH}"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + settings=settings, + verify_tokens=False, + app_shim=first_challenge_shim(challenge), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["scope"] == "from-header" + + [register] = find(recorded, "POST", "/register") + assert json.loads(register.content)["scope"] == "from-header" + + +@requirement("client-auth:pkce:refuse-if-unsupported") +async def test_pkce_is_still_sent_when_as_metadata_omits_code_challenge_methods_supported() -> None: + """AS metadata without `code_challenge_methods_supported` does not stop the client sending PKCE. + + The spec says the client MUST refuse to proceed in this case; the SDK proceeds and the flow + completes. See the divergence on the requirement. + """ + override = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + ) + assert override.code_challenge_methods_supported is None + serve = {ASM_PATH: override.model_dump_json(exclude_none=True).encode()} + + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, app_shim=lambda app: shimmed_app(app, serve=serve) + ) as (client, headless): + result = await client.list_tools() + + assert headless.authorize_url is not None + params = authorize_params(headless.authorize_url) + assert params["code_challenge_method"] == "S256" + assert params["code_challenge"] != "" + assert result.tools[0].name == "echo" + + +@requirement("client-auth:authorize:error-surfaces") +async def test_an_authorize_error_on_the_callback_aborts_the_flow_before_the_token_request() -> None: + """An `error=` redirect from /authorize aborts the flow with no /token request issued. + + The SDK's callback contract is `() -> (code, state)` with no error form, so the failure is + observed as an empty code reaching the SDK and `OAuthFlowError("No authorization code + received")` being raised. The actual `error` value from the redirect is not surfaced to the + caller; that gap is noted in the manifest. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(deny_authorize=True) + server = Server("guarded", on_list_tools=list_tools) + headless = HeadlessOAuth() + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^No authorization code received$"), flatten_subgroups=True + ): + await connect_with_oauth(server, provider=provider, headless=headless, on_request=on_request).__aenter__() + + assert headless.error == "access_denied" + assert find(recorded, "POST", "/token") == [] diff --git a/tests/interaction/auth/test_bearer.py b/tests/interaction/auth/test_bearer.py new file mode 100644 index 0000000000..341a8e0db9 --- /dev/null +++ b/tests/interaction/auth/test_bearer.py @@ -0,0 +1,189 @@ +"""Resource-server bearer-token gate: status codes and `WWW-Authenticate` for each token shape. + +These tests mount only the resource-server side of the auth wiring (a `StaticTokenVerifier` +seeded with hand-built tokens, no authorization-server provider) and speak raw HTTP, since +every assertion is about HTTP semantics the SDK `Client` cannot observe: the 401/403 status, +the `WWW-Authenticate` header structure, and that a wrong-audience token reaches the MCP +endpoint behind the gate. The flow side of the same 401 is `test_flow.py`'s flagship test. +""" + +import time +from collections.abc import AsyncIterator + +import httpx +import pytest +from inline_snapshot import snapshot + +from mcp.server import Server +from mcp.server.auth.provider import AccessToken +from mcp.types import JSONRPCResponse +from tests.interaction._connect import base_headers, initialize_body, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import StaticTokenVerifier, auth_settings + +pytestmark = pytest.mark.anyio + +REQUIRED_SCOPE = "mcp:read" +RESOURCE_METADATA_URL = "http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp" + +_FUTURE = int(time.time()) + 3600 +_PAST = int(time.time()) - 3600 + +TOKENS = { + "tok-valid": AccessToken(token="tok-valid", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_FUTURE), + "tok-expired": AccessToken(token="tok-expired", client_id="c", scopes=[REQUIRED_SCOPE], expires_at=_PAST), + "tok-noscope": AccessToken(token="tok-noscope", client_id="c", scopes=["other:thing"], expires_at=_FUTURE), + "tok-wrong-aud": AccessToken( + token="tok-wrong-aud", + client_id="c", + scopes=[REQUIRED_SCOPE], + expires_at=_FUTURE, + resource="https://other.example/mcp", + ), +} + + +@pytest.fixture +async def protected() -> AsyncIterator[httpx.AsyncClient]: + """A bearer-gated streamable-HTTP app (resource server only) on the in-process bridge.""" + server = Server("rs") + settings = auth_settings(required_scopes=[REQUIRED_SCOPE]) + async with mounted_app(server, auth=settings, token_verifier=StaticTokenVerifier(TOKENS)) as (http, _): + yield http + + +async def post_mcp( + http: httpx.AsyncClient, *, bearer: str | None = None, query: dict[str, str] | None = None +) -> httpx.Response: + """POST an initialize body to `/mcp`, optionally with a bearer token and/or a query string.""" + headers = base_headers() + if bearer is not None: + headers["authorization"] = f"Bearer {bearer}" + return await http.post("/mcp", headers=headers, params=query, json=initialize_body()) + + +def parse_www_authenticate(value: str) -> dict[str, str]: + """Parse a `Bearer k="v", k="v"` challenge into a dict. + + The SDK emits each parameter exactly once, comma-space separated, with double-quoted + values that contain no quotes themselves; this helper relies on that and would fail + visibly if the format changed. + """ + scheme, _, params = value.partition(" ") + assert scheme == "Bearer" + return {key: quoted.strip('"') for key, _, quoted in (pair.partition("=") for pair in params.split(", "))} + + +@requirement("hosting:auth:missing-401") +async def test_a_request_with_no_authorization_header_is_challenged_with_resource_metadata( + protected: httpx.AsyncClient, +) -> None: + """No `Authorization` header → 401 with a `WWW-Authenticate` carrying `resource_metadata`. + + The snapshot pins current behaviour: the SDK collapses the no-header, unknown-token, and + expired-token cases into one challenge (`error="invalid_token"`, no `scope` parameter). The + spec says the discovery-time challenge SHOULD include `scope` and RFC 6750 says the + no-credentials case SHOULD NOT carry an error code; both gaps are recorded as the divergence + on this requirement. Asserting the dict equals an exact key set also pins that no parameter + appears twice. + """ + response = await post_mcp(protected) + + assert response.status_code == 401 + assert response.headers["www-authenticate"] == snapshot( + 'Bearer error="invalid_token", error_description="Authentication required", ' + 'resource_metadata="http://127.0.0.1:8000/.well-known/oauth-protected-resource/mcp"' + ) + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert response.json() == snapshot({"error": "invalid_token", "error_description": "Authentication required"}) + + +@requirement("hosting:auth:invalid-401") +async def test_an_unrecognized_bearer_token_is_answered_401_invalid_token(protected: httpx.AsyncClient) -> None: + """A token the verifier does not recognize is answered 401 `invalid_token`. + + The challenge is identical to the no-header case (the backend returns `None` for both); the + missing `scope` parameter is the recorded divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-unknown") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"]) == { + "error": "invalid_token", + "error_description": "Authentication required", + "resource_metadata": RESOURCE_METADATA_URL, + } + + +@requirement("hosting:auth:expired-401") +async def test_an_expired_token_is_answered_401(protected: httpx.AsyncClient) -> None: + """A token whose `expires_at` is in the past is answered 401 `invalid_token`. + + The expiry check is the bearer backend's, against the wall clock; the test seeds a concrete + past timestamp so no time mocking is involved. The missing `scope` parameter is the recorded + divergence on this requirement. + """ + response = await post_mcp(protected, bearer="tok-expired") + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" + + +@requirement("hosting:auth:scope-403") +async def test_a_token_missing_a_required_scope_is_answered_403_insufficient_scope_without_a_scope_param( + protected: httpx.AsyncClient, +) -> None: + """A token lacking the required scope is answered 403 `insufficient_scope`, with no `scope` parameter. + + The spec's runtime-insufficient-scope guidance says the challenge SHOULD include `scope` + naming the required scope; the SDK never emits it, recorded as the divergence on this + requirement. The SDK client reads `scope` from this header to drive step-up, so the gap is + a resource-server/client asymmetry. + """ + response = await post_mcp(protected, bearer="tok-noscope") + + assert response.status_code == 403 + parsed = parse_www_authenticate(response.headers["www-authenticate"]) + assert parsed == { + "error": "insufficient_scope", + "error_description": f"Required scope: {REQUIRED_SCOPE}", + "resource_metadata": RESOURCE_METADATA_URL, + } + assert "scope" not in parsed + + +@requirement("hosting:auth:aud-validation") +async def test_a_token_with_a_mismatched_audience_is_accepted(protected: httpx.AsyncClient) -> None: + """A token whose `resource` does not match the server's resource identifier is accepted. + + The spec mandates the resource server validate the token's audience; the bearer backend + never inspects `AccessToken.resource`, so the request passes the gate and the MCP endpoint + serves it. This pins current behaviour with the divergence recorded on the requirement. + """ + response = await post_mcp(protected, bearer="tok-wrong-aud") + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + # The body is finite SSE: a result event followed by stream close. Pull the JSON-RPC response + # out of the buffered text to prove the MCP endpoint actually answered the initialize request. + [data] = [line.removeprefix("data: ") for line in response.text.splitlines() if line.startswith("data: ")] + assert "protocolVersion" in JSONRPCResponse.model_validate_json(data).result + + +@requirement("hosting:auth:query-token-ignored") +async def test_an_access_token_in_the_query_string_is_not_accepted(protected: httpx.AsyncClient) -> None: + """A valid token presented in the URI query string is treated as no authentication. + + The bearer backend reads only the `Authorization` header, so `?access_token=...` is never + consulted; the request is treated as unauthenticated and answered 401. This satisfies, by + absence, the security best-practice that resource servers must not accept query-string + tokens. + """ + response = await post_mcp(protected, query={"access_token": "tok-valid"}) + + assert response.status_code == 401 + assert parse_www_authenticate(response.headers["www-authenticate"])["error"] == "invalid_token" diff --git a/tests/interaction/auth/test_discovery.py b/tests/interaction/auth/test_discovery.py new file mode 100644 index 0000000000..68c33c8a2d --- /dev/null +++ b/tests/interaction/auth/test_discovery.py @@ -0,0 +1,333 @@ +"""Protected-resource and authorization-server metadata discovery, end to end. + +Every client-side test connects a real `Client` via `connect_with_oauth` and asserts on the +recorded request paths the discovery probes produced; the discovery URL ordering is a wire +detail `Client` cannot observe directly but the recording can. Tests that need a metadata +endpoint to 404 or return alternate content wrap the SDK's app in `shimmed_app` while leaving +the real authorize and token endpoints behind it, so the rest of the flow runs unaltered. + +The two server-side tests (#5, #6) drive raw httpx against `mounted_app` because their +assertions are the metadata response bodies and headers, which `Client` does not surface. +""" + +import json + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl + +from mcp import types +from mcp.client.auth import OAuthFlowError, OAuthRegistrationError +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata +from mcp.types import ListToolsResult, Tool +from tests.interaction._connect import BASE_URL, mounted_app +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + RecordedRequest, + auth_settings, + connect_with_oauth, + metadata_body, + record_requests, + shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH_SUFFIXED = "/.well-known/oauth-protected-resource/mcp" +PRM_ROOT = "/.well-known/oauth-protected-resource" +ASM_ROOT = "/.well-known/oauth-authorization-server" +OIDC_ROOT = "/.well-known/openid-configuration" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="probe", input_schema={"type": "object"})]) + + +def discovery_gets(recorded: list[RecordedRequest]) -> list[str]: + """Return the well-known GET paths in recorded order, ignoring everything else.""" + return [r.path for r in recorded if r.method == "GET" and "/.well-known/" in r.path] + + +def real_asm() -> OAuthMetadata: + """Build an authorization-server metadata document pointing at the real co-hosted endpoints.""" + return OAuthMetadata( + issuer=AnyHttpUrl(BASE_URL), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + ) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_uses_the_resource_metadata_url_from_www_authenticate() -> None: + """The first protected-resource probe is the URL the 401's `WWW-Authenticate` header supplied. + + With co-hosted defaults the header carries the path-suffixed well-known URL; the client + fetches that one first and, because it succeeds, never falls back. The single-probe + sequence proves priority 1. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, on_request=on_request) as (client, _): + await client.list_tools() + + assert discovery_gets(recorded) == snapshot([PRM_PATH_SUFFIXED, ASM_ROOT]) + assert (recorded[0].method, recorded[0].path) == ("POST", "/mcp") + assert (recorded[1].method, recorded[1].path) == ("GET", PRM_PATH_SUFFIXED) + + +@requirement("client-auth:prm-discovery:fallback-order") +async def test_prm_discovery_falls_back_from_path_well_known_to_root_on_404() -> None: + """When the path-suffixed PRM well-known 404s, the client falls back to the root well-known. + + The exact GET count is not asserted: the WWW-Authenticate URL equals the path well-known + here, so the SDK probes it twice (once as priority 1, once as priority 2) before reaching + root. Asserting "path before root, root reached, then the flow proceeds" pins the spec + invariant; the duplicate probe is an implementation detail. The served PRM body carries an + unrecognized field to prove the client's parser ignores unknown members (RFC 9728 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim( + not_found=frozenset({PRM_PATH_SUFFIXED}), + serve={PRM_ROOT: metadata_body(prm, x_unknown_extension="ignored")}, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known.index(PRM_PATH_SUFFIXED) < well_known.index(PRM_ROOT) + assert any(r.path == "/authorize" for r in recorded) + + +@requirement("client-auth:prm-discovery:no-prm-fallback") +async def test_when_every_prm_probe_fails_the_client_discovers_as_metadata_at_the_server_origin() -> None: + """When every protected-resource metadata probe 404s, the client falls back to the legacy path. + + The legacy 2025-03-26 behaviour: with no PRM document available, treat the MCP server's + origin as the authorization server and fetch its `/.well-known/oauth-authorization-server` + directly. The real co-hosted ASM endpoint is at exactly that location, so the flow completes. + The recorded sequence shows both PRM well-known paths probed (and failed) before ASM_ROOT. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + app_shim = shim(not_found=frozenset({PRM_PATH_SUFFIXED, PRM_ROOT})) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + result = await client.list_tools() + + well_known = discovery_gets(recorded) + assert PRM_PATH_SUFFIXED in well_known + assert PRM_ROOT in well_known + assert well_known[-1] == ASM_ROOT + assert all(well_known.index(prm) < well_known.index(ASM_ROOT) for prm in (PRM_PATH_SUFFIXED, PRM_ROOT)) + assert result.tools[0].name == "probe" + + +@requirement("client-auth:dcr:registration-error-surfaces") +async def test_a_400_from_the_registration_endpoint_surfaces_as_a_registration_error() -> None: + """A 400 from `/register` surfaces as `OAuthRegistrationError` carrying the server's body. + + The shim makes `/register` return RFC 7591's `invalid_client_metadata`; the SDK reads the + body and raises with the status and text in the message, before any authorize or token + request is made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + error_body = json.dumps({"error": "invalid_client_metadata", "error_description": "no"}).encode() + app_shim = shim(serve={"/register": (400, error_body)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthRegistrationError, match=r"^Registration failed: 400 .*invalid_client_metadata"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:prm-resource-mismatch") +async def test_prm_with_a_mismatched_resource_aborts_the_flow_before_authorize() -> None: + """A PRM document whose `resource` does not cover the server URL aborts the flow. + + The shim serves PRM at the URL the WWW-Authenticate header supplies, but with a `resource` + on a different path; `check_resource_allowed` rejects it and `OAuthFlowError` is raised + before any authorize or token request is made. The error reaches the test wrapped in nested + single-element exception groups by the streamable-HTTP client's task group. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/other"), authorization_servers=[AnyHttpUrl(BASE_URL)] + ) + app_shim = shim(serve={PRM_PATH_SUFFIXED: metadata_body(prm)}) + + with anyio.fail_after(5): + with pytest.RaisesGroup( + pytest.RaisesExc(OAuthFlowError, match="^Protected resource .* does not match expected"), + flatten_subgroups=True, + ): + await connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request).__aenter__() + + assert [r.path for r in recorded if r.path in ("/authorize", "/token")] == [] + + +@requirement("client-auth:as-metadata-discovery:priority-order") +@pytest.mark.parametrize( + ("authorization_server", "not_found", "serve_at", "expected_order"), + [ + pytest.param( + f"{BASE_URL}/", + frozenset({ASM_ROOT}), + OIDC_ROOT, + [ASM_ROOT, OIDC_ROOT], + id="root-issuer", + ), + pytest.param( + f"{BASE_URL}/tenant", + frozenset({f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant"}), + "/tenant/.well-known/openid-configuration", + [f"{ASM_ROOT}/tenant", f"{OIDC_ROOT}/tenant", "/tenant/.well-known/openid-configuration"], + id="path-issuer", + ), + ], +) +async def test_as_metadata_discovery_falls_back_through_the_spec_endpoint_order( + authorization_server: str, not_found: frozenset[str], serve_at: str, expected_order: list[str] +) -> None: + """Authorization-server metadata is fetched at the spec's endpoints in the spec's order. + + The shim 404s every endpoint before the last so the recording proves each probe and its + position. For an issuer URL with no path the order is OAuth root then OIDC root; for an + issuer URL with a path component it is OAuth path-inserted, OIDC path-inserted, then OIDC + path-appended (the spec's three-endpoint MUST). The path-issuer case is driven by serving + a PRM whose `authorization_servers` carries the path; the SDK's own AS routes stay at root + (the served body points at the real `/authorize` and `/token`). The served bodies carry an + unrecognized field to prove the client's parser ignores unknown members (RFC 8414 §3.2). + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + prm = ProtectedResourceMetadata( + resource=AnyHttpUrl(f"{BASE_URL}/mcp"), authorization_servers=[AnyHttpUrl(authorization_server)] + ) + app_shim = shim( + not_found=not_found, + serve={ + PRM_PATH_SUFFIXED: metadata_body(prm), + serve_at: metadata_body(real_asm(), x_unknown_extension="ignored"), + }, + ) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim, on_request=on_request) as ( + client, + _, + ): + await client.list_tools() + + assert discovery_gets(recorded) == [PRM_PATH_SUFFIXED, *expected_order] + + +@requirement("hosting:auth:metadata-endpoints") +@requirement("hosting:auth:prm:authorization-servers-field") +async def test_the_prm_endpoint_serves_the_resource_url_and_at_least_one_authorization_server() -> None: + """The protected-resource metadata document the SDK serves identifies the resource and an authorization server. + + Also asserts the response is `application/json` (RFC 9728 §3.2) and that fields the SDK has + no value for are absent rather than null (`PydanticJSONResponse` serializes with + `exclude_none=True`, satisfying RFC 9728 §3.2's omit-zero-value rule). + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(PRM_PATH_SUFFIXED) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + document = json.loads(response.content) + assert "resource_documentation" not in document + assert "scopes_supported" in document + + metadata = ProtectedResourceMetadata.model_validate(document) + assert str(metadata.resource).rstrip("/") == f"{BASE_URL}/mcp" + assert len(metadata.authorization_servers) >= 1 + assert metadata.bearer_methods_supported == ["header"] + + +@requirement("hosting:auth:as-router") +async def test_as_metadata_advertises_authorize_token_registration_and_s256() -> None: + """The authorization-server metadata document the SDK serves names the required endpoints and S256.""" + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + + async with mounted_app(server, auth=auth_settings(), auth_server_provider=provider) as (http, _): + response = await http.get(ASM_ROOT) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + + metadata = OAuthMetadata.model_validate_json(response.content) + assert str(metadata.issuer).rstrip("/") == BASE_URL + assert str(metadata.authorization_endpoint) == f"{BASE_URL}/authorize" + assert str(metadata.token_endpoint) == f"{BASE_URL}/token" + assert str(metadata.registration_endpoint) == f"{BASE_URL}/register" + assert metadata.response_types_supported == ["code"] + assert metadata.code_challenge_methods_supported is not None + assert "S256" in metadata.code_challenge_methods_supported + + +@requirement("client-auth:as-metadata-discovery:issuer-validation") +async def test_as_metadata_with_a_mismatched_issuer_is_accepted_and_the_flow_proceeds() -> None: + """Authorization-server metadata whose `issuer` does not match the discovery URL is accepted. + + RFC 8414 §3.3 requires the client to reject the document; the SDK parses and uses it + without comparing `issuer` to the URL it was fetched from. See the divergence on the + requirement. The served body carries an unrecognized field as a fold-in proof of + unknown-field tolerance. + """ + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + metadata = real_asm() + metadata.issuer = AnyHttpUrl(f"{BASE_URL}/wrong-issuer") + app_shim = shim(serve={ASM_ROOT: metadata_body(metadata, x_unknown_extension="ignored")}) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, app_shim=app_shim) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "probe" diff --git a/tests/interaction/auth/test_flow.py b/tests/interaction/auth/test_flow.py new file mode 100644 index 0000000000..968fc5f980 --- /dev/null +++ b/tests/interaction/auth/test_flow.py @@ -0,0 +1,239 @@ +"""End-to-end OAuth authorization-code flow against the SDK's own server, fully in process. + +Auth is HTTP-only so these tests are not transport-parametrized; each connects via +`connect_with_oauth`, which co-hosts the SDK's authorization server, protected-resource +metadata, and bearer-gated MCP endpoint on one bridge-backed Starlette app and drives the +whole flow through one `httpx.AsyncClient` carrying the SDK's `OAuthClientProvider`. The +authorize redirect completes headlessly through the same bridge, so every request the flow +makes is observable via `on_request`. +""" + +import json +from collections import Counter +from urllib.parse import parse_qs, urlsplit + +import anyio +import httpx +import pytest +from inline_snapshot import snapshot +from pydantic import AnyUrl + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import get_access_token +from mcp.shared.auth import OAuthClientInformationFull +from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + auth_settings, + connect_with_oauth, + oauth_client_metadata, + shimmed_app, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider +from tests.interaction.transports._bridge import StreamingASGITransport + +pytestmark = pytest.mark.anyio + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})]) + + +@requirement("flow:oauth:authorization-code-roundtrip") +@requirement("client-auth:401-triggers-flow") +@requirement("hosting:auth:missing-401") +async def test_an_unauthenticated_request_is_challenged_then_the_full_oauth_flow_connects() -> None: + """Connecting to a bearer-gated server walks the full authorization-code flow and succeeds. + + Three requirements are proven by one connect: the flow runs end to end (authorization-code + roundtrip), it was triggered by a 401 on the first MCP request (401-triggers-flow), and + that 401 carried `resource_metadata` in `WWW-Authenticate` for discovery (missing-401). + The flagship test pins the recorded request sequence so the discovery → registration → + authorize → token → retry order is asserted explicitly. + + Steps the SDK is expected to perform: + 1. POST /mcp without a token → 401 with `WWW-Authenticate: Bearer resource_metadata=...`. + 2. GET the protected-resource metadata. + 3. GET the authorization-server metadata. + 4. POST /register (dynamic client registration). + 5. GET /authorize → 302 with code+state (completed by the headless redirect). + 6. POST /token (authorization-code exchange). + 7. Retry POST /mcp with `Authorization: Bearer ` → succeeds. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + headless, + ): + result = await client.list_tools() + + assert result == snapshot(ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object"})])) + assert headless.authorize_url is not None + + paths = [(r.method, r.url.path) for r in requests] + assert Counter(paths) == snapshot( + Counter( + { + ("POST", "/mcp"): 4, + ("GET", "/.well-known/oauth-protected-resource/mcp"): 1, + ("GET", "/.well-known/oauth-authorization-server"): 1, + ("POST", "/register"): 1, + ("GET", "/authorize"): 1, + ("POST", "/token"): 1, + ("GET", "/mcp"): 1, + ("DELETE", "/mcp"): 1, + } + ) + ) + + assert (requests[0].method, requests[0].url.path) == ("POST", "/mcp") + # The recorded Request objects are live references: the auth flow mutates the original + # request's headers in place when it adds the bearer token for the retry, so the first + # entry's headers cannot be used to assert "no Authorization on the first attempt". The + # path multiset above proving discovery happened is the evidence the first attempt was 401. + + # The first PRM discovery GET carries the protocol-version header (an SDK behaviour, not a + # spec requirement on discovery requests). + prm_get = next(r for r in requests if r.url.path == "/.well-known/oauth-protected-resource/mcp") + assert prm_get.headers.get("mcp-protocol-version") == snapshot("2025-11-25") + + authorize = parse_qs(urlsplit(headless.authorize_url).query) + assert authorize["response_type"] == ["code"] + assert authorize["code_challenge_method"] == ["S256"] + assert authorize["client_id"][0] in provider.clients + + assert storage.tokens is not None + bearer = f"Bearer {storage.tokens.access_token}" + authed_mcp = [r for r in requests if r.url.path == "/mcp" and r.headers.get("authorization") == bearer] + assert len(authed_mcp) > 0 + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("hosting:auth:authinfo-propagates") +async def test_the_access_token_reaches_the_tool_handler_via_get_access_token() -> None: + """A tool handler reads the request's access token through `get_access_token()`.""" + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "whoami" + token = get_access_token() + assert token is not None + return CallToolResult(content=[TextContent(text=" ".join(token.scopes))]) + + server = Server("guarded", on_list_tools=list_tools, on_call_tool=call_tool) + provider = InMemoryAuthorizationServerProvider() + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider) as (client, _): + result = await client.call_tool("whoami", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="mcp")])) + + +@requirement("client-auth:pre-registration") +async def test_a_preregistered_client_skips_registration() -> None: + """A client whose storage already holds client info uses it instead of registering. + + The provider holds the same registration server-side so the authorize and token steps + accept it; the recorded requests prove no `/register` call was made. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_info = OAuthClientInformationFull( + client_id="preregistered", + client_secret="s3cret", + token_endpoint_auth_method="client_secret_post", + redirect_uris=[AnyUrl(REDIRECT_URI)], + grant_types=["authorization_code", "refresh_token"], + scope="mcp", + ) + await provider.register_client(client_info) + storage.client_info = client_info + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=requests.append) as ( + client, + _, + ): + await client.list_tools() + + assert [r.url.path for r in requests].count("/register") == 0 + assert list(provider.clients) == ["preregistered"] + + +@requirement("client-auth:dcr") +async def test_the_dcr_request_carries_the_client_metadata() -> None: + """Dynamic registration sends the client's metadata and persists what the server issued. + + The body of the recorded `/register` POST carries the metadata the test supplied (with the + scope filled in from server discovery), and the server's issued client_id and secret are + persisted to storage and held by the provider. + """ + requests: list[httpx.Request] = [] + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + client_metadata = oauth_client_metadata() + client_metadata.software_id = "interaction-test-suite" + + with anyio.fail_after(5): + async with connect_with_oauth( + server, provider=provider, storage=storage, client_metadata=client_metadata, on_request=requests.append + ) as (client, _): + await client.list_tools() + + register = next(r for r in requests if r.url.path == "/register") + assert register.headers["content-type"] == "application/json" + body = json.loads(register.content) + assert body == snapshot( + { + "redirect_uris": ["http://127.0.0.1:8000/oauth/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "mcp", + "client_name": "interaction-suite", + "software_id": "interaction-test-suite", + } + ) + + assert storage.client_info is not None + assert storage.client_info.client_id is not None + assert storage.client_info.client_secret is not None + assert list(provider.clients) == [storage.client_info.client_id] + + +async def test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app() -> None: + """Harness self-test: `shimmed_app` serves canned bodies, 404s, and forwards everything else. + + Wraps a real auth-hosting Starlette app so the forward path is exercised against the SDK's + own routing; provided here so the discovery tests can rely on the shim without each adding + their own contract test. + """ + server = Server("bare") + provider = InMemoryAuthorizationServerProvider() + real_app = server.streamable_http_app(auth=auth_settings(), auth_server_provider=provider) + app = shimmed_app(real_app, not_found=frozenset({"/missing"}), serve={"/override": b'{"shimmed": true}'}) + async with server.session_manager.run(): + async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http: + served = await http.get("/override") + assert served.status_code == 200 + assert served.headers["content-type"] == "application/json" + assert served.json() == {"shimmed": True} + + assert (await http.get("/missing")).status_code == 404 + + forwarded = await http.get("/.well-known/oauth-authorization-server") + assert forwarded.status_code == 200 + assert forwarded.json()["issuer"] == "http://127.0.0.1:8000/" diff --git a/tests/interaction/auth/test_lifecycle.py b/tests/interaction/auth/test_lifecycle.py new file mode 100644 index 0000000000..aa552ae8a6 --- /dev/null +++ b/tests/interaction/auth/test_lifecycle.py @@ -0,0 +1,445 @@ +"""Token lifecycle, step-up, and registration-variant flows of the SDK's OAuth client. + +Every test connects end to end via `connect_with_oauth`; the assertions are recording-first +(the recorded request sequence is asserted before, or independently of, the call result), so a +surprise in the refresh or step-up paths produces a readable diff of what fired rather than an +opaque failure. The provider knobs that drive each scenario are documented per test. +""" + +import base64 +from collections import Counter +from urllib.parse import parse_qsl, urlsplit + +import anyio +import pytest +from inline_snapshot import snapshot +from pydantic import AnyHttpUrl, AnyUrl + +from mcp import MCPError, types +from mcp.client.auth.extensions.client_credentials import ClientCredentialsOAuthProvider, PrivateKeyJWTOAuthProvider +from mcp.server import Server, ServerRequestContext +from mcp.shared.auth import OAuthClientInformationFull, OAuthMetadata +from mcp.types import INTERNAL_ERROR, ListToolsResult, Tool +from tests.interaction._connect import BASE_URL +from tests.interaction._requirements import requirement +from tests.interaction.auth._harness import ( + REDIRECT_URI, + InMemoryTokenStorage, + RecordedRequest, + auth_settings, + connect_with_oauth, + m2m_token_shim, + metadata_body, + record_requests, + shim, + step_up_shim, +) +from tests.interaction.auth._provider import InMemoryAuthorizationServerProvider + +pytestmark = pytest.mark.anyio + +PRM_PATH = "/.well-known/oauth-protected-resource/mcp" +ASM_PATH = "/.well-known/oauth-authorization-server" +CIMD_URL = "https://client.example/.well-known/mcp-client" + + +async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="echo", input_schema={"type": "object"})]) + + +def form_body(request: RecordedRequest) -> dict[str, str]: + """Parse an `application/x-www-form-urlencoded` request body into a flat dict.""" + return dict(parse_qsl(request.content.decode())) + + +def authorize_params(authorize_url: str) -> dict[str, str]: + """Parse the authorize URL's query string into a flat dict.""" + return dict(parse_qsl(urlsplit(authorize_url).query)) + + +def find(recorded: list[RecordedRequest], method: str, path: str) -> list[RecordedRequest]: + return [r for r in recorded if r.method == method and r.path == path] + + +def path_counts(recorded: list[RecordedRequest]) -> Counter[tuple[str, str]]: + return Counter((r.method, r.path) for r in recorded) + + +def cimd_supported_metadata() -> bytes: + """AS metadata advertising `client_id_metadata_document_supported: true` (the SDK server never sets it).""" + metadata = OAuthMetadata( + issuer=AnyHttpUrl(f"{BASE_URL}/"), + authorization_endpoint=AnyHttpUrl(f"{BASE_URL}/authorize"), + token_endpoint=AnyHttpUrl(f"{BASE_URL}/token"), + registration_endpoint=AnyHttpUrl(f"{BASE_URL}/register"), + scopes_supported=["mcp"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + code_challenge_methods_supported=["S256"], + client_id_metadata_document_supported=True, + ) + return metadata_body(metadata) + + +def seeded_client(provider: InMemoryAuthorizationServerProvider, **kwargs: object) -> OAuthClientInformationFull: + """Register a client with the provider and return its info, for pre-registration and CIMD scenarios.""" + base: dict[str, object] = { + "client_id": "preregistered", + "token_endpoint_auth_method": "none", + "redirect_uris": [AnyUrl(REDIRECT_URI)], + "grant_types": ["authorization_code", "refresh_token"], + "scope": "mcp", + } + base.update(kwargs) + info = OAuthClientInformationFull.model_validate(base) + assert info.client_id is not None + provider.clients[info.client_id] = info + return info + + +@requirement("client-auth:refresh:transparent") +async def test_an_expired_access_token_is_transparently_refreshed_before_the_next_request() -> None: + """An access token the client considers expired is refreshed and the new bearer is used. + + The provider tells the client `expires_in=-3600` for the first token while keeping the + server-side `expires_at` in the future, so the connect's retry succeeds and the next + request finds the token expired and refreshes. The recorded requests prove exactly one + `grant_type=refresh_token` exchange carrying the resource indicator, and the bearer used + after the refresh is the second access token, which is the one persisted to storage. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + bodies = [form_body(r) for r in token_posts] + assert [b["grant_type"] for b in bodies] == snapshot(["authorization_code", "refresh_token"]) + + refresh_body = bodies[1] + assert sorted(refresh_body) == snapshot(["client_id", "client_secret", "grant_type", "refresh_token", "resource"]) + assert refresh_body["refresh_token"].startswith("refresh_") + assert refresh_body["resource"].startswith(BASE_URL) + + bearers = {r.headers["authorization"] for r in recorded if r.path == "/mcp" and "authorization" in r.headers} + assert len(bearers) == 2 + assert storage.tokens is not None + assert f"Bearer {storage.tokens.access_token}" in bearers + assert storage.tokens.expires_in == 3600 + + +@requirement("client-auth:403-scope-upgrade") +async def test_a_403_insufficient_scope_triggers_one_reauthorize_with_the_challenged_scope() -> None: + """A 403 `insufficient_scope` challenge is answered by one re-authorize with the challenge's scope. + + The shim 403s the second authenticated `/mcp` POST (the `notifications/initialized` request, + which reaches the auth flow's step-up handler; the first authenticated POST is the post-401 + retry, after which the generator ends without inspecting the response). The challenge names a + wider scope; step-up reuses cached metadata and the existing client registration, + re-authorizes with the new scope, and the connect completes. The client is pre-registered + with both scopes so the server's authorize handler accepts the wider second request. One + re-authorize, one retry; the spec's SHOULD-retry-limit ("a few") is not enforced. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + storage = InMemoryTokenStorage(client_info=seeded_client(provider, scope="mcp write")) + server = Server("guarded", on_list_tools=list_tools) + settings = auth_settings(required_scopes=["mcp"], valid_scopes=["mcp", "write"]) + challenge = 'Bearer error="insufficient_scope", scope="mcp write"' + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + settings=settings, + app_shim=step_up_shim(challenge), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + assert len(headless.authorize_urls) == 2 + assert authorize_params(headless.authorize_urls[0])["scope"] == "mcp" + assert authorize_params(headless.authorize_urls[1])["scope"] == "mcp write" + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 0 + assert counts[("GET", "/authorize")] == 2 + assert counts[("POST", "/token")] == 2 + + +@requirement("client-auth:401-after-auth-throws") +async def test_a_second_401_after_a_completed_oauth_flow_surfaces_without_looping() -> None: + """A 401 on the post-auth retry surfaces as an error rather than re-entering discovery. + + The provider rejects every token at verification, so the full flow runs once and the retry + is 401'd. The auth-flow generator ends after that retry, so the 401 propagates and the + transport converts it to an INTERNAL_ERROR result, raising during connect. Discovery, + registration, authorize, and token each ran exactly once: no loop. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(reject_all_tokens=True) + server = Server("guarded", on_list_tools=list_tools) + + def is_internal_error(error: MCPError) -> bool: + return error.error.code == INTERNAL_ERROR + + with anyio.fail_after(5): + with pytest.RaisesGroup(pytest.RaisesExc(MCPError, check=is_internal_error), flatten_subgroups=True): + # Entering the connect raises during the OAuth handshake (inside `Client.__aenter__`), + # so an `async with` body would be unreachable; entering explicitly avoids dead code. + await connect_with_oauth(server, provider=provider, on_request=on_request).__aenter__() + + counts = path_counts(recorded) + assert counts[("GET", PRM_PATH)] == 1 + assert counts[("GET", ASM_PATH)] == 1 + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 1 + assert counts[("POST", "/token")] == 1 + assert counts[("POST", "/mcp")] == 2 + + +@requirement("client-auth:cimd") +async def test_cimd_is_selected_when_the_as_advertises_support_and_a_metadata_url_is_supplied() -> None: + """A client-ID metadata-document URL is used as `client_id` instead of registering. + + AS metadata is shimmed to advertise `client_id_metadata_document_supported: true`; the + provider is pre-seeded so the server's authorize and token handlers accept the URL as a + client_id (the SDK server has no CIMD-aware client lookup of its own). The recorded + requests prove no `/register` call, the authorize URL's `client_id` is the CIMD URL, the + token request uses `token_endpoint_auth_method=none`, and storage persists the URL as + `client_id`. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + seeded_client(provider, client_id=CIMD_URL) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=shim(serve={ASM_PATH: cimd_supported_metadata()}), + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert find(recorded, "POST", "/register") == [] + assert headless.authorize_url is not None + assert authorize_params(headless.authorize_url)["client_id"] == CIMD_URL + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body["client_id"] == CIMD_URL + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + assert storage.client_info is not None + assert storage.client_info.client_id == CIMD_URL + assert storage.client_info.token_endpoint_auth_method == "none" + + +@requirement("client-auth:invalid-grant-clears-tokens") +async def test_a_failed_refresh_clears_stored_tokens_and_restarts_the_full_flow() -> None: + """A non-200 refresh response clears the in-memory tokens and the flow re-runs from discovery. + + The first token is reported expired so the next request refreshes; the provider denies the + refresh once with `invalid_grant`, the auth flow clears its tokens, the unauthenticated + request 401s, and discovery, authorize, and token run again. The original registration is + preserved (`client_info` is not cleared). The SDK clears tokens on any non-200 refresh + response, not specifically `error=invalid_grant`; `source="sdk"` so this is a precision + note rather than a divergence. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider(issue_expired_first=True, fail_next_refresh=True) + storage = InMemoryTokenStorage() + server = Server("guarded", on_list_tools=list_tools) + + with anyio.fail_after(5): + async with connect_with_oauth(server, provider=provider, storage=storage, on_request=on_request) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + + token_posts = find(recorded, "POST", "/token") + assert [form_body(r)["grant_type"] for r in token_posts] == snapshot( + ["authorization_code", "refresh_token", "authorization_code"] + ) + + counts = path_counts(recorded) + assert counts[("POST", "/register")] == 1 + assert counts[("GET", "/authorize")] == 2 + assert counts[("GET", PRM_PATH)] == 2 + assert counts[("GET", ASM_PATH)] == 2 + + assert storage.client_info is not None + assert storage.tokens is not None + assert storage.tokens.access_token in provider.access_tokens + + +@requirement("client-auth:client-credentials") +async def test_client_credentials_provider_obtains_a_token_without_an_authorize_step() -> None: + """The client-credentials provider connects with no authorize step and a `client_credentials` grant. + + The SDK server's `TokenHandler` does not route `client_credentials`, so the harness shim + handles it (the shim is harness; the SDK-under-test is the client provider). The recorded + `/token` body proves the grant type, scope, resource indicator, and HTTP-Basic client + authentication; no `/authorize` or `/register` request was made. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + auth = ClientCredentialsOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-client", + client_secret="m2m-secret", + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, headless): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert headless.authorize_url is None + assert find(recorded, "GET", "/authorize") == [] + assert find(recorded, "POST", "/register") == [] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + {"grant_type": "client_credentials", "resource": "http://127.0.0.1:8000/mcp", "scope": "mcp"} + ) + decoded = base64.b64decode(token_req.headers["authorization"].removeprefix("Basic ")).decode() + assert decoded == "m2m-client:m2m-secret" + + +@requirement("client-auth:private-key-jwt") +async def test_private_key_jwt_provider_authenticates_the_token_request_with_an_assertion() -> None: + """The private-key-JWT provider sends a `client_assertion` on the token request, with the issuer as audience. + + The assertion provider is a closure that records the audience it was called with and returns + a fixed opaque value (the JWT contents are not the SDK's concern here); the test asserts the + `client_assertion`/`client_assertion_type` form fields and that the audience matches the AS + metadata's issuer. + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + + audiences: list[str] = [] + + async def assertion_provider(audience: str) -> str: + audiences.append(audience) + return "header.payload.sig" + + auth = PrivateKeyJWTOAuthProvider( + server_url=f"{BASE_URL}/mcp", + storage=InMemoryTokenStorage(), + client_id="m2m-jwt-client", + assertion_provider=assertion_provider, + scopes="mcp", + ) + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + auth=auth, + app_shim=m2m_token_shim(provider, scopes=["mcp"]), + on_request=on_request, + ) as (client, _): + result = await client.list_tools() + + assert result.tools[0].name == "echo" + assert audiences == [f"{BASE_URL}/"] + + [token_req] = find(recorded, "POST", "/token") + body = form_body(token_req) + assert body == snapshot( + { + "grant_type": "client_credentials", + "client_assertion": "header.payload.sig", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "resource": "http://127.0.0.1:8000/mcp", + "scope": "mcp", + } + ) + assert "client_secret" not in body + assert "authorization" not in token_req.headers + + +@pytest.mark.parametrize( + ("case", "preseed_storage", "advertise_cimd"), + [("cimd_unsupported_falls_through_to_dcr", False, False), ("preregistered_beats_cimd", True, True)], + ids=["cimd_unsupported_falls_through_to_dcr", "preregistered_beats_cimd"], +) +@requirement("client-auth:cimd") +async def test_registration_priority_prefers_preregistered_then_cimd_then_dcr( + case: str, preseed_storage: bool, advertise_cimd: bool +) -> None: + """The client picks pre-registration over CIMD over DCR, falling through when each is unavailable. + + Two priority edges are exercised: with a CIMD URL configured but no AS support, DCR runs and + the registered `client_id` is used; with a CIMD URL configured and AS support but a + pre-registered client in storage, the stored `client_id` is used and neither CIMD nor DCR + runs. (The positive CIMD case and pre-registration over DCR are covered by their own tests.) + """ + recorded, on_request = record_requests() + provider = InMemoryAuthorizationServerProvider() + server = Server("guarded", on_list_tools=list_tools) + storage = InMemoryTokenStorage() + + expected_client_id: str + if preseed_storage: + info = seeded_client(provider) + storage.client_info = info + assert info.client_id is not None + expected_client_id = info.client_id + else: + expected_client_id = "" + + app_shim = shim(serve={ASM_PATH: cimd_supported_metadata()}) if advertise_cimd else None + + with anyio.fail_after(5): + async with connect_with_oauth( + server, + provider=provider, + storage=storage, + client_metadata_url=CIMD_URL, + app_shim=app_shim, + on_request=on_request, + ) as (client, headless): + await client.list_tools() + + assert headless.authorize_url is not None + chosen_client_id = authorize_params(headless.authorize_url)["client_id"] + assert chosen_client_id != CIMD_URL + + if case == "cimd_unsupported_falls_through_to_dcr": + assert len(find(recorded, "POST", "/register")) == 1 + assert chosen_client_id in provider.clients + else: + assert find(recorded, "POST", "/register") == [] + assert chosen_client_id == expected_client_id diff --git a/tests/interaction/test_coverage.py b/tests/interaction/test_coverage.py index 47b1b95e71..7821c1eed5 100644 --- a/tests/interaction/test_coverage.py +++ b/tests/interaction/test_coverage.py @@ -30,6 +30,7 @@ "tests.interaction.transports.test_bridge.test_closing_the_response_delivers_a_disconnect_to_the_application", "tests.interaction.transports.test_bridge.test_an_application_failure_before_the_response_starts_fails_the_request", "tests.interaction.transports.test_bridge.test_disabling_cancel_on_close_lets_the_application_finish_after_disconnect", + "tests.interaction.auth.test_flow.test_shimmed_app_serves_overrides_404s_and_otherwise_forwards_to_the_wrapped_app", } From cec4a2d895caf6984d0c910dc3b29e53981e9960 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 14:26:54 +0000 Subject: [PATCH 23/34] test: tighten remaining deferral reasons to reflect SDK feature gaps --- tests/interaction/_requirements.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 4e072ae254..526af032b2 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -854,9 +854,9 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/server/resources#subscriptions", behavior="After resources/unsubscribe the server stops sending updated notifications for that URI.", deferred=( - "The SDK keeps no subscription state -- emitting updated notifications is entirely handler " - "code -- so there is no SDK behaviour to pin beyond the unsubscribe request reaching the " - "handler (covered by resources:unsubscribe)." + "Not implemented in the SDK: the server keeps no subscription state, so whether updated " + "notifications stop after unsubscribe is entirely handler code; there is no SDK behaviour to " + "pin beyond the unsubscribe request reaching the handler (covered by resources:unsubscribe)." ), ), "resources:updated-notification": Requirement( @@ -2327,9 +2327,9 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), deferred=( - "Not yet covered here: the standalone GET stream emits no priming event or retry hint, so " - "the client's reconnection path always sleeps the hard-coded 1 s default; a deterministic " - "in-process test would inject real-time delay or require an SDK change. The POST-stream " + "Not implemented in the SDK: the server's standalone GET stream emits no priming event or " + "retry hint, so the client's reconnection path always sleeps the hard-coded 1 s default; a " + "deterministic in-process test would require accepting that real-time wait. The POST-stream " "reconnection path is covered by client-transport:http:reconnect-post-priming." ), ), @@ -2647,10 +2647,10 @@ def __post_init__(self) -> None: ), transports=("stdio",), deferred=( - "Not yet covered here: a server that ignores stdin close takes the full " - "PROCESS_TERMINATION_TIMEOUT (2.0 s) grace period plus up to a further 2.0 s for " - "SIGTERM/SIGKILL escalation; a robust test of that path is real-time-bound and the constant " - "is module-level (no public override). Covered by tests/client/test_stdio.py." + "A server that ignores stdin close takes the full PROCESS_TERMINATION_TIMEOUT (2.0 s) grace " + "period plus up to a further 2.0 s for SIGTERM/SIGKILL escalation; testing that path is " + "real-time-bound (the constant is module-level with no public override) and so is deliberately " + "excluded from this suite. Covered by tests/client/test_stdio.py." ), ), "transport:stdio:stderr-passthrough": Requirement( @@ -2704,10 +2704,10 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), deferred=( - "No public per-session post-initialization hook exists on either server flavour " - "(Server.lifespan runs at server startup, not per session; ServerSession handles the " - "initialized notification internally with no callback). Driving 'before any client " - "request' deterministically would also require knowing the standalone GET stream is " + "Not implemented in the SDK: no public per-session post-initialization hook exists on either " + "server flavour (Server.lifespan runs at server startup, not per session; ServerSession " + "handles the initialized notification internally with no callback). Driving 'before any " + "client request' deterministically would also require knowing the standalone GET stream is " "established, which has no synchronization signal." ), ), From 0157444277a4ea2a9b8e41b50a3b4cd618a4691d Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 14:30:43 +0000 Subject: [PATCH 24/34] docs: update interaction suite README for transports, auth, and decorator stacking --- tests/interaction/README.md | 43 ++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index ba08fa564e..1245eae30e 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -10,7 +10,8 @@ running the suite before and after. uv run --frozen pytest tests/interaction/ ``` -The whole suite is in-memory and event-driven; it runs in about a second. +The whole suite is in-process and event-driven — including the streamable HTTP, SSE, and OAuth +flows — with a single subprocess test for stdio. ## Ground rules @@ -26,10 +27,10 @@ The whole suite is in-memory and event-driven; it runs in about a second. the constants in `mcp.types`; error *message strings* are pinned only where they are the SDK's own deliberate output. - **No sleeps, no real I/O.** Concurrency is coordinated with `anyio.Event`; every wait that - could hang is bounded by `anyio.fail_after(5)`. The streamable HTTP tests drive the Starlette + could hang is bounded by `anyio.fail_after(5)`. The HTTP and OAuth tests drive the Starlette app in-process through the suite's streaming ASGI bridge (`transports/_bridge.py`), which delivers each response chunk as the server produces it — full duplex, but still no sockets, - threads, or subprocesses anywhere. + threads, or subprocesses anywhere outside the one stdio test. ## Layout @@ -42,7 +43,8 @@ tests/interaction/ test_coverage.py enforces the manifest ↔ test contract lowlevel/ one file per feature area, against the low-level Server mcpserver/ the same feature areas in MCPServer's natural idiom - transports/ behaviour specific to one transport (modes, streams, framing) + transports/ behaviour specific to one transport (sessions, resumability, framing) + auth/ OAuth flows against an in-process authorization server ``` The two server APIs produce genuinely different wire output for the same conceptual feature @@ -53,14 +55,15 @@ test body — each directory pins its flavour's true output exactly. ### The transport matrix Transport-agnostic tests take the `connect` fixture instead of constructing `Client(server)` -directly, and therefore run once per transport: over the in-memory transport and over the -server's real streamable HTTP app driven in process through the streaming bridge. A test connects -the same way in either case — `async with connect(server, ...) as client:` — and asserts the same -output, because the transport is not supposed to change observable behaviour. Tests that are tied -to one transport do not use the fixture: the wire-recording tests (their seam is the in-memory -stream pair), the bare-`ClientSession` lifecycle tests, the real-clock timeout tests (the timeout -machinery is transport-independent and must not race transport latency), and everything under -`transports/`, which pins behaviour only observable on that transport. +directly, and therefore run once per transport: over the in-memory transport, over the server's +real streamable HTTP app driven in-process through the streaming bridge, and over the legacy SSE +transport the same way. A test connects with `async with connect(server, ...) as client:` and +asserts the same output on every leg, because the transport is not supposed to change observable +behaviour. Tests that are tied to one transport do not use the fixture: the wire-recording tests +(their seam is the in-memory stream pair), the bare-`ClientSession` lifecycle tests, the +real-clock timeout tests (the timeout machinery is transport-independent and must not race +transport latency), and everything under `transports/`, which pins behaviour only observable on +that transport. A transport conformance test in `transports/` speaks raw `httpx` against the mounted ASGI app **only** when its assertion is about HTTP semantics that `Client` cannot observe — status codes, @@ -86,9 +89,10 @@ clients can share one session manager. contract) says should happen. Tests always pin the SDK's current behaviour; where that falls short of `behavior`, the gap is recorded as data rather than hidden in the test. - **`divergence`** records that gap for entries whose tests pin the divergent current behaviour. -- **`deferred`** marks a behaviour that is tracked but not yet covered by a test in this suite. - The reason names the covering tests elsewhere in the repo, starts with "Not implemented in the - SDK" for genuine feature gaps, or starts with "Not yet covered here" for tests that are planned. +- **`deferred`** marks a behaviour that is tracked but has no test in this suite, with a precise + reason: the SDK does not implement it, the negative cannot be observed, the assertion is + schema-level rather than interaction-level, the feature is experimental (tasks), or the test + would require real-time waits the suite refuses. - **`transports`** names the transports a behaviour applies to; omitted means transport-independent. - **`issue`** carries the tracking link for a recorded gap once one is filed. @@ -168,6 +172,15 @@ async def test_call_tool_returns_text_content() -> None: act → assert. The test reads in the order the conversation happens. - A registered handler or tool that a test never invokes gets a `raise NotImplementedError` body so it cannot silently become load-bearing. +- A test that needs a peer no real `Server` or `Client` can play (a server that answers initialize + with an unsupported version, a client that sends malformed params) plays that side of the wire by + hand over `create_client_server_memory_streams()`. This scripted-peer pattern is the suite's only + way to drive behaviour the typed API cannot produce, and the docstring of every such test says so. + +Stack a second `@requirement` decorator only when a test's natural assertions incidentally prove +another behaviour — one capabilities snapshot proving four `*:capability:declared` entries, one +input-schema identity check proving each preserved keyword. Do not build a test around covering +many requirements at once; if the assertions would be separate, write separate tests. ### Choosing an assertion From e0e8e57ec41f788c1d2b51cc91f85f11c1a9ba97 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 15:05:34 +0000 Subject: [PATCH 25/34] test: fix interaction suite for 3.10/3.11/3.14 and lowest-direct CI legs - raise pytest floor to 8.4.0 (RaisesGroup, used in the auth tests) - collect twice in connect_over_sse so the unclosed sse_stream_reader is finalized on 3.10 (PEP 442 cycle needs a second pass) - collapse stacked async-with statements into comma form where it reads cleanly, and apply # pragma: no branch on the remaining sync-with + async-with shapes coverage.py mis-traces on 3.11/3.14 --- pyproject.toml | 2 +- tests/interaction/README.md | 7 +- tests/interaction/_connect.py | 51 +++--- tests/interaction/auth/_harness.py | 22 ++- .../interaction/lowlevel/test_cancellation.py | 116 ++++++------- .../interaction/lowlevel/test_elicitation.py | 122 +++++++------- tests/interaction/lowlevel/test_initialize.py | 158 +++++++++--------- tests/interaction/lowlevel/test_progress.py | 2 +- tests/interaction/lowlevel/test_tools.py | 2 +- tests/interaction/lowlevel/test_wire.py | 2 +- tests/interaction/transports/test_bridge.py | 10 +- .../transports/test_client_transport_http.py | 57 ++++--- tests/interaction/transports/test_flows.py | 16 +- .../transports/test_hosting_http.py | 2 +- .../transports/test_hosting_resume.py | 6 +- .../transports/test_hosting_session.py | 2 +- tests/interaction/transports/test_stdio.py | 29 ++-- uv.lock | 2 +- 18 files changed, 320 insertions(+), 288 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b98e64a487..6d2319621a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ dev = [ # We add mcp[cli,ws] so `uv sync` considers the extras. "mcp[cli,ws]", "pyright>=1.1.400", - "pytest>=8.3.4", + "pytest>=8.4.0", "ruff>=0.8.5", "trio>=0.26.2", "pytest-flakefinder>=1.1.0", diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 1245eae30e..3afac44cfb 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -208,5 +208,8 @@ assert after the call, with no synchronisation. The exceptions: CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a -pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# pragma`, -`# type: ignore`, or `# noqa` comments; restructure instead. +pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or +`# noqa` comments; restructure instead. The one sanctioned pragma is `# pragma: no branch` on a +`with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context — restructure first, and reserve the pragma for shapes that cannot collapse (a sync +`with` adjacent to an `async with`). diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 3dda864cd5..26fa1c42ce 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -130,20 +130,21 @@ async def connect_over_streamable_http( retry_interval=retry_interval, transport_security=NO_DNS_REBINDING_PROTECTION, ) - async with server.session_manager.run(): - async with httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client: - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - async with Client( - transport, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client: - yield client + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(app), base_url=BASE_URL) as http_client, + Client( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client), + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client, + ): + yield client @asynccontextmanager @@ -183,11 +184,13 @@ async def mounted_app( auth_server_provider=auth_server_provider, ) event_hooks = {"request": [on_request]} if on_request is not None else None - async with server.session_manager.run(): - async with httpx.AsyncClient( + async with ( + server.session_manager.run(), + httpx.AsyncClient( transport=StreamingASGITransport(app), base_url=BASE_URL, event_hooks=event_hooks, headers=headers - ) as http_client: - yield http_client, server.session_manager + ) as http_client, + ): + yield http_client, server.session_manager @asynccontextmanager @@ -357,11 +360,13 @@ def httpx_client_factory( ) as client: yield client finally: - # SseServerTransport.connect_sse hands its internal SSE-chunk receive stream to - # sse_starlette's EventSourceResponse, which never closes it when its task group is - # cancelled on disconnect (see notes/findings.md). Collect the orphan here so its - # ResourceWarning fires deterministically inside this fixture instead of at an - # arbitrary later GC. + # SseServerTransport.connect_sse never closes its sse_stream_reader (handed to + # sse_starlette.EventSourceResponse, which does not aclose() its content on cancel). + # After teardown that reader is held only by a reference cycle through the connect_sse + # frame and its task objects; collecting twice runs the cycle's finalizers and then + # frees the reader while ResourceWarning is suppressed, instead of at an arbitrary + # later GC under pytest's error filter. One pass suffices on 3.11+; 3.10 needs both. with warnings.catch_warnings(): warnings.simplefilter("ignore", ResourceWarning) gc.collect() + gc.collect() diff --git a/tests/interaction/auth/_harness.py b/tests/interaction/auth/_harness.py index 8ee8263c6f..d013364f33 100644 --- a/tests/interaction/auth/_harness.py +++ b/tests/interaction/auth/_harness.py @@ -11,7 +11,7 @@ import json from collections.abc import AsyncIterator, Callable, Mapping, Sequence -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass, field from typing import Any from urllib.parse import parse_qs, parse_qsl, urlsplit @@ -451,11 +451,15 @@ async def hook(request: httpx.Request) -> None: event_hooks = {"request": [hook]} - async with server.session_manager.run(): - async with httpx.AsyncClient( - transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks - ) as http_client: - headless.bind(http_client) - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - async with Client(transport) as client: - yield client, headless + async with AsyncExitStack() as stack: + await stack.enter_async_context(server.session_manager.run()) + http_client = await stack.enter_async_context( + httpx.AsyncClient( + transport=StreamingASGITransport(app), base_url=BASE_URL, auth=oauth, event_hooks=event_hooks + ) + ) + headless.bind(http_client) + client = await stack.enter_async_context( + Client(streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client)) + ) + yield client, headless diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index f39b2014cf..6f1454e58a 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -13,7 +13,7 @@ from mcp import MCPError, types from mcp.client import ClientSession from mcp.server import Server, ServerRequestContext -from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( CallToolResult, @@ -170,63 +170,65 @@ async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_hand scripted-peer mechanism is the in-memory stream pair, not because the behaviour is transport-specific. """ - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def scripted_server() -> None: - def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: - return SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - init = await server_read.receive() - assert isinstance(init, SessionMessage) - assert isinstance(init.message, JSONRPCRequest) - assert init.message.method == "initialize" - await server_write.send( - respond( - init.message.id, - InitializeResult( - protocol_version="2025-11-25", - capabilities=ServerCapabilities(), - server_info=Implementation(name="scripted", version="0.0.1"), - ), + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage: + return SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) - initialized = await server_read.receive() - assert isinstance(initialized, SessionMessage) - assert isinstance(initialized.message, JSONRPCNotification) - assert initialized.message.method == "notifications/initialized" - - ping = await server_read.receive() - assert isinstance(ping, SessionMessage) - assert isinstance(ping.message, JSONRPCRequest) - assert ping.message.method == "ping" - # First answer with a fabricated id that matches nothing in flight, then the real id. - await server_write.send(respond(9999, EmptyResult())) - await server_write.send(respond(ping.message.id, EmptyResult())) - - incoming: list[IncomingMessage] = [] - - async def message_handler(message: IncomingMessage) -> None: - incoming.append(message) - - async with anyio.create_task_group() as task_group: - task_group.start_soon(scripted_server) - async with ClientSession(client_read, client_write, message_handler=message_handler) as session: - with anyio.fail_after(5): - await session.initialize() - pong = await session.send_request(PingRequest(), EmptyResult) - - assert pong == snapshot(EmptyResult()) - assert len(incoming) == 1 - assert isinstance(incoming[0], RuntimeError) - # The full message embeds the response object's repr; only the prefix is stable. - assert str(incoming[0]).startswith("Received response with an unknown request ID:") + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + assert init.message.method == "initialize" + await server_write.send( + respond( + init.message.id, + InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="scripted", version="0.0.1"), + ), + ) + ) + + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + + ping = await server_read.receive() + assert isinstance(ping, SessionMessage) + assert isinstance(ping.message, JSONRPCRequest) + assert ping.message.method == "ping" + # First answer with a fabricated id that matches nothing in flight, then the real id. + await server_write.send(respond(9999, EmptyResult())) + await server_write.send(respond(ping.message.id, EmptyResult())) + + incoming: list[IncomingMessage] = [] + + async def message_handler(message: IncomingMessage) -> None: + incoming.append(message) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + ClientSession(client_read, client_write, message_handler=message_handler) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + pong = await session.send_request(PingRequest(), EmptyResult) + + assert pong == snapshot(EmptyResult()) + assert len(incoming) == 1 + assert isinstance(incoming[0], RuntimeError) + # The full message embeds the response object's repr; only the prefix is stable. + assert str(incoming[0]).startswith("Received response with an unknown request ID:") diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 83a77592a9..2b264dbebe 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -11,7 +11,7 @@ from mcp import MCPError, UrlElicitationRequiredError, types from mcp.client import ClientRequestContext, ClientSession from mcp.server import Server, ServerRequestContext -from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( CallToolResult, @@ -594,68 +594,68 @@ async def answer_form(context: ClientRequestContext, params: types.ElicitRequest received.append(params) return ElicitResult(action="accept", content={}) - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def scripted_server() -> None: - initialize = await server_read.receive() - assert isinstance(initialize, SessionMessage) - request = initialize.message - assert isinstance(request, JSONRPCRequest) - assert request.method == "initialize" - result = InitializeResult( - protocol_version="2025-11-25", - capabilities=ServerCapabilities(), - server_info=Implementation(name="legacy", version="0.0.1"), - ) - await server_write.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + initialize = await server_read.receive() + assert isinstance(initialize, SessionMessage) + request = initialize.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-11-25", + capabilities=ServerCapabilities(), + server_info=Implementation(name="legacy", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) - initialized = await server_read.receive() - assert isinstance(initialized, SessionMessage) - assert isinstance(initialized.message, JSONRPCNotification) - assert initialized.message.method == "notifications/initialized" - # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. - await server_write.send( - SessionMessage( - JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="elicitation/create", - params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, - ) + ) + initialized = await server_read.receive() + assert isinstance(initialized, SessionMessage) + assert isinstance(initialized.message, JSONRPCNotification) + assert initialized.message.method == "notifications/initialized" + # No mode key: a server speaking a pre-mode revision of the spec sends only message + schema. + await server_write.send( + SessionMessage( + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="elicitation/create", + params={"message": "Legacy ask.", "requestedSchema": {"type": "object", "properties": {}}}, ) ) - response = await server_read.receive() - assert isinstance(response, SessionMessage) - server_received.append(response.message) - answered.set() - - async with anyio.create_task_group() as tg: - tg.start_soon(scripted_server) - async with ClientSession(client_read, client_write, elicitation_callback=answer_form) as session: - with anyio.fail_after(5): - await session.initialize() - await answered.wait() - - assert received == snapshot( - [ - ElicitRequestFormParams( - _meta=None, - message="Legacy ask.", - requested_schema={"type": "object", "properties": {}}, - ) - ] + ) + response = await server_read.receive() + assert isinstance(response, SessionMessage) + server_received.append(response.message) + answered.set() + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write, elicitation_callback=answer_form) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + await session.initialize() + await answered.wait() + + assert received == snapshot( + [ + ElicitRequestFormParams( + _meta=None, + message="Legacy ask.", + requested_schema={"type": "object", "properties": {}}, ) - assert isinstance(received[0], ElicitRequestFormParams) - assert received[0].mode == "form" - assert len(server_received) == 1 - assert isinstance(server_received[0], JSONRPCResponse) - assert server_received[0].id == 2 + ] + ) + assert isinstance(received[0], ElicitRequestFormParams) + assert received[0].mode == "form" + assert len(server_received) == 1 + assert isinstance(server_received[0], JSONRPCResponse) + assert server_received[0].id == 2 diff --git a/tests/interaction/lowlevel/test_initialize.py b/tests/interaction/lowlevel/test_initialize.py index 027c80505d..91adbf5611 100644 --- a/tests/interaction/lowlevel/test_initialize.py +++ b/tests/interaction/lowlevel/test_initialize.py @@ -15,7 +15,7 @@ from mcp.client import ClientRequestContext, ClientSession from mcp.client._memory import InMemoryTransport from mcp.server import Server, ServerRequestContext -from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( INVALID_PARAMS, @@ -241,14 +241,16 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa server = Server("strict", on_list_tools=list_tools) - async with InMemoryTransport(server) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - with anyio.fail_after(5): - with pytest.raises(MCPError) as exc_info: - await session.send_request(ListToolsRequest(), ListToolsResult) + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.send_request(ListToolsRequest(), ListToolsResult) - # Ping is explicitly permitted before initialization completes. - pong = await session.send_ping() + # Ping is explicitly permitted before initialization completes. + pong = await session.send_ping() assert exc_info.value.error == snapshot( ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") @@ -275,16 +277,20 @@ def initialize_request(protocol_version: str) -> InitializeRequest: ) ) - async with InMemoryTransport(server) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - with anyio.fail_after(5): - result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("2025-03-26"), InitializeResult) assert result.protocol_version == snapshot("2025-03-26") - async with InMemoryTransport(server) as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - with anyio.fail_after(5): - result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) + async with ( + InMemoryTransport(server) as (read_stream, write_stream), + ClientSession(read_stream, write_stream) as session, + ): + with anyio.fail_after(5): + result = await session.send_request(initialize_request("1999-01-01"), InitializeResult) assert result.protocol_version == snapshot("2025-11-25") @@ -297,40 +303,41 @@ async def test_unsupported_server_protocol_version_fails_initialization() -> Non answers it with a hand-built result. Reserve this pattern for behaviour no real server can be made to produce. """ - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def scripted_server() -> None: - message = await server_read.receive() - assert isinstance(message, SessionMessage) - request = message.message - assert isinstance(request, JSONRPCRequest) - assert request.method == "initialize" - result = InitializeResult( - protocol_version="1991-08-06", - capabilities=ServerCapabilities(), - server_info=Implementation(name="relic", version="0.0.1"), - ) - await server_write.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="1991-08-06", + capabilities=ServerCapabilities(), + server_info=Implementation(name="relic", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) + ) - async with anyio.create_task_group() as tg: - tg.start_soon(scripted_server) - async with ClientSession(client_read, client_write) as session: - with anyio.fail_after(5): - with pytest.raises(RuntimeError) as exc_info: - await session.initialize() + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(RuntimeError) as exc_info: + await session.initialize() - assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") + assert str(exc_info.value) == snapshot("Unsupported protocol version from the server: 1991-08-06") @requirement("lifecycle:version:downgrade") @@ -341,36 +348,37 @@ async def test_an_older_supported_protocol_version_from_the_server_is_accepted() plays the server's side of the wire by hand to return a fixed older version regardless of what was requested. Reserve this pattern for behaviour no real server can be made to produce. """ - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def scripted_server() -> None: - message = await server_read.receive() - assert isinstance(message, SessionMessage) - request = message.message - assert isinstance(request, JSONRPCRequest) - assert request.method == "initialize" - result = InitializeResult( - protocol_version="2025-06-18", - capabilities=ServerCapabilities(), - server_info=Implementation(name="conservative", version="0.0.1"), - ) - await server_write.send( - SessionMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=request.id, - # Serialized exactly as a real server serializes results onto the wire. - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + message = await server_read.receive() + assert isinstance(message, SessionMessage) + request = message.message + assert isinstance(request, JSONRPCRequest) + assert request.method == "initialize" + result = InitializeResult( + protocol_version="2025-06-18", + capabilities=ServerCapabilities(), + server_info=Implementation(name="conservative", version="0.0.1"), + ) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=request.id, + # Serialized exactly as a real server serializes results onto the wire. + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) + ) - async with anyio.create_task_group() as tg: - tg.start_soon(scripted_server) - async with ClientSession(client_read, client_write) as session: - with anyio.fail_after(5): - initialize_result = await session.initialize() + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as tg, + ClientSession(client_read, client_write) as session, + ): + tg.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + initialize_result = await session.initialize() - assert initialize_result.protocol_version == snapshot("2025-06-18") + assert initialize_result.protocol_version == snapshot("2025-06-18") diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 54faf85888..db44deb091 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -191,7 +191,7 @@ async def call(label: str, collect: ProgressFnT) -> None: await client.call_tool("report", {"label": label}, progress_callback=collect) with anyio.fail_after(5): - async with anyio.create_task_group() as task_group: + async with anyio.create_task_group() as task_group: # pragma: no branch task_group.start_soon(call, "a", collect_a) task_group.start_soon(call, "b", collect_b) await entered["a"].wait() diff --git a/tests/interaction/lowlevel/test_tools.py b/tests/interaction/lowlevel/test_tools.py index 95bb6bd790..e8053fbaa7 100644 --- a/tests/interaction/lowlevel/test_tools.py +++ b/tests/interaction/lowlevel/test_tools.py @@ -343,7 +343,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server) as client: with anyio.fail_after(5): - async with anyio.create_task_group() as task_group: + async with anyio.create_task_group() as task_group: # pragma: no branch async def call_and_record(tag: str) -> None: results[tag] = await client.call_tool("echo", {"tag": tag}) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index a3453b7b2a..1a5d32129d 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -178,7 +178,7 @@ async def call_and_capture_error() -> None: ) errors.append(exc_info.value.error) - async with anyio.create_task_group() as task_group: + async with anyio.create_task_group() as task_group: # pragma: no branch task_group.start_soon(call_and_capture_error) await handler_started.wait() await server_write.aclose() diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py index 71be14ced0..b1a42543f4 100644 --- a/tests/interaction/transports/test_bridge.py +++ b/tests/interaction/transports/test_bridge.py @@ -28,10 +28,12 @@ async def chunked_app(scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": True}) await send({"type": "http.response.body", "body": b"second", "more_body": False}) - async with httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http: - async with http.stream("GET", "/chunks") as response: - with anyio.fail_after(5): - chunks = [chunk async for chunk in response.aiter_raw()] + async with ( + httpx.AsyncClient(transport=StreamingASGITransport(chunked_app), base_url="http://bridge") as http, + http.stream("GET", "/chunks") as response, + ): + with anyio.fail_after(5): + chunks = [chunk async for chunk in response.aiter_raw()] assert response.status_code == 200 assert response.headers["content-type"] == "text/plain" diff --git a/tests/interaction/transports/test_client_transport_http.py b/tests/interaction/transports/test_client_transport_http.py index 2d9d0c42b6..1c9de371ac 100644 --- a/tests/interaction/transports/test_client_transport_http.py +++ b/tests/interaction/transports/test_client_transport_http.py @@ -135,16 +135,15 @@ async def test_concurrent_tool_calls_each_open_a_post_stream_and_receive_their_o async def record(request: httpx.Request) -> None: requests.append(request) - async with mounted_app(_tooled_server(), on_request=record) as (http, _): - async with client_via_http(http) as client: + async with mounted_app(_tooled_server(), on_request=record) as (http, _), client_via_http(http) as client: - async def call(n: int) -> None: - results[n] = await client.call_tool("echo", {"text": str(n)}) + async def call(n: int) -> None: + results[n] = await client.call_tool("echo", {"text": str(n)}) - with anyio.fail_after(5): - async with anyio.create_task_group() as tg: - for n in (1, 2, 3): - tg.start_soon(call, n) + with anyio.fail_after(5): # pragma: no branch + async with anyio.create_task_group() as tg: # pragma: no branch + for n in (1, 2, 3): + tg.start_soon(call, n) assert results == snapshot( { @@ -176,13 +175,14 @@ async def filter_methods(scope: Scope, receive: Receive, send: Send) -> None: return await real_app(scope, receive, send) - async with server.session_manager.run(): - http_client = httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) - async with http_client: - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - with anyio.fail_after(5): - async with Client(transport) as client: - result = await client.list_tools() + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(filter_methods), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + result = await client.list_tools() assert [tool.name for tool in result.tools] == ["echo"] @@ -201,10 +201,12 @@ async def record(request: httpx.Request) -> None: requests.append(request) server = _tooled_server() - async with mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _): - async with client_via_http(http) as client: - with anyio.fail_after(5): - result = await client.list_tools() + async with ( + mounted_app(server, event_store=SequencedEventStore(), retry_interval=0, on_request=record) as (http, _), + client_via_http(http) as client, + ): + with anyio.fail_after(5): + result = await client.list_tools() assert [tool.name for tool in result.tools] == ["echo"] resumption_gets = [r for r in requests if r.method == "GET" and "last-event-id" in r.headers] @@ -232,13 +234,14 @@ async def first_post_then_404(scope: Scope, receive: Receive, send: Send) -> Non initialize_seen.set() await real_app(scope, receive, send) - async with server.session_manager.run(): - http_client = httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) - async with http_client: - transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) - with anyio.fail_after(5): - async with Client(transport) as client: - with pytest.raises(MCPError) as exc_info: - await client.list_tools() + async with ( + server.session_manager.run(), + httpx.AsyncClient(transport=StreamingASGITransport(first_post_then_404), base_url=BASE_URL) as http_client, + ): + transport = streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) + with anyio.fail_after(5): # pragma: no branch + async with Client(transport) as client: # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch + await client.list_tools() assert exc_info.value.error == snapshot(ErrorData(code=INVALID_REQUEST, message="Session terminated")) diff --git a/tests/interaction/transports/test_flows.py b/tests/interaction/transports/test_flows.py index 6e3d787356..c428fe2d68 100644 --- a/tests/interaction/transports/test_flows.py +++ b/tests/interaction/transports/test_flows.py @@ -46,7 +46,7 @@ async def collect_b(params: LoggingMessageNotificationParams) -> None: async with mounted_app(mcp) as (http, _): with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call(label: str, collect: LoggingFnT) -> None: async with client_via_http(http, logging_callback=collect) as client: @@ -112,12 +112,14 @@ def echo(text: str) -> str: """Return the input unchanged.""" return text - async with mounted_app(mcp) as (http, _): - async with connect_over_sse(mcp) as sse_client: - async with client_via_http(http) as shttp_client: - with anyio.fail_after(5): - shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) - sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) + async with ( + mounted_app(mcp) as (http, _), + connect_over_sse(mcp) as sse_client, + client_via_http(http) as shttp_client, + ): + with anyio.fail_after(5): + shttp_result = await shttp_client.call_tool("echo", {"text": "via http"}) + sse_result = await sse_client.call_tool("echo", {"text": "via sse"}) assert shttp_result == snapshot( CallToolResult(content=[TextContent(text="via http")], structured_content={"result": "via http"}) diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index aa9beee067..f842f4083e 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -222,7 +222,7 @@ async def read_standalone_stream() -> None: standalone_ready = anyio.Event() seen_on_standalone = anyio.Event() with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch tg.start_soon(read_standalone_stream) await standalone_ready.wait() diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index bb98a96e7a..06bffed27c 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -75,7 +75,7 @@ async def test_a_post_sse_stream_begins_with_a_priming_event_and_stamps_every_ev async with mounted_app(_counting_server(), event_store=SequencedEventStore(), retry_interval=0) as (http, _): session_id = await initialize_via_http(http) with anyio.fail_after(5): - async with http.stream( + async with http.stream( # pragma: no branch "POST", "/mcp", content=_tools_call(1, "count", {"n": 2}), headers=base_headers(session_id=session_id) ) as response: assert response.status_code == 200 @@ -155,7 +155,7 @@ async def count(ctx: Context) -> str: await store.wait_until_stored(4) await store.wait_until_stored(8) replay_headers = base_headers(session_id=session_id) | {"last-event-id": last_seen} - async with http.stream("GET", "/mcp", headers=replay_headers) as replay: + async with http.stream("GET", "/mcp", headers=replay_headers) as replay: # pragma: no branch assert replay.status_code == 200 missed = await _read_events(replay, 3) @@ -274,7 +274,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: mcp, event_store=SequencedEventStore(), retry_interval=0, logging_callback=collect ) as client: with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch async def call() -> None: result.append(await client.call_tool("interrupt", {})) diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py index 561fbf251a..da1f5626a7 100644 --- a/tests/interaction/transports/test_hosting_session.py +++ b/tests/interaction/transports/test_hosting_session.py @@ -195,7 +195,7 @@ async def list_via(label: str) -> None: results[label] = await client.list_tools() with anyio.fail_after(5): - async with anyio.create_task_group() as tg: + async with anyio.create_task_group() as tg: # pragma: no branch tg.start_soon(list_via, "a") tg.start_soon(list_via, "b") diff --git a/tests/interaction/transports/test_stdio.py b/tests/interaction/transports/test_stdio.py index 2d15d61ff8..27cc65de42 100644 --- a/tests/interaction/transports/test_stdio.py +++ b/tests/interaction/transports/test_stdio.py @@ -111,22 +111,25 @@ async def test_stdio_server_writes_one_jsonrpc_message_per_line() -> None: sent_line = json.dumps(initialize_body(request_id=1)) + "\n" with anyio.fail_after(5): - async with stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + async with ( + stdio_server(stdin=anyio.wrap_file(io.StringIO(sent_line)), stdout=anyio.wrap_file(captured)) as ( + read_stream, + write_stream, + ), read_stream, write_stream, ): - async with read_stream, write_stream: - received = await read_stream.receive() - assert isinstance(received, SessionMessage) - assert isinstance(received.message, JSONRPCRequest) - assert received.message.method == "initialize" - - response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) - notification = JSONRPCNotification( - jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} - ) - await write_stream.send(SessionMessage(response)) - await write_stream.send(SessionMessage(notification)) + received = await read_stream.receive() + assert isinstance(received, SessionMessage) + assert isinstance(received.message, JSONRPCRequest) + assert received.message.method == "initialize" + + response = JSONRPCResponse(jsonrpc="2.0", id=1, result={"text": "line\nbreak"}) + notification = JSONRPCNotification( + jsonrpc="2.0", method="notifications/message", params={"level": "info", "data": "two\nlines"} + ) + await write_stream.send(SessionMessage(response)) + await write_stream.send(SessionMessage(notification)) output = captured.getvalue() assert output.endswith("\n") diff --git a/uv.lock b/uv.lock index b396898b66..5b72e97fce 100644 --- a/uv.lock +++ b/uv.lock @@ -939,7 +939,7 @@ dev = [ { name = "mcp", extras = ["cli", "ws"], editable = "." }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, - { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest", specifier = ">=8.4.0" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, { name = "pytest-pretty", specifier = ">=1.2.0" }, From 9f2b105d635af6d507279505f5da067cc5e62f9b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 15:24:59 +0000 Subject: [PATCH 26/34] test: close leaked SSE receive streams instead of gc-collecting them SseServerTransport.connect_sse never closed sse_stream_reader after EventSourceResponse returned; the in-process bridge dropped its chunk reader when a request was cancelled before the response started. Close both at source so the interaction suite no longer needs a gc.collect() workaround, and pull one assert inside its async-with body to clear the last 3.11 coverage gap. --- src/mcp/server/sse.py | 1 + tests/interaction/_connect.py | 36 ++++++------------- tests/interaction/transports/_bridge.py | 13 ++++--- .../transports/test_hosting_session.py | 19 +++++----- tests/interaction/transports/test_sse.py | 6 ---- 5 files changed, 30 insertions(+), 45 deletions(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 3e5261896b..be8e979c9d 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -179,6 +179,7 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( scope, receive, send ) + await sse_stream_reader.aclose() await read_stream_writer.aclose() await write_stream_reader.aclose() self._read_stream_writers.pop(session_id, None) diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 26fa1c42ce..9c71acee9b 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -7,8 +7,6 @@ (session ids, SSE encoding, session management) runs with no sockets, threads, or subprocesses. """ -import gc -import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, Protocol @@ -347,26 +345,14 @@ def httpx_client_factory( ) transport = sse_client(f"{BASE_URL}/sse", httpx_client_factory=httpx_client_factory) - try: - async with Client( - transport, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client: - yield client - finally: - # SseServerTransport.connect_sse never closes its sse_stream_reader (handed to - # sse_starlette.EventSourceResponse, which does not aclose() its content on cancel). - # After teardown that reader is held only by a reference cycle through the connect_sse - # frame and its task objects; collecting twice runs the cycle's finalizers and then - # frees the reader while ResourceWarning is suppressed, instead of at an arbitrary - # later GC under pytest's error filter. One pass suffices on 3.11+; 3.10 needs both. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ResourceWarning) - gc.collect() - gc.collect() + async with Client( + transport, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client: + yield client diff --git a/tests/interaction/transports/_bridge.py b/tests/interaction/transports/_bridge.py index 6d0bfd62d4..f78c6d14b5 100644 --- a/tests/interaction/transports/_bridge.py +++ b/tests/interaction/transports/_bridge.py @@ -151,11 +151,16 @@ async def run_application() -> None: await chunk_writer.aclose() self._task_group.start_soon(run_application) - await response_started.wait() - if application_error is not None: - # No response will be built, so close the reader the response body would have owned. + try: + await response_started.wait() + if application_error is not None: + raise application_error + except BaseException: + # No response will be built, so close the reader the response body would have owned + # and tell the application its peer has gone away. + client_disconnected.set() await chunk_reader.aclose() - raise application_error + raise return httpx.Response( status_code=response_status, headers=response_headers, diff --git a/tests/interaction/transports/test_hosting_session.py b/tests/interaction/transports/test_hosting_session.py index da1f5626a7..a926c3e8a2 100644 --- a/tests/interaction/transports/test_hosting_session.py +++ b/tests/interaction/transports/test_hosting_session.py @@ -115,17 +115,16 @@ async def test_delete_terminates_the_session_and_subsequent_requests_return_404( json={"jsonrpc": "2.0", "id": 2, "method": "tools/list"}, headers=base_headers(session_id=session_id), ) - - assert (post.status_code, post.json()) == snapshot( - ( - 404, - { - "jsonrpc": "2.0", - "id": None, - "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, - }, + assert (post.status_code, post.json()) == snapshot( + ( + 404, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32600, "message": "Not Found: Session has been terminated"}, + }, + ) ) - ) @requirement("hosting:session:isolation") diff --git a/tests/interaction/transports/test_sse.py b/tests/interaction/transports/test_sse.py index 4facadec73..9c7353dda5 100644 --- a/tests/interaction/transports/test_sse.py +++ b/tests/interaction/transports/test_sse.py @@ -7,8 +7,6 @@ through the suite's streaming ASGI bridge. """ -import gc -import warnings from uuid import UUID, uuid4 import anyio @@ -59,10 +57,6 @@ def httpx_client_factory( assert await client.send_ping() == snapshot(EmptyResult()) assert sse._read_stream_writers == {} - # See connect_over_sse: collect the one stream sse_starlette never closes on disconnect. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", ResourceWarning) - gc.collect() @requirement("transport:sse:post:session-routing") From 7a026f282045bbaab02b9f9c5b3d4279fa3aa9b1 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 16:57:38 +0000 Subject: [PATCH 27/34] test: correct manifest divergence notes and route in-handler notifications deterministically MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four manifest fixes from spec/SDK re-verification: - lifecycle:capability:* divergence notes used SHOULD; spec basic/lifecycle#operation has been MUST since 2025-06-18 - mcpserver:tool:naming-validation deferred reason claimed no naming check exists; Tool.from_function calls validate_and_warn_tool_name (warns, doesn't reject) - converted to a Divergence with a pinning test - client-auth:...issuer-validation divergence's second sentence is false (OAuthMetadata types the endpoints AnyHttpUrl, so scheme is validated) - resources:annotations now records that the SDK Annotations model lacks lastModified; the round-trip test sends it via model_validate so the snapshot pins the drop Twelve lowlevel tests sent notifications from inside a tool handler without related_request_id, so on the streamable-HTTP leg they routed to the standalone GET stream and the assertion relied on cross-stream ordering the suite documents as not guaranteed. Eight now pass related_request_id; four whose senders don't accept it use anyio.Event with the snapshot still proving the delivered set. The module docstrings that overstated the ordering guarantee are corrected. README §Coverage now documents the four lax-no-cover teardown markers and the sse.py aclose() fix that landed alongside this suite. --- tests/interaction/README.md | 17 +++++++-- tests/interaction/_requirements.py | 26 +++++++------ .../interaction/lowlevel/test_elicitation.py | 9 +++-- .../interaction/lowlevel/test_list_changed.py | 23 +++++++++-- tests/interaction/lowlevel/test_logging.py | 22 +++++++---- tests/interaction/lowlevel/test_progress.py | 38 ++++++++++++------- tests/interaction/lowlevel/test_resources.py | 18 +++++++-- tests/interaction/mcpserver/test_tools.py | 35 +++++++++++++++++ 8 files changed, 141 insertions(+), 47 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 3afac44cfb..23a308a0ea 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -209,7 +209,16 @@ assert after the call, with no synchronisation. The exceptions: CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or -`# noqa` comments; restructure instead. The one sanctioned pragma is `# pragma: no branch` on a -`with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested -async context — restructure first, and reserve the pragma for shapes that cannot collapse (a sync -`with` adjacent to an `async with`). +`# noqa` comments; restructure instead. The one sanctioned pragma in this suite's test code is +`# pragma: no branch` on a `with`/`async with` line whose only fault is coverage.py mis-tracing +the exit arc of a nested async context — restructure first, and reserve the pragma for shapes +that cannot collapse (a sync `with` adjacent to an `async with`). + +A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose +execution is timing-dependent under the in-process HTTP bridge — the POST-stream and +stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` +check in `message_router`, and the response-stream double-close guard in +`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to +strict `no cover` without first making the teardown ordering deterministic. The suite also relies +on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the +SSE leg would otherwise leak. diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 526af032b2..e0bbd52b4b 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -90,7 +90,7 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The client does not check its own declared capabilities before sending notifications or " - "serving callbacks; nothing prevents a caller from violating the spec's SHOULD." + "serving callbacks; nothing prevents a caller from violating the spec's MUST." ), ), deferred=( @@ -106,7 +106,7 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The client sends any request regardless of the server's advertised capabilities and " - "surfaces whatever the server answers; the spec's SHOULD is not enforced." + "surfaces whatever the server answers; the spec's MUST is not enforced." ), ), deferred=( @@ -693,9 +693,12 @@ def __post_init__(self) -> None: "mcpserver:tool:naming-validation": Requirement( source="sdk", behavior="Tool names that violate the spec's naming rules are rejected at registration time.", - deferred=( - "Not implemented in the SDK: MCPServer accepts any string as a tool name; there is no " - "spec-naming-rules check at registration time." + divergence=Divergence( + note=( + "MCPServer runs the SEP-986 naming check at registration (validate_and_warn_tool_name at " + "tools/base.py) and logs a warning for non-conforming names, but does not reject them; the " + "bool result is discarded and registration proceeds." + ), ), ), "mcpserver:tool:output-schema:model": Requirement( @@ -769,9 +772,12 @@ def __post_init__(self) -> None: # ═══════════════════════════════════════════════════════════════════════════ "resources:annotations": Requirement( source=f"{SPEC_BASE_URL}/server/resources#annotations", - behavior=( - "Resource annotations (audience, priority) supplied by the server round-trip to the client " - "in the list result." + behavior="Resource annotations supplied by the server round-trip to the client in the list result.", + divergence=Divergence( + note=( + "The SDK Annotations model is missing the schema's lastModified field; MCPModel uses the " + "pydantic default extra='ignore', so the value is silently dropped on parse." + ), ), ), "resources:capability:declared": Requirement( @@ -2413,9 +2419,7 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The SDK parses authorization-server metadata without comparing issuer to the discovery " - "URL; a mismatched issuer is accepted and the flow proceeds. The SDK also does not " - "validate that the document's authorization_endpoint, token_endpoint, and " - "registration_endpoint use http(s) schemes." + "URL; a mismatched issuer is accepted and the flow proceeds." ), ), ), diff --git a/tests/interaction/lowlevel/test_elicitation.py b/tests/interaction/lowlevel/test_elicitation.py index 2b264dbebe..b8edf601d0 100644 --- a/tests/interaction/lowlevel/test_elicitation.py +++ b/tests/interaction/lowlevel/test_elicitation.py @@ -304,8 +304,9 @@ async def test_elicitation_complete_notification_carries_the_elicited_id_back_to The lifecycle under test: the tool elicits a URL interaction with an elicitationId, the user agrees to visit the URL, the out-of-band interaction finishes, and the server emits elicitation/complete so the client can correlate the completion with the elicitation it - accepted earlier. Both messages arrive before the tool call returns, so a plain collected - list needs no synchronisation. + accepted earlier. The completion notification carries ``related_request_id`` so over + streamable HTTP it rides the tool call's own stream and reaches the client before the call + returns; the same ordering already holds on in-memory and SSE transports. """ elicitation_id = "auth-001" elicited_ids: list[str] = [] @@ -327,7 +328,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara "Authorize access to your files.", "https://example.com/oauth/authorize", elicitation_id ) assert answer.action == "accept" - await ctx.session.send_elicit_complete(elicitation_id) + await ctx.session.send_elicit_complete(elicitation_id, related_request_id=ctx.request_id) return CallToolResult(content=[TextContent(text="linked")]) server = Server("authorizer", on_list_tools=list_tools, on_call_tool=call_tool) @@ -559,7 +560,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "noop" - await ctx.session.send_elicit_complete("never-elicited") + await ctx.session.send_elicit_complete("never-elicited", related_request_id=ctx.request_id) return CallToolResult(content=[TextContent(text="ok")]) server = Server("notifier", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index eb20db207b..0a681fffa7 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -1,11 +1,14 @@ """List-changed notifications from the low-level Server, driven through the public Client API. -The notifications are emitted from inside a tool call, so the ordering guarantee described in -test_logging.py applies: they reach the client's message handler before the tool call returns, -and the tests assert on a plain collected list with no synchronisation. The collector records -every message the handler receives, so the assertions also prove nothing else was delivered. +``send_*_list_changed`` does not take a ``related_request_id``, so over streamable HTTP the +notification routes to the standalone GET stream and is not guaranteed to arrive before the tool +result on its POST stream. Tests therefore wait on an event the collector sets, the same pattern +as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. +The collector still records every message it receives, so the snapshot also proves nothing else +was delivered. """ +import anyio import pytest from inline_snapshot import snapshot @@ -29,9 +32,11 @@ async def test_tool_list_changed_notification(connect: Connect) -> None: """A tools/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] + seen = anyio.Event() async def collect(message: IncomingMessage) -> None: received.append(message) + seen.set() async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -47,6 +52,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server, message_handler=collect) as client: await client.call_tool("install", {}) + with anyio.fail_after(5): + await seen.wait() assert received == snapshot([ToolListChangedNotification()]) @@ -55,9 +62,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_resource_list_changed_notification(connect: Connect) -> None: """A resources/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] + seen = anyio.Event() async def collect(message: IncomingMessage) -> None: received.append(message) + seen.set() async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -73,6 +82,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server, message_handler=collect) as client: await client.call_tool("mount", {}) + with anyio.fail_after(5): + await seen.wait() assert received == snapshot([ResourceListChangedNotification()]) @@ -81,9 +92,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_prompt_list_changed_notification(connect: Connect) -> None: """A prompts/list_changed notification sent during a tool call reaches the client's message handler.""" received: list[IncomingMessage] = [] + seen = anyio.Event() async def collect(message: IncomingMessage) -> None: received.append(message) + seen.set() async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -99,5 +112,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server, message_handler=collect) as client: await client.call_tool("learn", {}) + with anyio.fail_after(5): + await seen.wait() assert received == snapshot([PromptListChangedNotification()]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index 792334ecd2..a7b2372083 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -2,11 +2,11 @@ Notification ordering: the in-memory transport delivers every server-to-client message on one ordered stream, and the client's receive loop dispatches each incoming message to completion -before reading the next one. Together these guarantee that every notification the server sends -before its response reaches the client callback before the originating request returns, so tests -collect notifications into a plain list and assert after the request completes -- no events, no -waiting. This does not generalise to transports that split messages across streams (the -streamable HTTP standalone GET stream); tests over those transports must synchronise explicitly. +before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds +only for messages that carry a ``related_request_id`` (they ride the originating request's POST +stream); without it the message routes to the standalone GET stream and may arrive after the +response. These tests pass ``related_request_id`` so they can collect into a plain list and +assert after the request completes on every transport leg -- no events, no waiting. """ import pytest @@ -68,8 +68,12 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "chatty" - await ctx.session.send_log_message(level="info", data="starting up", logger="app.lifecycle") - await ctx.session.send_log_message(level="error", data={"code": 502, "retryable": True}) + await ctx.session.send_log_message( + level="info", data="starting up", logger="app.lifecycle", related_request_id=ctx.request_id + ) + await ctx.session.send_log_message( + level="error", data={"code": 502, "retryable": True}, related_request_id=ctx.request_id + ) return CallToolResult(content=[TextContent(text="done")]) server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) @@ -102,7 +106,9 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "siren" for level in ALL_LEVELS: - await ctx.session.send_log_message(level=level, data=f"a {level} message") + await ctx.session.send_log_message( + level=level, data=f"a {level} message", related_request_id=ctx.request_id + ) return CallToolResult(content=[TextContent(text="logged")]) server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index db44deb091..6350c33a33 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -1,10 +1,12 @@ """Progress interactions against the low-level Server, driven through the public Client API. Server-to-client progress emitted during a request follows the same ordering guarantee as -logging notifications (see test_logging.py): everything the server sends before its response is -dispatched to the progress callback before the request returns, so no synchronisation is needed. -The client-to-server direction is a standalone notification with no response to await, so that -test waits on an event set by the server's handler. +logging notifications (see test_logging.py) -- on the in-memory transport unconditionally, and +over streamable HTTP only when sent with ``related_request_id`` so the notification rides the +originating request's POST stream rather than the standalone GET stream. These tests pass +``related_request_id`` so no synchronisation is needed. The client-to-server direction is a +standalone notification with no response to await, so that test waits on an event set by the +server's handler. """ import anyio @@ -42,9 +44,15 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert ctx.meta is not None token = ctx.meta.get("progress_token") assert token is not None - await ctx.session.send_progress_notification(token, 1.0, total=3.0, message="first chunk") - await ctx.session.send_progress_notification(token, 2.0, total=3.0, message="second chunk") - await ctx.session.send_progress_notification(token, 3.0, total=3.0, message="done") + await ctx.session.send_progress_notification( + token, 1.0, total=3.0, message="first chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 2.0, total=3.0, message="second chunk", related_request_id=str(ctx.request_id) + ) + await ctx.session.send_progress_notification( + token, 3.0, total=3.0, message="done", related_request_id=str(ctx.request_id) + ) return CallToolResult(content=[TextContent(text="downloaded")]) server = Server("downloader", on_list_tools=list_tools, on_call_tool=call_tool) @@ -166,10 +174,14 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara # The two handlers interleave by waiting on alternating turns: a takes 0 and 2, b takes 1 and 3. first, second = (0, 2) if label == "a" else (1, 3) await turns[first].wait() - await ctx.session.send_progress_notification(token, progress_values[label][0]) + await ctx.session.send_progress_notification( + token, progress_values[label][0], related_request_id=str(ctx.request_id) + ) turns[first + 1].set() await turns[second].wait() - await ctx.session.send_progress_notification(token, progress_values[label][1]) + await ctx.session.send_progress_notification( + token, progress_values[label][1], related_request_id=str(ctx.request_id) + ) if second + 1 < len(turns): turns[second + 1].set() return CallToolResult(content=[TextContent(text="done")]) @@ -227,7 +239,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara token = ctx.meta.get("progress_token") assert token is not None captured.append((ctx.session, token)) - await ctx.session.send_progress_notification(token, 0.5) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) return CallToolResult(content=[TextContent(text="done")]) server = Server("reporter", on_list_tools=list_tools, on_call_tool=call_tool) @@ -276,9 +288,9 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert ctx.meta is not None token = ctx.meta.get("progress_token") assert token is not None - await ctx.session.send_progress_notification(token, 0.5) - await ctx.session.send_progress_notification(token, 0.3) - await ctx.session.send_progress_notification(token, 0.9) + await ctx.session.send_progress_notification(token, 0.5, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.3, related_request_id=str(ctx.request_id)) + await ctx.session.send_progress_notification(token, 0.9, related_request_id=str(ctx.request_id)) return CallToolResult(content=[TextContent(text="done")]) server = Server("zigzagger", on_list_tools=list_tools, on_call_tool=call_tool) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index b6bed63a9c..9c25404e32 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -2,6 +2,7 @@ import base64 +import anyio import pytest from inline_snapshot import snapshot @@ -38,6 +39,9 @@ async def test_list_resources_returns_registered_resources(connect: Connect) -> """Listed resources reach the client with their URIs, names, and optional descriptive fields intact. The fully-populated entry includes annotations, so the snapshot also proves they round-trip. + The SDK's Annotations model omits the schema's lastModified field (see the divergence on + resources:annotations); the input is built via model_validate with lastModified set so the + snapshot pins the drop and will fail once the SDK adds the field. """ async def list_resources( @@ -53,7 +57,9 @@ async def list_resources( description="The project's front page.", mime_type="text/markdown", size=1024, - annotations=Annotations(audience=["user", "assistant"], priority=0.8), + annotations=Annotations.model_validate( + {"audience": ["user", "assistant"], "priority": 0.8, "lastModified": "2025-01-01T00:00:00Z"} + ), icons=[Icon(src="https://example.com/readme.png", mime_type="image/png", sizes=["48x48"])], ), ] @@ -253,13 +259,17 @@ async def unsubscribe_resource(ctx: ServerRequestContext, params: types.Unsubscr async def test_resource_updated_notification_reaches_client(connect: Connect) -> None: """A resources/updated notification sent during a tool call reaches the client with the resource URI. - The collector records every message the handler receives, so the assertion also proves nothing - else was delivered. + ``send_resource_updated`` does not take a ``related_request_id``, so over streamable HTTP the + notification routes to the standalone GET stream and is not guaranteed to arrive before the + tool result; the test waits on an event the collector sets. The collector records every + message the handler receives, so the assertion also proves nothing else was delivered. """ received: list[IncomingMessage] = [] + seen = anyio.Event() async def collect(message: IncomingMessage) -> None: received.append(message) + seen.set() async def list_tools( ctx: ServerRequestContext, params: types.PaginatedRequestParams | None @@ -275,6 +285,8 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async with connect(server, message_handler=collect) as client: await client.call_tool("touch", {}) + with anyio.fail_after(5): + await seen.wait() assert received == snapshot( [ResourceUpdatedNotification(params=ResourceUpdatedNotificationParams(uri="file:///watched.txt"))] diff --git a/tests/interaction/mcpserver/test_tools.py b/tests/interaction/mcpserver/test_tools.py index f8aa208d7f..05135c1286 100644 --- a/tests/interaction/mcpserver/test_tools.py +++ b/tests/interaction/mcpserver/test_tools.py @@ -1,5 +1,6 @@ """Tool interactions against MCPServer, driven through the public Client API.""" +import logging from typing import Annotated, Literal import pytest @@ -308,6 +309,40 @@ def echo_second() -> str: ) +@requirement("mcpserver:tool:naming-validation") +async def test_registering_a_tool_with_a_spec_invalid_name_warns_but_does_not_reject( + connect: Connect, caplog: pytest.LogCaptureFixture +) -> None: + """A tool name that violates the SEP-986 rules logs a warning at registration but is still registered. + + The intended behaviour is rejection at registration time; MCPServer instead logs the + naming-rule violation and proceeds (see the divergence note on the requirement). The warning + spans several SDK-authored log records, so only the stable prefix and inclusion of the + offending name are asserted. + """ + mcp = MCPServer("naming") + + with caplog.at_level(logging.WARNING, logger="mcp.shared.tool_name_validation"): + + @mcp.tool(name="bad name!") + def bad() -> str: + return "ok" + + assert any( + rec.levelno == logging.WARNING + and rec.message.startswith("Tool name validation warning") + and "bad name!" in rec.message + for rec in caplog.records + ) + + async with connect(mcp) as client: + listed = await client.list_tools() + result = await client.call_tool("bad name!", {}) + + assert [tool.name for tool in listed.tools] == ["bad name!"] + assert result == snapshot(CallToolResult(content=[TextContent(text="ok")], structured_content={"result": "ok"})) + + @requirement("mcpserver:tool:url-elicitation-error") async def test_decorated_tool_raising_url_elicitation_required_surfaces_as_error_32042(connect: Connect) -> None: """A decorated tool raising the URL-elicitation-required error reaches the client as error -32042. From 171a01f47c1128247d408c5c498463385fd9cee7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 17:20:05 +0000 Subject: [PATCH 28/34] test: mark replay_sender's stream-id check no-branch for 3.10 coverage The False arc exits through the enclosing async-with and is dropped by coverage.py on 3.10 only; 3.11+ record both arcs. --- src/mcp/server/streamable_http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index c85eeeeadf..8b8441e968 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -904,7 +904,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer From 05a41e1fc31772d4fb7d3ca4f6c899de7b147b16 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 17:35:31 +0000 Subject: [PATCH 29/34] test: tighten manifest wording and assertion conventions from review pass - _provider.py docstrings now describe the provider, not its authoring history - eight manifest behavior/divergence notes scoped to what the spec/SDK actually state (logging field-shape only, sse endpoint-event spec-only, scope-selection AS-fallback step, sampling result-balance ValueError, hosting:session:delete doesn't remove transport, low-level elicitation validation/restriction scope, progress no-token second clause) - client-transport:http:reconnect-get deferred reason no longer claims the feature is unimplemented - rename the sampling mixed-content rejection test to match what it asserts - pagination cursor pass-through asserted by identity per the suite convention - mcpserver:prompt:unknown-name docstring acknowledges the code-0 divergence - test_bridge cancel_on_close test bounds the transport-close wait - drop now-stale no-branch pragma at shared/auth.py:94 --- src/mcp/shared/auth.py | 2 +- tests/interaction/_requirements.py | 45 ++++++++++--------- tests/interaction/auth/_provider.py | 5 +-- tests/interaction/lowlevel/test_pagination.py | 39 ++++++++-------- tests/interaction/lowlevel/test_sampling.py | 4 +- tests/interaction/mcpserver/test_prompts.py | 6 ++- tests/interaction/transports/test_bridge.py | 4 +- 7 files changed, 56 insertions(+), 49 deletions(-) diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index dd93ad7e17..3b48152d5b 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -91,7 +91,7 @@ def validate_scope(self, requested_scope: str | None) -> list[str] | None: requested_scopes = requested_scope.split(" ") allowed_scopes = [] if self.scope is None else self.scope.split(" ") for scope in requested_scopes: - if scope not in allowed_scopes: # pragma: no branch + if scope not in allowed_scopes: raise InvalidScopeError(f"Client was not registered with scope {scope}") return requested_scopes diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index e0bbd52b4b..d5f16185cc 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -424,10 +424,7 @@ def __post_init__(self) -> None: ), "protocol:progress:no-token": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", - behavior=( - "Without a progress callback no token is attached, and a handler that reports progress anyway " - "sends nothing." - ), + behavior="Without a progress callback the request carries no progress token.", ), "protocol:progress:client-to-server": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/progress#progress-flow", @@ -1079,7 +1076,7 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/server/utilities/logging#log-message-notifications", behavior=( "A log message sent by a server handler is delivered to the client's logging callback with its " - "severity level, logger name, and data, in the order the server sent them." + "severity level, logger name, and data." ), ), "logging:message:filtered": Requirement( @@ -1219,7 +1216,8 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/client/sampling#tool-use-and-result-balance", behavior=( "Every assistant tool_use block in a sampling request must be matched by a tool_result with " - "the same id in the following user message; an unmatched tool_use is rejected with Invalid params." + "the same id in the following user message; an unmatched tool_use is rejected with a ValueError " + "before the request is sent." ), ), "sampling:tools:server-gated-by-capability": Requirement( @@ -1331,8 +1329,9 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "Nothing restricts or validates the requested-schema shape on the sending side; a server " - "can send nested or non-primitive schemas and the SDK forwards them unchanged." + "ServerSession.elicit_form forwards an arbitrary dict[str, Any] schema unchanged; no shape " + "validation at the low-level session layer (the high-level Context.elicit / " + "elicit_with_validation helper enforces primitive-only fields before generating the schema)." ), ), ), @@ -1343,7 +1342,11 @@ def __post_init__(self) -> None: "the response before sending and the server validates the content it receives." ), divergence=Divergence( - note="Accepted elicitation content passes through unvalidated on both sides.", + note=( + "The client never validates outbound content; ServerSession.elicit_form returns received " + "content unvalidated (the high-level Context.elicit / elicit_with_validation helper " + "validates server-side, but the low-level session API does not)." + ), ), ), "elicitation:url:action:accept-no-content": Requirement( @@ -1788,18 +1791,15 @@ def __post_init__(self) -> None: ), "transport:sse:endpoint-event": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#backwards-compatibility", - behavior=( - "Opening the SSE stream delivers an `endpoint` event naming the message-POST URL and a fresh " - "session identifier; the server registers the session before the event is sent and releases it " - "when the stream disconnects." - ), + behavior="Opening the SSE stream delivers an `endpoint` event naming the message-POST URL as the first event.", transports=("sse",), ), "transport:sse:post:session-routing": Requirement( source="sdk", behavior=( - "A POST to the SSE message endpoint that names no session id, a malformed session id, or an " - "unknown session id is rejected (400/400/404) instead of being forwarded." + "The endpoint URL carries a fresh session identifier; the server registers the session before " + "the endpoint event is sent and releases it when the stream disconnects, and a POST that names " + "no session id, a malformed session id, or an unknown session id is rejected (400/400/404)." ), transports=("sse",), ), @@ -1830,7 +1830,7 @@ def __post_init__(self) -> None: ), "hosting:session:delete": Requirement( source=f"{SPEC_BASE_URL}/basic/transports#session-management", - behavior="DELETE with a valid Mcp-Session-Id terminates the session and removes its transport.", + behavior="DELETE with a valid Mcp-Session-Id terminates the session.", transports=("streamable-http",), ), "hosting:session:id-charset": Requirement( @@ -2333,10 +2333,10 @@ def __post_init__(self) -> None: ), transports=("streamable-http",), deferred=( - "Not implemented in the SDK: the server's standalone GET stream emits no priming event or " - "retry hint, so the client's reconnection path always sleeps the hard-coded 1 s default; a " - "deterministic in-process test would require accepting that real-time wait. The POST-stream " - "reconnection path is covered by client-transport:http:reconnect-post-priming." + "The server's standalone GET stream emits no priming event or retry hint, so the client's " + "reconnection path always sleeps the hard-coded 1 s default; a deterministic in-process test " + "would require accepting that real-time wait. The POST-stream reconnection path is covered " + "by client-transport:http:reconnect-post-priming." ), ), "client-transport:http:reconnect-post-priming": Requirement( @@ -2586,7 +2586,8 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic/authorization#scope-selection-strategy", behavior=( "The client selects the requested scope from WWW-Authenticate when present, then from the " - "protected-resource metadata, and otherwise omits scope." + "protected-resource metadata, then (as an SDK addition beyond the spec's chain) from the " + "AS metadata's scopes_supported, and otherwise omits scope." ), transports=("streamable-http",), ), diff --git a/tests/interaction/auth/_provider.py b/tests/interaction/auth/_provider.py index 34b434e4a9..5c88995a30 100644 --- a/tests/interaction/auth/_provider.py +++ b/tests/interaction/auth/_provider.py @@ -5,8 +5,7 @@ values are unique without being predictable. The behaviour mirrors what the SDK's authorization handlers expect: `authorize` immediately mints a code and returns the redirect, `exchange_*` issue and rotate tokens, and `load_*` are simple lookups. Only the parts the auth interaction -suite drives are implemented; methods the tests do not yet reach raise `NotImplementedError` -and are filled in by the chunk that first exercises them. +suite drives are implemented; methods the suite does not exercise raise `NotImplementedError`. """ import secrets @@ -183,5 +182,5 @@ async def exchange_refresh_token( ) async def revoke_token(self, token: AccessToken | RefreshToken) -> None: - """Implemented when the bearer/lifecycle tests first exercise revocation.""" + """Not exercised by this suite; revocation is out of scope for the interaction tests.""" raise NotImplementedError diff --git a/tests/interaction/lowlevel/test_pagination.py b/tests/interaction/lowlevel/test_pagination.py index 0c2a0b1588..77db90401e 100644 --- a/tests/interaction/lowlevel/test_pagination.py +++ b/tests/interaction/lowlevel/test_pagination.py @@ -32,6 +32,7 @@ async def test_next_cursor_round_trips_through_the_client(connect: Connect) -> N """The next_cursor a list handler returns reaches the client, and the cursor the client sends back on the following call reaches the handler verbatim. """ + cursor = "page-2" seen_cursors: list[str | None] = [] async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListToolsResult: @@ -40,7 +41,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa if params.cursor is None: return ListToolsResult( tools=[Tool(name="alpha", input_schema={"type": "object"})], - next_cursor="page-2", + next_cursor=cursor, ) return ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})]) @@ -48,13 +49,12 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa async with connect(server) as client: first_page = await client.list_tools() - second_page = await client.list_tools(cursor="page-2") + second_page = await client.list_tools(cursor=first_page.next_cursor) - assert first_page == snapshot( - ListToolsResult(tools=[Tool(name="alpha", input_schema={"type": "object"})], next_cursor="page-2") - ) + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] + assert [tool.name for tool in first_page.tools] == ["alpha"] assert second_page == snapshot(ListToolsResult(tools=[Tool(name="beta", input_schema={"type": "object"})])) - assert seen_cursors == snapshot([None, "page-2"]) @requirement("pagination:exhaustion") @@ -158,6 +158,7 @@ async def list_tools(ctx: ServerRequestContext, params: types.PaginatedRequestPa @requirement("resources:list:pagination") async def test_resources_list_supports_cursor_pagination(connect: Connect) -> None: """resources/list round-trips the cursor like every other list operation.""" + cursor = "page-2" seen_cursors: list[str | None] = [] async def list_resources( @@ -166,18 +167,18 @@ async def list_resources( assert params is not None seen_cursors.append(params.cursor) if params.cursor is None: - return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor="page-2") + return ListResourcesResult(resources=[Resource(uri="memo://1", name="first")], next_cursor=cursor) return ListResourcesResult(resources=[Resource(uri="memo://2", name="second")]) server = Server("paginated", on_list_resources=list_resources) async with connect(server) as client: first_page = await client.list_resources() - second_page = await client.list_resources(cursor="page-2") + second_page = await client.list_resources(cursor=first_page.next_cursor) - assert seen_cursors == snapshot([None, "page-2"]) + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] assert [resource.name for resource in first_page.resources] == ["first"] - assert first_page.next_cursor == "page-2" assert [resource.name for resource in second_page.resources] == ["second"] assert second_page.next_cursor is None @@ -185,6 +186,7 @@ async def list_resources( @requirement("resources:templates:pagination") async def test_resource_templates_list_supports_cursor_pagination(connect: Connect) -> None: """resources/templates/list round-trips the cursor like every other list operation.""" + cursor = "page-2" seen_cursors: list[str | None] = [] async def list_resource_templates( @@ -195,7 +197,7 @@ async def list_resource_templates( if params.cursor is None: return ListResourceTemplatesResult( resource_templates=[ResourceTemplate(name="first", uri_template="users://{id}")], - next_cursor="page-2", + next_cursor=cursor, ) return ListResourceTemplatesResult( resource_templates=[ResourceTemplate(name="second", uri_template="teams://{id}")] @@ -205,11 +207,11 @@ async def list_resource_templates( async with connect(server) as client: first_page = await client.list_resource_templates() - second_page = await client.list_resource_templates(cursor="page-2") + second_page = await client.list_resource_templates(cursor=first_page.next_cursor) - assert seen_cursors == snapshot([None, "page-2"]) + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] assert [template.name for template in first_page.resource_templates] == ["first"] - assert first_page.next_cursor == "page-2" assert [template.name for template in second_page.resource_templates] == ["second"] assert second_page.next_cursor is None @@ -217,23 +219,24 @@ async def list_resource_templates( @requirement("prompts:list:pagination") async def test_prompts_list_supports_cursor_pagination(connect: Connect) -> None: """prompts/list round-trips the cursor like every other list operation.""" + cursor = "page-2" seen_cursors: list[str | None] = [] async def list_prompts(ctx: ServerRequestContext, params: types.PaginatedRequestParams | None) -> ListPromptsResult: assert params is not None seen_cursors.append(params.cursor) if params.cursor is None: - return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor="page-2") + return ListPromptsResult(prompts=[Prompt(name="first")], next_cursor=cursor) return ListPromptsResult(prompts=[Prompt(name="second")]) server = Server("paginated", on_list_prompts=list_prompts) async with connect(server) as client: first_page = await client.list_prompts() - second_page = await client.list_prompts(cursor="page-2") + second_page = await client.list_prompts(cursor=first_page.next_cursor) - assert seen_cursors == snapshot([None, "page-2"]) + assert first_page.next_cursor == cursor + assert seen_cursors == [None, cursor] assert [prompt.name for prompt in first_page.prompts] == ["first"] - assert first_page.next_cursor == "page-2" assert [prompt.name for prompt in second_page.prompts] == ["second"] assert second_page.next_cursor is None diff --git a/tests/interaction/lowlevel/test_sampling.py b/tests/interaction/lowlevel/test_sampling.py index 53a246b2e8..8efd5e4c31 100644 --- a/tests/interaction/lowlevel/test_sampling.py +++ b/tests/interaction/lowlevel/test_sampling.py @@ -350,8 +350,8 @@ async def sampling_callback( @requirement("sampling:tool-result:no-mixed-content") -async def test_create_message_with_unbalanced_tool_messages_is_rejected(connect: Connect) -> None: - """A sampling request whose messages mix tool results with other content never leaves the server. +async def test_create_message_with_mixed_tool_result_content_is_rejected(connect: Connect) -> None: + """A sampling request whose user message mixes tool_result with other content never leaves the server. The message-structure validation runs inside create_message before the request is sent, even when no tools are passed, so the client callback is never invoked and the handler observes the diff --git a/tests/interaction/mcpserver/test_prompts.py b/tests/interaction/mcpserver/test_prompts.py index ddea4d8278..2095f086d4 100644 --- a/tests/interaction/mcpserver/test_prompts.py +++ b/tests/interaction/mcpserver/test_prompts.py @@ -75,7 +75,11 @@ def greet(name: str) -> str: @requirement("mcpserver:prompt:unknown-name") async def test_get_unknown_prompt_is_error(connect: Connect) -> None: - """Getting a prompt name that was never registered fails with a JSON-RPC error.""" + """Getting a prompt name that was never registered fails with a JSON-RPC error. + + The spec reserves -32602 for this case; the SDK reports code 0 (see the divergence note on + the requirement). + """ mcp = MCPServer("prompter") @mcp.prompt() diff --git a/tests/interaction/transports/test_bridge.py b/tests/interaction/transports/test_bridge.py index b1a42543f4..7420b9d902 100644 --- a/tests/interaction/transports/test_bridge.py +++ b/tests/interaction/transports/test_bridge.py @@ -86,8 +86,8 @@ async def lingering_app(scope: Scope, receive: Receive, send: Send) -> None: cleanup_ran.set() transport = StreamingASGITransport(lingering_app, cancel_on_close=False) - async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: - with anyio.fail_after(5): + with anyio.fail_after(5): + async with httpx.AsyncClient(transport=transport, base_url="http://bridge") as http: async with http.stream("GET", "/linger") as response: assert response.status_code == 200 assert not cleanup_ran.is_set() From ca0ba11ba2bf440c4e4fb33e7156970a92a38bac Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 17:46:56 +0000 Subject: [PATCH 30/34] test: declare capabilities the notification tests rely on Nine notification tests sent capability-gated messages (log, resource_updated, roots/list_changed, *_list_changed) from a peer that had not declared the capability, which only works because the SDK does not yet enforce capability gating. Adding stub handlers / list_roots_callback so the capability is advertised makes the tests match their requirement preconditions and survive the gate fix unchanged. The three list_changed tests cannot set listChanged=True without threading NotificationOptions through the in-memory and HTTP-manager connection paths; the requirement behavior text now describes what the tests prove (send -> arrives) and the module docstring records the remaining coupling on lifecycle:capability:server-not-advertised. --- tests/interaction/_requirements.py | 12 ++++----- tests/interaction/lowlevel/test_flows.py | 12 ++++++++- .../interaction/lowlevel/test_list_changed.py | 22 ++++++++++++++-- tests/interaction/lowlevel/test_logging.py | 12 +++++++-- tests/interaction/lowlevel/test_resources.py | 18 ++++++++++++- tests/interaction/lowlevel/test_roots.py | 6 ++++- tests/interaction/lowlevel/test_wire.py | 10 ++++++-- tests/interaction/transports/_stdio_server.py | 9 ++++++- .../transports/test_hosting_http.py | 25 ++++++++++++++++++- 9 files changed, 109 insertions(+), 17 deletions(-) diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index d5f16185cc..f11d64abfa 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -586,8 +586,8 @@ def __post_init__(self) -> None: "tools:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/tools#list-changed-notification", behavior=( - "When the tool set changes, a server that declared the tools listChanged capability sends " - "notifications/tools/list_changed and it reaches the client's handler." + "When the tool set changes, the server sends notifications/tools/list_changed and it reaches " + "the client's handler." ), ), "tools:list:basic": Requirement( @@ -787,8 +787,8 @@ def __post_init__(self) -> None: "resources:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/resources#list-changed-notification", behavior=( - "When the resource set changes, a server that declared the resources listChanged capability " - "sends notifications/resources/list_changed and it reaches the client's handler." + "When the resource set changes, the server sends notifications/resources/list_changed and it " + "reaches the client's handler." ), ), "resources:list:basic": Requirement( @@ -959,8 +959,8 @@ def __post_init__(self) -> None: "prompts:list-changed": Requirement( source=f"{SPEC_BASE_URL}/server/prompts#list-changed-notification", behavior=( - "When the prompt set changes, a server that declared the prompts listChanged capability sends " - "notifications/prompts/list_changed and it reaches the client's handler." + "When the prompt set changes, the server sends notifications/prompts/list_changed and it " + "reaches the client's handler." ), ), "prompts:list:basic": Requirement( diff --git a/tests/interaction/lowlevel/test_flows.py b/tests/interaction/lowlevel/test_flows.py index 8ff9dd4f1d..8d96582341 100644 --- a/tests/interaction/lowlevel/test_flows.py +++ b/tests/interaction/lowlevel/test_flows.py @@ -23,6 +23,7 @@ ElicitRequestFormParams, ElicitRequestURLParams, ElicitResult, + EmptyResult, ListToolsResult, ReadResourceResult, ResourceLink, @@ -167,7 +168,16 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) return CallToolResult(content=[TextContent(text="contents")]) - server = Server("gatekeeper", on_list_tools=_list_tools("read_files"), on_call_tool=call_tool) + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server( + "gatekeeper", + on_list_tools=_list_tools("read_files"), + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + ) async def collect(message: IncomingMessage) -> None: if isinstance(message, ElicitCompleteNotification): diff --git a/tests/interaction/lowlevel/test_list_changed.py b/tests/interaction/lowlevel/test_list_changed.py index 0a681fffa7..a2f85eeacf 100644 --- a/tests/interaction/lowlevel/test_list_changed.py +++ b/tests/interaction/lowlevel/test_list_changed.py @@ -6,6 +6,12 @@ as ``transports/test_streamable_http.py::test_unrelated_server_messages_arrive_on_the_standalone_stream``. The collector still records every message it receives, so the snapshot also proves nothing else was delivered. + +The servers register the parent capability (resources/prompts) so that part of the spec's +precondition holds, but the ``listChanged`` sub-capability stays ``False``: ``NotificationOptions`` +is not threaded through any of the suite's connection paths. The tests therefore rely on the +recorded ``lifecycle:capability:server-not-advertised`` divergence and will need updating +alongside the fix that introduces capability gating. """ import anyio @@ -78,7 +84,13 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_resource_list_changed() return CallToolResult(content=[TextContent(text="mounted")]) - server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_resources=list_resources) async with connect(server, message_handler=collect) as client: await client.call_tool("mount", {}) @@ -108,7 +120,13 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_prompt_list_changed() return CallToolResult(content=[TextContent(text="learned")]) - server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool) + async def list_prompts( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + """Registered so the prompts capability is advertised; the client never lists prompts.""" + raise NotImplementedError + + server = Server("registry", on_list_tools=list_tools, on_call_tool=call_tool, on_list_prompts=list_prompts) async with connect(server, message_handler=collect) as client: await client.call_tool("learn", {}) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index a7b2372083..fba632ef4d 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -76,7 +76,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) return CallToolResult(content=[TextContent(text="done")]) - server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) async with connect(server, logging_callback=collect) as client: result = await client.call_tool("chatty", {}) @@ -111,7 +115,11 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) return CallToolResult(content=[TextContent(text="logged")]) - server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool) + async def set_logging_level(ctx: ServerRequestContext, params: types.SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + server = Server("logger", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) async with connect(server, logging_callback=collect) as client: await client.call_tool("siren", {}) diff --git a/tests/interaction/lowlevel/test_resources.py b/tests/interaction/lowlevel/test_resources.py index 9c25404e32..4e369d3645 100644 --- a/tests/interaction/lowlevel/test_resources.py +++ b/tests/interaction/lowlevel/test_resources.py @@ -281,7 +281,23 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara await ctx.session.send_resource_updated("file:///watched.txt") return CallToolResult(content=[TextContent(text="touched")]) - server = Server("library", on_list_tools=list_tools, on_call_tool=call_tool) + async def list_resources( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: types.SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + server = Server( + "library", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) async with connect(server, message_handler=collect) as client: await client.call_tool("touch", {}) diff --git a/tests/interaction/lowlevel/test_roots.py b/tests/interaction/lowlevel/test_roots.py index 577b99819c..8149e0befb 100644 --- a/tests/interaction/lowlevel/test_roots.py +++ b/tests/interaction/lowlevel/test_roots.py @@ -154,7 +154,11 @@ async def roots_list_changed(ctx: ServerRequestContext, params: types.Notificati server = Server("rooted", on_roots_list_changed=roots_list_changed) - async with connect(server) as client: + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + + async with connect(server, list_roots_callback=list_roots) as client: await client.send_roots_list_changed() with anyio.fail_after(5): await delivered.wait() diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 1a5d32129d..0f9c58aa7a 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -15,7 +15,7 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client import ClientSession +from mcp.client import ClientRequestContext, ClientSession from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext @@ -33,6 +33,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ListRootsResult, TextContent, ) from tests.interaction._helpers import RecordingTransport, _RecordingReadStream @@ -87,9 +88,14 @@ async def test_notifications_are_never_answered() -> None: the messages received from the server must be exactly one response per request, each carrying the id of the request it answers, and nothing else. """ + + async def list_roots(context: ClientRequestContext) -> ListRootsResult: + """Registered so the client declares the roots capability; the server never asks for roots.""" + raise NotImplementedError + recording = RecordingTransport(InMemoryTransport(_echo_server())) - async with Client(recording) as client: + async with Client(recording, list_roots_callback=list_roots) as client: await client.send_roots_list_changed() await client.send_ping() diff --git a/tests/interaction/transports/_stdio_server.py b/tests/interaction/transports/_stdio_server.py index fbe7e614f7..5977cc3e99 100644 --- a/tests/interaction/transports/_stdio_server.py +++ b/tests/interaction/transports/_stdio_server.py @@ -15,8 +15,10 @@ from mcp.types import ( CallToolRequestParams, CallToolResult, + EmptyResult, ListToolsResult, PaginatedRequestParams, + SetLevelRequestParams, TextContent, Tool, ) @@ -41,7 +43,12 @@ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> return CallToolResult(content=[TextContent(text=text)]) -server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool) +async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + +server = Server("stdio-echo", on_list_tools=list_tools, on_call_tool=call_tool, on_set_logging_level=set_logging_level) async def main() -> None: diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index f842f4083e..62fb04b914 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -19,12 +19,16 @@ PARSE_ERROR, CallToolRequestParams, CallToolResult, + EmptyResult, JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ListResourcesResult, ListToolsResult, PaginatedRequestParams, + SetLevelRequestParams, + SubscribeRequestParams, TextContent, ) from tests.interaction._connect import ( @@ -52,7 +56,26 @@ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> await ctx.session.send_resource_updated("file:///watched.txt") return CallToolResult(content=[TextContent(text="done")]) - return Server("hosted", on_list_tools=list_tools, on_call_tool=call_tool) + async def set_logging_level(ctx: ServerRequestContext, params: SetLevelRequestParams) -> EmptyResult: + """Registered so the logging capability is advertised; the client never sets a level.""" + raise NotImplementedError + + async def list_resources(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListResourcesResult: + """Registered so the resources capability is advertised; the client never lists resources.""" + raise NotImplementedError + + async def subscribe_resource(ctx: ServerRequestContext, params: SubscribeRequestParams) -> EmptyResult: + """Registered so the resources subscribe sub-capability is advertised; the client never subscribes.""" + raise NotImplementedError + + return Server( + "hosted", + on_list_tools=list_tools, + on_call_tool=call_tool, + on_set_logging_level=set_logging_level, + on_list_resources=list_resources, + on_subscribe_resource=subscribe_resource, + ) @requirement("hosting:http:method-405") From 2d621242e49de38e8a6c39a4575de50e5122848f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 20:25:28 +0000 Subject: [PATCH 31/34] test: prove json-response Content-Type and explicit resumption-token API at the wire --- src/mcp/server/streamable_http.py | 2 +- tests/interaction/_connect.py | 2 + .../transports/test_hosting_http.py | 24 +++++ .../transports/test_hosting_resume.py | 87 ++++++++++++++++++- .../transports/test_streamable_http.py | 1 - 5 files changed, 111 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8b8441e968..f2f4407cea 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -1042,7 +1042,7 @@ async def message_router(): yield read_stream, write_stream finally: for stream_id in list(self._request_streams.keys()): - await self._clean_up_memory_streams(stream_id) # pragma: no cover + await self._clean_up_memory_streams(stream_id) self._request_streams.clear() # Clean up the read and write streams diff --git a/tests/interaction/_connect.py b/tests/interaction/_connect.py index 9c71acee9b..1faf4aa8d6 100644 --- a/tests/interaction/_connect.py +++ b/tests/interaction/_connect.py @@ -150,6 +150,7 @@ async def mounted_app( server: Server | MCPServer, *, stateless_http: bool = False, + json_response: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, transport_security: TransportSecuritySettings | None = NO_DNS_REBINDING_PROTECTION, @@ -174,6 +175,7 @@ async def mounted_app( lowlevel = server._lowlevel_server if isinstance(server, MCPServer) else server app = lowlevel.streamable_http_app( stateless_http=stateless_http, + json_response=json_response, event_store=event_store, retry_interval=retry_interval, transport_security=transport_security, diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 62fb04b914..85e64ded42 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -179,6 +179,30 @@ async def test_protocol_version_header_is_validated() -> None: assert defaulted.status_code == 202 +@requirement("hosting:http:json-response-mode") +async def test_json_response_mode_answers_with_application_json_not_sse() -> None: + """With JSON response mode enabled, request POSTs are answered with a single application/json body. + + Asserted at the wire level because the SDK client parses either representation, so a + Client-driven round trip cannot distinguish a JSON response from an SSE one. + """ + async with mounted_app(_server(), json_response=True) as (http, _): + initialized = await http.post("/mcp", json=initialize_body(), headers=base_headers()) + session_id = initialized.headers["mcp-session-id"] + ping = await http.post( + "/mcp", + json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, + headers=base_headers(session_id=session_id), + ) + + assert initialized.status_code == 200 + assert initialized.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(initialized.json()).id == 1 + assert ping.status_code == 200 + assert ping.headers["content-type"].split(";", 1)[0] == "application/json" + assert JSONRPCResponse.model_validate(ping.json()).id == 2 + + @requirement("hosting:http:notifications-202") async def test_notification_post_returns_202_with_no_body() -> None: """A POST containing only a notification (no request ID) returns 202 Accepted with no body.""" diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index 06bffed27c..6ab9ff4b3f 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -17,8 +17,14 @@ from httpx_sse import EventSource, ServerSentEvent from inline_snapshot import snapshot +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamable_http_client from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.message import ClientMessageMetadata from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolRequest, + CallToolRequestParams, CallToolResult, JSONRPCNotification, JSONRPCRequest, @@ -28,6 +34,7 @@ jsonrpc_message_adapter, ) from tests.interaction._connect import ( + BASE_URL, base_headers, connect_over_streamable_http, initialize_via_http, @@ -229,13 +236,12 @@ async def hold(ctx: Context) -> str: await finished.wait() -# This test intentionally carries every resumability requirement: the close-then-resume -# scenario is indivisible, so splitting it would mean six near-identical bodies. +# This test intentionally carries every automatic-reconnection requirement: the +# close-then-resume scenario is indivisible, so splitting it would mean five near-identical bodies. @requirement("hosting:resume:close-stream") @requirement("transport:streamable-http:resumability") @requirement("client-transport:http:reconnect-post-priming") @requirement("client-transport:http:reconnect-retry-value") -@requirement("client-transport:http:resume-stream-api") @requirement("flow:resume:tool-call-resumption-token") async def test_a_call_whose_stream_the_server_closes_is_resumed_by_the_client() -> None: """A server-closed request stream is reconnected by the client and the call completes. @@ -288,3 +294,78 @@ async def call() -> None: [CallToolResult(content=[TextContent(text="resumed")], structured_content={"result": "resumed"})] ) assert received == snapshot(["before close", "after close"]) + + +@requirement("client-transport:http:resume-stream-api") +async def test_a_captured_resumption_token_replays_missed_messages_on_a_new_connection() -> None: + """A resumption token captured via on_resumption_token_update on one connection lets a fresh + connection retrieve the messages it missed by passing resumption_token to send_request. + + This is the explicit ClientMessageMetadata API, distinct from the automatic reconnection the + previous test covers: the transport dispatches a resumption_token request as a GET with + Last-Event-ID instead of POSTing the body, and remaps the replayed response onto the new + request's id. Client.call_tool does not expose ClientMessageMetadata, so the test drives a + bare ClientSession via session.send_request -- the sanctioned drop-down for behaviour Client + cannot express. The second connection carries the original session id but does not initialize + (the server-side session already is), modelling a caller that resumes after a process restart. + """ + captured: list[str] = [] + received: list[object] = [] + first_seen = anyio.Event() + token_seen = anyio.Event() + release = anyio.Event() + store = SequencedEventStore() + + mcp = MCPServer("resumable") + + @mcp.tool() + async def hold(ctx: Context) -> str: + """Emit one notification, wait for the test, emit another, return.""" + await ctx.info("first") + await release.wait() + await ctx.info("second") + return "done" + + async def on_token(token: str) -> None: + captured.append(token) + if len(captured) >= 2: + token_seen.set() + + async def collect(params: LoggingMessageNotificationParams) -> None: + received.append(params.data) + first_seen.set() + + call = CallToolRequest(params=CallToolRequestParams(name="hold", arguments={})) + capture = ClientMessageMetadata(on_resumption_token_update=on_token) + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): + with anyio.fail_after(5): # pragma: no branch + async with ( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), + ClientSession(r1, w1, logging_callback=collect) as first, + anyio.create_task_group() as tg, + ): + await first.initialize() + tg.start_soon(first.send_request, call, CallToolResult, None, capture) + await first_seen.wait() + await token_seen.wait() + tg.cancel_scope.cancel() + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + + release.set() + # init priming + init response + call priming + "first" + "second" + result = 6 stored events. + await store.wait_until_stored(6) + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + async with ( + streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), + ClientSession(r2, w2, logging_callback=collect) as second, + ): + result = await second.send_request( + call, CallToolResult, metadata=ClientMessageMetadata(resumption_token=captured[-1]) + ) + assert result == snapshot(CallToolResult(content=[TextContent(text="done")], structured_content={"result": "done"})) + assert received == snapshot(["first", "second"]) diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index 72af075770..d38e2a0bb3 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -63,7 +63,6 @@ async def announce(ctx: Context) -> str: @requirement("transport:streamable-http:json-response") -@requirement("hosting:http:json-response-mode") @requirement("client-transport:http:json-response-parsed") async def test_tool_call_over_streamable_http_with_json_responses() -> None: """The round trip works when the server answers with a single JSON body instead of an SSE stream.""" From 5e129bf58bcc65a4b807050d9fca6e428b28ce86 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 20:41:39 +0000 Subject: [PATCH 32/34] test: cancel only the abandoned call so 3.11/3.14 trace the resumption test cleanly --- .../interaction/transports/test_hosting_resume.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index 6ab9ff4b3f..dcab8ae3b5 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -346,10 +346,18 @@ async def collect(params: LoggingMessageNotificationParams) -> None: anyio.create_task_group() as tg, ): await first.initialize() - tg.start_soon(first.send_request, call, CallToolResult, None, capture) + # The call is abandoned via its own scope so the task group exits cleanly: cancelling + # the whole group propagates Cancelled through this frame, which 3.11's tracer mishandles. + call_scope = anyio.CancelScope() + + async def issue_call() -> None: + with call_scope: + await first.send_request(call, CallToolResult, metadata=capture) + + tg.start_soon(issue_call) await first_seen.wait() await token_seen.wait() - tg.cancel_scope.cancel() + call_scope.cancel() assert captured == snapshot(["3", "4"]) assert received == snapshot(["first"]) # The session id is only observable via the manager (the client transport does not expose it). @@ -360,7 +368,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: await store.wait_until_stored(6) http.headers["mcp-session-id"] = session_id http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION - async with ( + async with ( # pragma: no branch streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), ClientSession(r2, w2, logging_callback=collect) as second, ): From 2ee59b7b88e16c1bb3b1c13f7e50865fadc5f165 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 20:59:01 +0000 Subject: [PATCH 33/34] test: restructure resumption test so 3.11 traces every line The previous attempt cancelled only the child task, but ClientSession and streamable_http_client both cancel internal task groups in __aexit__, so the comma-form unwind still tripped 3.11's tracer dead-zone. Hoist the phase-1 assertions and header setup inside the block before cancelling, and split fail_after into two phases so no sync statements sit between the cancel-on-exit unwind and the next await. --- .../transports/test_hosting_resume.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index dcab8ae3b5..b02927a17e 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -340,34 +340,27 @@ async def collect(params: LoggingMessageNotificationParams) -> None: async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, manager): with anyio.fail_after(5): # pragma: no branch - async with ( + async with ( # pragma: no branch streamable_http_client(f"{BASE_URL}/mcp", http_client=http, terminate_on_close=False) as (r1, w1), ClientSession(r1, w1, logging_callback=collect) as first, anyio.create_task_group() as tg, ): await first.initialize() - # The call is abandoned via its own scope so the task group exits cleanly: cancelling - # the whole group propagates Cancelled through this frame, which 3.11's tracer mishandles. - call_scope = anyio.CancelScope() - - async def issue_call() -> None: - with call_scope: - await first.send_request(call, CallToolResult, metadata=capture) - - tg.start_soon(issue_call) + tg.start_soon(first.send_request, call, CallToolResult, None, capture) await first_seen.wait() await token_seen.wait() - call_scope.cancel() - assert captured == snapshot(["3", "4"]) - assert received == snapshot(["first"]) - # The session id is only observable via the manager (the client transport does not expose it). - (session_id,) = manager._server_instances + assert captured == snapshot(["3", "4"]) + assert received == snapshot(["first"]) + # The session id is only observable via the manager (the client transport does not expose it). + (session_id,) = manager._server_instances + http.headers["mcp-session-id"] = session_id + http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION + tg.cancel_scope.cancel() + with anyio.fail_after(5): # pragma: no branch release.set() # init priming + init response + call priming + "first" + "second" + result = 6 stored events. await store.wait_until_stored(6) - http.headers["mcp-session-id"] = session_id - http.headers["mcp-protocol-version"] = LATEST_PROTOCOL_VERSION async with ( # pragma: no branch streamable_http_client(f"{BASE_URL}/mcp", http_client=http) as (r2, w2), ClientSession(r2, w2, logging_callback=collect) as second, From 93cc828074b8da7a46ab29bcbd43a47bd905383b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 27 May 2026 21:35:33 +0000 Subject: [PATCH 34/34] test: mark the resumption-token test's post-exit line lax-no-cover for 3.11 CPython gh-106749 (wontfix on 3.11): coro.throw() omits the call trace event, desyncing coverage.py's CTracer so the first plain statement after a ClientSession/streamable_http_client __aexit__ is dropped. Restructures only relocate the miss (verified empirically across five variants); name the upstream bug and sanction lax-no-cover for this case in the README. --- tests/interaction/README.md | 12 ++++++++---- tests/interaction/transports/test_hosting_resume.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index 23a308a0ea..be68c3b0f1 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -209,10 +209,14 @@ assert after the call, with no synchronisation. The exceptions: CI requires 100% line and branch coverage, including `tests/`, and `strict-no-cover` fails the build if a line marked `# pragma: no cover` is ever executed. When a new test starts covering a pragma'd line in `src/`, delete the pragma in the same change. Do not add new `# type: ignore` or -`# noqa` comments; restructure instead. The one sanctioned pragma in this suite's test code is -`# pragma: no branch` on a `with`/`async with` line whose only fault is coverage.py mis-tracing -the exit arc of a nested async context — restructure first, and reserve the pragma for shapes -that cannot collapse (a sync `with` adjacent to an `async with`). +`# noqa` comments; restructure instead. Two pragmas are sanctioned in this suite's test code, both +for known-upstream tracer bugs and only after restructuring has been tried: `# pragma: no branch` +on a `with`/`async with` line whose only fault is coverage.py mis-tracing the exit arc of a nested +async context (reserve it for shapes that cannot collapse — a sync `with` adjacent to an +`async with`); and `# pragma: lax no cover` on a single statement that 3.11's tracer drops because +the preceding `async with` unwinds via `coro.throw()` (python/cpython#106749, wontfix on 3.11) — +this hits any test that must run statements after a `ClientSession`/`streamable_http_client` exits +but still inside an outer `async with`, and no restructure can avoid it. A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose execution is timing-dependent under the in-process HTTP bridge — the POST-stream and diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index b02927a17e..c7945d56c3 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -358,7 +358,7 @@ async def collect(params: LoggingMessageNotificationParams) -> None: tg.cancel_scope.cancel() with anyio.fail_after(5): # pragma: no branch - release.set() + release.set() # pragma: lax no cover — python/cpython#106749: 3.11 drops this line event # init priming + init response + call priming + "first" + "second" + result = 6 stored events. await store.wait_until_stored(6) async with ( # pragma: no branch