diff --git a/client_test.go b/client_test.go index 7f7f0f9..9e1145c 100644 --- a/client_test.go +++ b/client_test.go @@ -1475,6 +1475,79 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) { assert.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError)) } +func TestRunWithOptions(t *testing.T) { + var mockServer *httptest.Server + mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/predictions": + assert.Equal(t, http.MethodPost, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Starting, + } + json.NewEncoder(w).Encode(prediction) + case "/predictions/gtsllfynndufawqhdngldkdrkq": + assert.Equal(t, http.MethodGet, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Succeeded, + Output: map[string]interface{}{ + "image": mockServer.URL + "/output.png", + "text": "Hello, world!", + }, + } + json.NewEncoder(w).Encode(prediction) + case "/output.png": + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("mock image data")) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx := context.Background() + input := replicate.PredictionInput{"prompt": "A test image"} + + // Test with WithFileOutput option + output, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil, replicate.WithFileOutput()) + + require.NoError(t, err) + assert.NotNil(t, output) + + // Check if the image output is transformed to io.ReadCloser + imageOutput, ok := output.(map[string]interface{})["image"].(io.ReadCloser) + require.True(t, ok, "Expected image output to be io.ReadCloser") + + imageData, err := io.ReadAll(imageOutput) + require.NoError(t, err) + assert.Equal(t, []byte("mock image data"), imageData) + + // Check if the text output remains unchanged + textOutput, ok := output.(map[string]interface{})["text"].(string) + require.True(t, ok, "Expected text output to be string") + assert.Equal(t, "Hello, world!", textOutput) + + // Test without WithFileOutput option + outputWithoutFileOption, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil) + + require.NoError(t, err) + assert.NotNil(t, outputWithoutFileOption) + + // Check if the image output remains a URL string + imageOutputURL, ok := outputWithoutFileOption.(map[string]interface{})["image"].(string) + require.True(t, ok, "Expected image output to be string") + assert.Equal(t, mockServer.URL+"/output.png", imageOutputURL) +} + func TestStream(t *testing.T) { tokens := []string{"Alpha", "Bravo", "Charlie", "Delta", "Echo"} diff --git a/run.go b/run.go index 0118f17..bb97ef8 100644 --- a/run.go +++ b/run.go @@ -1,11 +1,45 @@ package replicate import ( + "bytes" "context" + "encoding/base64" "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" ) -func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) { +// RunOption is a function that modifies RunOptions +type RunOption func(*runOptions) + +// runOptions represents options for running a model +type runOptions struct { + useFileOutput bool +} + +// FileOutput is a custom type that implements io.ReadCloser and includes a URL field +type FileOutput struct { + io.ReadCloser + URL string +} + +// WithFileOutput sets the UseFileOutput option to true +func WithFileOutput() RunOption { + return func(o *runOptions) { + o.useFileOutput = true + } +} + +// RunWithOptions runs a model with specified options +func (r *Client) RunWithOptions(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (PredictionOutput, error) { + options := runOptions{} + for _, opt := range opts { + opt(&options) + } + id, err := ParseIdentifier(identifier) if err != nil { return nil, err @@ -29,5 +63,96 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp return nil, &ModelError{Prediction: prediction} } + if options.useFileOutput { + return transformOutput(ctx, prediction.Output, r) + } + return prediction.Output, nil } + +// Run runs a model and returns the output +func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) { + return r.RunWithOptions(ctx, identifier, input, webhook) +} + +func transformOutput(ctx context.Context, value interface{}, client *Client) (interface{}, error) { + var err error + switch v := value.(type) { + case map[string]interface{}: + for k, val := range v { + v[k], err = transformOutput(ctx, val, client) + if err != nil { + return nil, err + } + } + return v, nil + case []interface{}: + for i, val := range v { + v[i], err = transformOutput(ctx, val, client) + if err != nil { + return nil, err + } + } + return v, nil + case string: + if strings.HasPrefix(v, "data:") { + return readDataURI(v) + } + if strings.HasPrefix(v, "https:") || strings.HasPrefix(v, "http:") { + return readHTTP(ctx, v, client) + } + return v, nil + } + return value, nil +} + +func readDataURI(uri string) (*FileOutput, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + if u.Scheme != "data" { + return nil, errors.New("not a data URI") + } + mediatype, data, found := strings.Cut(u.Opaque, ",") + if !found { + return nil, errors.New("invalid data URI format") + } + var reader io.Reader + if strings.HasSuffix(mediatype, ";base64") { + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + reader = bytes.NewReader(decoded) + } else { + reader = strings.NewReader(data) + } + return &FileOutput{ + ReadCloser: io.NopCloser(reader), + URL: uri, + }, nil +} + +func readHTTP(ctx context.Context, url string, client *Client) (*FileOutput, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + resp, err := client.c.Do(req) + if err != nil { + return nil, err + } + if resp == nil || resp.Body == nil { + return nil, errors.New("HTTP request failed to get a response") + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode) + } + + return &FileOutput{ + ReadCloser: resp.Body, + URL: url, + }, nil +}