diff --git a/tests/test_deployment.py b/tests/test_deployment.py index b8ee42f0..a039567a 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,3 +1,5 @@ +import json + import httpx import pytest import respx @@ -14,6 +16,7 @@ 201, json={ "id": "p1", + "model": "test/model", "version": "v1", "urls": { "get": "https://api.replicate.com/v1/predictions/p1", @@ -46,6 +49,7 @@ async def test_deployment_predictions_create(async_flag): input={"text": "world"}, webhook="https://example.com/webhook", webhook_events_filter=["completed"], + stream=True, ) else: deployment = client.deployments.get("test/model") @@ -54,8 +58,16 @@ async def test_deployment_predictions_create(async_flag): input={"text": "world"}, webhook="https://example.com/webhook", webhook_events_filter=["completed"], + stream=True, ) assert router["deployments.predictions.create"].called + request = router["deployments.predictions.create"].calls[0].request + request_body = json.loads(request.content) + assert request_body["input"] == {"text": "world"} + assert request_body["webhook"] == "https://example.com/webhook" + assert request_body["webhook_events_filter"] == ["completed"] + assert request_body["stream"] is True + assert prediction.id == "p1" assert prediction.input == {"text": "world"}