From b0d634b9cc8d4be18acbd03e12011981eaa6774e Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 12 Sep 2024 15:49:56 -0700 Subject: [PATCH 1/4] Add RunWithOptions method that supports returning file output as bytes --- run.go | 106 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/run.go b/run.go index 0118f17..b8606bb 100644 --- a/run.go +++ b/run.go @@ -2,10 +2,37 @@ package replicate import ( "context" + "encoding/base64" "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" ) -func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) { +// RunOption is a function that modifies RunOptions +type RunOption func(*runOptions) + +// runOptions represents options for running a model +type runOptions struct { + useFileOutput bool +} + +// WithFileOutput sets the UseFileOutput option to true +func WithFileOutput() RunOption { + return func(o *runOptions) { + o.useFileOutput = true + } +} + +// RunWithOptions runs a model with specified options +func (r *Client) RunWithOptions(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (PredictionOutput, error) { + options := runOptions{} + for _, opt := range opts { + opt(&options) + } + id, err := ParseIdentifier(identifier) if err != nil { return nil, err @@ -29,5 +56,82 @@ func (r *Client) Run(ctx context.Context, identifier string, input PredictionInp return nil, &ModelError{Prediction: prediction} } + if options.useFileOutput { + return transformOutput(ctx, prediction.Output, r) + } + return prediction.Output, nil } + +// Run runs a model and returns the output +func (r *Client) Run(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (PredictionOutput, error) { + return r.RunWithOptions(ctx, identifier, input, webhook) +} + +func transformOutput(ctx context.Context, value interface{}, client *Client) (interface{}, error) { + var err error + switch v := value.(type) { + case map[string]interface{}: + for k, val := range v { + v[k], err = transformOutput(ctx, val, client) + if err != nil { + return nil, err + } + } + return v, nil + case []interface{}: + for i, val := range v { + v[i], err = transformOutput(ctx, val, client) + if err != nil { + return nil, err + } + } + return v, nil + case string: + if strings.HasPrefix(v, "data:") { + return readDataURI(v) + } + if strings.HasPrefix(v, "https:") || strings.HasPrefix(v, "http:") { + return readHTTP(ctx, v, client) + } + return v, nil + } + return value, nil +} + +func readDataURI(uri string) ([]byte, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + if u.Scheme != "data" { + return nil, errors.New("not a data URI") + } + mediatype, data, found := strings.Cut(u.Opaque, ",") + if !found { + return nil, errors.New("invalid data URI format") + } + if strings.HasSuffix(mediatype, ";base64") { + return base64.StdEncoding.DecodeString(data) + } + return []byte(data), nil +} + +func readHTTP(ctx context.Context, url string, client *Client) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + resp, err := client.c.Do(req) + if resp == nil || resp.Body == nil { + return nil, errors.New("HTTP request failed to get a response") + } + defer resp.Body.Close() + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode) + } + return io.ReadAll(resp.Body) +} From d72885b619c910533183ff4fa8da751fda6f6d92 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 12 Sep 2024 15:51:12 -0700 Subject: [PATCH 2/4] Return io.ReadCloser instead of []byte --- run.go | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/run.go b/run.go index b8606bb..9d19088 100644 --- a/run.go +++ b/run.go @@ -1,6 +1,7 @@ package replicate import ( + "bytes" "context" "encoding/base64" "errors" @@ -99,7 +100,7 @@ func transformOutput(ctx context.Context, value interface{}, client *Client) (in return value, nil } -func readDataURI(uri string) ([]byte, error) { +func readDataURI(uri string) (io.ReadCloser, error) { u, err := url.Parse(uri) if err != nil { return nil, err @@ -111,27 +112,34 @@ func readDataURI(uri string) ([]byte, error) { if !found { return nil, errors.New("invalid data URI format") } + var reader io.Reader if strings.HasSuffix(mediatype, ";base64") { - return base64.StdEncoding.DecodeString(data) + decoded, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + reader = bytes.NewReader(decoded) + } else { + reader = strings.NewReader(data) } - return []byte(data), nil + return io.NopCloser(reader), nil } -func readHTTP(ctx context.Context, url string, client *Client) ([]byte, error) { +func readHTTP(ctx context.Context, url string, client *Client) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err } resp, err := client.c.Do(req) - if resp == nil || resp.Body == nil { - return nil, errors.New("HTTP request failed to get a response") - } - defer resp.Body.Close() if err != nil { return nil, err } + if resp == nil || resp.Body == nil { + return nil, errors.New("HTTP request failed to get a response") + } if resp.StatusCode != http.StatusOK { + resp.Body.Close() return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode) } - return io.ReadAll(resp.Body) + return resp.Body, nil } From 0407ea8e7861428116622375552036fda218b1f0 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Thu, 12 Sep 2024 16:49:27 -0700 Subject: [PATCH 3/4] Add test coverage for RunWithOptions --- client_test.go | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/client_test.go b/client_test.go index 7f7f0f9..9e1145c 100644 --- a/client_test.go +++ b/client_test.go @@ -1475,6 +1475,79 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) { assert.ErrorContains(t, err, http.StatusText(http.StatusInternalServerError)) } +func TestRunWithOptions(t *testing.T) { + var mockServer *httptest.Server + mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/predictions": + assert.Equal(t, http.MethodPost, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Starting, + } + json.NewEncoder(w).Encode(prediction) + case "/predictions/gtsllfynndufawqhdngldkdrkq": + assert.Equal(t, http.MethodGet, r.Method) + prediction := replicate.Prediction{ + ID: "gtsllfynndufawqhdngldkdrkq", + Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + Status: replicate.Succeeded, + Output: map[string]interface{}{ + "image": mockServer.URL + "/output.png", + "text": "Hello, world!", + }, + } + json.NewEncoder(w).Encode(prediction) + case "/output.png": + w.Header().Set("Content-Type", "image/png") + w.Write([]byte("mock image data")) + default: + t.Fatalf("Unexpected request to %s", r.URL.Path) + } + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NoError(t, err) + + ctx := context.Background() + input := replicate.PredictionInput{"prompt": "A test image"} + + // Test with WithFileOutput option + output, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil, replicate.WithFileOutput()) + + require.NoError(t, err) + assert.NotNil(t, output) + + // Check if the image output is transformed to io.ReadCloser + imageOutput, ok := output.(map[string]interface{})["image"].(io.ReadCloser) + require.True(t, ok, "Expected image output to be io.ReadCloser") + + imageData, err := io.ReadAll(imageOutput) + require.NoError(t, err) + assert.Equal(t, []byte("mock image data"), imageData) + + // Check if the text output remains unchanged + textOutput, ok := output.(map[string]interface{})["text"].(string) + require.True(t, ok, "Expected text output to be string") + assert.Equal(t, "Hello, world!", textOutput) + + // Test without WithFileOutput option + outputWithoutFileOption, err := client.RunWithOptions(ctx, "owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", input, nil) + + require.NoError(t, err) + assert.NotNil(t, outputWithoutFileOption) + + // Check if the image output remains a URL string + imageOutputURL, ok := outputWithoutFileOption.(map[string]interface{})["image"].(string) + require.True(t, ok, "Expected image output to be string") + assert.Equal(t, mockServer.URL+"/output.png", imageOutputURL) +} + func TestStream(t *testing.T) { tokens := []string{"Alpha", "Bravo", "Charlie", "Delta", "Echo"} From b8027080fc0e149d6de9601b54255092385acaa1 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 13 Sep 2024 03:28:34 -0700 Subject: [PATCH 4/4] Return custom FileOutput type that implements io.ReadCloser and provides URL --- run.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/run.go b/run.go index 9d19088..bb97ef8 100644 --- a/run.go +++ b/run.go @@ -20,6 +20,12 @@ type runOptions struct { useFileOutput bool } +// FileOutput is a custom type that implements io.ReadCloser and includes a URL field +type FileOutput struct { + io.ReadCloser + URL string +} + // WithFileOutput sets the UseFileOutput option to true func WithFileOutput() RunOption { return func(o *runOptions) { @@ -100,7 +106,7 @@ func transformOutput(ctx context.Context, value interface{}, client *Client) (in return value, nil } -func readDataURI(uri string) (io.ReadCloser, error) { +func readDataURI(uri string) (*FileOutput, error) { u, err := url.Parse(uri) if err != nil { return nil, err @@ -122,10 +128,13 @@ func readDataURI(uri string) (io.ReadCloser, error) { } else { reader = strings.NewReader(data) } - return io.NopCloser(reader), nil + return &FileOutput{ + ReadCloser: io.NopCloser(reader), + URL: uri, + }, nil } -func readHTTP(ctx context.Context, url string, client *Client) (io.ReadCloser, error) { +func readHTTP(ctx context.Context, url string, client *Client) (*FileOutput, error) { req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, err @@ -141,5 +150,9 @@ func readHTTP(ctx context.Context, url string, client *Client) (io.ReadCloser, e resp.Body.Close() return nil, fmt.Errorf("HTTP request failed with status code %d", resp.StatusCode) } - return resp.Body, nil + + return &FileOutput{ + ReadCloser: resp.Body, + URL: url, + }, nil }