diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 04aed345e..5aa80a75f 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -619,14 +619,14 @@ async def sse_writer(): # pragma: lax no cover # Then send the message to be processed by the server session_message = self._create_session_message(message, request, request_id, protocol_version) await writer.send(session_message) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("SSE response error") await sse_stream_writer.aclose() await self._clean_up_memory_streams(request_id) finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -809,7 +809,7 @@ async def _validate_request_headers(self, request: Request, send: Send) -> bool: async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" - if not self.mcp_session_id: # pragma: no cover + if not self.mcp_session_id: # pragma: lax no cover # If we're not using session IDs, return True return True @@ -1019,7 +1019,7 @@ async def message_router(): ) except anyio.ClosedResourceError: if self._terminated: - logger.debug("Read stream closed by client") + logger.debug("Read stream closed by client") # pragma: lax no cover else: logger.exception("Unexpected closure of read stream in message router") except Exception: # pragma: lax no cover diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 50bcd5e79..718e0d847 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -151,7 +151,12 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_stateful_request(scope, receive, send) async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Process request in stateless mode - creating a new transport for each request.""" + """Process request in stateless mode - creating a new transport for each request. + + Uses a request-scoped task group so the server task is automatically + cancelled when the request completes, preventing task accumulation in + the manager's global task group. + """ logger.debug("Stateless mode: Creating new transport for this request") # No session ID needed in stateless mode http_transport = StreamableHTTPServerTransport( @@ -173,18 +178,24 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA self.app.create_initialization_options(), stateless=True, ) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") - # Assert task group is not None for type checking - assert self._task_group is not None - # Start the server task - await self._task_group.start(run_stateless_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - - # Terminate the transport after the request is handled + # Use a request-scoped task group instead of the global one. + # This ensures the server task is cancelled when the request + # finishes, preventing zombie tasks from accumulating. + # See: https://github.com/modelcontextprotocol/python-sdk/issues/1764 + async with anyio.create_task_group() as request_tg: + await request_tg.start(run_stateless_server) + # Handle the HTTP request directly in the caller's context + # (not as a child task) so execution flows back naturally. + await http_transport.handle_request(scope, receive, send) + # Cancel the request-scoped task group to stop the server task. + request_tg.cancel_scope.cancel() + + # Terminate after the task group exits — the server task is already + # cancelled at this point, so this is just cleanup (sets _terminated + # flag and closes any remaining streams). await http_transport.terminate() async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: Send) -> None: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b617d702f..fe768a6ad 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -423,7 +423,7 @@ async def _receive_loop(self) -> None: try: await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) await stream.aclose() - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover # Stream might already be closed pass self._response_streams.clear() diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 54a898cc5..a8c6ce6d3 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -268,6 +268,83 @@ async def mock_receive(): assert len(transport._request_streams) == 0, "Transport should have no active request streams" +@pytest.mark.anyio +async def test_stateless_requests_task_leak_on_client_disconnect(): + """Test that stateless tasks don't leak when clients disconnect mid-request. + + Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1764 + + Reproduces the production memory leak: a client sends a tool call, the tool + handler takes some time, and the client disconnects before the response is + delivered. The SSE response pipeline detects the disconnect but app.run() + continues in the background. After the tool finishes, the response has + nowhere to go, and app.run() blocks on ``async for message in + session.incoming_messages`` forever — leaking the task in the global + task group. + + The test uses real Server.run() with a real tool handler, real SSE streaming + via httpx.ASGITransport, and simulates client disconnect by cancelling the + request task. + """ + from mcp.types import CallToolResult, TextContent + + tool_started = anyio.Event() + tool_gate = anyio.Event() + + async def handle_call_tool(ctx: ServerRequestContext, params: Any) -> CallToolResult: + tool_started.set() + await tool_gate.wait() + return CallToolResult(content=[TextContent(type="text", text="done")]) # pragma: no cover + + app = Server( + "test-stateless-leak", + on_call_tool=handle_call_tool, + ) + + host = "testserver" + mcp_app = app.streamable_http_app(host=host, stateless_http=True) + + async with ( + mcp_app.router.lifespan_context(mcp_app), + httpx.ASGITransport(mcp_app) as transport, + ): + session_manager = app._session_manager + assert session_manager is not None + + async def make_and_abandon_tool_call(): + async with httpx.AsyncClient(transport=transport, base_url=f"http://{host}", timeout=30.0) as http_client: + async with Client(streamable_http_client(f"http://{host}/mcp", http_client=http_client)) as client: + # Start tool call — this will block until tool completes + # We'll cancel it from outside to simulate disconnect + await client.call_tool("slow_tool", {}) + + num_requests = 3 + for _ in range(num_requests): + async with anyio.create_task_group() as tg: + tg.start_soon(make_and_abandon_tool_call) + # Wait for the tool handler to actually start + await tool_started.wait() + tool_started = anyio.Event() # Reset for next iteration + # Simulate client disconnect by cancelling the request + tg.cancel_scope.cancel() + + # Let the tool finish now (response has nowhere to go) + tool_gate.set() + tool_gate = anyio.Event() # Reset for next iteration + + # Give tasks a chance to settle + await anyio.sleep(0.1) + + # Check for leaked tasks in the session manager's global task group + await anyio.sleep(0.1) + assert session_manager._task_group is not None + leaked = len(session_manager._task_group._tasks) # type: ignore[attr-defined] + + assert leaked == 0, ( + f"Expected 0 lingering tasks but found {leaked}. Stateless request tasks are leaking after client disconnect." + ) + + @pytest.mark.anyio async def test_unknown_session_id_returns_404(): """Test that requests with unknown session IDs return HTTP 404 per MCP spec."""