Skip to content

Commit

Permalink
enhance: Support hybrid search multiple vector fields
Browse files Browse the repository at this point in the history
See also: milvus-io/milvus#25639
milvus pr: #29433

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
  • Loading branch information
congqixia committed Feb 1, 2024
1 parent 1e03ea4 commit 688f2de
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 1 deletion.
72 changes: 72 additions & 0 deletions client/ann_request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package client

import (
"encoding/json"
"fmt"
"strconv"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

type ANNSearchRequest struct {
fieldName string
vectors []entity.Vector
metricType entity.MetricType
expr string
searchParam entity.SearchParam
options []SearchQueryOptionFunc
limit int
}

func NewANNSearchRequest(fieldName string, metricsType entity.MetricType, vectors []entity.Vector, searchParam entity.SearchParam, limit int, options ...SearchQueryOptionFunc) *ANNSearchRequest {
return &ANNSearchRequest{
fieldName: fieldName,
vectors: vectors,
metricType: metricsType,
searchParam: searchParam,
limit: limit,
}
}
func (r *ANNSearchRequest) WithExpr(expr string) *ANNSearchRequest {
r.expr = expr
return r
}

func (req *ANNSearchRequest) getMilvusSearchRequest(collectionInfo *collInfo) (*milvuspb.SearchRequest, error) {
opt := &SearchQueryOption{
ConsistencyLevel: collectionInfo.ConsistencyLevel, // default
}
for _, o := range req.options {
o(opt)
}
params := req.searchParam.Params()
params[forTuningKey] = opt.ForTuning
bs, err := json.Marshal(params)
if err != nil {
return nil, err
}

searchParams := entity.MapKvPairs(map[string]string{
"anns_field": req.fieldName,
"topk": fmt.Sprintf("%d", req.limit),
"params": string(bs),
"metric_type": string(req.metricType),
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
groupByKey: opt.GroupByField,
})

result := &milvuspb.SearchRequest{
DbName: "",
Dsl: req.expr,
PlaceholderGroup: vector2PlaceholderGroupBytes(req.vectors),
DslType: commonpb.DslType_BoolExprV1,
SearchParams: searchParams,
GuaranteeTimestamp: opt.GuaranteeTimestamp,
Nq: int64(len(req.vectors)),
}
return result, nil
}
2 changes: 2 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ type Client interface {
msgsBytes [][]byte, startPositions, endPositions []*msgpb.MsgPosition,
opts ...ReplicateMessageOption,
) (*entity.MessageInfo, error)

HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error)
}

// NewClient create a client connected to remote milvus cluster.
Expand Down
62 changes: 61 additions & 1 deletion client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/milvus-io/milvus-sdk-go/v2/merr"
)

const (
Expand All @@ -35,6 +36,59 @@ const (
groupByKey = `group_by_field`
)

func (c *GrpcClient) HybridSearch(ctx context.Context, collName string, partitions []string, limit int, outputFields []string, reranker Reranker, subRequests []*ANNSearchRequest) ([]SearchResult, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}

var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)
if !ok {
coll, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

sReqs := make([]*milvuspb.SearchRequest, 0, len(subRequests))
nq := 0
for _, subRequest := range subRequests {
r, err := subRequest.getMilvusSearchRequest(collInfo)
if err != nil {
return nil, err
}
r.CollectionName = collName
r.PartitionNames = partitions
r.OutputFields = outputFields
nq = len(subRequest.vectors)
sReqs = append(sReqs, r)
}

params := reranker.GetParams()
params = append(params, &commonpb.KeyValuePair{Key: limitKey, Value: strconv.FormatInt(int64(limit), 10)})

req := &milvuspb.HybridSearchRequest{
CollectionName: collName,
PartitionNames: partitions,
Requests: sReqs,
OutputFields: outputFields,
ConsistencyLevel: commonpb.ConsistencyLevel(collInfo.ConsistencyLevel),
RankParams: params,
}

result, err := c.Service.HybridSearch(ctx, req)

err = merr.CheckRPCCall(result, err)
if err != nil {
return nil, err
}

return c.handleSearchResult(schema, outputFields, nq, result)
}

// Search with bool expression
func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error) {
Expand Down Expand Up @@ -63,7 +117,6 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}

sr := make([]SearchResult, 0, len(vectors))
resp, err := c.Service.Search(ctx, req)
if err != nil {
return nil, err
Expand All @@ -72,6 +125,13 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
return nil, err
}
// 3. parse result into result
return c.handleSearchResult(schema, outputFields, len(vectors), resp)
}

func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []string, nq int, resp *milvuspb.SearchResults) ([]SearchResult, error) {
var err error
sr := make([]SearchResult, 0, nq)
// 3. parse result into result
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
Expand Down
62 changes: 62 additions & 0 deletions client/reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package client

import (
"encoding/json"

"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
)

const (
rerankType = "strategy"
rerankParams = "params"
rffParam = "k"
weightedParam = "weights"

rrfRerankType = `rrf`
weightedRerankType = `weighted`
)

type Reranker interface {
GetParams() []*commonpb.KeyValuePair
}

type rrfReranker struct {
K float64 `json:"k,omitempty"`
}

func (r *rrfReranker) WithK(k float64) *rrfReranker {
r.K = k
return r
}

func (r *rrfReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewRRFReranker() *rrfReranker {
return &rrfReranker{K: 60}
}

type weightedReranker struct {
Weights []float64 `json:"weights,omitempty"`
}

func (r *weightedReranker) GetParams() []*commonpb.KeyValuePair {
bs, _ := json.Marshal(r)

return []*commonpb.KeyValuePair{
{Key: rerankType, Value: rrfRerankType},
{Key: rerankParams, Value: string(bs)},
}
}

func NewWeightedReranker(weights []float64) *weightedReranker {
return &weightedReranker{
Weights: weights,
}
}
161 changes: 161 additions & 0 deletions examples/multivectors/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package main

import (
"context"
"log"
"math/rand"
"time"

"github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

const (
milvusAddr = `localhost:19530`
nEntities, dim = 10000, 128
collectionName = "hello_multi_vectors"

idCol, keyCol, embeddingCol1, embeddingCol2 = "ID", "key", "vector1", "vector2"
topK = 3
)

func main() {
ctx := context.Background()

log.Println("start connecting to Milvus")
c, err := client.NewClient(ctx, client.Config{
Address: milvusAddr,
})
if err != nil {
log.Fatalf("failed to connect to milvus, err: %v", err)
}
defer c.Close()

// delete collection if exists
has, err := c.HasCollection(ctx, collectionName)
if err != nil {
log.Fatalf("failed to check collection exists, err: %v", err)
}
if has {
c.DropCollection(ctx, collectionName)
}

// create collection
log.Printf("create collection `%s`\n", collectionName)
schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs").
WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName(embeddingCol1).WithDataType(entity.FieldTypeFloatVector).WithDim(dim)).
WithField(entity.NewField().WithName(embeddingCol2).WithDataType(entity.FieldTypeFloatVector).WithDim(dim))

if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber); err != nil { // use default shard number
log.Fatalf("create collection failed, err: %v", err)
}

var keyList []int64
var embeddingList [][]float32
keyList = make([]int64, 0, nEntities)
embeddingList = make([][]float32, 0, nEntities)
for i := 0; i < nEntities; i++ {
keyList = append(keyList, rand.Int63()%512)
}
for i := 0; i < nEntities; i++ {
vec := make([]float32, 0, dim)
for j := 0; j < dim; j++ {
vec = append(vec, rand.Float32())
}
embeddingList = append(embeddingList, vec)
}
keyColData := entity.NewColumnInt64(keyCol, keyList)
embeddingColData1 := entity.NewColumnFloatVector(embeddingCol1, dim, embeddingList)
embeddingColData2 := entity.NewColumnFloatVector(embeddingCol2, dim, embeddingList)

log.Println("start to insert data into collection")

if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData1, embeddingColData2); err != nil {
log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err)
}

log.Println("insert data done, start to flush")

if err := c.Flush(ctx, collectionName, false); err != nil {
log.Fatalf("failed to flush data, err: %v", err)
}
log.Println("flush data done")

// build index
log.Println("start creating index HNSW")
idx, err := entity.NewIndexHNSW(entity.L2, 16, 256)
if err != nil {
log.Fatalf("failed to create ivf flat index, err: %v", err)
}
if err := c.CreateIndex(ctx, collectionName, embeddingCol1, idx, false); err != nil {
log.Fatalf("failed to create index, err: %v", err)
}
if err := c.CreateIndex(ctx, collectionName, embeddingCol2, idx, false); err != nil {
log.Fatalf("failed to create index, err: %v", err)
}

log.Printf("build HNSW index done for collection `%s`\n", collectionName)
log.Printf("start to load collection `%s`\n", collectionName)

// load collection
if err := c.LoadCollection(ctx, collectionName, false); err != nil {
log.Fatalf("failed to load collection, err: %v", err)
}

log.Println("load collection done")

// currently only nq =1 is supported
vec2search1 := []entity.Vector{
entity.FloatVector(embeddingList[len(embeddingList)-2]),
}
vec2search2 := []entity.Vector{
entity.FloatVector(embeddingList[len(embeddingList)-1]),
}

begin := time.Now()
sp, _ := entity.NewIndexHNSWSearchParam(30)

log.Println("start to search vector field 1")
result, err := c.Search(ctx, collectionName, nil, "", []string{keyCol, embeddingCol1, embeddingCol2}, vec2search1,
embeddingCol1, entity.L2, topK, sp)
if err != nil {
log.Fatalf("failed to search collection, err: %v", err)
}

log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin))
for _, rs := range result {
for i := 0; i < rs.ResultCount; i++ {
id, _ := rs.IDs.GetAsInt64(i)
score := rs.Scores[i]
embedding, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)

log.Printf("ID: %d, score %f, embedding: %v\n", id, score, embedding)
}
}

log.Println("start to execute hybrid search")

result, err = c.HybridSearch(ctx, collectionName, nil, topK, []string{keyCol, embeddingCol1, embeddingCol2},
client.NewRRFReranker(), []*client.ANNSearchRequest{
client.NewANNSearchRequest(embeddingCol1, entity.L2, vec2search1, sp, topK),
client.NewANNSearchRequest(embeddingCol2, entity.L2, vec2search2, sp, topK),
})
if err != nil {
log.Fatalf("failed to search collection, err: %v", err)
}

log.Printf("hybrid search `%s` done, latency %v\n", collectionName, time.Since(begin))
for _, rs := range result {
for i := 0; i < rs.ResultCount; i++ {
id, _ := rs.IDs.GetAsInt64(i)
score := rs.Scores[i]
embedding1, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)
embedding2, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)
log.Printf("ID: %d, score %f, embedding1: %v, embedding2: %v\n", id, score, embedding1, embedding2)
}
}

c.DropCollection(ctx, collectionName)
}

0 comments on commit 688f2de

Please sign in to comment.