@@ -68,9 +68,10 @@ async def main():
68
68
import logging
69
69
import warnings
70
70
from collections .abc import Awaitable , Callable
71
- from contextlib import AbstractAsyncContextManager , asynccontextmanager
71
+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
72
72
from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
73
73
74
+ import anyio
74
75
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
75
76
from pydantic import AnyUrl
76
77
@@ -458,6 +459,30 @@ async def handler(req: types.CompleteRequest):
458
459
459
460
return decorator
460
461
462
+ async def _handle_message (
463
+ self ,
464
+ message : RequestResponder [types .ClientRequest , types .ServerResult ]
465
+ | types .ClientNotification
466
+ | Exception ,
467
+ session : ServerSession ,
468
+ lifespan_context : LifespanResultT ,
469
+ raise_exceptions : bool = False ,
470
+ ):
471
+ with warnings .catch_warnings (record = True ) as w :
472
+ match message :
473
+ case (
474
+ RequestResponder (request = types .ClientRequest (root = req )) as responder
475
+ ):
476
+ with responder :
477
+ await self ._handle_request (
478
+ message , req , session , lifespan_context , raise_exceptions
479
+ )
480
+ case types .ClientNotification (root = notify ):
481
+ await self ._handle_notification (notify )
482
+
483
+ for warning in w :
484
+ logger .info (f"Warning: { warning .category .__name__ } : { warning .message } " )
485
+
461
486
async def run (
462
487
self ,
463
488
read_stream : MemoryObjectReceiveStream [types .JSONRPCMessage | Exception ],
@@ -469,41 +494,23 @@ async def run(
469
494
# in-process servers.
470
495
raise_exceptions : bool = False ,
471
496
):
472
- with warnings .catch_warnings (record = True ) as w :
473
- from contextlib import AsyncExitStack
474
-
475
- async with AsyncExitStack () as stack :
476
- lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477
- session = await stack .enter_async_context (
478
- ServerSession (read_stream , write_stream , initialization_options )
479
- )
497
+ async with AsyncExitStack () as stack :
498
+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
499
+ session = await stack .enter_async_context (
500
+ ServerSession (read_stream , write_stream , initialization_options )
501
+ )
480
502
503
+ async with anyio .create_task_group () as tg :
481
504
async for message in session .incoming_messages :
482
505
logger .debug (f"Received message: { message } " )
483
506
484
- match message :
485
- case (
486
- RequestResponder (
487
- request = types .ClientRequest (root = req )
488
- ) as responder
489
- ):
490
- with responder :
491
- await self ._handle_request (
492
- message ,
493
- req ,
494
- session ,
495
- lifespan_context ,
496
- raise_exceptions ,
497
- )
498
- case types .ClientNotification (root = notify ):
499
- await self ._handle_notification (notify )
500
-
501
- for warning in w :
502
- logger .info (
503
- "Warning: %s: %s" ,
504
- warning .category .__name__ ,
505
- warning .message ,
506
- )
507
+ tg .start_soon (
508
+ self ._handle_message ,
509
+ message ,
510
+ session ,
511
+ lifespan_context ,
512
+ raise_exceptions ,
513
+ )
507
514
508
515
async def _handle_request (
509
516
self ,
0 commit comments