Skip to content

Commit

Permalink
Properly handle completions in middle of document - Include LLM reque… (
Browse files Browse the repository at this point in the history
#271)

…st/response in the trace

* Fix #111
* If we are generating a completion for a cell that is in the middle of
the document then we need to truncate the cells after the currently
corrected cell. This is needed because our prompt just asks the AI to
continue writing the document. So if we have a bunch of cells after the
currently selected cell we will have problems.

* Truncating the cells after the selected cell is the easiest way to
avoid this confusion.

* Fix #267 include the LLM request and response in the actual trace

* This is a much more efficient way to get the raw request and response
* GetLLMLogs does a linear search over the logs which is not efficient.

* Here's my notebook with my test that its working
https://gist.github.com/jlewi/164389227f5f797752ca1a290a05ad1e
  • Loading branch information
jlewi authored Oct 3, 2024
1 parent 527e171 commit 26dbff4
Show file tree
Hide file tree
Showing 24 changed files with 964 additions and 305 deletions.
34 changes: 34 additions & 0 deletions app/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"time"

"github.com/jlewi/foyle/app/pkg/logs/matchers"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -222,3 +224,35 @@ func (L *LogEntry) Time() time.Time {
timestamp := time.Unix(seconds, nanoseconds)
return timestamp
}

// SetRequest sets the request field in the log entry.
// This is only intended for constructing log entries as part of testing
func SetRequest(e *LogEntry, req interface{}) error {
b, err := json.Marshal(req)
if err != nil {
return err
}

o := make(map[string]interface{})
if err := json.Unmarshal(b, &o); err != nil {
return err
}
(*e)[matchers.RequestField] = o
return nil
}

// SetResponse sets the response field in the log entry.
// This is only intended for constructing log entries as part of testing
func SetResponse(e *LogEntry, req interface{}) error {
b, err := json.Marshal(req)
if err != nil {
return err
}

o := make(map[string]interface{})
if err := json.Unmarshal(b, &o); err != nil {
return err
}
(*e)[matchers.ResponseField] = o
return nil
}
23 changes: 17 additions & 6 deletions app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ func (a *Agent) Generate(ctx context.Context, req *v1alpha1.GenerateRequest) (*v
func (a *Agent) completeWithRetries(ctx context.Context, req *v1alpha1.GenerateRequest, examples []*v1alpha1.Example) ([]*v1alpha1.Block, error) {
log := logs.FromContext(ctx)

t := docs.NewTailer(req.Doc.GetBlocks(), MaxDocChars)
cells := preprocessDoc(req)
t := docs.NewTailer(cells, MaxDocChars)

exampleArgs := make([]Example, 0, len(examples))
for _, example := range examples {
Expand Down Expand Up @@ -224,7 +225,8 @@ func (a *Agent) StreamGenerate(ctx context.Context, stream *connect.BidiStream[v
// This should be safe because each time we update pendingDoc we update it to point to
// a new doc object. So the other thread won't be modifying the doc pendingDoc points to
r := &v1alpha1.GenerateRequest{
Doc: pendingDoc,
Doc: pendingDoc,
SelectedIndex: selectedCell,
}
pendingDoc = nil
return r
Expand All @@ -234,7 +236,7 @@ func (a *Agent) StreamGenerate(ctx context.Context, stream *connect.BidiStream[v
continue
}

response, err := a.createCompletion(ctx, generateRequest, notebookUri, selectedCell, state.getContextID())
response, err := a.createCompletion(ctx, generateRequest, notebookUri, state.getContextID())

if err != nil {
log.Error(err, "createCompletion failed")
Expand Down Expand Up @@ -448,7 +450,8 @@ func (a *Agent) GenerateCells(ctx context.Context, req *connect.Request[v1alpha1
return nil, err
}
agentReq := &v1alpha1.GenerateRequest{
Doc: doc,
Doc: doc,
SelectedIndex: req.Msg.GetSelectedIndex(),
}

// Call the agent
Expand All @@ -475,11 +478,12 @@ func (a *Agent) GenerateCells(ctx context.Context, req *connect.Request[v1alpha1
}

// createCompletion is a helper function to create a single completion as part of a stream.
func (a *Agent) createCompletion(ctx context.Context, generateRequest *v1alpha1.GenerateRequest, notebookUri string, selectedCell int32, contextID string) (*v1alpha1.StreamGenerateResponse, error) {
func (a *Agent) createCompletion(ctx context.Context, generateRequest *v1alpha1.GenerateRequest, notebookUri string, contextID string) (*v1alpha1.StreamGenerateResponse, error) {
span := trace.SpanFromContext(ctx)
log := logs.FromContext(ctx)
traceId := span.SpanContext().TraceID()
tp := tracer()

// We need to generate a new ctx with a new trace ID because we want one trace per completion
// We need to use withNewRoot because we want to make it a new trace and not rooted at the current one
generateCtx, generateSpan := tp.Start(ctx, "CreateCompletion", trace.WithNewRoot(), trace.WithAttributes(attribute.String("streamTraceID", traceId.String()), attribute.String("contextID", contextID)))
Expand All @@ -501,7 +505,7 @@ func (a *Agent) createCompletion(ctx context.Context, generateRequest *v1alpha1.
response := &v1alpha1.StreamGenerateResponse{
Cells: cells,
NotebookUri: notebookUri,
InsertAt: selectedCell + 1,
InsertAt: generateRequest.GetSelectedIndex() + 1,
ContextId: contextID,
}

Expand Down Expand Up @@ -631,3 +635,10 @@ func dropResponse(response *v1alpha1.StreamGenerateResponse) bool {
}
return false
}

// preprocessDoc does some preprocessing of the doc.
func preprocessDoc(req *v1alpha1.GenerateRequest) []*v1alpha1.Block {
// We want to remove all cells after the selected cell because our prompt doesn't know how to take them into account.
cells := req.Doc.Blocks[:req.SelectedIndex+1]
return cells
}
70 changes: 70 additions & 0 deletions app/pkg/agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp/cmpopts"

"github.com/google/go-cmp/cmp"

"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -424,3 +426,71 @@ func Test_dropResponse(t *testing.T) {
})
}
}

func Test_peprocessDoc(t *testing.T) {
type testCase struct {
name string
input *v1alpha1.GenerateRequest
expected []*v1alpha1.Block
}

doc := &v1alpha1.Doc{
Blocks: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
{
Kind: v1alpha1.BlockKind_CODE,
Contents: "cell 1",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 2",
},
},
}
cases := []testCase{
{
name: "basic",
input: &v1alpha1.GenerateRequest{
Doc: doc,
SelectedIndex: 0,
},
expected: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
},
},
{
name: "middle",
input: &v1alpha1.GenerateRequest{
Doc: doc,
SelectedIndex: 1,
},
expected: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
{
Kind: v1alpha1.BlockKind_CODE,
Contents: "cell 1",
},
},
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
actual := preprocessDoc(c.input)

opts := cmpopts.IgnoreUnexported(v1alpha1.Block{})
if d := cmp.Diff(c.expected, actual, opts); d != "" {
t.Errorf("Unexpected diff:\n%s", d)
}
})
}
}
4 changes: 0 additions & 4 deletions app/pkg/analyze/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ func readLLMLog(ctx context.Context, traceId string, logFile string) (*logspb.Ge
continue
}
isMatch := false
if strings.HasSuffix(entry.Function(), "anthropic.(*Completer).Complete") {
provider = api.ModelProviderAnthropic
isMatch = true
}

if matchers.IsOAIComplete(entry.Function()) {
provider = api.ModelProviderOpenAI
Expand Down
66 changes: 66 additions & 0 deletions app/pkg/analyze/spans.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"strings"

"github.com/jlewi/foyle/app/pkg/logs/matchers"
"google.golang.org/protobuf/proto"

"github.com/jlewi/foyle/app/api"
logspb "github.com/jlewi/foyle/protos/go/foyle/logs"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand All @@ -13,6 +16,47 @@ func logEntryToSpan(ctx context.Context, e *api.LogEntry) *logspb.Span {
if strings.Contains(e.Function(), "learn.(*InMemoryExampleDB).GetExamples") {
return logEntryToRAGSpan(ctx, e)
}

if matchers.IsOAIComplete(e.Function()) || matchers.IsAnthropicComplete(e.Function()) {
return logEntryToLLMSpan(ctx, e)
}
return nil
}

func logEntryToLLMSpan(ctx context.Context, e *api.LogEntry) *logspb.Span {
provider := v1alpha1.ModelProvider_MODEL_PROVIDER_UNKNOWN
if matchers.IsOAIComplete(e.Function()) {
provider = v1alpha1.ModelProvider_OPEN_AI
} else if matchers.IsAnthropicComplete(e.Function()) {
provider = v1alpha1.ModelProvider_ANTHROPIC
}

// Code relies on the fact that the completer field only use the fields request and response for the LLM model.
// The code below also relies on the fact that the request and response are logged on separate log lines
reqB := e.Request()
if reqB != nil {
return &logspb.Span{
Data: &logspb.Span_Llm{
Llm: &logspb.LLMSpan{
Provider: provider,
RequestJson: string(reqB),
},
},
}
}

resB := e.Response()
if resB != nil {
return &logspb.Span{
Data: &logspb.Span_Llm{
Llm: &logspb.LLMSpan{
Provider: provider,
ResponseJson: string(resB),
},
},
}
}

return nil
}

Expand Down Expand Up @@ -54,13 +98,21 @@ func combineSpans(trace *logspb.Trace) {
trace.Spans = make([]*logspb.Span, 0, len(oldSpans))

var ragSpan *logspb.RAGSpan
var llmSpan *logspb.LLMSpan

for _, s := range oldSpans {
if s.GetRag() != nil {
if ragSpan == nil {
ragSpan = s.GetRag()
} else {
ragSpan = combineRAGSpans(ragSpan, s.GetRag())
}
} else if s.GetLlm() != nil {
if llmSpan == nil {
llmSpan = s.GetLlm()
} else {
llmSpan = combineLLMSpans(llmSpan, s.GetLlm())
}
} else {
trace.Spans = append(trace.Spans, s)
}
Expand All @@ -73,6 +125,14 @@ func combineSpans(trace *logspb.Trace) {
},
})
}

if llmSpan != nil {
trace.Spans = append(trace.Spans, &logspb.Span{
Data: &logspb.Span_Llm{
Llm: llmSpan,
},
})
}
}

// combine two RagSpans
Expand All @@ -90,3 +150,9 @@ func combineRAGSpans(a, b *logspb.RAGSpan) *logspb.RAGSpan {
span.Results = append(span.Results, b.Results...)
return span
}

// combine two LLMSpans
func combineLLMSpans(a, b *logspb.LLMSpan) *logspb.LLMSpan {
proto.Merge(a, b)
return a
}
Loading

0 comments on commit 26dbff4

Please sign in to comment.