Skip to content

Commit 2912670

Browse files
authored
Improve error handling in model proxy (#1973)
1 parent fc5286d commit 2912670

File tree

3 files changed

+145
-57
lines changed

3 files changed

+145
-57
lines changed

src/dstack/_internal/proxy/routers/model_proxy.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import AsyncIterator
1+
from typing import AsyncIterator, Optional
22

33
from fastapi import APIRouter, Depends, status
44
from fastapi.responses import StreamingResponse
@@ -55,13 +55,44 @@ async def post_chat_completions(
5555
return await client.generate(body)
5656
else:
5757
return StreamingResponse(
58-
stream_chunks(client.stream(body)),
58+
await StreamingAdaptor(client.stream(body)).get_stream(),
5959
media_type="text/event-stream",
6060
headers={"X-Accel-Buffering": "no"},
6161
)
6262

6363

64-
async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
65-
async for chunk in chunks:
66-
yield f"data:{chunk.json()}\n\n".encode()
67-
yield "data: [DONE]\n\n".encode()
64+
class StreamingAdaptor:
65+
"""
66+
Converts a stream of ChatCompletionsChunk to an SSE stream.
67+
Also pre-fetches the first chunk **before** starting streaming to downstream,
68+
so that upstream request errors can propagate to the downstream client.
69+
"""
70+
71+
def __init__(self, stream: AsyncIterator[ChatCompletionsChunk]) -> None:
72+
self._stream = stream
73+
74+
async def get_stream(self) -> AsyncIterator[bytes]:
75+
try:
76+
first_chunk = await self._stream.__anext__()
77+
except StopAsyncIteration:
78+
first_chunk = None
79+
return self._adaptor(first_chunk)
80+
81+
async def _adaptor(self, first_chunk: Optional[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
82+
if first_chunk is not None:
83+
yield self._encode_chunk(first_chunk)
84+
85+
try:
86+
async for chunk in self._stream:
87+
yield self._encode_chunk(chunk)
88+
except ProxyError as e:
89+
# No standard way to report errors while streaming,
90+
# but we'll at least send them as comments
91+
yield f": {e.detail!r}\n\n".encode() # !r to avoid line breaks
92+
return
93+
94+
yield "data: [DONE]\n\n".encode()
95+
96+
@staticmethod
97+
def _encode_chunk(chunk: ChatCompletionsChunk) -> bytes:
98+
return f"data:{chunk.json()}\n\n".encode()
Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import AsyncIterator
22

33
import httpx
4+
from fastapi import status
5+
from pydantic import ValidationError
46

57
from dstack._internal.proxy.errors import ProxyError
68
from dstack._internal.proxy.schemas.model_proxy import (
@@ -17,21 +19,49 @@ def __init__(self, http_client: httpx.AsyncClient, prefix: str):
1719
self._prefix = prefix
1820

1921
async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
20-
resp = await self._http.post(
21-
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
22-
)
23-
if resp.status_code != 200:
24-
raise ProxyError(resp.text)
25-
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
22+
try:
23+
resp = await self._http.post(
24+
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
25+
)
26+
await self._propagate_error(resp)
27+
except httpx.RequestError as e:
28+
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)
29+
30+
try:
31+
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
32+
except ValidationError as e:
33+
raise ProxyError(f"Invalid response from model: {e}", status.HTTP_502_BAD_GATEWAY)
2634

2735
async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
28-
async with self._http.stream(
29-
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
30-
) as resp:
31-
async for line in resp.aiter_lines():
32-
if not line.startswith("data:"):
33-
continue
34-
data = line[len("data:") :].strip()
35-
if data == "[DONE]":
36-
break
37-
yield ChatCompletionsChunk.__response__.parse_raw(data)
36+
try:
37+
async with self._http.stream(
38+
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
39+
) as resp:
40+
await self._propagate_error(resp)
41+
42+
async for line in resp.aiter_lines():
43+
if not line.startswith("data:"):
44+
continue
45+
data = line[len("data:") :].strip()
46+
if data == "[DONE]":
47+
break
48+
yield self._parse_chunk_data(data)
49+
except httpx.RequestError as e:
50+
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)
51+
52+
@staticmethod
53+
def _parse_chunk_data(data: str) -> ChatCompletionsChunk:
54+
try:
55+
return ChatCompletionsChunk.__response__.parse_raw(data)
56+
except ValidationError as e:
57+
raise ProxyError(f"Invalid chunk in model stream: {e}", status.HTTP_502_BAD_GATEWAY)
58+
59+
@staticmethod
60+
async def _propagate_error(resp: httpx.Response) -> None:
61+
"""
62+
Propagates HTTP error by raising ProxyError if status is not 200.
63+
May also raise httpx.RequestError if there are issues reading the response.
64+
"""
65+
if resp.status_code != 200:
66+
resp_body = await resp.aread()
67+
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)

src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import httpx
77
import jinja2
88
import jinja2.sandbox
9+
from fastapi import status
910

1011
from dstack._internal.proxy.errors import ProxyError
1112
from dstack._internal.proxy.schemas.model_proxy import (
@@ -38,9 +39,12 @@ def __init__(self, http_client: httpx.AsyncClient, chat_template: str, eos_token
3839

3940
async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
4041
payload = self.get_payload(request)
41-
resp = await self.client.post("/generate", json=payload)
42-
if resp.status_code != 200:
43-
raise ProxyError(resp.text) # TODO(egor-s)
42+
try:
43+
resp = await self.client.post("/generate", json=payload)
44+
await self.propagate_error(resp)
45+
except httpx.RequestError as e:
46+
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)
47+
4448
data = resp.json()
4549

4650
choices = [
@@ -91,38 +95,51 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
9195
created = int(datetime.datetime.utcnow().timestamp())
9296

9397
payload = self.get_payload(request)
94-
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
95-
async for line in resp.aiter_lines():
96-
if line.startswith("data:"):
97-
data = json.loads(line[len("data:") :].strip("\n"))
98-
if "error" in data:
99-
raise ProxyError(data["error"])
100-
chunk = ChatCompletionsChunk(
101-
id=completion_id,
102-
choices=[],
103-
created=created,
104-
model=request.model,
105-
system_fingerprint="",
106-
)
107-
if data["details"] is not None:
108-
chunk.choices = [
109-
ChatCompletionsChunkChoice(
110-
delta={},
111-
logprobs=None,
112-
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
113-
index=0,
114-
)
115-
]
116-
else:
117-
chunk.choices = [
118-
ChatCompletionsChunkChoice(
119-
delta={"content": data["token"]["text"], "role": "assistant"},
120-
logprobs=None,
121-
finish_reason=None,
122-
index=0,
123-
)
124-
]
125-
yield chunk
98+
try:
99+
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
100+
await self.propagate_error(resp)
101+
async for line in resp.aiter_lines():
102+
if line.startswith("data:"):
103+
yield self.parse_chunk(
104+
data=json.loads(line[len("data:") :].strip("\n")),
105+
model=request.model,
106+
completion_id=completion_id,
107+
created=created,
108+
)
109+
except httpx.RequestError as e:
110+
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)
111+
112+
def parse_chunk(
113+
self, data: dict, model: str, completion_id: str, created: int
114+
) -> ChatCompletionsChunk:
115+
if "error" in data:
116+
raise ProxyError(data["error"])
117+
chunk = ChatCompletionsChunk(
118+
id=completion_id,
119+
choices=[],
120+
created=created,
121+
model=model,
122+
system_fingerprint="",
123+
)
124+
if data["details"] is not None:
125+
chunk.choices = [
126+
ChatCompletionsChunkChoice(
127+
delta={},
128+
logprobs=None,
129+
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
130+
index=0,
131+
)
132+
]
133+
else:
134+
chunk.choices = [
135+
ChatCompletionsChunkChoice(
136+
delta={"content": data["token"]["text"], "role": "assistant"},
137+
logprobs=None,
138+
finish_reason=None,
139+
index=0,
140+
)
141+
]
142+
return chunk
126143

127144
def get_payload(self, request: ChatCompletionsRequest) -> Dict:
128145
try:
@@ -177,6 +194,16 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str:
177194
return text[: -len(stop_token)]
178195
return text
179196

197+
@staticmethod
198+
async def propagate_error(resp: httpx.Response) -> None:
199+
"""
200+
Propagates HTTP error by raising ProxyError if status is not 200.
201+
May also raise httpx.RequestError if there are issues reading the response.
202+
"""
203+
if resp.status_code != 200:
204+
resp_body = await resp.aread()
205+
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)
206+
180207

181208
def raise_exception(message: str):
182209
raise jinja2.TemplateError(message)

0 commit comments

Comments
 (0)