Skip to content

Document fine-tuning and clean up trainings API #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,34 @@ Here's how to list of all the available hardware for running models on Replicate
['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large']
```

## Fine-tune a model

Use the [training API](https://replicate.com/docs/fine-tuning)
to fine-tune models to make them better at a particular task.
To see what **language models** currently support fine-tuning,
check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models).

If you're looking to fine-tune **image models**,
check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model).

Here's how to fine-tune a model on Replicate:

```python
training = replicate.trainings.create(
model="stability-ai/sdxl",
version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
input={
"input_images": "https://my-domain/training-images.zip",
"token_string": "TOK",
"caption_prefix": "a photo of TOK",
"max_train_steps": 1000,
"use_face_detection_instead": False
},
# You need to create a model on Replicate that will be the destination for the trained version.
destination="your-username/model-name"
)
```

## Development

See [CONTRIBUTING.md](CONTRIBUTING.md)
9 changes: 9 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,15 @@ def cancel(self) -> None:
for name, value in canceled.dict().items():
setattr(self, name, value)

async def async_cancel(self) -> None:
"""
Cancels a running prediction asynchronously.
"""

canceled = await self._client.predictions.async_cancel(self.id)
for name, value in canceled.dict().items():
setattr(self, name, value)

def reload(self) -> None:
"""
Load this prediction from the server.
Expand Down
27 changes: 25 additions & 2 deletions replicate/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,22 @@ class Training(Resource):
"""

def cancel(self) -> None:
"""Cancel a running training"""
self._client.trainings.cancel(self.id)
"""
Cancel a running training.
"""

canceled = self._client.trainings.cancel(self.id)
for name, value in canceled.dict().items():
setattr(self, name, value)

async def async_cancel(self) -> None:
"""
Cancel a running training asynchronously.
"""

canceled = await self._client.trainings.async_cancel(self.id)
for name, value in canceled.dict().items():
setattr(self, name, value)

def reload(self) -> None:
"""
Expand All @@ -95,6 +109,15 @@ def reload(self) -> None:
for name, value in updated.dict().items():
setattr(self, name, value)

async def async_reload(self) -> None:
"""
Load the training from the server asynchronously.
"""

updated = await self._client.trainings.async_get(self.id)
for name, value in updated.dict().items():
setattr(self, name, value)


class Trainings(Namespace):
"""
Expand Down
57 changes: 53 additions & 4 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ async def test_predictions_cancel(async_flag):
version=version,
input=input,
)

id = prediction.id
assert prediction.status == "starting"

prediction = await replicate.predictions.async_cancel(prediction.id)
assert prediction.id == id
assert prediction.status == "canceled"
else:
model = replicate.models.get("stability-ai/sdxl")
version = model.versions.get(
Expand All @@ -111,11 +118,53 @@ async def test_predictions_cancel(async_flag):
input=input,
)

# id = prediction.id
assert prediction.status == "starting"
id = prediction.id
assert prediction.status == "starting"

prediction = replicate.predictions.cancel(prediction.id)
assert prediction.id == id
assert prediction.status == "canceled"


@pytest.mark.vcr("predictions-cancel.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_predictions_cancel_instance_method(async_flag):
input = {
"prompt": "a studio photo of a rainbow colored corgi",
"width": 512,
"height": 512,
"seed": 42069,
}

if async_flag:
model = await replicate.models.async_get("stability-ai/sdxl")
version = await model.versions.async_get(
"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
)
prediction = await replicate.predictions.async_create(
version=version,
input=input,
)

assert prediction.status == "starting"

await prediction.async_cancel()
assert prediction.status == "canceled"
else:
model = replicate.models.get("stability-ai/sdxl")
version = model.versions.get(
"a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5"
)
prediction = replicate.predictions.create(
version=version,
input=input,
)

assert prediction.status == "starting"

# prediction = replicate.predictions.cancel(prediction)
prediction.cancel()
prediction.cancel()
assert prediction.status == "canceled"


@pytest.mark.vcr("predictions-stream.yaml")
Expand Down
176 changes: 141 additions & 35 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,77 +8,183 @@

@pytest.mark.vcr("trainings-create.yaml")
@pytest.mark.asyncio
async def test_trainings_create(mock_replicate_api_token):
training = replicate.trainings.create(
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="replicate/dreambooth-sdxl",
)
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_create(async_flag, mock_replicate_api_token):
if async_flag:
training = await replicate.trainings.async_create(
model="stability-ai/sdxl",
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="replicate/dreambooth-sdxl",
)
else:
training = replicate.trainings.create(
model="stability-ai/sdxl",
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="replicate/dreambooth-sdxl",
)

assert training.id is not None
assert training.status == "starting"


@pytest.mark.vcr("trainings-create.yaml")
@pytest.mark.asyncio
async def test_trainings_create_with_positional_argument(mock_replicate_api_token):
training = replicate.trainings.create(
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
{
"input_images": input_images_url,
"use_face_detection_instead": True,
},
"replicate/dreambooth-sdxl",
)
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_create_with_named_version_argument(
async_flag, mock_replicate_api_token
):
if async_flag:
# The overload with a model version identifier is soft-deprecated
# and not supported in the async version.
return
else:
training = replicate.trainings.create(
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="replicate/dreambooth-sdxl",
)

assert training.id is not None
assert training.status == "starting"


@pytest.mark.vcr("trainings-create__invalid-destination.yaml")
@pytest.mark.vcr("trainings-create.yaml")
@pytest.mark.asyncio
async def test_trainings_create_with_invalid_destination(mock_replicate_api_token):
with pytest.raises(ReplicateException):
replicate.trainings.create(
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_create_with_positional_argument(
async_flag, mock_replicate_api_token
):
if async_flag:
# The overload with positional arguments is soft-deprecated
# and not supported in the async version.
return
else:
training = replicate.trainings.create(
"stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
{
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="<invalid>",
"replicate/dreambooth-sdxl",
)

assert training.id is not None
assert training.status == "starting"


@pytest.mark.vcr("trainings-create__invalid-destination.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_create_with_invalid_destination(
async_flag, mock_replicate_api_token
):
with pytest.raises(ReplicateException):
if async_flag:
await replicate.trainings.async_create(
model="stability-ai/sdxl",
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
"use_face_detection_instead": True,
},
destination="<invalid>",
)
else:
replicate.trainings.create(
model="stability-ai/sdxl",
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input={
"input_images": input_images_url,
},
destination="<invalid>",
)


@pytest.mark.vcr("trainings-get.yaml")
@pytest.mark.asyncio
async def test_trainings_get(mock_replicate_api_token):
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_get(async_flag, mock_replicate_api_token):
id = "medrnz3bm5dd6ultvad2tejrte"

training = replicate.trainings.get(id)
if async_flag:
training = await replicate.trainings.async_get(id)
else:
training = replicate.trainings.get(id)

assert training.id == id
assert training.status == "processing"


@pytest.mark.vcr("trainings-cancel.yaml")
@pytest.mark.asyncio
async def test_trainings_cancel(mock_replicate_api_token):
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_cancel(async_flag, mock_replicate_api_token):
input = {
"input_images": input_images_url,
"use_face_detection_instead": True,
}

destination = "replicate/dreambooth-sdxl"

training = replicate.trainings.create(
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
destination=destination,
input=input,
)
if async_flag:
training = await replicate.trainings.async_create(
model="stability-ai/sdxl",
version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
input=input,
destination=destination,
)

assert training.status == "starting"
assert training.status == "starting"

training = replicate.trainings.cancel(training.id)
assert training.status == "canceled"
else:
training = replicate.trainings.create(
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
destination=destination,
input=input,
)

assert training.status == "starting"

training = replicate.trainings.cancel(training.id)
assert training.status == "canceled"


@pytest.mark.vcr("trainings-cancel.yaml")
@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token):
input = {
"input_images": input_images_url,
"use_face_detection_instead": True,
}

destination = "replicate/dreambooth-sdxl"

if async_flag:
# The cancel instance method is soft-deprecated,
# and not supported in the async version.
return
else:
training = replicate.trainings.create(
version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5",
destination=destination,
input=input,
)

assert training.status == "starting"

# training = replicate.trainings.cancel(training)
training.cancel()
training.cancel()
assert training.status == "canceled"