From c72808f18b7ea9811557b201e82e88a54b2abc61 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 17 Feb 2024 10:00:34 +0100 Subject: [PATCH] feat(tools): support Tool calls in the API (#1715) * feat(tools): support Tools in the API Co-authored-by: =?UTF-8?q?Stephan=20A=C3=9Fmus?= * feat(tools): support function streaming * Adhere to new return types when using tools instead of functions * Keep backward compatibility with function calling * Evaluate function names in chat templates * Disable recovery with --debug * Correctly stream out the entire result * Detect when llm chooses to reply and to not perform any action in SSE * Feedback from code review --------- Co-authored-by: =?UTF-8?q?Stephan=20A=C3=9Fmus?= --- api/api.go | 6 +- api/openai/chat.go | 329 ++++++++++++++++++++++++++++----------- api/openai/request.go | 15 ++ api/schema/openai.go | 21 +++ pkg/grammar/functions.go | 6 + pkg/model/loader.go | 1 + 6 files changed, 286 insertions(+), 92 deletions(-) diff --git a/api/api.go b/api/api.go index 7ec95f1b63a..946204d2b06 100644 --- a/api/api.go +++ b/api/api.go @@ -146,7 +146,11 @@ func App(opts ...options.AppOption) (*fiber.App, error) { } // Default middleware config - app.Use(recover.New()) + + if !options.Debug { + app.Use(recover.New()) + } + if options.Metrics != nil { app.Use(metrics.APIMiddleware(options.Metrics)) } diff --git a/api/openai/chat.go b/api/openai/chat.go index 819cd6b2d6c..68c3a291a1b 100644 --- a/api/openai/chat.go +++ b/api/openai/chat.go @@ -55,6 +55,98 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) }) close(responses) } + processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + result := "" + _, tokenUsage, _ := ComputeChoices(req, prompt, config, o, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { + result += s + // TODO: Change generated BNF grammar to be compliant with the schema so we can + // stream the result token by token here. + return true + }) + + ss := map[string]interface{}{} + name, args := parseFunctionCall(result) + ss["name"], ss["arguments"] = name, args + + if name == noAction { + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &emptyMessage}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + result, err := handleQuestion(config, req, o, args, prompt) + if err != nil { + log.Error().Msgf("error handling question: %s", err.Error()) + return + } + + resp := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}}, + Object: "chat.completion.chunk", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + + responses <- resp + close(responses) + return + } + + initialMessage := schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: 0, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + responses <- schema.OpenAIResponse{ + ID: id, + Created: created, + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []schema.Choice{{ + Delta: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + Index: 0, + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Arguments: args, + }, + }, + }, + }}}, + Object: "chat.completion.chunk", + } + close(responses) + } + return func(c *fiber.Ctx) error { processFunctions := false funcs := grammar.Functions{} @@ -122,7 +214,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions + toStream := input.Stream log.Debug().Msgf("Parameters: %+v", config) @@ -145,6 +237,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) } r := config.Roles[role] contentExists := i.Content != nil && i.StringContent != "" + // First attempt to populate content via a chat message specific template if config.TemplateConfig.ChatMessage != "" { chatMessageData := model.ChatMessageTemplateData{ @@ -152,6 +245,7 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) Role: r, RoleName: role, Content: i.StringContent, + FunctionName: i.Name, MessageIndex: messageIndex, } templatedChatMessage, err := o.Loader.EvaluateTemplateForChatMessage(config.TemplateConfig.ChatMessage, chatMessageData) @@ -254,17 +348,24 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) log.Debug().Msgf("Grammar: %+v", config.Grammar) } - if toStream { + switch { + case toStream: responses := make(chan schema.OpenAIResponse) - go process(predInput, input, config, o.Loader, responses) + if !processFunctions { + go process(predInput, input, config, o.Loader, responses) + } else { + go processTools(noActionName, predInput, input, config, o.Loader, responses) + } c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - usage := &schema.OpenAIUsage{} - + toolsCalled := false for ev := range responses { usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it + if len(ev.Choices[0].Delta.ToolCalls) > 0 { + toolsCalled = true + } var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.Encode(ev) @@ -278,13 +379,20 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) w.Flush() } + finishReason := "stop" + if toolsCalled { + finishReason = "tool_calls" + } else if toolsCalled && len(input.Tools) == 0 { + finishReason = "function_call" + } + resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { - FinishReason: "stop", + FinishReason: finishReason, Index: 0, Delta: &schema.Message{Content: &emptyMessage}, }}, @@ -298,102 +406,141 @@ func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) w.Flush() })) return nil - } - result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - // This prevent newlines to break JSON parsing for clients - s = utils.EscapeNewLines(s) - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = backend.Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &message}}) - return - } - } - } + default: + result, tokenUsage, err := ComputeChoices(input, predInput, config, o, o.Loader, func(s string, c *[]schema.Choice) { + if processFunctions { + ss := map[string]interface{}{} - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - images := []string{} - for _, m := range input.Messages { - images = append(images, m.StringImages...) - } - predFunc, err := backend.ModelInference(input.Context, predInput, images, o.Loader, *config, o, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } + name, args := parseFunctionCall(s) + ss["name"], ss["arguments"] = name, args - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return + // if do nothing, reply with a message + if name == noActionName { + result, err := handleQuestion(config, input, o, args, predInput) + if err != nil { + log.Error().Msgf("error handling question: %s", err.Error()) + return + } + *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &result}}) + } else { + if len(input.Tools) > 0 { + // Result is different in the case we have a tool call + *c = append(*c, schema.Choice{ + FinishReason: "tool_calls", + Message: &schema.Message{ + Role: "assistant", + ToolCalls: []schema.ToolCall{ + { + ID: id, + Type: "function", + FunctionCall: schema.FunctionCall{ + Name: name, + Arguments: args, + }, + }, + }, + }, + }) + } else { + // otherwise reply with the function call + *c = append(*c, schema.Choice{ + FinishReason: "function_call", + Message: &schema.Message{ + Role: "assistant", + FunctionCall: ss, + }, + }) + } } - fineTunedResponse := backend.Finetune(*config, predInput, prediction.Response) - *c = append(*c, schema.Choice{Message: &schema.Message{Role: "assistant", Content: &fineTunedResponse}}) - } else { - // otherwise reply with the function call - *c = append(*c, schema.Choice{ - FinishReason: "function_call", - Message: &schema.Message{Role: "assistant", FunctionCall: ss}, - }) + return } - return + *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return err } - *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err + + resp := &schema.OpenAIResponse{ + ID: id, + Created: created, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + Usage: schema.OpenAIUsage{ + PromptTokens: tokenUsage.Prompt, + CompletionTokens: tokenUsage.Completion, + TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, + }, + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) } - resp := &schema.OpenAIResponse{ - ID: id, - Created: created, - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - Usage: schema.OpenAIUsage{ - PromptTokens: tokenUsage.Prompt, - CompletionTokens: tokenUsage.Completion, - TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, - }, + } +} + +func handleQuestion(config *config.Config, input *schema.OpenAIRequest, o *options.Option, args, prompt string) (string, error) { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(args), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, prompt, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + return message, nil + } } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU/GPU) another computation + config.Grammar = "" + images := []string{} + for _, m := range input.Messages { + images = append(images, m.StringImages...) + } + + predFunc, err := backend.ModelInference(input.Context, prompt, images, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return "", err + } - // Return the prediction in the response body - return c.JSON(resp) + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return "", err } + return backend.Finetune(*config, prompt, prediction.Response), nil +} + +func parseFunctionCall(llmresult string) (string, string) { + // As we have to change the result before processing, we can't stream the answer token-by-token (yet?) + ss := map[string]interface{}{} + // This prevent newlines to break JSON parsing for clients + s := utils.EscapeNewLines(llmresult) + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + return func_name.(string), string(d) } diff --git a/api/openai/request.go b/api/openai/request.go index 382a930e1c7..6a7a14e8502 100644 --- a/api/openai/request.go +++ b/api/openai/request.go @@ -13,6 +13,7 @@ import ( fiberContext "github.com/go-skynet/LocalAI/api/ctx" options "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/schema" + "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -136,6 +137,20 @@ func updateRequestConfig(config *config.Config, input *schema.OpenAIRequest) { } } + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice grammar.Tool + json.Unmarshal([]byte(input.ToolsChoice.(string)), &toolChoice) + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + // Decode each request's message content index := 0 for i, m := range input.Messages { diff --git a/api/schema/openai.go b/api/schema/openai.go index 6355ff63d5e..12a39b4284d 100644 --- a/api/schema/openai.go +++ b/api/schema/openai.go @@ -68,6 +68,10 @@ type ContentURL struct { type Message struct { // The message role Role string `json:"role,omitempty" yaml:"role"` + + // The message name (used for tools calls) + Name string `json:"name,omitempty" yaml:"name"` + // The message content Content interface{} `json:"content" yaml:"content"` @@ -76,6 +80,20 @@ type Message struct { // A result of a function call FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` + + ToolCalls []ToolCall `json:"tool_calls,omitempty" yaml:"tool_call,omitempty"` +} + +type ToolCall struct { + Index int `json:"index"` + ID string `json:"id"` + Type string `json:"type"` + FunctionCall FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + Arguments string `json:"arguments"` } type OpenAIModel struct { @@ -117,6 +135,9 @@ type OpenAIRequest struct { Functions []grammar.Function `json:"functions" yaml:"functions"` FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Tools []grammar.Tool `json:"tools,omitempty" yaml:"tools"` + ToolsChoice interface{} `json:"tool_choice,omitempty" yaml:"tool_choice"` + Stream bool `json:"stream"` // Image (not supported by OpenAI) diff --git a/pkg/grammar/functions.go b/pkg/grammar/functions.go index ef56662b7b9..1038f5e6f14 100644 --- a/pkg/grammar/functions.go +++ b/pkg/grammar/functions.go @@ -11,6 +11,12 @@ type Function struct { } type Functions []Function +type Tool struct { + Type string `json:"type"` + Function Function `json:"function,omitempty"` +} +type Tools []Tool + func (f Functions) ToJSONStructure() JSONFunctionStructure { js := JSONFunctionStructure{} for _, function := range f { diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 37c2a603a63..bea32fb72a4 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -33,6 +33,7 @@ type ChatMessageTemplateData struct { SystemPrompt string Role string RoleName string + FunctionName string Content string MessageIndex int }