Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ def __init__(self, url: str) -> None:
url: The endpoint URL.
"""
self.url = url
parsed_url = httpx.URL(url)
self.origin = f"{parsed_url.scheme}://{parsed_url.netloc.decode()}" if parsed_url.netloc else None
self.session_id: str | None = None
self.protocol_version: str | None = None

def _prepare_headers(self) -> dict[str, str]:
def _prepare_headers(self, client: httpx.AsyncClient | None = None) -> dict[str, str]:
"""Build MCP-specific request headers.

These headers will be merged with the httpx.AsyncClient's default headers,
Expand All @@ -92,6 +94,8 @@ def _prepare_headers(self) -> dict[str, str]:
"accept": "application/json, text/event-stream",
"content-type": "application/json",
}
if self.origin and (client is None or "origin" not in client.headers):
headers["origin"] = self.origin
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
Expand Down Expand Up @@ -189,7 +193,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
if not self.session_id:
return

headers = self._prepare_headers()
headers = self._prepare_headers(client)
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id

Expand Down Expand Up @@ -225,7 +229,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_headers()
headers = self._prepare_headers(ctx.client)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand Down Expand Up @@ -253,7 +257,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
headers = self._prepare_headers(ctx.client)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand Down Expand Up @@ -388,7 +392,7 @@ async def _handle_reconnection(
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
await anyio.sleep(delay_ms / 1000.0)

headers = self._prepare_headers()
headers = self._prepare_headers(ctx.client)
headers[LAST_EVENT_ID] = last_event_id

# Extract original request ID to map responses
Expand Down Expand Up @@ -496,7 +500,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
return # pragma: no cover

try:
headers = self._prepare_headers()
headers = self._prepare_headers(client)
response = await client.delete(self.url, headers=headers)

if response.status_code == 405:
Expand Down
36 changes: 36 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,3 +2318,39 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


@pytest.mark.anyio
async def test_streamable_http_client_adds_origin_header(context_aware_server: None, basic_server_url: str) -> None:
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

tool_result = await session.call_tool("echo_headers", {})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
headers_data = json.loads(tool_result.content[0].text)

assert headers_data["origin"] == basic_server_url


@pytest.mark.anyio
async def test_streamable_http_client_preserves_custom_origin_header(
context_aware_server: None, basic_server_url: str
) -> None:
custom_origin = "https://proxy.example"

async with create_mcp_http_client(headers={"Origin": custom_origin}) as httpx_client:
async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as (
read_stream,
write_stream,
):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

tool_result = await session.call_tool("echo_headers", {})
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
headers_data = json.loads(tool_result.content[0].text)

assert headers_data["origin"] == custom_origin
Loading