diff --git a/replicate/prediction.py b/replicate/prediction.py index 488e9a3b..a71524ec 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,7 +1,7 @@ import re import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Union from typing_extensions import NotRequired, TypedDict, Unpack @@ -37,7 +37,7 @@ class Prediction(Resource): version: str """An identifier for the version of the model used to create the prediction.""" - status: str + status: Literal["starting", "processing", "succeeded", "failed", "canceled"] """The status of the prediction.""" input: Optional[Dict[str, Any]] diff --git a/replicate/training.py b/replicate/training.py index 39291f76..934489ca 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -3,6 +3,7 @@ Any, Dict, List, + Literal, Optional, Tuple, TypedDict, @@ -48,7 +49,7 @@ class Training(Resource): destination: Optional[str] """The model destination of the training.""" - status: str + status: Literal["starting", "processing", "succeeded", "failed", "canceled"] """The status of the training.""" input: Optional[Dict[str, Any]] diff --git a/tests/test_run.py b/tests/test_run.py index c72c3bdb..00c93cbc 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -121,7 +121,7 @@ def prediction_with_status(status: str) -> dict: router.route(method="POST", path="/predictions").mock( return_value=httpx.Response( 201, - json=prediction_with_status("running"), + json=prediction_with_status("processing"), ) ) router.route(method="GET", path="/predictions/p1").mock(