From 1969904c0e5ae7c038924d50e2b8d204daec082d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 16 Nov 2023 06:00:55 -0500 Subject: [PATCH 1/2] Test that stream key is sent when creating a prediction on a deployment Signed-off-by: Mattt Zmuda --- tests/test_deployment.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_deployment.py b/tests/test_deployment.py index b8ee42f0..045677f6 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,3 +1,5 @@ +import json + import httpx import pytest import respx @@ -14,6 +16,7 @@ 201, json={ "id": "p1", + "model": "test/model", "version": "v1", "urls": { "get": "https://api.replicate.com/v1/predictions/p1", @@ -46,6 +49,7 @@ async def test_deployment_predictions_create(async_flag): input={"text": "world"}, webhook="https://example.com/webhook", webhook_events_filter=["completed"], + stream=True, ) else: deployment = client.deployments.get("test/model") @@ -54,8 +58,12 @@ async def test_deployment_predictions_create(async_flag): input={"text": "world"}, webhook="https://example.com/webhook", webhook_events_filter=["completed"], + stream=True, ) assert router["deployments.predictions.create"].called + request = router["deployments.predictions.create"].calls[0].request + assert "stream" in json.loads(request.content) + assert prediction.id == "p1" assert prediction.input == {"text": "world"} From 836a2f80bf90d15a1f7f28decf8ab3eda32e7188 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 16 Nov 2023 06:04:16 -0500 Subject: [PATCH 2/2] Test all values in request body Signed-off-by: Mattt Zmuda --- tests/test_deployment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 045677f6..a039567a 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -63,7 +63,11 @@ async def test_deployment_predictions_create(async_flag): assert router["deployments.predictions.create"].called request = router["deployments.predictions.create"].calls[0].request - assert "stream" in json.loads(request.content) + request_body = json.loads(request.content) + assert request_body["input"] == {"text": "world"} + assert request_body["webhook"] == "https://example.com/webhook" + assert request_body["webhook_events_filter"] == ["completed"] + assert request_body["stream"] is True assert prediction.id == "p1" assert prediction.input == {"text": "world"}