Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,10 +767,19 @@ async def terminate(self) -> None:

Once terminated, all requests with this session ID will receive 404 Not Found.
"""
if self._terminated:
return

self._terminated = True
logger.info(f"Terminating session: {self.mcp_session_id}")

# Close active SSE responses so ASGI response tasks can finish before
# the session manager cancels the owning task group.
sse_stream_writers = list(self._sse_stream_writers.values())
self._sse_stream_writers.clear()
for writer in sse_stream_writers:
writer.close()

# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())

Expand Down
23 changes: 17 additions & 6 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,23 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
yield # Let the application run
finally:
logger.info("StreamableHTTP session manager shutting down")
# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._session_owners.clear()
try:
await self._terminate_active_sessions()
finally:
# Cancel task group to stop all spawned tasks
tg.cancel_scope.cancel()
self._task_group = None
# Clear any remaining server instances
self._server_instances.clear()
self._session_owners.clear()

async def _terminate_active_sessions(self) -> None:
"""Terminate tracked transports before cancelling their task group."""
for transport in list(self._server_instances.values()):
try:
await transport.terminate()
except Exception:
logger.exception("Error terminating StreamableHTTP session during shutdown")

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process ASGI request with proper session handling and transport setup.
Expand Down
84 changes: 83 additions & 1 deletion tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import json
import logging
from typing import Any
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import AsyncMock, patch

import anyio
Expand Down Expand Up @@ -64,6 +65,50 @@ async def try_run():
assert "StreamableHTTPSessionManager .run() can only be called once per instance" in str(errors[0])


@pytest.mark.anyio
async def test_run_terminates_active_streaming_session_before_shutdown():
"""run() should close active SSE transports before task cancellation."""
app = Server("test-shutdown-cleanup")
manager = StreamableHTTPSessionManager(app=app)
transport = StreamableHTTPServerTransport(mcp_session_id="session-id")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)

try:
transport._sse_stream_writers["request-id"] = sse_stream_writer

async with manager.run():
manager._server_instances["session-id"] = transport

assert transport.is_terminated
assert transport._sse_stream_writers == {}
assert manager._server_instances == {}
with pytest.raises(anyio.ClosedResourceError):
await sse_stream_writer.send({"data": "still-open"})
finally:
await sse_stream_reader.aclose()


@pytest.mark.anyio
async def test_run_terminates_remaining_sessions_if_one_shutdown_fails(caplog: pytest.LogCaptureFixture):
"""One failed transport shutdown should not skip later active sessions."""
app = Server("test-shutdown-cleanup-error")
manager = StreamableHTTPSessionManager(app=app)
failing_terminate = AsyncMock(side_effect=RuntimeError("terminate failed"))
healthy_terminate = AsyncMock()
failing_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=failing_terminate))
healthy_transport = cast(StreamableHTTPServerTransport, SimpleNamespace(terminate=healthy_terminate))

with caplog.at_level(logging.ERROR):
async with manager.run():
manager._server_instances["bad-session"] = failing_transport
manager._server_instances["healthy-session"] = healthy_transport

failing_terminate.assert_awaited_once_with()
healthy_terminate.assert_awaited_once_with()
assert "Error terminating StreamableHTTP session during shutdown" in caplog.text
assert manager._server_instances == {}


@pytest.mark.anyio
async def test_handle_request_without_run_raises_error():
"""Test that handle_request raises error if run() hasn't been called."""
Expand Down Expand Up @@ -271,6 +316,43 @@ async def mock_receive():
assert len(transport._request_streams) == 0, "Transport should have no active request streams"


@pytest.mark.anyio
async def test_transport_terminate_closes_sse_stream_writers():
"""terminate() should close active SSE writers so streaming responses can finish."""
transport = StreamableHTTPServerTransport(mcp_session_id="test-session")
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](1)

try:
transport._sse_stream_writers["request-id"] = sse_stream_writer

await transport.terminate()

assert transport._sse_stream_writers == {}
with pytest.raises(anyio.ClosedResourceError):
await sse_stream_writer.send({"data": "still-open"})

await transport.terminate()
finally:
await sse_stream_reader.aclose()


@pytest.mark.anyio
async def test_transport_connect_cleans_request_streams_on_exit():
"""connect() should close registered request streams when the transport exits."""
transport = StreamableHTTPServerTransport(mcp_session_id="test-session")
request_stream_writer, request_stream_reader = anyio.create_memory_object_stream[Any](1)

transport._request_streams["request-id"] = (request_stream_writer, request_stream_reader)

async with transport.connect():
assert "request-id" in transport._request_streams
transport._terminated = True

assert transport._request_streams == {}
with pytest.raises(anyio.ClosedResourceError):
await request_stream_writer.send(cast(Any, object()))


@pytest.mark.anyio
async def test_unknown_session_id_returns_404(caplog: pytest.LogCaptureFixture):
"""Test that requests with unknown session IDs return HTTP 404 per MCP spec."""
Expand Down
Loading