diff --git a/replicate/prediction.py b/replicate/prediction.py index 854aac64..298a4655 100644 --- a/replicate/prediction.py +++ b/replicate/prediction.py @@ -35,6 +35,9 @@ class Prediction(BaseModel): error: Optional[str] """The error encountered during the prediction, if any.""" + metrics: Optional[Dict[str, Any]] + """Metrics for the prediction.""" + created_at: Optional[str] """When the prediction was created.""" diff --git a/tests/test_prediction.py b/tests/test_prediction.py index c014f3b6..ad6ccba6 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -193,6 +193,9 @@ def test_async_timings(): "output": "hello world", "error": None, "logs": "", + "metrics": { + "predict_time": 1.2345, + }, }, ) @@ -210,3 +213,4 @@ def test_async_timings(): assert prediction.created_at == "2022-04-26T20:00:40.658234Z" assert prediction.completed_at == "2022-04-26T20:02:27.648305Z" assert prediction.output == "hello world" + assert prediction.metrics["predict_time"] == 1.2345