-
Notifications
You must be signed in to change notification settings - Fork 829
/
Copy pathsse.py
146 lines (127 loc) · 6.6 KB
/
sse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
import mcp.types as types
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client,
"GET",
url,
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse():
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme
!= endpoint_parsed.scheme
):
error_msg = (
"Endpoint origin does not match "
f"connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
task_status.started(endpoint_url)
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(
f"Received server message: {message}"
)
except Exception as exc:
logger.error(
f"Error parsing server message: {exc}"
)
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
case _:
logger.warning(
f"Unknown SSE event: {sse.event}"
)
except Exception as exc:
logger.error(f"Error in sse_reader: {exc}")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for message in write_stream_reader:
logger.debug(f"Sending client message: {message}")
response = await client.post(
endpoint_url,
json=message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(
"Client message sent successfully: "
f"{response.status_code}"
)
except Exception as exc:
logger.error(f"Error in post_writer: {exc}")
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.info(
f"Starting post writer with endpoint URL: {endpoint_url}"
)
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()