Skip to content

Commit e531cff

Browse files
authored
Fix signature of create methods (#181)
Provides backwards-compatible migration to typed kwargs following [PEP 692](https://peps.python.org/pep-0692/). --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent 39d6bc9 commit e531cff

File tree

7 files changed

+195
-46
lines changed

7 files changed

+195
-46
lines changed

replicate/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def get(self, key: str) -> Model: # pylint: disable=missing-function-docstring
3232
pass
3333

3434
@abc.abstractmethod
35-
def create(self, **kwargs) -> Model: # pylint: disable=missing-function-docstring
35+
def create( # pylint: disable=missing-function-docstring
36+
self, *args, **kwargs
37+
) -> Model:
3638
pass
3739

3840
def prepare_model(self, attrs: Union[Model, Dict]) -> Model:

replicate/deployment.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload
2+
3+
from typing_extensions import Unpack
24

35
from replicate.base_model import BaseModel
46
from replicate.collection import Collection
57
from replicate.files import upload_file
68
from replicate.json import encode_json
7-
from replicate.prediction import Prediction
9+
from replicate.prediction import Prediction, PredictionCollection
810

911
if TYPE_CHECKING:
1012
from replicate.client import Client
@@ -65,7 +67,11 @@ def get(self, name: str) -> Deployment:
6567
username, name = name.split("/")
6668
return self.prepare_model({"username": username, "name": name})
6769

68-
def create(self, **kwargs) -> Deployment:
70+
def create(
71+
self,
72+
*args,
73+
**kwargs,
74+
) -> Deployment:
6975
"""
7076
Create a deployment.
7177
@@ -114,15 +120,34 @@ def get(self, id: str) -> Prediction:
114120
del obj["version"]
115121
return self.prepare_model(obj)
116122

117-
def create( # type: ignore
123+
@overload
124+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
118125
self,
119126
input: Dict[str, Any],
127+
*,
120128
webhook: Optional[str] = None,
121129
webhook_completed: Optional[str] = None,
122130
webhook_events_filter: Optional[List[str]] = None,
131+
stream: Optional[bool] = None,
132+
) -> Prediction:
133+
...
134+
135+
@overload
136+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
137+
self,
123138
*,
139+
input: Dict[str, Any],
140+
webhook: Optional[str] = None,
141+
webhook_completed: Optional[str] = None,
142+
webhook_events_filter: Optional[List[str]] = None,
124143
stream: Optional[bool] = None,
125-
**kwargs,
144+
) -> Prediction:
145+
...
146+
147+
def create(
148+
self,
149+
*args,
150+
**kwargs: Unpack[PredictionCollection.CreateParams], # type: ignore[misc]
126151
) -> Prediction:
127152
"""
128153
Create a new prediction with the deployment.
@@ -138,18 +163,20 @@ def create( # type: ignore
138163
Prediction: The created prediction object.
139164
"""
140165

141-
input = encode_json(input, upload_file=upload_file)
142-
body: Dict[str, Any] = {
143-
"input": input,
166+
input = args[0] if len(args) > 0 else kwargs.get("input")
167+
if input is None:
168+
raise ValueError(
169+
"An input must be provided as a positional or keyword argument."
170+
)
171+
172+
body = {
173+
"input": encode_json(input, upload_file=upload_file),
144174
}
145-
if webhook is not None:
146-
body["webhook"] = webhook
147-
if webhook_completed is not None:
148-
body["webhook_completed"] = webhook_completed
149-
if webhook_events_filter is not None:
150-
body["webhook_events_filter"] = webhook_events_filter
151-
if stream is True:
152-
body["stream"] = True
175+
176+
for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
177+
value = kwargs.get(key)
178+
if value is not None:
179+
body[key] = value
153180

154181
resp = self._client._request(
155182
"POST",

replicate/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,11 @@ def get(self, key: str) -> Model:
139139
resp = self._client._request("GET", f"/v1/models/{key}")
140140
return self.prepare_model(resp.json())
141141

142-
def create(self, **kwargs) -> Model:
142+
def create(
143+
self,
144+
*args,
145+
**kwargs,
146+
) -> Model:
143147
"""
144148
Create a model.
145149

replicate/prediction.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import re
22
import time
33
from dataclasses import dataclass
4-
from typing import Any, Dict, Iterator, List, Optional, Union
4+
from typing import Any, Dict, Iterator, List, Optional, TypedDict, Union, overload
5+
6+
from typing_extensions import Unpack
57

68
from replicate.base_model import BaseModel
79
from replicate.collection import Collection
@@ -137,6 +139,16 @@ class PredictionCollection(Collection):
137139
Namespace for operations related to predictions.
138140
"""
139141

142+
class CreateParams(TypedDict):
143+
"""Parameters for creating a prediction."""
144+
145+
version: Union[Version, str]
146+
input: Dict[str, Any]
147+
webhook: Optional[str]
148+
webhook_completed: Optional[str]
149+
webhook_events_filter: Optional[List[str]]
150+
stream: Optional[bool]
151+
140152
model = Prediction
141153

142154
def list(self) -> List[Prediction]:
@@ -171,16 +183,36 @@ def get(self, id: str) -> Prediction:
171183
del obj["version"]
172184
return self.prepare_model(obj)
173185

174-
def create( # type: ignore
186+
@overload
187+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
175188
self,
176189
version: Union[Version, str],
177190
input: Dict[str, Any],
191+
*,
178192
webhook: Optional[str] = None,
179193
webhook_completed: Optional[str] = None,
180194
webhook_events_filter: Optional[List[str]] = None,
195+
stream: Optional[bool] = None,
196+
) -> Prediction:
197+
...
198+
199+
@overload
200+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
201+
self,
181202
*,
203+
version: Union[Version, str],
204+
input: Dict[str, Any],
205+
webhook: Optional[str] = None,
206+
webhook_completed: Optional[str] = None,
207+
webhook_events_filter: Optional[List[str]] = None,
182208
stream: Optional[bool] = None,
183-
**kwargs,
209+
) -> Prediction:
210+
...
211+
212+
def create(
213+
self,
214+
*args,
215+
**kwargs: Unpack[CreateParams], # type: ignore[misc]
184216
) -> Prediction:
185217
"""
186218
Create a new prediction for the specified model version.
@@ -197,19 +229,28 @@ def create( # type: ignore
197229
Prediction: The created prediction object.
198230
"""
199231

200-
input = encode_json(input, upload_file=upload_file)
201-
body: Dict[str, Any] = {
232+
# Support positional arguments for backwards compatibility
233+
version = args[0] if args else kwargs.get("version")
234+
if version is None:
235+
raise ValueError(
236+
"A version identifier must be provided as a positional or keyword argument."
237+
)
238+
239+
input = args[1] if len(args) > 1 else kwargs.get("input")
240+
if input is None:
241+
raise ValueError(
242+
"An input must be provided as a positional or keyword argument."
243+
)
244+
245+
body = {
202246
"version": version if isinstance(version, str) else version.id,
203-
"input": input,
247+
"input": encode_json(input, upload_file=upload_file),
204248
}
205-
if webhook is not None:
206-
body["webhook"] = webhook
207-
if webhook_completed is not None:
208-
body["webhook_completed"] = webhook_completed
209-
if webhook_events_filter is not None:
210-
body["webhook_events_filter"] = webhook_events_filter
211-
if stream is True:
212-
body["stream"] = True
249+
250+
for key in ["webhook", "webhook_completed", "webhook_events_filter", "stream"]:
251+
value = kwargs.get(key)
252+
if value is not None:
253+
body[key] = value
213254

214255
resp = self._client._request(
215256
"POST",

replicate/training.py

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import re
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, TypedDict, Union
3+
4+
from typing_extensions import NotRequired, Unpack, overload
35

46
from replicate.base_model import BaseModel
57
from replicate.collection import Collection
@@ -68,6 +70,16 @@ class TrainingCollection(Collection):
6870

6971
model = Training
7072

73+
class CreateParams(TypedDict):
74+
"""Parameters for creating a prediction."""
75+
76+
version: Union[Version, str]
77+
destination: str
78+
input: Dict[str, Any]
79+
webhook: NotRequired[str]
80+
webhook_completed: NotRequired[str]
81+
webhook_events_filter: NotRequired[List[str]]
82+
7183
def list(self) -> List[Training]:
7284
"""
7385
List your trainings.
@@ -103,14 +115,36 @@ def get(self, id: str) -> Training:
103115
del obj["version"]
104116
return self.prepare_model(obj)
105117

106-
def create( # type: ignore
118+
@overload
119+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
120+
self,
121+
version: Union[Version, str],
122+
input: Dict[str, Any],
123+
destination: str,
124+
*,
125+
webhook: Optional[str] = None,
126+
webhook_completed: Optional[str] = None,
127+
webhook_events_filter: Optional[List[str]] = None,
128+
) -> Training:
129+
...
130+
131+
@overload
132+
def create( # pylint: disable=arguments-differ disable=too-many-arguments
107133
self,
108-
version: str,
134+
*,
135+
version: Union[Version, str],
109136
input: Dict[str, Any],
110137
destination: str,
111138
webhook: Optional[str] = None,
139+
webhook_completed: Optional[str] = None,
112140
webhook_events_filter: Optional[List[str]] = None,
113-
**kwargs,
141+
) -> Training:
142+
...
143+
144+
def create(
145+
self,
146+
*args,
147+
**kwargs: Unpack[CreateParams], # type: ignore[misc]
114148
) -> Training:
115149
"""
116150
Create a new training using the specified model version as a base.
@@ -120,24 +154,45 @@ def create( # type: ignore
120154
input: The input to the training.
121155
destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request.
122156
webhook: The URL to send a POST request to when the training is completed. Defaults to None.
157+
webhook_completed: The URL to receive a POST request when the prediction is completed.
123158
webhook_events_filter: The events to send to the webhook. Defaults to None.
124159
Returns:
125160
The training object.
126161
"""
127162

128-
input = encode_json(input, upload_file=upload_file)
163+
# Support positional arguments for backwards compatibility
164+
version = args[0] if args else kwargs.get("version")
165+
if version is None:
166+
raise ValueError(
167+
"A version identifier must be provided as a positional or keyword argument."
168+
)
169+
170+
destination = args[1] if len(args) > 1 else kwargs.get("destination")
171+
if destination is None:
172+
raise ValueError(
173+
"A destination must be provided as a positional or keyword argument."
174+
)
175+
176+
input = args[2] if len(args) > 2 else kwargs.get("input")
177+
if input is None:
178+
raise ValueError(
179+
"An input must be provided as a positional or keyword argument."
180+
)
181+
129182
body = {
130-
"input": input,
183+
"input": encode_json(input, upload_file=upload_file),
131184
"destination": destination,
132185
}
133-
if webhook is not None:
134-
body["webhook"] = webhook
135-
if webhook_events_filter is not None:
136-
body["webhook_events_filter"] = webhook_events_filter
186+
187+
for key in ["webhook", "webhook_completed", "webhook_events_filter"]:
188+
value = kwargs.get(key)
189+
if value is not None:
190+
body[key] = value
137191

138192
# Split version in format "username/model_name:version_id"
139193
match = re.match(
140-
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$", version
194+
r"^(?P<username>[^/]+)/(?P<model_name>[^:]+):(?P<version_id>.+)$",
195+
version.id if isinstance(version, Version) else version,
141196
)
142197
if not match:
143198
raise ReplicateException(

replicate/version.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def get(self, id: str) -> Version:
9494
)
9595
return self.prepare_model(resp.json())
9696

97-
def create(self, **kwargs) -> Version:
97+
def create(
98+
self,
99+
*args,
100+
**kwargs,
101+
) -> Version:
98102
"""
99103
Create a model version.
100104

0 commit comments

Comments
 (0)