Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,60 @@ func TestListModels(t *testing.T) {
assert.Equal(t, "codellama-13b", modelsPage.Results[1].Name)
}

func TestSearchModels(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/models", r.URL.Path)
assert.Equal(t, "QUERY", r.Method)
assert.Equal(t, "text/plain", r.Header.Get("Content-Type"))

body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
defer r.Body.Close()

assert.Equal(t, "stable diffusion", string(body))

response := replicate.Page[replicate.Model]{
Results: []replicate.Model{
{
Owner: "stability-ai",
Name: "sdxl",
Description: "A text-to-image generative AI model that creates beautiful 1024x1024 images",
},
{
Owner: "stability-ai",
Name: "stable-diffusion",
Description: "Stable Diffusion is a text-to-image diffusion model",
},
},
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(response)
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NotNil(t, client)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

modelsPage, err := client.SearchModels(ctx, "stable diffusion")
assert.NoError(t, err)
assert.Equal(t, 2, len(modelsPage.Results))
assert.Equal(t, "stability-ai", modelsPage.Results[0].Owner)
assert.Equal(t, "sdxl", modelsPage.Results[0].Name)
assert.Equal(t, "stability-ai", modelsPage.Results[1].Owner)
assert.Equal(t, "stable-diffusion", modelsPage.Results[1].Name)
}

func TestGetModel(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/models/replicate/hello-world", r.URL.Path)
Expand Down
24 changes: 24 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package replicate_test
import (
"context"
"fmt"
"strings"

"github.com/replicate/replicate-go"
)
Expand Down Expand Up @@ -59,3 +60,26 @@ func ExampleClient_CreatePrediction() {
fmt.Println(prediction.Status)
// Output: succeeded
}

func ExampleClient_SearchModels() {
ctx := context.TODO()

r8, err := replicate.NewClient(replicate.WithTokenFromEnv())
if err != nil {
panic(err)
}

query := "llama"
modelsPage, err := r8.SearchModels(ctx, query)
if err != nil {
panic(err)
}

for _, model := range modelsPage.Results {
if model.Owner == "meta" && strings.HasPrefix(model.Name, "meta-llama-3") {
fmt.Printf("Found Meta Llama 3 model")
break
}
}
// Output: Found Meta Llama 3 model
}
16 changes: 16 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strings"
)

type Model struct {
Expand Down Expand Up @@ -79,6 +80,21 @@ func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) {
return response, nil
}

// SearchModels searches for public models.
func (r *Client) SearchModels(ctx context.Context, query string) (*Page[Model], error) {
response := &Page[Model]{}
request, err := r.newRequest(ctx, "QUERY", "/models", strings.NewReader(query))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
request.Header.Set("Content-Type", "text/plain")
err = r.do(request, response)
if err != nil {
return nil, fmt.Errorf("failed to search models: %w", err)
}
return response, nil
}

// GetModel retrieves information about a model.
func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) {
model := &Model{}
Expand Down