Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
# renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa

# here until https://github.com/donomii/go-rwkv.cpp/pull/1 is merged
RWKV_REPO?=https://github.com/mudler/go-rwkv.cpp
RWKV_VERSION?=6ba15255b03016b5ecce36529b500d21815399a7

GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3)
WHITE := $(shell tput -Txterm setaf 7)
CYAN := $(shell tput -Txterm setaf 6)
RESET := $(shell tput -Txterm sgr0)

C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv

# Use this if you want to set the default behavior
ifndef BUILD_TYPE
Expand All @@ -33,16 +37,6 @@ endif

all: help

## Build:

build: prepare ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./

generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build

## GPT4ALL-J
go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
Expand All @@ -57,11 +51,19 @@ go-gpt4all-j:
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +

## RWKV
go-rwkv:
git clone --recurse-submodules $(RWKV_REPO) go-rwkv
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1

go-rwkv/librwkv.a: go-rwkv
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a ..

go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a

# CEREBRAS GPT
go-gpt2:
## CEREBRAS GPT
go-gpt2:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2
cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
Expand All @@ -74,7 +76,6 @@ go-gpt2:

go-gpt2/libgpt2.a: go-gpt2
$(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a


go-llama:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
Expand All @@ -86,26 +87,40 @@ replace:
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv

prepare-sources: go-llama go-gpt2 go-gpt4all-j
prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv
$(GOCMD) mod download

rebuild:
## GENERIC
rebuild: ## Rebuilds the project
$(MAKE) -C go-llama clean
$(MAKE) -C go-gpt4all-j clean
$(MAKE) -C go-gpt2 clean
$(MAKE) -C go-rwkv clean
$(MAKE) build

prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a replace
prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a go-rwkv/librwkv.a replace ## Prepares for building

clean: ## Remove build related file
rm -fr ./go-llama
rm -rf ./go-gpt4all-j
rm -rf ./go-gpt2
rm -rf ./go-rwkv
rm -rf $(BINARY_NAME)

## Run:
run: prepare
## Build:

build: prepare ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./

generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build

## Run
run: prepare ## run local-ai
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go

test-models/testmodel:
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
- Supports multiple-models
- Once loaded the first time, it keep models loaded in memory for faster inference
- Support for prompt templates
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp).
- Doesn't shell-out, but uses C bindings for a faster inference and better performance.

LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).

Expand All @@ -39,6 +39,7 @@ Tested with:
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
- Koala
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
- [RWKV](https://github.com/BlinkDL/RWKV-LM) with [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp)

It should also be compatible with StableLM and GPTNeoX ggml models (untested)

Expand Down Expand Up @@ -506,6 +507,13 @@ LocalAI is a community-driven project. It was initially created by [mudler](http

MIT

## Golang bindings used

- [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- [go-skynet/go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp)
- [go-skynet/go-gpt2.cpp](https://github.com/go-skynet/go-gpt2.cpp)
- [donomii/go-rwkv.cpp](https://github.com/donomii/go-rwkv.cpp)

## Acknowledgements

- [llama.cpp](https://github.com/ggerganov/llama.cpp)
Expand Down
2 changes: 1 addition & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:"))
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
})

})
Expand Down
35 changes: 31 additions & 4 deletions api/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,24 @@ import (
"strings"
"sync"

"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
)

const tokenizerSuffix = ".tokenizer.json"

// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)

var loadedModels map[string]interface{} = map[string]interface{}{}
var muModels sync.Mutex

func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
switch strings.ToLower(backendString) {
case "llama":
return loader.LoadLLaMAModel(modelFile, llamaOpts...)
Expand All @@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
return loader.LoadGPT2Model(modelFile)
case "gptj":
return loader.LoadGPTJModel(modelFile)
case "rwkv":
return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
default:
return nil, fmt.Errorf("backend unsupported: %s", backendString)
}
}

func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) {
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
updateModels := func(model interface{}) {
muModels.Lock()
defer muModels.Unlock()
Expand Down Expand Up @@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
err = multierror.Append(err, modelerr)
}

model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}

return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
}

Expand All @@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var inferenceModel interface{}
var err error
if c.Backend == "" {
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts)
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads))
} else {
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts)
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads))
}
if err != nil {
return nil, err
Expand All @@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var fn func() (string, error)

switch model := inferenceModel.(type) {
case *rwkv.RwkvState:
supportStreams = true

fn = func() (string, error) {
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
stopWord := "\n"
if len(c.StopWords) > 0 {
stopWord = c.StopWords[0]
}

response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)

return response, nil
}
case *gpt2.StableLM:
fn = func() (string, error) {
// Generate the prediction using the language model
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/StackExchange/wmi v1.2.1 // indirect
github.com/andybalholm/brotli v1.0.5 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-logr/logr v1.2.3 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d h1:lSHwlYf1H4WAWYgf7rjEVTGen1qmigUq2Egpu8mnQiY=
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d/go.mod h1:H6QBF7/Tz6DAEBDXQged4H1BvsmqY/K5FG9wQRGa01g=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
Expand Down
36 changes: 34 additions & 2 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/rs/zerolog/log"

rwkv "github.com/donomii/go-rwkv.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
Expand All @@ -25,8 +26,8 @@ type ModelLoader struct {
gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM

promptsTemplates map[string]*template.Template
rwkv map[string]*rwkv.RwkvState
promptsTemplates map[string]*template.Template
}

func NewModelLoader(modelPath string) *ModelLoader {
Expand All @@ -36,6 +37,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
models: make(map[string]*llama.LLama),
rwkv: make(map[string]*rwkv.RwkvState),
promptsTemplates: make(map[string]*template.Template),
}
}
Expand Down Expand Up @@ -218,6 +220,36 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
return model, err
}

func (ml *ModelLoader) LoadRWKV(modelName, tokenFile string, threads uint32) (*rwkv.RwkvState, error) {
ml.mu.Lock()
defer ml.mu.Unlock()

log.Debug().Msgf("Loading model name: %s", modelName)

// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}

if m, ok := ml.rwkv[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}

// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
tokenPath := filepath.Join(ml.ModelPath, tokenFile)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)

model := rwkv.LoadFiles(modelFile, tokenPath, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}

ml.rwkv[modelName] = model
return model, nil
}

func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
Expand Down