Skip to content

Commit

Permalink
Merge pull request #46 from basenana/feature/chat_ofsummary
Browse files Browse the repository at this point in the history
add summary in chat
  • Loading branch information
zwwhdls authored May 6, 2024
2 parents 0dd2292 + 4405dcc commit 4996b19
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 93 deletions.
10 changes: 6 additions & 4 deletions pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ type Statement struct {
context context.Context

// for chat
history []map[string]string
question string
query *models.DocQuery
info string
summary string // summary of doc
history []map[string]string // real chat history
question string // question for chat
query *models.DocQuery // search in doc or dir
historySummary string // summary of chat history
info string // info of embedding

// for ingest or summary
file *models.File // a whole file providing models.File
Expand Down
176 changes: 91 additions & 85 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
package friday

import (
"bytes"
"errors"
"fmt"
"html/template"
"strings"

"github.com/basenana/friday/pkg/llm/prompts"
"github.com/basenana/friday/pkg/models"
)

const remainHistoryNum = 3 // must be odd
const remainHistoryNum = 5 // must be odd

func (f *Friday) History(history []map[string]string) *Friday {
f.statement.history = history
Expand All @@ -42,22 +44,40 @@ func (f *Friday) Question(q string) *Friday {
return f
}

func (f *Friday) OfSummary(summary string) *Friday {
f.statement.summary = summary
return f
}

func (f *Friday) GetRealHistory() []map[string]string {
return f.statement.history
}

func (f *Friday) Chat(res *ChatState) *Friday {
func (f *Friday) preCheck(res *ChatState) error {
if len(f.statement.history) == 0 {
f.Error = errors.New("history can not be nil")
return f
return errors.New("history can not be nil")
}
if f.LLM == nil {
f.Error = errors.New("llm client of friday is not set")
return errors.New("llm client of friday is not set")
}
if res == nil {
return errors.New("result can not be nil")
}
return nil
}

func (f *Friday) Chat(res *ChatState) *Friday {
if f.Error = f.preCheck(res); f.Error != nil {
return f
}

var (
dialogues = []map[string]string{}
systemInfo = ""
)

// search for docs
if f.statement.query != nil {
// search for docs
questions := ""
for _, d := range f.statement.history {
if d["role"] == f.LLM.GetUserModel() {
Expand All @@ -69,99 +89,85 @@ func (f *Friday) Chat(res *ChatState) *Friday {
return f
}
}

// If the number of dialogue rounds exceeds some rounds, should conclude it.
if len(f.statement.history) >= remainHistoryNum {
f.statement.historySummary = f.summaryHistory().statement.historySummary
if f.Error != nil {
return f
}
}

// if it already has system info, rewrite it
if (f.statement.history)[0]["role"] == "system" {
f.statement.history = f.statement.history[1:]
}
f.statement.history = append([]map[string]string{
{
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: %s\n", f.statement.info),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}, f.statement.history...)

return f.chat(res)
}

func (f *Friday) chat(res *ChatState) *Friday {
if res == nil {
f.Error = errors.New("result can not be nil")
return f
}
var (
dialogues = []map[string]string{}
)
// regenerate system info
systemInfo = f.generateSystemInfo()

// If the number of dialogue rounds exceeds 2 rounds, should conclude it.
if len(f.statement.history) >= remainHistoryNum {
sumDialogue := make([]map[string]string, len(f.statement.history))
copy(sumDialogue, f.statement.history)
sumDialogue[len(sumDialogue)-1] = map[string]string{
"role": f.LLM.GetSystemModel(),
"content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内",
}
var (
sumBuf = make(chan map[string]string)
sum = make(map[string]string)
errCh = make(chan error)
)
go func() {
defer close(errCh)
_, err := f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
f.Error = err
return f
}
case sum = <-sumBuf:
// add context prompt for dialogue
if f.statement.query != nil {
// there has been ingest info, combine them.
dialogues = []map[string]string{
{
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf(
"%s\n%s",
f.statement.history[0]["content"],
fmt.Sprintf("这是历史聊天总结作为前情提要:%s\n", sum["content"]),
),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}
} else {
dialogues = []map[string]string{
{
"role": f.LLM.GetSystemModel(),
"content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]),
},
{
"role": f.LLM.GetAssistantModel(),
"content": "",
},
}
}
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...)
dialogues = []map[string]string{
{"role": f.LLM.GetSystemModel(), "content": systemInfo},
{"role": f.LLM.GetAssistantModel(), "content": ""},
}
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...)
} else {
dialogues = make([]map[string]string, len(f.statement.history))
copy(dialogues, f.statement.history)
}

// go for llm
f.statement.history = dialogues // return realHistory
_, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response)
if err != nil {
f.Error = err
return f
_, f.Error = f.LLM.Chat(f.statement.context, true, dialogues, res.Response)
return f
}

func (f *Friday) generateSystemInfo() string {
systemTemplate := "基于以下内容,简洁和专业的来回答用户的问题。答案请使用中文。\n"
if f.statement.summary != "" {
systemTemplate += "\n这是文章简介: {{ .Summary }}\n"
}
if f.statement.info != "" {
systemTemplate += "\n这是已知内容: {{ .Info }}\n"
}
if f.statement.historySummary != "" {
systemTemplate += "\n这是历史聊天总结作为前情提要: {{ .HistorySummary }}\n"
}

temp := template.Must(template.New("systemInfo").Parse(systemTemplate))
prompt := new(bytes.Buffer)
f.Error = temp.Execute(prompt, f.statement)
if f.Error != nil {
return ""
}
return prompt.String()
}

func (f *Friday) summaryHistory() *Friday {
sumDialogue := make([]map[string]string, len(f.statement.history))
copy(sumDialogue, f.statement.history)
sumDialogue[len(sumDialogue)-1] = map[string]string{
"role": f.LLM.GetSystemModel(),
"content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内",
}
var (
sumBuf = make(chan map[string]string)
sum = make(map[string]string)
errCh = make(chan error)
)
go func() {
defer close(errCh)
_, err := f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
f.Error = err
}
case sum = <-sumBuf:
f.statement.historySummary = sum["content"]
}
return f
}
Expand Down
4 changes: 0 additions & 4 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ var _ = Describe("TestQuestion", func() {
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expand Down Expand Up @@ -106,7 +105,6 @@ var _ = Describe("TestQuestion", func() {
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expand Down Expand Up @@ -144,7 +142,6 @@ var _ = Describe("TestQuestion", func() {
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response

Expand Down Expand Up @@ -191,7 +188,6 @@ var _ = Describe("TestQuestion", func() {
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expand Down

0 comments on commit 4996b19

Please sign in to comment.