Skip to content

Commit 69e3773

Browse files
committed
Fix iterator support for replicate.run()
Prior to 1.0.0 `replicate.run()` would return an iterator for cog models that output a type of `Iterator[Any]`. This would poll the `predictions.get` endpoint for the in progress prediction and yield any new output. When implementing the new file interface we introduced two bugs: 1. The iterator didn't convert URLs returned by the model into `FileOutput` types making it inconsistent with the non-iterator interface. This is controlled by the `use_file_outputs` argument. 2. The iterator was returned without checking if we are using the new blocking API introduced by default and controlled by the `wait` argument. This commit fixes these two issues, consistently applying the `transform_output` function to the output of the iterator as well as returning the polling iterator (`prediciton.output_iterator`) if the blocking API has not successfully returned a completed prediction. The tests have been updated to exercise both of these code paths.
1 parent 23bd903 commit 69e3773

File tree

2 files changed

+492
-219
lines changed

2 files changed

+492
-219
lines changed

replicate/run.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,27 @@ def run(
5959
if not version and (owner and name and version_id):
6060
version = Versions(client, model=(owner, name)).get(version_id)
6161

62-
if version and (iterator := _make_output_iterator(version, prediction)):
63-
return iterator
64-
62+
# Currently the "Prefer: wait" interface will return a prediction with a status
63+
# of "processing" rather than a terminal state because it returns before the
64+
# prediction has been fully processed. If request exceeds the wait time, the
65+
# prediction will be in a "starting" state.
6566
if not (is_blocking and prediction.status != "starting"):
67+
# Return a "polling" iterator if the model has an output iterator array type.
68+
print("polling why?")
69+
if version and (iterator := _make_output_iterator(client, version, prediction)):
70+
print("return iterator", iterator)
71+
return iterator
72+
6673
prediction.wait()
6774

6875
if prediction.status == "failed":
6976
raise ModelError(prediction)
7077

78+
# Return an iterator for the completed prediction when needed.
79+
if version and (iterator := _make_output_iterator(client, version, prediction)):
80+
print("iterator why?")
81+
return iterator
82+
7183
if use_file_output:
7284
return transform_output(prediction.output, client)
7385

@@ -108,12 +120,25 @@ async def async_run(
108120
if not version and (owner and name and version_id):
109121
version = await Versions(client, model=(owner, name)).async_get(version_id)
110122

111-
if version and (iterator := _make_async_output_iterator(version, prediction)):
112-
return iterator
113-
123+
# Currently the "Prefer: wait" interface will return a prediction with a status
124+
# of "processing" rather than a terminal state because it returns before the
125+
# prediction has been fully processed. If request exceeds the wait time, the
126+
# prediction will be in a "starting" state.
114127
if not (is_blocking and prediction.status != "starting"):
128+
# Return a "polling" iterator if the model has an output iterator array type.
129+
if version and (
130+
iterator := _make_async_output_iterator(client, version, prediction)
131+
):
132+
return iterator
133+
115134
await prediction.async_wait()
116135

136+
# Return an iterator for completed output if the model has an output iterator array type.
137+
if version and (
138+
iterator := _make_async_output_iterator(client, version, prediction)
139+
):
140+
return iterator
141+
117142
if prediction.status == "failed":
118143
raise ModelError(prediction)
119144

@@ -134,21 +159,48 @@ def _has_output_iterator_array_type(version: Version) -> bool:
134159

135160

136161
def _make_output_iterator(
137-
version: Version, prediction: Prediction
162+
client: "Client", version: Version, prediction: Prediction
138163
) -> Optional[Iterator[Any]]:
139-
if _has_output_iterator_array_type(version):
140-
return prediction.output_iterator()
164+
if not _has_output_iterator_array_type(version):
165+
return None
166+
167+
if prediction.status == "starting":
168+
iterator = prediction.output_iterator()
169+
elif prediction.output is not None:
170+
iterator = iter(prediction.output)
171+
else:
172+
return None
141173

142-
return None
174+
def _iterate(iter: Iterator[Any]) -> Iterator[Any]:
175+
for chunk in iter:
176+
yield transform_output(chunk, client)
177+
178+
return _iterate(iterator)
143179

144180

145181
def _make_async_output_iterator(
146-
version: Version, prediction: Prediction
182+
client: "Client", version: Version, prediction: Prediction
147183
) -> Optional[AsyncIterator[Any]]:
148-
if _has_output_iterator_array_type(version):
149-
return prediction.async_output_iterator()
184+
if not _has_output_iterator_array_type(version):
185+
return None
186+
187+
if prediction.status == "starting":
188+
iterator = prediction.async_output_iterator()
189+
elif prediction.output is not None:
190+
191+
async def _list_to_aiter(lst: list) -> AsyncIterator:
192+
for item in lst:
193+
yield item
194+
195+
iterator = _list_to_aiter(prediction.output)
196+
else:
197+
return None
198+
199+
async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator:
200+
async for chunk in iter:
201+
yield transform_output(chunk, client)
150202

151-
return None
203+
return _transform(iterator)
152204

153205

154206
__all__: List = []

0 commit comments

Comments
 (0)