Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,21 @@ def create(
self,
version: Version,
input: Dict[str, Any],
webhook: Optional[str] = None,
webhook_completed: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
) -> Prediction:
input = encode_json(input, upload_file=upload_file)
body = {
"version": version.id,
"input": input,
}
if webhook is not None:
body["webhook"] = webhook
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed
if webhook_events_filter is not None:
body["webhook_events_filter"] = webhook_events_filter

resp = self._client._request(
"POST",
Expand Down
48 changes: 47 additions & 1 deletion tests/test_prediction.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
import replicate
import responses
from responses import matchers

import replicate

from .factories import create_client, create_version


@responses.activate
def test_create_works_with_webhooks():
client = create_client()
version = create_version(client)

rsp = responses.post(
"https://api.replicate.com/v1/predictions",
match=[
matchers.json_params_matcher(
{
"version": "v1",
"input": {"text": "world"},
"webhook": "https://example.com/webhook",
"webhook_events_filter": ["completed"],
}
),
],
json={
"id": "p1",
"version": "v1",
"urls": {
"get": "https://api.replicate.com/v1/predictions/p1",
"cancel": "https://api.replicate.com/v1/predictions/p1/cancel",
},
"created_at": "2022-04-26T20:00:40.658234Z",
"completed_at": "2022-04-26T20:02:27.648305Z",
"source": "api",
"status": "processing",
"input": {"text": "world"},
"output": None,
"error": None,
"logs": "",
},
)

prediction = client.predictions.create(
version=version,
input={"text": "world"},
webhook="https://example.com/webhook",
webhook_events_filter=["completed"],
)

assert rsp.call_count == 1


@responses.activate
def test_cancel():
client = create_client()
Expand Down