From 8837ed5a6cd1e6fa49afbff128170ade78c70e16 Mon Sep 17 00:00:00 2001 From: mudler Date: Tue, 9 May 2023 20:08:57 +0200 Subject: [PATCH 1/4] feat: add dolly models support --- Makefile | 2 +- api/api_test.go | 2 +- api/prediction.go | 24 +++++++++++++++++++ go.mod | 2 +- go.sum | 2 ++ pkg/model/loader.go | 56 +++++++++++++++++++++++++++++++++++++++++---- 6 files changed, 80 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index bffb161ccc48..4ec1bde8913d 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ BINARY_NAME=local-ai GOLLAMA_VERSION?=c03e8adbc45c866e0f6d876af1887d6b01d57eb4 GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 -GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa +GOGPT2_VERSION?=d498232 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 diff --git a/api/api_test.go b/api/api_test.go index 9682a218b7e3..0a581912a446 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -79,7 +79,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 6 errors occurred:")) }) }) diff --git a/api/prediction.go b/api/prediction.go index 95d111ff5f0b..42f6eca48bf9 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -213,6 +213,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions..., ) } + case *gpt2.Dolly: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } case *gpt2.GPT2: fn = func() (string, error) { // Generate the prediction using the language model diff --git a/go.mod b/go.mod index c25c9e3a7096..d9af07b1c9ea 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 github.com/go-audio/wav v1.1.0 - github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 + github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3 github.com/gofiber/fiber/v2 v2.45.0 diff --git a/go.sum b/go.sum index 400d9ff6e749..d2648f6466ad 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyr github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 h1:cfOi4TWvQ6JsAm9Q1A8I8j9YfNy10bmIfwOiyGyU5wQ= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-llama.cpp v0.0.0-20230508165257-c03e8adbc45c h1:JoW2+LKrSemoV32QRwrEC5f53erym96NCsUSM3wSVbM= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 167d5d7eec13..d6c799a7fc74 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -27,8 +27,10 @@ type ModelLoader struct { gptmodels map[string]*gptj.GPTJ gpt2models map[string]*gpt2.GPT2 gptstablelmmodels map[string]*gpt2.StableLM - rwkv map[string]*rwkv.RwkvState - promptsTemplates map[string]*template.Template + dollymodels map[string]*gpt2.Dolly + + rwkv map[string]*rwkv.RwkvState + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { @@ -37,9 +39,11 @@ func NewModelLoader(modelPath string) *ModelLoader { gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), gptstablelmmodels: make(map[string]*gpt2.StableLM), - models: make(map[string]*llama.LLama), - rwkv: make(map[string]*rwkv.RwkvState), - promptsTemplates: make(map[string]*template.Template), + dollymodels: make(map[string]*gpt2.Dolly), + + models: make(map[string]*llama.LLama), + rwkv: make(map[string]*rwkv.RwkvState), + promptsTemplates: make(map[string]*template.Template), } } @@ -124,6 +128,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } +func (ml *ModelLoader) LoadDollyModel(modelName string) (*gpt2.Dolly, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.dollymodels[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.ModelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := gpt2.NewDolly(modelFile) + if err != nil { + return nil, err + } + + // If there is a prompt template, load it + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return nil, err + } + + ml.dollymodels[modelName] = model + return model, err +} + func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -295,6 +331,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadLLaMAModel(modelFile, llamaOpts...) case "stablelm": return ml.LoadStableLMModel(modelFile) + case "dolly": + return ml.LoadDollyModel(modelFile) case "gpt2": return ml.LoadGPT2Model(modelFile) case "gptj": @@ -353,6 +391,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt err = multierror.Append(err, modelerr) } + model, modelerr = ml.LoadDollyModel(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads) if modelerr == nil { updateModels(model) From b99ca631c60edf8e5df5a4fa154ac238fe5aaf20 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 10 May 2023 16:58:09 +0200 Subject: [PATCH 2/4] Update Makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index a9c9fbd257d9..03ece6838e09 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ BINARY_NAME=local-ai GOLLAMA_VERSION?=c03e8adbc45c866e0f6d876af1887d6b01d57eb4 GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 -GOGPT2_VERSION?=d498232 +GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 From 6273f783a63d22b4841425e20a8c8fd7a41177f8 Mon Sep 17 00:00:00 2001 From: mudler Date: Wed, 10 May 2023 18:30:22 +0200 Subject: [PATCH 3/4] Add redpajama --- Makefile | 2 +- pkg/model/loader.go | 48 ++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 03ece6838e09..a5af2fbda080 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ BINARY_NAME=local-ai GOLLAMA_VERSION?=c03e8adbc45c866e0f6d876af1887d6b01d57eb4 GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 -GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa +GOGPT2_VERSION?=abf038a7d8efa4eefdc7c891f05ad33d4e59e49d RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 7e92d6906f51..da21ddc293a5 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -29,9 +29,10 @@ type ModelLoader struct { gpt2models map[string]*gpt2.GPT2 gptstablelmmodels map[string]*gpt2.StableLM dollymodels map[string]*gpt2.Dolly + redpajama map[string]*gpt2.RedPajama rwkv map[string]*rwkv.RwkvState bert map[string]*bert.Bert - promptsTemplates map[string]*template.Template + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { @@ -41,8 +42,7 @@ func NewModelLoader(modelPath string) *ModelLoader { gptmodels: make(map[string]*gptj.GPTJ), gptstablelmmodels: make(map[string]*gpt2.StableLM), dollymodels: make(map[string]*gpt2.Dolly), - - + redpajama: make(map[string]*gpt2.RedPajama), models: make(map[string]*llama.LLama), rwkv: make(map[string]*rwkv.RwkvState), bert: make(map[string]*bert.Bert), @@ -131,6 +131,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } +func (ml *ModelLoader) LoadRedPajama(modelName string) (*gpt2.RedPajama, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.redpajama[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.ModelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := gpt2.NewRedPajama(modelFile) + if err != nil { + return nil, err + } + + // If there is a prompt template, load it + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return nil, err + } + + ml.redpajama[modelName] = model + return model, err +} + func (ml *ModelLoader) LoadDollyModel(modelName string) (*gpt2.Dolly, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -368,6 +400,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadStableLMModel(modelFile) case "dolly": return ml.LoadDollyModel(modelFile) + case "redpajama": + return ml.LoadRedPajama(modelFile) case "gpt2": return ml.LoadGPT2Model(modelFile) case "gptj": @@ -436,6 +470,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt err = multierror.Append(err, modelerr) } + model, modelerr = ml.LoadRedPajama(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads) if modelerr == nil { updateModels(model) From 60b7feeda7d77199b4673e391a253a8c548247f1 Mon Sep 17 00:00:00 2001 From: mudler Date: Thu, 11 May 2023 00:29:04 +0200 Subject: [PATCH 4/4] Add bloomz --- .github/workflows/bump_deps.yaml | 3 ++ Makefile | 24 +++++++++++--- api/api_test.go | 2 +- api/prediction.go | 45 ++++++++++++++++++++++++++ go.mod | 3 +- go.sum | 15 ++------- pkg/model/loader.go | 55 +++++++++++++++++++++++++++++--- 7 files changed, 124 insertions(+), 23 deletions(-) diff --git a/.github/workflows/bump_deps.yaml b/.github/workflows/bump_deps.yaml index c889fab92495..6aa7aa442ed9 100644 --- a/.github/workflows/bump_deps.yaml +++ b/.github/workflows/bump_deps.yaml @@ -27,6 +27,9 @@ jobs: - repository: "go-skynet/go-bert.cpp" variable: "BERT_VERSION" branch: "master" + - repository: "go-skynet/bloomz.cpp" + variable: "BLOOMZ_VERSION" + branch: "main" runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/Makefile b/Makefile index a5af2fbda080..e043c4569eaf 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=af62fcc432be2847acb6e0688b2c2491d6588d58 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13 +BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 GREEN := $(shell tput -Txterm setaf 2) @@ -18,8 +19,8 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert -LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert +C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz # Use this if you want to set the default behavior ifndef BUILD_TYPE @@ -69,6 +70,18 @@ go-rwkv: go-rwkv/librwkv.a: go-rwkv cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a .. +## bloomz +bloomz: + git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz + @find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + + @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + + @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + + @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + + @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + + +bloomz/libbloomz.a: bloomz + cd bloomz && make libbloomz.a + go-bert/libgobert.a: go-bert $(MAKE) -C go-bert libgobert.a @@ -111,8 +124,9 @@ replace: $(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv $(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(shell pwd)/whisper.cpp $(GOCMD) mod edit -replace github.com/go-skynet/go-bert.cpp=$(shell pwd)/go-bert + $(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz -prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv whisper.cpp go-bert +prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv whisper.cpp go-bert bloomz $(GOCMD) mod download ## GENERIC @@ -123,9 +137,10 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-rwkv clean $(MAKE) -C whisper.cpp clean $(MAKE) -C go-bert clean + $(MAKE) -C bloomz clean $(MAKE) build -prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-bert/libgobert.a go-gpt2/libgpt2.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a replace ## Prepares for building +prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-bert/libgobert.a go-gpt2/libgpt2.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a replace ## Prepares for building clean: ## Remove build related file rm -fr ./go-llama @@ -133,6 +148,7 @@ clean: ## Remove build related file rm -rf ./go-gpt2 rm -rf ./go-rwkv rm -rf ./go-bert + rm -rf ./bloomz rm -rf $(BINARY_NAME) ## Build: diff --git a/api/api_test.go b/api/api_test.go index 0a581912a446..639f18d9937d 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -79,7 +79,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 6 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 9 errors occurred:")) }) }) diff --git a/api/prediction.go b/api/prediction.go index ce5743d23c4d..b705f6635b20 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -8,6 +8,7 @@ import ( "github.com/donomii/go-rwkv.cpp" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" @@ -198,6 +199,50 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback return response, nil } + case *gpt2.RedPajama: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } + case *bloomz.Bloomz: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(c.Temperature), + bloomz.SetTopP(c.TopP), + bloomz.SetTopK(c.TopK), + bloomz.SetTokens(c.Maxtokens), + bloomz.SetThreads(c.Threads), + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } case *gpt2.StableLM: fn = func() (string, error) { // Generate the prediction using the language model diff --git a/go.mod b/go.mod index 40b5e5dc66df..7032eca5054c 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,9 @@ require ( github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 github.com/go-audio/wav v1.1.0 - github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 + github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea + github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3 github.com/gofiber/fiber/v2 v2.45.0 diff --git a/go.sum b/go.sum index ea89947206b7..3ae82948c1e5 100644 --- a/go.sum +++ b/go.sum @@ -36,16 +36,8 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7 github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708 h1:cfOi4TWvQ6JsAm9Q1A8I8j9YfNy10bmIfwOiyGyU5wQ= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230422085954-245a5bfe6708/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230508165257-c03e8adbc45c h1:JoW2+LKrSemoV32QRwrEC5f53erym96NCsUSM3wSVbM= -github.com/go-skynet/go-llama.cpp v0.0.0-20230508165257-c03e8adbc45c/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I= -github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3 h1:YNi1oetK5kGJoUgT3/r/Wj3XPOICWf3nwHsz5v89iSs= -github.com/go-skynet/go-llama.cpp v0.0.0-20230509080828-f4d26f43f1d3/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I= +github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f h1:GW8RQa1RVeDF1dOuAP/y6xWVC+BRtf9tJOuEza6Asbg= +github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -197,9 +189,8 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= -gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index da21ddc293a5..3679a4615cb3 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,14 +10,14 @@ import ( "sync" "text/template" - "github.com/hashicorp/go-multierror" - "github.com/rs/zerolog/log" - rwkv "github.com/donomii/go-rwkv.cpp" + bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" + "github.com/hashicorp/go-multierror" + "github.com/rs/zerolog/log" ) type ModelLoader struct { @@ -31,8 +31,10 @@ type ModelLoader struct { dollymodels map[string]*gpt2.Dolly redpajama map[string]*gpt2.RedPajama rwkv map[string]*rwkv.RwkvState - bert map[string]*bert.Bert - promptsTemplates map[string]*template.Template + bloomz map[string]*bloomz.Bloomz + + bert map[string]*bert.Bert + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { @@ -45,6 +47,7 @@ func NewModelLoader(modelPath string) *ModelLoader { redpajama: make(map[string]*gpt2.RedPajama), models: make(map[string]*llama.LLama), rwkv: make(map[string]*rwkv.RwkvState), + bloomz: make(map[string]*bloomz.Bloomz), bert: make(map[string]*bert.Bert), promptsTemplates: make(map[string]*template.Template), } @@ -259,6 +262,38 @@ func (ml *ModelLoader) LoadBERT(modelName string) (*bert.Bert, error) { return model, err } +func (ml *ModelLoader) LoadBloomz(modelName string) (*bloomz.Bloomz, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.bloomz[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.ModelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := bloomz.New(modelFile) + if err != nil { + return nil, err + } + + // If there is a prompt template, load it + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return nil, err + } + + ml.bloomz[modelName] = model + return model, err +} + func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -396,6 +431,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla switch strings.ToLower(backendString) { case "llama": return ml.LoadLLaMAModel(modelFile, llamaOpts...) + case "bloomz": + return ml.LoadBloomz(modelFile) case "stablelm": return ml.LoadStableLMModel(modelFile) case "dolly": @@ -478,6 +515,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt err = multierror.Append(err, modelerr) } + model, modelerr = ml.LoadBloomz(modelFile) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + model, modelerr = ml.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads) if modelerr == nil { updateModels(model)