diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3d3988ce..da826d63 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -8,6 +8,7 @@ import httpx from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel +from typing_extensions import Self from mcp.shared.exceptions import McpError from mcp.types import ( @@ -60,7 +61,7 @@ def __init__( request_id: RequestId, request_meta: RequestParams.Meta | None, request: ReceiveRequestT, - session: "BaseSession", + session: "BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ) -> None: self.request_id = request_id @@ -134,7 +135,6 @@ def cancelled(self) -> bool: class BaseSession( - AbstractAsyncContextManager, Generic[ SendRequestT, SendNotificationT, @@ -183,7 +183,7 @@ def __init__( ]() ) - async def __aenter__(self): + async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() self._task_group.start_soon(self._receive_loop)