From 6c427f71c7f45d71f6fcf807c107257d271efff5 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Thu, 24 Oct 2024 12:23:37 +0100 Subject: [PATCH 1/6] Fix iterator support for replicate.run() Prior to 1.0.0 `replicate.run()` would return an iterator for cog models that output a type of `Iterator[Any]`. This would poll the `predictions.get` endpoint for the in progress prediction and yield any new output. When implementing the new file interface we introduced two bugs: 1. The iterator didn't convert URLs returned by the model into `FileOutput` types making it inconsistent with the non-iterator interface. This is controlled by the `use_file_outputs` argument. 2. The iterator was returned without checking if we are using the new blocking API introduced by default and controlled by the `wait` argument. This commit fixes these two issues, consistently applying the `transform_output` function to the output of the iterator as well as returning the polling iterator (`prediciton.output_iterator`) if the blocking API has not successfully returned a completed prediction. The tests have been updated to exercise both of these code paths. --- replicate/run.py | 77 +++++- tests/test_run.py | 644 +++++++++++++++++++++++++++++++--------------- 2 files changed, 502 insertions(+), 219 deletions(-) diff --git a/replicate/run.py b/replicate/run.py index 3b6bddb..3f94311 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -59,15 +59,24 @@ def run( if not version and (owner and name and version_id): version = Versions(client, model=(owner, name)).get(version_id) - if version and (iterator := _make_output_iterator(version, prediction)): - return iterator - + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, the + # prediction will be in a "starting" state. if not (is_blocking and prediction.status != "starting"): + # Return a "polling" iterator if the model has an output iterator array type. + if version and (iterator := _make_output_iterator(client, version, prediction)): + return iterator + prediction.wait() if prediction.status == "failed": raise ModelError(prediction) + # Return an iterator for the completed prediction when needed. + if version and (iterator := _make_output_iterator(client, version, prediction)): + return iterator + if use_file_output: return transform_output(prediction.output, client) @@ -108,12 +117,25 @@ async def async_run( if not version and (owner and name and version_id): version = await Versions(client, model=(owner, name)).async_get(version_id) - if version and (iterator := _make_async_output_iterator(version, prediction)): - return iterator - + # Currently the "Prefer: wait" interface will return a prediction with a status + # of "processing" rather than a terminal state because it returns before the + # prediction has been fully processed. If request exceeds the wait time, the + # prediction will be in a "starting" state. if not (is_blocking and prediction.status != "starting"): + # Return a "polling" iterator if the model has an output iterator array type. + if version and ( + iterator := _make_async_output_iterator(client, version, prediction) + ): + return iterator + await prediction.async_wait() + # Return an iterator for completed output if the model has an output iterator array type. + if version and ( + iterator := _make_async_output_iterator(client, version, prediction) + ): + return iterator + if prediction.status == "failed": raise ModelError(prediction) @@ -134,21 +156,48 @@ def _has_output_iterator_array_type(version: Version) -> bool: def _make_output_iterator( - version: Version, prediction: Prediction + client: "Client", version: Version, prediction: Prediction ) -> Optional[Iterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.output_iterator() + if not _has_output_iterator_array_type(version): + return None + + if prediction.status == "starting": + iterator = prediction.output_iterator() + elif prediction.output is not None: + iterator = iter(prediction.output) + else: + return None - return None + def _iterate(iter: Iterator[Any]) -> Iterator[Any]: + for chunk in iter: + yield transform_output(chunk, client) + + return _iterate(iterator) def _make_async_output_iterator( - version: Version, prediction: Prediction + client: "Client", version: Version, prediction: Prediction ) -> Optional[AsyncIterator[Any]]: - if _has_output_iterator_array_type(version): - return prediction.async_output_iterator() + if not _has_output_iterator_array_type(version): + return None + + if prediction.status == "starting": + iterator = prediction.async_output_iterator() + elif prediction.output is not None: + + async def _list_to_aiter(lst: list) -> AsyncIterator: + for item in lst: + yield item + + iterator = _list_to_aiter(prediction.output) + else: + return None + + async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator: + async for chunk in iter: + yield transform_output(chunk, client) - return None + return _transform(iterator) __all__: List = [] diff --git a/tests/test_run.py b/tests/test_run.py index 8eac091..c847c71 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import asyncio import sys -from typing import cast +from typing import AsyncIterator, Iterator, Optional, cast import httpx import pytest @@ -12,6 +12,19 @@ from replicate.helpers import FileOutput +async def anext(async_iterator, default=None): + """ + `anext` is only available from Python 3.10 onwards so here + we use an equivalent to ensure tests work in earlier versions. + """ + try: + return await async_iterator.__anext__() + except StopAsyncIteration: + if default is None: + raise + return default + + @pytest.mark.vcr("run.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) @@ -48,6 +61,274 @@ async def test_run(async_flag, record_mode): assert output[0].url.startswith("https://") +@pytest.mark.asyncio +async def test_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output1 = next(stream) + output2 = next(stream) + with pytest.raises(StopIteration): + next(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_async_run_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + wait=False, + ), + ) + + output1 = await anext(stream) + output2 = await anext(stream) + with pytest.raises(StopAsyncIteration): + await anext(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output1 = next(stream) + output2 = next(stream) + with pytest.raises(StopIteration): + next(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + +@pytest.mark.asyncio +async def test_async_run_blocking_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + "world!", + ], + ), + ) + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[FileOutput], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output1 = await anext(stream) + output2 = await anext(stream) + with pytest.raises(StopAsyncIteration): + await anext(stream) + + assert output1 == "Hello, " + assert output2 == "world!" + + @pytest.mark.vcr("run__concurrently.yaml") @pytest.mark.asyncio @pytest.mark.skipif( @@ -104,35 +385,17 @@ async def test_run_with_invalid_token(): @pytest.mark.asyncio async def test_run_version_with_invalid_cog_version(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": "Hello, world!" if status == "succeeded" else None, - "error": None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("succeeded"), + json=_prediction_with_status("succeeded", "Hello, world!"), ) ) router.route( @@ -141,37 +404,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2022-03-16T00:35:56.210272Z", - "cog_version": "dev", - "openapi_schema": { - "openapi": "3.0.2", - "info": {"title": "Cog", "version": "0.1.0"}, - "paths": {}, - "components": { - "schemas": { - "Input": { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "The text input", - }, - }, - }, - "Output": { - "type": "string", - "title": "Output", - }, - } - }, - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -193,35 +426,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_model_error(mock_replicate_api_token): - def prediction_with_status(status: str) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": None, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status("failed"), + json=_prediction_with_status("failed"), ) ) router.route( @@ -230,14 +445,7 @@ def prediction_with_status(status: str) -> dict: ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(host="api.replicate.com").pass_through() @@ -262,37 +470,17 @@ def prediction_with_status(status: str) -> dict: @pytest.mark.asyncio async def test_run_with_file_output(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -303,14 +491,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/output.txt").mock( @@ -347,31 +528,11 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_blocking(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") predictions_create_route = router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status( + json=_prediction_with_status( "processing", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==" ), ) @@ -379,7 +540,7 @@ def prediction_with_status( predictions_get_route = router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "https://api.replicate.com/v1/assets/output.txt" ), ) @@ -387,26 +548,14 @@ def prediction_with_status( router.route( method="GET", path="/models/test/example/versions/v1", - ).mock( - return_value=httpx.Response( - 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, - ) - ) + ).mock(return_value=httpx.Response(201, json=_version_with_schema())) client = Client( api_token="test-token", transport=httpx.MockTransport(router.handler) ) client.poll_interval = 0.001 output = cast( - list[FileOutput], + FileOutput, client.run( "test/example:v1", input={ @@ -434,37 +583,17 @@ def prediction_with_status( @pytest.mark.asyncio async def test_run_with_file_output_array(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", - }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } - router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", [ "https://api.replicate.com/v1/assets/hello.txt", @@ -479,14 +608,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) router.route(method="GET", path="/assets/hello.txt").mock( @@ -521,38 +643,103 @@ def prediction_with_status( @pytest.mark.asyncio -async def test_run_with_file_output_data_uri(mock_replicate_api_token): - def prediction_with_status( - status: str, output: str | list[str] | None = None - ) -> dict: - return { - "id": "p1", - "model": "test/example", - "version": "v1", - "urls": { - "get": "https://api.replicate.com/v1/predictions/p1", - "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", +async def test_run_with_file_output_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + router.route(method="POST", path="/predictions").mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status("starting"), + ) + ) + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "https://api.replicate.com/v1/assets/hello.txt", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "https://api.replicate.com/v1/assets/hello.txt", + "https://api.replicate.com/v1/assets/world.txt", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + "format": "uri", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + router.route(method="GET", path="/assets/hello.txt").mock( + return_value=httpx.Response(200, content=b"Hello,") + ) + router.route(method="GET", path="/assets/world.txt").mock( + return_value=httpx.Response(200, content=b" world!") + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[FileOutput], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", }, - "created_at": "2023-10-05T12:00:00.000000Z", - "source": "api", - "status": status, - "input": {"text": "world"}, - "output": output, - "error": "OOM" if status == "failed" else None, - "logs": "", - } + use_file_output=True, + wait=False, + ), + ) + + output1 = next(stream) + output2 = next(stream) + + assert output1.url == "https://api.replicate.com/v1/assets/hello.txt" + assert output2.url == "https://api.replicate.com/v1/assets/world.txt" + + assert output1.read() == b"Hello," + assert output2.read() == b" world!" + +@pytest.mark.asyncio +async def test_run_with_file_output_data_uri(mock_replicate_api_token): router = respx.Router(base_url="https://api.replicate.com/v1") router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("starting"), + json=_prediction_with_status("starting"), ) ) router.route(method="GET", path="/predictions/p1").mock( return_value=httpx.Response( 200, - json=prediction_with_status( + json=_prediction_with_status( "succeeded", "data:text/plain;base64,SGVsbG8sIHdvcmxkIQ==", ), @@ -564,14 +751,7 @@ def prediction_with_status( ).mock( return_value=httpx.Response( 201, - json={ - "id": "f2d6b24e6002f25f77ae89c2b0a5987daa6d0bf751b858b94b8416e8542434d1", - "created_at": "2024-07-18T00:35:56.210272Z", - "cog_version": "0.9.10", - "openapi_schema": { - "openapi": "3.0.2", - }, - }, + json=_version_with_schema(), ) ) @@ -600,3 +780,57 @@ def prediction_with_status( assert await output.aread() == b"Hello, world!" async for chunk in output: assert chunk == b"Hello, world!" + + +def _prediction_with_status(status: str, output: str | list[str] | None = None) -> dict: + return { + "id": "p1", + "model": "test/example", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2023-10-05T12:00:00.000000Z", + "source": "api", + "status": status, + "input": {"text": "world"}, + "output": output, + "error": "OOM" if status == "failed" else None, + "logs": "", + } + + +def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None): + return { + "id": id, + "created_at": "2022-03-16T00:35:56.210272Z", + "cog_version": "dev", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "paths": {}, + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "The text input", + }, + }, + }, + "Output": output_schema + or { + "type": "string", + "title": "Output", + }, + } + }, + }, + } From 796ec8f4e44620dec8f3c86d1a7a438e1233cd05 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 25 Oct 2024 12:21:32 +0100 Subject: [PATCH 2/6] Improve readability of conditional that determines if polling is required --- replicate/run.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/replicate/run.py b/replicate/run.py index 3f94311..f25dace 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -61,9 +61,14 @@ def run( # Currently the "Prefer: wait" interface will return a prediction with a status # of "processing" rather than a terminal state because it returns before the - # prediction has been fully processed. If request exceeds the wait time, the - # prediction will be in a "starting" state. - if not (is_blocking and prediction.status != "starting"): + # prediction has been fully processed. If request exceeds the wait time, even if + # it is actually processing, the prediction will be in a "starting" state. + # + # We should fix this in the blocking API itself. Predictions that are done should + # be in a terminal state and predictions that are processing should be in state + # "processing". + in_terminal_state = is_blocking and prediction.status != "starting" + if not in_terminal_state: # Return a "polling" iterator if the model has an output iterator array type. if version and (iterator := _make_output_iterator(client, version, prediction)): return iterator @@ -119,9 +124,14 @@ async def async_run( # Currently the "Prefer: wait" interface will return a prediction with a status # of "processing" rather than a terminal state because it returns before the - # prediction has been fully processed. If request exceeds the wait time, the - # prediction will be in a "starting" state. - if not (is_blocking and prediction.status != "starting"): + # prediction has been fully processed. If request exceeds the wait time, even if + # it is actually processing, the prediction will be in a "starting" state. + # + # We should fix this in the blocking API itself. Predictions that are done should + # be in a terminal state and predictions that are processing should be in state + # "processing". + in_terminal_state = is_blocking and prediction.status != "starting" + if not in_terminal_state: # Return a "polling" iterator if the model has an output iterator array type. if version and ( iterator := _make_async_output_iterator(client, version, prediction) From 147a886a294dee9dfc0fc40e878d565408b53b92 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 25 Oct 2024 12:22:09 +0100 Subject: [PATCH 3/6] Seperate iterator type check from iterator creation --- replicate/run.py | 87 +++++++++++++++++------------------------------- 1 file changed, 30 insertions(+), 57 deletions(-) diff --git a/replicate/run.py b/replicate/run.py index f25dace..3aa4bb8 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -70,8 +70,11 @@ def run( in_terminal_state = is_blocking and prediction.status != "starting" if not in_terminal_state: # Return a "polling" iterator if the model has an output iterator array type. - if version and (iterator := _make_output_iterator(client, version, prediction)): - return iterator + if version and _has_output_iterator_array_type(version): + return ( + transform_output(chunk, client) + for chunk in prediction.output_iterator() + ) prediction.wait() @@ -79,8 +82,12 @@ def run( raise ModelError(prediction) # Return an iterator for the completed prediction when needed. - if version and (iterator := _make_output_iterator(client, version, prediction)): - return iterator + if ( + version + and _has_output_iterator_array_type(version) + and prediction.output is not None + ): + return (transform_output(chunk, client) for chunk in prediction.output) if use_file_output: return transform_output(prediction.output, client) @@ -133,22 +140,28 @@ async def async_run( in_terminal_state = is_blocking and prediction.status != "starting" if not in_terminal_state: # Return a "polling" iterator if the model has an output iterator array type. - if version and ( - iterator := _make_async_output_iterator(client, version, prediction) - ): - return iterator + if version and _has_output_iterator_array_type(version): + return ( + transform_output(chunk, client) + async for chunk in prediction.async_output_iterator() + ) await prediction.async_wait() - # Return an iterator for completed output if the model has an output iterator array type. - if version and ( - iterator := _make_async_output_iterator(client, version, prediction) - ): - return iterator - if prediction.status == "failed": raise ModelError(prediction) + # Return an iterator for completed output if the model has an output iterator array type. + if ( + version + and _has_output_iterator_array_type(version) + and prediction.output is not None + ): + return ( + transform_output(chunk, client) + async for chunk in _make_async_iterator(prediction.output) + ) + if use_file_output: return transform_output(prediction.output, client) @@ -165,49 +178,9 @@ def _has_output_iterator_array_type(version: Version) -> bool: ) -def _make_output_iterator( - client: "Client", version: Version, prediction: Prediction -) -> Optional[Iterator[Any]]: - if not _has_output_iterator_array_type(version): - return None - - if prediction.status == "starting": - iterator = prediction.output_iterator() - elif prediction.output is not None: - iterator = iter(prediction.output) - else: - return None - - def _iterate(iter: Iterator[Any]) -> Iterator[Any]: - for chunk in iter: - yield transform_output(chunk, client) - - return _iterate(iterator) - - -def _make_async_output_iterator( - client: "Client", version: Version, prediction: Prediction -) -> Optional[AsyncIterator[Any]]: - if not _has_output_iterator_array_type(version): - return None - - if prediction.status == "starting": - iterator = prediction.async_output_iterator() - elif prediction.output is not None: - - async def _list_to_aiter(lst: list) -> AsyncIterator: - for item in lst: - yield item - - iterator = _list_to_aiter(prediction.output) - else: - return None - - async def _transform(iter: AsyncIterator[Any]) -> AsyncIterator: - async for chunk in iter: - yield transform_output(chunk, client) - - return _transform(iterator) +async def _make_async_iterator(list: list) -> AsyncIterator: + for item in list: + yield item __all__: List = [] From 18f49ec0e07843de3c1671fe5a5f0d38c1da157e Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 25 Oct 2024 12:22:58 +0100 Subject: [PATCH 4/6] Cleanup iterator assertions in test_run.py --- tests/test_run.py | 66 ++++++++++++----------------------------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/tests/test_run.py b/tests/test_run.py index c847c71..fd20338 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -12,19 +12,6 @@ from replicate.helpers import FileOutput -async def anext(async_iterator, default=None): - """ - `anext` is only available from Python 3.10 onwards so here - we use an equivalent to ensure tests work in earlier versions. - """ - try: - return await async_iterator.__anext__() - except StopAsyncIteration: - if default is None: - raise - return default - - @pytest.mark.vcr("run.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) @@ -118,7 +105,7 @@ async def test_run_with_iterator(mock_replicate_api_token): client.poll_interval = 0.001 stream = cast( - Iterator[FileOutput], + Iterator[str], client.run( "test/example:v1", input={ @@ -128,13 +115,8 @@ async def test_run_with_iterator(mock_replicate_api_token): ), ) - output1 = next(stream) - output2 = next(stream) - with pytest.raises(StopIteration): - next(stream) - - assert output1 == "Hello, " - assert output2 == "world!" + output = [chunk for chunk in stream] + assert output == ["Hello, ", "world!"] @pytest.mark.asyncio @@ -204,13 +186,8 @@ async def test_async_run_with_iterator(mock_replicate_api_token): ), ) - output1 = await anext(stream) - output2 = await anext(stream) - with pytest.raises(StopAsyncIteration): - await anext(stream) - - assert output1 == "Hello, " - assert output2 == "world!" + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] @pytest.mark.asyncio @@ -253,7 +230,7 @@ async def test_run_blocking_with_iterator(mock_replicate_api_token): client.poll_interval = 0.001 stream = cast( - Iterator[FileOutput], + Iterator[str], client.run( "test/example:v1", input={ @@ -262,13 +239,9 @@ async def test_run_blocking_with_iterator(mock_replicate_api_token): ), ) - output1 = next(stream) - output2 = next(stream) - with pytest.raises(StopIteration): - next(stream) + assert list(stream) == ["Hello, ", "world!"] + - assert output1 == "Hello, " - assert output2 == "world!" @pytest.mark.asyncio @@ -320,13 +293,8 @@ async def test_async_run_blocking_with_iterator(mock_replicate_api_token): ), ) - output1 = await anext(stream) - output2 = await anext(stream) - with pytest.raises(StopAsyncIteration): - await anext(stream) - - assert output1 == "Hello, " - assert output2 == "world!" + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] @pytest.mark.vcr("run__concurrently.yaml") @@ -717,14 +685,14 @@ async def test_run_with_file_output_iterator(mock_replicate_api_token): ), ) - output1 = next(stream) - output2 = next(stream) - - assert output1.url == "https://api.replicate.com/v1/assets/hello.txt" - assert output2.url == "https://api.replicate.com/v1/assets/world.txt" + expected = [ + {"url": "https://api.replicate.com/v1/assets/hello.txt", "content": b"Hello,"}, + {"url": "https://api.replicate.com/v1/assets/world.txt", "content": b" world!"}, + ] - assert output1.read() == b"Hello," - assert output2.read() == b" world!" + for output, expect in zip(stream, expected): + assert output.url == expect["url"] + assert output.read() == expect["content"] @pytest.mark.asyncio From f5c88b80c4e67dc2d72e158c815411242fd46385 Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 25 Oct 2024 12:23:11 +0100 Subject: [PATCH 5/6] Add tests for blocking fallback state --- tests/test_run.py | 145 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/tests/test_run.py b/tests/test_run.py index fd20338..beb7f6e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -242,6 +242,151 @@ async def test_run_blocking_with_iterator(mock_replicate_api_token): assert list(stream) == ["Hello, ", "world!"] +@pytest.mark.asyncio +async def test_run_blocking_timeout_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + # Initial request times out and returns "starting" state. + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "starting", + ), + ) + ) + # Client should start polling for the prediction. + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + Iterator[str], + client.run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + assert list(stream) == ["Hello, ", "world!"] + + +@pytest.mark.asyncio +async def test_async_run_blocking_timeout_with_iterator(mock_replicate_api_token): + router = respx.Router(base_url="https://api.replicate.com/v1") + # Initial request times out and returns "starting" state. + router.route(method="POST", path="/predictions", headers={"Prefer": "wait"}).mock( + return_value=httpx.Response( + 201, + json=_prediction_with_status( + "starting", + ), + ) + ) + # Client should start polling for the prediction. + router.route(method="GET", path="/predictions/p1").mock( + side_effect=[ + httpx.Response( + 200, + json=_prediction_with_status( + "processing", + [ + "Hello, ", + ], + ), + ), + httpx.Response( + 200, + json=_prediction_with_status( + "succeeded", + [ + "Hello, ", + "world!", + ], + ), + ), + ] + ) + router.route( + method="GET", + path="/models/test/example/versions/v1", + ).mock( + return_value=httpx.Response( + 201, + json=_version_with_schema( + "p1", + { + "type": "array", + "items": { + "type": "string", + }, + "x-cog-array-type": "iterator", + }, + ), + ) + ) + + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + client.poll_interval = 0.001 + + stream = cast( + AsyncIterator[str], + await client.async_run( + "test/example:v1", + input={ + "text": "Hello, world!", + }, + ), + ) + + output = [chunk async for chunk in stream] + assert output == ["Hello, ", "world!"] @pytest.mark.asyncio From d8033fa11b405523d1f514ccc80be2a483b8384f Mon Sep 17 00:00:00 2001 From: Aron Carroll Date: Fri, 25 Oct 2024 12:27:51 +0100 Subject: [PATCH 6/6] Linting --- replicate/run.py | 1 - 1 file changed, 1 deletion(-) diff --git a/replicate/run.py b/replicate/run.py index 3aa4bb8..19db492 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -15,7 +15,6 @@ from replicate.exceptions import ModelError from replicate.helpers import transform_output from replicate.model import Model -from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible from replicate.version import Version, Versions