diff --git a/replicate/training.py b/replicate/training.py index b60c9613..06391d2d 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -30,7 +30,13 @@ class TrainingCollection(Collection): model = Training def list(self) -> List[Training]: - raise NotImplementedError() + resp = self._client._request("GET", "/v1/trainings") + # TODO: paginate + trainings = resp.json()["results"] + for training in trainings: + # HACK: resolve this? make it lazy somehow? + del training["version"] + return [self.prepare_model(obj) for obj in trainings] def get(self, id: str) -> Training: resp = self._client._request( diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 00000000..b74938db --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,175 @@ +import responses +from responses import matchers + +from .factories import create_client, create_version + + +@responses.activate +def test_create_works_with_webhooks(): + client = create_client() + version = create_version(client) + + rsp = responses.post( + "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", + match=[ + matchers.json_params_matcher( + { + "input": {"data": "..."}, + "destination": "new_owner/new_model", + "webhook": "https://example.com/webhook", + "webhook_events_filter": ["completed"], + } + ), + ], + json={ + "id": "t1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/trainings/t1", + "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"data": "..."}, + "output": None, + "error": None, + "logs": "", + }, + ) + + client.trainings.create( + version=f"owner/model:{version.id}", + input={"data": "..."}, + destination="new_owner/new_model", + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + ) + + assert rsp.call_count == 1 + + +@responses.activate +def test_cancel(): + client = create_client() + version = create_version(client) + + responses.post( + "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", + match=[ + matchers.json_params_matcher( + { + "input": {"data": "..."}, + "destination": "new_owner/new_model", + "webhook": "https://example.com/webhook", + "webhook_events_filter": ["completed"], + } + ), + ], + json={ + "id": "t1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/trainings/t1", + "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"data": "..."}, + "output": None, + "error": None, + "logs": "", + }, + ) + + training = client.trainings.create( + version=f"owner/model:{version.id}", + input={"data": "..."}, + destination="new_owner/new_model", + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + ) + + rsp = responses.post("https://api.replicate.com/v1/trainings/t1/cancel", json={}) + training.cancel() + assert rsp.call_count == 1 + + +@responses.activate +def test_async_timings(): + client = create_client() + version = create_version(client) + + responses.post( + "https://api.replicate.com/v1/models/owner/model/versions/v1/trainings", + match=[ + matchers.json_params_matcher( + { + "input": {"data": "..."}, + "destination": "new_owner/new_model", + "webhook": "https://example.com/webhook", + "webhook_events_filter": ["completed"], + } + ), + ], + json={ + "id": "t1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/trainings/t1", + "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "source": "api", + "status": "processing", + "input": {"data": "..."}, + "output": None, + "error": None, + "logs": "", + }, + ) + + responses.get( + "https://api.replicate.com/v1/trainings/t1", + json={ + "id": "t1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/trainings/t1", + "cancel": "https://api.replicate.com/v1/trainings/t1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "succeeded", + "input": {"data": "..."}, + "output": { + "weights": "https://delivery.replicate.com/weights.tgz", + "version": "v2", + }, + "error": None, + "logs": "", + }, + ) + + training = client.trainings.create( + version=f"owner/model:{version.id}", + input={"data": "..."}, + destination="new_owner/new_model", + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + ) + + assert training.created_at == "2022-04-26T20:00:40.658234Z" + assert training.completed_at is None + assert training.output is None + + # trainings don't have a wait method, so simulate it by calling reload + training.reload() + assert training.created_at == "2022-04-26T20:00:40.658234Z" + assert training.completed_at == "2022-04-26T20:02:27.648305Z" + assert training.output["weights"] == "https://delivery.replicate.com/weights.tgz" + assert training.output["version"] == "v2"