diff --git a/app/Makefile b/app/Makefile index 36c9e3fc..44c0673a 100644 --- a/app/Makefile +++ b/app/Makefile @@ -1,8 +1,17 @@ build-dir: mkdir -p .build +GIT_SHA := $(shell git rev-parse HEAD) +GIT_SHA_SHORT := $(shell git rev-parse --short HEAD) +DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") +VERSION := $(shell git describe --tags)-$(GIT_SHA_SHORT) +LDFLAGS := -s -w \ + -X 'github.com/jlewi/foyle/app/cmd.date=$(DATE)' \ + -X 'github.com/jlewi/foyle/app/cmd.version=$(subst v,,$(VERSION))' \ + -X 'github.com/jlewi/foyle/app/cmd.commit=$(GIT_SHA)' + build: build-dir - CGO_ENABLED=0 go build -o .build/foyle github.com/jlewi/foyle/app + CGO_ENABLED=0 go build -o .build/foyle -ldflags="$(LDFLAGS)" github.com/jlewi/foyle/app build-wasm: GOARCH=wasm GOOS=js go build -o web/app.wasm ./pwa diff --git a/app/pkg/agent/agent.go b/app/pkg/agent/agent.go index 8835d24d..44354cef 100644 --- a/app/pkg/agent/agent.go +++ b/app/pkg/agent/agent.go @@ -92,7 +92,7 @@ func (a *Agent) Generate(ctx context.Context, req *v1alpha1.GenerateRequest) (*v var examples []*v1alpha1.Example if a.config.UseRAG() { var err error - examples, err = a.db.GetExamples(ctx, req.Doc, a.config.RagMaxResults()) + examples, err = a.db.GetExamples(ctx, req, a.config.RagMaxResults()) if err != nil { // Fail gracefully; keep going without examples log.Error(err, "Failed to get examples") diff --git a/app/pkg/docs/blocks.go b/app/pkg/docs/blocks.go new file mode 100644 index 00000000..09ad9b8c --- /dev/null +++ b/app/pkg/docs/blocks.go @@ -0,0 +1,25 @@ +package docs + +import ( + "context" + + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" +) + +// CreateQuery creates a query from a GenerateRequest +// It returns the blocks that should be used to query for similar documents +func CreateQuery(ctx context.Context, req *v1alpha1.GenerateRequest) ([]*v1alpha1.Block, error) { + // Use a simple algorithm. + // 1. Always select at least the current block + // 2. Select additional blocks if they are markup blocks. + startIndex := req.GetSelectedIndex() - 1 + + for ; startIndex >= 0; startIndex-- { + if req.GetDoc().GetBlocks()[startIndex].Kind != v1alpha1.BlockKind_MARKUP { + break + } + } + + blocks := req.GetDoc().GetBlocks()[startIndex+1 : req.GetSelectedIndex()+1] + return blocks, nil +} diff --git a/app/pkg/docs/blocks_test.go b/app/pkg/docs/blocks_test.go new file mode 100644 index 00000000..e9d844d0 --- /dev/null +++ b/app/pkg/docs/blocks_test.go @@ -0,0 +1,87 @@ +package docs + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/jlewi/foyle/app/pkg/testutil" + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" +) + +func Test_CreateQuery(t *testing.T) { + doc1 := &v1alpha1.Doc{ + Blocks: []*v1alpha1.Block{ + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 0", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 1", + }, + { + Kind: v1alpha1.BlockKind_CODE, + Contents: "cell 2", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 3", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 4", + }, + }, + } + + type testCase struct { + name string + input *v1alpha1.GenerateRequest + expected []*v1alpha1.Block + } + + cases := []testCase{ + { + name: "stop-at-start", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 1, + }, + expected: doc1.Blocks[0:2], + }, + { + name: "start-on-codeblock", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 2, + }, + expected: doc1.Blocks[0:3], + }, + { + name: "stop-on-code", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 4, + }, + expected: doc1.Blocks[3:5], + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + blocks, err := CreateQuery(context.Background(), tc.input) + if err != nil { + t.Fatalf("CreateQuery failed: %v", err) + } + if len(blocks) != len(tc.expected) { + t.Errorf("CreateQuery returned %d blocks; want %d", len(blocks), len(tc.expected)) + } + + if d := cmp.Diff(tc.expected, blocks, testutil.DocComparer, testutil.BlockComparer); d != "" { + t.Errorf("CreateQuery returned unexpected blocks:\n%v", d) + } + + }) + } +} diff --git a/app/pkg/eval/evaluator.go b/app/pkg/eval/evaluator.go index 0dd20e1d..7b134cd5 100644 --- a/app/pkg/eval/evaluator.go +++ b/app/pkg/eval/evaluator.go @@ -171,6 +171,7 @@ func (e *Evaluator) processExamples(ctx context.Context, examples []*v1alpha1.Ev log.V(logs.Debug).Info("Skipping example; already processed") continue } + log.Info("Processing example") var processErr error diff --git a/app/pkg/learn/in_memory.go b/app/pkg/learn/in_memory.go index d2dda97f..3b60ce32 100644 --- a/app/pkg/learn/in_memory.go +++ b/app/pkg/learn/in_memory.go @@ -6,6 +6,8 @@ import ( "sort" "sync" + "github.com/jlewi/foyle/app/pkg/docs" + "github.com/jlewi/foyle/app/pkg/llms" "github.com/jlewi/monogo/files" @@ -13,7 +15,6 @@ import ( "k8s.io/client-go/util/workqueue" "github.com/jlewi/foyle/app/pkg/config" - "github.com/jlewi/foyle/app/pkg/docs" "github.com/jlewi/foyle/app/pkg/logs" "github.com/jlewi/foyle/app/pkg/oai" "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" @@ -71,21 +72,26 @@ func NewInMemoryExampleDB(cfg config.Config, vectorizer llms.Vectorizer) (*InMem return db, nil } -func (db *InMemoryExampleDB) GetExamples(ctx context.Context, doc *v1alpha1.Doc, maxResults int) ([]*v1alpha1.Example, error) { +func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.GenerateRequest, maxResults int) ([]*v1alpha1.Example, error) { log := logs.FromContext(ctx) - query := docs.DocToMarkdown(doc) if len(db.examples) == 0 { - // TODO(jeremy): What should we do in this case? - return nil, errors.New("No examples available") + // Since there are no examples just return an empty list + return []*v1alpha1.Example{}, nil + } + + blocks, err := docs.CreateQuery(ctx, req) + if err != nil { + return nil, errors.Wrap(err, "Failed to create query") } // Compute the embedding for the query. - qVec, err := db.vectorizer.Embed(ctx, query) + qVecData, err := db.vectorizer.Embed(ctx, blocks) if err != nil { return nil, errors.Wrap(err, "Failed to compute embedding for query") } + qVec := llms.VectorToVecDense(qVecData) // Acquire a lock on the data so we can safely read it. db.lock.RLock() defer db.lock.RUnlock() diff --git a/app/pkg/learn/in_memory_test.go b/app/pkg/learn/in_memory_test.go index 1031f3f5..1c7d3563 100644 --- a/app/pkg/learn/in_memory_test.go +++ b/app/pkg/learn/in_memory_test.go @@ -114,7 +114,11 @@ func Test_InMemoryDB(t *testing.T) { }, }, } - examples, err := db.GetExamples(context.Background(), doc, 1) + req := &v1alpha1.GenerateRequest{ + Doc: doc, + SelectedIndex: 0, + } + examples, err := db.GetExamples(context.Background(), req, 1) if err != nil { t.Fatalf("Error getting examples; %v", err) } diff --git a/app/pkg/learn/learner.go b/app/pkg/learn/learner.go index ad7b8b0a..0a2a254d 100644 --- a/app/pkg/learn/learner.go +++ b/app/pkg/learn/learner.go @@ -6,7 +6,8 @@ import ( "io" "strings" "sync" - "time" + + "github.com/jlewi/foyle/app/pkg/docs" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -20,7 +21,6 @@ import ( logspb "github.com/jlewi/foyle/protos/go/foyle/logs" "github.com/jlewi/foyle/app/pkg/config" - "github.com/jlewi/foyle/app/pkg/docs" "github.com/jlewi/foyle/app/pkg/logs" "github.com/jlewi/foyle/app/pkg/oai" "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" @@ -59,18 +59,22 @@ type Learner struct { postFunc PostLearnEvent eventLoopIsDone sync.WaitGroup factory *files.Factory + vectorizer *oai.Vectorizer } func NewLearner(cfg config.Config, client *openai.Client, blocksDB *dbutil.LockingDB[*logspb.BlockLog]) (*Learner, error) { if client == nil { return nil, errors.New("OpenAI client is required") } + + vectorizer := oai.NewVectorizer(client) return &Learner{ - Config: cfg, - client: client, - blocksDB: blocksDB, - queue: workqueue.NewDelayingQueue(), - factory: &files.Factory{}, + Config: cfg, + client: client, + blocksDB: blocksDB, + queue: workqueue.NewDelayingQueue(), + factory: &files.Factory{}, + vectorizer: vectorizer, }, nil } @@ -113,10 +117,10 @@ func (l *Learner) eventLoop(ctx context.Context) { } if err := l.Reconcile(ctx, exampleId); err != nil { + // N.B. Right now we treat learning errors as permanent and don't retry. + // The most likely source of retryable errors the vectorizer endpoint should already be handled + // by using a retryable HTTP client. log.Error(err, "Error learning from example", "example", exampleId) - // Requeue the item so we will try again. - // TODO(jeremy): should we use a rate limiting queue so we eventually give up? - l.queue.AddAfter(exampleId, 30*time.Second) return } }() @@ -180,17 +184,31 @@ func (l *Learner) Reconcile(ctx context.Context, id string) error { if len(expectedFiles) == 0 { cellsProcessed.WithLabelValues("noExampleFiles").Inc() - log.Error(err, "No training files found", "id", b.GetId()) + log.Error(err, "No training files found", "blockId", b.GetId()) return errors.Wrapf(err, "No training files found for example %s", b.GetId()) } // TODO(jeremy): Should we take into account execution status when looking for mistakes? // Deep copy the original message - newDoc := proto.Clone(b.Doc).(*v1alpha1.Doc) newBlock := proto.Clone(b.ExecutedBlock).(*v1alpha1.Block) answer := []*v1alpha1.Block{newBlock} + req := &v1alpha1.GenerateRequest{ + Doc: b.Doc, + SelectedIndex: int32(len(b.Doc.Blocks) - 1), + } + queryBlocks, err := docs.CreateQuery(ctx, req) + + newDoc := &v1alpha1.Doc{ + Blocks: queryBlocks, + } + + if err != nil { + log.Error(err, "Failed to create query", "exampleId", b.GetId()) + return errors.Wrapf(err, "Failed to create query for example %s", b.GetId()) + } + example := &v1alpha1.Example{ Id: b.GetId(), Query: newDoc, @@ -275,30 +293,17 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp return nil } - query := docs.DocToMarkdown(example.Query) + qVec, err := l.vectorizer.Embed(ctx, example.Query.GetBlocks()) - request := openai.EmbeddingRequestStrings{ - Input: []string{query}, - Model: openai.SmallEmbedding3, - User: "", - EncodingFormat: "float", - } - resp, err := l.client.CreateEmbeddings(ctx, request) if err != nil { - log.Error(err, "Failed to create embeddings", "id", example.Id, "query", query) - return errors.Wrapf(err, "Failed to create embeddings") - } - - if len(resp.Data) != 1 { - log.Error(err, "Expected exactly 1 embedding", "id", example.Id, "query", query, "got", len(resp.Data)) - return errors.Errorf("Expected exactly 1 embedding but got %d", len(resp.Data)) + return err } - if len(resp.Data[0].Embedding) != oai.SmallEmbeddingsDims { - log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", query, "got", len(resp.Data[0].Embedding), "want", oai.SmallEmbeddingsDims) - return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), oai.SmallEmbeddingsDims) + if len(qVec) != oai.SmallEmbeddingsDims { + log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", example.Query, "got", len(qVec), "want", oai.SmallEmbeddingsDims) + return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(qVec), oai.SmallEmbeddingsDims) } - example.Embedding = resp.Data[0].Embedding + example.Embedding = qVec return nil } diff --git a/app/pkg/learn/learner_test.go b/app/pkg/learn/learner_test.go index 5646c608..6272e922 100644 --- a/app/pkg/learn/learner_test.go +++ b/app/pkg/learn/learner_test.go @@ -32,7 +32,7 @@ func Test_Learner(t *testing.T) { blocksDB, err := pebble.Open(cfg.GetBlocksDBDir(), &pebble.Options{}) if err != nil { - t.Fatalf("could not open blocks database %s", cfg.GetBlocksDBDir()) + t.Fatalf("could not open blocks database %+v", cfg.GetBlocksDBDir()) } defer blocksDB.Close() diff --git a/app/pkg/llms/query.go b/app/pkg/llms/query.go new file mode 100644 index 00000000..d03cb402 --- /dev/null +++ b/app/pkg/llms/query.go @@ -0,0 +1 @@ +package llms diff --git a/app/pkg/llms/vectorizer.go b/app/pkg/llms/vectorizer.go index c508a4d5..169505d5 100644 --- a/app/pkg/llms/vectorizer.go +++ b/app/pkg/llms/vectorizer.go @@ -3,13 +3,27 @@ package llms import ( "context" + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" + "gonum.org/v1/gonum/mat" ) +type Vector []float32 + // Vectorizer computes embedding representations of text. type Vectorizer interface { - // Embed computes the embedding of the text - Embed(ctx context.Context, text string) (*mat.VecDense, error) + // Embed computes the embedding of the blocks + Embed(ctx context.Context, blocks []*v1alpha1.Block) (Vector, error) // Length returns the length of the embeddings Length() int } + +// VectorToVecDense converts a Vector to a *mat.VecDense +func VectorToVecDense(v Vector) *mat.VecDense { + // We need to cast from float32 to float64 + qVec := mat.NewVecDense(len(v), nil) + for i := 0; i < len(v); i++ { + qVec.SetVec(i, float64(v[i])) + } + return qVec +} diff --git a/app/pkg/oai/embeddings.go b/app/pkg/oai/embeddings.go index c391c5bb..37c7d0d7 100644 --- a/app/pkg/oai/embeddings.go +++ b/app/pkg/oai/embeddings.go @@ -3,10 +3,13 @@ package oai import ( "context" + "github.com/jlewi/foyle/app/pkg/docs" + "github.com/jlewi/foyle/app/pkg/llms" + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" + "github.com/jlewi/foyle/app/pkg/logs" "github.com/pkg/errors" "github.com/sashabaranov/go-openai" - "gonum.org/v1/gonum/mat" ) func NewVectorizer(client *openai.Client) *Vectorizer { @@ -19,7 +22,10 @@ type Vectorizer struct { client *openai.Client } -func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, error) { +func (v *Vectorizer) Embed(ctx context.Context, blocks []*v1alpha1.Block) (llms.Vector, error) { + text := docs.BlocksToMarkdown(blocks) + + // Compute the embedding for the query. log := logs.FromContext(ctx) log.Info("RAG Query", "query", text) request := openai.EmbeddingRequestStrings{ @@ -29,6 +35,7 @@ func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, err EncodingFormat: "float", } + // N.B. regarding retries. We should already be doing retries in the HTTP client. resp, err := v.client.CreateEmbeddings(ctx, request) if err != nil { return nil, errors.Errorf("Failed to create embeddings") @@ -42,12 +49,7 @@ func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, err return nil, errors.Errorf("Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), SmallEmbeddingsDims) } - // Compute the cosine similarity between the query and each example. - qVec := mat.NewVecDense(SmallEmbeddingsDims, nil) - for i := 0; i < SmallEmbeddingsDims; i++ { - qVec.SetVec(i, float64(resp.Data[0].Embedding[i])) - } - return qVec, nil + return resp.Data[0].Embedding, nil } func (v *Vectorizer) Length() int { diff --git a/app/pkg/oai/errors.go b/app/pkg/oai/errors.go index cd5c3536..1de26dff 100644 --- a/app/pkg/oai/errors.go +++ b/app/pkg/oai/errors.go @@ -1,6 +1,9 @@ package oai -import "github.com/sashabaranov/go-openai" +import ( + "github.com/pkg/errors" + "github.com/sashabaranov/go-openai" +) const ( // ContextLengthExceededCode the error code returned by OpenAI to indicate the context length was exceeded @@ -21,3 +24,13 @@ func ErrorIs(err error, oaiCode string) bool { return val == oaiCode } + +// HTTPStatusCode returns the HTTP status code from the error if it is an OpenAI error. +// Returns -1 if its not of type APIError. +func HTTPStatusCode(err error) int { + target := &openai.APIError{} + if !errors.As(err, &target) { + return -1 + } + return target.HTTPStatusCode +} diff --git a/app/pkg/oai/errors_test.go b/app/pkg/oai/errors_test.go new file mode 100644 index 00000000..e5e32299 --- /dev/null +++ b/app/pkg/oai/errors_test.go @@ -0,0 +1,47 @@ +package oai + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/sashabaranov/go-openai" +) + +func Test_HTTPStatusCode(t *testing.T) { + type testCase struct { + name string + err error + expected int + } + + cases := []testCase{ + { + name: "basic", + err: &openai.APIError{ + HTTPStatusCode: 404, + }, + expected: 404, + }, + { + name: "wrapped", + err: errors.Wrapf(&openai.APIError{ + HTTPStatusCode: 509, + }, "wrapped"), + expected: 509, + }, + { + name: "not api error", + err: errors.New("not an api error"), + expected: -1, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + actual := HTTPStatusCode(tc.err) + if actual != tc.expected { + t.Errorf("expected %v, got %v", tc.expected, actual) + } + }) + } +}