Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix failures to compute embeddings too much context & Shorten RAG Results #279

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion app/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
build-dir:
mkdir -p .build

GIT_SHA := $(shell git rev-parse HEAD)
GIT_SHA_SHORT := $(shell git rev-parse --short HEAD)
DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ")
VERSION := $(shell git describe --tags)-$(GIT_SHA_SHORT)
LDFLAGS := -s -w \
-X 'github.com/jlewi/foyle/app/cmd.date=$(DATE)' \
-X 'github.com/jlewi/foyle/app/cmd.version=$(subst v,,$(VERSION))' \
-X 'github.com/jlewi/foyle/app/cmd.commit=$(GIT_SHA)'

build: build-dir
CGO_ENABLED=0 go build -o .build/foyle github.com/jlewi/foyle/app
CGO_ENABLED=0 go build -o .build/foyle -ldflags="$(LDFLAGS)" github.com/jlewi/foyle/app

build-wasm:
GOARCH=wasm GOOS=js go build -o web/app.wasm ./pwa
Expand Down
2 changes: 1 addition & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (a *Agent) Generate(ctx context.Context, req *v1alpha1.GenerateRequest) (*v
var examples []*v1alpha1.Example
if a.config.UseRAG() {
var err error
examples, err = a.db.GetExamples(ctx, req.Doc, a.config.RagMaxResults())
examples, err = a.db.GetExamples(ctx, req, a.config.RagMaxResults())
if err != nil {
// Fail gracefully; keep going without examples
log.Error(err, "Failed to get examples")
Expand Down
25 changes: 25 additions & 0 deletions app/pkg/docs/blocks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package docs

import (
"context"

"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
)

// CreateQuery creates a query from a GenerateRequest
// It returns the blocks that should be used to query for similar documents
func CreateQuery(ctx context.Context, req *v1alpha1.GenerateRequest) ([]*v1alpha1.Block, error) {
// Use a simple algorithm.
// 1. Always select at least the current block
// 2. Select additional blocks if they are markup blocks.
startIndex := req.GetSelectedIndex() - 1

for ; startIndex >= 0; startIndex-- {
if req.GetDoc().GetBlocks()[startIndex].Kind != v1alpha1.BlockKind_MARKUP {
break
}
}

blocks := req.GetDoc().GetBlocks()[startIndex+1 : req.GetSelectedIndex()+1]
return blocks, nil
}
87 changes: 87 additions & 0 deletions app/pkg/docs/blocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package docs

import (
"context"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/jlewi/foyle/app/pkg/testutil"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
)

func Test_CreateQuery(t *testing.T) {
doc1 := &v1alpha1.Doc{
Blocks: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 1",
},
{
Kind: v1alpha1.BlockKind_CODE,
Contents: "cell 2",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 3",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 4",
},
},
}

type testCase struct {
name string
input *v1alpha1.GenerateRequest
expected []*v1alpha1.Block
}

cases := []testCase{
{
name: "stop-at-start",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 1,
},
expected: doc1.Blocks[0:2],
},
{
name: "start-on-codeblock",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 2,
},
expected: doc1.Blocks[0:3],
},
{
name: "stop-on-code",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 4,
},
expected: doc1.Blocks[3:5],
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
blocks, err := CreateQuery(context.Background(), tc.input)
if err != nil {
t.Fatalf("CreateQuery failed: %v", err)
}
if len(blocks) != len(tc.expected) {
t.Errorf("CreateQuery returned %d blocks; want %d", len(blocks), len(tc.expected))
}

if d := cmp.Diff(tc.expected, blocks, testutil.DocComparer, testutil.BlockComparer); d != "" {
t.Errorf("CreateQuery returned unexpected blocks:\n%v", d)
}

})
}
}
1 change: 1 addition & 0 deletions app/pkg/eval/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func (e *Evaluator) processExamples(ctx context.Context, examples []*v1alpha1.Ev
log.V(logs.Debug).Info("Skipping example; already processed")
continue
}
log.Info("Processing example")

var processErr error

Expand Down
18 changes: 12 additions & 6 deletions app/pkg/learn/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ import (
"sort"
"sync"

"github.com/jlewi/foyle/app/pkg/docs"

"github.com/jlewi/foyle/app/pkg/llms"

"github.com/jlewi/monogo/files"

"k8s.io/client-go/util/workqueue"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -71,21 +72,26 @@ func NewInMemoryExampleDB(cfg config.Config, vectorizer llms.Vectorizer) (*InMem
return db, nil
}

func (db *InMemoryExampleDB) GetExamples(ctx context.Context, doc *v1alpha1.Doc, maxResults int) ([]*v1alpha1.Example, error) {
func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.GenerateRequest, maxResults int) ([]*v1alpha1.Example, error) {
log := logs.FromContext(ctx)
query := docs.DocToMarkdown(doc)

if len(db.examples) == 0 {
// TODO(jeremy): What should we do in this case?
return nil, errors.New("No examples available")
// Since there are no examples just return an empty list
return []*v1alpha1.Example{}, nil
}

blocks, err := docs.CreateQuery(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "Failed to create query")
}

// Compute the embedding for the query.
qVec, err := db.vectorizer.Embed(ctx, query)
qVecData, err := db.vectorizer.Embed(ctx, blocks)
if err != nil {
return nil, errors.Wrap(err, "Failed to compute embedding for query")
}

qVec := llms.VectorToVecDense(qVecData)
// Acquire a lock on the data so we can safely read it.
db.lock.RLock()
defer db.lock.RUnlock()
Expand Down
6 changes: 5 additions & 1 deletion app/pkg/learn/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ func Test_InMemoryDB(t *testing.T) {
},
},
}
examples, err := db.GetExamples(context.Background(), doc, 1)
req := &v1alpha1.GenerateRequest{
Doc: doc,
SelectedIndex: 0,
}
examples, err := db.GetExamples(context.Background(), req, 1)
if err != nil {
t.Fatalf("Error getting examples; %v", err)
}
Expand Down
67 changes: 36 additions & 31 deletions app/pkg/learn/learner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ import (
"io"
"strings"
"sync"
"time"

"github.com/jlewi/foyle/app/pkg/docs"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
Expand All @@ -20,7 +21,6 @@ import (
logspb "github.com/jlewi/foyle/protos/go/foyle/logs"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -59,18 +59,22 @@ type Learner struct {
postFunc PostLearnEvent
eventLoopIsDone sync.WaitGroup
factory *files.Factory
vectorizer *oai.Vectorizer
}

func NewLearner(cfg config.Config, client *openai.Client, blocksDB *dbutil.LockingDB[*logspb.BlockLog]) (*Learner, error) {
if client == nil {
return nil, errors.New("OpenAI client is required")
}

vectorizer := oai.NewVectorizer(client)
return &Learner{
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
vectorizer: vectorizer,
}, nil
}

Expand Down Expand Up @@ -113,10 +117,10 @@ func (l *Learner) eventLoop(ctx context.Context) {
}

if err := l.Reconcile(ctx, exampleId); err != nil {
// N.B. Right now we treat learning errors as permanent and don't retry.
// The most likely source of retryable errors the vectorizer endpoint should already be handled
// by using a retryable HTTP client.
log.Error(err, "Error learning from example", "example", exampleId)
// Requeue the item so we will try again.
// TODO(jeremy): should we use a rate limiting queue so we eventually give up?
l.queue.AddAfter(exampleId, 30*time.Second)
return
}
}()
Expand Down Expand Up @@ -180,17 +184,31 @@ func (l *Learner) Reconcile(ctx context.Context, id string) error {

if len(expectedFiles) == 0 {
cellsProcessed.WithLabelValues("noExampleFiles").Inc()
log.Error(err, "No training files found", "id", b.GetId())
log.Error(err, "No training files found", "blockId", b.GetId())
return errors.Wrapf(err, "No training files found for example %s", b.GetId())
}

// TODO(jeremy): Should we take into account execution status when looking for mistakes?

// Deep copy the original message
newDoc := proto.Clone(b.Doc).(*v1alpha1.Doc)
newBlock := proto.Clone(b.ExecutedBlock).(*v1alpha1.Block)
answer := []*v1alpha1.Block{newBlock}

req := &v1alpha1.GenerateRequest{
Doc: b.Doc,
SelectedIndex: int32(len(b.Doc.Blocks) - 1),
}
queryBlocks, err := docs.CreateQuery(ctx, req)

newDoc := &v1alpha1.Doc{
Blocks: queryBlocks,
}

if err != nil {
log.Error(err, "Failed to create query", "exampleId", b.GetId())
return errors.Wrapf(err, "Failed to create query for example %s", b.GetId())
}

example := &v1alpha1.Example{
Id: b.GetId(),
Query: newDoc,
Expand Down Expand Up @@ -275,30 +293,17 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp
return nil
}

query := docs.DocToMarkdown(example.Query)
qVec, err := l.vectorizer.Embed(ctx, example.Query.GetBlocks())

request := openai.EmbeddingRequestStrings{
Input: []string{query},
Model: openai.SmallEmbedding3,
User: "",
EncodingFormat: "float",
}
resp, err := l.client.CreateEmbeddings(ctx, request)
if err != nil {
log.Error(err, "Failed to create embeddings", "id", example.Id, "query", query)
return errors.Wrapf(err, "Failed to create embeddings")
}

if len(resp.Data) != 1 {
log.Error(err, "Expected exactly 1 embedding", "id", example.Id, "query", query, "got", len(resp.Data))
return errors.Errorf("Expected exactly 1 embedding but got %d", len(resp.Data))
return err
}

if len(resp.Data[0].Embedding) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", query, "got", len(resp.Data[0].Embedding), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), oai.SmallEmbeddingsDims)
if len(qVec) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", example.Query, "got", len(qVec), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(qVec), oai.SmallEmbeddingsDims)
}

example.Embedding = resp.Data[0].Embedding
example.Embedding = qVec
return nil
}
2 changes: 1 addition & 1 deletion app/pkg/learn/learner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Test_Learner(t *testing.T) {

blocksDB, err := pebble.Open(cfg.GetBlocksDBDir(), &pebble.Options{})
if err != nil {
t.Fatalf("could not open blocks database %s", cfg.GetBlocksDBDir())
t.Fatalf("could not open blocks database %+v", cfg.GetBlocksDBDir())
}
defer blocksDB.Close()

Expand Down
1 change: 1 addition & 0 deletions app/pkg/llms/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package llms
18 changes: 16 additions & 2 deletions app/pkg/llms/vectorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@ package llms
import (
"context"

"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"

"gonum.org/v1/gonum/mat"
)

type Vector []float32

// Vectorizer computes embedding representations of text.
type Vectorizer interface {
// Embed computes the embedding of the text
Embed(ctx context.Context, text string) (*mat.VecDense, error)
// Embed computes the embedding of the blocks
Embed(ctx context.Context, blocks []*v1alpha1.Block) (Vector, error)
// Length returns the length of the embeddings
Length() int
}

// VectorToVecDense converts a Vector to a *mat.VecDense
func VectorToVecDense(v Vector) *mat.VecDense {
// We need to cast from float32 to float64
qVec := mat.NewVecDense(len(v), nil)
for i := 0; i < len(v); i++ {
qVec.SetVec(i, float64(v[i]))
}
return qVec
}
Loading
Loading