Skip to content

Commit

Permalink
enhance: Support Search GroupBy feature (#662)
Browse files Browse the repository at this point in the history
See also milvus-io/milvus#25324 milvus-io/milvus#28983

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
  • Loading branch information
congqixia authored Jan 31, 2024
1 parent b3abfbe commit 1e03ea4
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 65 deletions.
4 changes: 2 additions & 2 deletions client/client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,8 @@ func (m *MockServer) Upsert(ctx context.Context, req *milvuspb.UpsertRequest) (*
return &milvuspb.MutationResult{Status: s}, err
}

func (m *MockServer) SearchV2(ctx context.Context, req *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error) {
f := m.GetInjection(MUpsert)
func (m *MockServer) HybridSearch(ctx context.Context, req *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
f := m.GetInjection(MSearchV2)
if f != nil {
r, err := f(ctx, req)
return r.(*milvuspb.SearchResults), err
Expand Down
14 changes: 14 additions & 0 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ const (
limitKey = `limit`
ignoreGrowingKey = `ignore_growing`
forTuningKey = `for_tuning`
groupByKey = `group_by_field`
)

// Search with bool expression
Expand Down Expand Up @@ -74,12 +75,23 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
gb := results.GetGroupByFieldValue()
var gbc entity.Column
if gb != nil {
gbc, err = entity.FieldDataColumn(gb, 0, -1)
if err != nil {
return nil, err
}
}
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
if gbc != nil {
entry.GroupByValue, _ = gbc.Get(i)
}
// parse result set if current nq is not empty
if rc > 0 {
entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc)
Expand Down Expand Up @@ -335,7 +347,9 @@ func prepareSearchRequest(collName string, partitions []string,
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
groupByKey: opt.GroupByField,
})

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
Expand Down
41 changes: 41 additions & 0 deletions client/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ func (s *SearchSuite) SetupSuite() {

s.sch = entity.NewSchema().WithName(testCollectionName).
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("Attr").WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim))
s.schDynamic = entity.NewSchema().WithName(testCollectionName).WithDynamicFieldEnabled(true).
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
Expand Down Expand Up @@ -531,6 +532,46 @@ func (s *SearchSuite) TestSearchSuccess() {
s.NoError(err)
s.Equal("abc", str)
})

s.Run("group_by", func() {
defer s.resetMock()
s.setupDescribeCollection(testCollectionName, s.sch)
s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")).
Run(func(_ context.Context, req *milvuspb.SearchRequest) {
s.Equal(testCollectionName, req.GetCollectionName())
s.Equal(expr, req.GetDsl())
s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType())
s.ElementsMatch([]string{"ID"}, req.GetOutputFields())
s.ElementsMatch([]string{partName}, req.GetPartitionNames())
}).
Return(&milvuspb.SearchResults{
Status: getSuccessStatus(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 10,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
Scores: make([]float32, 10),
Topks: []int64{10},
GroupByFieldValue: s.getInt64FieldData("Attr", []int64{10}),
},
}, nil)

r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])},
testVectorField, entity.L2, 10, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(10000000000), WithGroupByField("Attr"))
s.NoError(err)
s.Require().Equal(1, len(r))
result := r[0]
s.Require().NotNil(result.Fields.GetColumn("ID"))
})
}

func TestSearch(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ type SearchQueryOption struct {

IgnoreGrowing bool
ForTuning bool

GroupByField string
}

// SearchQueryOptionFunc is a function which modifies SearchOption
Expand Down Expand Up @@ -163,6 +165,12 @@ func WithLimit(limit int64) SearchQueryOptionFunc {
}
}

func WithGroupByField(groupByField string) SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
option.GroupByField = groupByField
}
}

// WithSearchQueryConsistencyLevel specifies consistency level
func WithSearchQueryConsistencyLevel(cl entity.ConsistencyLevel) SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
Expand Down
11 changes: 6 additions & 5 deletions client/results.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ import "github.com/milvus-io/milvus-sdk-go/v2/entity"
// Fields contains the data of `outputFieleds` specified or all columns if non
// Scores is actually the distance between the vector current record contains and the search target vector
type SearchResult struct {
ResultCount int // the returning entry count
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields ResultSet //[]entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
ResultCount int // the returning entry count
GroupByValue interface{}
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields ResultSet //[]entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}

// ResultSet is an alias type for column slice.
Expand Down
129 changes: 129 additions & 0 deletions examples/groupby/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package main

import (
"context"
"log"
"math/rand"
"time"

"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

const (
milvusAddr = `localhost:19530`
nEntities, dim = 10000, 128
collectionName = "hello_group_by"

idCol, keyCol, embeddingCol = "ID", "key", "embeddings"
topK = 3
)

func main() {
ctx := context.Background()

log.Println("start connecting to Milvus")
c, err := client.NewClient(ctx, client.Config{
Address: milvusAddr,
})
if err != nil {
log.Fatalf("failed to connect to milvus, err: %v", err)
}
defer c.Close()

// delete collection if exists
has, err := c.HasCollection(ctx, collectionName)
if err != nil {
log.Fatalf("failed to check collection exists, err: %v", err)
}
if has {
c.DropCollection(ctx, collectionName)
}

// create collection
log.Printf("create collection `%s`\n", collectionName)
schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs").
WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName(embeddingCol).WithDataType(entity.FieldTypeFloatVector).WithDim(dim))

if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber); err != nil { // use default shard number
log.Fatalf("create collection failed, err: %v", err)
}

var keyList []int64
var embeddingList [][]float32
keyList = make([]int64, 0, nEntities)
embeddingList = make([][]float32, 0, nEntities)
for i := 0; i < nEntities; i++ {
keyList = append(keyList, rand.Int63()%512)
}
for i := 0; i < nEntities; i++ {
vec := make([]float32, 0, dim)
for j := 0; j < dim; j++ {
vec = append(vec, rand.Float32())
}
embeddingList = append(embeddingList, vec)
}
keyColData := entity.NewColumnInt64(keyCol, keyList)
embeddingColData := entity.NewColumnFloatVector(embeddingCol, dim, embeddingList)

log.Println("start to insert data into collection")

if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData); err != nil {
log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err)
}

log.Println("insert data done, start to flush")

if err := c.Flush(ctx, collectionName, false); err != nil {
log.Fatalf("failed to flush data, err: %v", err)
}
log.Println("flush data done")

// build index
log.Println("start creating index HNSW")
idx, err := entity.NewIndexHNSW(entity.L2, 16, 256)
if err != nil {
log.Fatalf("failed to create ivf flat index, err: %v", err)
}
if err := c.CreateIndex(ctx, collectionName, embeddingCol, idx, false); err != nil {
log.Fatalf("failed to create index, err: %v", err)
}

log.Printf("build HNSW index done for collection `%s`\n", collectionName)
log.Printf("start to load collection `%s`\n", collectionName)

// load collection
if err := c.LoadCollection(ctx, collectionName, false); err != nil {
log.Fatalf("failed to load collection, err: %v", err)
}

log.Println("load collection done")

vec2search := []entity.Vector{
entity.FloatVector(embeddingList[len(embeddingList)-2]),
entity.FloatVector(embeddingList[len(embeddingList)-1]),
}
begin := time.Now()
sp, _ := entity.NewIndexHNSWSearchParam(30)
result, err := c.Search(ctx, collectionName, nil, "", []string{keyCol, embeddingCol}, vec2search,
embeddingCol, entity.L2, topK, sp, client.WithGroupByField(keyCol))
if err != nil {
log.Fatalf("failed to search collection, err: %v", err)
}

log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin))
for _, rs := range result {
log.Printf("GroupByValue: %v\n", rs.GroupByValue)
for i := 0; i < rs.ResultCount; i++ {
id, _ := rs.IDs.GetAsInt64(i)
score := rs.Scores[i]
embedding, _ := rs.Fields.GetColumn(embeddingCol).Get(i)

log.Printf("ID: %d, score %f, embedding: %v\n", id, score, embedding)
}
}

c.DropCollection(ctx, collectionName)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ require (
github.com/go-faker/faker/v4 v4.1.0
github.com/golang/protobuf v1.5.2
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1
github.com/stretchr/testify v1.8.1
github.com/tidwall/gjson v1.14.4
google.golang.org/grpc v1.48.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8=
github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4 h1:HtNGcUb52ojnl+zDAZMmbHyVaTdBjzuCnnBHpb675TU=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1 h1:oNpMivd94JAMhdSVsFw8t1b+olXz8pbzd5PES21sth8=
github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down
Loading

0 comments on commit 1e03ea4

Please sign in to comment.