diff --git a/replicate/prediction.py b/replicate/prediction.py index 2d59791e..b1825c30 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -11,7 +11,9 @@ List, Literal, Optional, + Tuple, Union, + overload, ) from typing_extensions import NotRequired, TypedDict, Unpack @@ -31,6 +33,8 @@ if TYPE_CHECKING: from replicate.client import Client + from replicate.deployment import Deployment + from replicate.model import Model from replicate.stream import ServerSentEvent @@ -380,21 +384,82 @@ class CreatePredictionParams(TypedDict): stream: NotRequired[bool] """Enable streaming of prediction output.""" + @overload def create( self, version: Union[Version, str], input: Optional[Dict[str, Any]], **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + def create( + self, + *, + model: Union[str, Tuple[str, str], "Model"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + def create( + self, + *, + deployment: Union[str, Tuple[str, str], "Deployment"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + def create( # type: ignore + self, + *args, + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[Version, str, "Version"]] = None, + deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None, + input: Optional[Dict[str, Any]] = None, + **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: """ - Create a new prediction for the specified model version. + Create a new prediction for the specified model, version, or deployment. """ + if args: + version = args[0] if len(args) > 0 else None + input = args[1] if len(args) > 1 else input + + if sum(bool(x) for x in [model, version, deployment]) != 1: + raise ValueError( + "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + if model is not None: + from replicate.model import ( # pylint: disable=import-outside-toplevel + Models, + ) + + return Models(self._client).predictions.create( + model=model, + input=input or {}, + **params, + ) + + if deployment is not None: + from replicate.deployment import ( # pylint: disable=import-outside-toplevel + Deployments, + ) + + return Deployments(self._client).predictions.create( + deployment=deployment, + input=input or {}, + **params, + ) + body = _create_prediction_body( version, input, **params, ) + resp = self._client._request( "POST", "/v1/predictions", @@ -403,21 +468,82 @@ def create( return _json_to_prediction(self._client, resp.json()) + @overload async def async_create( self, version: Union[Version, str], input: Optional[Dict[str, Any]], **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + async def async_create( + self, + *, + model: Union[str, Tuple[str, str], "Model"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + @overload + async def async_create( + self, + *, + deployment: Union[str, Tuple[str, str], "Deployment"], + input: Optional[Dict[str, Any]], + **params: Unpack["Predictions.CreatePredictionParams"], + ) -> Prediction: ... + + async def async_create( # type: ignore + self, + *args, + model: Optional[Union[str, Tuple[str, str], "Model"]] = None, + version: Optional[Union[Version, str, "Version"]] = None, + deployment: Optional[Union[str, Tuple[str, str], "Deployment"]] = None, + input: Optional[Dict[str, Any]] = None, + **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: """ - Create a new prediction for the specified model version. + Create a new prediction for the specified model, version, or deployment. """ + if args: + version = args[0] if len(args) > 0 else None + input = args[1] if len(args) > 1 else input + + if sum(bool(x) for x in [model, version, deployment]) != 1: + raise ValueError( + "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + if model is not None: + from replicate.model import ( # pylint: disable=import-outside-toplevel + Models, + ) + + return await Models(self._client).predictions.async_create( + model=model, + input=input or {}, + **params, + ) + + if deployment is not None: + from replicate.deployment import ( # pylint: disable=import-outside-toplevel + Deployments, + ) + + return await Deployments(self._client).predictions.async_create( + deployment=deployment, + input=input or {}, + **params, + ) + body = _create_prediction_body( version, input, **params, ) + resp = await self._client._async_request( "POST", "/v1/predictions", diff --git a/tests/cassettes/predictions-cancel.yaml b/tests/cassettes/predictions-cancel.yaml index f67e671e..cdbcf9f3 100644 --- a/tests/cassettes/predictions-cancel.yaml +++ b/tests/cassettes/predictions-cancel.yaml @@ -1,183 +1,6 @@ interactions: - request: - body: '' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: GET - uri: https://api.replicate.com/v1/models/stability-ai/sdxl - response: - content: "{\"url\":\"https://replicate.com/stability-ai/sdxl\",\"owner\":\"stability-ai\",\"name\":\"sdxl\",\"description\":\"A - text-to-image generative AI model that creates beautiful 1024x1024 images\",\"visibility\":\"public\",\"github_url\":\"https://github.com/Stability-AI/generative-models\",\"paper_url\":\"https://arxiv.org/abs/2307.01952\",\"license_url\":\"https://github.com/Stability-AI/generative-models/blob/main/model_licenses/LICENSE-SDXL1.0\",\"run_count\":918101,\"cover_image_url\":\"https://tjzk.replicate.delivery/models_models_cover_image/61004930-fb88-4e09-9bd4-74fd8b4aa677/sdxl_cover.png\",\"default_example\":{\"completed_at\":\"2023-07-26T21:04:37.933562Z\",\"created_at\":\"2023-07-26T21:04:23.762683Z\",\"error\":null,\"id\":\"vu42q7dbkm6iicbpal4v6uvbqm\",\"input\":{\"width\":1024,\"height\":1024,\"prompt\":\"An - astronaut riding a rainbow unicorn, cinematic, dramatic\",\"refine\":\"expert_ensemble_refiner\",\"scheduler\":\"DDIM\",\"num_outputs\":1,\"guidance_scale\":7.5,\"high_noise_frac\":0.8,\"prompt_strength\":0.8,\"num_inference_steps\":50},\"logs\":\"Using - seed: 12103\\ntxt2img mode\\n 0%| | 0/40 [00:00"}' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '148' - content-type: - - application/json - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: POST - uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5/trainings - response: - content: '{"detail":"The specified training destination does not exist","status":404} + - request: + body: + '{"input": {"input_images": "https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip"}, + "destination": ""}' + headers: + accept: + - "*/*" + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - "148" + content-type: + - application/json + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: POST + uri: https://api.replicate.com/v1/models/stability-ai/sdxl/versions/39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b/trainings + response: + content: + '{"detail":"The specified training destination does not exist","status":404} - ' - headers: - CF-Cache-Status: - - DYNAMIC - CF-RAY: - - 7f7c2190ed8c281a-SEA - Connection: - - keep-alive - Content-Length: - - '76' - Content-Type: - - application/problem+json - Date: - - Wed, 16 Aug 2023 19:37:18 GMT - NEL: - - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' - Report-To: - - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' - Server: - - cloudflare - Strict-Transport-Security: - - max-age=15552000 - ratelimit-remaining: - - '2999' - ratelimit-reset: - - '1' - via: - - 1.1 google - http_version: HTTP/1.1 - status_code: 404 + ' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c2190ed8c281a-SEA + Connection: + - keep-alive + Content-Length: + - "76" + Content-Type: + - application/problem+json + Date: + - Wed, 16 Aug 2023 19:37:18 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=0vMWFlGDyffyF0A%2FL4%2FH830OVHnZd0gZDww4oocSSHq7eMAt327ut6v%2B2qAda7fThmH4WcElLTM%2B3PFyrsa1w1SHgfEdWyJSv8TYYi2nWXMqeP5EJc1SDjV958HGKSKDnjH5"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + ratelimit-remaining: + - "2999" + ratelimit-reset: + - "1" + via: + - 1.1 google + http_version: HTTP/1.1 + status_code: 404 version: 1 diff --git a/tests/cassettes/trainings-get.yaml b/tests/cassettes/trainings-get.yaml index 3777f743..933591e8 100644 --- a/tests/cassettes/trainings-get.yaml +++ b/tests/cassettes/trainings-get.yaml @@ -1,73 +1,73 @@ interactions: -- request: - body: '' - headers: - accept: - - '*/*' - accept-encoding: - - gzip, deflate - connection: - - keep-alive - host: - - api.replicate.com - user-agent: - - replicate-python/0.11.0 - method: GET - uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte - response: - content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5","webhook_completed":null}' - headers: - CF-Cache-Status: - - DYNAMIC - CF-RAY: - - 7f7c1beaedff279c-SEA - Connection: - - keep-alive - Content-Encoding: - - gzip - Content-Type: - - application/json - Date: - - Wed, 16 Aug 2023 19:33:26 GMT - NEL: - - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' - Report-To: - - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' - Server: - - cloudflare - Strict-Transport-Security: - - max-age=15552000 - Transfer-Encoding: - - chunked - allow: - - OPTIONS, GET - content-security-policy-report-only: - - 'style-src ''report-sample'' ''self'' ''unsafe-inline'' https://fonts.googleapis.com; - img-src ''report-sample'' ''self'' data: https://replicate.delivery https://*.replicate.delivery - https://*.githubusercontent.com https://github.com; worker-src ''none''; media-src - ''report-sample'' ''self'' https://replicate.delivery https://*.replicate.delivery - https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src ''report-sample'' - ''self'' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com - https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src - ''report-sample'' ''self'' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; - font-src ''report-sample'' ''self'' data: https://fonts.replicate.ai https://fonts.gstatic.com; - default-src ''self''; report-uri' - cross-origin-opener-policy: - - same-origin - ratelimit-remaining: - - '2999' - ratelimit-reset: - - '1' - referrer-policy: - - same-origin - vary: - - Cookie, origin - via: - - 1.1 vegur, 1.1 google - x-content-type-options: - - nosniff - x-frame-options: - - DENY - http_version: HTTP/1.1 - status_code: 200 + - request: + body: "" + headers: + accept: + - "*/*" + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - api.replicate.com + user-agent: + - replicate-python/0.11.0 + method: GET + uri: https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte + response: + content: '{"completed_at":null,"created_at":"2023-08-16T19:33:26.906823Z","error":null,"id":"medrnz3bm5dd6ultvad2tejrte","input":{"input_images":"https://replicate.delivery/pbxt/JMV5OrEWpBAC5gO8rre0tPOyJIOkaXvG0TWfVJ9b4zhLeEUY/data.zip","use_face_detection_instead":true},"logs":null,"metrics":{},"output":null,"started_at":"2023-08-16T19:33:42.114513Z","status":"processing","urls":{"get":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte","cancel":"https://api.replicate.com/v1/trainings/medrnz3bm5dd6ultvad2tejrte/cancel"},"model":"stability-ai/sdxl","version":"39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b","webhook_completed":null}' + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 7f7c1beaedff279c-SEA + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Wed, 16 Aug 2023 19:33:26 GMT + NEL: + - '{"success_fraction":0,"report_to":"cf-nel","max_age":604800}' + Report-To: + - '{"endpoints":[{"url":"https:\/\/a.nel.cloudflare.com\/report\/v3?s=SntiwLHCR4wiv49Qmn%2BR1ZblcX%2FgoVlIgsek4yZliZiWts2SqPjqTjrSkB%2Bwch8oHqR%2BBNVs1cSbihlHd8MWPXsbwC2uShz0c6tD4nclaecblb3FnEp4Mccy9hlZ39izF9Tm"}],"group":"cf-nel","max_age":604800}' + Server: + - cloudflare + Strict-Transport-Security: + - max-age=15552000 + Transfer-Encoding: + - chunked + allow: + - OPTIONS, GET + content-security-policy-report-only: + - "style-src 'report-sample' 'self' 'unsafe-inline' https://fonts.googleapis.com; + img-src 'report-sample' 'self' data: https://replicate.delivery https://*.replicate.delivery + https://*.githubusercontent.com https://github.com; worker-src 'none'; media-src + 'report-sample' 'self' https://replicate.delivery https://*.replicate.delivery + https://*.mux.com https://*.gstatic.com https://*.sentry.io; connect-src 'report-sample' + 'self' https://replicate.delivery https://*.replicate.delivery https://*.rudderlabs.com + https://*.rudderstack.com https://*.mux.com https://*.sentry.io; script-src + 'report-sample' 'self' https://cdn.rudderlabs.com/v1.1/rudder-analytics.min.js; + font-src 'report-sample' 'self' data: https://fonts.replicate.ai https://fonts.gstatic.com; + default-src 'self'; report-uri" + cross-origin-opener-policy: + - same-origin + ratelimit-remaining: + - "2999" + ratelimit-reset: + - "1" + referrer-policy: + - same-origin + vary: + - Cookie, origin + via: + - 1.1 vegur, 1.1 google + x-content-type-options: + - nosniff + x-frame-options: + - DENY + http_version: HTTP/1.1 + status_code: 200 version: 1 diff --git a/tests/conftest.py b/tests/conftest.py index a29ed640..103d1693 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,17 @@ +import asyncio import os from unittest import mock import pytest +import pytest_asyncio + + +@pytest_asyncio.fixture(scope="session", autouse=True) +def event_loop(): + event_loop_policy = asyncio.get_event_loop_policy() + loop = event_loop_policy.new_event_loop() + yield loop + loop.close() @pytest.fixture(scope="session") diff --git a/tests/test_prediction.py b/tests/test_prediction.py index c64a5989..c07e02c6 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,4 +1,6 @@ +import httpx import pytest +import respx import replicate @@ -17,7 +19,7 @@ async def test_predictions_create(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -26,7 +28,7 @@ async def test_predictions_create(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, @@ -42,7 +44,7 @@ async def test_predictions_create(async_flag): @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_predictions_create_with_positional_argument(async_flag): - version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" input = { "prompt": "a studio photo of a rainbow colored corgi", @@ -67,21 +69,121 @@ async def test_predictions_create_with_positional_argument(async_flag): assert prediction.status == "starting" -@pytest.mark.vcr("predictions-get.yaml") +@pytest.mark.vcr("predictions-create-by-model.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) -async def test_predictions_get(async_flag): - id = "vgcm4plb7tgzlyznry5d5jkgvu" +async def test_predictions_create_by_model(async_flag): + model = "meta/meta-llama-3-8b-instruct" + input = { + "prompt": "write a haiku about llamas", + } if async_flag: - prediction = await replicate.predictions.async_get(id) + prediction = await replicate.predictions.async_create( + model=model, + input=input, + ) else: - prediction = replicate.predictions.get(id) + prediction = replicate.predictions.create( + model=model, + input=input, + ) - assert prediction.id == id + assert prediction.id is not None + # assert prediction.model == model + assert prediction.status == "starting" -@pytest.mark.vcr("predictions-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_create_by_deployment(async_flag): + router = respx.Router(base_url="https://api.replicate.com/v1") + + router.route( + method="POST", + path="/deployments/replicate/my-app-image-generator/predictions", + name="deployments.predictions.create", + ).mock( + return_value=httpx.Response( + 201, + json={ + "id": "p1", + "model": "replicate/my-app-image-generator", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "source": "api", + "status": "starting", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + ) + + router.route(host="api.replicate.com").pass_through() + + client = replicate.Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + input = {"text": "world"} + + if async_flag: + prediction = await client.predictions.async_create( + deployment="replicate/my-app-image-generator", + input=input, + ) + else: + prediction = client.predictions.create( + deployment="replicate/my-app-image-generator", + input=input, + ) + + assert prediction.id is not None + assert prediction.status == "starting" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_create_fail_with_too_many_arguments(async_flag): + router = respx.Router(base_url="https://api.replicate.com/v1") + + client = replicate.Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + version = "02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + model = "meta/meta-llama-3-8b-instruct" + deployment = "replicate/my-app-image-generator" + input = {} + + with pytest.raises(ValueError) as exc_info: + if async_flag: + await client.predictions.async_create( + version=version, + model=model, + deployment=deployment, + input=input, + ) + else: + client.predictions.create( + version=version, + model=model, + deployment=deployment, + input=input, + ) + assert ( + str(exc_info.value) + == "Exactly one of 'model', 'version', or 'deployment' must be specified." + ) + + +@pytest.mark.vcr("models-predictions-create.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) async def test_predictions_cancel(async_flag): @@ -95,7 +197,7 @@ async def test_predictions_cancel(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -111,7 +213,7 @@ async def test_predictions_cancel(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, @@ -126,6 +228,20 @@ async def test_predictions_cancel(async_flag): assert prediction.status == "canceled" +@pytest.mark.vcr("predictions-get.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_get(async_flag): + id = "vgcm4plb7tgzlyznry5d5jkgvu" + + if async_flag: + prediction = await replicate.predictions.async_get(id) + else: + prediction = replicate.predictions.get(id) + + assert prediction.id == id + + @pytest.mark.vcr("predictions-cancel.yaml") @pytest.mark.asyncio @pytest.mark.parametrize("async_flag", [True, False]) @@ -140,7 +256,7 @@ async def test_predictions_cancel_instance_method(async_flag): if async_flag: model = await replicate.models.async_get("stability-ai/sdxl") version = await model.versions.async_get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = await replicate.predictions.async_create( version=version, @@ -154,7 +270,7 @@ async def test_predictions_cancel_instance_method(async_flag): else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( - "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" ) prediction = replicate.predictions.create( version=version, @@ -199,6 +315,7 @@ async def test_predictions_stream(async_flag): assert prediction.id is not None assert prediction.version == version.id assert prediction.status == "starting" + assert prediction.urls is not None assert prediction.urls["stream"] is not None diff --git a/tests/test_run.py b/tests/test_run.py index 00c93cbc..84c8f3ab 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -17,7 +17,7 @@ async def test_run(async_flag, record_mode): if record_mode == "none": replicate.default_client.poll_interval = 0.001 - version = "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + version = "39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" input = { "prompt": "a studio photo of a rainbow colored corgi", diff --git a/tests/test_training.py b/tests/test_training.py index 1955ffe6..64926c64 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -13,7 +13,7 @@ async def test_trainings_create(async_flag, mock_replicate_api_token): if async_flag: training = await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -23,7 +23,7 @@ async def test_trainings_create(async_flag, mock_replicate_api_token): else: training = replicate.trainings.create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -47,7 +47,7 @@ async def test_trainings_create_with_named_version_argument( return else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -71,7 +71,7 @@ async def test_trainings_create_with_positional_argument( return else: training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", { "input_images": input_images_url, "use_face_detection_instead": True, @@ -93,7 +93,7 @@ async def test_trainings_create_with_invalid_destination( if async_flag: await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -103,7 +103,7 @@ async def test_trainings_create_with_invalid_destination( else: replicate.trainings.create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input={ "input_images": input_images_url, }, @@ -140,7 +140,7 @@ async def test_trainings_cancel(async_flag, mock_replicate_api_token): if async_flag: training = await replicate.trainings.async_create( model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input=input, destination=destination, ) @@ -151,7 +151,7 @@ async def test_trainings_cancel(async_flag, mock_replicate_api_token): assert training.status == "canceled" else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", destination=destination, input=input, ) @@ -179,7 +179,7 @@ async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_t return else: training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", destination=destination, input=input, )