7
7
import anyio .lowlevel
8
8
import httpx
9
9
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10
- from pydantic import BaseModel
10
+ from pydantic import BaseModel , RootModel
11
11
from typing_extensions import Self
12
12
13
13
from mcp .shared .exceptions import McpError
28
28
ServerResult ,
29
29
)
30
30
31
+ RawT = TypeVar ("RawT" )
32
+
33
+
34
+ class ParsedMessage (RootModel [JSONRPCMessage ], Generic [RawT ]):
35
+ root : JSONRPCMessage
36
+ raw : RawT | None = None
37
+
38
+ class Config :
39
+ arbitrary_types_allowed = True
40
+
41
+
42
+ ReadStream = MemoryObjectReceiveStream [ParsedMessage [RawT ] | Exception ]
43
+ ReadStreamWriter = MemoryObjectSendStream [ParsedMessage [RawT ] | Exception ]
44
+ WriteStream = MemoryObjectSendStream [ParsedMessage [RawT ]]
45
+ WriteStreamReader = MemoryObjectReceiveStream [ParsedMessage [RawT ]]
46
+
31
47
SendRequestT = TypeVar ("SendRequestT" , ClientRequest , ServerRequest )
32
48
SendResultT = TypeVar ("SendResultT" , ClientResult , ServerResult )
33
49
SendNotificationT = TypeVar ("SendNotificationT" , ClientNotification , ServerNotification )
@@ -165,8 +181,8 @@ class BaseSession(
165
181
166
182
def __init__ (
167
183
self ,
168
- read_stream : MemoryObjectReceiveStream [ JSONRPCMessage | Exception ] ,
169
- write_stream : MemoryObjectSendStream [ JSONRPCMessage ] ,
184
+ read_stream : ReadStream ,
185
+ write_stream : WriteStream ,
170
186
receive_request_type : type [ReceiveRequestT ],
171
187
receive_notification_type : type [ReceiveNotificationT ],
172
188
# If none, reading will never time out
@@ -242,7 +258,9 @@ async def send_request(
242
258
243
259
# TODO: Support progress callbacks
244
260
245
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_request ))
261
+ await self ._write_stream .send (
262
+ ParsedMessage (JSONRPCMessage (jsonrpc_request ), None )
263
+ )
246
264
247
265
try :
248
266
with anyio .fail_after (
@@ -278,14 +296,16 @@ async def send_notification(self, notification: SendNotificationT) -> None:
278
296
** notification .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
279
297
)
280
298
281
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_notification ))
299
+ await self ._write_stream .send (
300
+ ParsedMessage (JSONRPCMessage (jsonrpc_notification ))
301
+ )
282
302
283
303
async def _send_response (
284
304
self , request_id : RequestId , response : SendResultT | ErrorData
285
305
) -> None :
286
306
if isinstance (response , ErrorData ):
287
307
jsonrpc_error = JSONRPCError (jsonrpc = "2.0" , id = request_id , error = response )
288
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_error ))
308
+ await self ._write_stream .send (ParsedMessage ( JSONRPCMessage (jsonrpc_error ) ))
289
309
else :
290
310
jsonrpc_response = JSONRPCResponse (
291
311
jsonrpc = "2.0" ,
@@ -294,18 +314,23 @@ async def _send_response(
294
314
by_alias = True , mode = "json" , exclude_none = True
295
315
),
296
316
)
297
- await self ._write_stream .send (JSONRPCMessage (jsonrpc_response ))
317
+ await self ._write_stream .send (
318
+ ParsedMessage (JSONRPCMessage (jsonrpc_response ))
319
+ )
298
320
299
321
async def _receive_loop (self ) -> None :
300
322
async with (
301
323
self ._read_stream ,
302
324
self ._write_stream ,
303
325
self ._incoming_message_stream_writer ,
304
326
):
305
- async for message in self ._read_stream :
306
- if isinstance (message , Exception ):
307
- await self ._incoming_message_stream_writer .send (message )
308
- elif isinstance (message .root , JSONRPCRequest ):
327
+ async for raw_message in self ._read_stream :
328
+ if isinstance (raw_message , Exception ):
329
+ await self ._incoming_message_stream_writer .send (raw_message )
330
+ continue
331
+
332
+ message = raw_message .root
333
+ if isinstance (message .root , JSONRPCRequest ):
309
334
validated_request = self ._receive_request_type .model_validate (
310
335
message .root .model_dump (
311
336
by_alias = True , mode = "json" , exclude_none = True
0 commit comments