From b6f40e8e5475d6be2f35ce4568f182f3435d5cbd Mon Sep 17 00:00:00 2001 From: Mattt Zmuda Date: Fri, 19 Jul 2024 10:37:14 -0700 Subject: [PATCH] Add support for models.search endpoint --- client_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++++++++ example_test.go | 24 ++++++++++++++++++++++ model.go | 16 +++++++++++++++ 3 files changed, 94 insertions(+) diff --git a/client_test.go b/client_test.go index 3adb149..569b264 100644 --- a/client_test.go +++ b/client_test.go @@ -205,6 +205,60 @@ func TestListModels(t *testing.T) { assert.Equal(t, "codellama-13b", modelsPage.Results[1].Name) } +func TestSearchModels(t *testing.T) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/models", r.URL.Path) + assert.Equal(t, "QUERY", r.Method) + assert.Equal(t, "text/plain", r.Header.Get("Content-Type")) + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + defer r.Body.Close() + + assert.Equal(t, "stable diffusion", string(body)) + + response := replicate.Page[replicate.Model]{ + Results: []replicate.Model{ + { + Owner: "stability-ai", + Name: "sdxl", + Description: "A text-to-image generative AI model that creates beautiful 1024x1024 images", + }, + { + Owner: "stability-ai", + Name: "stable-diffusion", + Description: "Stable Diffusion is a text-to-image diffusion model", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(response) + })) + defer mockServer.Close() + + client, err := replicate.NewClient( + replicate.WithToken("test-token"), + replicate.WithBaseURL(mockServer.URL), + ) + require.NotNil(t, client) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + modelsPage, err := client.SearchModels(ctx, "stable diffusion") + assert.NoError(t, err) + assert.Equal(t, 2, len(modelsPage.Results)) + assert.Equal(t, "stability-ai", modelsPage.Results[0].Owner) + assert.Equal(t, "sdxl", modelsPage.Results[0].Name) + assert.Equal(t, "stability-ai", modelsPage.Results[1].Owner) + assert.Equal(t, "stable-diffusion", modelsPage.Results[1].Name) +} + func TestGetModel(t *testing.T) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/models/replicate/hello-world", r.URL.Path) diff --git a/example_test.go b/example_test.go index 4213162..50f0139 100644 --- a/example_test.go +++ b/example_test.go @@ -3,6 +3,7 @@ package replicate_test import ( "context" "fmt" + "strings" "github.com/replicate/replicate-go" ) @@ -59,3 +60,26 @@ func ExampleClient_CreatePrediction() { fmt.Println(prediction.Status) // Output: succeeded } + +func ExampleClient_SearchModels() { + ctx := context.TODO() + + r8, err := replicate.NewClient(replicate.WithTokenFromEnv()) + if err != nil { + panic(err) + } + + query := "llama" + modelsPage, err := r8.SearchModels(ctx, query) + if err != nil { + panic(err) + } + + for _, model := range modelsPage.Results { + if model.Owner == "meta" && strings.HasPrefix(model.Name, "meta-llama-3") { + fmt.Printf("Found Meta Llama 3 model") + break + } + } + // Output: Found Meta Llama 3 model +} diff --git a/model.go b/model.go index 9d4bdae..9ca5c12 100644 --- a/model.go +++ b/model.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "net/http" + "strings" ) type Model struct { @@ -79,6 +80,21 @@ func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) { return response, nil } +// SearchModels searches for public models. +func (r *Client) SearchModels(ctx context.Context, query string) (*Page[Model], error) { + response := &Page[Model]{} + request, err := r.newRequest(ctx, "QUERY", "/models", strings.NewReader(query)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + request.Header.Set("Content-Type", "text/plain") + err = r.do(request, response) + if err != nil { + return nil, fmt.Errorf("failed to search models: %w", err) + } + return response, nil +} + // GetModel retrieves information about a model. func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) { model := &Model{}