diff --git a/pyproject.toml b/pyproject.toml index e400ad7d..d014bf0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,6 @@ include = ["src/mcp", "tests"] venvPath = "." venv = ".venv" strict = ["src/mcp/**/*.py"] -exclude = ["src/mcp/types.py"] [tool.ruff.lint] select = ["E", "F", "I", "UP"] diff --git a/src/mcp/types.py b/src/mcp/types.py index f043fb10..4ef11106 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -64,8 +64,10 @@ class Meta(BaseModel): """ -RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams) -NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams) +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) +NotificationParamsT = TypeVar( + "NotificationParamsT", bound=NotificationParams | dict[str, Any] | None +) MethodT = TypeVar("MethodT", bound=str) @@ -113,15 +115,16 @@ class PaginatedResult(Result): """ -class JSONRPCRequest(Request): +class JSONRPCRequest(Request[dict[str, Any] | None, str]): """A request that expects a response.""" jsonrpc: Literal["2.0"] id: RequestId + method: str params: dict[str, Any] | None = None -class JSONRPCNotification(Notification): +class JSONRPCNotification(Notification[dict[str, Any] | None, str]): """A notification which does not expect a response.""" jsonrpc: Literal["2.0"] @@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class InitializeRequest(Request): +class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): """ This request is sent from the client to the server when it first connects, asking it to begin initialization. @@ -298,7 +301,9 @@ class InitializeResult(Result): """Instructions describing how to use the server and its features.""" -class InitializedNotification(Notification): +class InitializedNotification( + Notification[NotificationParams | None, Literal["notifications/initialized"]] +): """ This notification is sent from the client to the server after initialization has finished. @@ -308,7 +313,7 @@ class InitializedNotification(Notification): params: NotificationParams | None = None -class PingRequest(Request): +class PingRequest(Request[RequestParams | None, Literal["ping"]]): """ A ping, issued by either the server or the client, to check that the other party is still alive. @@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class ProgressNotification(Notification): +class ProgressNotification( + Notification[ProgressNotificationParams, Literal["notifications/progress"]] +): """ An out-of-band notification used to inform the receiver of a progress update for a long-running request. @@ -346,7 +353,9 @@ class ProgressNotification(Notification): params: ProgressNotificationParams -class ListResourcesRequest(PaginatedRequest): +class ListResourcesRequest( + PaginatedRequest[RequestParams | None, Literal["resources/list"]] +): """Sent from the client to request a list of resources the server has.""" method: Literal["resources/list"] @@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult): resources: list[Resource] -class ListResourceTemplatesRequest(PaginatedRequest): +class ListResourceTemplatesRequest( + PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]] +): """Sent from the client to request a list of resource templates the server has.""" method: Literal["resources/templates/list"] @@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class ReadResourceRequest(Request): +class ReadResourceRequest( + Request[ReadResourceRequestParams, Literal["resources/read"]] +): """Sent from the client to the server, to read a specific resource URI.""" method: Literal["resources/read"] @@ -472,7 +485,11 @@ class ReadResourceResult(Result): contents: list[TextResourceContents | BlobResourceContents] -class ResourceListChangedNotification(Notification): +class ResourceListChangedNotification( + Notification[ + NotificationParams | None, Literal["notifications/resources/list_changed"] + ] +): """ An optional notification from the server to the client, informing it that the list of resources it can read from has changed. @@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class SubscribeRequest(Request): +class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): """ Sent from the client to request resources/updated notifications from the server whenever a particular resource changes. @@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class UnsubscribeRequest(Request): +class UnsubscribeRequest( + Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]] +): """ Sent from the client to request cancellation of resources/updated notifications from the server. @@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class ResourceUpdatedNotification(Notification): +class ResourceUpdatedNotification( + Notification[ + ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"] + ] +): """ A notification from the server to the client, informing it that a resource has changed and may need to be read again. @@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification): params: ResourceUpdatedNotificationParams -class ListPromptsRequest(PaginatedRequest): +class ListPromptsRequest( + PaginatedRequest[RequestParams | None, Literal["prompts/list"]] +): """Sent from the client to request a list of prompts and prompt templates.""" method: Literal["prompts/list"] @@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class GetPromptRequest(Request): +class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): """Used by the client to get a prompt provided by the server.""" method: Literal["prompts/get"] @@ -659,7 +684,11 @@ class GetPromptResult(Result): messages: list[PromptMessage] -class PromptListChangedNotification(Notification): +class PromptListChangedNotification( + Notification[ + NotificationParams | None, Literal["notifications/prompts/list_changed"] + ] +): """ An optional notification from the server to the client, informing it that the list of prompts it offers has changed. @@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification): params: NotificationParams | None = None -class ListToolsRequest(PaginatedRequest): +class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" method: Literal["tools/list"] @@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CallToolRequest(Request): +class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): """Used by the client to invoke a tool provided by the server.""" method: Literal["tools/call"] @@ -716,7 +745,9 @@ class CallToolResult(Result): isError: bool = False -class ToolListChangedNotification(Notification): +class ToolListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]] +): """ An optional notification from the server to the client, informing it that the list of tools it offers has changed. @@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class SetLevelRequest(Request): +class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): """A request from the client to the server, to enable or adjust logging.""" method: Literal["logging/setLevel"] @@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams): model_config = ConfigDict(extra="allow") -class LoggingMessageNotification(Notification): +class LoggingMessageNotification( + Notification[LoggingMessageNotificationParams, Literal["notifications/message"]] +): """Notification of a log message passed from server to client.""" method: Literal["notifications/message"] @@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CreateMessageRequest(Request): +class CreateMessageRequest( + Request[CreateMessageRequestParams, Literal["sampling/createMessage"]] +): """A request from the server to sample an LLM via the client.""" method: Literal["sampling/createMessage"] @@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams): model_config = ConfigDict(extra="allow") -class CompleteRequest(Request): +class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): """A request from the client to the server, to ask for completion options.""" method: Literal["completion/complete"] @@ -944,7 +979,7 @@ class CompleteResult(Result): completion: Completion -class ListRootsRequest(Request): +class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): """ Sent from the server to request a list of root URIs from the client. Roots allow servers to ask for specific directories or files to operate on. A common example @@ -987,7 +1022,9 @@ class ListRootsResult(Result): roots: list[Root] -class RootsListChangedNotification(Notification): +class RootsListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] +): """ A notification from the client to the server, informing it that the list of roots has changed. diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 43107b59..f5158c3c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 2aca97e1..1381c815 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield