diff --git a/client_test.go b/client_test.go index 569b264..7f7f0f9 100644 --- a/client_test.go +++ b/client_test.go @@ -1075,6 +1075,96 @@ func TestWaitAsync(t *testing.T) { assert.Equal(t, replicate.Succeeded, lastStatus) } +func TestRun(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/predictions": + assert.Equal(t, http.MethodPost, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Starting, + } + json.NewEncoder(w).Encode(prediction) + case "/predictions/gtsllfynndufawqhdngldkdrkq": + assert.Equal(t, http.MethodGet, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Succeeded, + Output: "Hello, world!", + } + json.NewEncoder(w).Encode(prediction) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx := context.Background() + input := replicate.PredictionInput{"prompt": "Hello"} + output, err := client.Run(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil) + + require.NoError(t, err) + assert.NotNil(t, output) + assert.Equal(t, "Hello, world!", output) +} + +func TestRunReturningModelError(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/predictions": + assert.Equal(t, http.MethodPost, r.Method) + prediction := replicate.Prediction{ + ID: "fynndufawqhdngldkgtslldrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Starting, + } + json.NewEncoder(w).Encode(prediction) + case "/predictions/fynndufawqhdngldkgtslldrkq": + assert.Equal(t, http.MethodGet, r.Method) + + logs := "Could not say hello" + prediction := replicate.Prediction{ + ID: "fynndufawqhdngldkgtslldrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Failed, + Logs: &logs, + Error: "Model execution failed", + } + json.NewEncoder(w).Encode(prediction) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx := context.Background() + input := replicate.PredictionInput{"prompt": "Hello"} + _, err = client.Run(ctx, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil) + + require.Error(t, err) + modelErr, ok := err.(*replicate.ModelError) + require.True(t, ok, "Expected error to be of type *replicate.ModelError") + assert.Equal(t, "model error: Model execution failed", modelErr.Error()) + assert.Equal(t, "fynndufawqhdngldkgtslldrkq", modelErr.Prediction.ID) + assert.Equal(t, replicate.Failed, modelErr.Prediction.Status) + assert.Equal(t, "Model execution failed", modelErr.Prediction.Error) + assert.Equal(t, "Could not say hello", *modelErr.Prediction.Logs) +} + func TestCreateTraining(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) diff --git a/apierror.go b/error.go similarity index 83% rename from apierror.go rename to error.go index 470aeaf..b97a2bb 100644 --- a/apierror.go +++ b/error.go @@ -55,7 +55,7 @@ func (e APIError) Error() string { output := strings.Join(components, ": ") if output == "" { - output = "Unknown error" + output = "unknown error" } if e.Instance != "" { @@ -78,3 +78,16 @@ func (e *APIError) WriteHTTPResponse(w http.ResponseWriter) { http.Error(w, err.Error(), http.StatusInternalServerError) } } + +// ModelError represents an error returned by a model for a failed prediction. +type ModelError struct { + Prediction *Prediction `json:"prediction"` +} + +func (e *ModelError) Error() string { + if e.Prediction == nil { + return "unknown model error" + } + + return fmt.Sprintf("model error: %s", e.Prediction.Error) +} diff --git a/run.go b/run.go index 544664b..0118f17 100644 --- a/run.go +++ b/run.go @@ -21,6 +21,13 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp } err = r.Wait(ctx, prediction) + if err != nil { + return nil, err + } + + if prediction.Error != nil { + return nil, &ModelError{Prediction: prediction} + } - return prediction.Output, err + return prediction.Output, nil }