From c3ed7f788f4bb8e257d503c75ecd307f5ab60b38 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:24:35 -0700 Subject: [PATCH 1/7] Add docstrings to methods Signed-off-by: Mattt Zmuda --- replicate/client.py | 8 +++++++- replicate/model.py | 13 +++++++++++++ replicate/prediction.py | 40 ++++++++++++++++++++++++++++++++++++++-- replicate/training.py | 29 +++++++++++++++++++++++++++++ replicate/version.py | 19 ++++++++++++++++++- 5 files changed, 105 insertions(+), 4 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index a08dacf0..1a78eb78 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -115,7 +115,13 @@ def trainings(self) -> TrainingCollection: def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ - Run a model in the format owner/name:version. + Runs a model and waits for its output. + + Args: + model_version: The model version to run, in the format `owner/name:version` + kwargs: The input to the model, as a dictionary + Returns: + The output of the model """ # Split model_version into owner, name, version in format owner/name:version m = re.match(r"^(?P[^/]+/[^:]+):(?P.+)$", model_version) diff --git a/replicate/model.py b/replicate/model.py index d6b32fcd..5cf22c5b 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -17,6 +17,10 @@ def predict(self, *args, **kwargs) -> None: @property def versions(self) -> VersionCollection: + """ + Get the versions of this model. + """ + return VersionCollection(client=self._client, model=self) @@ -27,6 +31,15 @@ def list(self) -> List[Model]: raise NotImplementedError() def get(self, name: str) -> Model: + """ + Get a model by name. + + Args: + name: The name of the model, in the format `owner/model-name`. + Returns: + The model. + """ + # TODO: fetch model from server # TODO: support permanent IDs username, name = name.split("/") diff --git a/replicate/prediction.py b/replicate/prediction.py index db197ae3..8c3c006e 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -23,7 +23,9 @@ class Prediction(BaseModel): urls: Optional[Dict[str, str]] def wait(self) -> None: - """Wait for prediction to finish.""" + """ + Waits for prediction to finish. + """ while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) self.reload() @@ -48,7 +50,9 @@ def output_iterator(self) -> Iterator[Any]: yield output def cancel(self) -> None: - """Cancel a currently running prediction""" + """ + Cancels a running prediction. + """ self._client._request("POST", f"/v1/predictions/{self.id}/cancel") @@ -56,6 +60,13 @@ class PredictionCollection(Collection): model = Prediction def list(self) -> List[Prediction]: + """ + List your predictions. + + Returns: + List[Prediction]: A list of prediction objects. + """ + resp = self._client._request("GET", "/v1/predictions") # TODO: paginate predictions = resp.json()["results"] @@ -65,6 +76,15 @@ def list(self) -> List[Prediction]: return [self.prepare_model(obj) for obj in predictions] def get(self, id: str) -> Prediction: + """ + Get a prediction by ID. + + Args: + id (str): The ID of the prediction. + Returns: + Prediction: The prediction object. + """ + resp = self._client._request("GET", f"/v1/predictions/{id}") obj = resp.json() # HACK: resolve this? make it lazy somehow? @@ -80,6 +100,22 @@ def create( # type: ignore webhook_events_filter: Optional[List[str]] = None, **kwargs, ) -> Prediction: + """ + Create a new prediction for the specified model version. + + Args: + version (Version): The model version to use for the prediction. + input (Dict[str, Any]): The input data for the prediction. + webhook (Optional[str]): The URL to receive a POST request with prediction updates. + webhook_completed (Optional[str]): The URL to receive a POST request when the prediction is completed. + webhook_events_filter (Optional[List[str]]): List of events to trigger webhooks. + stream (Optional[bool]): Set to True to enable streaming of prediction output. + + Returns: + Prediction: The created prediction object. + + """ + input = encode_json(input, upload_file=upload_file) body = { "version": version.id, diff --git a/replicate/training.py b/replicate/training.py index d7e97bc3..8b8ddf63 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -31,6 +31,13 @@ class TrainingCollection(Collection): model = Training def list(self) -> List[Training]: + """ + List your trainings. + + Returns: + List[Training]: A list of training objects. + """ + resp = self._client._request("GET", "/v1/trainings") # TODO: paginate trainings = resp.json()["results"] @@ -40,6 +47,15 @@ def list(self) -> List[Training]: return [self.prepare_model(obj) for obj in trainings] def get(self, id: str) -> Training: + """ + Get a training by ID. + + Args: + id (str): The ID of the training. + Returns: + Training: The training object. + """ + resp = self._client._request( "GET", f"/v1/trainings/{id}", @@ -58,6 +74,19 @@ def create( # type: ignore webhook_events_filter: Optional[List[str]] = None, **kwargs, ) -> Training: + """ + Create a new training using the specified model version as a base. + + Args: + version (str): The ID of the base model version that you're using to train a new model version. + input (Dict[str, Any]): The input to the training. + destination (str): The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. + webhook (Optional[str], optional): The URL to send a POST request to when the training is completed. Defaults to None. + webhook_events_filter (Optional[List[str]], optional): The events to send to the webhook. Defaults to None. + Returns: + Training: The training object. + """ + input = encode_json(input, upload_file=upload_file) body = { "input": input, diff --git a/replicate/version.py b/replicate/version.py index d4ed9108..f1646774 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -20,6 +20,15 @@ class Version(BaseModel): openapi_schema: dict def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: + """ + Create a prediction using this model version. + + Args: + kwargs: The input to the model. + Returns: + The output of the model. + """ + warnings.warn( "version.predict() is deprecated. Use replicate.run() instead. It will be removed before version 1.0.", DeprecationWarning, @@ -57,7 +66,12 @@ def __init__(self, client: "Client", model: "Model") -> None: # doesn't exist yet def get(self, id: str) -> Version: """ - Get a specific version. + Get a specific model version. + + Args: + id: The version ID. + Returns: + The model version. """ resp = self._client._request( "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}" @@ -70,6 +84,9 @@ def create(self, **kwargs) -> Version: def list(self) -> List[Version]: """ Return a list of all versions for a model. + + Returns: + List[Version]: A list of version objects. """ resp = self._client._request( "GET", f"/v1/models/{self._model.username}/{self._model.name}/versions" From 68f289ff407da0b1084177e38489cadaa1c2c276 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:42:39 -0700 Subject: [PATCH 2/7] Add docstrings to classes and fields Signed-off-by: Mattt Zmuda --- replicate/model.py | 11 ++++++++++ replicate/prediction.py | 41 ++++++++++++++++++++++++++++++----- replicate/training.py | 48 +++++++++++++++++++++++++++++++++++------ replicate/version.py | 11 ++++++++++ 4 files changed, 99 insertions(+), 12 deletions(-) diff --git a/replicate/model.py b/replicate/model.py index 5cf22c5b..2577cd1c 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -7,8 +7,19 @@ class Model(BaseModel): + """ + A machine learning model hosted on Replicate. + """ + username: str + """ + The name of the user or organization that owns the model. + """ + name: str + """ + The name of the model. + """ def predict(self, *args, **kwargs) -> None: raise ReplicateException( diff --git a/replicate/prediction.py b/replicate/prediction.py index 8c3c006e..2f88e8ef 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -10,17 +10,48 @@ class Prediction(BaseModel): + """ + A prediction made by a model hosted on Replicate. + """ + id: str - error: Optional[str] + """The unique ID of the prediction.""" + + version: Optional[Version] + """The version of the model used to create the prediction.""" + + status: str + """The status of the prediction.""" + input: Optional[Dict[str, Any]] - logs: Optional[str] + """The input to the prediction.""" + output: Optional[Any] - status: str - version: Optional[Version] - started_at: Optional[str] + """The output of the prediction.""" + + logs: Optional[str] + """The logs of the prediction.""" + + error: Optional[str] + """The error encountered during the prediction, if any.""" + created_at: Optional[str] + """When the prediction was created.""" + + started_at: Optional[str] + """When the prediction was started.""" + completed_at: Optional[str] + """When the prediction was completed, if finished.""" + urls: Optional[Dict[str, str]] + """ + URLs associated with the prediction. + + The following keys are available: + - `get`: A URL to fetch the prediction. + - `cancel`: A URL to cancel the prediction. + """ def wait(self) -> None: """ diff --git a/replicate/training.py b/replicate/training.py index 8b8ddf63..bbbe4a2f 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -10,17 +10,51 @@ class Training(BaseModel): - completed_at: Optional[str] - created_at: Optional[str] - destination: Optional[str] - error: Optional[str] + """ + A training made for a model hosted on Replicate. + """ + id: str + """The unique ID of the training.""" + + version: Optional[Version] + """The version of the model used to create the training.""" + + destination: Optional[str] + """The model destination of the training.""" + + status: str + """The status of the training.""" + input: Optional[Dict[str, Any]] - logs: Optional[str] + """The input to the training.""" + output: Optional[Any] + """The output of the training.""" + + logs: Optional[str] + """The logs of the training.""" + + error: Optional[str] + """The error encountered during the training, if any.""" + + created_at: Optional[str] + """When the training was created.""" + started_at: Optional[str] - status: str - version: Optional[Version] + """When the training was started.""" + + completed_at: Optional[str] + """When the training was completed, if finished.""" + + urls: Optional[Dict[str, str]] + """ + URLs associated with the training. + + The following keys are available: + - `get`: A URL to fetch the training. + - `cancel`: A URL to cancel the training. + """ def cancel(self) -> None: """Cancel a running training""" diff --git a/replicate/version.py b/replicate/version.py index f1646774..71cee3ad 100644 --- a/replicate/version.py +++ b/replicate/version.py @@ -14,10 +14,21 @@ class Version(BaseModel): + """ + A version of a model. + """ + id: str + """The unique ID of the version.""" + created_at: datetime.datetime + """When the version was created.""" + cog_version: str + """The version of the Cog used to create the version.""" + openapi_schema: dict + """An OpenAPI description of the model inputs and outputs.""" def predict(self, **kwargs) -> Union[Any, Iterator[Any]]: """ From 07cebc395323173b99902e791a8224bf4c34a0b1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:43:22 -0700 Subject: [PATCH 3/7] Add docstring with deprecation warning Signed-off-by: Mattt Zmuda --- replicate/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/replicate/model.py b/replicate/model.py index 2577cd1c..4787e337 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -22,6 +22,10 @@ class Model(BaseModel): """ def predict(self, *args, **kwargs) -> None: + """ + DEPRECATED: Use `version.predict()` instead. + """ + raise ReplicateException( "The `model.predict()` method has been removed, because it's unstable: if a new version of the model you're using is pushed and its API has changed, your code may break. Use `version.predict()` instead. See https://github.com/replicate/replicate-python#readme" ) From 089e8829d207e658bf408acf8a8707fdb61f8823 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:52:06 -0700 Subject: [PATCH 4/7] Adopt grammatical mood of PEP 257 Signed-off-by: Mattt Zmuda --- replicate/client.py | 2 +- replicate/json.py | 2 +- replicate/prediction.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 1a78eb78..91a2bf07 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -115,7 +115,7 @@ def trainings(self) -> TrainingCollection: def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: """ - Runs a model and waits for its output. + Run a model and wait for its output. Args: model_version: The model version to run, in the format `owner/name:version` diff --git a/replicate/json.py b/replicate/json.py index cd0b864e..9cf0b0f9 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -15,7 +15,7 @@ def encode_json( obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401 ) -> Any: # noqa: ANN401 """ - Returns a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. + Return a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. """ if isinstance(obj, dict): return {key: encode_json(value, upload_file) for key, value in obj.items()} diff --git a/replicate/prediction.py b/replicate/prediction.py index 2f88e8ef..dee7716d 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -55,7 +55,7 @@ class Prediction(BaseModel): def wait(self) -> None: """ - Waits for prediction to finish. + Wait for prediction to finish. """ while self.status not in ["succeeded", "failed", "canceled"]: time.sleep(self._client.poll_interval) From c1eb9b88983b9ee0c6b85f8047d1f69ade58f78a Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:52:22 -0700 Subject: [PATCH 5/7] Format Collection class docs docstring Signed-off-by: Mattt Zmuda --- replicate/collection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/replicate/collection.py b/replicate/collection.py index 92e7a88a..32596f89 100644 --- a/replicate/collection.py +++ b/replicate/collection.py @@ -11,8 +11,7 @@ class Collection(abc.ABC, Generic[Model]): """ - A base class for representing all objects of a particular type on the - server. + A base class for representing objects of a particular type on the server. """ def __init__(self, client: "Client") -> None: From a10e5911988e9388aa31cd368186584557fbfffe Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:55:35 -0700 Subject: [PATCH 6/7] Remove implementation notes about Cog from documentation strings Signed-off-by: Mattt Zmuda --- replicate/files.py | 10 +++++++++- replicate/json.py | 4 +++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/replicate/files.py b/replicate/files.py index 55a6612c..27dbb6db 100644 --- a/replicate/files.py +++ b/replicate/files.py @@ -9,8 +9,16 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str: """ - Lifted straight from cog.files + Upload a file to the server. + + Args: + fh: A file handle to upload. + output_file_prefix: A string to prepend to the output file name. + Returns: + str: A URL to the uploaded file. """ + # Lifted straight from cog.files + fh.seek(0) if output_file_prefix is not None: diff --git a/replicate/json.py b/replicate/json.py index 9cf0b0f9..8884cd06 100644 --- a/replicate/json.py +++ b/replicate/json.py @@ -15,8 +15,10 @@ def encode_json( obj: Any, upload_file: Callable[[io.IOBase], str] # noqa: ANN401 ) -> Any: # noqa: ANN401 """ - Return a JSON-compatible version of the object. Effectively the same thing as cog.json.encode_json. + Return a JSON-compatible version of the object. """ + # Effectively the same thing as cog.json.encode_json. + if isinstance(obj, dict): return {key: encode_json(value, upload_file) for key, value in obj.items()} if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): From b27a55b237f54f4834fa3072274fdea9d6b0a31d Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 13:57:05 -0700 Subject: [PATCH 7/7] Remove type information from method args Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 17 ++++++++--------- replicate/training.py | 14 +++++++------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/replicate/prediction.py b/replicate/prediction.py index dee7716d..9f2fc8a7 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -95,7 +95,7 @@ def list(self) -> List[Prediction]: List your predictions. Returns: - List[Prediction]: A list of prediction objects. + A list of prediction objects. """ resp = self._client._request("GET", "/v1/predictions") @@ -111,7 +111,7 @@ def get(self, id: str) -> Prediction: Get a prediction by ID. Args: - id (str): The ID of the prediction. + id: The ID of the prediction. Returns: Prediction: The prediction object. """ @@ -135,16 +135,15 @@ def create( # type: ignore Create a new prediction for the specified model version. Args: - version (Version): The model version to use for the prediction. - input (Dict[str, Any]): The input data for the prediction. - webhook (Optional[str]): The URL to receive a POST request with prediction updates. - webhook_completed (Optional[str]): The URL to receive a POST request when the prediction is completed. - webhook_events_filter (Optional[List[str]]): List of events to trigger webhooks. - stream (Optional[bool]): Set to True to enable streaming of prediction output. + version: The model version to use for the prediction. + input: The input data for the prediction. + webhook: The URL to receive a POST request with prediction updates. + webhook_completed: The URL to receive a POST request when the prediction is completed. + webhook_events_filter: List of events to trigger webhooks. + stream: Set to True to enable streaming of prediction output. Returns: Prediction: The created prediction object. - """ input = encode_json(input, upload_file=upload_file) diff --git a/replicate/training.py b/replicate/training.py index bbbe4a2f..d93b56ab 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -85,7 +85,7 @@ def get(self, id: str) -> Training: Get a training by ID. Args: - id (str): The ID of the training. + id: The ID of the training. Returns: Training: The training object. """ @@ -112,13 +112,13 @@ def create( # type: ignore Create a new training using the specified model version as a base. Args: - version (str): The ID of the base model version that you're using to train a new model version. - input (Dict[str, Any]): The input to the training. - destination (str): The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. - webhook (Optional[str], optional): The URL to send a POST request to when the training is completed. Defaults to None. - webhook_events_filter (Optional[List[str]], optional): The events to send to the webhook. Defaults to None. + version: The ID of the base model version that you're using to train a new model version. + input: The input to the training. + destination: The desired model to push to in the format `{owner}/{model_name}`. This should be an existing model owned by the user or organization making the API request. + webhook: The URL to send a POST request to when the training is completed. Defaults to None. + webhook_events_filter: The events to send to the webhook. Defaults to None. Returns: - Training: The training object. + The training object. """ input = encode_json(input, upload_file=upload_file)