diff --git a/src/codeocean/computation.py b/src/codeocean/computation.py index bd72451..723834b 100644 --- a/src/codeocean/computation.py +++ b/src/codeocean/computation.py @@ -4,7 +4,7 @@ from dataclasses_json import dataclass_json from requests_toolbelt.sessions import BaseUrlSession from typing import Optional -from time import sleep +from time import sleep, time from codeocean.enum import StrEnum @@ -141,20 +141,39 @@ def run_capsule(self, run_params: RunParams) -> Computation: return Computation.from_dict(res.json()) - def wait_until_completed(self, computation: Computation, polling_interval: int = 5) -> Computation: + def wait_until_completed( + self, + computation: Computation, + polling_interval: float = 5, + timeout: Optional[float] = None, + ) -> Computation: """ Polls the given computation until it reaches the 'Completed' or 'Failed' state. + + - `polling_interval` and `timeout` are in seconds """ if polling_interval < 5: raise ValueError( f"Polling interval {polling_interval} should be greater than or equal to 5" ) + if timeout is not None and timeout < polling_interval: + raise ValueError( + f"Timeout {timeout} should be greater than or equal to polling interval {polling_interval}" + ) + if timeout is not None and timeout < 0: + raise ValueError( + f"Timeout {timeout} should be greater than or equal to 0 (seconds), or None" + ) + t0 = time() while True: comp = self.get_computation(computation.id) if comp.state in [ComputationState.Completed, ComputationState.Failed]: return comp + if timeout is not None and (time() - t0) > timeout: + raise TimeoutError(f"Computation {computation.id} did not complete within {timeout} seconds") + sleep(polling_interval) def list_computation_results(self, computation_id: str, path: str = "") -> Folder: diff --git a/src/codeocean/data_asset.py b/src/codeocean/data_asset.py index e7783af..69f8cc7 100644 --- a/src/codeocean/data_asset.py +++ b/src/codeocean/data_asset.py @@ -3,7 +3,7 @@ from dataclasses_json import dataclass_json from dataclasses import dataclass from requests_toolbelt.sessions import BaseUrlSession -from time import sleep +from time import sleep, time from typing import Optional from codeocean.components import SortOrder, SearchFilter, Permissions @@ -237,20 +237,39 @@ def create_data_asset(self, data_asset_params: DataAssetParams) -> DataAsset: return DataAsset.from_dict(res.json()) - def wait_until_ready(self, data_asset: DataAsset, polling_interval: int = 5) -> DataAsset: + def wait_until_ready( + self, + data_asset: DataAsset, + polling_interval: float = 5, + timeout: float | None = None, + ) -> DataAsset: """ Polls the given data asset until it reaches the 'Ready' or 'Failed' state. + + - `polling_interval` and `timeout` are in seconds """ if polling_interval < 5: raise ValueError( f"Polling interval {polling_interval} should be greater than or equal to 5" ) + if timeout is not None and timeout < polling_interval: + raise ValueError( + f"Timeout {timeout} should be greater than or equal to polling interval {polling_interval}" + ) + if timeout is not None and timeout < 0: + raise ValueError( + f"Timeout {timeout} should be greater than or equal to 0 (seconds), or None" + ) + t0 = time() while True: da = self.get_data_asset(data_asset.id) if da.state in [DataAssetState.Ready, DataAssetState.Failed]: return da + if timeout is not None and (time() - t0) > timeout: + raise TimeoutError(f"Data asset {data_asset.id} was not ready within {timeout} seconds") + sleep(polling_interval) def delete_data_asset(self, data_asset_id: str):