Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ async def sse_client(
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
read_stream_writer, read_stream = anyio.create_memory_object_stream(1)
write_stream, write_stream_reader = anyio.create_memory_object_stream(1)

async with anyio.create_task_group() as tg:
try:
Expand Down Expand Up @@ -113,19 +113,34 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
try: # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
except ( # pragma: no cover
anyio.ClosedResourceError,
anyio.BrokenResourceError,
):
return # pragma: no cover
continue # pragma: no cover

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
try:
await read_stream_writer.send(session_message)
except (
anyio.ClosedResourceError,
anyio.BrokenResourceError,
): # pragma: no cover
return # pragma: no cover
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: lax no cover
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc: # pragma: lax no cover
logger.exception("Error in sse_reader")
await read_stream_writer.send(exc)
try:
await read_stream_writer.send(exc)
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
pass
finally:
await read_stream_writer.aclose()

Expand Down Expand Up @@ -156,6 +171,8 @@ async def post_writer(endpoint_url: str):
try:
yield read_stream, write_stream
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
Expand Down
49 changes: 39 additions & 10 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ async def _handle_sse_event(
message.id = original_request_id

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
try:
await read_stream_writer.send(session_message)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
logger.debug("Read stream closed, stopping SSE event handling")
return True

# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
Expand All @@ -170,9 +174,15 @@ async def _handle_sse_event(
if original_request_id is not None:
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
await read_stream_writer.send(error_msg)
try:
await read_stream_writer.send(error_msg)
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
pass
return True
await read_stream_writer.send(exc)
try:
await read_stream_writer.send(exc)
except (anyio.ClosedResourceError, anyio.BrokenResourceError):
pass
return False
else: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}")
Expand Down Expand Up @@ -271,14 +281,20 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
if isinstance(message, JSONRPCRequest): # pragma: no branch
error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated")
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
await ctx.read_stream_writer.send(session_message)
try:
await ctx.read_stream_writer.send(session_message)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
pass
return

if response.status_code >= 400:
if isinstance(message, JSONRPCRequest):
error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response")
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
await ctx.read_stream_writer.send(session_message)
try:
await ctx.read_stream_writer.send(session_message)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
pass
return

if is_initialization:
Expand All @@ -298,7 +314,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
logger.error(f"Unexpected content type: {content_type}")
error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
await ctx.read_stream_writer.send(error_msg)
try:
await ctx.read_stream_writer.send(error_msg)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
pass

async def _handle_json_response(
self,
Expand All @@ -318,12 +337,18 @@ async def _handle_json_response(
self._maybe_extract_protocol_version_from_message(message)

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
try:
await read_stream_writer.send(session_message)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
return
except (httpx.StreamError, ValidationError) as exc:
logger.exception("Error parsing JSON response")
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}")
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
await read_stream_writer.send(error_msg)
try:
await read_stream_writer.send(error_msg)
except (anyio.ClosedResourceError, anyio.BrokenResourceError): # pragma: no cover
return

async def _handle_sse_response(
self,
Expand Down Expand Up @@ -533,8 +558,8 @@ async def streamable_http_client(
Example:
See examples/snippets/clients/ for usage patterns.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](1)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](1)

# Determine if we need to create and manage the client
client_provided = http_client is not None
Expand Down Expand Up @@ -573,6 +598,10 @@ def start_get_stream() -> None:
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
# Close streams before cancelling to unblock tasks
# waiting on stream send/receive during shutdown.
await read_stream_writer.aclose()
await write_stream.aclose()
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
Expand Down
7 changes: 4 additions & 3 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ async def __aexit__(
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
# Close streams first so _receive_loop exits cooperatively,
# then cancel the task group as a fallback.
await self._read_stream.aclose()
await self._write_stream.aclose()
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

Expand Down
42 changes: 42 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,45 @@ async def make_request(client_session: ClientSession):
# Pending request completed successfully
assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_session_exit_closes_streams_before_cancel():
"""Verify BaseSession.__aexit__ closes streams before cancelling task group.

The receive loop should exit via ClosedResourceError on the read stream,
not via forced task group cancellation. This prevents AnyIO cancellation
busy-loops when tasks are blocked on stream operations.
"""
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, _server_write = server_streams

async def slow_server():
"""Read a request but never respond, keeping the session busy."""
try:
await server_read.receive()
# Hold the connection open
await anyio.sleep(60)
except (anyio.ClosedResourceError, anyio.get_cancelled_exc_class()):
pass

async with anyio.create_task_group() as outer_tg:
outer_tg.start_soon(slow_server)

with anyio.fail_after(5): # pragma: no branch
async with ClientSession(read_stream=client_read, write_stream=client_write) as client_session:
# Fire a request in a background task (will never get a response)
async with anyio.create_task_group() as inner_tg: # pragma: no branch

async def send_and_ignore():
try:
await client_session.send_ping()
except (MCPError, anyio.get_cancelled_exc_class()):
pass

inner_tg.start_soon(send_and_ignore)
await anyio.sleep(0.1)
inner_tg.cancel_scope.cancel()

outer_tg.cancel_scope.cancel() # pragma: lax no cover
49 changes: 49 additions & 0 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,3 +2247,52 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


@pytest.mark.anyio
async def test_streamable_http_client_exit_with_pending_requests(basic_server: None, basic_server_url: str):
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1805.

Sends tool calls to a server-side handler that blocks indefinitely (lock
never released), then exits the client context while responses are still
pending. Verifies that shutdown completes within the timeout and does not
hang or busy-loop in AnyIO cancellation delivery.
"""
with anyio.fail_after(10): # pragma: no branch
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()

async with anyio.create_task_group() as tg: # pragma: no branch

async def call_blocked_tool():
try:
await session.call_tool("wait_for_lock_with_notification", {})
except (MCPError, anyio.get_cancelled_exc_class()):
pass

# Fire off multiple requests that will block server-side
for _ in range(3):
tg.start_soon(call_blocked_tool)

# Give the server a moment to receive them, then bail out
await anyio.sleep(0.2)
tg.cancel_scope.cancel()

# If we reach here, shutdown completed without hanging.
await anyio.sleep(0.1)


@pytest.mark.anyio
async def test_streamable_http_client_rapid_connect_disconnect(basic_server: None, basic_server_url: str):
"""Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1805.

Rapidly connects, initializes, and disconnects multiple times. Verifies no
resource leak or cancellation busy-loop across iterations.
"""
for _ in range(5): # pragma: no branch
with anyio.fail_after(10): # pragma: no branch
async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
await session.initialize()
await anyio.sleep(0.1)