Skip to content

Commit 6c427f7

Browse files
committed
Fix iterator support for replicate.run()
Prior to 1.0.0 `replicate.run()` would return an iterator for cog models that output a type of `Iterator[Any]`. This would poll the `predictions.get` endpoint for the in progress prediction and yield any new output. When implementing the new file interface we introduced two bugs: 1. The iterator didn't convert URLs returned by the model into `FileOutput` types making it inconsistent with the non-iterator interface. This is controlled by the `use_file_outputs` argument. 2. The iterator was returned without checking if we are using the new blocking API introduced by default and controlled by the `wait` argument. This commit fixes these two issues, consistently applying the `transform_output` function to the output of the iterator as well as returning the polling iterator (`prediciton.output_iterator`) if the blocking API has not successfully returned a completed prediction. The tests have been updated to exercise both of these code paths.
1 parent 23bd903 commit 6c427f7

File tree

2 files changed

+502
-219
lines changed

2 files changed

+502
-219
lines changed

replicate/run.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,24 @@ def run(
5959
if not version and (owner and name and version_id):
6060
version = Versions(client, model=(owner, name)).get(version_id)
6161

62-
if version and (iterator := _make_output_iterator(version, prediction)):
63-
return iterator
64-
62+
# Currently the "Prefer: wait" interface will return a prediction with a status
63+
# of "processing" rather than a terminal state because it returns before the
64+
# prediction has been fully processed. If request exceeds the wait time, the
65+
# prediction will be in a "starting" state.
6566
if not (is_blocking and prediction.status != "starting"):
67+
# Return a "polling" iterator if the model has an output iterator array type.
68+
if version and (iterator := _make_output_iterator(client, version, prediction)):
69+
return iterator
70+
6671
prediction.wait()
6772

6873
if prediction.status == "failed":
6974
raise ModelError(prediction)
7075

76+
# Return an iterator for the completed prediction when needed.
77+
if version and (iterator := _make_output_iterator(client, version, prediction)):
78+
return iterator
79+
7180
if use_file_output:
7281
return transform_output(prediction.output, client)
7382

@@ -108,12 +117,25 @@ async def async_run(
108117
if not version and (owner and name and version_id):
109118
version = await Versions(client, model=(owner, name)).async_get(version_id)
110119

111-
if version and (iterator := _make_async_output_iterator(version, prediction)):
112-
return iterator
113-
120+
# Currently the "Prefer: wait" interface will return a prediction with a status
121+
# of "processing" rather than a terminal state because it returns before the
122+
# prediction has been fully processed. If request exceeds the wait time, the
123+
# prediction will be in a "starting" state.
114124
if not (is_blocking and prediction.status != "starting"):
125+
# Return a "polling" iterator if the model has an output iterator array type.
126+
if version and (
127+
iterator := _make_async_output_iterator(client, version, prediction)
128+
):
129+
return iterator
130+
115131
await prediction.async_wait()
116132

133+
# Return an iterator for completed output if the model has an output iterator array type.
134+
if version and (
135+
iterator := _make_async_output_iterator(client, version, prediction)
136+
):
137+
return iterator
138+
117139
if prediction.status == "failed":
118140
raise ModelError(prediction)
119141

@@ -134,21 +156,48 @@ def _has_output_iterator_array_type(version: Version) -> bool:
134156

135157

136158
def _make_output_iterator(
137-
version: Version, prediction: Prediction
159+
client: "Client", version: Version, prediction: Prediction
138160
) -> Optional[Iterator[Any]]:
139-
if _has_output_iterator_array_type(version):
140-
return prediction.output_iterator()
161+
if not _has_output_iterator_array_type(version):
162+
return None
163+
164+
if prediction.status == "starting":
165+
iterator = prediction.output_iterator()
166+
elif prediction.output is not None:
167+
iterator = iter(prediction.output)
168+
else:
169+
return None
141170

142-
return None
171+
def _iterate(iter: Iterator[Any]) -> Iterator[Any]:
172+
for chunk in iter:
173+
yield transform_output(chunk, client)
174+
175+
return _iterate(iterator)
143176

144177

145178
def _make_async_output_iterator(
146-
version: Version, prediction: Prediction
179+
client: "Client", version: Version, prediction: Prediction
147180
) -> Optional[AsyncIterator[Any]]:
148-
if _has_output_iterator_array_type(version):
149-
return prediction.async_output_iterator()
181+
if not _has_output_iterator_array_type(version):
182+
return None
183+
184+
if prediction.status == "starting":
185+
iterator = prediction.async_output_iterator()
186+
elif prediction.output is not None:
187+
188+
async def _list_to_aiter(lst: list) -> AsyncIterator:
189+
for item in lst:
190+
yield item
191+
192+
iterator = _list_to_aiter(prediction.output)
193+
else:
194+
return None
195+
196+
async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator:
197+
async for chunk in iter:
198+
yield transform_output(chunk, client)
150199

151-
return None
200+
return _transform(iterator)
152201

153202

154203
__all__: List = []

0 commit comments

Comments
 (0)