Skip to content

Commit b2535ac

Browse files
authored
Update async_run to use async output iterator (#230)
Resolves #229 --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent a3784c4 commit b2535ac

File tree

4 files changed

+72
-13
lines changed

4 files changed

+72
-13
lines changed

.vscode/settings.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
"editor.defaultFormatter": "charliermarsh.ruff",
1111
"editor.formatOnSave": true,
1212
"editor.codeActionsOnSave": {
13-
"source.fixAll": true,
14-
"source.organizeImports": true
13+
"source.fixAll": "explicit",
14+
"source.organizeImports": "explicit"
1515
}
1616
},
1717
"python.languageServer": "Pylance",

replicate/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import os
23
import random
34
import time
@@ -151,7 +152,7 @@ async def async_run(
151152
ref: str,
152153
input: Optional[Dict[str, Any]] = None,
153154
**params: Unpack["Predictions.CreatePredictionParams"],
154-
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
155+
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
155156
"""
156157
Run a model and wait for its output asynchronously.
157158
"""
@@ -298,7 +299,7 @@ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
298299
response.close()
299300

300301
sleep_for = self._calculate_sleep(attempts_made, response.headers)
301-
time.sleep(sleep_for)
302+
await asyncio.sleep(sleep_for)
302303

303304
response = await self._wrapped_transport.handle_async_request(request) # type: ignore
304305

replicate/prediction.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,17 @@
22
import re
33
import time
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Union
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
AsyncIterator,
9+
Dict,
10+
Iterator,
11+
List,
12+
Literal,
13+
Optional,
14+
Union,
15+
)
616

717
from typing_extensions import NotRequired, TypedDict, Unpack
818

@@ -208,6 +218,30 @@ def output_iterator(self) -> Iterator[Any]:
208218
for output in new_output:
209219
yield output
210220

221+
async def async_output_iterator(self) -> AsyncIterator[Any]:
222+
"""
223+
Return an asynchronous iterator of the prediction output.
224+
"""
225+
226+
# TODO: check output is list
227+
previous_output = self.output or []
228+
while self.status not in ["succeeded", "failed", "canceled"]:
229+
output = self.output or []
230+
new_output = output[len(previous_output) :]
231+
for item in new_output:
232+
yield item
233+
previous_output = output
234+
await asyncio.sleep(self._client.poll_interval) # pylint: disable=no-member
235+
await self.async_reload()
236+
237+
if self.status == "failed":
238+
raise ModelError(self.error)
239+
240+
output = self.output or []
241+
new_output = output[len(previous_output) :]
242+
for output in new_output:
243+
yield output
244+
211245

212246
class Predictions(Namespace):
213247
"""

replicate/run.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
1-
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
1+
from typing import (
2+
TYPE_CHECKING,
3+
Any,
4+
AsyncIterator,
5+
Dict,
6+
Iterator,
7+
List,
8+
Optional,
9+
Union,
10+
)
211

312
from typing_extensions import Unpack
413

@@ -59,7 +68,7 @@ async def async_run(
5968
ref: Union["Model", "Version", "ModelVersionIdentifier", str],
6069
input: Optional[Dict[str, Any]] = None,
6170
**params: Unpack["Predictions.CreatePredictionParams"],
62-
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
71+
) -> Union[Any, AsyncIterator[Any]]: # noqa: ANN401
6372
"""
6473
Run a model and wait for its output asynchronously.
6574
"""
@@ -82,7 +91,7 @@ async def async_run(
8291
if not version and (owner and name and version_id):
8392
version = await Versions(client, model=(owner, name)).async_get(version_id)
8493

85-
if version and (iterator := _make_output_iterator(version, prediction)):
94+
if version and (iterator := _make_async_output_iterator(version, prediction)):
8695
return iterator
8796

8897
await prediction.async_wait()
@@ -93,17 +102,32 @@ async def async_run(
93102
return prediction.output
94103

95104

96-
def _make_output_iterator(
97-
version: Version, prediction: Prediction
98-
) -> Optional[Iterator[Any]]:
105+
def _has_output_iterator_array_type(version: Version) -> bool:
99106
schema = make_schema_backwards_compatible(
100107
version.openapi_schema, version.cog_version
101108
)
102-
output = schema["components"]["schemas"]["Output"]
103-
if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator":
109+
output = schema.get("components", {}).get("schemas", {}).get("Output", {})
110+
return (
111+
output.get("type") == "array" and output.get("x-cog-array-type") == "iterator"
112+
)
113+
114+
115+
def _make_output_iterator(
116+
version: Version, prediction: Prediction
117+
) -> Optional[Iterator[Any]]:
118+
if _has_output_iterator_array_type(version):
104119
return prediction.output_iterator()
105120

106121
return None
107122

108123

124+
def _make_async_output_iterator(
125+
version: Version, prediction: Prediction
126+
) -> Optional[AsyncIterator[Any]]:
127+
if _has_output_iterator_array_type(version):
128+
return prediction.async_output_iterator()
129+
130+
return None
131+
132+
109133
__all__: List = []

0 commit comments

Comments
 (0)