Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: chat with dir #39

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions cmd/apps/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2023 friday
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package apps

import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"

"github.com/spf13/cobra"

"github.com/basenana/friday/pkg/friday"
"github.com/basenana/friday/pkg/models"
)

var ChatCmd = &cobra.Command{
Use: "chat",
Short: "chat with llm base on knowledge",
Run: func(cmd *cobra.Command, args []string) {
if len(args) <= 1 {
panic("dirId and history is needed.")
}
dirIdStr := args[0]
dirId, err := strconv.Atoi(dirIdStr)
if err != nil {
panic(err)
}

historyStr := fmt.Sprint(strings.Join(args[1:], " "))

history := make([]map[string]string, 0)
err = json.Unmarshal([]byte(historyStr), &history)
if err != nil {
panic(err)
}

if err := chat(int64(dirId), history); err != nil {
panic(err)
}
},
}

func chat(dirId int64, history []map[string]string) error {
f := friday.Fri.WithContext(context.TODO()).History(history).SearchIn(&models.DocQuery{
ParentId: dirId,
})
resp := make(chan map[string]string)
res := &friday.ChatState{
Response: resp,
}
go func() {
f = f.Chat(res)
close(resp)
}()
if f.Error != nil {
return f.Error
}

fmt.Println("Dialogues: ")
for line := range res.Response {
fmt.Printf("%v: %v\n", time.Now().Format("15:04:05"), line)
}
return nil
}
10 changes: 6 additions & 4 deletions cmd/apps/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ var IngestCmd = &cobra.Command{
}

func ingest(ps string) error {
usage, err := friday.Fri.IngestFromOriginFile(context.TODO(), ps)
if err != nil {
return err
f := friday.Fri.WithContext(context.TODO()).OriginFile(&ps)
res := &friday.IngestState{}
f = f.Ingest(res)
if f.Error != nil {
return f.Error
}
fmt.Printf("Usage: %v", usage)
fmt.Printf("Usage: %v", res.Tokens)
return nil
}
13 changes: 8 additions & 5 deletions cmd/apps/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ var QuestionCmd = &cobra.Command{
}

func run(question string) error {
a, usage, err := friday.Fri.Question(context.TODO(), 0, question)
if err != nil {
return err
f := friday.Fri.WithContext(context.TODO()).Question(question)
res := &friday.ChatState{}
f = f.Complete(res)
if f.Error != nil {
return f.Error
}

fmt.Println("Answer: ")
fmt.Println(a)
fmt.Printf("Usage: %v", usage)
fmt.Println(res.Answer)
fmt.Printf("Usage: %v", res.Tokens)
return nil
}
4 changes: 2 additions & 2 deletions cmd/apps/wechat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ var WeChatCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
ps := fmt.Sprint(strings.Join(args, " "))

if err := chat(ps); err != nil {
if err := wechat(ps); err != nil {
panic(err)
}
},
}

func chat(ps string) error {
func wechat(ps string) error {
a, usage, err := friday.Fri.ChatConclusionFromFile(context.TODO(), ps)
if err != nil {
return err
Expand Down
1 change: 1 addition & 0 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func init() {
}

RootCmd.AddCommand(apps.QuestionCmd)
RootCmd.AddCommand(apps.ChatCmd)
RootCmd.AddCommand(apps.IngestCmd)
RootCmd.AddCommand(apps.WeChatCmd)
RootCmd.AddCommand(apps.SummaryCmd)
Expand Down
9 changes: 6 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ type Config struct {
Logger logger.Logger

// llm limit token
LimitToken int `json:"limit_token,omitempty"`
LimitToken int `json:"limit_token,omitempty"` // used by summary, split input into mutil sub-docs summaried by llm separately.

// openai key
OpenAIBaseUrl string `json:"open_ai_base_url,omitempty"` // if openai is used for embedding or llm, it is needed, default is "https://api.openai.com"
OpenAIKey string `json:"open_ai_key,omitempty"` // if openai is used for embedding or llm, it is needed

// gemini key
GeminiBaseUri string `json:"gemini_base_uri,omitempty"` // if gemini is used for embedding or llm, it is needed, default is "https://generativelanguage.googleapis.com"
GeminiKey string `json:"gemini_key,omitempty"` // if gemini is used for embedding or llm, it is needed

// embedding config
EmbeddingConfig EmbeddingConfig `json:"embedding_config,omitempty"`

Expand Down Expand Up @@ -60,7 +64,7 @@ type OpenAIConfig struct {
QueryPerMinute int `json:"query_per_minute,omitempty"` // qpm, default is 3
Burst int `json:"burst,omitempty"` // burst, default is 5
Model *string `json:"model,omitempty"` // model of openai, default for llm is "gpt-3.5-turbo"; default for embedding is "text-embedding-ada-002"
MaxReturnToken *int `json:"max_return_token,omitempty"`
MaxReturnToken *int `json:"max_return_token,omitempty"` // maxReturnToken + VectorStoreConfig.TopK * TextSpliterConfig.SpliterChunkSize <= token limit of llm model
FrequencyPenalty *uint `json:"frequency_penalty,omitempty"`
PresencePenalty *uint `json:"presence_penalty,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
Expand All @@ -70,7 +74,6 @@ type GeminiConfig struct {
QueryPerMinute int `json:"query_per_minute,omitempty"` // qpm, default is 3
Burst int `json:"burst,omitempty"` // burst, default is 5
Model *string `json:"model,omitempty"` // model of gemini, default for llm is "gemini-pro"; default for embedding is "embedding-001"
Key string `json:"key"` // key of Gemini api
}

type EmbeddingConfig struct {
Expand Down
5 changes: 3 additions & 2 deletions flow/operator/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func (i *ingestOperator) Do(ctx context.Context, param *flow.Parameter) error {
Name: source,
Content: knowledge,
}
_, err := friday.Fri.IngestFromFile(context.TODO(), doc)
return err
res := friday.IngestState{}
f := friday.Fri.WithContext(context.TODO()).File(&doc).Ingest(&res)
return f.Error
}
4 changes: 2 additions & 2 deletions pkg/build/withvector/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
llmClient = glm_6b.NewGLM(log, conf.LLMConfig.GLM6B.Url)
}
if conf.LLMConfig.LLMType == config.LLMGemini {
llmClient = gemini.NewGemini(log, conf.LLMConfig.Gemini)
llmClient = gemini.NewGemini(log, conf.GeminiBaseUri, conf.GeminiKey, conf.LLMConfig.Gemini)
}

if conf.LLMConfig.Prompts != nil {
Expand All @@ -80,7 +80,7 @@ func NewFridayWithVector(conf *config.Config, vectorClient vectorstore.VectorSto
conf.VectorStoreConfig.EmbeddingDim = len(testEmbed)
}
if conf.EmbeddingConfig.EmbeddingType == config.EmbeddingGemini {
embeddingModel = geminiembedding.NewGeminiEmbedding(log, conf.EmbeddingConfig.Gemini)
embeddingModel = geminiembedding.NewGeminiEmbedding(log, conf.GeminiBaseUri, conf.GeminiKey, conf.EmbeddingConfig.Gemini)
}

defaultVectorTopK := friday.DefaultTopK
Expand Down
4 changes: 2 additions & 2 deletions pkg/embedding/gemini/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ type GeminiEmbedding struct {
*gemini.Gemini
}

func NewGeminiEmbedding(log logger.Logger, conf config.GeminiConfig) embedding.Embedding {
func NewGeminiEmbedding(log logger.Logger, baseUrl, key string, conf config.GeminiConfig) embedding.Embedding {
return &GeminiEmbedding{
Gemini: gemini.NewGemini(log, conf),
Gemini: gemini.NewGemini(log, baseUrl, key, conf),
}
}

Expand Down
48 changes: 47 additions & 1 deletion pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package friday

import (
"context"

"github.com/basenana/friday/pkg/embedding"
"github.com/basenana/friday/pkg/llm"
"github.com/basenana/friday/pkg/models"
"github.com/basenana/friday/pkg/spliter"
"github.com/basenana/friday/pkg/utils/logger"
"github.com/basenana/friday/pkg/vectorstore"
Expand All @@ -37,7 +40,9 @@ var (
)

type Friday struct {
Log logger.Logger
Log logger.Logger
Error error
statement Statement

LimitToken int

Expand All @@ -51,3 +56,44 @@ type Friday struct {

Spliter spliter.Spliter
}

type Statement struct {
context context.Context

// for chat
history []map[string]string
question string
query *models.DocQuery
info string

// for ingest
file *models.File // a whole file providing models.File
elementFile *string // a whole file given an element-style origin file
originFile *string // a whole file given an origin file
elements []models.Element
}

type ChatState struct {
Response chan map[string]string // dialogue result for chat
Answer string // answer result for question
Tokens map[string]int
}

type IngestState struct {
Tokens map[string]int
}

func (f *Friday) WithContext(ctx context.Context) *Friday {
t := &Friday{
Log: f.Log,
statement: Statement{context: ctx},
LimitToken: f.LimitToken,
LLM: f.LLM,
Prompts: f.Prompts,
Embedding: f.Embedding,
Vector: f.Vector,
VectorTopK: f.VectorTopK,
Spliter: f.Spliter,
}
return t
}
Loading
Loading