Skip to content

Commit

Permalink
Adding float16 and bfloat16 support
Browse files Browse the repository at this point in the history
Signed-off-by: Ted Xu <ted.xu@zilliz.com>
  • Loading branch information
tedxu committed Feb 5, 2024
1 parent 1e03ea4 commit 42fcd3a
Show file tree
Hide file tree
Showing 19 changed files with 812 additions and 364 deletions.
4 changes: 2 additions & 2 deletions client/client_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,8 @@ func (m *MockServer) Upsert(ctx context.Context, req *milvuspb.UpsertRequest) (*
return &milvuspb.MutationResult{Status: s}, err
}

func (m *MockServer) HybridSearch(ctx context.Context, req *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
f := m.GetInjection(MSearchV2)
func (m *MockServer) SearchV2(ctx context.Context, req *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error) {
f := m.GetInjection(MUpsert)
if f != nil {
r, err := f(ctx, req)
return r.(*milvuspb.SearchResults), err
Expand Down
5 changes: 4 additions & 1 deletion client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error {
if field.IsPartitionKey {
hasPartitionKey = true
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
vectors++
}
}
Expand Down
14 changes: 0 additions & 14 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ const (
limitKey = `limit`
ignoreGrowingKey = `ignore_growing`
forTuningKey = `for_tuning`
groupByKey = `group_by_field`
)

// Search with bool expression
Expand Down Expand Up @@ -75,23 +74,12 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s
results := resp.GetResults()
offset := 0
fieldDataList := results.GetFieldsData()
gb := results.GetGroupByFieldValue()
var gbc entity.Column
if gb != nil {
gbc, err = entity.FieldDataColumn(gb, 0, -1)
if err != nil {
return nil, err
}
}
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
if gbc != nil {
entry.GroupByValue, _ = gbc.Get(i)
}
// parse result set if current nq is not empty
if rc > 0 {
entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc)
Expand Down Expand Up @@ -347,9 +335,7 @@ func prepareSearchRequest(collName string, partitions []string,
"round_decimal": "-1",
ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing),
offsetKey: fmt.Sprintf("%d", opt.Offset),
groupByKey: opt.GroupByField,
})

req := &milvuspb.SearchRequest{
DbName: "",
CollectionName: collName,
Expand Down
41 changes: 0 additions & 41 deletions client/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ func (s *SearchSuite) SetupSuite() {

s.sch = entity.NewSchema().WithName(testCollectionName).
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
WithField(entity.NewField().WithName("Attr").WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim))
s.schDynamic = entity.NewSchema().WithName(testCollectionName).WithDynamicFieldEnabled(true).
WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)).
Expand Down Expand Up @@ -532,46 +531,6 @@ func (s *SearchSuite) TestSearchSuccess() {
s.NoError(err)
s.Equal("abc", str)
})

s.Run("group_by", func() {
defer s.resetMock()
s.setupDescribeCollection(testCollectionName, s.sch)
s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")).
Run(func(_ context.Context, req *milvuspb.SearchRequest) {
s.Equal(testCollectionName, req.GetCollectionName())
s.Equal(expr, req.GetDsl())
s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType())
s.ElementsMatch([]string{"ID"}, req.GetOutputFields())
s.ElementsMatch([]string{partName}, req.GetPartitionNames())
}).
Return(&milvuspb.SearchResults{
Status: getSuccessStatus(),
Results: &schemapb.SearchResultData{
NumQueries: 1,
TopK: 10,
FieldsData: []*schemapb.FieldData{
s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
},
},
},
Scores: make([]float32, 10),
Topks: []int64{10},
GroupByFieldValue: s.getInt64FieldData("Attr", []int64{10}),
},
}, nil)

r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])},
testVectorField, entity.L2, 10, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(10000000000), WithGroupByField("Attr"))
s.NoError(err)
s.Require().Equal(1, len(r))
result := r[0]
s.Require().NotNil(result.Fields.GetColumn("ID"))
})
}

func TestSearch(t *testing.T) {
Expand Down
8 changes: 0 additions & 8 deletions client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ type SearchQueryOption struct {

IgnoreGrowing bool
ForTuning bool

GroupByField string
}

// SearchQueryOptionFunc is a function which modifies SearchOption
Expand Down Expand Up @@ -165,12 +163,6 @@ func WithLimit(limit int64) SearchQueryOptionFunc {
}
}

func WithGroupByField(groupByField string) SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
option.GroupByField = groupByField
}
}

// WithSearchQueryConsistencyLevel specifies consistency level
func WithSearchQueryConsistencyLevel(cl entity.ConsistencyLevel) SearchQueryOptionFunc {
return func(option *SearchQueryOption) {
Expand Down
11 changes: 5 additions & 6 deletions client/results.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import "github.com/milvus-io/milvus-sdk-go/v2/entity"
// Fields contains the data of `outputFieleds` specified or all columns if non
// Scores is actually the distance between the vector current record contains and the search target vector
type SearchResult struct {
ResultCount int // the returning entry count
GroupByValue interface{}
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields ResultSet //[]entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
ResultCount int // the returning entry count
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Fields ResultSet //[]entity.Column // output field data
Scores []float32 // distance to the target vector
Err error // search error if any
}

// ResultSet is an alias type for column slice.
Expand Down
101 changes: 101 additions & 0 deletions entity/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,40 @@ func (fv FloatVector) Serialize() []byte {
return data
}

// FloatVector float32 vector wrapper.
type Float16Vector []byte

// Dim returns vector dimension.
func (fv Float16Vector) Dim() int {
return len(fv) / 2
}

// FieldType returns coresponding field type.
func (fv Float16Vector) FieldType() FieldType {
return FieldTypeFloat16Vector
}

func (fv Float16Vector) Serialize() []byte {
return fv
}

// FloatVector float32 vector wrapper.
type BFloat16Vector []byte

// Dim returns vector dimension.
func (fv BFloat16Vector) Dim() int {
return len(fv) / 2
}

// FieldType returns coresponding field type.
func (fv BFloat16Vector) FieldType() FieldType {
return FieldTypeBFloat16Vector
}

func (fv BFloat16Vector) Serialize() []byte {
return fv
}

// BinaryVector []byte vector wrapper
type BinaryVector []byte

Expand Down Expand Up @@ -306,6 +340,43 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) {
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil

case schema.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil

case schema.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin) // shall not have remanunt
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
Expand Down Expand Up @@ -447,6 +518,36 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) {
vector = append(vector, v)
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, errors.New("unsupported data type")
}
Expand Down
Loading

0 comments on commit 42fcd3a

Please sign in to comment.