Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ type Functions struct {
DisableNoAction bool `yaml:"disable_no_action"`
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionDescriptionName string `yaml:"no_action_description_name"`
ParallelCalls bool `yaml:"parallel_calls"`
}

type TemplateConfig struct {
Expand Down
241 changes: 143 additions & 98 deletions api/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return true
})

ss := map[string]interface{}{}
name, args := parseFunctionCall(result)
ss["name"], ss["arguments"] = name, args
results := parseFunctionCall(result, config.FunctionsConfig.ParallelCalls)
noActionToRun := len(results) > 0 && results[0].name == noAction

if name == noAction {
switch {
case noActionToRun:
initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Expand All @@ -78,7 +78,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}
responses <- initialMessage

result, err := handleQuestion(config, req, o, args, prompt)
result, err := handleQuestion(config, req, o, results[0].arguments, prompt)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
Expand All @@ -98,52 +98,56 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}

responses <- resp
close(responses)
return
}

initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
default:
for i, ss := range results {
name, args := ss.name, ss.arguments

initialMessage := schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
},
},
},
},
},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage

responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: 0,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
responses <- schema.OpenAIResponse{
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{
Delta: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
Index: i,
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Arguments: args,
},
},
},
},
},
}}},
Object: "chat.completion.chunk",
}}},
Object: "chat.completion.chunk",
}
}
}

close(responses)
}

Expand Down Expand Up @@ -208,9 +212,9 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)

// Update input grammar
jsStruct := funcs.ToJSONStructure()
config.Grammar = jsStruct.Grammar("")
config.Grammar = jsStruct.Grammar("", config.FunctionsConfig.ParallelCalls)
} else if input.JSONFunctionGrammarObject != nil {
config.Grammar = input.JSONFunctionGrammarObject.Grammar("")
config.Grammar = input.JSONFunctionGrammarObject.Grammar("", config.FunctionsConfig.ParallelCalls)
}

// functions are not supported in stream mode (yet?)
Expand Down Expand Up @@ -407,57 +411,74 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
}))
return nil

// no streaming mode
default:
result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) {
if processFunctions {
ss := map[string]interface{}{}

name, args := parseFunctionCall(s)
ss["name"], ss["arguments"] = name, args

// if do nothing, reply with a message
if name == noActionName {
result, err := handleQuestion(config, input, o, args, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &result}})
} else {
if !processFunctions {
// no function is called, just reply and use stop as finish reason
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}

results := parseFunctionCall(s, config.FunctionsConfig.ParallelCalls)
noActionsToRun := len(results) > 0 && results[0].name == noActionName

switch {
case noActionsToRun:
result, err := handleQuestion(config, input, o, results[0].arguments, predInput)
if err != nil {
log.Error().Msgf("error handling question: %s", err.Error())
return
}
*c = append(*c, schema.Choice{
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolChoice := schema.Choice{
Message: &schema.Message{
Role: "assistant",
},
}

if len(input.Tools) > 0 {
toolChoice.FinishReason = "tool_calls"
}

for _, ss := range results {
name, args := ss.name, ss.arguments
if len(input.Tools) > 0 {
// Result is different in the case we have a tool call
*c = append(*c, schema.Choice{
FinishReason: "tool_calls",
Message: &schema.Message{
Role: "assistant",
ToolCalls: []schema.ToolCall{
{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
// If we are using tools, we condense the function calls into
// a single response choice with all the tools
toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls,
schema.ToolCall{
ID: id,
Type: "function",
FunctionCall: schema.FunctionCall{
Name: name,
Arguments: args,
},
},
})
)
} else {
// otherwise reply with the function call
// otherwise we return more choices directly
*c = append(*c, schema.Choice{
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
FunctionCall: ss,
Role: "assistant",
FunctionCall: map[string]interface{}{
"name": name,
"arguments": args,
},
},
})
}
}

return
if len(input.Tools) > 0 {
// we need to append our result if we are using tools
*c = append(*c, toolChoice)
}
}

*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
}, nil)
if err != nil {
return err
Expand Down Expand Up @@ -528,19 +549,43 @@ func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *optio
return backend.Finetune(*config, prompt, prediction.Response), nil
}

func parseFunctionCall(llmresult string) (string, string) {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)

return func_name.(string), string(d)
type funcCallResults struct {
name string
arguments string
}

func parseFunctionCall(llmresult string, multipleResults bool) []funcCallResults {
results := []funcCallResults{}

// TODO: use generics to avoid this code duplication
if multipleResults {
ss := []map[string]interface{}{}
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

for _, s := range ss {
func_name := s["function"]
args := s["arguments"]
d, _ := json.Marshal(args)
results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}
} else {
// As we have to change the result before processing, we can't stream the answer token-by-token (yet?)
ss := map[string]interface{}{}
// This prevent newlines to break JSON parsing for clients
s := utils.EscapeNewLines(llmresult)
json.Unmarshal([]byte(s), &ss)
log.Debug().Msgf("Function return: %s %+v", s, ss)

// The grammar defines the function name as "function", while OpenAI returns "name"
func_name := ss["function"]
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
d, _ := json.Marshal(args)

results = append(results, funcCallResults{name: func_name.(string), arguments: string(d)})
}

return results
}
31 changes: 24 additions & 7 deletions pkg/grammar/json_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,28 @@ func (sc *JSONSchemaConverter) addRule(name, rule string) string {
return key
}

func (sc *JSONSchemaConverter) formatGrammar() string {
const array = `arr ::=
"[\n" (
realvalue
(",\n" realvalue)*
)? "]"`

func (sc *JSONSchemaConverter) finalizeGrammar(maybeArray bool) string {
var lines []string
// write down the computed rules.
// if maybeArray is true, we need to add the array rule and slightly tweak the root rule
for name, rule := range sc.rules {
if maybeArray && name == "root" {
name = "realvalue"
}
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule))
}

if maybeArray {
lines = append(lines, fmt.Sprintf("%s ::= %s", "root", "arr | realvalue"))
lines = append(lines, array)
}

return strings.Join(lines, "\n")
}

Expand Down Expand Up @@ -234,15 +251,15 @@ func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[strin

return def
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}, maybeArray bool) string {
sc.visit(schema, "", schema)
return sc.formatGrammar()
return sc.finalizeGrammar(maybeArray)
}

func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string {
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte, maybeArray bool) string {
var schema map[string]interface{}
_ = json.Unmarshal(b, &schema)
return sc.Grammar(schema)
return sc.Grammar(schema, maybeArray)
}

func jsonString(v interface{}) string {
Expand Down Expand Up @@ -275,7 +292,7 @@ type JSONFunctionStructure struct {
Defs map[string]interface{} `json:"$defs,omitempty"`
}

func (j JSONFunctionStructure) Grammar(propOrder string) string {
func (j JSONFunctionStructure) Grammar(propOrder string, maybeArray bool) string {
dat, _ := json.Marshal(j)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat)
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat, maybeArray)
}
Loading