-
Notifications
You must be signed in to change notification settings - Fork 34
/
streamaccumulator.go
188 lines (161 loc) · 6.22 KB
/
streamaccumulator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
package openai
// Helper to accumulate chunks from a stream
type ChatCompletionAccumulator struct {
// The up-to-date accumulation of model's responses
ChatCompletion
choiceChatCompletionStates []chatCompletionResponseState
justFinished chatCompletionResponseState
}
type FinishedChatCompletionToolCall struct {
ChatCompletionMessageToolCallFunction
Index int
}
type chatCompletionResponseState struct {
state chatCompletionResponseStateEnum
index int
}
type chatCompletionResponseStateEnum int
const (
emptyResponseState chatCompletionResponseStateEnum = iota
contentResponseState
refusalResponseState
toolResponseState
finishedResponseState
)
// AddChunk incorporates a chunk into the accumulation. Chunks must be added in order.
// Returns false if the chunk could not be successfully accumulated.
//
// The ChatCompletion field JSON does not get accumulated.
func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
acc.justFinished = chatCompletionResponseState{}
if !acc.accumulateDelta(chunk) {
return false
}
// only chunks with choices can cause finished events
if len(chunk.Choices) == 0 {
return true
}
chunkIndex := int(chunk.Choices[0].Index)
acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex)
acc.justFinished = acc.choiceChatCompletionStates[chunkIndex].update(chunk)
return true
}
// JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed.
// The content is "just completed" when the last added chunk no longer contains a content
// delta. If the content is just completed, the content is returned and the boolean is true. Otherwise,
// an empty string is returned and the boolean will be false.
func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok bool) {
if acc.justFinished.state == contentResponseState {
return acc.Choices[0].Message.Content, true
}
return "", false
}
// JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed.
// The refusal is "just completed" when the last added chunk no longer contains a refusal
// delta. If the refusal is just completed, the refusal is returned and the boolean is true. Otherwise,
// an empty string is returned and the boolean will be false.
func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok bool) {
if acc.justFinished.state == refusalResponseState {
return acc.Choices[0].Message.Refusal, true
}
return "", false
}
// JustFinishedToolCall retrieves a tool call when it is known to have just been completed.
// A tool call is "just completed" when the last added chunk no longer contains a tool call
// delta or contains a delta for a different tool call. If the tool call is just completed,
// a FinishedChatCompletionToolCall is returned and the boolean is true. Otherwise, an empty
// tool call is returned and the boolean will be false.
//
// You cannot rely on this with a stream that has ParallelToolCalls enabled.
func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedChatCompletionToolCall, ok bool) {
if acc.justFinished.state == toolResponseState {
f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function
return FinishedChatCompletionToolCall{
Index: acc.justFinished.index,
ChatCompletionMessageToolCallFunction: ChatCompletionMessageToolCallFunction{
Name: f.Name,
Arguments: f.Arguments,
},
}, true
}
return FinishedChatCompletionToolCall{}, false
}
// Concatenates a ChatCompletionChunk onto a ChatCompletion. Returns false and
// does nothing if a mismatch is detected.
//
// Ignores the JSON field
func (cc *ChatCompletion) accumulateDelta(chunk ChatCompletionChunk) bool {
if len(cc.ID) == 0 {
cc.ID = chunk.ID
} else if cc.ID != chunk.ID {
return false
}
for _, delta := range chunk.Choices {
cc.Choices = expandToFit(cc.Choices, int(delta.Index))
choice := &cc.Choices[delta.Index]
choice.Index = delta.Index
choice.FinishReason = ChatCompletionChoicesFinishReason(delta.FinishReason)
if delta.Delta.Role != "" {
choice.Message.Role = ChatCompletionMessageRole(delta.Delta.Role)
}
choice.Message.Content += delta.Delta.Content
choice.Message.Refusal += delta.Delta.Refusal
for j := range delta.Delta.ToolCalls {
deltaTool := &delta.Delta.ToolCalls[j]
choice.Message.ToolCalls = expandToFit(choice.Message.ToolCalls, int(deltaTool.Index))
tool := &choice.Message.ToolCalls[deltaTool.Index]
if deltaTool.ID != "" {
tool.ID = deltaTool.ID
}
if deltaTool.Type != "" {
tool.Type = ChatCompletionMessageToolCallType(deltaTool.Type)
}
tool.Function.Name += deltaTool.Function.Name
tool.Function.Arguments += deltaTool.Function.Arguments
}
choice.Logprobs.Content = append(choice.Logprobs.Content, delta.Logprobs.Content...)
choice.Logprobs.Refusal = append(choice.Logprobs.Refusal, delta.Logprobs.Refusal...)
}
cc.Usage.CompletionTokens += chunk.Usage.CompletionTokens
cc.Usage.PromptTokens += chunk.Usage.PromptTokens
cc.Usage.TotalTokens += chunk.Usage.TotalTokens
cc.Model = chunk.Model
cc.Created = chunk.Created
cc.SystemFingerprint = chunk.SystemFingerprint
cc.ServiceTier = ChatCompletionServiceTier(chunk.ServiceTier)
cc.Object = ChatCompletionObject(chunk.Object)
return true
}
// Updates the internal response state and returns the previous state if
// the state changed. This ensures that JustFinished events only fire once.
func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (justFinished chatCompletionResponseState) {
delta := chunk.Choices[0].Delta
new := chatCompletionResponseState{}
switch {
case !delta.JSON.Content.IsNull():
new.state = contentResponseState
case !delta.JSON.Refusal.IsNull():
new.state = refusalResponseState
case !delta.JSON.ToolCalls.IsNull():
new.state = toolResponseState
new.index = int(delta.ToolCalls[0].Index)
default:
new.state = finishedResponseState
}
if *prev != new {
justFinished = *prev
}
*prev = new
return justFinished
}
func expandToFit[T any](slice []T, index int) []T {
if index < len(slice) {
return slice
}
if index < cap(slice) {
return slice[:index+1]
}
newSlice := make([]T, index+1)
copy(newSlice, slice)
return newSlice
}