diff --git a/pkg/friday/friday.go b/pkg/friday/friday.go index 87951d3..bcade3b 100644 --- a/pkg/friday/friday.go +++ b/pkg/friday/friday.go @@ -62,10 +62,12 @@ type Statement struct { context context.Context // for chat - history []map[string]string - question string - query *models.DocQuery - info string + summary string // summary of doc + history []map[string]string // real chat history + question string // question for chat + query *models.DocQuery // search in doc or dir + historySummary string // summary of chat history + info string // info of embedding // for ingest or summary file *models.File // a whole file providing models.File diff --git a/pkg/friday/question.go b/pkg/friday/question.go index 34a467a..fdaebbe 100644 --- a/pkg/friday/question.go +++ b/pkg/friday/question.go @@ -17,15 +17,17 @@ package friday import ( + "bytes" "errors" "fmt" + "html/template" "strings" "github.com/basenana/friday/pkg/llm/prompts" "github.com/basenana/friday/pkg/models" ) -const remainHistoryNum = 3 // must be odd +const remainHistoryNum = 5 // must be odd func (f *Friday) History(history []map[string]string) *Friday { f.statement.history = history @@ -42,22 +44,40 @@ func (f *Friday) Question(q string) *Friday { return f } +func (f *Friday) OfSummary(summary string) *Friday { + f.statement.summary = summary + return f +} + func (f *Friday) GetRealHistory() []map[string]string { return f.statement.history } -func (f *Friday) Chat(res *ChatState) *Friday { +func (f *Friday) preCheck(res *ChatState) error { if len(f.statement.history) == 0 { - f.Error = errors.New("history can not be nil") - return f + return errors.New("history can not be nil") } if f.LLM == nil { - f.Error = errors.New("llm client of friday is not set") + return errors.New("llm client of friday is not set") + } + if res == nil { + return errors.New("result can not be nil") + } + return nil +} + +func (f *Friday) Chat(res *ChatState) *Friday { + if f.Error = f.preCheck(res); f.Error != nil { return f } + var ( + dialogues = []map[string]string{} + systemInfo = "" + ) + + // search for docs if f.statement.query != nil { - // search for docs questions := "" for _, d := range f.statement.history { if d["role"] == f.LLM.GetUserModel() { @@ -69,88 +89,29 @@ func (f *Friday) Chat(res *ChatState) *Friday { return f } } + + // If the number of dialogue rounds exceeds some rounds, should conclude it. + if len(f.statement.history) >= remainHistoryNum { + f.statement.historySummary = f.summaryHistory().statement.historySummary + if f.Error != nil { + return f + } + } + + // if it already has system info, rewrite it if (f.statement.history)[0]["role"] == "system" { f.statement.history = f.statement.history[1:] } - f.statement.history = append([]map[string]string{ - { - "role": f.LLM.GetSystemModel(), - "content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s\n", f.statement.info), - }, - { - "role": f.LLM.GetAssistantModel(), - "content": "", - }, - }, f.statement.history...) - - return f.chat(res) -} -func (f *Friday) chat(res *ChatState) *Friday { - if res == nil { - f.Error = errors.New("result can not be nil") - return f - } - var ( - dialogues = []map[string]string{} - ) + // regenerate system info + systemInfo = f.generateSystemInfo() - // If the number of dialogue rounds exceeds 2 rounds, should conclude it. if len(f.statement.history) >= remainHistoryNum { - sumDialogue := make([]map[string]string, len(f.statement.history)) - copy(sumDialogue, f.statement.history) - sumDialogue[len(sumDialogue)-1] = map[string]string{ - "role": f.LLM.GetSystemModel(), - "content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内", - } - var ( - sumBuf = make(chan map[string]string) - sum = make(map[string]string) - errCh = make(chan error) - ) - go func() { - defer close(errCh) - _, err := f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf) - errCh <- err - }() - select { - case err := <-errCh: - if err != nil { - f.Error = err - return f - } - case sum = <-sumBuf: - // add context prompt for dialogue - if f.statement.query != nil { - // there has been ingest info, combine them. - dialogues = []map[string]string{ - { - "role": f.LLM.GetSystemModel(), - "content": fmt.Sprintf( - "%s\n%s", - f.statement.history[0]["content"], - fmt.Sprintf("这是历史聊天总结作为前情提要:%s\n", sum["content"]), - ), - }, - { - "role": f.LLM.GetAssistantModel(), - "content": "", - }, - } - } else { - dialogues = []map[string]string{ - { - "role": f.LLM.GetSystemModel(), - "content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]), - }, - { - "role": f.LLM.GetAssistantModel(), - "content": "", - }, - } - } - dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...) + dialogues = []map[string]string{ + {"role": f.LLM.GetSystemModel(), "content": systemInfo}, + {"role": f.LLM.GetAssistantModel(), "content": ""}, } + dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...) } else { dialogues = make([]map[string]string, len(f.statement.history)) copy(dialogues, f.statement.history) @@ -158,10 +119,55 @@ func (f *Friday) chat(res *ChatState) *Friday { // go for llm f.statement.history = dialogues // return realHistory - _, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response) - if err != nil { - f.Error = err - return f + _, f.Error = f.LLM.Chat(f.statement.context, true, dialogues, res.Response) + return f +} + +func (f *Friday) generateSystemInfo() string { + systemTemplate := "基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n" + if f.statement.summary != "" { + systemTemplate += "\n这是文章简介: {{ .Summary }}\n" + } + if f.statement.info != "" { + systemTemplate += "\n这是已知内容: {{ .Info }}\n" + } + if f.statement.historySummary != "" { + systemTemplate += "\n这是历史聊天总结作为前情提要: {{ .HistorySummary }}\n" + } + + temp := template.Must(template.New("systemInfo").Parse(systemTemplate)) + prompt := new(bytes.Buffer) + f.Error = temp.Execute(prompt, f.statement) + if f.Error != nil { + return "" + } + return prompt.String() +} + +func (f *Friday) summaryHistory() *Friday { + sumDialogue := make([]map[string]string, len(f.statement.history)) + copy(sumDialogue, f.statement.history) + sumDialogue[len(sumDialogue)-1] = map[string]string{ + "role": f.LLM.GetSystemModel(), + "content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内", + } + var ( + sumBuf = make(chan map[string]string) + sum = make(map[string]string) + errCh = make(chan error) + ) + go func() { + defer close(errCh) + _, err := f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf) + errCh <- err + }() + select { + case err := <-errCh: + if err != nil { + f.Error = err + } + case sum = <-sumBuf: + f.statement.historySummary = sum["content"] } return f } diff --git a/pkg/friday/question_test.go b/pkg/friday/question_test.go index 5ed0811..a765b69 100644 --- a/pkg/friday/question_test.go +++ b/pkg/friday/question_test.go @@ -76,7 +76,6 @@ var _ = Describe("TestQuestion", func() { ) go func() { f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res) - close(res.Response) }() resp := <-res.Response Expect(f.Error).Should(BeNil()) @@ -106,7 +105,6 @@ var _ = Describe("TestQuestion", func() { ) go func() { f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res) - close(res.Response) }() resp := <-res.Response Expect(f.Error).Should(BeNil()) @@ -144,7 +142,6 @@ var _ = Describe("TestQuestion", func() { ) go func() { f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res) - close(res.Response) }() resp := <-res.Response @@ -191,7 +188,6 @@ var _ = Describe("TestQuestion", func() { ) go func() { f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res) - close(res.Response) }() resp := <-res.Response Expect(f.Error).Should(BeNil())