7
7
import anyio .lowlevel
8
8
import httpx
9
9
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10
- from pydantic import BaseModel
10
+ from pydantic import BaseModel , RootModel
11
11
12
12
from mcp .shared .exceptions import McpError
13
13
from mcp .types import (
27
27
ServerResult ,
28
28
)
29
29
30
+ RawT = TypeVar ("RawT" )
31
+
32
+
33
+ class ParsedMessage (RootModel [JSONRPCMessage ], Generic [RawT ]):
34
+ root : JSONRPCMessage
35
+ raw : RawT | None = None
36
+
37
+ class Config :
38
+ arbitrary_types_allowed = True
39
+
40
+
41
+ ReadStream = MemoryObjectReceiveStream [ParsedMessage [RawT ] | Exception ]
42
+ ReadStreamWriter = MemoryObjectSendStream [ParsedMessage [RawT ] | Exception ]
43
+ WriteStream = MemoryObjectSendStream [ParsedMessage [RawT ]]
44
+ WriteStreamReader = MemoryObjectReceiveStream [ParsedMessage [RawT ]]
45
+
30
46
SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
31
47
SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
32
48
SendNotificationT = TypeVar ("SendNotificationT" , ClientNotification , ServerNotification )
@@ -159,8 +175,8 @@ class BaseSession(
159
175
160
176
def __init__ (
161
177
self ,
162
- read_stream : MemoryObjectReceiveStream [ JSONRPCMessage | Exception ] ,
163
- write_stream : MemoryObjectSendStream [ JSONRPCMessage ] ,
178
+ read_stream : ReadStream ,
179
+ write_stream : WriteStream ,
164
180
receive_request_type : type [ReceiveRequestT ],
165
181
receive_notification_type : type [ReceiveNotificationT ],
166
182
# If none, reading will never time out
@@ -225,7 +241,9 @@ async def send_request(
225
241
226
242
# TODO: Support progress callbacks
227
243
228
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_request ))
244
+ await self ._write_stream .send (
245
+ ParsedMessage (JSONRPCMessage (jsonrpc_request ), None )
246
+ )
229
247
230
248
try :
231
249
with anyio .fail_after (
@@ -261,14 +279,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
261
279
** notification .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
262
280
)
263
281
264
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_notification ))
282
+ await self ._write_stream .send (
283
+ ParsedMessage (JSONRPCMessage (jsonrpc_notification ))
284
+ )
265
285
266
286
async def _send_response (
267
287
self , request_id : RequestId , response : SendResultT | ErrorData
268
288
) -> None :
269
289
if isinstance (response , ErrorData ):
270
290
jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
271
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_error ))
291
+ await self ._write_stream .send (ParsedMessage ( JSONRPCMessage (jsonrpc_error ) ))
272
292
else :
273
293
jsonrpc_response = JSONRPCResponse (
274
294
jsonrpc = "2.0" ,
@@ -277,18 +297,23 @@ async def _send_response(
277
297
by_alias = True , mode = "json" , exclude_none = True
278
298
),
279
299
)
280
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_response ))
300
+ await self ._write_stream .send (
301
+ ParsedMessage (JSONRPCMessage (jsonrpc_response ))
302
+ )
281
303
282
304
async def _receive_loop (self ) -> None :
283
305
async with (
284
306
self ._read_stream ,
285
307
self ._write_stream ,
286
308
self ._incoming_message_stream_writer ,
287
309
):
288
- async for message in self ._read_stream :
289
- if isinstance (message , Exception ):
290
- await self ._incoming_message_stream_writer .send (message )
291
- elif isinstance (message .root , JSONRPCRequest ):
310
+ async for raw_message in self ._read_stream :
311
+ if isinstance (raw_message , Exception ):
312
+ await self ._incoming_message_stream_writer .send (raw_message )
313
+ continue
314
+
315
+ message = raw_message .root
316
+ if isinstance (message .root , JSONRPCRequest ):
292
317
validated_request = self ._receive_request_type .model_validate (
293
318
message .root .model_dump (
294
319
by_alias = True , mode = "json" , exclude_none = True
0 commit comments