Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
MCP_METHOD = "mcp-method"
MCP_NAME = "mcp-name"
LAST_EVENT_ID = "last-event-id"

# Reconnection defaults
Expand Down Expand Up @@ -82,7 +84,7 @@ def __init__(self, url: str) -> None:
self.session_id: str | None = None
self.protocol_version: str | None = None

def _prepare_headers(self) -> dict[str, str]:
def _prepare_headers(self, message: JSONRPCMessage | None = None) -> dict[str, str]:
"""Build MCP-specific request headers.

These headers will be merged with the httpx.AsyncClient's default headers,
Expand All @@ -97,8 +99,28 @@ def _prepare_headers(self) -> dict[str, str]:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
if isinstance(message, JSONRPCRequest | JSONRPCNotification):
headers[MCP_METHOD] = message.method
if mcp_name := self._get_mcp_name(message):
headers[MCP_NAME] = mcp_name
return headers

def _get_mcp_name(self, message: JSONRPCRequest | JSONRPCNotification) -> str | None:
params = message.params
if not isinstance(params, dict):
return None

if message.method in {"tools/call", "prompts/get"}:
value = params.get("name")
elif message.method in {"resources/read", "resources/subscribe", "resources/unsubscribe"}:
value = params.get("uri")
else:
return None

if value is None:
return None
return str(value)

def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message, JSONRPCRequest) and message.method == "initialize"
Expand Down Expand Up @@ -253,8 +275,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
message = ctx.session_message.message
headers = self._prepare_headers(message)
is_initialization = self._is_initialization_request(message)

async with ctx.client.stream(
Expand Down
27 changes: 27 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1718,6 +1718,33 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv
assert response.status_code == 200


@pytest.mark.parametrize(
("method", "params", "expected_name"),
[
("tools/call", {"name": "echo_headers"}, "echo_headers"),
("prompts/get", {"name": "summarize"}, "summarize"),
("resources/read", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"),
("resources/subscribe", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"),
("resources/unsubscribe", {"uri": "file:///tmp/readme.md"}, "file:///tmp/readme.md"),
("tools/call", {}, None),
("resources/read", {}, None),
("tools/list", {}, None),
],
)
def test_streamable_http_client_adds_sep_2243_headers(method: str, params: dict[str, Any], expected_name: str | None):
"""POST requests include SEP-2243 method/name headers."""
transport = StreamableHTTPTransport("https://example.com/mcp")
message = JSONRPCRequest(jsonrpc="2.0", id=1, method=method, params=params)

headers = transport._prepare_headers(message)

assert headers["mcp-method"] == method
if expected_name is None:
assert "mcp-name" not in headers
else:
assert headers["mcp-name"] == expected_name


def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str):
"""Test server accepts requests without protocol version header."""
# First initialize a session to get a valid session ID
Expand Down
Loading