Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions docs/concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,98 @@
- Context and sessions
- Lifecycle and state
-->

## Transport Security

MCP servers that use HTTP transports (SSE or Streamable HTTP) include DNS rebinding
protection via `TransportSecuritySettings`. This guards against attacks where a malicious
page tricks a browser into making requests to a locally running MCP server by spoofing the
`Host` header.

### Default behavior

- **Streamable HTTP** (`streamable_http_app()`) enables protection by default.
- **SSE** (`sse_app()`) disables protection by default for backwards compatibility.
- **stdio** transport is unaffected — it has no network surface.

### Configuring allowed hosts

Set `allowed_hosts` to the hostname(s) your server is reachable at:

```python
from mcp.server.mcpserver import MCPServer
from mcp.server.transport_security import TransportSecuritySettings

mcp = MCPServer("My Server")

security = TransportSecuritySettings(
allowed_hosts=["myserver.example.com"],
)

app = mcp.streamable_http_app(transport_security=security)
```

If `allowed_hosts` is empty while protection is enabled, **all requests will be rejected
with HTTP 421**. A warning is logged at startup to make this misconfiguration visible.

### Wildcard port matching

The `Host` header includes a port when the client connects on a non-default port
(e.g., `myserver.example.com:8080`). Use a `:*` suffix to allow any port for a given
hostname:

```python
security = TransportSecuritySettings(
allowed_hosts=["localhost:*", "myserver.example.com:*"],
)
```

### TLS termination and reverse proxies

Behind a reverse proxy (nginx, Caddy, an AWS load balancer, etc.), the port that appears
in the `Host` header depends on how the proxy is configured. Common variants:

| Proxy configuration | `Host` header seen by MCP server |
|---|---|
| Proxy strips port (default for HTTPS) | `myserver.example.com` |
| Proxy preserves port | `myserver.example.com:443` |
| Local development | `localhost:8000` |

Because the behavior varies, the safest production setting is the `:*` wildcard:

```python
security = TransportSecuritySettings(
allowed_hosts=["myserver.example.com:*", "myserver.example.com"],
)
```

Or, if you only need to match any port:

```python
security = TransportSecuritySettings(
allowed_hosts=["myserver.example.com:*"],
# "myserver.example.com" (no port) won't match "myserver.example.com:*"
# Add the bare hostname too if your proxy strips the port
)
```

### Restricting origins

For browser-based MCP clients, you can also restrict which origins are allowed to connect.
Requests without an `Origin` header (e.g., from non-browser clients) are always allowed:

```python
security = TransportSecuritySettings(
allowed_hosts=["myserver.example.com:*"],
allowed_origins=["https://myapp.example.com:*"],
)
```

### Disabling protection

Protection can be turned off entirely, for example during local development with a client
that sends unusual headers:

```python
security = TransportSecuritySettings(enable_dns_rebinding_protection=False)
```
60 changes: 43 additions & 17 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import logging

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from starlette.requests import Request
from starlette.responses import Response
from typing_extensions import Self

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +32,17 @@ class TransportSecuritySettings(BaseModel):
Only applies when `enable_dns_rebinding_protection` is `True`.
"""

@model_validator(mode="after")
def _warn_if_protection_enabled_with_empty_allowlist(self) -> Self:
if self.enable_dns_rebinding_protection and not self.allowed_hosts:
logger.warning(
"TransportSecuritySettings has DNS rebinding protection enabled but "
"allowed_hosts is empty — all requests will be rejected with HTTP 421. "
"Set allowed_hosts to your server's hostname(s), e.g. "
'TransportSecuritySettings(allowed_hosts=["your-host.example.com:*"])'
)
return self


# TODO(Marcelo): This should be a proper ASGI middleware. I'm sad to see this.
class TransportSecurityMiddleware:
Expand All @@ -39,8 +51,10 @@ class TransportSecurityMiddleware:
def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
self._warned_hosts: set[str | None] = set()
self._warned_origins: set[str] = set()

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
Expand All @@ -59,10 +73,12 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
if host.startswith(base_host + ":"):
return True

logger.warning(f"Invalid Host header: {host}")
if host not in self._warned_hosts:
self._warned_hosts.add(host)
logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
Expand All @@ -81,7 +97,9 @@ def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
if origin.startswith(base_origin + ":"):
return True

logger.warning(f"Invalid Origin header: {origin}")
if origin not in self._warned_origins:
self._warned_origins.add(origin)
logger.warning(f"Invalid Origin header: {origin}")
return False

def _validate_content_type(self, content_type: str | None) -> bool:
Expand All @@ -94,7 +112,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post: # pragma: no branch
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
Expand All @@ -103,14 +121,22 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
if not self.settings.enable_dns_rebinding_protection:
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover

return None # pragma: no cover
# Validate Host header
host = request.headers.get("host")
if not self._validate_host(host):
return Response(
f"Invalid Host header: {host!r}. "
"Configure TransportSecuritySettings(allowed_hosts=[...]) with your server's hostname.",
status_code=421,
)

# Validate Origin header
origin = request.headers.get("origin")
if not self._validate_origin(origin):
return Response(
f"Invalid Origin header: {origin!r}. "
"Configure TransportSecuritySettings(allowed_origins=[...]) with your server's origin.",
status_code=403,
)

return None
6 changes: 3 additions & 3 deletions tests/server/test_sse_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def test_sse_security_invalid_host_header(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand All @@ -128,7 +128,7 @@ async def test_sse_security_invalid_origin_header(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 403
assert response.text == "Invalid Origin header"
assert "Invalid Origin header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -215,7 +215,7 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
async with httpx.AsyncClient() as client:
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand Down
6 changes: 3 additions & 3 deletions tests/server/test_streamable_http_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def test_streamable_http_security_invalid_host_header(server_port: int):
headers=headers,
)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -154,7 +154,7 @@ async def test_streamable_http_security_invalid_origin_header(server_port: int):
headers=headers,
)
assert response.status_code == 403
assert response.text == "Invalid Origin header"
assert "Invalid Origin header" in response.text

finally:
process.terminate()
Expand Down Expand Up @@ -269,7 +269,7 @@ async def test_streamable_http_security_get_request(server_port: int):
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"http://127.0.0.1:{server_port}/", headers=headers)
assert response.status_code == 421
assert response.text == "Invalid Host header"
assert "Invalid Host header" in response.text

# Test GET request with valid host header
headers = {
Expand Down
Loading
Loading