Skip to content

Commit 371731c

Browse files
matttzeke
authored andcommitted
Document fine-tuning and clean up trainings API (#240)
This PR improves the experience of using Replicate's [training APIs](https://replicate.com/docs/reference/http#trainings.create) by making the following changes: - Adding a section to the README discussing how to create a training - Updating the `Training.cancel` instance method to mutate the caller, matching the behavior of `prediction.reload` - Adding test coverage for async variants of training namespace methods This PR also makes corresponding changes to predictions APIs to fix a few discrepancies that were found along the way. --------- Signed-off-by: Mattt Zmuda <[email protected]> Co-authored-by: Zeke Sikelianos <[email protected]>
1 parent cb222e5 commit 371731c

File tree

5 files changed

+256
-41
lines changed

5 files changed

+256
-41
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,34 @@ Here's how to list of all the available hardware for running models on Replicate
305305
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
306306
```
307307

308+
## Fine-tune a model
309+
310+
Use the [training API](https://replicate.com/docs/fine-tuning)
311+
to fine-tune models to make them better at a particular task.
312+
To see what **language models** currently support fine-tuning,
313+
check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models).
314+
315+
If you're looking to fine-tune **image models**,
316+
check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model).
317+
318+
Here's how to fine-tune a model on Replicate:
319+
320+
```python
321+
training = replicate.trainings.create(
322+
model="stability-ai/sdxl",
323+
version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
324+
input={
325+
"input_images": "https://my-domain/training-images.zip",
326+
"token_string": "TOK",
327+
"caption_prefix": "a photo of TOK",
328+
"max_train_steps": 1000,
329+
"use_face_detection_instead": False
330+
},
331+
# You need to create a model on Replicate that will be the destination for the trained version.
332+
destination="your-username/model-name"
333+
)
334+
```
335+
308336
## Development
309337

310338
See [CONTRIBUTING.md](CONTRIBUTING.md)

replicate/prediction.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ def cancel(self) -> None:
177177
for name, value in canceled.dict().items():
178178
setattr(self, name, value)
179179

180+
async def async_cancel(self) -> None:
181+
"""
182+
Cancels a running prediction asynchronously.
183+
"""
184+
185+
canceled = await self._client.predictions.async_cancel(self.id)
186+
for name, value in canceled.dict().items():
187+
setattr(self, name, value)
188+
180189
def reload(self) -> None:
181190
"""
182191
Load this prediction from the server.

replicate/training.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,22 @@ class Training(Resource):
8383
"""
8484

8585
def cancel(self) -> None:
86-
"""Cancel a running training"""
87-
self._client.trainings.cancel(self.id)
86+
"""
87+
Cancel a running training.
88+
"""
89+
90+
canceled = self._client.trainings.cancel(self.id)
91+
for name, value in canceled.dict().items():
92+
setattr(self, name, value)
93+
94+
async def async_cancel(self) -> None:
95+
"""
96+
Cancel a running training asynchronously.
97+
"""
98+
99+
canceled = await self._client.trainings.async_cancel(self.id)
100+
for name, value in canceled.dict().items():
101+
setattr(self, name, value)
88102

89103
def reload(self) -> None:
90104
"""
@@ -95,6 +109,15 @@ def reload(self) -> None:
95109
for name, value in updated.dict().items():
96110
setattr(self, name, value)
97111

112+
async def async_reload(self) -> None:
113+
"""
114+
Load the training from the server asynchronously.
115+
"""
116+
117+
updated = await self._client.trainings.async_get(self.id)
118+
for name, value in updated.dict().items():
119+
setattr(self, name, value)
120+
98121

99122
class Trainings(Namespace):
100123
"""

tests/test_prediction.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ async def test_predictions_cancel(async_flag):
101101
version=version,
102102
input=input,
103103
)
104+
105+
id = prediction.id
106+
assert prediction.status == "starting"
107+
108+
prediction = await replicate.predictions.async_cancel(prediction.id)
109+
assert prediction.id == id
110+
assert prediction.status == "canceled"
104111
else:
105112
model = replicate.models.get("stability-ai/sdxl")
106113
version = model.versions.get(
@@ -111,11 +118,53 @@ async def test_predictions_cancel(async_flag):
111118
input=input,
112119
)
113120

114-
# id = prediction.id
115-
assert prediction.status == "starting"
121+
id = prediction.id
122+
assert prediction.status == "starting"
123+
124+
prediction = replicate.predictions.cancel(prediction.id)
125+
assert prediction.id == id
126+
assert prediction.status == "canceled"
127+
128+
129+
@pytest.mark.vcr("predictions-cancel.yaml")
130+
@pytest.mark.asyncio
131+
@pytest.mark.parametrize("async_flag", [True, False])
132+
async def test_predictions_cancel_instance_method(async_flag):
133+
input = {
134+
"prompt": "a studio photo of a rainbow colored corgi",
135+
"width": 512,
136+
"height": 512,
137+
"seed": 42069,
138+
}
139+
140+
if async_flag:
141+
model = await replicate.models.async_get("stability-ai/sdxl")
142+
version = await model.versions.async_get(
143+
"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
144+
)
145+
prediction = await replicate.predictions.async_create(
146+
version=version,
147+
input=input,
148+
)
149+
150+
assert prediction.status == "starting"
151+
152+
await prediction.async_cancel()
153+
assert prediction.status == "canceled"
154+
else:
155+
model = replicate.models.get("stability-ai/sdxl")
156+
version = model.versions.get(
157+
"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
158+
)
159+
prediction = replicate.predictions.create(
160+
version=version,
161+
input=input,
162+
)
163+
164+
assert prediction.status == "starting"
116165

117-
# prediction = replicate.predictions.cancel(prediction)
118-
prediction.cancel()
166+
prediction.cancel()
167+
assert prediction.status == "canceled"
119168

120169

121170
@pytest.mark.vcr("predictions-stream.yaml")

tests/test_training.py

Lines changed: 141 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,77 +8,183 @@
88

99
@pytest.mark.vcr("trainings-create.yaml")
1010
@pytest.mark.asyncio
11-
async def test_trainings_create(mock_replicate_api_token):
12-
training = replicate.trainings.create(
13-
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
14-
input={
15-
"input_images": input_images_url,
16-
"use_face_detection_instead": True,
17-
},
18-
destination="replicate/dreambooth-sdxl",
19-
)
11+
@pytest.mark.parametrize("async_flag", [True, False])
12+
async def test_trainings_create(async_flag, mock_replicate_api_token):
13+
if async_flag:
14+
training = await replicate.trainings.async_create(
15+
model="stability-ai/sdxl",
16+
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
17+
input={
18+
"input_images": input_images_url,
19+
"use_face_detection_instead": True,
20+
},
21+
destination="replicate/dreambooth-sdxl",
22+
)
23+
else:
24+
training = replicate.trainings.create(
25+
model="stability-ai/sdxl",
26+
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
27+
input={
28+
"input_images": input_images_url,
29+
"use_face_detection_instead": True,
30+
},
31+
destination="replicate/dreambooth-sdxl",
32+
)
2033

2134
assert training.id is not None
2235
assert training.status == "starting"
2336

2437

2538
@pytest.mark.vcr("trainings-create.yaml")
2639
@pytest.mark.asyncio
27-
async def test_trainings_create_with_positional_argument(mock_replicate_api_token):
28-
training = replicate.trainings.create(
29-
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
30-
{
31-
"input_images": input_images_url,
32-
"use_face_detection_instead": True,
33-
},
34-
"replicate/dreambooth-sdxl",
35-
)
40+
@pytest.mark.parametrize("async_flag", [True, False])
41+
async def test_trainings_create_with_named_version_argument(
42+
async_flag, mock_replicate_api_token
43+
):
44+
if async_flag:
45+
# The overload with a model version identifier is soft-deprecated
46+
# and not supported in the async version.
47+
return
48+
else:
49+
training = replicate.trainings.create(
50+
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
51+
input={
52+
"input_images": input_images_url,
53+
"use_face_detection_instead": True,
54+
},
55+
destination="replicate/dreambooth-sdxl",
56+
)
3657

3758
assert training.id is not None
3859
assert training.status == "starting"
3960

4061

41-
@pytest.mark.vcr("trainings-create__invalid-destination.yaml")
62+
@pytest.mark.vcr("trainings-create.yaml")
4263
@pytest.mark.asyncio
43-
async def test_trainings_create_with_invalid_destination(mock_replicate_api_token):
44-
with pytest.raises(ReplicateException):
45-
replicate.trainings.create(
64+
@pytest.mark.parametrize("async_flag", [True, False])
65+
async def test_trainings_create_with_positional_argument(
66+
async_flag, mock_replicate_api_token
67+
):
68+
if async_flag:
69+
# The overload with positional arguments is soft-deprecated
70+
# and not supported in the async version.
71+
return
72+
else:
73+
training = replicate.trainings.create(
4674
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
47-
input={
75+
{
4876
"input_images": input_images_url,
77+
"use_face_detection_instead": True,
4978
},
50-
destination="<invalid>",
79+
"replicate/dreambooth-sdxl",
5180
)
5281

82+
assert training.id is not None
83+
assert training.status == "starting"
84+
85+
86+
@pytest.mark.vcr("trainings-create__invalid-destination.yaml")
87+
@pytest.mark.asyncio
88+
@pytest.mark.parametrize("async_flag", [True, False])
89+
async def test_trainings_create_with_invalid_destination(
90+
async_flag, mock_replicate_api_token
91+
):
92+
with pytest.raises(ReplicateException):
93+
if async_flag:
94+
await replicate.trainings.async_create(
95+
model="stability-ai/sdxl",
96+
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
97+
input={
98+
"input_images": input_images_url,
99+
"use_face_detection_instead": True,
100+
},
101+
destination="<invalid>",
102+
)
103+
else:
104+
replicate.trainings.create(
105+
model="stability-ai/sdxl",
106+
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
107+
input={
108+
"input_images": input_images_url,
109+
},
110+
destination="<invalid>",
111+
)
112+
53113

54114
@pytest.mark.vcr("trainings-get.yaml")
55115
@pytest.mark.asyncio
56-
async def test_trainings_get(mock_replicate_api_token):
116+
@pytest.mark.parametrize("async_flag", [True, False])
117+
async def test_trainings_get(async_flag, mock_replicate_api_token):
57118
id = "medrnz3bm5dd6ultvad2tejrte"
58119

59-
training = replicate.trainings.get(id)
120+
if async_flag:
121+
training = await replicate.trainings.async_get(id)
122+
else:
123+
training = replicate.trainings.get(id)
60124

61125
assert training.id == id
62126
assert training.status == "processing"
63127

64128

65129
@pytest.mark.vcr("trainings-cancel.yaml")
66130
@pytest.mark.asyncio
67-
async def test_trainings_cancel(mock_replicate_api_token):
131+
@pytest.mark.parametrize("async_flag", [True, False])
132+
async def test_trainings_cancel(async_flag, mock_replicate_api_token):
68133
input = {
69134
"input_images": input_images_url,
70135
"use_face_detection_instead": True,
71136
}
72137

73138
destination = "replicate/dreambooth-sdxl"
74139

75-
training = replicate.trainings.create(
76-
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
77-
destination=destination,
78-
input=input,
79-
)
140+
if async_flag:
141+
training = await replicate.trainings.async_create(
142+
model="stability-ai/sdxl",
143+
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
144+
input=input,
145+
destination=destination,
146+
)
80147

81-
assert training.status == "starting"
148+
assert training.status == "starting"
149+
150+
training = replicate.trainings.cancel(training.id)
151+
assert training.status == "canceled"
152+
else:
153+
training = replicate.trainings.create(
154+
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
155+
destination=destination,
156+
input=input,
157+
)
158+
159+
assert training.status == "starting"
160+
161+
training = replicate.trainings.cancel(training.id)
162+
assert training.status == "canceled"
163+
164+
165+
@pytest.mark.vcr("trainings-cancel.yaml")
166+
@pytest.mark.asyncio
167+
@pytest.mark.parametrize("async_flag", [True, False])
168+
async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token):
169+
input = {
170+
"input_images": input_images_url,
171+
"use_face_detection_instead": True,
172+
}
173+
174+
destination = "replicate/dreambooth-sdxl"
175+
176+
if async_flag:
177+
# The cancel instance method is soft-deprecated,
178+
# and not supported in the async version.
179+
return
180+
else:
181+
training = replicate.trainings.create(
182+
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
183+
destination=destination,
184+
input=input,
185+
)
186+
187+
assert training.status == "starting"
82188

83-
# training = replicate.trainings.cancel(training)
84-
training.cancel()
189+
training.cancel()
190+
assert training.status == "canceled"

0 commit comments

Comments
 (0)