Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,46 @@ async def async_list(self) -> Page[Version]:

return Page[Version](**obj)

def delete(self, id: str) -> bool:
"""
Delete a model version and all associated predictions, including all output files.

Model version deletion has some restrictions:

* You can only delete versions from models you own.
* You can only delete versions from private models.
* You cannot delete a version if someone other than you
has run predictions with it.

Args:
id: The version ID.
"""

resp = self._client._request(
"DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}"
)
return resp.status_code == 204

async def async_delete(self, id: str) -> bool:
"""
Delete a model version and all associated predictions, including all output files.

Model version deletion has some restrictions:

* You can only delete versions from models you own.
* You can only delete versions from private models.
* You cannot delete a version if someone other than you
has run predictions with it.

Args:
id: The version ID.
"""

resp = await self._client._async_request(
"DELETE", f"/v1/models/{self.model[0]}/{self.model[1]}/versions/{id}"
)
return resp.status_code == 204


def _json_to_version(json: Dict[str, Any]) -> Version:
return Version(**json)
87 changes: 87 additions & 0 deletions tests/test_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import httpx
import pytest
import respx

from replicate.client import Client

router = respx.Router(base_url="https://api.replicate.com/v1")

router.route(
method="GET",
path="/models/replicate/hello-world",
name="models.get",
).mock(
return_value=httpx.Response(
200,
json={
"owner": "replicate",
"name": "hello-world",
"description": "A tiny model that says hello",
"visibility": "public",
"run_count": 1e10,
"url": "https://replicate.com/replicate/hello-world",
"created_at": "2022-04-26T19:13:45.911328Z",
"latest_version": {
"id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
"cog_version": "0.3.0",
"openapi_schema": {
"openapi": "3.0.2",
"info": {"title": "Cog", "version": "0.1.0"},
"components": {
"schemas": {
"Input": {
"type": "object",
"title": "Input",
"required": ["text"],
"properties": {
"text": {
"type": "string",
"title": "Text",
"x-order": 0,
"description": "Text to prefix with 'hello '",
}
},
},
"Output": {"type": "string", "title": "Output"},
}
},
},
"created_at": "2022-04-26T19:29:04.418669Z",
},
},
)
)

router.route(
method="DELETE",
path__regex=r"^/models/replicate/hello-world/versions/(?P<id>\w+)/?",
name="models.versions.delete",
).mock(
return_value=httpx.Response(
202,
)
)


@pytest.mark.asyncio
@pytest.mark.parametrize("async_flag", [True, False])
async def test_version_delete(async_flag):
client = Client(
api_token="test-token", transport=httpx.MockTransport(router.handler)
)

if async_flag:
model = await client.models.async_get("replicate/hello-world")
assert model is not None
assert model.latest_version is not None

await model.versions.async_delete(model.latest_version.id)
else:
model = client.models.get("replicate/hello-world")
assert model is not None
assert model.latest_version is not None

model.versions.delete(model.latest_version.id)

assert router["models.get"].called
assert router["models.versions.delete"].called