From f63e32a3db04cbbe7d9c8c1ddc18bb9257c7a325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8D=BB=E5=8D=87?= <63136897+ongdisheng@users.noreply.github.com> Date: Sun, 22 Dec 2024 21:20:00 +0800 Subject: [PATCH 01/25] feat(wren-launcher): Added OpenAI API key validation (#1043) --- wren-launcher/commands/launch.go | 35 ++++++++++++++++++++++++++++++++ wren-launcher/go.mod | 1 + wren-launcher/go.sum | 2 ++ 3 files changed, 38 insertions(+) diff --git a/wren-launcher/commands/launch.go b/wren-launcher/commands/launch.go index cdf183cbc..a22d1c6f5 100644 --- a/wren-launcher/commands/launch.go +++ b/wren-launcher/commands/launch.go @@ -1,6 +1,7 @@ package commands import ( + "context" "errors" "fmt" "os" @@ -14,6 +15,7 @@ import ( "github.com/common-nighthawk/go-figure" "github.com/manifoldco/promptui" "github.com/pterm/pterm" + openai "github.com/sashabaranov/go-openai" ) func prepareProjectDir() string { @@ -190,6 +192,12 @@ func Launch() { return } + // check if OpenAI API key is valid + shouldReturn = validateOpenaiApiKey(openaiApiKey) + if shouldReturn { + return + } + // ask for OpenAI generation model pterm.Print("\n") openaiGenerationModel, shouldReturn = getOpenaiGenerationModel() @@ -365,3 +373,30 @@ func getLLMProvider() (string, bool) { } return llmProvider, false } + +func validateOpenaiApiKey(apiKey string) bool { + // validate if input api key is valid by sending a hello request + pterm.Info.Println("Sending a hello request to OpenAI...") + client := openai.NewClient(apiKey) + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT4oMini20240718, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + }, + ) + + // insufficient credit balance error + if err != nil { + pterm.Error.Println("Invalid API key", err) + return true + } + + pterm.Info.Println("Valid API key, Response:", resp.Choices[0].Message.Content) + return false +} diff --git a/wren-launcher/go.mod b/wren-launcher/go.mod index 909fc5404..0b186782c 100644 --- a/wren-launcher/go.mod +++ b/wren-launcher/go.mod @@ -8,6 +8,7 @@ require ( github.com/docker/docker v26.1.5+incompatible github.com/google/uuid v1.6.0 github.com/manifoldco/promptui v0.9.0 + github.com/sashabaranov/go-openai v1.36.0 ) require ( diff --git a/wren-launcher/go.sum b/wren-launcher/go.sum index 20c3a0473..2223393d6 100644 --- a/wren-launcher/go.sum +++ b/wren-launcher/go.sum @@ -478,6 +478,8 @@ github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sashabaranov/go-openai v1.36.0 h1:fcSrn8uGuorzPWCBp8L0aCR95Zjb/Dd+ZSML0YZy9EI= +github.com/sashabaranov/go-openai v1.36.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/secure-systems-lab/go-securesystemslib v0.4.0 h1:b23VGrQhTA8cN2CbBw7/FulN9fTtqYUdS5+Oxzt+DUE= github.com/secure-systems-lab/go-securesystemslib v0.4.0/go.mod h1:FGBZgq2tXWICsxWQW1msNf49F0Pf2Op5Htayx335Qbs= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= From 8915cc3bbcc18171c78420bc3875ba31dda3356e Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Mon, 23 Dec 2024 10:32:10 +0800 Subject: [PATCH 02/25] chore(wren-launcher): simplify launcher choosing custom (#1044) --- wren-launcher/commands/launch.go | 32 ++----- wren-launcher/go.mod | 2 +- wren-launcher/go.sum | 4 +- wren-launcher/utils/docker.go | 153 +++++++++++++++++++++---------- wren-launcher/utils/network.go | 2 +- 5 files changed, 117 insertions(+), 76 deletions(-) diff --git a/wren-launcher/commands/launch.go b/wren-launcher/commands/launch.go index a22d1c6f5..08e3edfe9 100644 --- a/wren-launcher/commands/launch.go +++ b/wren-launcher/commands/launch.go @@ -121,18 +121,6 @@ func askForGenerationModel() (string, error) { return result, nil } -func isConfigFileValidForCustomLLM(projectDir string) error { - // validate if config.yaml file exists in ~/.wrenai - configFilePath := path.Join(projectDir, "config.yaml") - - if _, err := os.Stat(configFilePath); os.IsNotExist(err) { - errMessage := fmt.Sprintf("Please create a config.yaml file in %s first, more details at https://docs.getwren.ai/oss/installation/custom_llm#running-wren-ai-with-your-custom-llm-or-document-store", projectDir) - return errors.New(errMessage) - } - - return nil -} - func Launch() { // recover from panic defer func() { @@ -210,12 +198,6 @@ func Launch() { if err != nil { panic(err) } - } else { - // check if config.yaml file exists - err := isConfigFileValidForCustomLLM(projectDir) - if err != nil { - panic(err) - } } // ask for telemetry consent @@ -250,7 +232,7 @@ func Launch() { uiPort := utils.FindAvailablePort(3000) aiPort := utils.FindAvailablePort(5555) - err = utils.PrepareDockerFiles(openaiApiKey, openaiGenerationModel, uiPort, aiPort, projectDir, telemetryEnabled) + err = utils.PrepareDockerFiles(openaiApiKey, openaiGenerationModel, uiPort, aiPort, projectDir, telemetryEnabled, llmProvider) if err != nil { panic(err) } @@ -264,10 +246,8 @@ func Launch() { } pterm.Info.Println("Wren AI is starting, please wait for a moment...") - if llmProvider == "Custom" { - pterm.Info.Println("If you choose Ollama as LLM provider, please make sure you have started the Ollama service first. Also, Wren AI will automatically pull your chosen models if you have not done so. You can check the progress by executing `docker logs -f wrenai-wren-ai-service-1` in the terminal.") - } - url := fmt.Sprintf("http://localhost:%d", uiPort) + uiUrl := fmt.Sprintf("http://localhost:%d", uiPort) + aiUrl := fmt.Sprintf("http://localhost:%d", aiPort) // wait until checking if CheckUIServiceStarted return without error // if timeout 2 minutes, panic timeoutTime := time.Now().Add(2 * time.Minute) @@ -277,7 +257,7 @@ func Launch() { } // check if ui is ready - err := utils.CheckUIServiceStarted(url) + err := utils.CheckUIServiceStarted(uiUrl) if err == nil { pterm.Info.Println("UI Service is ready") break @@ -294,7 +274,7 @@ func Launch() { } // check if ai service is ready - err := utils.CheckAIServiceStarted(aiPort) + err := utils.CheckAIServiceStarted(aiUrl) if err == nil { pterm.Info.Println("AI Service is Ready") break @@ -304,7 +284,7 @@ func Launch() { // open browser pterm.Info.Println("Opening browser") - utils.Openbrowser(url) + utils.Openbrowser(uiUrl) pterm.Info.Println("You can now safely close this terminal window") fmt.Scanf("h") diff --git a/wren-launcher/go.mod b/wren-launcher/go.mod index 0b186782c..a3590fcee 100644 --- a/wren-launcher/go.mod +++ b/wren-launcher/go.mod @@ -138,7 +138,7 @@ require ( go.uber.org/mock v0.4.0 // indirect golang.org/x/crypto v0.31.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/net v0.25.0 // indirect + golang.org/x/net v0.33.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/wren-launcher/go.sum b/wren-launcher/go.sum index 2223393d6..c1666d33a 100644 --- a/wren-launcher/go.sum +++ b/wren-launcher/go.sum @@ -630,8 +630,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.15.0 h1:s8pnnxNVzjWyrvYdFUQq5llS1PX2zhPXmccZv99h7uQ= golang.org/x/oauth2 v0.15.0/go.mod h1:q48ptWNTY5XWf+JNten23lcvHpLJ0ZSxF5ttTHKVCAM= diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index c5f17bfa6..aa41b4344 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -228,7 +228,7 @@ func mergeEnvContent(newEnvFile string, envFileContent string) (string, error) { return envFileContent, nil } -func PrepareDockerFiles(openaiApiKey string, openaiGenerationModel string, hostPort int, aiPort int, projectDir string, telemetryEnabled bool) error { +func PrepareDockerFiles(openaiApiKey string, openaiGenerationModel string, hostPort int, aiPort int, projectDir string, telemetryEnabled bool, llmProvider string) error { // download docker-compose file composeFile := path.Join(projectDir, "docker-compose.yaml") pterm.Info.Println("Downloading docker-compose file to", composeFile) @@ -237,54 +237,66 @@ func PrepareDockerFiles(openaiApiKey string, openaiGenerationModel string, hostP return err } - userUUID, err := prepareUserUUID(projectDir) - if err != nil { - return err - } + if strings.ToLower(llmProvider) == "openai" { + userUUID, err := prepareUserUUID(projectDir) + if err != nil { + return err + } - // download env file - envExampleFile := path.Join(projectDir, ".env.example") - pterm.Info.Println("Downloading env file to", envExampleFile) - err = downloadFile(envExampleFile, DOCKER_COMPOSE_ENV_URL) - if err != nil { - return err - } + // download env file + envExampleFile := path.Join(projectDir, ".env.example") + pterm.Info.Println("Downloading env file to", envExampleFile) + err = downloadFile(envExampleFile, DOCKER_COMPOSE_ENV_URL) + if err != nil { + return err + } - // read the file - envExampleFileContent, err := os.ReadFile(envExampleFile) - if err != nil { - return err - } + // read the file + envExampleFileContent, err := os.ReadFile(envExampleFile) + if err != nil { + return err + } - // replace the content with regex - envFileContent := replaceEnvFileContent( - string(envExampleFileContent), - projectDir, - openaiApiKey, - openaiGenerationModel, - hostPort, - aiPort, - userUUID, - telemetryEnabled, - ) - newEnvFile := getEnvFilePath(projectDir) - - // merge the env file content with the existing env file - envFileContent, err = mergeEnvContent(newEnvFile, envFileContent) - if err != nil { - return err - } + // replace the content with regex + envFileContent := replaceEnvFileContent( + string(envExampleFileContent), + projectDir, + openaiApiKey, + openaiGenerationModel, + hostPort, + aiPort, + userUUID, + telemetryEnabled, + ) + newEnvFile := getEnvFilePath(projectDir) + + // merge the env file content with the existing env file + envFileContent, err = mergeEnvContent(newEnvFile, envFileContent) + if err != nil { + return err + } - // write the file - err = os.WriteFile(newEnvFile, []byte(envFileContent), 0644) - if err != nil { - return err - } + // write the file + err = os.WriteFile(newEnvFile, []byte(envFileContent), 0644) + if err != nil { + return err + } - // remove the old env file - err = os.Remove(envExampleFile) - if err != nil { - return err + // remove the old env file + err = os.Remove(envExampleFile) + if err != nil { + return err + } + } else if strings.ToLower(llmProvider) == "custom" { + // if .env file does not exist, return error + if _, err := os.Stat(getEnvFilePath(projectDir)); os.IsNotExist(err) { + return fmt.Errorf(".env file does not exist, please download the env file from %s to ~/.wrenai, rename it to .env and fill in the required information", DOCKER_COMPOSE_ENV_URL) + } + + // if config.yaml file does not exist, return error + if _, err := os.Stat(getConfigFilePath(projectDir)); os.IsNotExist(err) { + return fmt.Errorf("config.yaml file does not exist, please download the config.yaml file from %s to ~/.wrenai, rename it to config.yaml and fill in the required information", AI_SERVICE_CONFIG_URL) + } } return nil @@ -294,6 +306,10 @@ func getEnvFilePath(projectDir string) string { return path.Join(projectDir, ".env") } +func getConfigFilePath(projectDir string) string { + return path.Join(projectDir, "config.yaml") +} + func RunDockerCompose(projectName string, projectDir string, llmProvider string) error { ctx := context.Background() composeFilePath := path.Join(projectDir, "docker-compose.yaml") @@ -341,6 +357,22 @@ func RunDockerCompose(projectName string, projectDir string, llmProvider string) return err } + if strings.ToLower(llmProvider) == "custom" { + // Create up options for force recreating only wren-ai-service + upOptions := api.UpOptions{ + Create: api.CreateOptions{ + Recreate: api.RecreateForce, + Services: []string{"wren-ai-service"}, + }, + } + + // Run the up command with specific options for wren-ai-service + err = apiService.Up(ctx, projectType, upOptions) + if err != nil { + return err + } + } + return nil } @@ -384,6 +416,21 @@ func findWrenUIContainer() (types.Container, error) { return types.Container{}, fmt.Errorf("WrenUI container not found") } +func findAIServiceContainer() (types.Container, error) { + containers, err := listProcess() + if err != nil { + return types.Container{}, err + } + + for _, container := range containers { + if container.Labels["com.docker.compose.project"] == "wrenai" && container.Labels["com.docker.compose.service"] == "wren-ai-service" { + return container, nil + } + } + + return types.Container{}, fmt.Errorf("WrenAI service container not found") +} + func IfPortUsedByWrenUI(port int) bool { container, err := findWrenUIContainer() if err != nil { @@ -399,6 +446,21 @@ func IfPortUsedByWrenUI(port int) bool { return false } +func IfPortUsedByAIService(port int) bool { + container, err := findAIServiceContainer() + if err != nil { + return false + } + + for _, containerPort := range container.Ports { + if containerPort.PublicPort == uint16(port) { + return true + } + } + + return false +} + func CheckUIServiceStarted(url string) error { // check response from localhost:3000 resp, err := http.Get(url) @@ -413,9 +475,8 @@ func CheckUIServiceStarted(url string) error { return nil } -func CheckAIServiceStarted(port int) error { +func CheckAIServiceStarted(url string) error { // health check - url := fmt.Sprintf("http://localhost:%d/health", port) resp, err := http.Get(url) if err != nil { return err diff --git a/wren-launcher/utils/network.go b/wren-launcher/utils/network.go index 86b611c58..3edf5f8dd 100644 --- a/wren-launcher/utils/network.go +++ b/wren-launcher/utils/network.go @@ -23,7 +23,7 @@ func FindAvailablePort(defaultPort int) int { if !ifPortUsed(port) { // Return the port if it's not used return port - } else if IfPortUsedByWrenUI(port) { + } else if IfPortUsedByWrenUI(port) || IfPortUsedByAIService(port) { // Return the port if it's used, but used by wrenAI return port } From d5091604fe21baa3b5c63a08c6fbb89b721d565b Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Mon, 23 Dec 2024 10:47:35 +0800 Subject: [PATCH 03/25] chore(wren-ai-service): make openai llm version fixed (#1045) --- deployment/kustomizations/base/cm.yaml | 40 +++++++++---------- docker/config.example.yaml | 40 +++++++++---------- .../src/web/v1/services/__init__.py | 4 +- .../test_relationship_recommendation.py | 2 +- .../tools/config/config.example.yaml | 40 +++++++++---------- wren-ai-service/tools/config/config.full.yaml | 40 +++++++++---------- wren-launcher/utils/docker.go | 7 +++- 7 files changed, 89 insertions(+), 84 deletions(-) diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 6ca078121..4e7f27666 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -23,7 +23,7 @@ data: POSTHOG_HOST: "https://app.posthog.com" TELEMETRY_ENABLED: "false" # this is for telemetry to know the model, i think ai-service might be able to provide a endpoint to get the information - GENERATION_MODEL: "gpt-4o-mini" + GENERATION_MODEL: "gpt-4o-mini-2024-07-18" # service endpoints of AI service & engine service WREN_ENGINE_ENDPOINT: "http://wren-engine-svc:8080" @@ -53,7 +53,7 @@ data: provider: litellm_llm timeout: 120 models: - - model: gpt-4o-mini + - model: gpt-4o-mini-2024-07-18 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -106,58 +106,58 @@ data: embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: db_schema_retrieval - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: historical_question_retrieval embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_correction - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: followup_sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_summary - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: preprocess_sql_data - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_breakdown - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_expansion - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_explanation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_regeneration - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: semantics_description - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: question_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: intent_classification - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: data_assistance - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui - name: chart_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: column_indexing_batch_size: 50 diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 59a1f11b7..9bd546be7 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -2,7 +2,7 @@ type: llm provider: litellm_llm timeout: 120 models: -- model: gpt-4o-mini +- model: gpt-4o-mini-2024-07-18 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -11,7 +11,7 @@ models: max_tokens: 4096 response_format: type: json_object -- model: gpt-4o +- model: gpt-4o-2024-08-06 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -56,58 +56,58 @@ pipes: embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: db_schema_retrieval - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: historical_question_retrieval embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_correction - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: followup_sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_summary - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_breakdown - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_expansion - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_explanation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_regeneration - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: semantics_description - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: question_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: intent_classification - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: data_assistance - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: preprocess_sql_data - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui - name: chart_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: column_indexing_batch_size: 50 diff --git a/wren-ai-service/src/web/v1/services/__init__.py b/wren-ai-service/src/web/v1/services/__init__.py index 6f5b06f95..3294b96aa 100644 --- a/wren-ai-service/src/web/v1/services/__init__.py +++ b/wren-ai-service/src/web/v1/services/__init__.py @@ -28,7 +28,7 @@ class FiscalYear(BaseModel): end: str class Timezone(BaseModel): - name: str = "Asia/Taipei" + name: str = "UTC" utc_offset: str = "" # Deprecated, will be removed in the future def show_current_time(self): @@ -41,7 +41,7 @@ def show_current_time(self): return f'{current_time.strftime("%Y-%m-%d %A %H:%M:%S")}' # YYYY-MM-DD weekday_name HH:MM:SS, ex: 2024-10-23 Wednesday 12:00:00 fiscal_year: Optional[FiscalYear] = None - language: Optional[str] = "en" + language: Optional[str] = "English" timezone: Optional[Timezone] = Timezone() diff --git a/wren-ai-service/tests/pytest/services/test_relationship_recommendation.py b/wren-ai-service/tests/pytest/services/test_relationship_recommendation.py index 57be1b0a8..65672afde 100644 --- a/wren-ai-service/tests/pytest/services/test_relationship_recommendation.py +++ b/wren-ai-service/tests/pytest/services/test_relationship_recommendation.py @@ -27,7 +27,7 @@ async def test_recommend_success(relationship_recommendation_service, mock_pipel assert response.id == "test_id" assert response.status == "finished" assert response.response == {"test": "data"} - mock_pipeline.run.assert_called_once_with(mdl={"key": "value"}, language="en") + mock_pipeline.run.assert_called_once_with(mdl={"key": "value"}, language="English") @pytest.mark.asyncio diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index 742495f4e..c439ccc59 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -2,7 +2,7 @@ type: llm provider: litellm_llm timeout: 120 models: -- model: gpt-4o-mini +- model: gpt-4o-mini-2024-07-18 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -11,7 +11,7 @@ models: max_tokens: 4096 response_format: type: json_object -- model: gpt-4o +- model: gpt-4o-2024-08-06 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -70,56 +70,56 @@ pipes: embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: db_schema_retrieval - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: historical_question_retrieval embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_correction - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: followup_sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_summary - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_breakdown - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_expansion - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_explanation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_regeneration - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: semantics_description - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: question_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: intent_classification - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: data_assistance - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: preprocess_sql_data - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui --- diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 64bdd9173..526c32464 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -2,7 +2,7 @@ type: llm provider: litellm_llm timeout: 120 models: -- model: gpt-4o-mini +- model: gpt-4o-mini-2024-07-18 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -11,7 +11,7 @@ models: max_tokens: 4096 response_format: type: json_object -- model: gpt-4o +- model: gpt-4o-2024-08-06 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: @@ -89,58 +89,58 @@ pipes: embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: db_schema_retrieval - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: historical_question_retrieval embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_correction - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: followup_sql_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_summary - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_answer - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_breakdown - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_expansion - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: sql_explanation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_regeneration - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: semantics_description - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: relationship_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 engine: wren_ui - name: question_recommendation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_generation - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: chart_adjustment - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: intent_classification - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 embedder: openai_embedder.text-embedding-3-large document_store: qdrant - name: data_assistance - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 - name: sql_executor engine: wren_ui - name: preprocess_sql_data - llm: litellm_llm.gpt-4o-mini + llm: litellm_llm.gpt-4o-mini-2024-07-18 --- settings: diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index aa41b4344..0bb06a9ac 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -30,6 +30,11 @@ const ( AI_SERVICE_CONFIG_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/config.example.yaml" ) +var generationModelToModelName = map[string]string{ + "gpt-4o-mini": "gpt-4o-mini-2024-07-18", + "gpt-4o": "gpt-4o-2024-08-06", +} + func replaceEnvFileContent(content string, projectDir string, openaiApiKey string, openAIGenerationModel string, hostPort int, aiPort int, userUUID string, telemetryEnabled bool) string { // replace PROJECT_DIR reg := regexp.MustCompile(`PROJECT_DIR=(.*)`) @@ -152,7 +157,7 @@ func PrepareConfigFileForOpenAI(projectDir string, generationModel string) error // replace the generation model in config.yaml config := string(content) - config = strings.ReplaceAll(config, "litellm_llm.gpt-4o-mini", "litellm_llm."+generationModel) + config = strings.ReplaceAll(config, "litellm_llm.gpt-4o-mini-2024-07-18", "litellm_llm."+generationModelToModelName[generationModel]) // write back to config.yaml err = os.WriteFile(configPath, []byte(config), 0644) From 2485d24654b738d3b2bb2fb15af66566be40415d Mon Sep 17 00:00:00 2001 From: imAsterSun <61279528+imAsterSun@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:28:33 +0800 Subject: [PATCH 04/25] chore(wren-ai-service): Add JSON function to prompt (#964) Co-authored-by: Chih-Yu Yeh --- wren-ai-service/src/pipelines/common.py | 33 +++++++++++++++++++ .../src/pipelines/indexing/utils/helper.py | 10 ++++++ wren-ai-service/tests/pytest/test_usecases.py | 5 +-- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/wren-ai-service/src/pipelines/common.py b/wren-ai-service/src/pipelines/common.py index 9b1eb4ef6..b106ac7ab 100644 --- a/wren-ai-service/src/pipelines/common.py +++ b/wren-ai-service/src/pipelines/common.py @@ -301,6 +301,39 @@ async def _task(result: Dict[str, str]): - `=` - `<>` - `!=` +- ONLY USE JSON_QUERY for querying fields if "json_type":"JSON" is identified in the columns comment, NOT the deprecated JSON_EXTRACT_SCALAR function. + - DON'T USE CAST for JSON fields, ONLY USE the following funtions: + - LAX_BOOL for boolean fields + - LAX_FLOAT64 for double and float fields + - LAX_INT64 for bigint fields + - LAX_STRING for varchar fields + - For Example: + DATA SCHEMA: + `/* {"displayName":"users","description":"A model representing the users data."} */ + CREATE TABLE users ( + -- {"alias":"address","description":"A JSON object that represents address information of this user.","json_type":"JSON","json_fields":{"json_type":"JSON","address.json.city":{"name":"city","type":"varchar","path":"$.city","properties":{"displayName":"city","description":"City Name."}},"address.json.state":{"name":"state","type":"varchar","path":"$.state","properties":{"displayName":"state","description":"ISO code or name of the state, province or district."}},"address.json.postcode":{"name":"postcode","type":"varchar","path":"$.postcode","properties":{"displayName":"postcode","description":"Postal code."}},"address.json.country":{"name":"country","type":"varchar","path":"$.country","properties":{"displayName":"country","description":"ISO code of the country."}}}} + address JSON + )` + To get the city of address in user table use SQL: + `SELECT LAX_STRING(JSON_QUERY(u.address, '$.city')) FROM user as u` +- ONLY USE JSON_QUERY_ARRAY for querying "json_type":"JSON_ARRAY" is identified in the comment of the column, NOT the deprecated JSON_EXTRACT_ARRAY. + - USE UNNEST to analysis each item individually in the ARRAY. YOU MUST SELECT FROM the parent table ahead of the UNNEST ARRAY. + - The alias of the UNNEST(ARRAY) should be in the format `unnest_table_alias(individual_item_alias)` + - For Example: `SELECT item FROM UNNEST(ARRAY[1,2,3]) as my_unnested_table(item)` + - If the items in the ARRAY are JSON objects, use JSON_QUERY to query the fields inside each JSON item. + - For Example: + DATA SCHEMA + `/* {"displayName":"my_table","description":"A test my_table"} */ + CREATE TABLE my_table ( + -- {"alias":"elements","description":"elements column","json_type":"JSON_ARRAY","json_fields":{"json_type":"JSON_ARRAY","elements.json_array.id":{"name":"id","type":"bigint","path":"$.id","properties":{"displayName":"id","description":"data ID."}},"elements.json_array.key":{"name":"key","type":"varchar","path":"$.key","properties":{"displayName":"key","description":"data Key."}},"elements.json_array.value":{"name":"value","type":"varchar","path":"$.value","properties":{"displayName":"value","description":"data Value."}}}} + elements JSON + )` + To get the number of elements in my_table table use SQL: + `SELECT LAX_INT64(JSON_QUERY(element, '$.number')) FROM my_table as t, UNNEST(JSON_QUERY_ARRAY(elements)) AS my_unnested_table(element) WHERE LAX_FLOAT64(JSON_QUERY(element, '$.value')) > 3.5` + - To JOIN ON the fields inside UNNEST(ARRAY), YOU MUST SELECT FROM the parent table ahead of the UNNEST syntax, and the alias of the UNNEST(ARRAY) SHOULD BE IN THE FORMAT unnest_table_alias(individual_item_alias) + - For Example: `SELECT p.column_1, j.column_2 FROM parent_table AS p, join_table AS j JOIN UNNEST(p.array_column) AS unnested(array_item) ON j.id = array_item.id` +- DON'T USE JSON_QUERY and JSON_QUERY_ARRAY when "json_type":"". +- DON'T USE LAX_BOOL, LAX_FLOAT64, LAX_INT64, LAX_STRING when "json_type":"". """ diff --git a/wren-ai-service/src/pipelines/indexing/utils/helper.py b/wren-ai-service/src/pipelines/indexing/utils/helper.py index ef1c98330..0eaad9db6 100644 --- a/wren-ai-service/src/pipelines/indexing/utils/helper.py +++ b/wren-ai-service/src/pipelines/indexing/utils/helper.py @@ -1,6 +1,7 @@ import importlib import logging import pkgutil +import re import sys from typing import Any, Callable, Dict @@ -30,6 +31,7 @@ def _properties_comment(column: Dict[str, Any], **_) -> str: column_properties = { "alias": props.get("displayName", ""), "description": props.get("description", ""), + "json_type": props.get("json_type", ""), } # Add any nested columns if they exist @@ -37,6 +39,14 @@ def _properties_comment(column: Dict[str, Any], **_) -> str: if nested: column_properties["nested_columns"] = nested + json_fields = { + k: v + for k, v in column["properties"].items() + if re.match(r".*json.*", k) + } + if json_fields: + column_properties["json_fields"] = json_fields + return f"-- {orjson.dumps(column_properties).decode('utf-8')}\n " diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 41b9956c1..58e6ad8e1 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -147,6 +147,7 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id usecase_to_dataset_type = { "hubspot": "bigquery", "ga4": "bigquery", + "woocommerce": "bigquery", "ecommerce": "duckdb", "hr": "duckdb", } @@ -157,12 +158,12 @@ async def ask_questions(questions: list[str], url: str, semantics_preperation_id type=str, nargs="+", default=["all"], - choices=["all", "hubspot", "ga4", "ecommerce", "hr"], + choices=["all", "hubspot", "ga4", "ecommerce", "hr", "woocommerce"], ) args = parser.parse_args() if "all" in args.usecases: - usecases = ["hubspot", "ga4", "ecommerce", "hr"] + usecases = ["hubspot", "ga4", "ecommerce", "hr", "woocommerce"] else: usecases = args.usecases From 1a03df6c8a8ebb8bafd992a99778a66e76ec73b5 Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Mon, 23 Dec 2024 11:40:50 +0800 Subject: [PATCH 05/25] release: 0.13.2-rc.1 (#1046) --- docker/.env.example | 2 +- wren-launcher/utils/docker.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index 22a6d6fc1..726366a09 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -24,7 +24,7 @@ QDRANT_API_KEY= # version # CHANGE THIS TO THE LATEST VERSION -WREN_PRODUCT_VERSION=0.13.1 +WREN_PRODUCT_VERSION=0.13.2-rc.1 WREN_ENGINE_VERSION=0.12.6 WREN_AI_SERVICE_VERSION=0.13.4 IBIS_SERVER_VERSION=0.12.6 diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index 0bb06a9ac..17f53373d 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -24,7 +24,7 @@ import ( const ( // please change the version when the version is updated - WREN_PRODUCT_VERSION string = "0.13.1" + WREN_PRODUCT_VERSION string = "0.13.2-rc.1" DOCKER_COMPOSE_YAML_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/docker-compose.yaml" DOCKER_COMPOSE_ENV_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/.env.example" AI_SERVICE_CONFIG_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/config.example.yaml" From 20e081cbc7fa3f573a2d5eb5138710e12b82e4ad Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Mon, 23 Dec 2024 14:43:14 +0800 Subject: [PATCH 06/25] release: 0.13.2 (#1048) --- docker/.env.example | 2 +- wren-launcher/utils/docker.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index 726366a09..03b62e717 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -24,7 +24,7 @@ QDRANT_API_KEY= # version # CHANGE THIS TO THE LATEST VERSION -WREN_PRODUCT_VERSION=0.13.2-rc.1 +WREN_PRODUCT_VERSION=0.13.2 WREN_ENGINE_VERSION=0.12.6 WREN_AI_SERVICE_VERSION=0.13.4 IBIS_SERVER_VERSION=0.12.6 diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index 17f53373d..d57e7bd24 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -24,7 +24,7 @@ import ( const ( // please change the version when the version is updated - WREN_PRODUCT_VERSION string = "0.13.2-rc.1" + WREN_PRODUCT_VERSION string = "0.13.2" DOCKER_COMPOSE_YAML_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/docker-compose.yaml" DOCKER_COMPOSE_ENV_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/.env.example" AI_SERVICE_CONFIG_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/config.example.yaml" From 69d6e517193322cd44294fbe6fae575a2e6c019c Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Mon, 23 Dec 2024 18:11:48 +0800 Subject: [PATCH 07/25] chore(wren-ai-service): fix add comment to column properties (#1050) --- .../src/pipelines/indexing/utils/helper.py | 18 ++++++++++-------- .../pytest/pipelines/indexing/test_helper.py | 5 ++++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/wren-ai-service/src/pipelines/indexing/utils/helper.py b/wren-ai-service/src/pipelines/indexing/utils/helper.py index 0eaad9db6..54258a67a 100644 --- a/wren-ai-service/src/pipelines/indexing/utils/helper.py +++ b/wren-ai-service/src/pipelines/indexing/utils/helper.py @@ -31,7 +31,6 @@ def _properties_comment(column: Dict[str, Any], **_) -> str: column_properties = { "alias": props.get("displayName", ""), "description": props.get("description", ""), - "json_type": props.get("json_type", ""), } # Add any nested columns if they exist @@ -39,13 +38,16 @@ def _properties_comment(column: Dict[str, Any], **_) -> str: if nested: column_properties["nested_columns"] = nested - json_fields = { - k: v - for k, v in column["properties"].items() - if re.match(r".*json.*", k) - } - if json_fields: - column_properties["json_fields"] = json_fields + if (json_type := props.get("json_type", "")) and json_type in [ + "JSON", + "JSON_ARRAY", + ]: + json_fields = { + k: v for k, v in column["properties"].items() if re.match(r".*json.*", k) + } + if json_fields: + column_properties["json_type"] = json_type + column_properties["json_fields"] = json_fields return f"-- {orjson.dumps(column_properties).decode('utf-8')}\n " diff --git a/wren-ai-service/tests/pytest/pipelines/indexing/test_helper.py b/wren-ai-service/tests/pytest/pipelines/indexing/test_helper.py index da86940dd..3534ac43a 100644 --- a/wren-ai-service/tests/pytest/pipelines/indexing/test_helper.py +++ b/wren-ai-service/tests/pytest/pipelines/indexing/test_helper.py @@ -78,7 +78,10 @@ def test_properties_comment_helper(): test_column = { "name": "test_column", - "properties": {"displayName": "Test Column", "description": "Test description"}, + "properties": { + "displayName": "Test Column", + "description": "Test description", + }, } assert helper.condition(test_column) is True From e9d70c2dd15a65459900a750e8a5082443024324 Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Tue, 24 Dec 2024 02:10:58 +0000 Subject: [PATCH 08/25] Upgrade AI Service version to 0.13.5 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 86139b04b..769c1ce48 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.13.4" +version = "0.13.5" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0" From 2f77836dea8da28dbf8ba59e1183a34d8a42649e Mon Sep 17 00:00:00 2001 From: Shimin Date: Wed, 25 Dec 2024 10:35:33 +0800 Subject: [PATCH 09/25] fix(wren-ui): Calculated fields in model preview Issue & improvements (#1051) --- wren-ui/src/apollo/server/resolvers/modelResolver.ts | 6 +++++- wren-ui/src/apollo/server/utils/model.ts | 6 ++++++ wren-ui/src/components/selectors/lineageSelector/index.tsx | 7 ++++++- wren-ui/src/utils/columnType.tsx | 2 ++ wren-ui/src/utils/enum/columnType.ts | 3 +++ wren-ui/src/utils/validator/calculatedFieldValidator.ts | 1 + 6 files changed, 23 insertions(+), 2 deletions(-) diff --git a/wren-ui/src/apollo/server/resolvers/modelResolver.ts b/wren-ui/src/apollo/server/resolvers/modelResolver.ts index 73fc5fbf4..5711d7076 100644 --- a/wren-ui/src/apollo/server/resolvers/modelResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/modelResolver.ts @@ -21,6 +21,7 @@ import { replaceAllowableSyntax, validateDisplayName } from '../utils/regex'; import { Model, ModelColumn } from '../repositories'; import { findColumnsToUpdate, + getPreviewColumnsStr, handleNestedColumns, replaceInvalidReferenceName, updateModelPrimaryKey, @@ -890,7 +891,10 @@ export class ModelResolver { } const project = await ctx.projectService.getCurrentProject(); const { manifest } = await ctx.mdlService.makeCurrentModelMDL(); - const sql = `select * from "${model.referenceName}"`; + const modelColumns = await ctx.modelColumnRepository.findColumnsByModelIds([ + model.id, + ]); + const sql = `select ${getPreviewColumnsStr(modelColumns)} from "${model.referenceName}"`; const data = (await ctx.queryService.preview(sql, { project, diff --git a/wren-ui/src/apollo/server/utils/model.ts b/wren-ui/src/apollo/server/utils/model.ts index fb20bce53..2dfe2e337 100644 --- a/wren-ui/src/apollo/server/utils/model.ts +++ b/wren-ui/src/apollo/server/utils/model.ts @@ -6,6 +6,12 @@ import { import { replaceAllowableSyntax } from './regex'; import { CompactColumn } from '@server/services/metadataService'; +export function getPreviewColumnsStr(modelColumns: ModelColumn[]) { + if (modelColumns.length === 0) return '*'; + const columns = modelColumns.map((column) => `"${column.sourceColumnName}"`); + return columns.join(','); +} + export function transformInvalidColumnName(columnName: string) { let referenceName = replaceAllowableSyntax(columnName); // If the reference name does not start with a letter, add a prefix diff --git a/wren-ui/src/components/selectors/lineageSelector/index.tsx b/wren-ui/src/components/selectors/lineageSelector/index.tsx index eabffc708..90f37fc7d 100644 --- a/wren-ui/src/components/selectors/lineageSelector/index.tsx +++ b/wren-ui/src/components/selectors/lineageSelector/index.tsx @@ -106,7 +106,9 @@ export const getLineageOptions = (data: { const convertor = (field) => { const value = compactObject(getFieldValue(field)); const isRelationship = field.nodeType === NODE_TYPE.RELATION; - + // check if source model's calculated field + const isSourceModelCalculatedField = + isSourceModel && field.nodeType === NODE_TYPE.CALCULATED_FIELD; // check if user select aggregation functions, then the source model fields cannot be selected const isSourceModelFieldsWithAggregation = aggregations.includes(expression) && isSourceModel && !isRelationship; @@ -133,6 +135,7 @@ export const getLineageOptions = (data: { const disabled = isSourceModelFieldsWithAggregation || isRelationshipWithoutPrimaryKey || + isSourceModelCalculatedField || isInUsedRelationship || isInvalidType; @@ -143,6 +146,8 @@ export const getLineageOptions = (data: { } else if (isRelationshipWithoutPrimaryKey) { title = 'Please set a primary key within this model to use it in a calculated field.'; + } else if (isSourceModelCalculatedField) { + title = 'Calculated field from the source model is not supported.'; } else if (isInUsedRelationship) { title = 'This relationship is in use.'; } else if (isInvalidType) { diff --git a/wren-ui/src/utils/columnType.tsx b/wren-ui/src/utils/columnType.tsx index 3ffb5b143..621105fb4 100644 --- a/wren-ui/src/utils/columnType.tsx +++ b/wren-ui/src/utils/columnType.tsx @@ -26,6 +26,7 @@ export const getColumnTypeIcon = (payload: { type: string }, attrs?: any) => { return ; case COLUMN_TYPE.BYTEA: + case COLUMN_TYPE.VARBINARY: return ; case COLUMN_TYPE.UUID: @@ -39,6 +40,7 @@ export const getColumnTypeIcon = (payload: { type: string }, attrs?: any) => { case COLUMN_TYPE.INTEGER: case COLUMN_TYPE.INT8: case COLUMN_TYPE.BIGINT: + case COLUMN_TYPE.INT64: case COLUMN_TYPE.NUMERIC: case COLUMN_TYPE.DECIMAL: case COLUMN_TYPE.FLOAT4: diff --git a/wren-ui/src/utils/enum/columnType.ts b/wren-ui/src/utils/enum/columnType.ts index 3d0a2701c..61fe458d1 100644 --- a/wren-ui/src/utils/enum/columnType.ts +++ b/wren-ui/src/utils/enum/columnType.ts @@ -17,6 +17,8 @@ export enum COLUMN_TYPE { INT8 = 'INT8', BIGINT = 'BIGINT', // alias for INT8 + INT64 = 'INT64', + NUMERIC = 'NUMERIC', DECIMAL = 'DECIMAL', @@ -53,6 +55,7 @@ export enum COLUMN_TYPE { // Binary Data Types BYTEA = 'BYTEA', + VARBINARY = 'VARBINARY', // UUID Type UUID = 'UUID', diff --git a/wren-ui/src/utils/validator/calculatedFieldValidator.ts b/wren-ui/src/utils/validator/calculatedFieldValidator.ts index eab0ca574..f26942bed 100644 --- a/wren-ui/src/utils/validator/calculatedFieldValidator.ts +++ b/wren-ui/src/utils/validator/calculatedFieldValidator.ts @@ -55,6 +55,7 @@ export const checkNumberFunctionAllowType = makeCheckAllowType(mathFunctions, [ COLUMN_TYPE.INTEGER, COLUMN_TYPE.INT8, COLUMN_TYPE.BIGINT, + COLUMN_TYPE.INT64, COLUMN_TYPE.NUMERIC, COLUMN_TYPE.DECIMAL, COLUMN_TYPE.FLOAT4, From 4da34d3f33241ee2e0a73ecf8b75e478aaee6f83 Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Wed, 25 Dec 2024 02:40:27 +0000 Subject: [PATCH 10/25] update wren-ui version to 0.18.8 --- wren-ui/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ui/package.json b/wren-ui/package.json index acf26e6a9..7dc0b8e6b 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -1,6 +1,6 @@ { "name": "wren-ui", - "version": "0.18.7", + "version": "0.18.8", "private": true, "scripts": { "dev": "next dev", From 2ceacdf5a2a52a89bbc94f490ea663e4db02558b Mon Sep 17 00:00:00 2001 From: Freda Lai <42527625+fredalai@users.noreply.github.com> Date: Wed, 25 Dec 2024 10:52:39 +0800 Subject: [PATCH 11/25] fix(wren-ui): fix answerDetail nested resolver for text-to-answer (#1057) --- .../textBasedAnswerBackgroundTracker.ts | 2 +- .../apollo/server/resolvers/askingResolver.ts | 22 +++++++++++-------- .../apollo/server/services/askingService.ts | 3 ++- .../apollo/server/services/projectService.ts | 3 ++- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/wren-ui/src/apollo/server/backgrounds/textBasedAnswerBackgroundTracker.ts b/wren-ui/src/apollo/server/backgrounds/textBasedAnswerBackgroundTracker.ts index 298aa1300..6a7e350fc 100644 --- a/wren-ui/src/apollo/server/backgrounds/textBasedAnswerBackgroundTracker.ts +++ b/wren-ui/src/apollo/server/backgrounds/textBasedAnswerBackgroundTracker.ts @@ -90,7 +90,7 @@ export class TextBasedAnswerBackgroundTracker { sqlData: data, threadId: threadResponse.threadId.toString(), configurations: { - language: project.language as WrenAILanguage, + language: WrenAILanguage[project.language] || WrenAILanguage.EN, }, }); diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index 90cbfc33e..e8b819f2b 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -509,17 +509,21 @@ export class AskingResolver { return { ...view, displayName }; }, answerDetail: (parent: ThreadResponse, _args: any, _ctx: IContext) => { - const content = parent?.answerDetail?.content - ? parent?.answerDetail?.content - // replace the \\n to \n - .replace(/\\n/g, '\n') - // replace the \\\" to \", - .replace(/\\"/g, '"') - : parent?.answerDetail?.content; + if (!parent?.answerDetail) return null; + + const { content, ...rest } = parent.answerDetail; + + if (!content) return parent.answerDetail; + + const formattedContent = content + // replace the \\n to \n + .replace(/\\n/g, '\n') + // replace the \\\" to \", + .replace(/\\"/g, '"'); return { - ...parent.answerDetail, - content, + ...rest, + content: formattedContent, }; }, sql: (parent: ThreadResponse, _args: any, _ctx: IContext) => { diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index efc2f523f..6ace876ee 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -10,6 +10,7 @@ import { RecommendationQuestionStatus, ChartStatus, ChartAdjustmentOption, + WrenAILanguage, } from '@server/models/adaptor'; import { IDeployService } from './deployService'; import { IProjectService } from './projectService'; @@ -1001,7 +1002,7 @@ export class AskingService implements IAskingService { maxCategories: config.threadRecommendationQuestionMaxCategories, maxQuestions: config.threadRecommendationQuestionsMaxQuestions, configuration: { - language: project.language, + language: WrenAILanguage[project.language] || WrenAILanguage.EN, }, }; } diff --git a/wren-ui/src/apollo/server/services/projectService.ts b/wren-ui/src/apollo/server/services/projectService.ts index e47c99bcf..ebb5616c2 100644 --- a/wren-ui/src/apollo/server/services/projectService.ts +++ b/wren-ui/src/apollo/server/services/projectService.ts @@ -14,6 +14,7 @@ import { RecommendationQuestion, RecommendationQuestionStatus, WrenAIError, + WrenAILanguage, } from '@server/models/adaptor'; import { encryptConnectionInfo } from '../dataSource'; import { IWrenAIAdaptor } from '../adaptors'; @@ -240,7 +241,7 @@ export class ProjectService implements IProjectService { maxQuestions: config.projectRecommendationQuestionsMaxQuestions, regenerate: true, configuration: { - language: project.language, + language: WrenAILanguage[project.language] || WrenAILanguage.EN, }, }; } From 28c83d2924a62c9a3f61658deb4812c8c7a7c4ce Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Wed, 25 Dec 2024 16:48:22 +0800 Subject: [PATCH 12/25] fix(wren-ai-service): fix column pruning check criteria (#1059) --- wren-ai-service/src/pipelines/retrieval/retrieval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/src/pipelines/retrieval/retrieval.py b/wren-ai-service/src/pipelines/retrieval/retrieval.py index 2cbd186dc..5fdafeb1a 100644 --- a/wren-ai-service/src/pipelines/retrieval/retrieval.py +++ b/wren-ai-service/src/pipelines/retrieval/retrieval.py @@ -405,7 +405,7 @@ def __init__( # for the first time, we need to load the encodings _model = llm_provider.get_model() - if _model == "gpt-4o-mini" or _model == "gpt-4o": + if "gpt-4o" in _model or "gpt-4o-mini" in _model: allow_using_db_schemas_without_pruning = True _encoding = tiktoken.get_encoding("o200k_base") else: From 235639870c8591ebe5e455578eb6c1e6ad9c8e09 Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Wed, 25 Dec 2024 08:49:18 +0000 Subject: [PATCH 13/25] Upgrade AI Service version to 0.13.6 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 769c1ce48..a91a58607 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.13.5" +version = "0.13.6" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0" From cc3bbd1e5d2ec369df34a4f033466bfc993cf7a8 Mon Sep 17 00:00:00 2001 From: Freda Lai <42527625+fredalai@users.noreply.github.com> Date: Thu, 26 Dec 2024 10:16:33 +0800 Subject: [PATCH 14/25] fix(wren-ui): add lang for adjust chart api (#1061) --- wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index 4c2ec7e31..37428648c 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -423,7 +423,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } private transformChartAdjustmentInput(input: ChartAdjustmentInput) { - const { query, sql, adjustmentOption, chartSchema } = input; + const { query, sql, adjustmentOption, chartSchema, configurations } = input; return { query, sql, @@ -436,6 +436,7 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { theta: adjustmentOption.theta, }, chart_schema: chartSchema, + configurations, }; } From edd04d6135e2d989a5b1a1d8b00416026136ccea Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 26 Dec 2024 02:25:39 +0000 Subject: [PATCH 15/25] update wren-ui version to 0.18.9 --- wren-ui/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ui/package.json b/wren-ui/package.json index 7dc0b8e6b..6ba4227a6 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -1,6 +1,6 @@ { "name": "wren-ui", - "version": "0.18.8", + "version": "0.18.9", "private": true, "scripts": { "dev": "next dev", From d53ed1bc397ae1a21ea60efe18bb359010682a5d Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Thu, 26 Dec 2024 11:55:13 +0800 Subject: [PATCH 16/25] chore(wren-ai-service): improve ai service (#1009) --- deployment/kustomizations/base/cm.yaml | 4 +- docker/config.example.yaml | 4 + wren-ai-service/Justfile | 5 +- wren-ai-service/demo/utils.py | 4 - wren-ai-service/src/__main__.py | 1 + wren-ai-service/src/force_update_config.py | 26 + wren-ai-service/src/pipelines/common.py | 482 +---------------- .../generation/followup_sql_generation.py | 4 +- .../generation/intent_classification.py | 65 ++- .../src/pipelines/generation/sql_breakdown.py | 5 +- .../pipelines/generation/sql_correction.py | 2 +- .../src/pipelines/generation/sql_expansion.py | 3 +- .../pipelines/generation/sql_generation.py | 6 +- .../pipelines/generation/sql_regeneration.py | 2 +- .../src/pipelines/generation/utils/sql.py | 485 ++++++++++++++++++ wren-ai-service/src/web/v1/services/ask.py | 65 ++- wren-ai-service/tests/pytest/test_usecases.py | 9 +- .../tools/config/config.example.yaml | 4 + wren-ai-service/tools/config/config.full.yaml | 4 + 19 files changed, 645 insertions(+), 535 deletions(-) create mode 100644 wren-ai-service/src/force_update_config.py create mode 100644 wren-ai-service/src/pipelines/generation/utils/sql.py diff --git a/deployment/kustomizations/base/cm.yaml b/deployment/kustomizations/base/cm.yaml index 4e7f27666..145fbcd64 100644 --- a/deployment/kustomizations/base/cm.yaml +++ b/deployment/kustomizations/base/cm.yaml @@ -59,15 +59,17 @@ data: kwargs: temperature: 0 n: 1 + seed: 0 max_tokens: 4096 response_format: type: json_object - - model: gpt-4o + - model: gpt-4o-2024-08-06 api_base: https://api.openai.com/v1 api_key_name: LLM_OPENAI_API_KEY kwargs: temperature: 0 n: 1 + seed: 0 max_tokens: 4096 response_format: type: json_object diff --git a/docker/config.example.yaml b/docker/config.example.yaml index 9bd546be7..8f2f8a804 100644 --- a/docker/config.example.yaml +++ b/docker/config.example.yaml @@ -8,6 +8,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response, refer: https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed + seed: 0 max_tokens: 4096 response_format: type: json_object @@ -17,6 +19,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response, refer: https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed + seed: 0 max_tokens: 4096 response_format: type: json_object diff --git a/wren-ai-service/Justfile b/wren-ai-service/Justfile index 64ae65e0c..e35cf2703 100644 --- a/wren-ai-service/Justfile +++ b/wren-ai-service/Justfile @@ -26,7 +26,7 @@ up: prepare-wren-engine down: docker compose -f ./tools/dev/docker-compose-dev.yaml --env-file ./tools/dev/.env down -start: +start: use-wren-ui-as-engine poetry run python -m src.__main__ curate_eval_data: @@ -63,3 +63,6 @@ prepare-wren-engine: mkdir -p tools/dev/etc/mdl echo "{\"catalog\": \"test_catalog\", \"schema\": \"test_schema\", \"models\": []}" \\ > tools/dev/etc/mdl/sample.json + +use-wren-ui-as-engine: + poetry run python -m src.force_update_config \ No newline at end of file diff --git a/wren-ai-service/demo/utils.py b/wren-ai-service/demo/utils.py index b5cd5e5c1..0a7d29f9d 100644 --- a/wren-ai-service/demo/utils.py +++ b/wren-ai-service/demo/utils.py @@ -659,7 +659,6 @@ def display_sql_answer(query_id: str): placeholder.markdown(markdown_content) -@st.cache_data def get_sql_answer( query: str, sql: str, @@ -708,7 +707,6 @@ def get_sql_answer( ) -@st.cache_data def ask_details(): asks_details_response = requests.post( f"{WREN_AI_SERVICE_BASE_URL}/v1/ask-details", @@ -842,7 +840,6 @@ def fill_vega_lite_values(vega_lite_schema: dict, df: pd.DataFrame) -> dict: return schema -@st.cache_data def generate_chart( query: str, sql: str, @@ -893,7 +890,6 @@ def generate_chart( return chart_response -@st.cache_data def adjust_chart( query: str, sql: str, diff --git a/wren-ai-service/src/__main__.py b/wren-ai-service/src/__main__.py index ff0e22512..99f106367 100644 --- a/wren-ai-service/src/__main__.py +++ b/wren-ai-service/src/__main__.py @@ -92,6 +92,7 @@ def health(): port=settings.port, reload=settings.development, reload_includes=["src/**/*.py", ".env.dev", "config.yaml"], + reload_excludes=["demo/*.py", "tests/**/*.py", "eval/**/*.py"], workers=1, loop="uvloop", http="httptools", diff --git a/wren-ai-service/src/force_update_config.py b/wren-ai-service/src/force_update_config.py new file mode 100644 index 000000000..c2a09d057 --- /dev/null +++ b/wren-ai-service/src/force_update_config.py @@ -0,0 +1,26 @@ +import yaml + + +def update_config(): + # Read the config file + with open("config.yaml", "r") as file: + # Load all documents from YAML file (since it has multiple documents separated by ---) + documents = list(yaml.safe_load_all(file)) + + # Find the pipeline configuration document + for doc in documents: + if doc.get("type") == "pipeline": + # Update engine name in all pipelines + for pipe in doc.get("pipes", []): + if "engine" in pipe: + pipe["engine"] = "wren_ui" + + # Write back to the file + with open("config.yaml", "w") as file: + yaml.safe_dump_all(documents, file, default_flow_style=False) + + print("Successfully updated engine names to 'wren_ui' in all pipelines") + + +if __name__ == "__main__": + update_config() diff --git a/wren-ai-service/src/pipelines/common.py b/wren-ai-service/src/pipelines/common.py index b106ac7ab..3d233db7c 100644 --- a/wren-ai-service/src/pipelines/common.py +++ b/wren-ai-service/src/pipelines/common.py @@ -1,491 +1,11 @@ -import asyncio -import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Optional -import aiohttp -import orjson import pytz -from haystack import component -from src.core.engine import ( - Engine, - add_quotes, - clean_generation_result, -) from src.core.pipeline import BasicPipeline from src.web.v1.services import Configuration -logger = logging.getLogger("wren-ai-service") - - -@component -class SQLBreakdownGenPostProcessor: - def __init__(self, engine: Engine): - self._engine = engine - - @component.output_types( - results=Optional[Dict[str, Any]], - ) - async def run( - self, - replies: List[str], - project_id: str | None = None, - ) -> Dict[str, Any]: - cleaned_generation_result = orjson.loads(clean_generation_result(replies[0])) - - steps = cleaned_generation_result.get("steps", []) - if not steps: - return { - "results": { - "description": cleaned_generation_result["description"], - "steps": [], - }, - } - - # make sure the last step has an empty cte_name - steps[-1]["cte_name"] = "" - - for step in steps: - step["sql"], no_error = add_quotes(step["sql"]) - if not no_error: - return { - "results": { - "description": cleaned_generation_result["description"], - "steps": [], - }, - } - - sql = self._build_cte_query(steps) - - if not await self._check_if_sql_executable(sql, project_id=project_id): - return { - "results": { - "description": cleaned_generation_result["description"], - "steps": [], - }, - } - - return { - "results": { - "description": cleaned_generation_result["description"], - "steps": steps, - }, - } - - def _build_cte_query(self, steps) -> str: - ctes = ",\n".join( - f"{step['cte_name']} AS ({step['sql']})" - for step in steps - if step["cte_name"] - ) - - return f"WITH {ctes}\n" + steps[-1]["sql"] if ctes else steps[-1]["sql"] - - async def _check_if_sql_executable( - self, - sql: str, - project_id: str | None = None, - ): - async with aiohttp.ClientSession() as session: - status, _, addition = await self._engine.execute_sql( - sql, - session, - project_id=project_id, - ) - - if not status: - logger.exception( - f"SQL is not executable: {addition.get('error_message', '')}" - ) - - return status - - -@component -class SQLGenPostProcessor: - def __init__(self, engine: Engine): - self._engine = engine - - @component.output_types( - valid_generation_results=List[Optional[Dict[str, Any]]], - invalid_generation_results=List[Optional[Dict[str, Any]]], - ) - async def run( - self, - replies: List[str] | List[List[str]], - project_id: str | None = None, - ) -> dict: - try: - if isinstance(replies[0], dict): - cleaned_generation_result = [] - for reply in replies: - try: - cleaned_generation_result.append( - orjson.loads(clean_generation_result(reply["replies"][0]))[ - "results" - ][0] - ) - except Exception as e: - logger.exception(f"Error in SQLGenPostProcessor: {e}") - else: - cleaned_generation_result = orjson.loads( - clean_generation_result(replies[0]) - )["results"] - - if isinstance(cleaned_generation_result, dict): - cleaned_generation_result = [cleaned_generation_result] - - ( - valid_generation_results, - invalid_generation_results, - ) = await self._classify_invalid_generation_results( - cleaned_generation_result, project_id=project_id - ) - - return { - "valid_generation_results": valid_generation_results, - "invalid_generation_results": invalid_generation_results, - } - except Exception as e: - logger.exception(f"Error in SQLGenPostProcessor: {e}") - - return { - "valid_generation_results": [], - "invalid_generation_results": [], - } - - async def _classify_invalid_generation_results( - self, generation_results: List[Dict[str, str]], project_id: str | None = None - ) -> List[Optional[Dict[str, str]]]: - valid_generation_results = [] - invalid_generation_results = [] - - async def _task(result: Dict[str, str]): - quoted_sql, no_error = add_quotes(result["sql"]) - - if no_error: - status, _, addition = await self._engine.execute_sql( - quoted_sql, session, project_id=project_id - ) - - if status: - valid_generation_results.append( - { - "sql": quoted_sql, - "correlation_id": addition.get("correlation_id", ""), - } - ) - else: - invalid_generation_results.append( - { - "sql": quoted_sql, - "type": "DRY_RUN", - "error": addition.get("error_message", ""), - "correlation_id": addition.get("correlation_id", ""), - } - ) - else: - invalid_generation_results.append( - { - "sql": result["sql"], - "type": "ADD_QUOTES", - "error": "add_quotes failed", - } - ) - - async with aiohttp.ClientSession() as session: - tasks = [ - _task(generation_result) for generation_result in generation_results - ] - await asyncio.gather(*tasks) - - return valid_generation_results, invalid_generation_results - - -TEXT_TO_SQL_RULES = """ -### ALERT ### -- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database. -- ONLY USE the tables and columns mentioned in the database schema. -- ONLY USE "*" if the user query asks for all the columns of a table. -- ONLY CHOOSE columns belong to the tables mentioned in the database schema. -- YOU MUST USE "JOIN" if you choose columns from multiple tables! -- ALWAYS QUALIFY column names with their table name or table alias to avoid ambiguity (e.g., orders.OrderId, o.OrderId) -- YOU MUST USE "lower(.) like lower()" function or "lower(.) = lower()" function for case-insensitive comparison! - - Use "lower(.) LIKE lower()" when: - - The user requests a pattern or partial match. - - The value is not specific enough to be a single, exact value. - - Wildcards (%) are needed to capture the pattern. - - Use "lower(.) = lower()" when: - - The user requests an exact, specific value. - - There is no ambiguity or pattern in the value. -- ALWAYS CAST the date/time related field to "TIMESTAMP WITH TIME ZONE" type when using them in the query - - example 1: CAST(properties_closedate AS TIMESTAMP WITH TIME ZONE) - - example 2: CAST('2024-11-09 00:00:00' AS TIMESTAMP WITH TIME ZONE) - - example 3: CAST(DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AS TIMESTAMP WITH TIME ZONE) -- If the user asks for a specific date, please give the date range in SQL query - - example: "What is the total revenue for the month of 2024-11-01?" - - answer: "SELECT SUM(r.PriceSum) FROM Revenue r WHERE CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) >= CAST('2024-11-01 00:00:00' AS TIMESTAMP WITH TIME ZONE) AND CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) < CAST('2024-11-02 00:00:00' AS TIMESTAMP WITH TIME ZONE)" -- ALWAYS ADD "timestamp" to the front of the timestamp literal, ex. "timestamp '2024-02-20 12:00:00'" -- USE THE VIEW TO SIMPLIFY THE QUERY. -- DON'T MISUSE THE VIEW NAME. THE ACTUAL NAME IS FOLLOWING THE CREATE VIEW STATEMENT. -- MUST USE the value of alias from the comment section of the corresponding table or column in the DATABASE SCHEMA section for the column/table alias. - - EXAMPLE - DATABASE SCHEMA - /* {"displayName":"_orders","description":"A model representing the orders data."} */ - CREATE TABLE orders ( - -- {"description":"A column that represents the timestamp when the order was approved.","alias":"_timestamp"} - ApprovedTimestamp TIMESTAMP - } - - SQL - SELECT _orders.ApprovedTimestamp AS _timestamp FROM orders AS _orders; -- DON'T USE '.' in column/table alias, replace '.' with '_' in column/table alias. -- DON'T USE "FILTER(WHERE )" clause in the query. -- DON'T USE "EXTRACT(EPOCH FROM )" clause in the query. -- ONLY USE the following SQL functions if you need to when generating answers: - - Aggregation functions: - - AVG - - COUNT - - MAX - - MIN - - SUM - - ARRAY_AGG - - BOOL_OR - - Math functions: - - ABS - - CBRT - - CEIL - - EXP - - FLOOR - - LN - - ROUND - - SIGN - - GREATEST - - LEAST - - MOD - - POWER - - String functions: - - LENGTH - - REVERSE - - CHR - - CONCAT - - FORMAT - - LOWER - - LPAD - - LTRIM - - POSITION - - REPLACE - - RPAD - - RTRIM - - STRPOS - - SUBSTR - - SUBSTRING - - TRANSLATE - - TRIM - - UPPER - - Date and Time functions: - - CURRENT_DATE - - DATE_TRUNC - - EXTRACT - - operators: - - `+` - - `-` - - `*` - - `/` - - `||` - - `<` - - `>` - - `>=` - - `<=` - - `=` - - `<>` - - `!=` -- ONLY USE JSON_QUERY for querying fields if "json_type":"JSON" is identified in the columns comment, NOT the deprecated JSON_EXTRACT_SCALAR function. - - DON'T USE CAST for JSON fields, ONLY USE the following funtions: - - LAX_BOOL for boolean fields - - LAX_FLOAT64 for double and float fields - - LAX_INT64 for bigint fields - - LAX_STRING for varchar fields - - For Example: - DATA SCHEMA: - `/* {"displayName":"users","description":"A model representing the users data."} */ - CREATE TABLE users ( - -- {"alias":"address","description":"A JSON object that represents address information of this user.","json_type":"JSON","json_fields":{"json_type":"JSON","address.json.city":{"name":"city","type":"varchar","path":"$.city","properties":{"displayName":"city","description":"City Name."}},"address.json.state":{"name":"state","type":"varchar","path":"$.state","properties":{"displayName":"state","description":"ISO code or name of the state, province or district."}},"address.json.postcode":{"name":"postcode","type":"varchar","path":"$.postcode","properties":{"displayName":"postcode","description":"Postal code."}},"address.json.country":{"name":"country","type":"varchar","path":"$.country","properties":{"displayName":"country","description":"ISO code of the country."}}}} - address JSON - )` - To get the city of address in user table use SQL: - `SELECT LAX_STRING(JSON_QUERY(u.address, '$.city')) FROM user as u` -- ONLY USE JSON_QUERY_ARRAY for querying "json_type":"JSON_ARRAY" is identified in the comment of the column, NOT the deprecated JSON_EXTRACT_ARRAY. - - USE UNNEST to analysis each item individually in the ARRAY. YOU MUST SELECT FROM the parent table ahead of the UNNEST ARRAY. - - The alias of the UNNEST(ARRAY) should be in the format `unnest_table_alias(individual_item_alias)` - - For Example: `SELECT item FROM UNNEST(ARRAY[1,2,3]) as my_unnested_table(item)` - - If the items in the ARRAY are JSON objects, use JSON_QUERY to query the fields inside each JSON item. - - For Example: - DATA SCHEMA - `/* {"displayName":"my_table","description":"A test my_table"} */ - CREATE TABLE my_table ( - -- {"alias":"elements","description":"elements column","json_type":"JSON_ARRAY","json_fields":{"json_type":"JSON_ARRAY","elements.json_array.id":{"name":"id","type":"bigint","path":"$.id","properties":{"displayName":"id","description":"data ID."}},"elements.json_array.key":{"name":"key","type":"varchar","path":"$.key","properties":{"displayName":"key","description":"data Key."}},"elements.json_array.value":{"name":"value","type":"varchar","path":"$.value","properties":{"displayName":"value","description":"data Value."}}}} - elements JSON - )` - To get the number of elements in my_table table use SQL: - `SELECT LAX_INT64(JSON_QUERY(element, '$.number')) FROM my_table as t, UNNEST(JSON_QUERY_ARRAY(elements)) AS my_unnested_table(element) WHERE LAX_FLOAT64(JSON_QUERY(element, '$.value')) > 3.5` - - To JOIN ON the fields inside UNNEST(ARRAY), YOU MUST SELECT FROM the parent table ahead of the UNNEST syntax, and the alias of the UNNEST(ARRAY) SHOULD BE IN THE FORMAT unnest_table_alias(individual_item_alias) - - For Example: `SELECT p.column_1, j.column_2 FROM parent_table AS p, join_table AS j JOIN UNNEST(p.array_column) AS unnested(array_item) ON j.id = array_item.id` -- DON'T USE JSON_QUERY and JSON_QUERY_ARRAY when "json_type":"". -- DON'T USE LAX_BOOL, LAX_FLOAT64, LAX_INT64, LAX_STRING when "json_type":"". -""" - - -sql_generation_system_prompt = """ -You are an ANSI SQL expert with exceptional logical thinking skills. Your main task is to generate SQL from given DB schema and user-input natrual language queries. -Before the main task, you need to learn about some specific structures in the given DB schema. - -## LESSON 1 ## -The first structure is the special column marked as "Calculated Field". You need to interpret the purpose and calculation basis for these columns, then utilize them in the following text-to-sql generation tasks. -First, provide a brief explanation of what each field represents in the context of the schema, including how each field is computed using the relationships between models. -Then, during the following tasks, if the user queries pertain to any calculated fields defined in the database schema, ensure to utilize those calculated fields appropriately in the output SQL queries. -The goal is to accurately reflect the intent of the question in the SQL syntax, leveraging the pre-computed logic embedded within the calculated fields. - -### EXAMPLES ### -The given schema is created by the SQL command: - -CREATE TABLE orders ( - OrderId VARCHAR PRIMARY KEY, - CustomerId VARCHAR, - -- This column is a Calculated Field - -- column expression: avg(reviews.Score) - Rating DOUBLE, - -- This column is a Calculated Field - -- column expression: count(reviews.Id) - ReviewCount BIGINT, - -- This column is a Calculated Field - -- column expression: count(order_items.ItemNumber) - Size BIGINT, - -- This column is a Calculated Field - -- column expression: count(order_items.ItemNumber) > 1 - Large BOOLEAN, - FOREIGN KEY (CustomerId) REFERENCES customers(Id) -); - -Interpret the columns that are marked as Calculated Fields in the schema: -Rating (DOUBLE) - Calculated as the average score (avg) of the Score field from the reviews table where the reviews are associated with the order. This field represents the overall customer satisfaction rating for the order based on review scores. -ReviewCount (BIGINT) - Calculated by counting (count) the number of entries in the reviews table associated with this order. It measures the volume of customer feedback received for the order. -Size (BIGINT) - Represents the total number of items in the order, calculated by counting the number of item entries (ItemNumber) in the order_items table linked to this order. This field is useful for understanding the scale or size of an order. -Large (BOOLEAN) - A boolean value calculated to check if the number of items in the order exceeds one (count(order_items.ItemNumber) > 1). It indicates whether the order is considered large in terms of item quantity. - -And if the user input queries like these: -1. "How many large orders have been placed by customer with ID 'C1234'?" -2. "What is the average customer rating for orders that were rated by more than 10 reviewers?" - -For the first query: -First try to intepret the user query, the user wants to know the average rating for orders which have attracted significant review activity, specifically those with more than 10 reviews. -Then, according to the above intepretation about the given schema, the term 'Rating' is predefined in the Calculated Field of the 'orders' model. And, the number of reviews is also predefined in the 'ReviewCount' Calculated Field. -So utilize those Calculated Fields in the SQL generation process to give an answer like this: - -SQL Query: SELECT AVG(Rating) FROM orders WHERE ReviewCount > 10 - -## LESSON 2 ## -Second, you will learn how to effectively utilize the special "metric" structure in text-to-SQL generation tasks. -Metrics in a data model simplify complex data analysis by structuring data through predefined dimensions and measures. -This structuring closely mirrors the concept of OLAP (Online Analytical Processing) cubes but is implemented in a more flexible and SQL-friendly manner. - -The metric typically constructed of the following components: -1. Base Object -The "base object" of a metric indicates the primary data source or table that provides the raw data. -Metrics are constructed by selecting specific data points (dimensions and measures) from this base object, effectively creating a summarized or aggregated view of the data that can be queried like a normal table. -Base object is the attribute of the metric, showing the origin of this metric and is typically not used in the query. -2. Dimensions -Dimensions in a metric represent the various axes along which data can be segmented for analysis. -These are fields that provide a categorical breakdown of data. -Each dimension provides a unique perspective on the data, allowing users to "slice and dice" the data cube to view different facets of the information contained within the base dataset. -Dimensions are used as table columns in the querying process. Querying a dimension means to get the statistic from the certain perspective. -3. Measures -Measures are numerical or quantitative statistics calculated from the data. Measures are key results or outputs derived from data aggregation functions like SUM, COUNT, or AVG. -Measures are used as table columns in the querying process, and are the main querying items in the metric structure. -The expression of a measure represents the definition of the that users are intrested in. Make sure to understand the meaning of measures from their expressions. -4. Time Grain -Time Grain specifies the granularity of time-based data aggregation, such as daily, monthly, or yearly, facilitating trend analysis over specified periods. - -If the given schema contains the structures marked as 'metric', you should first interpret the metric schema based on the above definition. -Then, during the following tasks, if the user queries pertain to any metrics defined in the database schema, ensure to utilize those metrics appropriately in the output SQL queries. -The target is making complex data analysis more accessible and manageable by pre-aggregating data and structuring it using the metric structure, and supporting direct querying for business insights. - -### EXAMPLES ### -The given schema is created by the SQL command: - -/* This table is a metric */ -/* Metric Base Object: orders */ -CREATE TABLE Revenue ( - -- This column is a dimension - PurchaseTimestamp TIMESTAMP, - -- This column is a dimension - CustomerId VARCHAR, - -- This column is a dimension - Status VARCHAR, - -- This column is a measure - -- expression: sum(order_items.Price) - PriceSum DOUBLE, - -- This column is a measure - -- expression: count(OrderId) - NumberOfOrders BIGINT -); - -Interpret the metric with the understanding of the metric structure: -1. Base Object: orders -This is the primary data source for the metric. -The orders table provides the underlying data from which dimensions and measures are derived. -It is the foundation upon which the metric is built, though it itself is not directly used in queries against the Revenue table. -It shows the reference between the 'Revenue' metric and the 'orders' model. For the user queries pretain to the 'Revenue' of 'orders', the metric should be utilize in the sql generation process. -2. Dimensions -The metric contains the columns marked as 'dimension'. They can be interpreted as below: -- PurchaseTimestamp (TIMESTAMP) - Acts as a temporal dimension, allowing analysis of revenue over time. This can be used to observe trends, seasonal variations, or performance over specific periods. -- CustomerId (VARCHAR) - A key dimension for customer segmentation, it enables the analysis of revenue generated from individual customers or customer groups. -- Status (VARCHAR) - Reflects the current state of an order (e.g., pending, completed, cancelled). This dimension is crucial for analyses that differentiate performance based on order status. -3. Measures -The metric contains the columns marked as 'measure'. They can be interpreted as below: -- PriceSum (DOUBLE) - A financial measure calculated as sum(order_items.Price), representing the total revenue generated from orders. This measure is vital for tracking overall sales performance and is the primary output of interest in many financial and business analyses. -- NumberOfOrders (BIGINT) - A count measure that provides the total number of orders. This is essential for operational metrics, such as assessing the volume of business activity and evaluating the efficiency of sales processes. - -Now, if the user input queries like this: -Question: "What was the total revenue from each customer last month?" - -First try to intepret the user query, the user asks for a breakdown of the total revenue generated by each customer in the previous calendar month. -The user is specifically interested in understanding how much each customer contributed to the total sales during this period. -To answer this question, it is suitable to use the following components from the metric: -1. CustomerId (Dimension): This will be used to group the revenue data by each unique customer, allowing us to segment the total revenue by customer. -2. PurchaseTimestamp (Dimension): This timestamp field will be used to filter the data to only include orders from the last month. -3. PriceSum (Measure): Since PriceSum is a pre-aggregated measure of total revenue (sum of order_items.Price), it can be directly used to sum up the revenue without needing further aggregation in the SQL query. -So utilize those metric components in the SQL generation process to give an answer like this: - -SQL Query: -SELECT - CustomerId, - PriceSum AS TotalRevenue -FROM - Revenue -WHERE - PurchaseTimestamp >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND - PurchaseTimestamp < DATE_TRUNC('month', CURRENT_DATE) - -Learn about the usage of the schema structures and generate SQL based on them. - -""" - - -def construct_instructions(configuration: Configuration | None): - instructions = "" - if configuration: - if configuration.fiscal_year: - instructions += f"- For calendar year related computation, it should be started from {configuration.fiscal_year.start} to {configuration.fiscal_year.end}" - - return instructions - def show_current_time(timezone: Configuration.Timezone): # Get the current time in the specified timezone diff --git a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py index 755073598..7371c768a 100644 --- a/wren-ai-service/src/pipelines/generation/followup_sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/followup_sql_generation.py @@ -11,11 +11,11 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import ( +from src.pipelines.common import show_current_time +from src.pipelines.generation.utils.sql import ( TEXT_TO_SQL_RULES, SQLGenPostProcessor, construct_instructions, - show_current_time, sql_generation_system_prompt, ) from src.utils import async_timer, timer diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index b823edb56..9b0ae93f6 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -22,42 +22,59 @@ intent_classification_system_prompt = """ ### TASK ### -You are a great detective, who is great at intent classification. Now you need to classify user's intent based on given database schema and user's question to one of three conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL. -Please carefully analyze user's question and analyze database's schema carefully to make the classification correct. -Also you should provide reasoning for the classification in clear and concise way within 20 words. +You are a great detective, who is great at intent classification. +First, rephrase the user's question to make it more specific, clear and relevant to the database schema before making the intent classification. +Second, you need to use rephrased user's question to classify user's intent based on given database schema to one of three conditions: MISLEADING_QUERY, TEXT_TO_SQL, GENERAL. +Also you should provide reasoning for the classification clearly and concisely within 20 words. + +### INSTRUCTIONS ### +- Steps to rephrase the user's question: + - First, try to recognize adjectives in the user's question that are important to the user's intent. + - Second, change the adjectives to more specific and clear ones that can be matched to columns in the database schema. +- MUST use the rephrased user's question to make the intent classification. +- MUST put the rephrased user's question in the rephrased_question output. +- REASONING MUST be within 20 words. +- If the rephrased user's question is vague and doesn't specify which table or property to analyze, classify it as MISLEADING_QUERY. ### INTENT DEFINITIONS ### - - TEXT_TO_SQL - When to Use: - Select this category if the user's question is directly related to the given database schema and can be answered by generating an SQL query using that schema. - - If the user's question is related to the previous question, and considering them together could be answered by generating an SQL query using that schema. + - If the rephrasedd user's question is related to the previous question, and considering them together could be answered by generating an SQL query using that schema. - Characteristics: - - The question involves specific data retrieval or manipulation that requires SQL. - - It references tables, columns, or specific data points within the schema. + - The rephrasedd user's question involves specific data retrieval or manipulation that requires SQL. + - The rephrasedd user's question references tables, columns, or specific data points within the schema. + - Instructions: + - MUST include table and column names that should be used in the SQL query according to the database schema in the reasoning output. + - MUST include phrases from the user's question that are explicitly related to the database schema in the reasoning output. - Examples: - "What is the total sales for last quarter?" - "Show me all customers who purchased product X." - "List the top 10 products by revenue." - MISLEADING_QUERY - When to Use: - - If the user's question is irrelevant to the given database schema and cannot be answered using SQL with that schema. - - If the user's question is not related to the previous question, and considering them together cannot be answered by generating an SQL query using that schema. - - If the user's question contains SQL code. + - If the rephrasedd user's question is irrelevant to the given database schema and cannot be answered using SQL with that schema. + - If the rephrasedd user's question is not related to the previous question, and considering them together cannot be answered by generating an SQL query using that schema. + - If the rephrasedd user's question contains SQL code. - Characteristics: - - The question does not pertain to any aspect of the database or its data. - - It might be a casual conversation starter or about an entirely different topic. + - The rephrasedd user's question does not pertain to any aspect of the database or its data. + - The rephrasedd user's question might be a casual conversation starter or about an entirely different topic. + - The rephrasedd user's question is vague and doesn't specify which table or property to analyze. + - Instructions: + - MUST explicitly add phrases from the rephrasedd user's question that are not explicitly related to the database schema in the reasoning output. Choose the most relevant phrases that cause the rephrasedd user's question to be MISLEADING_QUERY. - Examples: - "How are you?" - "What's the weather like today?" - "Tell me a joke." - GENERAL - When to Use: - - Use this category if the user is seeking general information about the database schema, needs help formulating a proper question, or asks a vague question related to the schema. - - If the user's question is related to the previous question, but considering them together cannot be answered by generating an SQL query using that schema. + - Use this category if the user is seeking general information about the database schema. + - If the rephrasedd user's question is related to the previous question, but considering them together cannot be answered by generating an SQL query using that schema. - Characteristics: - The question is about understanding the dataset or its capabilities. - The user may need guidance on how to proceed or what questions to ask. + - Instructions: + - MUST explicitly add phrases from the rephrasedd user's question that are not explicitly related to the database schema in the reasoning output. Choose the most relevant phrases that cause the rephrasedd user's question to be GENERAL. - Examples: - "What is the dataset about?" - "Tell me more about the database." @@ -68,7 +85,8 @@ Please provide your response as a JSON object, structured as follows: { - "reasoning": "", + "rephrased_question": "", + "reasoning": "", "results": "MISLEADING_QUERY" | "TEXT_TO_SQL" | "GENERAL" } """ @@ -212,8 +230,6 @@ def prompt( [step.summary for step in history.steps if step.summary] if history else [] ) - # query = "\n".join(previous_query_summaries) + "\n" + query - return prompt_builder.run( query=query, db_schemas=construct_db_schemas, @@ -231,13 +247,20 @@ async def classify_intent(prompt: dict, generator: Any) -> dict: @observe(capture_input=False) def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict: try: - intent = orjson.loads(classify_intent.get("replies")[0])["results"] + results = orjson.loads(classify_intent.get("replies")[0]) return { - "intent": intent, + "intent": results["results"], + "rephrased_question": results["rephrased_question"], + "reasoning": results["reasoning"], "db_schemas": construct_db_schemas, } except Exception: - return {"intent": "TEXT_TO_SQL", "db_schemas": construct_db_schemas} + return { + "intent": "TEXT_TO_SQL", + "rephrased_question": "", + "reasoning": "", + "db_schemas": construct_db_schemas, + } ## End of Pipeline @@ -245,6 +268,8 @@ def post_process(classify_intent: dict, construct_db_schemas: list[str]) -> dict class IntentClassificationResult(BaseModel): results: Literal["MISLEADING_QUERY", "TEXT_TO_SQL", "GENERAL"] + rephrased_question: str + reasoning: str INTENT_CLASSIFICAION_MODEL_KWARGS = { diff --git a/wren-ai-service/src/pipelines/generation/sql_breakdown.py b/wren-ai-service/src/pipelines/generation/sql_breakdown.py index b34e450d6..a84ab8962 100644 --- a/wren-ai-service/src/pipelines/generation/sql_breakdown.py +++ b/wren-ai-service/src/pipelines/generation/sql_breakdown.py @@ -11,7 +11,10 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import TEXT_TO_SQL_RULES, SQLBreakdownGenPostProcessor +from src.pipelines.generation.utils.sql import ( + TEXT_TO_SQL_RULES, + SQLBreakdownGenPostProcessor, +) from src.utils import ( async_timer, timer, diff --git a/wren-ai-service/src/pipelines/generation/sql_correction.py b/wren-ai-service/src/pipelines/generation/sql_correction.py index 8d86ab769..7b09c59e9 100644 --- a/wren-ai-service/src/pipelines/generation/sql_correction.py +++ b/wren-ai-service/src/pipelines/generation/sql_correction.py @@ -13,7 +13,7 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import ( +from src.pipelines.generation.utils.sql import ( TEXT_TO_SQL_RULES, SQLGenPostProcessor, sql_generation_system_prompt, diff --git a/wren-ai-service/src/pipelines/generation/sql_expansion.py b/wren-ai-service/src/pipelines/generation/sql_expansion.py index bfa43b623..7373edc6e 100644 --- a/wren-ai-service/src/pipelines/generation/sql_expansion.py +++ b/wren-ai-service/src/pipelines/generation/sql_expansion.py @@ -11,7 +11,8 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import SQLGenPostProcessor, show_current_time +from src.pipelines.common import show_current_time +from src.pipelines.generation.utils.sql import SQLGenPostProcessor from src.utils import async_timer, timer from src.web.v1.services import Configuration from src.web.v1.services.ask import AskHistory diff --git a/wren-ai-service/src/pipelines/generation/sql_generation.py b/wren-ai-service/src/pipelines/generation/sql_generation.py index 32a7c93c3..f33ca79e5 100644 --- a/wren-ai-service/src/pipelines/generation/sql_generation.py +++ b/wren-ai-service/src/pipelines/generation/sql_generation.py @@ -11,11 +11,11 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import ( +from src.pipelines.common import show_current_time +from src.pipelines.generation.utils.sql import ( TEXT_TO_SQL_RULES, SQLGenPostProcessor, construct_instructions, - show_current_time, sql_generation_system_prompt, ) from src.utils import async_timer, timer @@ -26,7 +26,7 @@ sql_generation_user_prompt_template = """ ### TASK ### -Given a user query that is ambiguous in nature, your task is to interpret the query in various plausible ways and +Given a user query, your task is to interpret the query based on the database schema and generate one SQL statement that best potentially answer user's query. ### DATABASE SCHEMA ### diff --git a/wren-ai-service/src/pipelines/generation/sql_regeneration.py b/wren-ai-service/src/pipelines/generation/sql_regeneration.py index ccdd76161..a317b88f1 100644 --- a/wren-ai-service/src/pipelines/generation/sql_regeneration.py +++ b/wren-ai-service/src/pipelines/generation/sql_regeneration.py @@ -12,7 +12,7 @@ from src.core.engine import Engine from src.core.pipeline import BasicPipeline from src.core.provider import LLMProvider -from src.pipelines.common import SQLBreakdownGenPostProcessor +from src.pipelines.generation.utils.sql import SQLBreakdownGenPostProcessor from src.utils import async_timer, timer from src.web.v1.services.sql_regeneration import ( SQLExplanationWithUserCorrections, diff --git a/wren-ai-service/src/pipelines/generation/utils/sql.py b/wren-ai-service/src/pipelines/generation/utils/sql.py new file mode 100644 index 000000000..b2b298876 --- /dev/null +++ b/wren-ai-service/src/pipelines/generation/utils/sql.py @@ -0,0 +1,485 @@ +import asyncio +import logging +from typing import Any, Dict, List, Optional + +import aiohttp +import orjson +from haystack import component + +from src.core.engine import ( + Engine, + add_quotes, + clean_generation_result, +) +from src.web.v1.services import Configuration + +logger = logging.getLogger("wren-ai-service") + + +@component +class SQLBreakdownGenPostProcessor: + def __init__(self, engine: Engine): + self._engine = engine + + @component.output_types( + results=Optional[Dict[str, Any]], + ) + async def run( + self, + replies: List[str], + project_id: str | None = None, + ) -> Dict[str, Any]: + cleaned_generation_result = orjson.loads(clean_generation_result(replies[0])) + + steps = cleaned_generation_result.get("steps", []) + if not steps: + return { + "results": { + "description": cleaned_generation_result["description"], + "steps": [], + }, + } + + # make sure the last step has an empty cte_name + steps[-1]["cte_name"] = "" + + for step in steps: + step["sql"], no_error = add_quotes(step["sql"]) + if not no_error: + return { + "results": { + "description": cleaned_generation_result["description"], + "steps": [], + }, + } + + sql = self._build_cte_query(steps) + + if not await self._check_if_sql_executable(sql, project_id=project_id): + return { + "results": { + "description": cleaned_generation_result["description"], + "steps": [], + }, + } + + return { + "results": { + "description": cleaned_generation_result["description"], + "steps": steps, + }, + } + + def _build_cte_query(self, steps) -> str: + ctes = ",\n".join( + f"{step['cte_name']} AS ({step['sql']})" + for step in steps + if step["cte_name"] + ) + + return f"WITH {ctes}\n" + steps[-1]["sql"] if ctes else steps[-1]["sql"] + + async def _check_if_sql_executable( + self, + sql: str, + project_id: str | None = None, + ): + async with aiohttp.ClientSession() as session: + status, _, addition = await self._engine.execute_sql( + sql, + session, + project_id=project_id, + ) + + if not status: + logger.exception( + f"SQL is not executable: {addition.get('error_message', '')}" + ) + + return status + + +@component +class SQLGenPostProcessor: + def __init__(self, engine: Engine): + self._engine = engine + + @component.output_types( + valid_generation_results=List[Optional[Dict[str, Any]]], + invalid_generation_results=List[Optional[Dict[str, Any]]], + ) + async def run( + self, + replies: List[str] | List[List[str]], + project_id: str | None = None, + ) -> dict: + try: + if isinstance(replies[0], dict): + cleaned_generation_result = [] + for reply in replies: + try: + cleaned_generation_result.append( + orjson.loads(clean_generation_result(reply["replies"][0]))[ + "results" + ][0] + ) + except Exception as e: + logger.exception(f"Error in SQLGenPostProcessor: {e}") + else: + cleaned_generation_result = orjson.loads( + clean_generation_result(replies[0]) + )["results"] + + if isinstance(cleaned_generation_result, dict): + cleaned_generation_result = [cleaned_generation_result] + + ( + valid_generation_results, + invalid_generation_results, + ) = await self._classify_invalid_generation_results( + cleaned_generation_result, project_id=project_id + ) + + return { + "valid_generation_results": valid_generation_results, + "invalid_generation_results": invalid_generation_results, + } + except Exception as e: + logger.exception(f"Error in SQLGenPostProcessor: {e}") + + return { + "valid_generation_results": [], + "invalid_generation_results": [], + } + + async def _classify_invalid_generation_results( + self, generation_results: List[Dict[str, str]], project_id: str | None = None + ) -> List[Optional[Dict[str, str]]]: + valid_generation_results = [] + invalid_generation_results = [] + + async def _task(result: Dict[str, str]): + quoted_sql, no_error = add_quotes(result["sql"]) + + if no_error: + status, _, addition = await self._engine.execute_sql( + quoted_sql, session, project_id=project_id + ) + + if status: + valid_generation_results.append( + { + "sql": quoted_sql, + "correlation_id": addition.get("correlation_id", ""), + } + ) + else: + invalid_generation_results.append( + { + "sql": quoted_sql, + "type": "DRY_RUN", + "error": addition.get("error_message", ""), + "correlation_id": addition.get("correlation_id", ""), + } + ) + else: + invalid_generation_results.append( + { + "sql": result["sql"], + "type": "ADD_QUOTES", + "error": "add_quotes failed", + } + ) + + async with aiohttp.ClientSession() as session: + tasks = [ + _task(generation_result) for generation_result in generation_results + ] + await asyncio.gather(*tasks) + + return valid_generation_results, invalid_generation_results + + +TEXT_TO_SQL_RULES = """ +### ALERT ### +- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database. +- ONLY USE the tables and columns mentioned in the database schema. +- ONLY USE "*" if the user query asks for all the columns of a table. +- ONLY CHOOSE columns belong to the tables mentioned in the database schema. +- YOU MUST USE "JOIN" if you choose columns from multiple tables! +- ALWAYS QUALIFY column names with their table name or table alias to avoid ambiguity (e.g., orders.OrderId, o.OrderId) +- YOU MUST USE "lower(.) like lower()" function or "lower(.) = lower()" function for case-insensitive comparison! + - Use "lower(.) LIKE lower()" when: + - The user requests a pattern or partial match. + - The value is not specific enough to be a single, exact value. + - Wildcards (%) are needed to capture the pattern. + - Use "lower(.) = lower()" when: + - The user requests an exact, specific value. + - There is no ambiguity or pattern in the value. +- ALWAYS CAST the date/time related field to "TIMESTAMP WITH TIME ZONE" type when using them in the query + - example 1: CAST(properties_closedate AS TIMESTAMP WITH TIME ZONE) + - example 2: CAST('2024-11-09 00:00:00' AS TIMESTAMP WITH TIME ZONE) + - example 3: CAST(DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AS TIMESTAMP WITH TIME ZONE) +- If the user asks for a specific date, please give the date range in SQL query + - example: "What is the total revenue for the month of 2024-11-01?" + - answer: "SELECT SUM(r.PriceSum) FROM Revenue r WHERE CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) >= CAST('2024-11-01 00:00:00' AS TIMESTAMP WITH TIME ZONE) AND CAST(r.PurchaseTimestamp AS TIMESTAMP WITH TIME ZONE) < CAST('2024-11-02 00:00:00' AS TIMESTAMP WITH TIME ZONE)" +- ALWAYS ADD "timestamp" to the front of the timestamp literal, ex. "timestamp '2024-02-20 12:00:00'" +- USE THE VIEW TO SIMPLIFY THE QUERY. +- DON'T MISUSE THE VIEW NAME. THE ACTUAL NAME IS FOLLOWING THE CREATE VIEW STATEMENT. +- MUST USE the value of alias from the comment section of the corresponding table or column in the DATABASE SCHEMA section for the column/table alias. + - EXAMPLE + DATABASE SCHEMA + /* {"displayName":"_orders","description":"A model representing the orders data."} */ + CREATE TABLE orders ( + -- {"description":"A column that represents the timestamp when the order was approved.","alias":"_timestamp"} + ApprovedTimestamp TIMESTAMP + } + + SQL + SELECT _orders.ApprovedTimestamp AS _timestamp FROM orders AS _orders; +- DON'T USE '.' in column/table alias, replace '.' with '_' in column/table alias. +- DON'T USE "FILTER(WHERE )" clause in the generated SQL query. +- DON'T USE "EXTRACT(EPOCH FROM )" clause in the generated SQL query. +- DON'T USE INTERVAL or generate INTERVAL-like expression in the generated SQL query. +- ONLY USE the following SQL keywords while generating SQL query: + - Aggregation functions: + - AVG + - COUNT + - MAX + - MIN + - SUM + - ARRAY_AGG + - BOOL_OR + - Math functions: + - ABS + - CBRT + - CEIL + - EXP + - FLOOR + - LN + - ROUND + - SIGN + - GREATEST + - LEAST + - MOD + - POWER + - String functions: + - LENGTH + - REVERSE + - CHR + - CONCAT + - FORMAT + - LOWER + - LPAD + - LTRIM + - POSITION + - REPLACE + - RPAD + - RTRIM + - STRPOS + - SUBSTR + - SUBSTRING + - TRANSLATE + - TRIM + - UPPER + - Date and Time functions: + - CURRENT_DATE + - DATE_TRUNC + - EXTRACT + - operators: + - `+` + - `-` + - `*` + - `/` + - `||` + - `<` + - `>` + - `>=` + - `<=` + - `=` + - `<>` + - `!=` +- ONLY USE JSON_QUERY for querying fields if "json_type":"JSON" is identified in the columns comment, NOT the deprecated JSON_EXTRACT_SCALAR function. + - DON'T USE CAST for JSON fields, ONLY USE the following funtions: + - LAX_BOOL for boolean fields + - LAX_FLOAT64 for double and float fields + - LAX_INT64 for bigint fields + - LAX_STRING for varchar fields + - For Example: + DATA SCHEMA: + `/* {"displayName":"users","description":"A model representing the users data."} */ + CREATE TABLE users ( + -- {"alias":"address","description":"A JSON object that represents address information of this user.","json_type":"JSON","json_fields":{"json_type":"JSON","address.json.city":{"name":"city","type":"varchar","path":"$.city","properties":{"displayName":"city","description":"City Name."}},"address.json.state":{"name":"state","type":"varchar","path":"$.state","properties":{"displayName":"state","description":"ISO code or name of the state, province or district."}},"address.json.postcode":{"name":"postcode","type":"varchar","path":"$.postcode","properties":{"displayName":"postcode","description":"Postal code."}},"address.json.country":{"name":"country","type":"varchar","path":"$.country","properties":{"displayName":"country","description":"ISO code of the country."}}}} + address JSON + )` + To get the city of address in user table use SQL: + `SELECT LAX_STRING(JSON_QUERY(u.address, '$.city')) FROM user as u` +- ONLY USE JSON_QUERY_ARRAY for querying "json_type":"JSON_ARRAY" is identified in the comment of the column, NOT the deprecated JSON_EXTRACT_ARRAY. + - USE UNNEST to analysis each item individually in the ARRAY. YOU MUST SELECT FROM the parent table ahead of the UNNEST ARRAY. + - The alias of the UNNEST(ARRAY) should be in the format `unnest_table_alias(individual_item_alias)` + - For Example: `SELECT item FROM UNNEST(ARRAY[1,2,3]) as my_unnested_table(item)` + - If the items in the ARRAY are JSON objects, use JSON_QUERY to query the fields inside each JSON item. + - For Example: + DATA SCHEMA + `/* {"displayName":"my_table","description":"A test my_table"} */ + CREATE TABLE my_table ( + -- {"alias":"elements","description":"elements column","json_type":"JSON_ARRAY","json_fields":{"json_type":"JSON_ARRAY","elements.json_array.id":{"name":"id","type":"bigint","path":"$.id","properties":{"displayName":"id","description":"data ID."}},"elements.json_array.key":{"name":"key","type":"varchar","path":"$.key","properties":{"displayName":"key","description":"data Key."}},"elements.json_array.value":{"name":"value","type":"varchar","path":"$.value","properties":{"displayName":"value","description":"data Value."}}}} + elements JSON + )` + To get the number of elements in my_table table use SQL: + `SELECT LAX_INT64(JSON_QUERY(element, '$.number')) FROM my_table as t, UNNEST(JSON_QUERY_ARRAY(elements)) AS my_unnested_table(element) WHERE LAX_FLOAT64(JSON_QUERY(element, '$.value')) > 3.5` + - To JOIN ON the fields inside UNNEST(ARRAY), YOU MUST SELECT FROM the parent table ahead of the UNNEST syntax, and the alias of the UNNEST(ARRAY) SHOULD BE IN THE FORMAT unnest_table_alias(individual_item_alias) + - For Example: `SELECT p.column_1, j.column_2 FROM parent_table AS p, join_table AS j JOIN UNNEST(p.array_column) AS unnested(array_item) ON j.id = array_item.id` +- DON'T USE JSON_QUERY and JSON_QUERY_ARRAY when "json_type":"". +- DON'T USE LAX_BOOL, LAX_FLOAT64, LAX_INT64, LAX_STRING when "json_type":"". +""" + + +sql_generation_system_prompt = """ +You are an ANSI SQL expert with exceptional logical thinking skills. Your main task is to generate SQL from given DB schema and user-input natrual language queries. +Before the main task, you need to learn about some specific structures in the given DB schema. + +## LESSON 1 ## +The first structure is the special column marked as "Calculated Field". You need to interpret the purpose and calculation basis for these columns, then utilize them in the following text-to-sql generation tasks. +First, provide a brief explanation of what each field represents in the context of the schema, including how each field is computed using the relationships between models. +Then, during the following tasks, if the user queries pertain to any calculated fields defined in the database schema, ensure to utilize those calculated fields appropriately in the output SQL queries. +The goal is to accurately reflect the intent of the question in the SQL syntax, leveraging the pre-computed logic embedded within the calculated fields. + +### EXAMPLES ### +The given schema is created by the SQL command: + +CREATE TABLE orders ( + OrderId VARCHAR PRIMARY KEY, + CustomerId VARCHAR, + -- This column is a Calculated Field + -- column expression: avg(reviews.Score) + Rating DOUBLE, + -- This column is a Calculated Field + -- column expression: count(reviews.Id) + ReviewCount BIGINT, + -- This column is a Calculated Field + -- column expression: count(order_items.ItemNumber) + Size BIGINT, + -- This column is a Calculated Field + -- column expression: count(order_items.ItemNumber) > 1 + Large BOOLEAN, + FOREIGN KEY (CustomerId) REFERENCES customers(Id) +); + +Interpret the columns that are marked as Calculated Fields in the schema: +Rating (DOUBLE) - Calculated as the average score (avg) of the Score field from the reviews table where the reviews are associated with the order. This field represents the overall customer satisfaction rating for the order based on review scores. +ReviewCount (BIGINT) - Calculated by counting (count) the number of entries in the reviews table associated with this order. It measures the volume of customer feedback received for the order. +Size (BIGINT) - Represents the total number of items in the order, calculated by counting the number of item entries (ItemNumber) in the order_items table linked to this order. This field is useful for understanding the scale or size of an order. +Large (BOOLEAN) - A boolean value calculated to check if the number of items in the order exceeds one (count(order_items.ItemNumber) > 1). It indicates whether the order is considered large in terms of item quantity. + +And if the user input queries like these: +1. "How many large orders have been placed by customer with ID 'C1234'?" +2. "What is the average customer rating for orders that were rated by more than 10 reviewers?" + +For the first query: +First try to intepret the user query, the user wants to know the average rating for orders which have attracted significant review activity, specifically those with more than 10 reviews. +Then, according to the above intepretation about the given schema, the term 'Rating' is predefined in the Calculated Field of the 'orders' model. And, the number of reviews is also predefined in the 'ReviewCount' Calculated Field. +So utilize those Calculated Fields in the SQL generation process to give an answer like this: + +SQL Query: SELECT AVG(Rating) FROM orders WHERE ReviewCount > 10 + +## LESSON 2 ## +Second, you will learn how to effectively utilize the special "metric" structure in text-to-SQL generation tasks. +Metrics in a data model simplify complex data analysis by structuring data through predefined dimensions and measures. +This structuring closely mirrors the concept of OLAP (Online Analytical Processing) cubes but is implemented in a more flexible and SQL-friendly manner. + +The metric typically constructed of the following components: +1. Base Object +The "base object" of a metric indicates the primary data source or table that provides the raw data. +Metrics are constructed by selecting specific data points (dimensions and measures) from this base object, effectively creating a summarized or aggregated view of the data that can be queried like a normal table. +Base object is the attribute of the metric, showing the origin of this metric and is typically not used in the query. +2. Dimensions +Dimensions in a metric represent the various axes along which data can be segmented for analysis. +These are fields that provide a categorical breakdown of data. +Each dimension provides a unique perspective on the data, allowing users to "slice and dice" the data cube to view different facets of the information contained within the base dataset. +Dimensions are used as table columns in the querying process. Querying a dimension means to get the statistic from the certain perspective. +3. Measures +Measures are numerical or quantitative statistics calculated from the data. Measures are key results or outputs derived from data aggregation functions like SUM, COUNT, or AVG. +Measures are used as table columns in the querying process, and are the main querying items in the metric structure. +The expression of a measure represents the definition of the that users are intrested in. Make sure to understand the meaning of measures from their expressions. +4. Time Grain +Time Grain specifies the granularity of time-based data aggregation, such as daily, monthly, or yearly, facilitating trend analysis over specified periods. + +If the given schema contains the structures marked as 'metric', you should first interpret the metric schema based on the above definition. +Then, during the following tasks, if the user queries pertain to any metrics defined in the database schema, ensure to utilize those metrics appropriately in the output SQL queries. +The target is making complex data analysis more accessible and manageable by pre-aggregating data and structuring it using the metric structure, and supporting direct querying for business insights. + +### EXAMPLES ### +The given schema is created by the SQL command: + +/* This table is a metric */ +/* Metric Base Object: orders */ +CREATE TABLE Revenue ( + -- This column is a dimension + PurchaseTimestamp TIMESTAMP, + -- This column is a dimension + CustomerId VARCHAR, + -- This column is a dimension + Status VARCHAR, + -- This column is a measure + -- expression: sum(order_items.Price) + PriceSum DOUBLE, + -- This column is a measure + -- expression: count(OrderId) + NumberOfOrders BIGINT +); + +Interpret the metric with the understanding of the metric structure: +1. Base Object: orders +This is the primary data source for the metric. +The orders table provides the underlying data from which dimensions and measures are derived. +It is the foundation upon which the metric is built, though it itself is not directly used in queries against the Revenue table. +It shows the reference between the 'Revenue' metric and the 'orders' model. For the user queries pretain to the 'Revenue' of 'orders', the metric should be utilize in the sql generation process. +2. Dimensions +The metric contains the columns marked as 'dimension'. They can be interpreted as below: +- PurchaseTimestamp (TIMESTAMP) + Acts as a temporal dimension, allowing analysis of revenue over time. This can be used to observe trends, seasonal variations, or performance over specific periods. +- CustomerId (VARCHAR) + A key dimension for customer segmentation, it enables the analysis of revenue generated from individual customers or customer groups. +- Status (VARCHAR) + Reflects the current state of an order (e.g., pending, completed, cancelled). This dimension is crucial for analyses that differentiate performance based on order status. +3. Measures +The metric contains the columns marked as 'measure'. They can be interpreted as below: +- PriceSum (DOUBLE) + A financial measure calculated as sum(order_items.Price), representing the total revenue generated from orders. This measure is vital for tracking overall sales performance and is the primary output of interest in many financial and business analyses. +- NumberOfOrders (BIGINT) + A count measure that provides the total number of orders. This is essential for operational metrics, such as assessing the volume of business activity and evaluating the efficiency of sales processes. + +Now, if the user input queries like this: +Question: "What was the total revenue from each customer last month?" + +First try to intepret the user query, the user asks for a breakdown of the total revenue generated by each customer in the previous calendar month. +The user is specifically interested in understanding how much each customer contributed to the total sales during this period. +To answer this question, it is suitable to use the following components from the metric: +1. CustomerId (Dimension): This will be used to group the revenue data by each unique customer, allowing us to segment the total revenue by customer. +2. PurchaseTimestamp (Dimension): This timestamp field will be used to filter the data to only include orders from the last month. +3. PriceSum (Measure): Since PriceSum is a pre-aggregated measure of total revenue (sum of order_items.Price), it can be directly used to sum up the revenue without needing further aggregation in the SQL query. +So utilize those metric components in the SQL generation process to give an answer like this: + +SQL Query: +SELECT + CustomerId, + PriceSum AS TotalRevenue +FROM + Revenue +WHERE + PurchaseTimestamp >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '1 month') AND + PurchaseTimestamp < DATE_TRUNC('month', CURRENT_DATE) + +Learn about the usage of the schema structures and generate SQL based on them. + +""" + + +def construct_instructions(configuration: Configuration | None): + instructions = "" + if configuration: + if configuration.fiscal_year: + instructions += f"- For calendar year related computation, it should be started from {configuration.fiscal_year.start} to {configuration.fiscal_year.end}" + + return instructions diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 4b1701e2d..6ef6bad2e 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -90,6 +90,8 @@ class AskResultResponse(BaseModel): "failed", "stopped", ] + rephrased_question: Optional[str] = None + intent_reasoning: Optional[str] = None type: Optional[Literal["MISLEADING_QUERY", "GENERAL", "TEXT_TO_SQL"]] = None response: Optional[List[AskResult]] = None error: Optional[AskError] = None @@ -137,11 +139,13 @@ async def ask( }, } + query_id = ask_request.query_id + rephrased_question = None + intent_reasoning = None + try: # ask status can be understanding, searching, generating, finished, failed, stopped # we will need to handle business logic for each status - query_id = ask_request.query_id - if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="understanding", @@ -155,17 +159,28 @@ async def ask( ) ).get("post_process", {}) intent = intent_classification_result.get("intent") + rephrased_question = intent_classification_result.get( + "rephrased_question" + ) + intent_reasoning = intent_classification_result.get("reasoning") + + user_query = ( + ask_request.query if not rephrased_question else rephrased_question + ) + if intent == "MISLEADING_QUERY": self._ask_results[query_id] = AskResultResponse( status="finished", type="MISLEADING_QUERY", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["metadata"]["type"] = "MISLEADING_QUERY" return results elif intent == "GENERAL": asyncio.create_task( self._pipelines["data_assistance"].run( - query=ask_request.query, + query=user_query, history=ask_request.history, db_schemas=intent_classification_result.get("db_schemas"), language=ask_request.configurations.language, @@ -174,27 +189,38 @@ async def ask( ) self._ask_results[query_id] = AskResultResponse( - status="finished", type="GENERAL" + status="finished", + type="GENERAL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["metadata"]["type"] = "GENERAL" return results + else: + self._ask_results[query_id] = AskResultResponse( + status="understanding", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, + ) if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="searching", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) retrieval_result = await self._pipelines["retrieval"].run( - query=ask_request.query, + query=user_query, history=ask_request.history, id=ask_request.project_id, ) documents = retrieval_result.get("construct_retrieval_results", []) if not documents: - logger.exception( - f"ask pipeline - NO_RELEVANT_DATA: {ask_request.query}" - ) + logger.exception(f"ask pipeline - NO_RELEVANT_DATA: {user_query}") if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="failed", @@ -203,6 +229,8 @@ async def ask( code="NO_RELEVANT_DATA", message="No relevant data", ), + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_DATA" results["metadata"]["type"] = "TEXT_TO_SQL" @@ -211,10 +239,13 @@ async def ask( if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="generating", + type="TEXT_TO_SQL", + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) historical_question = await self._pipelines["historical_question"].run( - query=ask_request.query, + query=user_query, id=ask_request.project_id, ) @@ -240,7 +271,7 @@ async def ask( text_to_sql_generation_results = await self._pipelines[ "followup_sql_generation" ].run( - query=ask_request.query, + query=user_query, contexts=documents, history=ask_request.history, project_id=ask_request.project_id, @@ -250,7 +281,7 @@ async def ask( text_to_sql_generation_results = await self._pipelines[ "sql_generation" ].run( - query=ask_request.query, + query=user_query, contexts=documents, exclude=historical_question_result, project_id=ask_request.project_id, @@ -304,13 +335,13 @@ async def ask( status="finished", type="TEXT_TO_SQL", response=api_results, + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["ask_result"] = api_results results["metadata"]["type"] = "TEXT_TO_SQL" else: - logger.exception( - f"ask pipeline - NO_RELEVANT_SQL: {ask_request.query}" - ) + logger.exception(f"ask pipeline - NO_RELEVANT_SQL: {user_query}") if not self._is_stopped(query_id): self._ask_results[query_id] = AskResultResponse( status="failed", @@ -319,6 +350,8 @@ async def ask( code="NO_RELEVANT_SQL", message="No relevant SQL", ), + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["metadata"]["error_type"] = "NO_RELEVANT_SQL" results["metadata"]["type"] = "TEXT_TO_SQL" @@ -327,13 +360,15 @@ async def ask( except Exception as e: logger.exception(f"ask pipeline - OTHERS: {e}") - self._ask_results[ask_request.query_id] = AskResultResponse( + self._ask_results[query_id] = AskResultResponse( status="failed", type="TEXT_TO_SQL", error=AskError( code="OTHERS", message=str(e), ), + rephrased_question=rephrased_question, + intent_reasoning=intent_reasoning, ) results["metadata"]["error_type"] = "OTHERS" diff --git a/wren-ai-service/tests/pytest/test_usecases.py b/wren-ai-service/tests/pytest/test_usecases.py index 58e6ad8e1..9436d8276 100644 --- a/wren-ai-service/tests/pytest/test_usecases.py +++ b/wren-ai-service/tests/pytest/test_usecases.py @@ -10,6 +10,7 @@ import aiohttp import orjson import requests +import yaml from demo.utils import ( _get_connection_info, @@ -34,16 +35,16 @@ def test_load_mdl_and_questions(usecases: list[str]): with open(f"tests/data/usecases/{usecase}/mdl.json", "r") as f: mdl_str = orjson.dumps(json.load(f)).decode("utf-8") - with open(f"tests/data/usecases/{usecase}/questions.json", "r") as f: - questions = json.load(f) + with open(f"tests/data/usecases/{usecase}/questions.yaml", "r") as f: + questions = yaml.safe_load(f) mdls_and_questions[usecase] = { "mdl_str": mdl_str, - "questions": questions, + "questions": [question["question"] for question in questions], } except FileNotFoundError: raise Exception( - f"tests/data/usecases/{usecase}/mdl.json or tests/data/usecases/{usecase}/questions.json not found" + f"tests/data/usecases/{usecase}/mdl.json or tests/data/usecases/{usecase}/questions.yaml not found" ) return mdls_and_questions diff --git a/wren-ai-service/tools/config/config.example.yaml b/wren-ai-service/tools/config/config.example.yaml index c439ccc59..2ba4c8ecc 100644 --- a/wren-ai-service/tools/config/config.example.yaml +++ b/wren-ai-service/tools/config/config.example.yaml @@ -8,6 +8,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response, refer: https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed + seed: 0 max_tokens: 4096 response_format: type: json_object @@ -17,6 +19,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response, refer: https://platform.openai.com/docs/api-reference/chat/create#chat-create-seed + seed: 0 max_tokens: 4096 response_format: type: json_object diff --git a/wren-ai-service/tools/config/config.full.yaml b/wren-ai-service/tools/config/config.full.yaml index 526c32464..742cafaaf 100644 --- a/wren-ai-service/tools/config/config.full.yaml +++ b/wren-ai-service/tools/config/config.full.yaml @@ -8,6 +8,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response + seed: 0 max_tokens: 4096 response_format: type: json_object @@ -17,6 +19,8 @@ models: kwargs: temperature: 0 n: 1 + # for better consistency of llm response + seed: 0 max_tokens: 4096 response_format: type: json_object From bd66a8f4cd65579c16efab37586a582dc74eabb6 Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 26 Dec 2024 03:56:08 +0000 Subject: [PATCH 17/25] Upgrade AI Service version to 0.13.7 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index a91a58607..4dd1c75ff 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.13.6" +version = "0.13.7" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0" From d82c737153b95501f0be38bfbbe9e97586fadc55 Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Thu, 26 Dec 2024 14:17:49 +0800 Subject: [PATCH 18/25] chore(wren-ai-service): fix historical question query input (#1064) --- wren-ai-service/src/web/v1/services/ask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/src/web/v1/services/ask.py b/wren-ai-service/src/web/v1/services/ask.py index 6ef6bad2e..c8e22018b 100644 --- a/wren-ai-service/src/web/v1/services/ask.py +++ b/wren-ai-service/src/web/v1/services/ask.py @@ -245,7 +245,7 @@ async def ask( ) historical_question = await self._pipelines["historical_question"].run( - query=user_query, + query=ask_request.query, id=ask_request.project_id, ) From e7c47bfaddd019693427c861fe70fca6c833f43a Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 26 Dec 2024 06:20:43 +0000 Subject: [PATCH 19/25] Upgrade AI Service version to 0.13.8 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 4dd1c75ff..411762941 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.13.7" +version = "0.13.8" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0" From ce4529c7d453d2fdd276d8b0d38b71968fd7f326 Mon Sep 17 00:00:00 2001 From: Freda Lai <42527625+fredalai@users.noreply.github.com> Date: Thu, 26 Dec 2024 14:23:31 +0800 Subject: [PATCH 20/25] feat(wren-ui): remove unnecessary sql column of thread table (#1063) --- .../20241226135712_remove_thread_sql.js | 26 +++++++++++++++++++ .../src/apollo/client/graphql/__types__.ts | 4 --- .../apollo/client/graphql/home.generated.ts | 9 +++---- wren-ui/src/apollo/client/graphql/home.ts | 3 --- .../server/repositories/threadRepository.ts | 1 - wren-ui/src/apollo/server/schema.ts | 8 ------ .../apollo/server/services/askingService.ts | 3 --- 7 files changed, 29 insertions(+), 25 deletions(-) create mode 100644 wren-ui/migrations/20241226135712_remove_thread_sql.js diff --git a/wren-ui/migrations/20241226135712_remove_thread_sql.js b/wren-ui/migrations/20241226135712_remove_thread_sql.js new file mode 100644 index 000000000..cf453b326 --- /dev/null +++ b/wren-ui/migrations/20241226135712_remove_thread_sql.js @@ -0,0 +1,26 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = async function (knex) { + // drop foreign key constraint before altering column type to prevent data loss + await knex.schema.alterTable('thread_response', (table) => { + table.dropForeign('thread_id'); + }); + await knex.schema.alterTable('thread', (table) => { + table.dropColumn('sql'); + }); + await knex.schema.alterTable('thread_response', (table) => { + table.foreign('thread_id').references('thread.id').onDelete('CASCADE'); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = async function (knex) { + await knex.schema.alterTable('thread', (table) => { + table.text('sql').nullable(); + }); +}; diff --git a/wren-ui/src/apollo/client/graphql/__types__.ts b/wren-ui/src/apollo/client/graphql/__types__.ts index 3e41e03d4..412aa82af 100644 --- a/wren-ui/src/apollo/client/graphql/__types__.ts +++ b/wren-ui/src/apollo/client/graphql/__types__.ts @@ -257,8 +257,6 @@ export type DetailedThread = { __typename?: 'DetailedThread'; id: Scalars['Int']; responses: Array; - /** @deprecated Doesn't seem to be reasonable to put a sql in a thread */ - sql: Scalars['String']; }; export type Diagram = { @@ -997,8 +995,6 @@ export type Task = { export type Thread = { __typename?: 'Thread'; id: Scalars['Int']; - /** @deprecated Doesn't seem to be reasonable to put a sql in a thread */ - sql: Scalars['String']; summary: Scalars['String']; }; diff --git a/wren-ui/src/apollo/client/graphql/home.generated.ts b/wren-ui/src/apollo/client/graphql/home.generated.ts index ec71b81eb..fec180e44 100644 --- a/wren-ui/src/apollo/client/graphql/home.generated.ts +++ b/wren-ui/src/apollo/client/graphql/home.generated.ts @@ -37,7 +37,7 @@ export type ThreadQueryVariables = Types.Exact<{ }>; -export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, sql: string, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; +export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; export type ThreadResponseQueryVariables = Types.Exact<{ responseId: Types.Scalars['Int']; @@ -65,7 +65,7 @@ export type CreateThreadMutationVariables = Types.Exact<{ }>; -export type CreateThreadMutation = { __typename?: 'Mutation', createThread: { __typename?: 'Thread', id: number, sql: string } }; +export type CreateThreadMutation = { __typename?: 'Mutation', createThread: { __typename?: 'Thread', id: number } }; export type CreateThreadResponseMutationVariables = Types.Exact<{ threadId: Types.Scalars['Int']; @@ -81,7 +81,7 @@ export type UpdateThreadMutationVariables = Types.Exact<{ }>; -export type UpdateThreadMutation = { __typename?: 'Mutation', updateThread: { __typename?: 'Thread', id: number, sql: string, summary: string } }; +export type UpdateThreadMutation = { __typename?: 'Mutation', updateThread: { __typename?: 'Thread', id: number, summary: string } }; export type DeleteThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -386,7 +386,6 @@ export const ThreadDocument = gql` query Thread($threadId: Int!) { thread(threadId: $threadId) { id - sql responses { ...CommonResponse } @@ -524,7 +523,6 @@ export const CreateThreadDocument = gql` mutation CreateThread($data: CreateThreadInput!) { createThread(data: $data) { id - sql } } `; @@ -592,7 +590,6 @@ export const UpdateThreadDocument = gql` mutation UpdateThread($where: ThreadUniqueWhereInput!, $data: UpdateThreadInput!) { updateThread(where: $where, data: $data) { id - sql summary } } diff --git a/wren-ui/src/apollo/client/graphql/home.ts b/wren-ui/src/apollo/client/graphql/home.ts index 53236c489..f6b12c1ad 100644 --- a/wren-ui/src/apollo/client/graphql/home.ts +++ b/wren-ui/src/apollo/client/graphql/home.ts @@ -144,7 +144,6 @@ export const THREAD = gql` query Thread($threadId: Int!) { thread(threadId: $threadId) { id - sql responses { ...CommonResponse } @@ -180,7 +179,6 @@ export const CREATE_THREAD = gql` mutation CreateThread($data: CreateThreadInput!) { createThread(data: $data) { id - sql } } `; @@ -204,7 +202,6 @@ export const UPDATE_THREAD = gql` ) { updateThread(where: $where, data: $data) { id - sql summary } } diff --git a/wren-ui/src/apollo/server/repositories/threadRepository.ts b/wren-ui/src/apollo/server/repositories/threadRepository.ts index 2a243be54..647c2a4a3 100644 --- a/wren-ui/src/apollo/server/repositories/threadRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadRepository.ts @@ -17,7 +17,6 @@ export interface ThreadRecommendationQuestionResult { export interface Thread { id: number; // ID projectId: number; // Reference to project.id - sql: string; // SQL summary: string; // Thread summary // recommend question diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 167633d34..970680712 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -700,20 +700,12 @@ export const typeDefs = gql` # Thread only consists of basic information of a thread type Thread { id: Int! - sql: String! - @deprecated( - reason: "Doesn't seem to be reasonable to put a sql in a thread" - ) summary: String! } # Detailed thread consists of thread and thread responses type DetailedThread { id: Int! - sql: String! - @deprecated( - reason: "Doesn't seem to be reasonable to put a sql in a thread" - ) responses: [ThreadResponse!]! } diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 6ace876ee..f6982649c 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -575,7 +575,6 @@ export class AskingService implements IAskingService { const { id } = await this.projectService.getCurrentProject(); const thread = await this.threadRepository.createOne({ projectId: id, - sql: input.sql, summary: input.question, }); @@ -969,8 +968,6 @@ export class AskingService implements IAskingService { const { id } = await this.projectService.getCurrentProject(); const thread = await this.threadRepository.createOne({ projectId: id, - // todo: remove sql from thread - sql: view.statement, summary: input.question, }); From 0e4b53ee8807c172b4ff0b2bf1173b62774b52f3 Mon Sep 17 00:00:00 2001 From: Shimin Date: Thu, 26 Dec 2024 14:43:23 +0800 Subject: [PATCH 21/25] fix(wren-ui): remove custom scale & add adjustment flag for adjust chart scenario (#1062) --- .../src/apollo/client/graphql/__types__.ts | 1 + .../apollo/client/graphql/home.generated.ts | 19 +++--- wren-ui/src/apollo/client/graphql/home.ts | 1 + .../src/apollo/server/backgrounds/chart.ts | 1 + .../repositories/threadResponseRepository.ts | 1 + wren-ui/src/apollo/server/schema.ts | 1 + .../apollo/server/services/askingService.ts | 1 + wren-ui/src/components/chart/handler.ts | 60 ------------------- .../pages/home/promptThread/ChartAnswer.tsx | 22 ++++--- wren-ui/src/styles/components/chart.less | 6 +- 10 files changed, 35 insertions(+), 78 deletions(-) diff --git a/wren-ui/src/apollo/client/graphql/__types__.ts b/wren-ui/src/apollo/client/graphql/__types__.ts index 412aa82af..c411e7097 100644 --- a/wren-ui/src/apollo/client/graphql/__types__.ts +++ b/wren-ui/src/apollo/client/graphql/__types__.ts @@ -1040,6 +1040,7 @@ export type ThreadResponseBreakdownDetail = { export type ThreadResponseChartDetail = { __typename?: 'ThreadResponseChartDetail'; + adjustment?: Maybe; chartSchema?: Maybe; description?: Maybe; error?: Maybe; diff --git a/wren-ui/src/apollo/client/graphql/home.generated.ts b/wren-ui/src/apollo/client/graphql/home.generated.ts index fec180e44..1380c96e3 100644 --- a/wren-ui/src/apollo/client/graphql/home.generated.ts +++ b/wren-ui/src/apollo/client/graphql/home.generated.ts @@ -9,9 +9,9 @@ export type CommonBreakdownDetailFragment = { __typename?: 'ThreadResponseBreakd export type CommonAnswerDetailFragment = { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; -export type CommonChartDetailFragment = { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; +export type CommonChartDetailFragment = { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; -export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }; +export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }; export type CommonRecommendedQuestionsTaskFragment = { __typename?: 'RecommendedQuestionsTask', status: Types.RecommendedQuestionsTaskStatus, questions: Array<{ __typename?: 'ResultQuestion', question: string, category: string, sql: string }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; @@ -37,14 +37,14 @@ export type ThreadQueryVariables = Types.Exact<{ }>; -export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; +export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; export type ThreadResponseQueryVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type CreateAskingTaskMutationVariables = Types.Exact<{ data: Types.AskingTaskInput; @@ -73,7 +73,7 @@ export type CreateThreadResponseMutationVariables = Types.Exact<{ }>; -export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type UpdateThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -154,21 +154,21 @@ export type GenerateThreadResponseBreakdownMutationVariables = Types.Exact<{ }>; -export type GenerateThreadResponseBreakdownMutation = { __typename?: 'Mutation', generateThreadResponseBreakdown: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type GenerateThreadResponseBreakdownMutation = { __typename?: 'Mutation', generateThreadResponseBreakdown: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type GenerateThreadResponseAnswerMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type GenerateThreadResponseAnswerMutation = { __typename?: 'Mutation', generateThreadResponseAnswer: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type GenerateThreadResponseAnswerMutation = { __typename?: 'Mutation', generateThreadResponseAnswer: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type GenerateThreadResponseChartMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type GenerateThreadResponseChartMutation = { __typename?: 'Mutation', generateThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type GenerateThreadResponseChartMutation = { __typename?: 'Mutation', generateThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type AdjustThreadResponseChartMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; @@ -176,7 +176,7 @@ export type AdjustThreadResponseChartMutationVariables = Types.Exact<{ }>; -export type AdjustThreadResponseChartMutation = { __typename?: 'Mutation', adjustThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type AdjustThreadResponseChartMutation = { __typename?: 'Mutation', adjustThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql: string, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export const CommonErrorFragmentDoc = gql` fragment CommonError on Error { @@ -221,6 +221,7 @@ export const CommonChartDetailFragmentDoc = gql` error { ...CommonError } + adjustment } ${CommonErrorFragmentDoc}`; export const CommonResponseFragmentDoc = gql` diff --git a/wren-ui/src/apollo/client/graphql/home.ts b/wren-ui/src/apollo/client/graphql/home.ts index f6b12c1ad..994042dfd 100644 --- a/wren-ui/src/apollo/client/graphql/home.ts +++ b/wren-ui/src/apollo/client/graphql/home.ts @@ -50,6 +50,7 @@ const COMMON_CHART_DETAIL = gql` error { ...CommonError } + adjustment } `; diff --git a/wren-ui/src/apollo/server/backgrounds/chart.ts b/wren-ui/src/apollo/server/backgrounds/chart.ts index c5451ec26..fd435326c 100644 --- a/wren-ui/src/apollo/server/backgrounds/chart.ts +++ b/wren-ui/src/apollo/server/backgrounds/chart.ts @@ -205,6 +205,7 @@ export class ChartAdjustmentBackgroundTracker { error: result?.error, description: result?.response?.reasoning, chartSchema: result?.response?.chartSchema, + adjustment: true, }; logger.debug( `Job ${threadResponse.id} chart status changed, updating`, diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index f3d3a81e6..2bb4ee5e4 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -35,6 +35,7 @@ export interface ThreadResponseChartDetail { error?: object; description?: string; chartSchema?: Record; + adjustment?: boolean; } export interface ThreadResponse { diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 970680712..c565a5a85 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -684,6 +684,7 @@ export const typeDefs = gql` error: Error description: String chartSchema: JSON + adjustment: Boolean } type ThreadResponse { diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index f6982649c..5c42a35b4 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -779,6 +779,7 @@ export class AskingService implements IAskingService { chartDetail: { queryId: response.queryId, status: ChartStatus.FETCHING, + adjustment: true, }, }, ); diff --git a/wren-ui/src/components/chart/handler.ts b/wren-ui/src/components/chart/handler.ts index a0d90a0c6..572a0080a 100644 --- a/wren-ui/src/components/chart/handler.ts +++ b/wren-ui/src/components/chart/handler.ts @@ -234,7 +234,6 @@ export default class ChartSpecHandler { private addEncoding(encoding: EncodingSpec) { this.encoding = encoding; - const { x, y } = this.getAxisDomain(); // fill color by x field if AI not provide color(category) field if (isNil(this.encoding.color)) { @@ -253,26 +252,6 @@ export default class ChartSpecHandler { // handle scale on bar chart if (this.mark.type === MarkType.BAR) { - if (y) { - this.encoding.y = { - ...this.encoding.y, - scale: { - domain: y, - nice: false, - }, - }; - } - - if (x) { - this.encoding.x = { - ...this.encoding.x, - scale: { - domain: x, - nice: false, - }, - }; - } - if ('stack' in this.encoding.y) { this.encoding.y.stack = this.options.stack; } @@ -291,45 +270,6 @@ export default class ChartSpecHandler { this.addHoverHighlight(this.encoding); } - private getAxisDomain() { - const xField = this.encoding.x as PositionFieldDef; - const yField = this.encoding.y as PositionFieldDef; - const calculateMaxDomain = (field: PositionFieldDef) => { - if (field?.type !== 'quantitative') return null; - const fieldValue = field.field; - const values = (this.data as any).values.map((d) => d[fieldValue]); - - const maxValue = Math.max(...values); - - // Get the magnitude (e.g., 1, 10, 100, 1000) - const magnitude = Math.pow(10, Math.floor(Math.log10(maxValue))); - - // Get number between 1-10 - const normalizedValue = maxValue / magnitude; - let niceNumber; - - if (normalizedValue <= 1.2) niceNumber = 1.2; - else if (normalizedValue <= 1.5) niceNumber = 1.5; - else if (normalizedValue <= 2) niceNumber = 2; - else if (normalizedValue <= 2.5) niceNumber = 2.5; - else if (normalizedValue <= 3) niceNumber = 3; - else if (normalizedValue <= 4) niceNumber = 4; - else if (normalizedValue <= 5) niceNumber = 5; - else if (normalizedValue <= 7.5) niceNumber = 7.5; - else if (normalizedValue <= 8) niceNumber = 8; - else niceNumber = 10; - - const domainMax = niceNumber * magnitude; - return [0, domainMax]; - }; - const xDomain = calculateMaxDomain(xField); - const yDomain = calculateMaxDomain(yField); - return { - x: xDomain, - y: yDomain, - }; - } - private addHoverHighlight(encoding: EncodingSpec) { const category = ( encoding.color?.condition ? encoding.color.condition : encoding.color diff --git a/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx b/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx index 4a1d36a6a..423ebffa9 100644 --- a/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx +++ b/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx @@ -96,7 +96,7 @@ export default function ChartAnswer(props: Props) { const [form] = Form.useForm(); const { chartDetail } = threadResponse; - const { error, status } = chartDetail || {}; + const { error, status, adjustment } = chartDetail || {}; const [previewData, previewDataResult] = usePreviewDataMutation({ onError: (error) => console.error(error), @@ -203,6 +203,14 @@ export default function ChartAnswer(props: Props) { onResetState(); }; + const regenerateBtn = ( +
+ +
+ ); + if (error) { return (
@@ -212,15 +220,13 @@ export default function ChartAnswer(props: Props) { type="error" showIcon /> -
- -
+ {regenerateBtn}
); } + const chartRegenerateBtn = adjustment ? regenerateBtn : null; + return (
{chartDetail?.description} - {chartSpec && ( + {chartSpec ? ( + ) : ( + chartRegenerateBtn )}
diff --git a/wren-ui/src/styles/components/chart.less b/wren-ui/src/styles/components/chart.less index 1fd6c0ccc..c5317572f 100644 --- a/wren-ui/src/styles/components/chart.less +++ b/wren-ui/src/styles/components/chart.less @@ -17,7 +17,8 @@ cursor: pointer; width: 28px; height: 28px; - opacity: 0.2; + opacity: 0.4; + color: @gray-8; transition: all 0.4s ease-in; &:hover { @@ -33,7 +34,7 @@ .vega-embed { &:hover { summary { - opacity: 0.2 !important; + opacity: 0.4 !important; } } @@ -43,6 +44,7 @@ box-shadow: none; transition: all 0.4s ease-in; color: @gray-8; + opacity: 0.4 !important; &:hover { opacity: 1 !important; From c7faa089abae1795e79b356363a05595f63d231c Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 26 Dec 2024 07:11:05 +0000 Subject: [PATCH 22/25] update wren-ui version to 0.18.10 --- wren-ui/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ui/package.json b/wren-ui/package.json index 6ba4227a6..9458f0779 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -1,6 +1,6 @@ { "name": "wren-ui", - "version": "0.18.9", + "version": "0.18.10", "private": true, "scripts": { "dev": "next dev", From 0e25b4c1c11d1d4f77809e960166b951534f5abe Mon Sep 17 00:00:00 2001 From: Shimin Date: Thu, 26 Dec 2024 15:40:50 +0800 Subject: [PATCH 23/25] fix(wren-ui): chart handler lint (#1066) --- wren-ui/src/components/chart/handler.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/wren-ui/src/components/chart/handler.ts b/wren-ui/src/components/chart/handler.ts index 572a0080a..1b2d71d0d 100644 --- a/wren-ui/src/components/chart/handler.ts +++ b/wren-ui/src/components/chart/handler.ts @@ -1,7 +1,6 @@ import { ChartType } from '@/apollo/client/graphql/__types__'; import { isNil, cloneDeep, uniq, sortBy, omit, isNumber } from 'lodash'; import { Config, TopLevelSpec } from 'vega-lite'; -import { PositionFieldDef } from 'vega-lite/build/src/channeldef'; enum MarkType { ARC = 'arc', From 67603cadd4fe705adc56ecc9454bfc1e53e91f1f Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Thu, 26 Dec 2024 16:08:27 +0800 Subject: [PATCH 24/25] chore(wren-ai-service): refine intent classification for time related questions (#1067) --- .../src/pipelines/generation/intent_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wren-ai-service/src/pipelines/generation/intent_classification.py b/wren-ai-service/src/pipelines/generation/intent_classification.py index 9b0ae93f6..62022acd0 100644 --- a/wren-ai-service/src/pipelines/generation/intent_classification.py +++ b/wren-ai-service/src/pipelines/generation/intent_classification.py @@ -31,6 +31,7 @@ - Steps to rephrase the user's question: - First, try to recognize adjectives in the user's question that are important to the user's intent. - Second, change the adjectives to more specific and clear ones that can be matched to columns in the database schema. + - Third, if the user's question is related to time/date, add time/date format(such as YYYY-MM-DD) in the rephrased_question output. - MUST use the rephrased user's question to make the intent classification. - MUST put the rephrased user's question in the rephrased_question output. - REASONING MUST be within 20 words. From 81e4f2a6fabf8b8ba2440897ace01d3121f2afa0 Mon Sep 17 00:00:00 2001 From: "wren-ai[bot]" Date: Thu, 26 Dec 2024 08:13:37 +0000 Subject: [PATCH 25/25] Upgrade AI Service version to 0.13.9 --- wren-ai-service/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 411762941..c52a21a86 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.13.8" +version = "0.13.9" description = "" authors = ["Jimmy Yeh ", "Pao Sheng Wang ", "Aster Sun "] license = "AGPL-3.0"