From 38573c247005f3edc47e87fc674234b0c55dcfca Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Sat, 27 Jan 2024 12:50:48 -0800 Subject: [PATCH] Implement models.versions.delete endpoint Signed-off-by: Mattt Zmuda --- replicate/version.py | 40 ++++++++++++++++++++ tests/test_version.py | 87 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 tests/test_version.py diff --git a/replicate/version.py b/replicate/version.py index f42d711b..5a67ef8b 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -113,6 +113,46 @@ async def async_list(self) -> Page[Version]: return Page[Version](**obj) + def delete(self, id: str) -> bool: + """ + Delete a model version and all associated predictions, including all output files. + + Model version deletion has some restrictions: + + * You can only delete versions from models you own. + * You can only delete versions from private models. + * You cannot delete a version if someone other than you + has run predictions with it. + + Args: + id: The version ID. + """ + + resp = self._client._request( + "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" + ) + return resp.status_code == 204 + + async def async_delete(self, id: str) -> bool: + """ + Delete a model version and all associated predictions, including all output files. + + Model version deletion has some restrictions: + + * You can only delete versions from models you own. + * You can only delete versions from private models. + * You cannot delete a version if someone other than you + has run predictions with it. + + Args: + id: The version ID. + """ + + resp = await self._client._async_request( + "DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}" + ) + return resp.status_code == 204 + def _json_to_version(json: Dict[str, Any]) -> Version: return Version(**json) diff --git a/tests/test_version.py b/tests/test_version.py new file mode 100644 index 00000000..50af02dd --- /dev/null +++ b/tests/test_version.py @@ -0,0 +1,87 @@ +import httpx +import pytest +import respx + +from replicate.client import Client + +router = respx.Router(base_url="https://api.replicate.com/v1") + +router.route( + method="GET", + path="/models/replicate/hello-world", + name="models.get", +).mock( + return_value=httpx.Response( + 200, + json={ + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "run_count": 1e10, + "url": "https://replicate.com/replicate/hello-world", + "created_at": "2022-04-26T19:13:45.911328Z", + "latest_version": { + "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "cog_version": "0.3.0", + "openapi_schema": { + "openapi": "3.0.2", + "info": {"title": "Cog", "version": "0.1.0"}, + "components": { + "schemas": { + "Input": { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '", + } + }, + }, + "Output": {"type": "string", "title": "Output"}, + } + }, + }, + "created_at": "2022-04-26T19:29:04.418669Z", + }, + }, + ) +) + +router.route( + method="DELETE", + path__regex=r"^/models/replicate/hello-world/versions/(?P\w+)/?", + name="models.versions.delete", +).mock( + return_value=httpx.Response( + 202, + ) +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("async_flag", [True, False]) +async def test_version_delete(async_flag): + client = Client( + api_token="test-token", transport=httpx.MockTransport(router.handler) + ) + + if async_flag: + model = await client.models.async_get("replicate/hello-world") + assert model is not None + assert model.latest_version is not None + + await model.versions.async_delete(model.latest_version.id) + else: + model = client.models.get("replicate/hello-world") + assert model is not None + assert model.latest_version is not None + + model.versions.delete(model.latest_version.id) + + assert router["models.get"].called + assert router["models.versions.delete"].called