diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index b61bf0b03..9f5bb1774 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,4 +1,5 @@ import logging +from types import TracebackType from typing import Any, Protocol, overload import anyio.lowlevel @@ -108,6 +109,8 @@ class ClientSession( types.ServerNotification, ] ): + _entered: bool + def __init__( self, read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], @@ -140,11 +143,31 @@ def __init__( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._server_capabilities: types.ServerCapabilities | None = None self._experimental_features: ExperimentalClientFeatures | None = None + self._entered = False # Experimental: Task handlers (use defaults if not provided) self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() + async def __aenter__(self) -> "ClientSession": + self._entered = True + await super().__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._entered = False + await super().__aexit__(exc_type, exc_value, traceback) + + def _check_is_active(self) -> None: + if not self._entered: + raise RuntimeError("ClientSession must be used within an 'async with' block.") + async def initialize(self) -> types.InitializeResult: + self._check_is_active() sampling = ( (self._sampling_capabilities or types.SamplingCapability()) if self._sampling_callback is not _default_sampling_callback diff --git a/tests/client/test_session.py b/tests/client/test_session.py index eb2683fbd..316462ca1 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, cast import anyio import pytest @@ -768,3 +768,20 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_initialize_without_context_manager_raises_error(): + """ + Test that calling initialize() without entering the context manager raises RuntimeError. + """ + send_stream, receive_stream = anyio.create_memory_object_stream[Any](0) + + read_stream = cast(Any, receive_stream) + write_stream = cast(Any, send_stream) + + async with send_stream, receive_stream: + session = ClientSession(read_stream, write_stream) + + with pytest.raises(RuntimeError, match="must be used within"): + await session.initialize()