From 09e87e16b471feff16c2eebdfb7d56bcdfc64133 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 31 Jul 2023 11:53:47 -0700 Subject: [PATCH] Add stream parameter to predictions.create Signed-off-by: Mattt Zmuda --- replicate/prediction.py | 4 ++++ tests/test_prediction.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/replicate/prediction.py b/replicate/prediction.py index 9f2fc8a7..854aac64 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -129,6 +129,8 @@ def create( # type: ignore webhook: Optional[str] = None, webhook_completed: Optional[str] = None, webhook_events_filter: Optional[List[str]] = None, + *, + stream: Optional[bool] = None, **kwargs, ) -> Prediction: """ @@ -157,6 +159,8 @@ def create( # type: ignore body["webhook_completed"] = webhook_completed if webhook_events_filter is not None: body["webhook_events_filter"] = webhook_events_filter + if stream is True: + body["stream"] = "true" resp = self._client._request( "POST", diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 3a336a86..c014f3b6 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -94,6 +94,55 @@ def test_cancel(): assert rsp.call_count == 1 +@responses.activate +def test_stream(): + client = create_client() + version = create_version(client) + + rsp = responses.post( + "https://api.replicate.com/v1/predictions", + match=[ + matchers.json_params_matcher( + { + "version": "v1", + "input": {"text": "world"}, + "stream": "true", + } + ), + ], + json={ + "id": "p1", + "version": "v1", + "urls": { + "get": "https://api.replicate.com/v1/predictions/p1", + "cancel": "https://api.replicate.com/v1/predictions/p1/cancel", + "stream": "https://streaming.api.replicate.com/v1/predictions/p1", + }, + "created_at": "2022-04-26T20:00:40.658234Z", + "completed_at": "2022-04-26T20:02:27.648305Z", + "source": "api", + "status": "processing", + "input": {"text": "world"}, + "output": None, + "error": None, + "logs": "", + }, + ) + + prediction = client.predictions.create( + version=version, + input={"text": "world"}, + stream=True, + ) + + assert rsp.call_count == 1 + + assert ( + prediction.urls["stream"] + == "https://streaming.api.replicate.com/v1/predictions/p1" + ) + + @responses.activate def test_async_timings(): client = create_client()