diff --git a/replicate/prediction.py b/replicate/prediction.py index 7edddc70..1d5df1a6 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -59,15 +59,21 @@ def create( self, version: Version, input: Dict[str, Any], + webhook: Optional[str] = None, webhook_completed: Optional[str] = None, + webhook_events_filter: Optional[List[str]] = None, ) -> Prediction: input = encode_json(input, upload_file=upload_file) body = { "version": version.id, "input": input, } + if webhook is not None: + body["webhook"] = webhook if webhook_completed is not None: body["webhook_completed"] = webhook_completed + if webhook_events_filter is not None: + body["webhook_events_filter"] = webhook_events_filter resp = self._client._request( "POST", diff --git a/tests/test_prediction.py b/tests/test_prediction.py index bfe8a3ec..bbfecb9d 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,10 +1,56 @@ -import replicate import responses from responses import matchers +import replicate + from .factories import create_client, create_version +@responses.activate +def test_create_works_with_webhooks(): + client = create_client() + version = create_version(client) + + rsp = responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher( + { + "version": "v1", + "input": {"text": "world"}, + "webhook": "https://example.com/webhook", + "webhook_events_filter": ["completed"], + } + ), + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + prediction = client.predictions.create( + version=version, + input={"text": "world"}, + webhook="https://example.com/webhook", + webhook_events_filter=["completed"], + ) + + assert rsp.call_count == 1 + + @responses.activate def test_cancel(): client = create_client()