Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,14 +619,14 @@ async def sse_writer(): # pragma: lax no cover
# Then send the message to be processed by the server
session_message = self._create_session_message(message, request, request_id, protocol_version)
await writer.send(session_message)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("SSE response error")
await sse_stream_writer.aclose()
await self._clean_up_memory_streams(request_id)
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
except Exception as err: # pragma: lax no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
Expand Down Expand Up @@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool:

async def _validate_session(self, request: Request, send: Send) -> bool:
"""Validate the session ID in the request."""
if not self.mcp_session_id: # pragma: no cover
if not self.mcp_session_id: # pragma: lax no cover
# If we're not using session IDs, return True
return True

Expand Down Expand Up @@ -1019,7 +1019,7 @@ async def message_router():
)
except anyio.ClosedResourceError:
if self._terminated:
logger.debug("Read stream closed by client")
logger.debug("Read stream closed by client") # pragma: lax no cover
else:
logger.exception("Unexpected closure of read stream in message router")
except Exception: # pragma: lax no cover
Expand Down
38 changes: 28 additions & 10 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_stateful_request(scope, receive, send)

async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Process request in stateless mode - creating a new transport for each request."""
"""Process request in stateless mode - creating a new transport for each request.

Uses a request-scoped task group so the server task is automatically
cancelled when the request completes, preventing task accumulation in
the manager's global task group.
"""
logger.debug("Stateless mode: Creating new transport for this request")
# No session ID needed in stateless mode
http_transport = StreamableHTTPServerTransport(
Expand All @@ -173,18 +178,31 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
self.app.create_initialization_options(),
stateless=True,
)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Stateless session crashed")

# Assert task group is not None for type checking
assert self._task_group is not None
# Start the server task
await self._task_group.start(run_stateless_server)

# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
# Use a request-scoped task group instead of the global one.
# This ensures the server task is cancelled when the request
# finishes, preventing zombie tasks from accumulating.
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1764
async with anyio.create_task_group() as request_tg:

# Terminate the transport after the request is handled
async def run_request_handler(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
task_status.started()
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
# Cancel the request-scoped task group to stop the server task.
# This ensures the Cancelled exception reaches the server task
# before terminate() closes the streams, avoiding a race between
# Cancelled and ClosedResourceError in the message router.
request_tg.cancel_scope.cancel()

await request_tg.start(run_stateless_server)
await request_tg.start(run_request_handler)

# Terminate after the task group exits — the server task is already
# cancelled at this point, so this is just cleanup (sets _terminated
# flag and closes any remaining streams).
await http_transport.terminate()

async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ async def _receive_loop(self) -> None:
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
# Stream might already be closed
pass
self._response_streams.clear()
Expand Down
89 changes: 85 additions & 4 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,96 @@ async def mock_receive():
await manager.handle_request(scope, mock_receive, mock_send)

# Verify transport was created
assert len(created_transports) == 1, "Should have created one transport"
assert len(created_transports) == 1, "Should have created one transport" # pragma: lax no cover

transport = created_transports[0]
transport = created_transports[0] # pragma: lax no cover

# The key assertion - transport should be terminated
assert transport._terminated, "Transport should be terminated after stateless request"
assert transport._terminated, (
"Transport should be terminated after stateless request"
) # pragma: lax no cover

# Verify internal state is cleaned up
assert len(transport._request_streams) == 0, "Transport should have no active request streams"
assert len(transport._request_streams) == 0, (
"Transport should have no active request streams"
) # pragma: lax no cover


@pytest.mark.anyio
async def test_stateless_requests_task_leak_on_client_disconnect():
"""Test that stateless tasks don't leak when clients disconnect mid-request.

Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1764

Reproduces the production memory leak: a client sends a tool call, the tool
handler takes some time, and the client disconnects before the response is
delivered. The SSE response pipeline detects the disconnect but app.run()
continues in the background. After the tool finishes, the response has
nowhere to go, and app.run() blocks on ``async for message in
session.incoming_messages`` forever — leaking the task in the global
task group.

The test uses real Server.run() with a real tool handler, real SSE streaming
via httpx.ASGITransport, and simulates client disconnect by cancelling the
request task.
"""
from mcp.types import CallToolResult, TextContent

tool_started = anyio.Event()
tool_gate = anyio.Event()

async def handle_call_tool(ctx: ServerRequestContext, params: Any) -> CallToolResult:
tool_started.set()
await tool_gate.wait()
return CallToolResult(content=[TextContent(type="text", text="done")]) # pragma: no cover

app = Server(
"test-stateless-leak",
on_call_tool=handle_call_tool,
)

host = "testserver"
mcp_app = app.streamable_http_app(host=host, stateless_http=True)

async with (
mcp_app.router.lifespan_context(mcp_app),
httpx.ASGITransport(mcp_app) as transport,
):
session_manager = app._session_manager
assert session_manager is not None

async def make_and_abandon_tool_call():
async with httpx.AsyncClient(transport=transport, base_url=f"http://{host}", timeout=30.0) as http_client:
async with Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client:
# Start tool call — this will block until tool completes
# We'll cancel it from outside to simulate disconnect
await client.call_tool("slow_tool", {})

num_requests = 3
for _ in range(num_requests):
async with anyio.create_task_group() as tg:
tg.start_soon(make_and_abandon_tool_call)
# Wait for the tool handler to actually start
await tool_started.wait()
tool_started = anyio.Event() # Reset for next iteration
# Simulate client disconnect by cancelling the request
tg.cancel_scope.cancel()

# Let the tool finish now (response has nowhere to go)
tool_gate.set()
tool_gate = anyio.Event() # Reset for next iteration

# Give tasks a chance to settle
await anyio.sleep(0.1)

# Check for leaked tasks in the session manager's global task group
await anyio.sleep(0.1)
assert session_manager._task_group is not None
leaked = len(session_manager._task_group._tasks) # type: ignore[attr-defined]

assert leaked == 0, (
f"Expected 0 lingering tasks but found {leaked}. Stateless request tasks are leaking after client disconnect."
)


@pytest.mark.anyio
Expand Down