Skip to content

Improve error handling in model proxy #1973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/dstack/_internal/proxy/routers/model_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator
from typing import AsyncIterator, Optional

from fastapi import APIRouter, Depends, status
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -55,13 +55,44 @@ async def post_chat_completions(
return await client.generate(body)
else:
return StreamingResponse(
stream_chunks(client.stream(body)),
await StreamingAdaptor(client.stream(body)).get_stream(),
media_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)


async def stream_chunks(chunks: AsyncIterator[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
async for chunk in chunks:
yield f"data:{chunk.json()}\n\n".encode()
yield "data: [DONE]\n\n".encode()
class StreamingAdaptor:
"""
Converts a stream of ChatCompletionsChunk to an SSE stream.
Also pre-fetches the first chunk **before** starting streaming to downstream,
so that upstream request errors can propagate to the downstream client.
"""

def __init__(self, stream: AsyncIterator[ChatCompletionsChunk]) -> None:
self._stream = stream

async def get_stream(self) -> AsyncIterator[bytes]:
try:
first_chunk = await self._stream.__anext__()
except StopAsyncIteration:
first_chunk = None
return self._adaptor(first_chunk)

async def _adaptor(self, first_chunk: Optional[ChatCompletionsChunk]) -> AsyncIterator[bytes]:
if first_chunk is not None:
yield self._encode_chunk(first_chunk)

try:
async for chunk in self._stream:
yield self._encode_chunk(chunk)
except ProxyError as e:
# No standard way to report errors while streaming,
# but we'll at least send them as comments
yield f": {e.detail!r}\n\n".encode() # !r to avoid line breaks
return

yield "data: [DONE]\n\n".encode()

@staticmethod
def _encode_chunk(chunk: ChatCompletionsChunk) -> bytes:
return f"data:{chunk.json()}\n\n".encode()
62 changes: 46 additions & 16 deletions src/dstack/_internal/proxy/services/model_proxy/clients/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import AsyncIterator

import httpx
from fastapi import status
from pydantic import ValidationError

from dstack._internal.proxy.errors import ProxyError
from dstack._internal.proxy.schemas.model_proxy import (
Expand All @@ -17,21 +19,49 @@ def __init__(self, http_client: httpx.AsyncClient, prefix: str):
self._prefix = prefix

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
resp = await self._http.post(
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
)
if resp.status_code != 200:
raise ProxyError(resp.text)
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
try:
resp = await self._http.post(
f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
)
await self._propagate_error(resp)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

try:
return ChatCompletionsResponse.__response__.parse_raw(resp.content)
except ValidationError as e:
raise ProxyError(f"Invalid response from model: {e}", status.HTTP_502_BAD_GATEWAY)

async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCompletionsChunk]:
async with self._http.stream(
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
) as resp:
async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data = line[len("data:") :].strip()
if data == "[DONE]":
break
yield ChatCompletionsChunk.__response__.parse_raw(data)
try:
async with self._http.stream(
"POST", f"{self._prefix}/chat/completions", json=request.dict(exclude_unset=True)
) as resp:
await self._propagate_error(resp)

async for line in resp.aiter_lines():
if not line.startswith("data:"):
continue
data = line[len("data:") :].strip()
if data == "[DONE]":
break
yield self._parse_chunk_data(data)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

@staticmethod
def _parse_chunk_data(data: str) -> ChatCompletionsChunk:
try:
return ChatCompletionsChunk.__response__.parse_raw(data)
except ValidationError as e:
raise ProxyError(f"Invalid chunk in model stream: {e}", status.HTTP_502_BAD_GATEWAY)

@staticmethod
async def _propagate_error(resp: httpx.Response) -> None:
"""
Propagates HTTP error by raising ProxyError if status is not 200.
May also raise httpx.RequestError if there are issues reading the response.
"""
if resp.status_code != 200:
resp_body = await resp.aread()
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)
97 changes: 62 additions & 35 deletions src/dstack/_internal/proxy/services/model_proxy/clients/tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import httpx
import jinja2
import jinja2.sandbox
from fastapi import status

from dstack._internal.proxy.errors import ProxyError
from dstack._internal.proxy.schemas.model_proxy import (
Expand Down Expand Up @@ -38,9 +39,12 @@ def __init__(self, http_client: httpx.AsyncClient, chat_template: str, eos_token

async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
payload = self.get_payload(request)
resp = await self.client.post("/generate", json=payload)
if resp.status_code != 200:
raise ProxyError(resp.text) # TODO(egor-s)
try:
resp = await self.client.post("/generate", json=payload)
await self.propagate_error(resp)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

data = resp.json()

choices = [
Expand Down Expand Up @@ -91,38 +95,51 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
created = int(datetime.datetime.utcnow().timestamp())

payload = self.get_payload(request)
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
async for line in resp.aiter_lines():
if line.startswith("data:"):
data = json.loads(line[len("data:") :].strip("\n"))
if "error" in data:
raise ProxyError(data["error"])
chunk = ChatCompletionsChunk(
id=completion_id,
choices=[],
created=created,
model=request.model,
system_fingerprint="",
)
if data["details"] is not None:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={},
logprobs=None,
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
index=0,
)
]
else:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={"content": data["token"]["text"], "role": "assistant"},
logprobs=None,
finish_reason=None,
index=0,
)
]
yield chunk
try:
async with self.client.stream("POST", "/generate_stream", json=payload) as resp:
await self.propagate_error(resp)
async for line in resp.aiter_lines():
if line.startswith("data:"):
yield self.parse_chunk(
data=json.loads(line[len("data:") :].strip("\n")),
model=request.model,
completion_id=completion_id,
created=created,
)
except httpx.RequestError as e:
raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY)

def parse_chunk(
self, data: dict, model: str, completion_id: str, created: int
) -> ChatCompletionsChunk:
if "error" in data:
raise ProxyError(data["error"])
chunk = ChatCompletionsChunk(
id=completion_id,
choices=[],
created=created,
model=model,
system_fingerprint="",
)
if data["details"] is not None:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={},
logprobs=None,
finish_reason=self.finish_reason(data["details"]["finish_reason"]),
index=0,
)
]
else:
chunk.choices = [
ChatCompletionsChunkChoice(
delta={"content": data["token"]["text"], "role": "assistant"},
logprobs=None,
finish_reason=None,
index=0,
)
]
return chunk

def get_payload(self, request: ChatCompletionsRequest) -> Dict:
try:
Expand Down Expand Up @@ -177,6 +194,16 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str:
return text[: -len(stop_token)]
return text

@staticmethod
async def propagate_error(resp: httpx.Response) -> None:
"""
Propagates HTTP error by raising ProxyError if status is not 200.
May also raise httpx.RequestError if there are issues reading the response.
"""
if resp.status_code != 200:
resp_body = await resp.aread()
raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code)


def raise_exception(message: str):
raise jinja2.TemplateError(message)