Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Make types.py strictly typechecked. #336

Merged
merged 1 commit into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ include = ["src/mcp", "tests"]
venvPath = "."
venv = ".venv"
strict = ["src/mcp/**/*.py"]
exclude = ["src/mcp/types.py"]

[tool.ruff.lint]
select = ["E", "F", "I", "UP"]
Expand Down
91 changes: 64 additions & 27 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ class Meta(BaseModel):
"""


RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
NotificationParamsT = TypeVar(
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
)
Comment on lines +67 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird tho... I'd like to spend some minutes checking if we can be a bit stricter here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it tomorrow.

MethodT = TypeVar("MethodT", bound=str)


Expand Down Expand Up @@ -113,15 +115,16 @@ class PaginatedResult(Result):
"""


class JSONRPCRequest(Request):
class JSONRPCRequest(Request[dict[str, Any] | None, str]):
"""A request that expects a response."""

jsonrpc: Literal["2.0"]
id: RequestId
method: str
params: dict[str, Any] | None = None


class JSONRPCNotification(Notification):
class JSONRPCNotification(Notification[dict[str, Any] | None, str]):
"""A notification which does not expect a response."""

jsonrpc: Literal["2.0"]
Expand Down Expand Up @@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class InitializeRequest(Request):
class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]):
"""
This request is sent from the client to the server when it first connects, asking it
to begin initialization.
Expand All @@ -298,7 +301,9 @@ class InitializeResult(Result):
"""Instructions describing how to use the server and its features."""


class InitializedNotification(Notification):
class InitializedNotification(
Notification[NotificationParams | None, Literal["notifications/initialized"]]
):
"""
This notification is sent from the client to the server after initialization has
finished.
Expand All @@ -308,7 +313,7 @@ class InitializedNotification(Notification):
params: NotificationParams | None = None


class PingRequest(Request):
class PingRequest(Request[RequestParams | None, Literal["ping"]]):
"""
A ping, issued by either the server or the client, to check that the other party is
still alive.
Expand Down Expand Up @@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")


class ProgressNotification(Notification):
class ProgressNotification(
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
):
"""
An out-of-band notification used to inform the receiver of a progress update for a
long-running request.
Expand All @@ -346,7 +353,9 @@ class ProgressNotification(Notification):
params: ProgressNotificationParams


class ListResourcesRequest(PaginatedRequest):
class ListResourcesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
):
"""Sent from the client to request a list of resources the server has."""

method: Literal["resources/list"]
Expand Down Expand Up @@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult):
resources: list[Resource]


class ListResourceTemplatesRequest(PaginatedRequest):
class ListResourceTemplatesRequest(
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
):
"""Sent from the client to request a list of resource templates the server has."""

method: Literal["resources/templates/list"]
Expand All @@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class ReadResourceRequest(Request):
class ReadResourceRequest(
Request[ReadResourceRequestParams, Literal["resources/read"]]
):
"""Sent from the client to the server, to read a specific resource URI."""

method: Literal["resources/read"]
Expand Down Expand Up @@ -472,7 +485,11 @@ class ReadResourceResult(Result):
contents: list[TextResourceContents | BlobResourceContents]


class ResourceListChangedNotification(Notification):
class ResourceListChangedNotification(
Notification[
NotificationParams | None, Literal["notifications/resources/list_changed"]
]
):
"""
An optional notification from the server to the client, informing it that the list
of resources it can read from has changed.
Expand All @@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class SubscribeRequest(Request):
class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]):
"""
Sent from the client to request resources/updated notifications from the server
whenever a particular resource changes.
Expand All @@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class UnsubscribeRequest(Request):
class UnsubscribeRequest(
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
):
"""
Sent from the client to request cancellation of resources/updated notifications from
the server.
Expand All @@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")


class ResourceUpdatedNotification(Notification):
class ResourceUpdatedNotification(
Notification[
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
]
):
"""
A notification from the server to the client, informing it that a resource has
changed and may need to be read again.
Expand All @@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification):
params: ResourceUpdatedNotificationParams


class ListPromptsRequest(PaginatedRequest):
class ListPromptsRequest(
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
):
"""Sent from the client to request a list of prompts and prompt templates."""

method: Literal["prompts/list"]
Expand Down Expand Up @@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class GetPromptRequest(Request):
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
"""Used by the client to get a prompt provided by the server."""

method: Literal["prompts/get"]
Expand Down Expand Up @@ -659,7 +684,11 @@ class GetPromptResult(Result):
messages: list[PromptMessage]


class PromptListChangedNotification(Notification):
class PromptListChangedNotification(
Notification[
NotificationParams | None, Literal["notifications/prompts/list_changed"]
]
):
"""
An optional notification from the server to the client, informing it that the list
of prompts it offers has changed.
Expand All @@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification):
params: NotificationParams | None = None


class ListToolsRequest(PaginatedRequest):
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
"""Sent from the client to request a list of tools the server has."""

method: Literal["tools/list"]
Expand Down Expand Up @@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class CallToolRequest(Request):
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
"""Used by the client to invoke a tool provided by the server."""

method: Literal["tools/call"]
Expand All @@ -716,7 +745,9 @@ class CallToolResult(Result):
isError: bool = False


class ToolListChangedNotification(Notification):
class ToolListChangedNotification(
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
):
"""
An optional notification from the server to the client, informing it that the list
of tools it offers has changed.
Expand All @@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class SetLevelRequest(Request):
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
"""A request from the client to the server, to enable or adjust logging."""

method: Literal["logging/setLevel"]
Expand All @@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams):
model_config = ConfigDict(extra="allow")


class LoggingMessageNotification(Notification):
class LoggingMessageNotification(
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
):
"""Notification of a log message passed from server to client."""

method: Literal["notifications/message"]
Expand Down Expand Up @@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class CreateMessageRequest(Request):
class CreateMessageRequest(
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
):
"""A request from the server to sample an LLM via the client."""

method: Literal["sampling/createMessage"]
Expand Down Expand Up @@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams):
model_config = ConfigDict(extra="allow")


class CompleteRequest(Request):
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
"""A request from the client to the server, to ask for completion options."""

method: Literal["completion/complete"]
Expand Down Expand Up @@ -944,7 +979,7 @@ class CompleteResult(Result):
completion: Completion


class ListRootsRequest(Request):
class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
"""
Sent from the server to request a list of root URIs from the client. Roots allow
servers to ask for specific directories or files to operate on. A common example
Expand Down Expand Up @@ -987,7 +1022,9 @@ class ListRootsResult(Result):
roots: list[Root]


class RootsListChangedNotification(Notification):
class RootsListChangedNotification(
Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]]
):
"""
A notification from the client to the server, informing it that the list of
roots has changed.
Expand Down
4 changes: 1 addition & 3 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

yield

Expand Down
4 changes: 1 addition & 3 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(
f"Server failed to start after {max_attempts} attempts"
)
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

yield

Expand Down