Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet
BINARY_NAME=local-ai

GOLLAMA_VERSION?=cf9b522db63898dcc5eb86e37c979ab85cbd583e
GOLLAMA_VERSION?=b4e97a42d0c10ada6b529b0ec17b05c72435aeab
GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
Expand Down
12 changes: 10 additions & 2 deletions api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Config struct {
Mirostat int `yaml:"mirostat"`

PromptStrings, InputStrings []string
InputToken [][]int
}

type TemplateConfig struct {
Expand Down Expand Up @@ -186,8 +187,15 @@ func updateConfig(config *Config, input *OpenAIRequest) {
}
case []interface{}:
for _, pp := range inputs {
if s, ok := pp.(string); ok {
config.InputStrings = append(config.InputStrings, s)
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
}
}
}
Expand Down
17 changes: 15 additions & 2 deletions api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,23 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{}

for i, s := range config.InputStrings {
for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding("", s, loader, *config)
if err != nil {
return err
}

embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
}

for i, s := range config.InputStrings {
// get the model function to call for the result
embedFn, err := ModelEmbedding(s, loader, *config)
embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
if err != nil {
return err
}
Expand Down
5 changes: 4 additions & 1 deletion api/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
return llamaOpts
}

func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
}
Expand All @@ -57,6 +57,9 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl
case *llama.LLama:
fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c)
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...)
}
return model.Embeddings(s, predictOptions...)
}
default:
Expand Down