Skip to content

Commit

Permalink
Ensure embeddings matrix is never initialized with zero dimensions
Browse files Browse the repository at this point in the history
* Because we divide by 1.5; in the edge case were we have 1 example we'd
  end up rounding down to 0 which will cause problems when we try to initialize a matrix to that size.
Fix #248
  • Loading branch information
jlewi committed Sep 20, 2024
1 parent e14d50a commit b1f1886
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
19 changes: 14 additions & 5 deletions app/pkg/learn/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,8 @@ func (db *InMemoryExampleDB) loadExamples(ctx context.Context) error {
db.examples = make([]*v1alpha1.Example, 0, len(matches))
}

// We intentionally initialize an initial matrix which is too small so that during the initial load
// grow will be triggered. Since we grow by a factor of two we should end up with an overallocated matrix
// This means that by default the matrix should contain extra rows that haven't been populated with examples
// yet. This way we can verify that doesn't trip up rag
if db.embeddings == nil {
db.embeddings = mat.NewDense(int(float32(len(matches))/1.5), oai.SmallEmbeddingsDims, nil)
db.embeddings = mat.NewDense(initialNumberOfRows(len(matches)), oai.SmallEmbeddingsDims, nil)
}

// Load the examples.
Expand All @@ -229,6 +225,19 @@ func (db *InMemoryExampleDB) loadExamples(ctx context.Context) error {
return nil
}

func initialNumberOfRows(numExamples int) int {
// We intentionally initialize an initial matrix which is too small so that during the initial load
// grow will be triggered. Since we grow by a factor of two we should end up with an overallocated matrix
// This means that by default the matrix should contain extra rows that haven't been populated with examples
// yet. This way we can verify that doesn't trip up rag
size := int(float32(numExamples) / 1.5)
// If size is < 1 we end up initializing an empty matrix which will cause a panic
if size < 1 {
size = 1
}
return size
}

func (db *InMemoryExampleDB) Shutdown(ctx context.Context) error {
log := logs.FromContext(ctx)

Expand Down
35 changes: 35 additions & 0 deletions app/pkg/learn/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,38 @@ func Test_UpdateExample(t *testing.T) {
})
}
}

func Test_initialNumberOfRows(t *testing.T) {
type testCase struct {
name string
numExamples int
expected int
}

cases := []testCase{
{
name: "edge-case-1",
numExamples: 1,
expected: 1,
},
{
name: "small",
numExamples: 10,
expected: 6,
},
{
name: "large",
numExamples: 100,
expected: 66,
},
}

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
actual := initialNumberOfRows(c.numExamples)
if actual != c.expected {
t.Errorf("Expected %v but got %v", c.expected, actual)
}
})
}
}

0 comments on commit b1f1886

Please sign in to comment.