Skip to content

Commit 9a2bb6a

Browse files
authored
refactor: Make types.py strictly typechecked. (#336)
1 parent df2d3a5 commit 9a2bb6a

File tree

4 files changed

+66
-34
lines changed

4 files changed

+66
-34
lines changed

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ include = ["src/mcp", "tests"]
8787
venvPath = "."
8888
venv = ".venv"
8989
strict = ["src/mcp/**/*.py"]
90-
exclude = ["src/mcp/types.py"]
9190

9291
[tool.ruff.lint]
9392
select = ["E", "F", "I", "UP"]

src/mcp/types.py

+64-27
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ class Meta(BaseModel):
6464
"""
6565

6666

67-
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams)
68-
NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams)
67+
RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
68+
NotificationParamsT = TypeVar(
69+
"NotificationParamsT", bound=NotificationParams | dict[str, Any] | None
70+
)
6971
MethodT = TypeVar("MethodT", bound=str)
7072

7173

@@ -113,15 +115,16 @@ class PaginatedResult(Result):
113115
"""
114116

115117

116-
class JSONRPCRequest(Request):
118+
class JSONRPCRequest(Request[dict[str, Any] | None, str]):
117119
"""A request that expects a response."""
118120

119121
jsonrpc: Literal["2.0"]
120122
id: RequestId
123+
method: str
121124
params: dict[str, Any] | None = None
122125

123126

124-
class JSONRPCNotification(Notification):
127+
class JSONRPCNotification(Notification[dict[str, Any] | None, str]):
125128
"""A notification which does not expect a response."""
126129

127130
jsonrpc: Literal["2.0"]
@@ -277,7 +280,7 @@ class InitializeRequestParams(RequestParams):
277280
model_config = ConfigDict(extra="allow")
278281

279282

280-
class InitializeRequest(Request):
283+
class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]):
281284
"""
282285
This request is sent from the client to the server when it first connects, asking it
283286
to begin initialization.
@@ -298,7 +301,9 @@ class InitializeResult(Result):
298301
"""Instructions describing how to use the server and its features."""
299302

300303

301-
class InitializedNotification(Notification):
304+
class InitializedNotification(
305+
Notification[NotificationParams | None, Literal["notifications/initialized"]]
306+
):
302307
"""
303308
This notification is sent from the client to the server after initialization has
304309
finished.
@@ -308,7 +313,7 @@ class InitializedNotification(Notification):
308313
params: NotificationParams | None = None
309314

310315

311-
class PingRequest(Request):
316+
class PingRequest(Request[RequestParams | None, Literal["ping"]]):
312317
"""
313318
A ping, issued by either the server or the client, to check that the other party is
314319
still alive.
@@ -336,7 +341,9 @@ class ProgressNotificationParams(NotificationParams):
336341
model_config = ConfigDict(extra="allow")
337342

338343

339-
class ProgressNotification(Notification):
344+
class ProgressNotification(
345+
Notification[ProgressNotificationParams, Literal["notifications/progress"]]
346+
):
340347
"""
341348
An out-of-band notification used to inform the receiver of a progress update for a
342349
long-running request.
@@ -346,7 +353,9 @@ class ProgressNotification(Notification):
346353
params: ProgressNotificationParams
347354

348355

349-
class ListResourcesRequest(PaginatedRequest):
356+
class ListResourcesRequest(
357+
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
358+
):
350359
"""Sent from the client to request a list of resources the server has."""
351360

352361
method: Literal["resources/list"]
@@ -408,7 +417,9 @@ class ListResourcesResult(PaginatedResult):
408417
resources: list[Resource]
409418

410419

411-
class ListResourceTemplatesRequest(PaginatedRequest):
420+
class ListResourceTemplatesRequest(
421+
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
422+
):
412423
"""Sent from the client to request a list of resource templates the server has."""
413424

414425
method: Literal["resources/templates/list"]
@@ -432,7 +443,9 @@ class ReadResourceRequestParams(RequestParams):
432443
model_config = ConfigDict(extra="allow")
433444

434445

435-
class ReadResourceRequest(Request):
446+
class ReadResourceRequest(
447+
Request[ReadResourceRequestParams, Literal["resources/read"]]
448+
):
436449
"""Sent from the client to the server, to read a specific resource URI."""
437450

438451
method: Literal["resources/read"]
@@ -472,7 +485,11 @@ class ReadResourceResult(Result):
472485
contents: list[TextResourceContents | BlobResourceContents]
473486

474487

475-
class ResourceListChangedNotification(Notification):
488+
class ResourceListChangedNotification(
489+
Notification[
490+
NotificationParams | None, Literal["notifications/resources/list_changed"]
491+
]
492+
):
476493
"""
477494
An optional notification from the server to the client, informing it that the list
478495
of resources it can read from has changed.
@@ -493,7 +510,7 @@ class SubscribeRequestParams(RequestParams):
493510
model_config = ConfigDict(extra="allow")
494511

495512

496-
class SubscribeRequest(Request):
513+
class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]):
497514
"""
498515
Sent from the client to request resources/updated notifications from the server
499516
whenever a particular resource changes.
@@ -511,7 +528,9 @@ class UnsubscribeRequestParams(RequestParams):
511528
model_config = ConfigDict(extra="allow")
512529

513530

514-
class UnsubscribeRequest(Request):
531+
class UnsubscribeRequest(
532+
Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]
533+
):
515534
"""
516535
Sent from the client to request cancellation of resources/updated notifications from
517536
the server.
@@ -532,7 +551,11 @@ class ResourceUpdatedNotificationParams(NotificationParams):
532551
model_config = ConfigDict(extra="allow")
533552

534553

535-
class ResourceUpdatedNotification(Notification):
554+
class ResourceUpdatedNotification(
555+
Notification[
556+
ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]
557+
]
558+
):
536559
"""
537560
A notification from the server to the client, informing it that a resource has
538561
changed and may need to be read again.
@@ -542,7 +565,9 @@ class ResourceUpdatedNotification(Notification):
542565
params: ResourceUpdatedNotificationParams
543566

544567

545-
class ListPromptsRequest(PaginatedRequest):
568+
class ListPromptsRequest(
569+
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
570+
):
546571
"""Sent from the client to request a list of prompts and prompt templates."""
547572

548573
method: Literal["prompts/list"]
@@ -589,7 +614,7 @@ class GetPromptRequestParams(RequestParams):
589614
model_config = ConfigDict(extra="allow")
590615

591616

592-
class GetPromptRequest(Request):
617+
class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
593618
"""Used by the client to get a prompt provided by the server."""
594619

595620
method: Literal["prompts/get"]
@@ -659,7 +684,11 @@ class GetPromptResult(Result):
659684
messages: list[PromptMessage]
660685

661686

662-
class PromptListChangedNotification(Notification):
687+
class PromptListChangedNotification(
688+
Notification[
689+
NotificationParams | None, Literal["notifications/prompts/list_changed"]
690+
]
691+
):
663692
"""
664693
An optional notification from the server to the client, informing it that the list
665694
of prompts it offers has changed.
@@ -669,7 +698,7 @@ class PromptListChangedNotification(Notification):
669698
params: NotificationParams | None = None
670699

671700

672-
class ListToolsRequest(PaginatedRequest):
701+
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
673702
"""Sent from the client to request a list of tools the server has."""
674703

675704
method: Literal["tools/list"]
@@ -702,7 +731,7 @@ class CallToolRequestParams(RequestParams):
702731
model_config = ConfigDict(extra="allow")
703732

704733

705-
class CallToolRequest(Request):
734+
class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
706735
"""Used by the client to invoke a tool provided by the server."""
707736

708737
method: Literal["tools/call"]
@@ -716,7 +745,9 @@ class CallToolResult(Result):
716745
isError: bool = False
717746

718747

719-
class ToolListChangedNotification(Notification):
748+
class ToolListChangedNotification(
749+
Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]
750+
):
720751
"""
721752
An optional notification from the server to the client, informing it that the list
722753
of tools it offers has changed.
@@ -739,7 +770,7 @@ class SetLevelRequestParams(RequestParams):
739770
model_config = ConfigDict(extra="allow")
740771

741772

742-
class SetLevelRequest(Request):
773+
class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
743774
"""A request from the client to the server, to enable or adjust logging."""
744775

745776
method: Literal["logging/setLevel"]
@@ -761,7 +792,9 @@ class LoggingMessageNotificationParams(NotificationParams):
761792
model_config = ConfigDict(extra="allow")
762793

763794

764-
class LoggingMessageNotification(Notification):
795+
class LoggingMessageNotification(
796+
Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]
797+
):
765798
"""Notification of a log message passed from server to client."""
766799

767800
method: Literal["notifications/message"]
@@ -856,7 +889,9 @@ class CreateMessageRequestParams(RequestParams):
856889
model_config = ConfigDict(extra="allow")
857890

858891

859-
class CreateMessageRequest(Request):
892+
class CreateMessageRequest(
893+
Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]
894+
):
860895
"""A request from the server to sample an LLM via the client."""
861896

862897
method: Literal["sampling/createMessage"]
@@ -913,7 +948,7 @@ class CompleteRequestParams(RequestParams):
913948
model_config = ConfigDict(extra="allow")
914949

915950

916-
class CompleteRequest(Request):
951+
class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
917952
"""A request from the client to the server, to ask for completion options."""
918953

919954
method: Literal["completion/complete"]
@@ -944,7 +979,7 @@ class CompleteResult(Result):
944979
completion: Completion
945980

946981

947-
class ListRootsRequest(Request):
982+
class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
948983
"""
949984
Sent from the server to request a list of root URIs from the client. Roots allow
950985
servers to ask for specific directories or files to operate on. A common example
@@ -987,7 +1022,9 @@ class ListRootsResult(Result):
9871022
roots: list[Root]
9881023

9891024

990-
class RootsListChangedNotification(Notification):
1025+
class RootsListChangedNotification(
1026+
Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]]
1027+
):
9911028
"""
9921029
A notification from the client to the server, informing it that the list of
9931030
roots has changed.

tests/shared/test_sse.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]:
138138
time.sleep(0.1)
139139
attempt += 1
140140
else:
141-
raise RuntimeError(
142-
f"Server failed to start after {max_attempts} attempts"
143-
)
141+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
144142

145143
yield
146144

tests/shared/test_ws.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]:
134134
time.sleep(0.1)
135135
attempt += 1
136136
else:
137-
raise RuntimeError(
138-
f"Server failed to start after {max_attempts} attempts"
139-
)
137+
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")
140138

141139
yield
142140

0 commit comments

Comments
 (0)