From 4bba221b2233ecf53cf2f201c693320b5956ce59 Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 04:32:43 -0700 Subject: [PATCH 1/2] Move shared logic for prediction creation into createPredictionRequest helper method --- deployment.go | 22 ++++++---------------- model.go | 21 ++++++--------------- prediction.go | 44 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/deployment.go b/deployment.go index 571279b..c07eb2c 100644 --- a/deployment.go +++ b/deployment.go @@ -45,26 +45,16 @@ func (d *Deployment) UnmarshalJSON(data []byte) error { // CreateDeploymentPrediction sends a request to the Replicate API to create a prediction using the specified deployment. func (c *Client) CreatePredictionWithDeployment(ctx context.Context, deploymentOwner string, deploymentName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { - data := map[string]interface{}{ - "input": input, - } - - if webhook != nil { - data["webhook"] = webhook.URL - if len(webhook.Events) > 0 { - data["webhook_events_filter"] = webhook.Events - } - } + path := fmt.Sprintf("/deployments/%s/%s/predictions", deploymentOwner, deploymentName) - if stream { - data["stream"] = true + req, err := c.createPredictionRequest(ctx, path, nil, input, webhook, stream) + if err != nil { + return nil, err } prediction := &Prediction{} - path := fmt.Sprintf("/deployments/%s/%s/predictions", deploymentOwner, deploymentName) - err := c.fetch(ctx, http.MethodPost, path, data, prediction) - if err != nil { - return nil, fmt.Errorf("failed to create prediction: %w", err) + if err := c.do(req, prediction); err != nil { + return nil, fmt.Errorf("failed to create prediction with deployment: %w", err) } return prediction, nil diff --git a/model.go b/model.go index 9ca5c12..29453bf 100644 --- a/model.go +++ b/model.go @@ -167,25 +167,16 @@ func (r *Client) DeleteModelVersion(ctx context.Context, modelOwner string, mode // CreatePredictionWithModel sends a request to the Replicate API to create a prediction for a model. func (r *Client) CreatePredictionWithModel(ctx context.Context, modelOwner string, modelName string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { - data := map[string]interface{}{ - "input": input, - } - - if webhook != nil { - data["webhook"] = webhook.URL - if len(webhook.Events) > 0 { - data["webhook_events_filter"] = webhook.Events - } - } + path := fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName) - if stream { - data["stream"] = true + req, err := r.createPredictionRequest(ctx, path, nil, input, webhook, stream) + if err != nil { + return nil, err } prediction := &Prediction{} - err := r.fetch(ctx, http.MethodPost, fmt.Sprintf("/models/%s/%s/predictions", modelOwner, modelName), data, prediction) - if err != nil { - return nil, err + if err := r.do(req, prediction); err != nil { + return nil, fmt.Errorf("failed to create prediction with model: %w", err) } return prediction, nil diff --git a/prediction.go b/prediction.go index 815bcd4..9f59fb8 100644 --- a/prediction.go +++ b/prediction.go @@ -1,6 +1,7 @@ package replicate import ( + "bytes" "context" "encoding/json" "fmt" @@ -102,8 +103,8 @@ func (p Prediction) Progress() *PredictionProgress { return nil } -// CreatePrediction sends a request to the Replicate API to create a prediction. -func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { +// createPredictionRequest creates a prediction request. +func (r *Client) createPredictionRequest(ctx context.Context, path string, data map[string]interface{}, input PredictionInput, webhook *Webhook, stream bool) (*http.Request, error) { // Convert File objects in input to their "get" URL value for key, value := range input { if file, ok := value.(*File); ok { @@ -111,11 +112,12 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre } } - data := map[string]interface{}{ - "version": version, - "input": input, + if data == nil { + data = make(map[string]interface{}) } + data["input"] = input + if webhook != nil { data["webhook"] = webhook.URL if len(webhook.Events) > 0 { @@ -127,9 +129,37 @@ func (r *Client) CreatePrediction(ctx context.Context, version string, input Pre data["stream"] = true } - prediction := &Prediction{} - err := r.fetch(ctx, http.MethodPost, "/predictions", data, prediction) + bodyBuffer := &bytes.Buffer{} + if data != nil { + bodyBytes, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + bodyBuffer = bytes.NewBuffer(bodyBytes) + } + + req, err := r.newRequest(ctx, http.MethodPost, path, bodyBuffer) if err != nil { + return nil, fmt.Errorf("failed to create prediction request: %w", err) + } + + return req, nil +} + +// CreatePrediction creates a prediction for a specific version of a model. +func (r *Client) CreatePrediction(ctx context.Context, version string, input PredictionInput, webhook *Webhook, stream bool) (*Prediction, error) { + path := "/predictions" + data := map[string]interface{}{ + "version": version, + } + + req, err := r.createPredictionRequest(ctx, path, data, input, webhook, stream) + if err != nil { + return nil, err + } + + prediction := &Prediction{} + if err := r.do(req, prediction); err != nil { return nil, fmt.Errorf("failed to create prediction: %w", err) } From b84a32de6cc56b4f6aeaba3e5bf58029ca6eb61a Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Mon, 23 Sep 2024 04:40:54 -0700 Subject: [PATCH 2/2] Add WithBlockUntilDone run option Improve docs and comments --- run.go | 37 ++++++++++++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/run.go b/run.go index bb97ef8..55426a6 100644 --- a/run.go +++ b/run.go @@ -17,7 +17,8 @@ type RunOption func(*runOptions) // runOptions represents options for running a model type runOptions struct { - useFileOutput bool + useFileOutput bool + blockUntilDone bool } // FileOutput is a custom type that implements io.ReadCloser and includes a URL field @@ -26,43 +27,73 @@ type FileOutput struct { URL string } -// WithFileOutput sets the UseFileOutput option to true +// WithFileOutput configures the run to automatically convert URLs in output to FileOutput objects func WithFileOutput() RunOption { return func(o *runOptions) { o.useFileOutput = true } } +// WithBlockUntilDone configures the run to block until the prediction is done +func WithBlockUntilDone() RunOption { + return func(o *runOptions) { + o.blockUntilDone = true + } +} + // RunWithOptions runs a model with specified options func (r *Client) RunWithOptions(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook, opts ...RunOption) (PredictionOutput, error) { + // Initialize options options := runOptions{} for _, opt := range opts { opt(&options) } + // Parse the identifier to extract version id, err := ParseIdentifier(identifier) if err != nil { return nil, err } + // Check if version is specified if id.Version == nil { return nil, errors.New("version must be specified") } - prediction, err := r.CreatePrediction(ctx, *id.Version, input, webhook, false) + // Prepare the data for the prediction request + data := map[string]interface{}{ + "version": *id.Version, + } + + // Create the prediction request + req, err := r.createPredictionRequest(ctx, "/predictions", data, input, webhook, false) if err != nil { return nil, err } + // Set the X-Sync header if blockUntilDone is true + if options.blockUntilDone { + req.Header.Set("X-Sync", "true") + } + + // Execute the request and obtain the prediction + prediction := &Prediction{} + if err := r.do(req, prediction); err != nil { + return nil, err + } + + // Wait for the prediction to complete err = r.Wait(ctx, prediction) if err != nil { return nil, err } + // Check for model error in the prediction if prediction.Error != nil { return nil, &ModelError{Prediction: prediction} } + // Transform the output based on the options if options.useFileOutput { return transformOutput(ctx, prediction.Output, r) }