Skip to content

Fix signature of create methods #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
pass

@abc.abstractmethod
def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring
def create( # pylint: disable=missing-function-docstring
self, *args, **kwargs
) -> Model:
pass

def prepare_model(self, attrs: Union[Model, Dict]) -> Model:
Expand Down
59 changes: 43 additions & 16 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload

from typing_extensions import Unpack

from replicate.base_model import BaseModel
from replicate.collection import Collection
from replicate.files import upload_file
from replicate.json import encode_json
from replicate.prediction import Prediction
from replicate.prediction import Prediction, PredictionCollection

if TYPE_CHECKING:
from replicate.client import Client
Expand Down Expand Up @@ -65,7 +67,11 @@ def get(self, name: str) -> Deployment:
username, name = name.split("/")
return self.prepare_model({"username": username, "name": name})

def create(self, **kwargs) -> Deployment:
def create(
self,
*args,
**kwargs,
) -> Deployment:
"""
Create a deployment.

Expand Down Expand Up @@ -114,15 +120,34 @@ def get(self, id: str) -> Prediction:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
input: Dict[str, Any],
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
*,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
...

def create(
self,
*args,
**kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc]
) -> Prediction:
"""
Create a new prediction with the deployment.
Expand All @@ -138,18 +163,20 @@ def create( # type: ignore
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body: Dict[str, Any] = {
"input": input,
input = args[0] if len(args) > 0 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"input": encode_json(input, upload_file=upload_file),
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = True

for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

resp = self._client._request(
"POST",
Expand Down
6 changes: 5 additions & 1 deletion replicate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def get(self, key: str) -> Model:
resp = self._client._request("GET", f"/v1/models/{key}")
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Model:
def create(
self,
*args,
**kwargs,
) -> Model:
"""
Create a model.

Expand Down
69 changes: 55 additions & 14 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload

from typing_extensions import Unpack

from replicate.base_model import BaseModel
from replicate.collection import Collection
Expand Down Expand Up @@ -137,6 +139,16 @@ class PredictionCollection(Collection):
Namespace for operations related to predictions.
"""

class CreateParams(TypedDict):
"""Parameters for creating a prediction."""

version: Union[Version, str]
input: Dict[str, Any]
webhook: Optional[str]
webhook_completed: Optional[str]
webhook_events_filter: Optional[List[str]]
stream: Optional[bool]

model = Prediction

def list(self) -> List[Prediction]:
Expand Down Expand Up @@ -171,16 +183,36 @@ def get(self, id: str) -> Prediction:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: Union[Version, str],
input: Dict[str, Any],
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
) -> Prediction:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
*,
version: Union[Version, str],
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
stream: Optional[bool] = None,
**kwargs,
) -> Prediction:
...

def create(
self,
*args,
**kwargs: Unpack[CreateParams], # type: ignore[misc]
) -> Prediction:
"""
Create a new prediction for the specified model version.
Expand All @@ -197,19 +229,28 @@ def create( # type: ignore
Prediction: The created prediction object.
"""

input = encode_json(input, upload_file=upload_file)
body: Dict[str, Any] = {
# Support positional arguments for backwards compatibility
version = args[0] if args else kwargs.get("version")
if version is None:
raise ValueError(
"A version identifier must be provided as a positional or keyword argument."
)

input = args[1] if len(args) > 1 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"version": version if isinstance(version, str) else version.id,
"input": input,
"input": encode_json(input, upload_file=upload_file),
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter
if stream is True:
body["stream"] = True

for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

resp = self._client._request(
"POST",
Expand Down
77 changes: 66 additions & 11 deletions replicate/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TypedDict, Union

from typing_extensions import NotRequired, Unpack, overload

from replicate.base_model import BaseModel
from replicate.collection import Collection
Expand Down Expand Up @@ -68,6 +70,16 @@ class TrainingCollection(Collection):

model = Training

class CreateParams(TypedDict):
"""Parameters for creating a prediction."""

version: Union[Version, str]
destination: str
input: Dict[str, Any]
webhook: NotRequired[str]
webhook_completed: NotRequired[str]
webhook_events_filter: NotRequired[List[str]]

def list(self) -> List[Training]:
"""
List your trainings.
Expand Down Expand Up @@ -103,14 +115,36 @@ def get(self, id: str) -> Training:
del obj["version"]
return self.prepare_model(obj)

def create( # type: ignore
@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: Union[Version, str],
input: Dict[str, Any],
destination: str,
*,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
) -> Training:
...

@overload
def create( # pylint: disable=arguments-differ disable=too-many-arguments
self,
version: str,
*,
version: Union[Version, str],
input: Dict[str, Any],
destination: str,
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
**kwargs,
) -> Training:
...

def create(
self,
*args,
**kwargs: Unpack[CreateParams], # type: ignore[misc]
) -> Training:
"""
Create a new training using the specified model version as a base.
Expand All @@ -120,24 +154,45 @@ def create( # type: ignore
input: The input to the training.
destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request.
webhook: The URL to send a POST request to when the training is completed. Defaults to None.
webhook_completed: The URL to receive a POST request when the prediction is completed.
webhook_events_filter: The events to send to the webhook. Defaults to None.
Returns:
The training object.
"""

input = encode_json(input, upload_file=upload_file)
# Support positional arguments for backwards compatibility
version = args[0] if args else kwargs.get("version")
if version is None:
raise ValueError(
"A version identifier must be provided as a positional or keyword argument."
)

destination = args[1] if len(args) > 1 else kwargs.get("destination")
if destination is None:
raise ValueError(
"A destination must be provided as a positional or keyword argument."
)

input = args[2] if len(args) > 2 else kwargs.get("input")
if input is None:
raise ValueError(
"An input must be provided as a positional or keyword argument."
)

body = {
"input": input,
"input": encode_json(input, upload_file=upload_file),
"destination": destination,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

for key in ["webhook", "webhook_completed", "webhook_events_filter"]:
value = kwargs.get(key)
if value is not None:
body[key] = value

# Split version in format "username/model_name:version_id"
match = re.match(
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$",
version.id if isinstance(version, Version) else version,
)
if not match:
raise ReplicateException(
Expand Down
6 changes: 5 additions & 1 deletion replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get(self, id: str) -> Version:
)
return self.prepare_model(resp.json())

def create(self, **kwargs) -> Version:
def create(
self,
*args,
**kwargs,
) -> Version:
"""
Create a model version.

Expand Down
Loading