Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ jobs:
uses: docker/login-action@v2
with:
registry: quay.io
username: ${{ secrets.QUAY_USERNAME }}
password: ${{ secrets.QUAY_PASSWORD }}
username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
- name: Build
if: github.event_name != 'pull_request'
uses: docker/build-push-action@v4
Expand Down
79 changes: 45 additions & 34 deletions api/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -245,7 +246,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,

result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
})
}, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -290,8 +291,9 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread

if input.Stream {
log.Debug().Msgf("Stream request received")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
c.Set("Content-Type", "text/event-stream; charset=utf-8")
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
Expand All @@ -312,53 +314,62 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread
log.Debug().Msgf("Template found, input modified to: %s", predInput)
}

result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
if input.Stream {
*c = append(*c, Choice{Delta: &Message{Role: "assistant", Content: s}})
} else {
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
}
})
if err != nil {
return err
}

resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
}

if input.Stream {
resp.Object = "chat.completion.chunk"
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
log.Debug().Msgf("Handling stream request")
responses := make(chan OpenAIResponse)

go func() {
ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
resp := OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}},
Object: "chat.completion.chunk",
}

responses <- resp
return true
})
close(responses)
}()

c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "event: data\n")
w.Flush()

fmt.Fprintf(w, "data: %s\n\n", jsonResult)
w.Flush()
for ev := range responses {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)

fmt.Fprintf(w, "event: data\n")
w.Flush()
fmt.Fprintf(w, "event: data\n\n")
fmt.Fprintf(w, "data: %v\n\n", buf.String())
log.Debug().Msgf("Sending chunk: %s", buf.String())
w.Flush()
}

w.WriteString("event: data\n\n")
resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []Choice{{FinishReason: "stop"}},
}
respData, _ := json.Marshal(resp)

fmt.Fprintf(w, "data: %s\n\n", respData)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.Flush()

// fmt.Fprintf(w, "data: [DONE]\n\n")
// w.Flush()
}))
return nil
}

result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
}, nil)
if err != nil {
return err
}

resp := &OpenAIResponse{
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: result,
Object: "chat.completion",
}

// Return the prediction in the response body
return c.JSON(resp)
}
Expand Down Expand Up @@ -392,7 +403,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread

result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s})
})
}, nil)
if err != nil {
return err
}
Expand Down
20 changes: 15 additions & 5 deletions api/prediction.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ import (
var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)

func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) {
func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) {
var model *llama.LLama
var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM

supportStreams := false
modelFile := c.Model

// Try to load the model
Expand Down Expand Up @@ -125,7 +125,13 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
)
}
case model != nil:
supportStreams = true
fn = func() (string, error) {

if tokenCallback != nil {
model.SetTokenCallback(tokenCallback)
}

// Generate the prediction using the language model
predictOptions := []llama.PredictOption{
llama.SetTemperature(c.Temperature),
Expand Down Expand Up @@ -185,11 +191,15 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri
l.Lock()
defer l.Unlock()

return fn()
res, err := fn()
if tokenCallback != nil && !supportStreams {
tokenCallback(res)
}
return res, err
}, nil
}

func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) {
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) {
result := []Choice{}

n := input.N
Expand All @@ -199,7 +209,7 @@ func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, load
}

// get the model function to call for the result
predFunc, err := ModelInference(predInput, loader, *config)
predFunc, err := ModelInference(predInput, loader, *config, tokenCallback)
if err != nil {
return result, err
}
Expand Down