Skip to content

Commit

Permalink
update: add stream chat api
Browse files Browse the repository at this point in the history
Signed-off-by: zwwhdls <zww@hdls.me>
  • Loading branch information
zwwhdls committed Mar 18, 2024
1 parent d0755f3 commit a22f663
Show file tree
Hide file tree
Showing 16 changed files with 289 additions and 158 deletions.
17 changes: 12 additions & 5 deletions cmd/apps/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"strconv"
"strings"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -60,15 +61,21 @@ func chat(dirId int64, history []map[string]string) error {
f := friday.Fri.WithContext(context.TODO()).History(history).SearchIn(&models.DocQuery{
ParentId: dirId,
})
res := &friday.ChatState{}
f = f.Chat(res)
resp := make(chan map[string]string)
res := &friday.ChatState{
Response: resp,
}
go func() {
f = f.Chat(res)
close(resp)
}()
if f.Error != nil {
return f.Error
}

fmt.Println("Dialogues: ")
d, _ := json.Marshal(res.Dialogues)
fmt.Println(string(d))
fmt.Printf("Usage: %v", res.Tokens)
for line := range res.Response {
fmt.Printf("%v: %v\n", time.Now().Format("15:04:05"), line)
}
return nil
}
6 changes: 3 additions & 3 deletions pkg/friday/friday.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ type Statement struct {
}

type ChatState struct {
Dialogues []map[string]string // dialogue result for chat
Answer string // answer result for question
Tokens map[string]int
Response chan map[string]string // dialogue result for chat
Answer string // answer result for question
Tokens map[string]int
}

type IngestState struct {
Expand Down
5 changes: 3 additions & 2 deletions pkg/friday/keywords_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ func (f FakeKeyWordsLLM) Completion(ctx context.Context, prompt prompts.PromptTe
return []string{"a, b, c"}, nil, nil
}

func (f FakeKeyWordsLLM) Chat(ctx context.Context, history []map[string]string) (answers map[string]string, tokens map[string]int, err error) {
return map[string]string{"content": "a, b, c"}, nil, nil
func (f FakeKeyWordsLLM) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) {
answers <- map[string]string{"content": "a, b, c"}
return nil, nil
}
46 changes: 28 additions & 18 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,48 @@ func (f *Friday) chat(res *ChatState) *Friday {
"role": "system",
"content": "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内",
})
sum, usage, e := f.LLM.Chat(f.statement.context, sumDialogue)
if e != nil {
f.Error = e
var (
sumBuf = make(chan map[string]string)
sum = make(map[string]string)
usage = make(map[string]int)
err error
)
defer close(sumBuf)
go func() {
usage, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
}()
if err != nil {
f.Error = err
return f
}
tokens = mergeTokens(usage, tokens)

// add context prompt for dialogue
dialogues = append(dialogues, []map[string]string{
f.statement.history[0],
{
"role": "system",
"content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]),
},
}...)
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-5:len(f.statement.history)]...)
select {
case <-f.statement.context.Done():
return f
case sum = <-sumBuf:
// add context prompt for dialogue
dialogues = append(dialogues, []map[string]string{
f.statement.history[0],
{
"role": "system",
"content": fmt.Sprintf("这是历史聊天总结作为前情提要:%s", sum["content"]),
},
}...)
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-5:len(f.statement.history)]...)
}
} else {
dialogues = make([]map[string]string, len(f.statement.history))
copy(dialogues, f.statement.history)
}

// go for llm
ans, usage, err := f.LLM.Chat(f.statement.context, dialogues)
usage, err := f.LLM.Chat(f.statement.context, true, dialogues, res.Response)
if err != nil {
f.Error = err
return f
}
f.Log.Debugf("Chat result: %s", ans)
dialogues = append(dialogues, ans)
tokens = mergeTokens(tokens, usage)

res.Dialogues = dialogues
res.Tokens = tokens
return f
}
Expand Down Expand Up @@ -184,7 +194,7 @@ func (f *Friday) searchDocs(q string) {

cs := []string{}
for _, c := range docs {
f.Log.Debugf("searched from [%s] for %s", c.Name, c.Content)
//f.Log.Debugf("searched from [%s] for %s", c.Name, c.Content)
cs = append(cs, c.Content)
}
f.statement.info = strings.Join(cs, "\n")
Expand Down
122 changes: 56 additions & 66 deletions pkg/friday/question_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,23 @@ var _ = Describe("TestQuestion", func() {
"role": "user",
"content": "Who are you?",
}}
res := ChatState{}
f := loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
var (
res = ChatState{
Response: make(chan map[string]string),
}
f *Friday
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expect(len(res.Dialogues)).Should(Equal(3))
Expect(res.Dialogues[0]["role"]).Should(Equal("system"))
Expect(res.Dialogues[0]["content"]).Should(Equal("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: There are logs of questions"))
Expect(res.Dialogues[1]["role"]).Should(Equal("user"))
Expect(res.Dialogues[2]["role"]).Should(Equal("assistant"))
Expect(len(resp)).Should(Equal(2))
Expect(resp["role"]).Should(Equal("assistant"))
})
It("chat for second time", func() {
history := []map[string]string{
{
"role": "system",
"content": "基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。",
},
{
"role": "user",
"content": "Who are you?",
Expand All @@ -96,23 +98,23 @@ var _ = Describe("TestQuestion", func() {
"content": "abc",
},
}
res := ChatState{}
f := loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
var (
res = ChatState{
Response: make(chan map[string]string),
}
f *Friday
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expect(len(res.Dialogues)).Should(Equal(5))
Expect(res.Dialogues[0]["role"]).Should(Equal("system"))
Expect(res.Dialogues[0]["content"]).Should(Equal("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: There are logs of questions"))
Expect(res.Dialogues[1]["role"]).Should(Equal("user"))
Expect(res.Dialogues[2]["role"]).Should(Equal("assistant"))
Expect(res.Dialogues[3]["role"]).Should(Equal("user"))
Expect(res.Dialogues[4]["role"]).Should(Equal("assistant"))
Expect(len(resp)).Should(Equal(2))
Expect(resp["role"]).Should(Equal("assistant"))
})
It("chat for three times", func() {
history := []map[string]string{
{
"role": "system",
"content": "基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。",
},
{
"role": "user",
"content": "one",
Expand All @@ -134,34 +136,24 @@ var _ = Describe("TestQuestion", func() {
"content": "three",
},
}
res := ChatState{}
f := loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
var (
res = ChatState{
Response: make(chan map[string]string),
}
f *Friday
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response

Expect(f.Error).Should(BeNil())
Expect(len(res.Dialogues)).Should(Equal(8))
Expect(res.Dialogues[0]["role"]).Should(Equal("system"))
Expect(res.Dialogues[0]["content"]).Should(Equal("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: There are logs of questions"))
Expect(res.Dialogues[1]["role"]).Should(Equal("system"))
Expect(res.Dialogues[1]["content"]).Should(Equal("这是历史聊天总结作为前情提要:I am an answer"))
Expect(res.Dialogues[2]["role"]).Should(Equal("user"))
Expect(res.Dialogues[2]["content"]).Should(Equal("one"))
Expect(res.Dialogues[3]["role"]).Should(Equal("assistant"))
Expect(res.Dialogues[4]["role"]).Should(Equal("user"))
Expect(res.Dialogues[4]["content"]).Should(Equal("two"))
Expect(res.Dialogues[5]["role"]).Should(Equal("assistant"))
Expect(res.Dialogues[6]["role"]).Should(Equal("user"))
Expect(res.Dialogues[6]["content"]).Should(Equal("three"))
Expect(res.Dialogues[7]["role"]).Should(Equal("assistant"))
Expect(len(resp)).Should(Equal(2))
Expect(resp["role"]).Should(Equal("assistant"))
})
It("chat for four times", func() {
history := []map[string]string{
{
"role": "system",
"content": "基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。",
},
{
"role": "system",
"content": "这是历史聊天总结作为前情提要:I am an answer",
},
{
"role": "user",
"content": "one",
Expand Down Expand Up @@ -191,23 +183,20 @@ var _ = Describe("TestQuestion", func() {
"content": "four",
},
}
res := ChatState{}
f := loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
var (
res = ChatState{
Response: make(chan map[string]string),
}
f *Friday
)
go func() {
f = loFriday.WithContext(context.TODO()).SearchIn(&models.DocQuery{ParentId: 1}).History(history).Chat(&res)
close(res.Response)
}()
resp := <-res.Response
Expect(f.Error).Should(BeNil())
Expect(len(res.Dialogues)).Should(Equal(8))
Expect(res.Dialogues[0]["role"]).Should(Equal("system"))
Expect(res.Dialogues[0]["content"]).Should(Equal("基于以下已知信息,简洁和专业的来回答用户的问题。答案请使用中文。 \n\n已知内容: There are logs of questions"))
Expect(res.Dialogues[1]["role"]).Should(Equal("system"))
Expect(res.Dialogues[1]["content"]).Should(Equal("这是历史聊天总结作为前情提要:I am an answer"))
Expect(res.Dialogues[2]["role"]).Should(Equal("user"))
Expect(res.Dialogues[2]["content"]).Should(Equal("two"))
Expect(res.Dialogues[3]["role"]).Should(Equal("assistant"))
Expect(res.Dialogues[4]["role"]).Should(Equal("user"))
Expect(res.Dialogues[4]["content"]).Should(Equal("three"))
Expect(res.Dialogues[5]["role"]).Should(Equal("assistant"))
Expect(res.Dialogues[6]["role"]).Should(Equal("user"))
Expect(res.Dialogues[6]["content"]).Should(Equal("four"))
Expect(res.Dialogues[7]["role"]).Should(Equal("assistant"))
Expect(len(resp)).Should(Equal(2))
Expect(resp["role"]).Should(Equal("assistant"))
})
})
})
Expand Down Expand Up @@ -251,6 +240,7 @@ func (f FakeQuestionLLM) Completion(ctx context.Context, prompt prompts.PromptTe
return []string{"I am an answer"}, nil, nil
}

func (f FakeQuestionLLM) Chat(ctx context.Context, history []map[string]string) (answers map[string]string, tokens map[string]int, err error) {
return map[string]string{"role": "assistant", "content": "I am an answer"}, nil, nil
func (f FakeQuestionLLM) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) {
answers <- map[string]string{"role": "assistant", "content": "I am an answer"}
return nil, nil
}
5 changes: 3 additions & 2 deletions pkg/friday/summary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func (f FakeSummaryLLM) Completion(ctx context.Context, prompt prompts.PromptTem
return []string{"a b c"}, nil, nil
}

func (f FakeSummaryLLM) Chat(ctx context.Context, history []map[string]string) (answers map[string]string, tokens map[string]int, err error) {
return map[string]string{"content": "a b c"}, nil, nil
func (f FakeSummaryLLM) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) {
answers <- map[string]string{"content": "a b c"}
return nil, nil
}
49 changes: 31 additions & 18 deletions pkg/llm/client/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ import (
"fmt"
)

func (g *Gemini) Chat(ctx context.Context, history []map[string]string) (map[string]string, map[string]int, error) {
path := fmt.Sprintf("v1beta/models/%s:generateContent", *g.conf.Model)
func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]string, answers chan<- map[string]string) (tokens map[string]int, err error) {
path := fmt.Sprintf("v1beta/models/%s:streamGenerateContent", *g.conf.Model)

contents := make([]map[string]any, 0)
for _, hs := range history {
Expand All @@ -35,25 +35,38 @@ func (g *Gemini) Chat(ctx context.Context, history []map[string]string) (map[str
})
}

respBody, err := g.request(ctx, path, "POST", map[string]any{"contents": contents})
buf := make(chan []byte)
go func() {
defer close(buf)
err = g.request(ctx, stream, path, "POST", map[string]any{"contents": contents}, buf)
}()
if err != nil {
return nil, nil, err
return
}

var res ChatResult
err = json.Unmarshal(respBody, &res)
if err != nil {
return nil, nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(respBody))
return nil, nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
ans := make(map[string]string)
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans[c.Content.Role] = t.Text
for line := range buf {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
return nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(line))
return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
ans := make(map[string]string)
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans[c.Content.Role] = t.Text
}
}
select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in gemini chat")
return
case answers <- ans:
continue
}
}
return ans, nil, err
return nil, err
}
Loading

0 comments on commit a22f663

Please sign in to comment.