diff --git a/docs/concepts.md b/docs/concepts.md index a2d6eb8d3a..7354a9b30a 100644 --- a/docs/concepts.md +++ b/docs/concepts.md @@ -11,3 +11,98 @@ - Context and sessions - Lifecycle and state --> + +## Transport Security + +MCP servers that use HTTP transports (SSE or Streamable HTTP) include DNS rebinding +protection via `TransportSecuritySettings`. This guards against attacks where a malicious +page tricks a browser into making requests to a locally running MCP server by spoofing the +`Host` header. + +### Default behavior + +- **Streamable HTTP** (`streamable_http_app()`) enables protection by default. +- **SSE** (`sse_app()`) disables protection by default for backwards compatibility. +- **stdio** transport is unaffected — it has no network surface. + +### Configuring allowed hosts + +Set `allowed_hosts` to the hostname(s) your server is reachable at: + +```python +from mcp.server.mcpserver import MCPServer +from mcp.server.transport_security import TransportSecuritySettings + +mcp = MCPServer("My Server") + +security = TransportSecuritySettings( + allowed_hosts=["myserver.example.com"], +) + +app = mcp.streamable_http_app(transport_security=security) +``` + +If `allowed_hosts` is empty while protection is enabled, **all requests will be rejected +with HTTP 421**. A warning is logged at startup to make this misconfiguration visible. + +### Wildcard port matching + +The `Host` header includes a port when the client connects on a non-default port +(e.g., `myserver.example.com:8080`). Use a `:*` suffix to allow any port for a given +hostname: + +```python +security = TransportSecuritySettings( + allowed_hosts=["localhost:*", "myserver.example.com:*"], +) +``` + +### TLS termination and reverse proxies + +Behind a reverse proxy (nginx, Caddy, an AWS load balancer, etc.), the port that appears +in the `Host` header depends on how the proxy is configured. Common variants: + +| Proxy configuration | `Host` header seen by MCP server | +|---|---| +| Proxy strips port (default for HTTPS) | `myserver.example.com` | +| Proxy preserves port | `myserver.example.com:443` | +| Local development | `localhost:8000` | + +Because the behavior varies, the safest production setting is the `:*` wildcard: + +```python +security = TransportSecuritySettings( + allowed_hosts=["myserver.example.com:*", "myserver.example.com"], +) +``` + +Or, if you only need to match any port: + +```python +security = TransportSecuritySettings( + allowed_hosts=["myserver.example.com:*"], + # "myserver.example.com" (no port) won't match "myserver.example.com:*" + # Add the bare hostname too if your proxy strips the port +) +``` + +### Restricting origins + +For browser-based MCP clients, you can also restrict which origins are allowed to connect. +Requests without an `Origin` header (e.g., from non-browser clients) are always allowed: + +```python +security = TransportSecuritySettings( + allowed_hosts=["myserver.example.com:*"], + allowed_origins=["https://myapp.example.com:*"], +) +``` + +### Disabling protection + +Protection can be turned off entirely, for example during local development with a client +that sends unusual headers: + +```python +security = TransportSecuritySettings(enable_dns_rebinding_protection=False) +``` diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0e..1ee4828c5d 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -2,9 +2,10 @@ import logging -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from starlette.requests import Request from starlette.responses import Response +from typing_extensions import Self logger = logging.getLogger(__name__) @@ -31,6 +32,17 @@ class TransportSecuritySettings(BaseModel): Only applies when `enable_dns_rebinding_protection` is `True`. """ + @model_validator(mode="after") + def _warn_if_protection_enabled_with_empty_allowlist(self) -> Self: + if self.enable_dns_rebinding_protection and not self.allowed_hosts: + logger.warning( + "TransportSecuritySettings has DNS rebinding protection enabled but " + "allowed_hosts is empty — all requests will be rejected with HTTP 421. " + "Set allowed_hosts to your server's hostname(s), e.g. " + 'TransportSecuritySettings(allowed_hosts=["your-host.example.com:*"])' + ) + return self + # TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this. class TransportSecurityMiddleware: @@ -39,8 +51,10 @@ class TransportSecurityMiddleware: def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) + self._warned_hosts: set[str | None] = set() + self._warned_origins: set[str] = set() - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -59,10 +73,12 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover if host.startswith(base_host + ":"): return True - logger.warning(f"Invalid Host header: {host}") + if host not in self._warned_hosts: + self._warned_hosts.add(host) + logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -81,7 +97,9 @@ def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover if origin.startswith(base_origin + ":"): return True - logger.warning(f"Invalid Origin header: {origin}") + if origin not in self._warned_origins: + self._warned_origins.add(origin) + logger.warning(f"Invalid Origin header: {origin}") return False def _validate_content_type(self, content_type: str | None) -> bool: @@ -94,7 +112,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res Returns None if validation passes, or an error Response if validation fails. """ # Always validate Content-Type for POST requests - if is_post: # pragma: no branch + if is_post: content_type = request.headers.get("content-type") if not self._validate_content_type(content_type): return Response("Invalid Content-Type header", status_code=400) @@ -103,14 +121,22 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover - - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover - - return None # pragma: no cover + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response( + f"Invalid Host header: {host!r}. " + "Configure TransportSecuritySettings(allowed_hosts=[...]) with your server's hostname.", + status_code=421, + ) + + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response( + f"Invalid Origin header: {origin!r}. " + "Configure TransportSecuritySettings(allowed_origins=[...]) with your server's origin.", + status_code=403, + ) + + return None diff --git a/tests/server/test_sse_security.py b/tests/server/test_sse_security.py index 010eaf6a25..d184c67c76 100644 --- a/tests/server/test_sse_security.py +++ b/tests/server/test_sse_security.py @@ -105,7 +105,7 @@ async def test_sse_security_invalid_host_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() @@ -128,7 +128,7 @@ async def test_sse_security_invalid_origin_header(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 403 - assert response.text == "Invalid Origin header" + assert "Invalid Origin header" in response.text finally: process.terminate() @@ -215,7 +215,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int): async with httpx.AsyncClient() as client: response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() diff --git a/tests/server/test_streamable_http_security.py b/tests/server/test_streamable_http_security.py index 897555353e..abec538abc 100644 --- a/tests/server/test_streamable_http_security.py +++ b/tests/server/test_streamable_http_security.py @@ -126,7 +126,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int): headers=headers, ) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text finally: process.terminate() @@ -154,7 +154,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int): headers=headers, ) assert response.status_code == 403 - assert response.text == "Invalid Origin header" + assert "Invalid Origin header" in response.text finally: process.terminate() @@ -269,7 +269,7 @@ async def test_streamable_http_security_get_request(server_port: int): async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers) assert response.status_code == 421 - assert response.text == "Invalid Host header" + assert "Invalid Host header" in response.text # Test GET request with valid host header headers = { diff --git a/tests/server/test_transport_security.py b/tests/server/test_transport_security.py new file mode 100644 index 0000000000..487d67ffd7 --- /dev/null +++ b/tests/server/test_transport_security.py @@ -0,0 +1,203 @@ +"""Unit tests for TransportSecuritySettings and TransportSecurityMiddleware.""" + +import logging + +import pytest +from starlette.requests import Request + +from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings + + +def make_request(headers: dict[str, str], method: str = "GET") -> Request: + scope = { + "type": "http", + "method": method, + "headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()], + "path": "/", + "query_string": b"", + } + return Request(scope) + + +# --------------------------------------------------------------------------- +# TransportSecuritySettings — construction-time warning +# --------------------------------------------------------------------------- + + +def test_no_warning_when_protection_disabled(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings(enable_dns_rebinding_protection=False) + assert not caplog.records + + +def test_no_warning_when_allowed_hosts_populated(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=["example.com"], + ) + assert not caplog.records + + +def test_warning_when_protection_enabled_with_empty_allowed_hosts(caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + TransportSecuritySettings(enable_dns_rebinding_protection=True) + assert len(caplog.records) == 1 + assert "allowed_hosts is empty" in caplog.records[0].message + assert "HTTP 421" in caplog.records[0].message + assert "allowed_hosts=" in caplog.records[0].message + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware._validate_host +# --------------------------------------------------------------------------- + + +def test_validate_host_missing_host() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host(None) is False + + +def test_validate_host_exact_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host("example.com") is True + + +def test_validate_host_exact_no_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + assert m._validate_host("other.com") is False + + +def test_validate_host_port_wildcard_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("localhost:8080") is True + + +def test_validate_host_port_wildcard_different_base() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("other:8080") is False + + +def test_validate_host_port_wildcard_no_port() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["localhost:*"])) + assert m._validate_host("localhost") is False + + +def test_validate_host_logs_once_per_unique_host(caplog: pytest.LogCaptureFixture) -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + m._validate_host("evil.com") + m._validate_host("evil.com") + m._validate_host("evil.com") + m._validate_host("other.com") + host_records = [r for r in caplog.records if "Invalid Host header" in r.message] + assert len(host_records) == 2 # one for evil.com, one for other.com + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware._validate_origin +# --------------------------------------------------------------------------- + + +def test_validate_origin_absent_is_allowed() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin(None) is True + + +def test_validate_origin_exact_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin("http://example.com") is True + + +def test_validate_origin_exact_no_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + assert m._validate_origin("http://other.com") is False + + +def test_validate_origin_port_wildcard_match() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"])) + assert m._validate_origin("http://localhost:3000") is True + + +def test_validate_origin_port_wildcard_different_base() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://localhost:*"])) + assert m._validate_origin("http://other:3000") is False + + +def test_validate_origin_logs_once_per_unique_origin(caplog: pytest.LogCaptureFixture) -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_origins=["http://example.com"])) + with caplog.at_level(logging.WARNING, logger="mcp.server.transport_security"): + m._validate_origin("http://evil.com") + m._validate_origin("http://evil.com") + m._validate_origin("http://other.com") + origin_records = [r for r in caplog.records if "Invalid Origin header" in r.message] + assert len(origin_records) == 2 # one for evil.com, one for other.com + + +# --------------------------------------------------------------------------- +# TransportSecurityMiddleware.validate_request +# --------------------------------------------------------------------------- + + +@pytest.mark.anyio +async def test_validate_request_post_valid_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"content-type": "application/json"}, method="POST") + assert await m.validate_request(request, is_post=True) is None + + +@pytest.mark.anyio +async def test_validate_request_post_invalid_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"content-type": "text/plain"}, method="POST") + response = await m.validate_request(request, is_post=True) + assert response is not None + assert response.status_code == 400 + + +@pytest.mark.anyio +async def test_validate_request_get_skips_content_type() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({}) + assert await m.validate_request(request, is_post=False) is None + + +@pytest.mark.anyio +async def test_validate_request_protection_disabled_allows_any_host() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False)) + request = make_request({"host": "attacker.example.com"}) + assert await m.validate_request(request) is None + + +@pytest.mark.anyio +async def test_validate_request_valid_host_and_no_origin() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + request = make_request({"host": "example.com"}) + assert await m.validate_request(request) is None + + +@pytest.mark.anyio +async def test_validate_request_invalid_host_returns_421_with_detail() -> None: + m = TransportSecurityMiddleware(TransportSecuritySettings(allowed_hosts=["example.com"])) + request = make_request({"host": "attacker.com"}) + response = await m.validate_request(request) + assert response is not None + assert response.status_code == 421 + assert b"attacker.com" in response.body + assert b"allowed_hosts" in response.body + + +@pytest.mark.anyio +async def test_validate_request_invalid_origin_returns_403_with_detail() -> None: + m = TransportSecurityMiddleware( + TransportSecuritySettings( + allowed_hosts=["example.com"], + allowed_origins=["http://example.com"], + ) + ) + request = make_request({"host": "example.com", "origin": "http://attacker.com"}) + response = await m.validate_request(request) + assert response is not None + assert response.status_code == 403 + assert b"attacker.com" in response.body + assert b"allowed_origins" in response.body