Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: WhereDocument filter with $and, $or, $contains and $not_contains filters #96

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"path/filepath"
"slices"
"sync"
)

Expand Down Expand Up @@ -64,7 +63,7 @@ type QueryOptions struct {
Where map[string]string

// Conditional filtering on documents.
WhereDocument map[string]string
WhereDocuments []WhereDocument

// Negative is the negative query options.
// They can be used to exclude certain results from the query.
Expand Down Expand Up @@ -296,19 +295,19 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
// - ids: The ids of the documents to delete. If empty, all documents are deleted.
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {
func (c *Collection) Delete(_ context.Context, where map[string]string, whereDocuments []WhereDocument, ids ...string) error {
// must have at least one of where, whereDocument or ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
if len(where) == 0 && len(whereDocuments) == 0 && len(ids) == 0 {
return fmt.Errorf("must have at least one of where, whereDocument or ids")
}

if len(c.documents) == 0 {
return nil
}

for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return errors.New("unsupported whereDocument operator")
for _, whereDocument := range whereDocuments {
if err := whereDocument.Validate(); err != nil {
return fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err)
}
}

Expand All @@ -317,9 +316,9 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s
c.documentsLock.Lock()
defer c.documentsLock.Unlock()

if where != nil || whereDocument != nil {
if where != nil || len(whereDocuments) > 0 {
// metadata + content filters
filteredDocs := filterDocs(c.documents, where, whereDocument)
filteredDocs := filterDocs(c.documents, where, whereDocuments)
for _, doc := range filteredDocs {
docIDs = append(docIDs, doc.ID)
}
Expand Down Expand Up @@ -376,7 +375,7 @@ type Result struct {
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where map[string]string, whereDocument []WhereDocument) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText is empty")
}
Expand Down Expand Up @@ -432,7 +431,7 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
}
}

result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocument)
result, err := c.queryEmbedding(ctx, queryVector, negativeVector, negativeFilterThreshold, options.NResults, options.Where, options.WhereDocuments)
if err != nil {
return nil, err
}
Expand All @@ -449,12 +448,12 @@ func (c *Collection) QueryWithOptions(ctx context.Context, options QueryOptions)
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocument)
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) {
return c.queryEmbedding(ctx, queryEmbedding, nil, 0, nResults, where, whereDocuments)
}

// queryEmbedding performs an exhaustive nearest neighbor search on the collection.
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativeEmbeddings []float32, negativeFilterThreshold float32, nResults int, where map[string]string, whereDocuments []WhereDocument) ([]Result, error) {
if len(queryEmbedding) == 0 {
return nil, errors.New("queryEmbedding is empty")
}
Expand All @@ -472,14 +471,14 @@ func (c *Collection) queryEmbedding(ctx context.Context, queryEmbedding, negativ
}

// Validate whereDocument operators
for k := range whereDocument {
if !slices.Contains(supportedFilters, k) {
return nil, errors.New("unsupported operator")
for _, whereDocument := range whereDocuments {
if err := whereDocument.Validate(); err != nil {
return nil, fmt.Errorf("invalid whereDocument %#v: %w", whereDocument, err)
}
}

// Filter docs by metadata and content
filteredDocs := filterDocs(c.documents, where, whereDocument)
filteredDocs := filterDocs(c.documents, where, whereDocuments)

// No need to continue if the filters got rid of all documents
if len(filteredDocs) == 0 {
Expand Down
9 changes: 5 additions & 4 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"slices"
"strconv"
"strings"
"testing"
)

Expand Down Expand Up @@ -372,10 +373,10 @@ func TestCollection_QueryError(t *testing.T) {
{
name: "Bad content filter",
query: func() error {
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
_, err := c.Query(context.Background(), "foo", 1, nil, []WhereDocument{{Operator: "invalid", Value: "foo"}})
return err
},
expErr: "unsupported operator",
expErr: "unsupported where document operator invalid",
},
}

Expand All @@ -384,7 +385,7 @@ func TestCollection_QueryError(t *testing.T) {
err := tc.query()
if err == nil {
t.Fatal("expected error, got nil")
} else if err.Error() != tc.expErr {
} else if !strings.Contains(err.Error(), tc.expErr) {
t.Fatal("expected", tc.expErr, "got", err)
}
})
Expand Down Expand Up @@ -502,7 +503,7 @@ func TestCollection_Delete(t *testing.T) {
checkCount(1)

// Test 3 - Remove document by content
err = c.Delete(context.Background(), nil, map[string]string{"$contains": "hallo welt"})
err = c.Delete(context.Background(), nil, []WhereDocument{WhereDocument{Operator: WhereDocumentOperatorContains, Value: "hallo welt"}})
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down
104 changes: 83 additions & 21 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"sync"
)

var supportedFilters = []string{"$contains", "$not_contains"}

type docSim struct {
docID string
similarity float32
Expand Down Expand Up @@ -84,7 +82,7 @@ func (d *maxDocSims) values() []docSim {

// filterDocs filters a map of documents by metadata and content.
// It does this concurrently.
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
func filterDocs(docs map[string]*Document, where map[string]string, whereDocuments []WhereDocument) []*Document {
filteredDocs := make([]*Document, 0, len(docs))
filteredDocsLock := sync.Mutex{}

Expand All @@ -104,7 +102,7 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin
go func() {
defer wg.Done()
for doc := range docChan {
if documentMatchesFilters(doc, where, whereDocument) {
if documentMatchesFilters(doc, where, whereDocuments) {
filteredDocsLock.Lock()
filteredDocs = append(filteredDocs, doc)
filteredDocsLock.Unlock()
Expand All @@ -128,9 +126,84 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin
return filteredDocs
}

type WhereDocumentOperator string

const (
WhereDocumentOperatorContains WhereDocumentOperator = "$contains"
WhereDocumentOperatorNotContains WhereDocumentOperator = "$not_contains"
WhereDocumentOperatorOr WhereDocumentOperator = "$or"
WhereDocumentOperatorAnd WhereDocumentOperator = "$and"
)

type WhereDocument struct {
Operator WhereDocumentOperator
Value string
WhereDocuments []WhereDocument
}

func (wd *WhereDocument) Validate() error {

if !slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains, WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) {
return fmt.Errorf("unsupported where document operator %s", wd.Operator)
}

if wd.Operator == "" {
return fmt.Errorf("where document operator is empty")
}

// $contains and $not_contains require a string value
if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorContains, WhereDocumentOperatorNotContains}, wd.Operator) {
if wd.Value == "" {
return fmt.Errorf("where document operator %s requires a value", wd.Operator)
}
}

// $or requires sub-filters
if slices.Contains([]WhereDocumentOperator{WhereDocumentOperatorOr, WhereDocumentOperatorAnd}, wd.Operator) {
if len(wd.WhereDocuments) == 0 {
return fmt.Errorf("where document operator %s must have at least one sub-filter", wd.Operator)
}
}

for _, wd := range wd.WhereDocuments {
if err := wd.Validate(); err != nil {
return err
}
}

return nil
}

// Matches checks if a document matches the WhereDocument filter(s)
// There is no error checking on the WhereDocument struct, so it must be validated before calling this function.
func (wd *WhereDocument) Matches(doc *Document) bool {
switch wd.Operator {
case WhereDocumentOperatorContains:
return strings.Contains(doc.Content, wd.Value)
case WhereDocumentOperatorNotContains:
return !strings.Contains(doc.Content, wd.Value)
case WhereDocumentOperatorOr:
for _, subFilter := range wd.WhereDocuments {
if subFilter.Matches(doc) {
return true
}
}
return false
case WhereDocumentOperatorAnd:
for _, subFilter := range wd.WhereDocuments {
if !subFilter.Matches(doc) {
return false
}
}
return true
default:
return false
}
}

// documentMatchesFilters checks if a document matches the given filters.
// When calling this function, the whereDocument keys must already be validated!
func documentMatchesFilters(document *Document, where, whereDocument map[string]string) bool {
// When calling this function, the whereDocument structs must already be validated!
func documentMatchesFilters(document *Document, where map[string]string, whereDocuments []WhereDocument) bool {
// A document's metadata must have *all* the fields in the where clause.
for k, v := range where {
// TODO: Do we want to check for existence of the key? I.e. should
Expand All @@ -141,21 +214,10 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
}
}

// A document must satisfy *all* filters, until we support the `$or` operator.
for k, v := range whereDocument {
switch k {
case "$contains":
if !strings.Contains(document.Content, v) {
return false
}
case "$not_contains":
if strings.Contains(document.Content, v) {
return false
}
default:
// No handling (error) required because we already validated the
// operators. This simplifies the concurrency logic (no err var
// and lock, no context to cancel).
// A document must satisfy *all* WhereDocument filters (that's basically a top-level $and operator)
for _, whereDocument := range whereDocuments {
if !whereDocument.Matches(document) {
return false
}
}

Expand Down
Loading