diff --git a/README.md b/README.md index b8c3df54..695b95d2 100644 --- a/README.md +++ b/README.md @@ -305,6 +305,34 @@ Here's how to list of all the available hardware for running models on Replicate ['cpu', 'gpu-t4', 'gpu-a40-small', 'gpu-a40-large'] ``` +## Fine-tune a model + +Use the [training API](https://replicate.com/docs/fine-tuning) +to fine-tune models to make them better at a particular task. +To see what **language models** currently support fine-tuning, +check out Replicate's [collection of trainable language models](https://replicate.com/collections/trainable-language-models). + +If you're looking to fine-tune **image models**, +check out Replicate's [guide to fine-tuning image models](https://replicate.com/docs/guides/fine-tune-an-image-model). + +Here's how to fine-tune a model on Replicate: + +```python +training = replicate.trainings.create( + model="stability-ai/sdxl", + version="39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", + input={ + "input_images": "https://my-domain/training-images.zip", + "token_string": "TOK", + "caption_prefix": "a photo of TOK", + "max_train_steps": 1000, + "use_face_detection_instead": False + }, + # You need to create a model on Replicate that will be the destination for the trained version. + destination="your-username/model-name" +) +``` + ## Development See [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/replicate/prediction.py b/replicate/prediction.py index 4319856e..014a77ce 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -177,6 +177,15 @@ def cancel(self) -> None: for name, value in canceled.dict().items(): setattr(self, name, value) + async def async_cancel(self) -> None: + """ + Cancels a running prediction asynchronously. + """ + + canceled = await self._client.predictions.async_cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) + def reload(self) -> None: """ Load this prediction from the server. diff --git a/replicate/training.py b/replicate/training.py index 8cb74c64..091e1990 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -83,8 +83,22 @@ class Training(Resource): """ def cancel(self) -> None: - """Cancel a running training""" - self._client.trainings.cancel(self.id) + """ + Cancel a running training. + """ + + canceled = self._client.trainings.cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) + + async def async_cancel(self) -> None: + """ + Cancel a running training asynchronously. + """ + + canceled = await self._client.trainings.async_cancel(self.id) + for name, value in canceled.dict().items(): + setattr(self, name, value) def reload(self) -> None: """ @@ -95,6 +109,15 @@ def reload(self) -> None: for name, value in updated.dict().items(): setattr(self, name, value) + async def async_reload(self) -> None: + """ + Load the training from the server asynchronously. + """ + + updated = await self._client.trainings.async_get(self.id) + for name, value in updated.dict().items(): + setattr(self, name, value) + class Trainings(Namespace): """ diff --git a/tests/test_prediction.py b/tests/test_prediction.py index b50259cb..c64a5989 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -101,6 +101,13 @@ async def test_predictions_cancel(async_flag): version=version, input=input, ) + + id = prediction.id + assert prediction.status == "starting" + + prediction = await replicate.predictions.async_cancel(prediction.id) + assert prediction.id == id + assert prediction.status == "canceled" else: model = replicate.models.get("stability-ai/sdxl") version = model.versions.get( @@ -111,11 +118,53 @@ async def test_predictions_cancel(async_flag): input=input, ) - # id = prediction.id - assert prediction.status == "starting" + id = prediction.id + assert prediction.status == "starting" + + prediction = replicate.predictions.cancel(prediction.id) + assert prediction.id == id + assert prediction.status == "canceled" + + +@pytest.mark.vcr("predictions-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_predictions_cancel_instance_method(async_flag): + input = { + "prompt": "a studio photo of a rainbow colored corgi", + "width": 512, + "height": 512, + "seed": 42069, + } + + if async_flag: + model = await replicate.models.async_get("stability-ai/sdxl") + version = await model.versions.async_get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + ) + prediction = await replicate.predictions.async_create( + version=version, + input=input, + ) + + assert prediction.status == "starting" + + await prediction.async_cancel() + assert prediction.status == "canceled" + else: + model = replicate.models.get("stability-ai/sdxl") + version = model.versions.get( + "a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5" + ) + prediction = replicate.predictions.create( + version=version, + input=input, + ) + + assert prediction.status == "starting" - # prediction = replicate.predictions.cancel(prediction) - prediction.cancel() + prediction.cancel() + assert prediction.status == "canceled" @pytest.mark.vcr("predictions-stream.yaml") diff --git a/tests/test_training.py b/tests/test_training.py index 0c4a4782..1955ffe6 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -8,15 +8,28 @@ @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create(mock_replicate_api_token): - training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - input={ - "input_images": input_images_url, - "use_face_detection_instead": True, - }, - destination="replicate/dreambooth-sdxl", - ) +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create(async_flag, mock_replicate_api_token): + if async_flag: + training = await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) + else: + training = replicate.trainings.create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) assert training.id is not None assert training.status == "starting" @@ -24,39 +37,90 @@ async def test_trainings_create(mock_replicate_api_token): @pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create_with_positional_argument(mock_replicate_api_token): - training = replicate.trainings.create( - "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - { - "input_images": input_images_url, - "use_face_detection_instead": True, - }, - "replicate/dreambooth-sdxl", - ) +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_named_version_argument( + async_flag, mock_replicate_api_token +): + if async_flag: + # The overload with a model version identifier is soft-deprecated + # and not supported in the async version. + return + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="replicate/dreambooth-sdxl", + ) assert training.id is not None assert training.status == "starting" -@pytest.mark.vcr("trainings-create__invalid-destination.yaml") +@pytest.mark.vcr("trainings-create.yaml") @pytest.mark.asyncio -async def test_trainings_create_with_invalid_destination(mock_replicate_api_token): - with pytest.raises(ReplicateException): - replicate.trainings.create( +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_positional_argument( + async_flag, mock_replicate_api_token +): + if async_flag: + # The overload with positional arguments is soft-deprecated + # and not supported in the async version. + return + else: + training = replicate.trainings.create( "stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - input={ + { "input_images": input_images_url, + "use_face_detection_instead": True, }, - destination="", + "replicate/dreambooth-sdxl", ) + assert training.id is not None + assert training.status == "starting" + + +@pytest.mark.vcr("trainings-create__invalid-destination.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_create_with_invalid_destination( + async_flag, mock_replicate_api_token +): + with pytest.raises(ReplicateException): + if async_flag: + await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + "use_face_detection_instead": True, + }, + destination="", + ) + else: + replicate.trainings.create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input={ + "input_images": input_images_url, + }, + destination="", + ) + @pytest.mark.vcr("trainings-get.yaml") @pytest.mark.asyncio -async def test_trainings_get(mock_replicate_api_token): +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_get(async_flag, mock_replicate_api_token): id = "medrnz3bm5dd6ultvad2tejrte" - training = replicate.trainings.get(id) + if async_flag: + training = await replicate.trainings.async_get(id) + else: + training = replicate.trainings.get(id) assert training.id == id assert training.status == "processing" @@ -64,7 +128,8 @@ async def test_trainings_get(mock_replicate_api_token): @pytest.mark.vcr("trainings-cancel.yaml") @pytest.mark.asyncio -async def test_trainings_cancel(mock_replicate_api_token): +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_cancel(async_flag, mock_replicate_api_token): input = { "input_images": input_images_url, "use_face_detection_instead": True, @@ -72,13 +137,54 @@ async def test_trainings_cancel(mock_replicate_api_token): destination = "replicate/dreambooth-sdxl" - training = replicate.trainings.create( - version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", - destination=destination, - input=input, - ) + if async_flag: + training = await replicate.trainings.async_create( + model="stability-ai/sdxl", + version="a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + input=input, + destination=destination, + ) - assert training.status == "starting" + assert training.status == "starting" + + training = replicate.trainings.cancel(training.id) + assert training.status == "canceled" + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, + ) + + assert training.status == "starting" + + training = replicate.trainings.cancel(training.id) + assert training.status == "canceled" + + +@pytest.mark.vcr("trainings-cancel.yaml") +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_trainings_cancel_instance_method(async_flag, mock_replicate_api_token): + input = { + "input_images": input_images_url, + "use_face_detection_instead": True, + } + + destination = "replicate/dreambooth-sdxl" + + if async_flag: + # The cancel instance method is soft-deprecated, + # and not supported in the async version. + return + else: + training = replicate.trainings.create( + version="stability-ai/sdxl:a00d0b7dcbb9c3fbb34ba87d2d5b46c56969c84a628bf778a7fdaec30b1b99c5", + destination=destination, + input=input, + ) + + assert training.status == "starting" - # training = replicate.trainings.cancel(training) - training.cancel() + training.cancel() + assert training.status == "canceled"