Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: err in chat #44

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions pkg/friday/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,19 @@ func (f *Friday) chat(res *ChatState) *Friday {
var (
sumBuf = make(chan map[string]string)
sum = make(map[string]string)
err error
errCh = make(chan error)
)
go func() {
_, err = f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
defer close(errCh)
_, err := f.LLM.Chat(f.statement.context, false, sumDialogue, sumBuf)
errCh <- err
}()
select {
case err := <-errCh:
if err != nil {
f.Error = err
return
return f
}
}()
select {
case <-f.statement.context.Done():
f.Error = errors.New("context canceled")
return f
case sum = <-sumBuf:
// add context prompt for dialogue
if f.statement.query != nil {
Expand Down Expand Up @@ -151,9 +151,6 @@ func (f *Friday) chat(res *ChatState) *Friday {
}
dialogues = append(dialogues, f.statement.history[len(f.statement.history)-remainHistoryNum:len(f.statement.history)]...)
}
if f.Error != nil {
return f
}
} else {
dialogues = make([]map[string]string, len(f.statement.history))
copy(dialogues, f.statement.history)
Expand Down
72 changes: 36 additions & 36 deletions pkg/llm/client/gemini/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,52 +53,52 @@ func (g *Gemini) Chat(ctx context.Context, stream bool, history []map[string]str
})
}

buf := make(chan []byte)
var (
buf = make(chan []byte)
errCh = make(chan error)
)
go func() {
defer close(buf)
err = g.request(ctx, stream, path, "POST", map[string]any{"contents": contents}, buf)
if err != nil {
return
}
defer close(errCh)
errCh <- g.request(ctx, stream, path, "POST", map[string]any{"contents": contents}, buf)
}()

for line := range buf {
ans := make(map[string]string)
l := strings.TrimSpace(string(line))
if stream {
if l == "EOF" {
ans["content"] = "EOF"
} else {
for {
select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in gemini chat")
return
case err = <-errCh:
return nil, err
case line, ok := <-buf:
if !ok {
return nil, nil
}
ans := make(map[string]string)
l := strings.TrimSpace(string(line))
if stream {
if !strings.HasPrefix(l, "\"text\"") {
continue
}
// it should be: "text": "xxx"
ans["content"] = l[9 : len(l)-2]
}
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
return nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(line))
return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans["role"] = c.Content.Role
ans["content"] = t.Text
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
return nil, err
}
if len(res.Candidates) == 0 && res.PromptFeedback.BlockReason != "" {
g.log.Errorf("gemini response: %s ", string(line))
return nil, fmt.Errorf("gemini api block because of %s", res.PromptFeedback.BlockReason)
}
for _, c := range res.Candidates {
for _, t := range c.Content.Parts {
ans["role"] = c.Content.Role
ans["content"] = t.Text
}
}
}
}
select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in gemini chat")
return
case answers <- ans:
continue
answers <- ans
}
}
return nil, err
}
1 change: 1 addition & 0 deletions pkg/llm/client/gemini/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func NewGemini(log logger.Logger, baseUrl, key string, conf config.GeminiConfig)
}

func (g *Gemini) request(ctx context.Context, stream bool, path string, method string, data map[string]any, res chan<- []byte) error {
defer close(res)
jsonData, _ := json.Marshal(data)

maxRetry := 100
Expand Down
61 changes: 31 additions & 30 deletions pkg/llm/client/openai/v1/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,31 @@ func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]s
"stream": stream,
}

buf := make(chan []byte)
var (
buf = make(chan []byte)
errCh = make(chan error)
)
go func() {
defer close(buf)
err = o.request(ctx, stream, path, "POST", data, buf)
if err != nil {
return
}
defer close(errCh)
errCh <- o.request(ctx, stream, path, "POST", data, buf)
}()

for line := range buf {
var delta map[string]string
if stream {
var res ChatStreamResult
if string(line) == "EOF" {
delta = map[string]string{"content": "EOF"}
} else {
for {
select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in openai chat")
return
case err = <-errCh:
if err != nil {
return nil, err
}
case line, ok := <-buf:
if !ok {
return nil, nil
}
var delta map[string]string
if stream {
var res ChatStreamResult
if !strings.HasPrefix(string(line), "data:") || strings.Contains(string(line), "data: [DONE]") {
continue
}
Expand All @@ -107,24 +116,16 @@ func (o *OpenAIV1) Chat(ctx context.Context, stream bool, history []map[string]s
return
}
delta = res.Choices[0].Delta
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err)
return
}
delta = res.Choices[0].Message
}
} else {
var res ChatResult
err = json.Unmarshal(line, &res)
if err != nil {
err = fmt.Errorf("cannot marshal msg: %s, err: %v", line, err)
return
}
delta = res.Choices[0].Message
}

select {
case <-ctx.Done():
err = fmt.Errorf("context timeout in openai chat")
return
case resp <- delta:
continue
resp <- delta
}
}
return
}
1 change: 1 addition & 0 deletions pkg/llm/client/openai/v1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ func NewOpenAIV1(log logger.Logger, baseUrl, key string, conf config.OpenAIConfi
var _ llm.LLM = &OpenAIV1{}

func (o *OpenAIV1) request(ctx context.Context, stream bool, path string, method string, data map[string]any, res chan<- []byte) error {
defer close(res)
jsonData, _ := json.Marshal(data)

maxRetry := 100
Expand Down
Loading