@@ -68,9 +68,10 @@ async def main():
68
68
import logging
69
69
import warnings
70
70
from collections .abc import Awaitable , Callable
71
- from contextlib import AbstractAsyncContextManager , asynccontextmanager
71
+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
72
72
from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
73
73
74
+ import anyio
74
75
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
75
76
from pydantic import AnyUrl
76
77
@@ -469,41 +470,47 @@ async def run(
469
470
# in-process servers.
470
471
raise_exceptions : bool = False ,
471
472
):
472
- with warnings .catch_warnings (record = True ) as w :
473
- from contextlib import AsyncExitStack
474
-
475
- async with AsyncExitStack () as stack :
476
- lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477
- session = await stack .enter_async_context (
478
- ServerSession (read_stream , write_stream , initialization_options )
479
- )
473
+ async with AsyncExitStack () as stack :
474
+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
475
+ session = await stack .enter_async_context (
476
+ ServerSession (read_stream , write_stream , initialization_options )
477
+ )
480
478
479
+ async with anyio .create_task_group () as tg :
481
480
async for message in session .incoming_messages :
482
481
logger .debug (f"Received message: { message } " )
483
482
484
- match message :
485
- case (
486
- RequestResponder (
487
- request = types .ClientRequest (root = req )
488
- ) as responder
489
- ):
490
- with responder :
491
- await self ._handle_request (
492
- message ,
493
- req ,
494
- session ,
495
- lifespan_context ,
496
- raise_exceptions ,
497
- )
498
- case types .ClientNotification (root = notify ):
499
- await self ._handle_notification (notify )
500
-
501
- for warning in w :
502
- logger .info (
503
- "Warning: %s: %s" ,
504
- warning .category .__name__ ,
505
- warning .message ,
483
+ tg .start_soon (
484
+ self ._handle_message ,
485
+ message ,
486
+ session ,
487
+ lifespan_context ,
488
+ raise_exceptions ,
489
+ )
490
+
491
+ async def _handle_message (
492
+ self ,
493
+ message : RequestResponder [types .ClientRequest , types .ServerResult ]
494
+ | types .ClientNotification
495
+ | Exception ,
496
+ session : ServerSession ,
497
+ lifespan_context : LifespanResultT ,
498
+ raise_exceptions : bool = False ,
499
+ ):
500
+ with warnings .catch_warnings (record = True ) as w :
501
+ match message :
502
+ case (
503
+ RequestResponder (request = types .ClientRequest (root = req )) as responder
504
+ ):
505
+ with responder :
506
+ await self ._handle_request (
507
+ message , req , session , lifespan_context , raise_exceptions
506
508
)
509
+ case types .ClientNotification (root = notify ):
510
+ await self ._handle_notification (notify )
511
+
512
+ for warning in w :
513
+ logger .info (f"Warning: { warning .category .__name__ } : { warning .message } " )
507
514
508
515
async def _handle_request (
509
516
self ,
0 commit comments