diff --git a/README.v2.md b/README.v2.md index d0851c04e..1a6b4562c 100644 --- a/README.v2.md +++ b/README.v2.md @@ -1371,6 +1371,28 @@ This configuration is necessary because: - Browsers restrict access to response headers unless explicitly exposed via CORS - Without this configuration, browser-based clients won't be able to read the session ID from initialization responses +#### Reverse Proxy Host Headers + +DNS rebinding protection checks the incoming `Host` header when transport security is enabled. If your server is behind +nginx, Cloudflare, or another reverse proxy, include the public hostname in `TransportSecuritySettings.allowed_hosts`. +Some proxies preserve the port, so include both forms when needed: + +```python +from mcp.server.transport_security import TransportSecuritySettings + +transport_security = TransportSecuritySettings( + allowed_hosts=[ + "mcp.example.com", + "mcp.example.com:443", + ], +) + +mcp_app = server.streamable_http_app(transport_security=transport_security) +``` + +If a request is rejected by this check, the server returns HTTP 421 with `host_not_allowed`, the received host, and the +setting to configure. + ### Mounting to an Existing ASGI Server By default, SSE servers are mounted at `/sse` and Streamable HTTP servers are mounted at `/mcp`. You can customize these paths using the methods described below. diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index d9e9f965b..0229d8745 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import JSONResponse, Response logger = logging.getLogger(__name__) @@ -106,7 +106,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res # Validate Host header host = request.headers.get("host") if not self._validate_host(host): - return Response("Invalid Host header", status_code=421) + return JSONResponse( + { + "error": "host_not_allowed", + "received_host": host, + "configure": "TransportSecuritySettings.allowed_hosts", + }, + status_code=421, + ) # Validate Origin header origin = request.headers.get("origin") diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 85e64ded4..fa0d96ece 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -330,7 +330,16 @@ async def test_origin_validation_rejects_disallowed_origins_when_enabled() -> No assert [event async for event in ok.aiter_sse()] assert (bad_origin.status_code, bad_origin.text) == snapshot((403, "Invalid Origin header")) - assert (bad_host.status_code, bad_host.text) == snapshot((421, "Invalid Host header")) + assert (bad_host.status_code, bad_host.json()) == snapshot( + ( + 421, + { + "error": "host_not_allowed", + "received_host": "evil.example", + "configure": "TransportSecuritySettings.allowed_hosts", + }, + ) + ) async with mounted_app( Server("unguarded"), transport_security=TransportSecuritySettings(enable_dns_rebinding_protection=False) diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index e95dc51b3..338338e78 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -122,7 +122,11 @@ async def test_sse_security_invalid_host_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } finally: process.terminate() @@ -232,7 +236,11 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } finally: process.terminate() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353..ed5e20e3a 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -126,7 +126,11 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): headers=headers, ) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } finally: process.terminate() @@ -269,7 +273,11 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert response.json() == { + "error": "host_not_allowed", + "received_host": "evil.com", + "configure": "TransportSecuritySettings.allowed_hosts", + } # Test GET request with valid host header headers = { diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py index be28980b5..cf9870d49 100644 --- a/tests/server/test_transport_security.py +++ b/tests/server/test_transport_security.py @@ -48,6 +48,20 @@ async def test_validate_request_checks_host_then_origin( assert (None if response is None else response.status_code) == expected +@pytest.mark.anyio +async def test_validate_request_explains_host_rejection() -> None: + middleware = TransportSecurityMiddleware(SETTINGS) + response = await middleware.validate_request(_request("evil.example", None)) + + assert response is not None + assert response.status_code == 421 + assert response.media_type == "application/json" + assert response.body == ( + b'{"error":"host_not_allowed","received_host":"evil.example",' + b'"configure":"TransportSecuritySettings.allowed_hosts"}' + ) + + @pytest.mark.anyio async def test_validate_request_skips_host_and_origin_when_protection_is_disabled() -> None: """With DNS-rebinding protection off, any Host/Origin is accepted."""