Skip to content

Commit

Permalink
Fix failures to compute embeddings too much context & Shorten RAG Re…
Browse files Browse the repository at this point in the history
…sults (#279)

# Fix failures to compute embeddings because of computing too much
context.
  
* Prior to this document we were computing the embeddings using the
entire notebook.
    This would lead to context exceeded errors on longer documents.
  
  * This had two negative impacts
  
1. We stop learning from long documents because we no longer compute
embeddings for the document
1. When making suggestions we don't embed up retrieving any documents
from RAG because we can't compute the embeddings
       for the current document

# Don't include the full Document in the RAG example

* In Example.Query we were including the full document which would then
be injected into the context when generating new suggests

* This can end up using a lot of tokens and potentially confusing the
agent when generating new suggestions

* Use a simple algorithm to shorten the example. The aglorithm is as
follows
   * Include the current selected cell
   * Keep including previous cells as long as they are markup cells

# Results

* On our evaluation results the number of match cells is approximately
the same; 6 correct whereas we got 7 before prior to this change
* Of these only 4 examples are the same ones that the two experiments
both got correct
* We should probably add level 1 assertions to test whether results are
retrieved to identify bugs like this in the future.
  
# Other
* This PR also refactors the code to share code for computing embeddings
between the learner and the Agent
    to minimize risk of training and serving skew.

* Treat learner failures as permanent failures rather than retrying. If
we see concrete examples of retryable errors than we can add retries
* Right now we were retrying on permanent errors (i.e. the context
length being exceeded).
  
  * Fix #260
  • Loading branch information
jlewi authored Oct 7, 2024
1 parent 59e16c2 commit b4874f5
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 52 deletions.
11 changes: 10 additions & 1 deletion app/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
build-dir:
mkdir -p .build

GIT_SHA := $(shell git rev-parse HEAD)
GIT_SHA_SHORT := $(shell git rev-parse --short HEAD)
DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
VERSION := $(shell git describe --tags)-$(GIT_SHA_SHORT)
LDFLAGS := -s -w \
-X 'github.com/jlewi/foyle/app/cmd.date=$(DATE)' \
-X 'github.com/jlewi/foyle/app/cmd.version=$(subst v,,$(VERSION))' \
-X 'github.com/jlewi/foyle/app/cmd.commit=$(GIT_SHA)'

build: build-dir
CGO_ENABLED=0 go build -o .build/foyle github.com/jlewi/foyle/app
CGO_ENABLED=0 go build -o .build/foyle -ldflags="$(LDFLAGS)" github.com/jlewi/foyle/app

build-wasm:
GOARCH=wasm GOOS=js go build -o web/app.wasm ./pwa
Expand Down
2 changes: 1 addition & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (a *Agent) Generate(ctx context.Context, req *v1alpha1.GenerateRequest) (*v
var examples []*v1alpha1.Example
if a.config.UseRAG() {
var err error
examples, err = a.db.GetExamples(ctx, req.Doc, a.config.RagMaxResults())
examples, err = a.db.GetExamples(ctx, req, a.config.RagMaxResults())
if err != nil {
// Fail gracefully; keep going without examples
log.Error(err, "Failed to get examples")
Expand Down
25 changes: 25 additions & 0 deletions app/pkg/docs/blocks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package docs

import (
"context"

"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
)

// CreateQuery creates a query from a GenerateRequest
// It returns the blocks that should be used to query for similar documents
func CreateQuery(ctx context.Context, req *v1alpha1.GenerateRequest) ([]*v1alpha1.Block, error) {
// Use a simple algorithm.
// 1. Always select at least the current block
// 2. Select additional blocks if they are markup blocks.
startIndex := req.GetSelectedIndex() - 1

for ; startIndex >= 0; startIndex-- {
if req.GetDoc().GetBlocks()[startIndex].Kind != v1alpha1.BlockKind_MARKUP {
break
}
}

blocks := req.GetDoc().GetBlocks()[startIndex+1 : req.GetSelectedIndex()+1]
return blocks, nil
}
87 changes: 87 additions & 0 deletions app/pkg/docs/blocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package docs

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/jlewi/foyle/app/pkg/testutil"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
)

func Test_CreateQuery(t *testing.T) {
doc1 := &v1alpha1.Doc{
Blocks: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 1",
},
{
Kind: v1alpha1.BlockKind_CODE,
Contents: "cell 2",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 3",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 4",
},
},
}

type testCase struct {
name string
input *v1alpha1.GenerateRequest
expected []*v1alpha1.Block
}

cases := []testCase{
{
name: "stop-at-start",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 1,
},
expected: doc1.Blocks[0:2],
},
{
name: "start-on-codeblock",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 2,
},
expected: doc1.Blocks[0:3],
},
{
name: "stop-on-code",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 4,
},
expected: doc1.Blocks[3:5],
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
blocks, err := CreateQuery(context.Background(), tc.input)
if err != nil {
t.Fatalf("CreateQuery failed: %v", err)
}
if len(blocks) != len(tc.expected) {
t.Errorf("CreateQuery returned %d blocks; want %d", len(blocks), len(tc.expected))
}

if d := cmp.Diff(tc.expected, blocks, testutil.DocComparer, testutil.BlockComparer); d != "" {
t.Errorf("CreateQuery returned unexpected blocks:\n%v", d)
}

})
}
}
1 change: 1 addition & 0 deletions app/pkg/eval/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (e *Evaluator) processExamples(ctx context.Context, examples []*v1alpha1.Ev
log.V(logs.Debug).Info("Skipping example; already processed")
continue
}
log.Info("Processing example")

var processErr error

Expand Down
18 changes: 12 additions & 6 deletions app/pkg/learn/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ import (
"sort"
"sync"

"github.com/jlewi/foyle/app/pkg/docs"

"github.com/jlewi/foyle/app/pkg/llms"

"github.com/jlewi/monogo/files"

"k8s.io/client-go/util/workqueue"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -71,21 +72,26 @@ func NewInMemoryExampleDB(cfg config.Config, vectorizer llms.Vectorizer) (*InMem
return db, nil
}

func (db *InMemoryExampleDB) GetExamples(ctx context.Context, doc *v1alpha1.Doc, maxResults int) ([]*v1alpha1.Example, error) {
func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.GenerateRequest, maxResults int) ([]*v1alpha1.Example, error) {
log := logs.FromContext(ctx)
query := docs.DocToMarkdown(doc)

if len(db.examples) == 0 {
// TODO(jeremy): What should we do in this case?
return nil, errors.New("No examples available")
// Since there are no examples just return an empty list
return []*v1alpha1.Example{}, nil
}

blocks, err := docs.CreateQuery(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "Failed to create query")
}

// Compute the embedding for the query.
qVec, err := db.vectorizer.Embed(ctx, query)
qVecData, err := db.vectorizer.Embed(ctx, blocks)
if err != nil {
return nil, errors.Wrap(err, "Failed to compute embedding for query")
}

qVec := llms.VectorToVecDense(qVecData)
// Acquire a lock on the data so we can safely read it.
db.lock.RLock()
defer db.lock.RUnlock()
Expand Down
6 changes: 5 additions & 1 deletion app/pkg/learn/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ func Test_InMemoryDB(t *testing.T) {
},
},
}
examples, err := db.GetExamples(context.Background(), doc, 1)
req := &v1alpha1.GenerateRequest{
Doc: doc,
SelectedIndex: 0,
}
examples, err := db.GetExamples(context.Background(), req, 1)
if err != nil {
t.Fatalf("Error getting examples; %v", err)
}
Expand Down
67 changes: 36 additions & 31 deletions app/pkg/learn/learner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"io"
"strings"
"sync"
"time"

"github.com/jlewi/foyle/app/pkg/docs"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
Expand All @@ -20,7 +21,6 @@ import (
logspb "github.com/jlewi/foyle/protos/go/foyle/logs"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -59,18 +59,22 @@ type Learner struct {
postFunc PostLearnEvent
eventLoopIsDone sync.WaitGroup
factory *files.Factory
vectorizer *oai.Vectorizer
}

func NewLearner(cfg config.Config, client *openai.Client, blocksDB *dbutil.LockingDB[*logspb.BlockLog]) (*Learner, error) {
if client == nil {
return nil, errors.New("OpenAI client is required")
}

vectorizer := oai.NewVectorizer(client)
return &Learner{
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
vectorizer: vectorizer,
}, nil
}

Expand Down Expand Up @@ -113,10 +117,10 @@ func (l *Learner) eventLoop(ctx context.Context) {
}

if err := l.Reconcile(ctx, exampleId); err != nil {
// N.B. Right now we treat learning errors as permanent and don't retry.
// The most likely source of retryable errors the vectorizer endpoint should already be handled
// by using a retryable HTTP client.
log.Error(err, "Error learning from example", "example", exampleId)
// Requeue the item so we will try again.
// TODO(jeremy): should we use a rate limiting queue so we eventually give up?
l.queue.AddAfter(exampleId, 30*time.Second)
return
}
}()
Expand Down Expand Up @@ -180,17 +184,31 @@ func (l *Learner) Reconcile(ctx context.Context, id string) error {

if len(expectedFiles) == 0 {
cellsProcessed.WithLabelValues("noExampleFiles").Inc()
log.Error(err, "No training files found", "id", b.GetId())
log.Error(err, "No training files found", "blockId", b.GetId())
return errors.Wrapf(err, "No training files found for example %s", b.GetId())
}

// TODO(jeremy): Should we take into account execution status when looking for mistakes?

// Deep copy the original message
newDoc := proto.Clone(b.Doc).(*v1alpha1.Doc)
newBlock := proto.Clone(b.ExecutedBlock).(*v1alpha1.Block)
answer := []*v1alpha1.Block{newBlock}

req := &v1alpha1.GenerateRequest{
Doc: b.Doc,
SelectedIndex: int32(len(b.Doc.Blocks) - 1),
}
queryBlocks, err := docs.CreateQuery(ctx, req)

newDoc := &v1alpha1.Doc{
Blocks: queryBlocks,
}

if err != nil {
log.Error(err, "Failed to create query", "exampleId", b.GetId())
return errors.Wrapf(err, "Failed to create query for example %s", b.GetId())
}

example := &v1alpha1.Example{
Id: b.GetId(),
Query: newDoc,
Expand Down Expand Up @@ -275,30 +293,17 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp
return nil
}

query := docs.DocToMarkdown(example.Query)
qVec, err := l.vectorizer.Embed(ctx, example.Query.GetBlocks())

request := openai.EmbeddingRequestStrings{
Input: []string{query},
Model: openai.SmallEmbedding3,
User: "",
EncodingFormat: "float",
}
resp, err := l.client.CreateEmbeddings(ctx, request)
if err != nil {
log.Error(err, "Failed to create embeddings", "id", example.Id, "query", query)
return errors.Wrapf(err, "Failed to create embeddings")
}

if len(resp.Data) != 1 {
log.Error(err, "Expected exactly 1 embedding", "id", example.Id, "query", query, "got", len(resp.Data))
return errors.Errorf("Expected exactly 1 embedding but got %d", len(resp.Data))
return err
}

if len(resp.Data[0].Embedding) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", query, "got", len(resp.Data[0].Embedding), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), oai.SmallEmbeddingsDims)
if len(qVec) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", example.Query, "got", len(qVec), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(qVec), oai.SmallEmbeddingsDims)
}

example.Embedding = resp.Data[0].Embedding
example.Embedding = qVec
return nil
}
2 changes: 1 addition & 1 deletion app/pkg/learn/learner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Test_Learner(t *testing.T) {

blocksDB, err := pebble.Open(cfg.GetBlocksDBDir(), &pebble.Options{})
if err != nil {
t.Fatalf("could not open blocks database %s", cfg.GetBlocksDBDir())
t.Fatalf("could not open blocks database %+v", cfg.GetBlocksDBDir())
}
defer blocksDB.Close()

Expand Down
1 change: 1 addition & 0 deletions app/pkg/llms/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package llms
18 changes: 16 additions & 2 deletions app/pkg/llms/vectorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@ package llms
import (
"context"

"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"

"gonum.org/v1/gonum/mat"
)

type Vector []float32

// Vectorizer computes embedding representations of text.
type Vectorizer interface {
// Embed computes the embedding of the text
Embed(ctx context.Context, text string) (*mat.VecDense, error)
// Embed computes the embedding of the blocks
Embed(ctx context.Context, blocks []*v1alpha1.Block) (Vector, error)
// Length returns the length of the embeddings
Length() int
}

// VectorToVecDense converts a Vector to a *mat.VecDense
func VectorToVecDense(v Vector) *mat.VecDense {
// We need to cast from float32 to float64
qVec := mat.NewVecDense(len(v), nil)
for i := 0; i < len(v); i++ {
qVec.SetVec(i, float64(v[i]))
}
return qVec
}
Loading

0 comments on commit b4874f5

Please sign in to comment.