From 1564104d65ace2804837c6f816d1cb2414185e57 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 02:37:32 -0800 Subject: [PATCH 1/7] Add test for trainings.create overload with model and version Signed-off-by: Mattt Zmuda --- tests/test_training.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_training.py b/tests/test_training.py index 0c4a4782..ce565d76 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -9,6 +9,23 @@ @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio async def test_trainings_create(mock_replicate_api_token): + training = replicate.trainings.create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) + + assert training.id is not None + assert training.status == "starting" + + +@pytest.mark.vcr("trainings-create.yaml") +@pytest.mark.asyncio +async def test_trainings_create_with_named_version_argument(mock_replicate_api_token): training = replicate.trainings.create( version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", input={ From f648f75e169368a5d978d3e09c1cdf0eb572c343 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 02:45:23 -0800 Subject: [PATCH 2/7] Add 'Fine-tune a model' section to README Signed-off-by: Mattt Zmuda --- README.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/README.md b/README.md index b8c3df54..d70be2a7 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,34 @@ Here's how to list of all the available hardware for running models on Replicate ['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large'] ``` +## Fine-tune a model + +Use the [training API](https://replicate.com/docs/fine-tuning) +to fine-tune language models to make them better at a particular task. +To see what **language models** currently support fine-tuning, +check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). + +If you're looking to fine-tune **image models**, +check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). + +Here's how to fine-tune a model on Replicate: + +```python +training = replicate.trainings.create( + model="stability-ai/sdxl", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", + input={ + "input_images": "https://my-domain/training-images.zip", + "token_string": "TOK", + "caption_prefix": "a photo of TOK", + "max_train_steps": 1000, + "use_face_detection_instead": False + }, + # You need to create a model on Replicate that will be the destination for the trained version. + destination="your-username/model-name" +) +``` + ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) From f92099fe4480d371dfdd85bab3d5d0e67ddcf4a2 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 03:01:43 -0800 Subject: [PATCH 3/7] Update Training.cancel instance method to mutate caller This matches the behavior of reload Signed-off-by: Mattt Zmuda --- replicate/training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/replicate/training.py b/replicate/training.py index 8cb74c64..030eecd6 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -84,7 +84,10 @@ class Training(Resource): def cancel(self) -> None: """Cancel a running training""" - self._client.trainings.cancel(self.id) + + canceled = self._client.trainings.cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) def reload(self) -> None: """ From deecda7793be0bf0b85ee435cddeda796b6710dd Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 03:02:36 -0800 Subject: [PATCH 4/7] Add test coverage for async variants of training methods Add test coverage for trainings.cancel Signed-off-by: Mattt Zmuda --- tests/test_training.py | 187 ++++++++++++++++++++++++++++++----------- 1 file changed, 138 insertions(+), 49 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index ce565d76..1955ffe6 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -8,16 +8,28 @@ @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create(mock_replicate_api_token): - training = replicate.trainings.create( - model="stability-ai/sdxl", - version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - input={ - "input_images": input_images_url, - "use_face_detection_instead": True, - }, - destination="replicate/dreambooth-sdxl", - ) +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create(async_flag, mock_replicate_api_token): + if async_flag: + training = await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) + else: + training = replicate.trainings.create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) assert training.id is not None assert training.status == "starting" @@ -25,15 +37,23 @@ async def test_trainings_create(mock_replicate_api_token): @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create_with_named_version_argument(mock_replicate_api_token): - training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - input={ - "input_images": input_images_url, - "use_face_detection_instead": True, - }, - destination="replicate/dreambooth-sdxl", - ) +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_named_version_argument( + async_flag, mock_replicate_api_token +): + if async_flag: + # The overload with a model version identifier is soft-deprecated + # and not supported in the async version. + return + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) assert training.id is not None assert training.status == "starting" @@ -41,39 +61,66 @@ async def test_trainings_create_with_named_version_argument(mock_replicate_api_t @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create_with_positional_argument(mock_replicate_api_token): - training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - { - "input_images": input_images_url, - "use_face_detection_instead": True, - }, - "replicate/dreambooth-sdxl", - ) +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_positional_argument( + async_flag, mock_replicate_api_token +): + if async_flag: + # The overload with positional arguments is soft-deprecated + # and not supported in the async version. + return + else: + training = replicate.trainings.create( + "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + { + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + "replicate/dreambooth-sdxl", + ) - assert training.id is not None - assert training.status == "starting" + assert training.id is not None + assert training.status == "starting" @pytest.mark.vcr("trainings-create__invalid-destination.yaml") @pytest.mark.asyncio -async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_invalid_destination( + async_flag, mock_replicate_api_token +): with pytest.raises(ReplicateException): - replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - input={ - "input_images": input_images_url, - }, - destination="", - ) + if async_flag: + await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="", + ) + else: + replicate.trainings.create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + }, + destination="", + ) @pytest.mark.vcr("trainings-get.yaml") @pytest.mark.asyncio -async def test_trainings_get(mock_replicate_api_token): +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_get(async_flag, mock_replicate_api_token): id = "medrnz3bm5dd6ultvad2tejrte" - training = replicate.trainings.get(id) + if async_flag: + training = await replicate.trainings.async_get(id) + else: + training = replicate.trainings.get(id) assert training.id == id assert training.status == "processing" @@ -81,7 +128,8 @@ async def test_trainings_get(mock_replicate_api_token): @pytest.mark.vcr("trainings-cancel.yaml") @pytest.mark.asyncio -async def test_trainings_cancel(mock_replicate_api_token): +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_cancel(async_flag, mock_replicate_api_token): input = { "input_images": input_images_url, "use_face_detection_instead": True, @@ -89,13 +137,54 @@ async def test_trainings_cancel(mock_replicate_api_token): destination = "replicate/dreambooth-sdxl" - training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - destination=destination, - input=input, - ) + if async_flag: + training = await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input=input, + destination=destination, + ) + + assert training.status == "starting" - assert training.status == "starting" + training = replicate.trainings.cancel(training.id) + assert training.status == "canceled" + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, + ) + + assert training.status == "starting" + + training = replicate.trainings.cancel(training.id) + assert training.status == "canceled" + + +@pytest.mark.vcr("trainings-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token): + input = { + "input_images": input_images_url, + "use_face_detection_instead": True, + } + + destination = "replicate/dreambooth-sdxl" + + if async_flag: + # The cancel instance method is soft-deprecated, + # and not supported in the async version. + return + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, + ) + + assert training.status == "starting" - # training = replicate.trainings.cancel(training) - training.cancel() + training.cancel() + assert training.status == "canceled" From 61e5f812f5d549498a7e3a6801d5012273651426 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 03:08:12 -0800 Subject: [PATCH 5/7] Add async_cancel and async_reload methods Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 9 +++++++++ replicate/training.py | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index 4319856e..014a77ce 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -177,6 +177,15 @@ def cancel(self) -> None: for name, value in canceled.dict().items(): setattr(self, name, value) + async def async_cancel(self) -> None: + """ + Cancels a running prediction asynchronously. + """ + + canceled = await self._client.predictions.async_cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) + def reload(self) -> None: """ Load this prediction from the server. diff --git a/replicate/training.py b/replicate/training.py index 030eecd6..091e1990 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -83,12 +83,23 @@ class Training(Resource): """ def cancel(self) -> None: - """Cancel a running training""" + """ + Cancel a running training. + """ canceled = self._client.trainings.cancel(self.id) for name, value in canceled.dict().items(): setattr(self, name, value) + async def async_cancel(self) -> None: + """ + Cancel a running training asynchronously. + """ + + canceled = await self._client.trainings.async_cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) + def reload(self) -> None: """ Load the training from the server. @@ -98,6 +109,15 @@ def reload(self) -> None: for name, value in updated.dict().items(): setattr(self, name, value) + async def async_reload(self) -> None: + """ + Load the training from the server asynchronously. + """ + + updated = await self._client.trainings.async_get(self.id) + for name, value in updated.dict().items(): + setattr(self, name, value) + class Trainings(Namespace): """ From 813177a37ffdb805e08cc6777c4408c9253b37d3 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sun, 4 Feb 2024 03:11:20 -0800 Subject: [PATCH 6/7] Add test coverage for predictions.cancel and predictions.async_cancel Signed-off-by: Mattt Zmuda --- tests/test_prediction.py | 57 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b50259cb..c64a5989 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -101,6 +101,13 @@ async def test_predictions_cancel(async_flag): version=version, input=input, ) + + id = prediction.id + assert prediction.status == "starting" + + prediction = await replicate.predictions.async_cancel(prediction.id) + assert prediction.id == id + assert prediction.status == "canceled" else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( @@ -111,11 +118,53 @@ async def test_predictions_cancel(async_flag): input=input, ) - # id = prediction.id - assert prediction.status == "starting" + id = prediction.id + assert prediction.status == "starting" + + prediction = replicate.predictions.cancel(prediction.id) + assert prediction.id == id + assert prediction.status == "canceled" + + +@pytest.mark.vcr("predictions-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_cancel_instance_method(async_flag): + input = { + "prompt": "a studio photo of a rainbow colored corgi", + "width": 512, + "height": 512, + "seed": 42069, + } + + if async_flag: + model = await replicate.models.async_get("stability-ai/sdxl") + version = await model.versions.async_get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + ) + prediction = await replicate.predictions.async_create( + version=version, + input=input, + ) + + assert prediction.status == "starting" + + await prediction.async_cancel() + assert prediction.status == "canceled" + else: + model = replicate.models.get("stability-ai/sdxl") + version = model.versions.get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + ) + prediction = replicate.predictions.create( + version=version, + input=input, + ) + + assert prediction.status == "starting" - # prediction = replicate.predictions.cancel(prediction) - prediction.cancel() + prediction.cancel() + assert prediction.status == "canceled" @pytest.mark.vcr("predictions-stream.yaml") From 1092a8ec55ad1b1aba8ad4e1161f4d26ecbb6c23 Mon Sep 17 00:00:00 2001 From: Mattt Date: Mon, 5 Feb 2024 08:59:00 -0800 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Zeke Sikelianos --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d70be2a7..695b95d2 100644 --- a/README.md +++ b/README.md @@ -308,7 +308,7 @@ Here's how to list of all the available hardware for running models on Replicate ## Fine-tune a model Use the [training API](https://replicate.com/docs/fine-tuning) -to fine-tune language models to make them better at a particular task. +to fine-tune models to make them better at a particular task. To see what **language models** currently support fine-tuning, check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models).