|
6 | 6 | import httpx
|
7 | 7 | import jinja2
|
8 | 8 | import jinja2.sandbox
|
| 9 | +from fastapi import status |
9 | 10 |
|
10 | 11 | from dstack._internal.proxy.errors import ProxyError
|
11 | 12 | from dstack._internal.proxy.schemas.model_proxy import (
|
@@ -38,9 +39,12 @@ def __init__(self, http_client: httpx.AsyncClient, chat_template: str, eos_token
|
38 | 39 |
|
39 | 40 | async def generate(self, request: ChatCompletionsRequest) -> ChatCompletionsResponse:
|
40 | 41 | payload = self.get_payload(request)
|
41 |
| - resp = await self.client.post("/generate", json=payload) |
42 |
| - if resp.status_code != 200: |
43 |
| - raise ProxyError(resp.text) # TODO(egor-s) |
| 42 | + try: |
| 43 | + resp = await self.client.post("/generate", json=payload) |
| 44 | + await self.propagate_error(resp) |
| 45 | + except httpx.RequestError as e: |
| 46 | + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) |
| 47 | + |
44 | 48 | data = resp.json()
|
45 | 49 |
|
46 | 50 | choices = [
|
@@ -91,38 +95,51 @@ async def stream(self, request: ChatCompletionsRequest) -> AsyncIterator[ChatCom
|
91 | 95 | created = int(datetime.datetime.utcnow().timestamp())
|
92 | 96 |
|
93 | 97 | payload = self.get_payload(request)
|
94 |
| - async with self.client.stream("POST", "/generate_stream", json=payload) as resp: |
95 |
| - async for line in resp.aiter_lines(): |
96 |
| - if line.startswith("data:"): |
97 |
| - data = json.loads(line[len("data:") :].strip("\n")) |
98 |
| - if "error" in data: |
99 |
| - raise ProxyError(data["error"]) |
100 |
| - chunk = ChatCompletionsChunk( |
101 |
| - id=completion_id, |
102 |
| - choices=[], |
103 |
| - created=created, |
104 |
| - model=request.model, |
105 |
| - system_fingerprint="", |
106 |
| - ) |
107 |
| - if data["details"] is not None: |
108 |
| - chunk.choices = [ |
109 |
| - ChatCompletionsChunkChoice( |
110 |
| - delta={}, |
111 |
| - logprobs=None, |
112 |
| - finish_reason=self.finish_reason(data["details"]["finish_reason"]), |
113 |
| - index=0, |
114 |
| - ) |
115 |
| - ] |
116 |
| - else: |
117 |
| - chunk.choices = [ |
118 |
| - ChatCompletionsChunkChoice( |
119 |
| - delta={"content": data["token"]["text"], "role": "assistant"}, |
120 |
| - logprobs=None, |
121 |
| - finish_reason=None, |
122 |
| - index=0, |
123 |
| - ) |
124 |
| - ] |
125 |
| - yield chunk |
| 98 | + try: |
| 99 | + async with self.client.stream("POST", "/generate_stream", json=payload) as resp: |
| 100 | + await self.propagate_error(resp) |
| 101 | + async for line in resp.aiter_lines(): |
| 102 | + if line.startswith("data:"): |
| 103 | + yield self.parse_chunk( |
| 104 | + data=json.loads(line[len("data:") :].strip("\n")), |
| 105 | + model=request.model, |
| 106 | + completion_id=completion_id, |
| 107 | + created=created, |
| 108 | + ) |
| 109 | + except httpx.RequestError as e: |
| 110 | + raise ProxyError(f"Error requesting model: {e!r}", status.HTTP_502_BAD_GATEWAY) |
| 111 | + |
| 112 | + def parse_chunk( |
| 113 | + self, data: dict, model: str, completion_id: str, created: int |
| 114 | + ) -> ChatCompletionsChunk: |
| 115 | + if "error" in data: |
| 116 | + raise ProxyError(data["error"]) |
| 117 | + chunk = ChatCompletionsChunk( |
| 118 | + id=completion_id, |
| 119 | + choices=[], |
| 120 | + created=created, |
| 121 | + model=model, |
| 122 | + system_fingerprint="", |
| 123 | + ) |
| 124 | + if data["details"] is not None: |
| 125 | + chunk.choices = [ |
| 126 | + ChatCompletionsChunkChoice( |
| 127 | + delta={}, |
| 128 | + logprobs=None, |
| 129 | + finish_reason=self.finish_reason(data["details"]["finish_reason"]), |
| 130 | + index=0, |
| 131 | + ) |
| 132 | + ] |
| 133 | + else: |
| 134 | + chunk.choices = [ |
| 135 | + ChatCompletionsChunkChoice( |
| 136 | + delta={"content": data["token"]["text"], "role": "assistant"}, |
| 137 | + logprobs=None, |
| 138 | + finish_reason=None, |
| 139 | + index=0, |
| 140 | + ) |
| 141 | + ] |
| 142 | + return chunk |
126 | 143 |
|
127 | 144 | def get_payload(self, request: ChatCompletionsRequest) -> Dict:
|
128 | 145 | try:
|
@@ -177,6 +194,16 @@ def trim_stop_tokens(text: str, stop_tokens: List[str]) -> str:
|
177 | 194 | return text[: -len(stop_token)]
|
178 | 195 | return text
|
179 | 196 |
|
| 197 | + @staticmethod |
| 198 | + async def propagate_error(resp: httpx.Response) -> None: |
| 199 | + """ |
| 200 | + Propagates HTTP error by raising ProxyError if status is not 200. |
| 201 | + May also raise httpx.RequestError if there are issues reading the response. |
| 202 | + """ |
| 203 | + if resp.status_code != 200: |
| 204 | + resp_body = await resp.aread() |
| 205 | + raise ProxyError(resp_body.decode(errors="replace"), code=resp.status_code) |
| 206 | + |
180 | 207 |
|
181 | 208 | def raise_exception(message: str):
|
182 | 209 | raise jinja2.TemplateError(message)
|
0 commit comments