Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/mcp/server/auth/middleware/bearer_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from typing import Any
from typing import Any, TypedDict

from pydantic import AnyHttpUrl
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
Expand All @@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
self.scopes = auth_info.scopes


class AuthorizationContext(TypedDict):
client_id: str
issuer: str | None
subject: str | None


def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
"""Identify the principal `user` represents, for transports to compare
against the principal that created a session. Components the token
verifier does not supply are `None`, so the comparison degrades to the
remaining components.

See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
a verifier that populates `subject` and `claims` from an introspection
response."""
token = user.access_token
issuer = (token.claims or {}).get("iss")
return AuthorizationContext(
client_id=token.client_id,
issuer=str(issuer) if issuer is not None else None,
subject=token.subject,
)


class BearerAuthBackend(AuthenticationBackend):
"""Authentication backend that validates Bearer tokens using a TokenVerifier."""

Expand Down
66 changes: 43 additions & 23 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def handle_sse(request):
from starlette.types import Receive, Scope, Send

from mcp import types
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
from mcp.server.transport_security import (
TransportSecurityMiddleware,
TransportSecuritySettings,
Expand All @@ -73,6 +74,9 @@ class SseServerTransport:

_endpoint: str
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
# Identity of the credential that created each session; requests for a
# session must present the same credential.
_session_owners: dict[UUID, AuthorizationContext]
_security: TransportSecurityMiddleware

def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
Expand Down Expand Up @@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |

self._endpoint = endpoint
self._read_stream_writers = {}
self._session_owners = {}
self._security = TransportSecurityMiddleware(security_settings)
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")

@asynccontextmanager
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http": # pragma: no cover
if scope["type"] != "http":
logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")

# Validate request headers for DNS rebinding protection
request = Request(scope, receive)
error_response = await self._security.validate_request(request, is_post=False)
if error_response: # pragma: no cover
if error_response:
await error_response(scope, receive, send)
raise ValueError("Request validation failed")

Expand All @@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)

session_id = uuid4()
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
self._session_owners[session_id] = authorization_context(user)
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

Expand Down Expand Up @@ -169,35 +177,38 @@ async def sse_writer():
}
)

async with anyio.create_task_group() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
scope, receive, send
)
await sse_stream_reader.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
self._read_stream_writers.pop(session_id, None)
logging.debug(f"Client session disconnected {session_id}")
try:
async with anyio.create_task_group() as tg:

async def response_wrapper(scope: Scope, receive: Receive, send: Send):
"""The EventSourceResponse returning signals a client close / disconnect.
In this case we close our side of the streams to signal the client that
the connection has been closed.
"""
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
scope, receive, send
)
await read_stream_writer.aclose()
await write_stream_reader.aclose()
await sse_stream_reader.aclose()
logging.debug(f"Client session disconnected {session_id}")

logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)
logger.debug("Starting SSE response task")
tg.start_soon(response_wrapper, scope, receive, send)

logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
finally:
self._read_stream_writers.pop(session_id, None)
self._session_owners.pop(session_id, None)

async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
logger.debug("Handling POST message")
request = Request(scope, receive)

# Validate request headers for DNS rebinding protection
error_response = await self._security.validate_request(request, is_post=True)
if error_response: # pragma: no cover
if error_response:
return await error_response(scope, receive, send)

session_id_param = request.query_params.get("session_id")
Expand All @@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)

user = scope.get("user")
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
if requestor != self._session_owners.get(session_id):
# A session can only be used with the credential that created it.
# Respond exactly as if the session did not exist.
logger.warning("Rejecting message for session %s: credential does not match", session_id)
response = Response("Could not find session", status_code=404)
return await response(scope, receive, send)

body = await request.body()
logger.debug(f"Received JSON: {body}")

try:
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
logger.debug(f"Validated client message: {message}")
except ValidationError as err: # pragma: no cover
except ValidationError as err:
logger.exception("Failed to parse message")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
Expand Down
40 changes: 32 additions & 8 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import contextlib
import logging
from collections.abc import AsyncIterator
from http import HTTPStatus
from typing import TYPE_CHECKING, Any
from uuid import uuid4

Expand All @@ -15,6 +14,7 @@
from starlette.responses import Response
from starlette.types import Receive, Scope, Send

from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
from mcp.server.streamable_http import (
MCP_SESSION_ID_HEADER,
EventStore,
Expand Down Expand Up @@ -89,6 +89,9 @@ def __init__(
# Session tracking (only used if not stateless)
self._session_creation_lock = anyio.Lock()
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
# Identity of the credential that created each session; requests for a
# session must present the same credential.
self._session_owners: dict[str, AuthorizationContext] = {}

# The task group will be set during lifespan
self._task_group = None
Expand Down Expand Up @@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._session_owners.clear()

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.
Expand Down Expand Up @@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)

user = scope.get("user")
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None

# Existing session case
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
transport = self._server_instances[request_mcp_session_id]
if requestor != self._session_owners.get(request_mcp_session_id):
# A session can only be used with the credential that created
# it. Respond exactly as if the session did not exist.
logger.warning(
"Rejecting request for session %s: credential does not match the one that created the session",
request_mcp_session_id[:64],
)
body = JSONRPCError(
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
)
response = Response(
body.model_dump_json(by_alias=True, exclude_unset=True),
status_code=404,
media_type="application/json",
)
await response(scope, receive, send)
return
logger.debug("Session already exists, handling request directly")
# Push back idle deadline on activity
if transport.idle_scope is not None and self.session_idle_timeout is not None:
Expand All @@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
)

assert http_transport.mcp_session_id is not None
if requestor is not None:
self._session_owners[http_transport.mcp_session_id] = requestor
self._server_instances[http_transport.mcp_session_id] = http_transport
logger.info(f"Created new transport with session ID: {new_session_id}")

Expand Down Expand Up @@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
assert http_transport.mcp_session_id is not None
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
self._server_instances.pop(http_transport.mcp_session_id, None)
self._session_owners.pop(http_transport.mcp_session_id, None)
await http_transport.terminate()
except Exception:
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
Expand All @@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
f"{http_transport.mcp_session_id} from active instances."
)
del self._server_instances[http_transport.mcp_session_id]
self._session_owners.pop(http_transport.mcp_session_id, None)

# Assert task group is not None for type checking
assert self._task_group is not None
Expand All @@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
# TODO: Align error code once spec clarifies
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=None,
error=ErrorData(code=INVALID_REQUEST, message="Session not found"),
body = JSONRPCError(
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
)
response = Response(
content=error_response.model_dump_json(by_alias=True, exclude_unset=True),
status_code=HTTPStatus.NOT_FOUND,
media_type="application/json",
body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json"
)
await response(scope, receive, send)

Expand Down
14 changes: 7 additions & 7 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None):

def _validate_host(self, host: str | None) -> bool:
"""Validate the Host header against allowed values."""
if not host: # pragma: no cover
if not host:
logger.warning("Missing Host header in request")
return False

# Check exact match first
if host in self.settings.allowed_hosts: # pragma: no cover
if host in self.settings.allowed_hosts:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_hosts:
if allowed.endswith(":*"): # pragma: no branch
if allowed.endswith(":*"):
# Extract base host from pattern
base_host = allowed[:-2]
# Check if the actual host starts with base host and has a port
Expand All @@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool:
def _validate_origin(self, origin: str | None) -> bool:
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin: # pragma: no cover
if not origin:
return True

# Check exact match first
if origin in self.settings.allowed_origins: # pragma: no cover
if origin in self.settings.allowed_origins:
return True

# Check wildcard port patterns
for allowed in self.settings.allowed_origins:
if allowed.endswith(":*"): # pragma: no branch
if allowed.endswith(":*"):
# Extract base origin from pattern
base_origin = allowed[:-2]
# Check if the actual origin starts with base origin and has a port
Expand All @@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
Returns None if validation passes, or an error Response if validation fails.
"""
# Always validate Content-Type for POST requests
if is_post: # pragma: no branch
if is_post:
content_type = request.headers.get("content-type")
if not self._validate_content_type(content_type):
return Response("Invalid Content-Type header", status_code=400)
Expand Down
Loading
Loading