From 2ff2ce8204ff089aefbb3bbbb48b3f97f463ead1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 14 May 2024 11:02:18 -0700 Subject: [PATCH 1/3] Add overloads to models.get method Signed-off-by: Mattt Zmuda --- replicate/model.py | 69 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/replicate/model.py b/replicate/model.py index 2349fe5e..a0e36e22 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload from typing_extensions import NotRequired, TypedDict, Unpack, deprecated @@ -207,31 +207,40 @@ async def async_list( return Page[Model](**obj) - def get(self, key: str) -> Model: + @overload + def get(self, key: str) -> Model: ... + + @overload + def get(self, owner: str, name: str) -> Model: ... + + def get(self, *args, **kwargs) -> Model: """ Get a model by name. - - Args: - key: The qualified name of the model, in the format `owner/model-name`. - Returns: - The model. """ - resp = self._client._request("GET", f"/v1/models/{key}") + url = _get_model_url(*args, **kwargs) + resp = self._client._request("GET", url) return _json_to_model(self._client, resp.json()) - async def async_get(self, key: str) -> Model: + @overload + async def async_get(self, key: str) -> Model: ... + + @overload + async def async_get(self, owner: str, name: str) -> Model: ... + + async def async_get(self, *args, **kwargs) -> Model: """ Get a model by name. Args: - key: The qualified name of the model, in the format `owner/model-name`. + key: The qualified name of the model, in the format `owner/name`. Returns: The model. """ - resp = await self._client._async_request("GET", f"/v1/models/{key}") + url = _get_model_url(*args, **kwargs) + resp = await self._client._async_request("GET", url) return _json_to_model(self._client, resp.json()) @@ -288,6 +297,13 @@ async def async_create( return _json_to_model(self._client, resp.json()) + def delete(self, key: str) -> None: + """ + Delete a model. + """ + + self._client._request("DELETE", f"/v1/models/{key}") + class ModelsPredictions(Namespace): """ @@ -374,6 +390,37 @@ def _create_model_body( # pylint: disable=too-many-arguments return body +def _get_model_url(*args, **kwargs) -> str: + if len(args) > 0 and len(kwargs) > 0: + raise ValueError("Cannot mix positional and keyword arguments") + + owner = kwargs.get("owner", None) + name = kwargs.get("name", None) + key = kwargs.get("key", None) + + if key and (owner or name): + raise ValueError( + "Must specify exactly one of 'owner' and 'name' or single 'key' in the format 'owner/name'" + ) + + if args: + if len(args) == 1: + key = args[0] + elif len(args) == 2: + owner, name = args + else: + raise ValueError("Invalid number of arguments") + + if not key: + if not (owner and name): + raise ValueError( + "Both 'owner' and 'name' must be provided if 'key' is not specified." + ) + key = f"{owner}/{name}" + + return f"/v1/models/{key}" + + def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: model = Model(**json) model._client = client From 8d39403a1b5c7032fc867938e3abe830c3fdfd90 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Tue, 14 May 2024 11:11:20 -0700 Subject: [PATCH 2/3] Add overloads to models.delete method Signed-off-by: Mattt Zmuda --- replicate/model.py | 48 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/replicate/model.py b/replicate/model.py index a0e36e22..f408140f 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -244,6 +244,43 @@ async def async_get(self, *args, **kwargs) -> Model: return _json_to_model(self._client, resp.json()) + @overload + def delete(self, key: str) -> Model: ... + + @overload + def delete(self, owner: str, name: str) -> Model: ... + + def delete(self, *args, **kwargs) -> Model: + """ + Delete a model by name. + """ + + url = _delete_model_url(*args, **kwargs) + resp = self._client._request("DELETE", url) + + return _json_to_model(self._client, resp.json()) + + @overload + async def async_delete(self, key: str) -> Model: ... + + @overload + async def async_delete(self, owner: str, name: str) -> Model: ... + + async def async_delete(self, *args, **kwargs) -> Model: + """ + Delete a model by name. + + Args: + key: The qualified name of the model, in the format `owner/name`. + Returns: + The model. + """ + + url = _delete_model_url(*args, **kwargs) + resp = await self._client._async_request("DELETE", url) + + return _json_to_model(self._client, resp.json()) + class CreateModelParams(TypedDict): """Parameters for creating a model.""" @@ -297,13 +334,6 @@ async def async_create( return _json_to_model(self._client, resp.json()) - def delete(self, key: str) -> None: - """ - Delete a model. - """ - - self._client._request("DELETE", f"/v1/models/{key}") - class ModelsPredictions(Namespace): """ @@ -421,6 +451,10 @@ def _get_model_url(*args, **kwargs) -> str: return f"/v1/models/{key}" +def _delete_model_url(*args, **kwargs) -> str: + return _get_model_url(*args, **kwargs) + + def _json_to_model(client: "Client", json: Dict[str, Any]) -> Model: model = Model(**json) model._client = client From cd1201f35f97c0579b27dfc771b26f4d909ba3e5 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 28 Jun 2024 05:51:22 -0700 Subject: [PATCH 3/3] Add missing word to README Signed-off-by: Mattt Zmuda --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ae288708..cea980ab 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,7 @@ background = Image.open("/tmp/out.png") ## List models -You can the models you've created: +You can list the models you've created: ```python replicate.models.list()