Skip to content

Commit

Permalink
Support search & query with dynamic schema (#457)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
  • Loading branch information
congqixia committed May 22, 2023
1 parent 5f94bce commit b619b53
Show file tree
Hide file tree
Showing 12 changed files with 930 additions and 636 deletions.
25 changes: 18 additions & 7 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ type Client interface {
Search(ctx context.Context, collName string, partitions []string,
expr string, outputFields []string, vectors []entity.Vector, vectorField string, metricType entity.MetricType, topK int, sp entity.SearchParam, opts ...SearchQueryOptionFunc) ([]SearchResult, error)
// QueryByPks query record by specified primary key(s).
QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) ([]entity.Column, error)
QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error)
// Query performs query records with boolean expression.
Query(ctx context.Context, collectionName string, partitionNames []string, expr string, outputFields []string, opts ...SearchQueryOptionFunc) ([]entity.Column, error)
Query(ctx context.Context, collectionName string, partitionNames []string, expr string, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error)

// CalcDistance calculate the distance between vectors specified by ids or provided
CalcDistance(ctx context.Context, collName string, partitions []string,
Expand Down Expand Up @@ -198,11 +198,22 @@ type Client interface {
// Fields contains the data of `outputFieleds` specified or all columns if non
// Scores is actually the distance between the vector current record contains and the search target vector
type SearchResult struct {
ResultCount int // the returning entry count
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields []entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
ResultCount int // the returning entry count
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields ResultSet //[]entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}

type ResultSet []entity.Column

func (rs ResultSet) GetColumn(fieldName string) entity.Column {
for _, column := range rs {
if column.Name() == fieldName {
return column
}
}
return nil
}

var DefaultGrpcOpts = []grpc.DialOption{
Expand Down
72 changes: 71 additions & 1 deletion client/client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/federpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
server "github.com/milvus-io/milvus-proto/go-api/milvuspb"
schema "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/milvus-io/milvus-sdk-go/v2/mocks"
"github.com/stretchr/testify/mock"
Expand Down Expand Up @@ -133,13 +134,78 @@ func (s *MockSuiteBase) setupDescribeCollection(collName string, schema *entity.
}, nil)
}

func (s *MockSuiteBase) setupDescirbeCollectionError(errorCode common.ErrorCode, err error) {
func (s *MockSuiteBase) setupDescribeCollectionError(errorCode common.ErrorCode, err error) {
s.mock.EXPECT().DescribeCollection(mock.Anything, mock.AnythingOfType("*milvuspb.DescribeCollectionRequest")).
Return(&milvuspb.DescribeCollectionResponse{
Status: &common.Status{ErrorCode: errorCode},
}, err)
}

func (s *MockSuiteBase) getInt64FieldData(name string, data []int64) *schema.FieldData {
return &schema.FieldData{
Type: schema.DataType_Int64,
FieldName: name,
Field: &schema.FieldData_Scalars{
Scalars: &schema.ScalarField{
Data: &schema.ScalarField_LongData{
LongData: &schema.LongArray{
Data: data,
},
},
},
},
}
}

func (s *MockSuiteBase) getVarcharFieldData(name string, data []string) *schema.FieldData {
return &schema.FieldData{
Type: schema.DataType_VarChar,
FieldName: name,
Field: &schema.FieldData_Scalars{
Scalars: &schema.ScalarField{
Data: &schema.ScalarField_StringData{
StringData: &schema.StringArray{
Data: data,
},
},
},
},
}
}

func (s *MockSuiteBase) getJSONBytesFieldData(name string, data [][]byte) *schema.FieldData {
return &schema.FieldData{
Type: schema.DataType_JSON,
FieldName: name,
Field: &schema.FieldData_Scalars{
Scalars: &schema.ScalarField{
Data: &schema.ScalarField_JsonData{
JsonData: &schema.JSONArray{
Data: data,
},
},
},
},
}
}

func (s *MockSuiteBase) getFloatVectorFieldData(name string, dim int64, data []float32) *schema.FieldData {
return &schema.FieldData{
Type: schema.DataType_FloatVector,
FieldName: name,
Field: &schema.FieldData_Vectors{
Vectors: &schema.VectorField{
Dim: dim,
Data: &schema.VectorField_FloatVector{
FloatVector: &schema.FloatArray{
Data: data,
},
},
},
},
}
}

// ref https://stackoverflow.com/questions/42102496/testing-a-grpc-service

var (
Expand Down Expand Up @@ -832,6 +898,10 @@ func (m *MockServer) CheckHealth(ctx context.Context, req *server.CheckHealthReq
return &server.CheckHealthResponse{Status: s}, err
}

func getSuccessStatus() *common.Status {
return &common.Status{ErrorCode: common.ErrorCode_Success}
}

func SuccessStatus() (*common.Status, error) {
return &common.Status{ErrorCode: common.ErrorCode_Success}, nil
}
Expand Down
106 changes: 76 additions & 30 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
if c.Service == nil {
return []SearchResult{}, ErrClientNotReady
}
_, ok := MetaCache.getCollectionInfo(collName)
var schema *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collName)
if !ok {
c.DescribeCollection(ctx, collName)
coll, err := c.DescribeCollection(ctx, collName)
if err != nil {
return nil, err
}
schema = coll.Schema
} else {
schema = collInfo.Schema
}

option, err := makeSearchQueryOption(collName, opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -76,21 +84,55 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
offset += rc
continue
}
entry.Fields = make([]entity.Column, 0, len(fieldDataList))
for _, fieldData := range fieldDataList {
column, err := entity.FieldDataColumn(fieldData, offset, offset+rc)
if err != nil {
entry.Err = err
continue
}
entry.Fields = append(entry.Fields, column)
}
entry.Fields, entry.Err = c.parseSearchResult(schema, outputFields, fieldDataList, i, offset, offset+rc)
sr = append(sr, entry)
offset += rc
}
return sr, nil
}

func (c *GrpcClient) parseSearchResult(sch *entity.Schema, outputFields []string, fieldDataList []*schema.FieldData, idx, from, to int) ([]entity.Column, error) {
dynamicField := sch.GetDynamicField()
fields := make(map[string]*schema.FieldData)
var dynamicColumn *entity.ColumnJSONBytes
for _, fieldData := range fieldDataList {
fields[fieldData.GetFieldName()] = fieldData
if dynamicField != nil && fieldData.GetFieldName() == dynamicField.Name {
column, err := entity.FieldDataColumn(fieldData, from, to)
if err != nil {
return nil, err
}
var ok bool
dynamicColumn, ok = column.(*entity.ColumnJSONBytes)
if !ok {
return nil, errors.New("dynamic field not json")
}
}
}
columns := make([]entity.Column, 0, len(outputFields))
for _, outputField := range outputFields {
fieldData, ok := fields[outputField]
var column entity.Column
var err error
if !ok {
if dynamicField == nil {
return nil, errors.New("output fields not match when dynamic field disabled")
}
if dynamicColumn == nil {
return nil, errors.New("output fields not match and result field data does not contain dynamic field")
}
column = entity.NewColumnDynamic(dynamicColumn, outputField)
} else {
column, err = entity.FieldDataColumn(fieldData, from, to)
}
if err != nil {
return nil, err
}
columns = append(columns, column)
}
return columns, nil
}

func PKs2Expr(backName string, ids entity.Column) string {
var expr string
var pkName = ids.Name()
Expand All @@ -111,7 +153,7 @@ func PKs2Expr(backName string, ids entity.Column) string {
}

// QueryByPks query record by specified primary key(s)
func (c *GrpcClient) QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) ([]entity.Column, error) {
func (c *GrpcClient) QueryByPks(ctx context.Context, collectionName string, partitionNames []string, ids entity.Column, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}
Expand All @@ -129,17 +171,21 @@ func (c *GrpcClient) QueryByPks(ctx context.Context, collectionName string, part
}

// Query performs query by expression.
func (c *GrpcClient) Query(ctx context.Context, collectionName string, partitionNames []string, expr string, outputFields []string, opts ...SearchQueryOptionFunc) ([]entity.Column, error) {
func (c *GrpcClient) Query(ctx context.Context, collectionName string, partitionNames []string, expr string, outputFields []string, opts ...SearchQueryOptionFunc) (ResultSet, error) {
if c.Service == nil {
return nil, ErrClientNotReady
}

_, ok := MetaCache.getCollectionInfo(collectionName)
var sch *entity.Schema
collInfo, ok := MetaCache.getCollectionInfo(collectionName)
if !ok {
_, err := c.DescribeCollection(ctx, collectionName)
coll, err := c.DescribeCollection(ctx, collectionName)
if err != nil {
return nil, err
}
sch = coll.Schema
} else {
sch = collInfo.Schema
}

option, err := makeSearchQueryOption(collectionName, opts...)
Expand Down Expand Up @@ -176,22 +222,22 @@ func (c *GrpcClient) Query(ctx context.Context, collectionName string, partition
}

fieldsData := resp.GetFieldsData()
columns := make([]entity.Column, 0, len(fieldsData))
for _, fieldData := range resp.GetFieldsData() {
if fieldData.GetType() == schema.DataType_FloatVector ||
fieldData.GetType() == schema.DataType_BinaryVector {
column, err := entity.FieldDataVector(fieldData)
if err != nil {
return nil, err
}
columns = append(columns, column)
continue
// query always has pk field as output
hasPK := false
pkName := sch.PKFieldName()
for _, output := range outputFields {
if output == pkName {
hasPK = true
break
}
column, err := entity.FieldDataColumn(fieldData, 0, -1)
if err != nil {
return nil, err
}
columns = append(columns, column)
}
if !hasPK {
outputFields = append(outputFields, pkName)
}

columns, err := c.parseSearchResult(sch, outputFields, fieldsData, 0, 0, -1) //entity.FieldDataColumn(fieldData, 0, -1)
if err != nil {
return nil, err
}

return columns, nil
Expand Down
Loading

0 comments on commit b619b53

Please sign in to comment.