From 1d3393d50a1ebca3f3bb2327184310e8d9dfc691 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 05:26:06 -0700 Subject: [PATCH 1/5] Fix signature of create methods Provide backwards-compatible migration to typed kwargs Signed-off-by: Mattt Zmuda --- replicate/collection.py | 11 ++++++-- replicate/deployment.py | 31 +++++++++++---------- replicate/model.py | 12 ++++++--- replicate/prediction.py | 55 +++++++++++++++++++++++-------------- replicate/training.py | 60 ++++++++++++++++++++++++++++++----------- replicate/version.py | 8 +++++- tests/test_training.py | 20 ++++++++++++-- 7 files changed, 139 insertions(+), 58 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index 799b7b63..5900c68c 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,6 +1,8 @@ import abc from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast +from typing_extensions import TypedDict, Unpack + if TYPE_CHECKING: from replicate.client import Client @@ -8,9 +10,10 @@ from replicate.exceptions import ReplicateException Model = TypeVar("Model", bound=BaseModel) +CreateParams = TypeVar("CreateParams", bound=TypedDict) -class Collection(abc.ABC, Generic[Model]): +class Collection(abc.ABC, Generic[Model, CreateParams]): """ A base class for representing objects of a particular type on the server. """ @@ -32,7 +35,11 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring pass @abc.abstractmethod - def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring + def create( # pylint: disable=missing-function-docstring + self, + *args, + **kwargs: Unpack[CreateParams], # type: ignore[misc] + ) -> Model: pass def prepare_model(self, attrs: Union[Model, Dict]) -> Model: diff --git a/replicate/deployment.py b/replicate/deployment.py index 1a0766c7..06805d57 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,10 +1,12 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from typing_extensions import TypedDict, Unpack from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.files import upload_file from replicate.json import encode_json -from replicate.prediction import Prediction +from replicate.prediction import Prediction, PredictionCollection if TYPE_CHECKING: from replicate.client import Client @@ -65,7 +67,11 @@ def get(self, name: str) -> Deployment: username, name = name.split("/") return self.prepare_model({"username": username, "name": name}) - def create(self, **kwargs) -> Deployment: + def create( + self, + *args, + **kwargs: Unpack[TypedDict], # type: ignore[misc] + ) -> Deployment: """ Create a deployment. @@ -114,15 +120,10 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + def create( self, - input: Dict[str, Any], - webhook: Optional[str] = None, - webhook_completed: Optional[str] = None, - webhook_events_filter: Optional[List[str]] = None, - *, - stream: Optional[bool] = None, - **kwargs, + *args, + **kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction with the deployment. @@ -138,14 +139,16 @@ def create( # type: ignore Prediction: The created prediction object. """ - input = encode_json(input, upload_file=upload_file) + webhook = kwargs.get("webhook") + webhook_events_filter = kwargs.get("webhook_events_filter") + stream = kwargs.get("stream") + + input = encode_json(kwargs.get("input"), upload_file=upload_file) body: Dict[str, Any] = { "input": input, } if webhook is not None: body["webhook"] = webhook - if webhook_completed is not None: - body["webhook_completed"] = webhook_completed if webhook_events_filter is not None: body["webhook_events_filter"] = webhook_events_filter if stream is True: diff --git a/replicate/model.py b/replicate/model.py index 3dcc2427..85fe4f06 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, TypedDict, Union -from typing_extensions import deprecated +from typing_extensions import TypedDict, Unpack, deprecated from replicate.base_model import BaseModel from replicate.collection import Collection @@ -139,9 +139,13 @@ def get(self, key: str) -> Model: resp = self._client._request("GET", f"/v1/models/{key}") return self.prepare_model(resp.json()) - def create(self, **kwargs) -> Model: + def create( + self, + *args, + **kwargs: Unpack[TypedDict], # type: ignore[misc] + ) -> Model: """ - Create a model. + Creates a model. Raises: NotImplementedError: This method is not implemented. diff --git a/replicate/prediction.py b/replicate/prediction.py index f8afa5e8..0c790d6d 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Optional, Union +from typing_extensions import TypedDict, Unpack + from replicate.base_model import BaseModel from replicate.collection import Collection from replicate.exceptions import ModelError @@ -137,6 +139,16 @@ class PredictionCollection(Collection): Namespace for operations related to predictions. """ + class CreateParams(TypedDict): + """Parameters for creating a prediction.""" + + version: Union[Version, str] + input: Dict[str, Any] + webhook: Optional[str] + webhook_completed: Optional[str] + webhook_events_filter: Optional[List[str]] + stream: Optional[bool] + model = Prediction def list(self) -> List[Prediction]: @@ -171,16 +183,10 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + def create( self, - version: Union[Version, str], - input: Dict[str, Any], - webhook: Optional[str] = None, - webhook_completed: Optional[str] = None, - webhook_events_filter: Optional[List[str]] = None, - *, - stream: Optional[bool] = None, - **kwargs, + *args, + **kwargs: Unpack[CreateParams], # type: ignore[misc] ) -> Prediction: """ Create a new prediction for the specified model version. @@ -197,19 +203,28 @@ def create( # type: ignore Prediction: The created prediction object. """ - input = encode_json(input, upload_file=upload_file) - body: Dict[str, Any] = { + # Support positional arguments for backwards compatibility + version = args[0] if args else kwargs.get("version") + if version is None: + raise ValueError( + "A version identifier must be provided as a positional or keyword argument." + ) + + input = args[1] if len(args) > 1 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) + + body = { "version": version if isinstance(version, str) else version.id, - "input": input, + "input": encode_json(input, upload_file=upload_file), } - if webhook is not None: - body["webhook"] = webhook - if webhook_completed is not None: - body["webhook_completed"] = webhook_completed - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter - if stream is True: - body["stream"] = True + + for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: + value = kwargs.get(key) + if value is not None: + body[key] = value resp = self._client._request( "POST", diff --git a/replicate/training.py b/replicate/training.py index 4499a79e..f275560a 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -1,5 +1,7 @@ import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union + +from typing_extensions import NotRequired, TypedDict, Unpack from replicate.base_model import BaseModel from replicate.collection import Collection @@ -68,6 +70,17 @@ class TrainingCollection(Collection): model = Training + class CreateParams(TypedDict): + """Parameters for creating a prediction.""" + + version: Union[Version, str] + destination: str + input: Dict[str, Any] + webhook: NotRequired[str] + webhook_completed: NotRequired[str] + webhook_events_filter: NotRequired[List[str]] + stream: NotRequired[bool] + def list(self) -> List[Training]: """ List your trainings. @@ -103,14 +116,10 @@ def get(self, id: str) -> Training: del obj["version"] return self.prepare_model(obj) - def create( # type: ignore + def create( self, - version: str, - input: Dict[str, Any], - destination: str, - webhook: Optional[str] = None, - webhook_events_filter: Optional[List[str]] = None, - **kwargs, + *args, + **kwargs: Unpack[CreateParams], # type: ignore[misc] ) -> Training: """ Create a new training using the specified model version as a base. @@ -120,24 +129,45 @@ def create( # type: ignore input: The input to the training. destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. webhook: The URL to send a POST request to when the training is completed. Defaults to None. + webhook_completed: The URL to receive a POST request when the prediction is completed. webhook_events_filter: The events to send to the webhook. Defaults to None. Returns: The training object. """ - input = encode_json(input, upload_file=upload_file) + # Support positional arguments for backwards compatibility + version = args[0] if args else kwargs.get("version") + if version is None: + raise ValueError( + "A version identifier must be provided as a positional or keyword argument." + ) + + destination = args[1] if len(args) > 1 else kwargs.get("destination") + if destination is None: + raise ValueError( + "A destination must be provided as a positional or keyword argument." + ) + + input = args[2] if len(args) > 2 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) + body = { - "input": input, + "input": encode_json(input, upload_file=upload_file), "destination": destination, } - if webhook is not None: - body["webhook"] = webhook - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter + + for key in ["webhook", "webhook_completed", "webhook_events_filter"]: + value = kwargs.get(key) + if value is not None: + body[key] = value # Split version in format "username/model_name:version_id" match = re.match( - r"^(?P[^/]+)/(?P[^:]+):(?P.+)$", version + r"^(?P[^/]+)/(?P[^:]+):(?P.+)$", + version.id if isinstance(version, Version) else version, ) if not match: raise ReplicateException( diff --git a/replicate/version.py b/replicate/version.py index c3be8b2e..9e66457a 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -2,6 +2,8 @@ import warnings from typing import TYPE_CHECKING, Any, Iterator, List, Union +from typing_extensions import TypedDict, Unpack + if TYPE_CHECKING: from replicate.client import Client from replicate.model import Model @@ -94,7 +96,11 @@ def get(self, id: str) -> Version: ) return self.prepare_model(resp.json()) - def create(self, **kwargs) -> Version: + def create( + self, + *args, + **kwargs: Unpack[TypedDict], # type: ignore[misc] + ) -> Version: """ Create a model version. diff --git a/tests/test_training.py b/tests/test_training.py index da215749..0c4a4782 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio async def test_trainings_create(mock_replicate_api_token): training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", input={ "input_images": input_images_url, "use_face_detection_instead": True, @@ -22,6 +22,22 @@ async def test_trainings_create(mock_replicate_api_token): assert training.status == "starting" +@pytest.mark.vcr("trainings-create.yaml") +@pytest.mark.asyncio +async def test_trainings_create_with_positional_argument(mock_replicate_api_token): + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + { + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + "replicate/dreambooth-sdxl", + ) + + assert training.id is not None + assert training.status == "starting" + + @pytest.mark.vcr("trainings-create__invalid-destination.yaml") @pytest.mark.asyncio async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): @@ -57,7 +73,7 @@ async def test_trainings_cancel(mock_replicate_api_token): destination = "replicate/dreambooth-sdxl" training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", destination=destination, input=input, ) From f278cdb975659ddc71e0dd72e16e5f3ee3bf1241 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 05:42:28 -0700 Subject: [PATCH 2/5] Fix tense of docstring Signed-off-by: Mattt Zmuda --- replicate/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/model.py b/replicate/model.py index 85fe4f06..77ff3cdd 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -145,7 +145,7 @@ def create( **kwargs: Unpack[TypedDict], # type: ignore[misc] ) -> Model: """ - Creates a model. + Create a model. Raises: NotImplementedError: This method is not implemented. From 6242c0f5439935d8701f88e18f1f0fe2080e355d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 05:58:07 -0700 Subject: [PATCH 3/5] Remove stream from Training create parameters Signed-off-by: Mattt Zmuda --- replicate/training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/replicate/training.py b/replicate/training.py index f275560a..27e23f41 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -79,7 +79,6 @@ class CreateParams(TypedDict): webhook: NotRequired[str] webhook_completed: NotRequired[str] webhook_events_filter: NotRequired[List[str]] - stream: NotRequired[bool] def list(self) -> List[Training]: """ From a20212fcdcc8e672740afcd6adc34f14362b643f Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 05:58:39 -0700 Subject: [PATCH 4/5] Fix Redefinition of unused `TypedDict` Signed-off-by: Mattt Zmuda --- replicate/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replicate/model.py b/replicate/model.py index 77ff3cdd..8d6e50d9 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, TypedDict, Union +from typing import Dict, List, Optional, Union from typing_extensions import TypedDict, Unpack, deprecated From a2da3a8e005234c9ea78c4cc2921be2fbec5d603 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 2 Nov 2023 06:37:24 -0700 Subject: [PATCH 5/5] Add overloads for create methods Signed-off-by: Mattt Zmuda --- replicate/collection.py | 9 ++----- replicate/deployment.py | 54 +++++++++++++++++++++++++++++------------ replicate/model.py | 4 +-- replicate/prediction.py | 30 +++++++++++++++++++++-- replicate/training.py | 30 +++++++++++++++++++++-- replicate/version.py | 4 +-- 6 files changed, 100 insertions(+), 31 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index 5900c68c..19e8e6c1 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -1,8 +1,6 @@ import abc from typing import TYPE_CHECKING, Dict, Generic, List, TypeVar, Union, cast -from typing_extensions import TypedDict, Unpack - if TYPE_CHECKING: from replicate.client import Client @@ -10,10 +8,9 @@ from replicate.exceptions import ReplicateException Model = TypeVar("Model", bound=BaseModel) -CreateParams = TypeVar("CreateParams", bound=TypedDict) -class Collection(abc.ABC, Generic[Model, CreateParams]): +class Collection(abc.ABC, Generic[Model]): """ A base class for representing objects of a particular type on the server. """ @@ -36,9 +33,7 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring @abc.abstractmethod def create( # pylint: disable=missing-function-docstring - self, - *args, - **kwargs: Unpack[CreateParams], # type: ignore[misc] + self, *args, **kwargs ) -> Model: pass diff --git a/replicate/deployment.py b/replicate/deployment.py index 06805d57..cfdd0da1 100644 --- a/replicate/deployment.py +++ b/replicate/deployment.py @@ -1,6 +1,6 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload -from typing_extensions import TypedDict, Unpack +from typing_extensions import Unpack from replicate.base_model import BaseModel from replicate.collection import Collection @@ -70,7 +70,7 @@ def get(self, name: str) -> Deployment: def create( self, *args, - **kwargs: Unpack[TypedDict], # type: ignore[misc] + **kwargs, ) -> Deployment: """ Create a deployment. @@ -120,6 +120,30 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + input: Dict[str, Any], + *, + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + *, + input: Dict[str, Any], + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + def create( self, *args, @@ -139,20 +163,20 @@ def create( Prediction: The created prediction object. """ - webhook = kwargs.get("webhook") - webhook_events_filter = kwargs.get("webhook_events_filter") - stream = kwargs.get("stream") + input = args[0] if len(args) > 0 else kwargs.get("input") + if input is None: + raise ValueError( + "An input must be provided as a positional or keyword argument." + ) - input = encode_json(kwargs.get("input"), upload_file=upload_file) - body: Dict[str, Any] = { - "input": input, + body = { + "input": encode_json(input, upload_file=upload_file), } - if webhook is not None: - body["webhook"] = webhook - if webhook_events_filter is not None: - body["webhook_events_filter"] = webhook_events_filter - if stream is True: - body["stream"] = True + + for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]: + value = kwargs.get(key) + if value is not None: + body[key] = value resp = self._client._request( "POST", diff --git a/replicate/model.py b/replicate/model.py index 8d6e50d9..887e2bfc 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional, Union -from typing_extensions import TypedDict, Unpack, deprecated +from typing_extensions import deprecated from replicate.base_model import BaseModel from replicate.collection import Collection @@ -142,7 +142,7 @@ def get(self, key: str) -> Model: def create( self, *args, - **kwargs: Unpack[TypedDict], # type: ignore[misc] + **kwargs, ) -> Model: """ Create a model. diff --git a/replicate/prediction.py b/replicate/prediction.py index 0c790d6d..05cdec2b 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -1,9 +1,9 @@ import re import time from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload -from typing_extensions import TypedDict, Unpack +from typing_extensions import Unpack from replicate.base_model import BaseModel from replicate.collection import Collection @@ -183,6 +183,32 @@ def get(self, id: str) -> Prediction: del obj["version"] return self.prepare_model(obj) + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + version: Union[Version, str], + input: Dict[str, Any], + *, + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + *, + version: Union[Version, str], + input: Dict[str, Any], + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + stream: Optional[bool] = None, + ) -> Prediction: + ... + def create( self, *args, diff --git a/replicate/training.py b/replicate/training.py index 27e23f41..ee3fe7e4 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -1,7 +1,7 @@ import re -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, TypedDict, Union -from typing_extensions import NotRequired, TypedDict, Unpack +from typing_extensions import NotRequired, Unpack, overload from replicate.base_model import BaseModel from replicate.collection import Collection @@ -115,6 +115,32 @@ def get(self, id: str) -> Training: del obj["version"] return self.prepare_model(obj) + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + version: Union[Version, str], + input: Dict[str, Any], + destination: str, + *, + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + ) -> Training: + ... + + @overload + def create( # pylint: disable=arguments-differ disable=too-many-arguments + self, + *, + version: Union[Version, str], + input: Dict[str, Any], + destination: str, + webhook: Optional[str] = None, + webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, + ) -> Training: + ... + def create( self, *args, diff --git a/replicate/version.py b/replicate/version.py index 9e66457a..2d7cd3ae 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -2,8 +2,6 @@ import warnings from typing import TYPE_CHECKING, Any, Iterator, List, Union -from typing_extensions import TypedDict, Unpack - if TYPE_CHECKING: from replicate.client import Client from replicate.model import Model @@ -99,7 +97,7 @@ def get(self, id: str) -> Version: def create( self, *args, - **kwargs: Unpack[TypedDict], # type: ignore[misc] + **kwargs, ) -> Version: """ Create a model version.