@@ -68,8 +68,10 @@ async def main():
68
68
import logging
69
69
import warnings
70
70
from collections .abc import Awaitable , Callable
71
+ from contextlib import AsyncExitStack
71
72
from typing import Any , Sequence
72
73
74
+ import anyio
73
75
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
74
76
from pydantic import AnyUrl
75
77
@@ -538,30 +540,41 @@ async def run(
538
540
# in-process servers.
539
541
raise_exceptions : bool = False ,
540
542
):
541
- with warnings .catch_warnings (record = True ) as w :
542
- async with ServerSession (
543
- read_stream , write_stream , initialization_options
544
- ) as session :
543
+ async with AsyncExitStack () as stack :
544
+ session = await stack .enter_async_context (
545
+ ServerSession (read_stream , write_stream , initialization_options )
546
+ )
547
+
548
+ async with anyio .create_task_group () as tg :
545
549
async for message in session .incoming_messages :
546
550
logger .debug (f"Received message: { message } " )
547
551
548
- match message :
549
- case (
550
- RequestResponder (
551
- request = types .ClientRequest (root = req )
552
- ) as responder
553
- ):
554
- with responder :
555
- await self ._handle_request (
556
- message , req , session , raise_exceptions
557
- )
558
- case types .ClientNotification (root = notify ):
559
- await self ._handle_notification (notify )
560
-
561
- for warning in w :
562
- logger .info (
563
- f"Warning: { warning .category .__name__ } : { warning .message } "
552
+ tg .start_soon (
553
+ self ._handle_message , message , session , raise_exceptions
554
+ )
555
+
556
+ async def _handle_message (
557
+ self ,
558
+ message : RequestResponder [types .ClientRequest , types .ServerResult ]
559
+ | types .ClientNotification
560
+ | Exception ,
561
+ session : ServerSession ,
562
+ raise_exceptions : bool = False ,
563
+ ):
564
+ with warnings .catch_warnings (record = True ) as w :
565
+ match message :
566
+ case (
567
+ RequestResponder (request = types .ClientRequest (root = req )) as responder
568
+ ):
569
+ with responder :
570
+ await self ._handle_request (
571
+ message , req , session , raise_exceptions
564
572
)
573
+ case types .ClientNotification (root = notify ):
574
+ await self ._handle_notification (notify )
575
+
576
+ for warning in w :
577
+ logger .info (f"Warning: { warning .category .__name__ } : { warning .message } " )
565
578
566
579
async def _handle_request (
567
580
self ,
0 commit comments