@@ -68,7 +68,8 @@ async def main():
68
68
import logging
69
69
import warnings
70
70
from collections .abc import Awaitable , Callable
71
- from typing import Any , Sequence
71
+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
72
+ from typing import Any , AsyncIterator , Generic , Sequence , TypeVar
72
73
73
74
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
74
75
from pydantic import AnyUrl
@@ -84,7 +85,10 @@ async def main():
84
85
85
86
logger = logging .getLogger (__name__ )
86
87
87
- request_ctx : contextvars .ContextVar [RequestContext [ServerSession ]] = (
88
+ LifespanResultT = TypeVar ("LifespanResultT" )
89
+
90
+ # This will be properly typed in each Server instance's context
91
+ request_ctx : contextvars .ContextVar [RequestContext [ServerSession , Any ]] = (
88
92
contextvars .ContextVar ("request_ctx" )
89
93
)
90
94
@@ -101,13 +105,33 @@ def __init__(
101
105
self .tools_changed = tools_changed
102
106
103
107
104
- class Server :
108
+ @asynccontextmanager
109
+ async def lifespan (server : "Server" ) -> AsyncIterator [object ]:
110
+ """Default lifespan context manager that does nothing.
111
+
112
+ Args:
113
+ server: The server instance this lifespan is managing
114
+
115
+ Returns:
116
+ An empty context object
117
+ """
118
+ yield {}
119
+
120
+
121
+ class Server (Generic [LifespanResultT ]):
105
122
def __init__ (
106
- self , name : str , version : str | None = None , instructions : str | None = None
123
+ self ,
124
+ name : str ,
125
+ version : str | None = None ,
126
+ instructions : str | None = None ,
127
+ lifespan : Callable [
128
+ ["Server" ], AbstractAsyncContextManager [LifespanResultT ]
129
+ ] = lifespan ,
107
130
):
108
131
self .name = name
109
132
self .version = version
110
133
self .instructions = instructions
134
+ self .lifespan = lifespan
111
135
self .request_handlers : dict [
112
136
type , Callable [..., Awaitable [types .ServerResult ]]
113
137
] = {
@@ -188,7 +212,7 @@ def get_capabilities(
188
212
)
189
213
190
214
@property
191
- def request_context (self ) -> RequestContext [ServerSession ]:
215
+ def request_context (self ) -> RequestContext [ServerSession , LifespanResultT ]:
192
216
"""If called outside of a request context, this will raise a LookupError."""
193
217
return request_ctx .get ()
194
218
@@ -446,9 +470,14 @@ async def run(
446
470
raise_exceptions : bool = False ,
447
471
):
448
472
with warnings .catch_warnings (record = True ) as w :
449
- async with ServerSession (
450
- read_stream , write_stream , initialization_options
451
- ) as session :
473
+ from contextlib import AsyncExitStack
474
+
475
+ async with AsyncExitStack () as stack :
476
+ lifespan_context = await stack .enter_async_context (self .lifespan (self ))
477
+ session = await stack .enter_async_context (
478
+ ServerSession (read_stream , write_stream , initialization_options )
479
+ )
480
+
452
481
async for message in session .incoming_messages :
453
482
logger .debug (f"Received message: { message } " )
454
483
@@ -460,21 +489,28 @@ async def run(
460
489
):
461
490
with responder :
462
491
await self ._handle_request (
463
- message , req , session , raise_exceptions
492
+ message ,
493
+ req ,
494
+ session ,
495
+ lifespan_context ,
496
+ raise_exceptions ,
464
497
)
465
498
case types .ClientNotification (root = notify ):
466
499
await self ._handle_notification (notify )
467
500
468
501
for warning in w :
469
502
logger .info (
470
- f"Warning: { warning .category .__name__ } : { warning .message } "
503
+ "Warning: %s: %s" ,
504
+ warning .category .__name__ ,
505
+ warning .message ,
471
506
)
472
507
473
508
async def _handle_request (
474
509
self ,
475
510
message : RequestResponder ,
476
511
req : Any ,
477
512
session : ServerSession ,
513
+ lifespan_context : LifespanResultT ,
478
514
raise_exceptions : bool ,
479
515
):
480
516
logger .info (f"Processing request of type { type (req ).__name__ } " )
@@ -491,6 +527,7 @@ async def _handle_request(
491
527
message .request_id ,
492
528
message .request_meta ,
493
529
session ,
530
+ lifespan_context ,
494
531
)
495
532
)
496
533
response = await handler (req )
0 commit comments