diff --git a/client/client_mock_test.go b/client/client_mock_test.go index 45cd851c..8bcdabe2 100644 --- a/client/client_mock_test.go +++ b/client/client_mock_test.go @@ -1017,8 +1017,8 @@ func (m *MockServer) Upsert(ctx context.Context, req *milvuspb.UpsertRequest) (* return &milvuspb.MutationResult{Status: s}, err } -func (m *MockServer) HybridSearch(ctx context.Context, req *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) { - f := m.GetInjection(MSearchV2) +func (m *MockServer) SearchV2(ctx context.Context, req *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error) { + f := m.GetInjection(MUpsert) if f != nil { r, err := f(ctx, req) return r.(*milvuspb.SearchResults), err diff --git a/client/collection.go b/client/collection.go index 748d7015..8f53d13f 100644 --- a/client/collection.go +++ b/client/collection.go @@ -218,7 +218,10 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error { if field.IsPartitionKey { hasPartitionKey = true } - if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector { + if field.DataType == entity.FieldTypeFloatVector || + field.DataType == entity.FieldTypeBinaryVector || + field.DataType == entity.FieldTypeBFloat16Vector || + field.DataType == entity.FieldTypeFloat16Vector { vectors++ } } diff --git a/client/data.go b/client/data.go index 9a85c43f..d40c45df 100644 --- a/client/data.go +++ b/client/data.go @@ -32,7 +32,6 @@ const ( limitKey = `limit` ignoreGrowingKey = `ignore_growing` forTuningKey = `for_tuning` - groupByKey = `group_by_field` ) // Search with bool expression @@ -75,23 +74,12 @@ func (c *GrpcClient) Search(ctx context.Context, collName string, partitions []s results := resp.GetResults() offset := 0 fieldDataList := results.GetFieldsData() - gb := results.GetGroupByFieldValue() - var gbc entity.Column - if gb != nil { - gbc, err = entity.FieldDataColumn(gb, 0, -1) - if err != nil { - return nil, err - } - } for i := 0; i < int(results.GetNumQueries()); i++ { rc := int(results.GetTopks()[i]) // result entry count for current query entry := SearchResult{ ResultCount: rc, Scores: results.GetScores()[offset : offset+rc], } - if gbc != nil { - entry.GroupByValue, _ = gbc.Get(i) - } // parse result set if current nq is not empty if rc > 0 { entry.IDs, entry.Err = entity.IDColumns(results.GetIds(), offset, offset+rc) @@ -347,9 +335,7 @@ func prepareSearchRequest(collName string, partitions []string, "round_decimal": "-1", ignoreGrowingKey: strconv.FormatBool(opt.IgnoreGrowing), offsetKey: fmt.Sprintf("%d", opt.Offset), - groupByKey: opt.GroupByField, }) - req := &milvuspb.SearchRequest{ DbName: "", CollectionName: collName, diff --git a/client/data_test.go b/client/data_test.go index 3797f7bf..8e2748d6 100644 --- a/client/data_test.go +++ b/client/data_test.go @@ -344,7 +344,6 @@ func (s *SearchSuite) SetupSuite() { s.sch = entity.NewSchema().WithName(testCollectionName). WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). - WithField(entity.NewField().WithName("Attr").WithDataType(entity.FieldTypeInt64)). WithField(entity.NewField().WithName("vector").WithDataType(entity.FieldTypeFloatVector).WithDim(testVectorDim)) s.schDynamic = entity.NewSchema().WithName(testCollectionName).WithDynamicFieldEnabled(true). WithField(entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)). @@ -532,46 +531,6 @@ func (s *SearchSuite) TestSearchSuccess() { s.NoError(err) s.Equal("abc", str) }) - - s.Run("group_by", func() { - defer s.resetMock() - s.setupDescribeCollection(testCollectionName, s.sch) - s.mock.EXPECT().Search(mock.Anything, mock.AnythingOfType("*milvuspb.SearchRequest")). - Run(func(_ context.Context, req *milvuspb.SearchRequest) { - s.Equal(testCollectionName, req.GetCollectionName()) - s.Equal(expr, req.GetDsl()) - s.Equal(commonpb.DslType_BoolExprV1, req.GetDslType()) - s.ElementsMatch([]string{"ID"}, req.GetOutputFields()) - s.ElementsMatch([]string{partName}, req.GetPartitionNames()) - }). - Return(&milvuspb.SearchResults{ - Status: getSuccessStatus(), - Results: &schemapb.SearchResultData{ - NumQueries: 1, - TopK: 10, - FieldsData: []*schemapb.FieldData{ - s.getInt64FieldData("ID", []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), - }, - Ids: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - }, - }, - }, - Scores: make([]float32, 10), - Topks: []int64{10}, - GroupByFieldValue: s.getInt64FieldData("Attr", []int64{10}), - }, - }, nil) - - r, err := c.Search(ctx, testCollectionName, []string{partName}, expr, []string{"ID"}, []entity.Vector{entity.FloatVector(vectors[0])}, - testVectorField, entity.L2, 10, sp, WithIgnoreGrowing(), WithForTuning(), WithSearchQueryConsistencyLevel(entity.ClCustomized), WithGuaranteeTimestamp(10000000000), WithGroupByField("Attr")) - s.NoError(err) - s.Require().Equal(1, len(r)) - result := r[0] - s.Require().NotNil(result.Fields.GetColumn("ID")) - }) } func TestSearch(t *testing.T) { diff --git a/client/options.go b/client/options.go index 6a84e6b9..3146cb79 100644 --- a/client/options.go +++ b/client/options.go @@ -132,8 +132,6 @@ type SearchQueryOption struct { IgnoreGrowing bool ForTuning bool - - GroupByField string } // SearchQueryOptionFunc is a function which modifies SearchOption @@ -165,12 +163,6 @@ func WithLimit(limit int64) SearchQueryOptionFunc { } } -func WithGroupByField(groupByField string) SearchQueryOptionFunc { - return func(option *SearchQueryOption) { - option.GroupByField = groupByField - } -} - // WithSearchQueryConsistencyLevel specifies consistency level func WithSearchQueryConsistencyLevel(cl entity.ConsistencyLevel) SearchQueryOptionFunc { return func(option *SearchQueryOption) { diff --git a/client/results.go b/client/results.go index 75ecddfa..51c7c521 100644 --- a/client/results.go +++ b/client/results.go @@ -7,12 +7,11 @@ import "github.com/milvus-io/milvus-sdk-go/v2/entity" // Fields contains the data of `outputFieleds` specified or all columns if non // Scores is actually the distance between the vector current record contains and the search target vector type SearchResult struct { - ResultCount int // the returning entry count - GroupByValue interface{} - IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API - Fields ResultSet //[]entity.Column // output field data - Scores []float32 // distance to the target vector - Err error // search error if any + ResultCount int // the returning entry count + IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API + Fields ResultSet //[]entity.Column // output field data + Scores []float32 // distance to the target vector + Err error // search error if any } // ResultSet is an alias type for column slice. diff --git a/entity/columns.go b/entity/columns.go index af35632c..cc3e4b90 100644 --- a/entity/columns.go +++ b/entity/columns.go @@ -88,6 +88,40 @@ func (fv FloatVector) Serialize() []byte { return data } +// FloatVector float32 vector wrapper. +type Float16Vector []byte + +// Dim returns vector dimension. +func (fv Float16Vector) Dim() int { + return len(fv) / 2 +} + +// FieldType returns coresponding field type. +func (fv Float16Vector) FieldType() FieldType { + return FieldTypeFloat16Vector +} + +func (fv Float16Vector) Serialize() []byte { + return fv +} + +// FloatVector float32 vector wrapper. +type BFloat16Vector []byte + +// Dim returns vector dimension. +func (fv BFloat16Vector) Dim() int { + return len(fv) / 2 +} + +// FieldType returns coresponding field type. +func (fv BFloat16Vector) FieldType() FieldType { + return FieldTypeBFloat16Vector +} + +func (fv BFloat16Vector) Serialize() []byte { + return fv +} + // BinaryVector []byte vector wrapper type BinaryVector []byte @@ -306,6 +340,43 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_Float16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Float16Vector + dim := int(vectors.GetDim()) + if end < 0 { + end = int(len(data) / dim) + } + vector := make([][]byte, 0, end-begin) + for i := begin; i < end; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil + + case schema.DataType_BFloat16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Bfloat16Vector + dim := int(vectors.GetDim()) + if end < 0 { + end = int(len(data) / dim) + } + vector := make([][]byte, 0, end-begin) // shall not have remanunt + for i := begin; i < end; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil default: return nil, fmt.Errorf("unsupported data type %s", fd.GetType()) } @@ -447,6 +518,36 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) { vector = append(vector, v) } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_Float16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Float16Vector + dim := int(vectors.GetDim()) + vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt + for i := 0; i < len(data)/dim; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_BFloat16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Bfloat16Vector + dim := int(vectors.GetDim()) + vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt + for i := 0; i < len(data)/dim; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil default: return nil, errors.New("unsupported data type") } diff --git a/entity/columns_scalar_gen.go b/entity/columns_scalar_gen.go index 67d91f61..ac65a1b2 100755 --- a/entity/columns_scalar_gen.go +++ b/entity/columns_scalar_gen.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generate +// This file is generated by go generate -package entity +package entity import ( "errors" @@ -44,7 +44,7 @@ func (c *ColumnBool) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnBool) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Bool, + Type: schema.DataType_Bool, FieldName: c.name, } data := make([]bool, 0, c.Len()) @@ -74,7 +74,7 @@ func (c *ColumnBool) ValueByIdx(idx int) (bool, error) { } // AppendValue append value into column -func (c *ColumnBool) AppendValue(i interface{}) error { +func(c *ColumnBool) AppendValue(i interface{}) error { v, ok := i.(bool) if !ok { return fmt.Errorf("invalid type, expected bool, got %T", i) @@ -91,8 +91,8 @@ func (c *ColumnBool) Data() []bool { // NewColumnBool auto generated constructor func NewColumnBool(name string, values []bool) *ColumnBool { - return &ColumnBool{ - name: name, + return &ColumnBool { + name: name, values: values, } } @@ -131,7 +131,7 @@ func (c *ColumnInt8) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt8) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int8, + Type: schema.DataType_Int8, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -161,7 +161,7 @@ func (c *ColumnInt8) ValueByIdx(idx int) (int8, error) { } // AppendValue append value into column -func (c *ColumnInt8) AppendValue(i interface{}) error { +func(c *ColumnInt8) AppendValue(i interface{}) error { v, ok := i.(int8) if !ok { return fmt.Errorf("invalid type, expected int8, got %T", i) @@ -178,8 +178,8 @@ func (c *ColumnInt8) Data() []int8 { // NewColumnInt8 auto generated constructor func NewColumnInt8(name string, values []int8) *ColumnInt8 { - return &ColumnInt8{ - name: name, + return &ColumnInt8 { + name: name, values: values, } } @@ -218,7 +218,7 @@ func (c *ColumnInt16) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt16) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int16, + Type: schema.DataType_Int16, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -248,7 +248,7 @@ func (c *ColumnInt16) ValueByIdx(idx int) (int16, error) { } // AppendValue append value into column -func (c *ColumnInt16) AppendValue(i interface{}) error { +func(c *ColumnInt16) AppendValue(i interface{}) error { v, ok := i.(int16) if !ok { return fmt.Errorf("invalid type, expected int16, got %T", i) @@ -265,8 +265,8 @@ func (c *ColumnInt16) Data() []int16 { // NewColumnInt16 auto generated constructor func NewColumnInt16(name string, values []int16) *ColumnInt16 { - return &ColumnInt16{ - name: name, + return &ColumnInt16 { + name: name, values: values, } } @@ -305,7 +305,7 @@ func (c *ColumnInt32) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt32) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int32, + Type: schema.DataType_Int32, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -335,7 +335,7 @@ func (c *ColumnInt32) ValueByIdx(idx int) (int32, error) { } // AppendValue append value into column -func (c *ColumnInt32) AppendValue(i interface{}) error { +func(c *ColumnInt32) AppendValue(i interface{}) error { v, ok := i.(int32) if !ok { return fmt.Errorf("invalid type, expected int32, got %T", i) @@ -352,8 +352,8 @@ func (c *ColumnInt32) Data() []int32 { // NewColumnInt32 auto generated constructor func NewColumnInt32(name string, values []int32) *ColumnInt32 { - return &ColumnInt32{ - name: name, + return &ColumnInt32 { + name: name, values: values, } } @@ -392,7 +392,7 @@ func (c *ColumnInt64) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt64) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int64, + Type: schema.DataType_Int64, FieldName: c.name, } data := make([]int64, 0, c.Len()) @@ -422,7 +422,7 @@ func (c *ColumnInt64) ValueByIdx(idx int) (int64, error) { } // AppendValue append value into column -func (c *ColumnInt64) AppendValue(i interface{}) error { +func(c *ColumnInt64) AppendValue(i interface{}) error { v, ok := i.(int64) if !ok { return fmt.Errorf("invalid type, expected int64, got %T", i) @@ -439,8 +439,8 @@ func (c *ColumnInt64) Data() []int64 { // NewColumnInt64 auto generated constructor func NewColumnInt64(name string, values []int64) *ColumnInt64 { - return &ColumnInt64{ - name: name, + return &ColumnInt64 { + name: name, values: values, } } @@ -479,7 +479,7 @@ func (c *ColumnFloat) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnFloat) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Float, + Type: schema.DataType_Float, FieldName: c.name, } data := make([]float32, 0, c.Len()) @@ -509,7 +509,7 @@ func (c *ColumnFloat) ValueByIdx(idx int) (float32, error) { } // AppendValue append value into column -func (c *ColumnFloat) AppendValue(i interface{}) error { +func(c *ColumnFloat) AppendValue(i interface{}) error { v, ok := i.(float32) if !ok { return fmt.Errorf("invalid type, expected float32, got %T", i) @@ -526,8 +526,8 @@ func (c *ColumnFloat) Data() []float32 { // NewColumnFloat auto generated constructor func NewColumnFloat(name string, values []float32) *ColumnFloat { - return &ColumnFloat{ - name: name, + return &ColumnFloat { + name: name, values: values, } } @@ -566,7 +566,7 @@ func (c *ColumnDouble) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnDouble) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Double, + Type: schema.DataType_Double, FieldName: c.name, } data := make([]float64, 0, c.Len()) @@ -596,7 +596,7 @@ func (c *ColumnDouble) ValueByIdx(idx int) (float64, error) { } // AppendValue append value into column -func (c *ColumnDouble) AppendValue(i interface{}) error { +func(c *ColumnDouble) AppendValue(i interface{}) error { v, ok := i.(float64) if !ok { return fmt.Errorf("invalid type, expected float64, got %T", i) @@ -613,8 +613,8 @@ func (c *ColumnDouble) Data() []float64 { // NewColumnDouble auto generated constructor func NewColumnDouble(name string, values []float64) *ColumnDouble { - return &ColumnDouble{ - name: name, + return &ColumnDouble { + name: name, values: values, } } @@ -653,7 +653,7 @@ func (c *ColumnString) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnString) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_String, + Type: schema.DataType_String, FieldName: c.name, } data := make([]string, 0, c.Len()) @@ -683,7 +683,7 @@ func (c *ColumnString) ValueByIdx(idx int) (string, error) { } // AppendValue append value into column -func (c *ColumnString) AppendValue(i interface{}) error { +func(c *ColumnString) AppendValue(i interface{}) error { v, ok := i.(string) if !ok { return fmt.Errorf("invalid type, expected string, got %T", i) @@ -700,8 +700,9 @@ func (c *ColumnString) Data() []string { // NewColumnString auto generated constructor func NewColumnString(name string, values []string) *ColumnString { - return &ColumnString{ - name: name, + return &ColumnString { + name: name, values: values, } } + diff --git a/entity/columns_scalar_gen_test.go b/entity/columns_scalar_gen_test.go index 5366e405..70e1408c 100755 --- a/entity/columns_scalar_gen_test.go +++ b/entity/columns_scalar_gen_test.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generated +// This file is generated by go generated -package entity +package entity import ( "fmt" @@ -9,8 +9,8 @@ import ( "testing" "time" - schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestColumnBool(t *testing.T) { @@ -60,7 +60,7 @@ func TestFieldDataBoolColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Bool_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Bool, + Type: schema.DataType_Bool, FieldName: name, } @@ -77,7 +77,7 @@ func TestFieldDataBoolColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeBool, column.Type()) @@ -92,12 +92,13 @@ func TestFieldDataBoolColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -111,7 +112,7 @@ func TestFieldDataBoolColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeBool, column.Type()) @@ -165,7 +166,7 @@ func TestFieldDataInt8Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int8_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int8, + Type: schema.DataType_Int8, FieldName: name, } @@ -182,7 +183,7 @@ func TestFieldDataInt8Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt8, column.Type()) @@ -197,12 +198,13 @@ func TestFieldDataInt8Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -216,7 +218,7 @@ func TestFieldDataInt8Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt8, column.Type()) @@ -270,7 +272,7 @@ func TestFieldDataInt16Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int16_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int16, + Type: schema.DataType_Int16, FieldName: name, } @@ -287,7 +289,7 @@ func TestFieldDataInt16Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt16, column.Type()) @@ -302,12 +304,13 @@ func TestFieldDataInt16Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -321,7 +324,7 @@ func TestFieldDataInt16Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt16, column.Type()) @@ -375,7 +378,7 @@ func TestFieldDataInt32Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int32_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int32, + Type: schema.DataType_Int32, FieldName: name, } @@ -392,7 +395,7 @@ func TestFieldDataInt32Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt32, column.Type()) @@ -407,12 +410,13 @@ func TestFieldDataInt32Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -426,7 +430,7 @@ func TestFieldDataInt32Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt32, column.Type()) @@ -480,7 +484,7 @@ func TestFieldDataInt64Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int64_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int64, + Type: schema.DataType_Int64, FieldName: name, } @@ -497,7 +501,7 @@ func TestFieldDataInt64Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt64, column.Type()) @@ -512,12 +516,13 @@ func TestFieldDataInt64Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -531,7 +536,7 @@ func TestFieldDataInt64Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt64, column.Type()) @@ -585,7 +590,7 @@ func TestFieldDataFloatColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Float_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Float, + Type: schema.DataType_Float, FieldName: name, } @@ -602,7 +607,7 @@ func TestFieldDataFloatColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeFloat, column.Type()) @@ -617,12 +622,13 @@ func TestFieldDataFloatColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -636,7 +642,7 @@ func TestFieldDataFloatColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeFloat, column.Type()) @@ -690,7 +696,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Double_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Double, + Type: schema.DataType_Double, FieldName: name, } @@ -707,7 +713,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeDouble, column.Type()) @@ -722,12 +728,13 @@ func TestFieldDataDoubleColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -741,7 +748,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeDouble, column.Type()) @@ -795,7 +802,7 @@ func TestFieldDataStringColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_String_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_String, + Type: schema.DataType_String, FieldName: name, } @@ -812,7 +819,7 @@ func TestFieldDataStringColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeString, column.Type()) @@ -827,12 +834,13 @@ func TestFieldDataStringColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -846,9 +854,10 @@ func TestFieldDataStringColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeString, column.Type()) }) } + diff --git a/entity/columns_vector_gen.go b/entity/columns_vector_gen.go index 87a874c7..f70e826b 100755 --- a/entity/columns_vector_gen.go +++ b/entity/columns_vector_gen.go @@ -10,6 +10,7 @@ import ( "github.com/cockroachdb/errors" ) + // ColumnBinaryVector generated columns type for BinaryVector type ColumnBinaryVector struct { ColumnBase @@ -29,7 +30,7 @@ func (c *ColumnBinaryVector) Type() FieldType { } // Len returns column data length -func (c *ColumnBinaryVector) Len() int { +func (c * ColumnBinaryVector) Len() int { return len(c.values) } @@ -47,7 +48,7 @@ func (c *ColumnBinaryVector) Get(idx int) (interface{}, error) { } // AppendValue append value into column -func (c *ColumnBinaryVector) AppendValue(i interface{}) error { +func(c *ColumnBinaryVector) AppendValue(i interface{}) error { v, ok := i.([]byte) if !ok { return fmt.Errorf("invalid type, expected []byte, got %T", i) @@ -65,11 +66,11 @@ func (c *ColumnBinaryVector) Data() [][]byte { // FieldData return column data mapped to schema.FieldData func (c *ColumnBinaryVector) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_BinaryVector, + Type: schema.DataType_BinaryVector, FieldName: c.name, } - data := make([]byte, 0, len(c.values)*c.dim) + data := make([]byte, 0, len(c.values)* c.dim) for _, vector := range c.values { data = append(data, vector...) @@ -78,9 +79,10 @@ func (c *ColumnBinaryVector) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - - Data: &schema.VectorField_BinaryVector{ + + Data: &schema.VectorField_BinaryVector{ BinaryVector: data, + }, }, } @@ -89,7 +91,7 @@ func (c *ColumnBinaryVector) FieldData() *schema.FieldData { // NewColumnBinaryVector auto generated constructor func NewColumnBinaryVector(name string, dim int, values [][]byte) *ColumnBinaryVector { - return &ColumnBinaryVector{ + return &ColumnBinaryVector { name: name, dim: dim, values: values, @@ -115,7 +117,7 @@ func (c *ColumnFloatVector) Type() FieldType { } // Len returns column data length -func (c *ColumnFloatVector) Len() int { +func (c * ColumnFloatVector) Len() int { return len(c.values) } @@ -133,7 +135,7 @@ func (c *ColumnFloatVector) Get(idx int) (interface{}, error) { } // AppendValue append value into column -func (c *ColumnFloatVector) AppendValue(i interface{}) error { +func(c *ColumnFloatVector) AppendValue(i interface{}) error { v, ok := i.([]float32) if !ok { return fmt.Errorf("invalid type, expected []float32, got %T", i) @@ -151,11 +153,11 @@ func (c *ColumnFloatVector) Data() [][]float32 { // FieldData return column data mapped to schema.FieldData func (c *ColumnFloatVector) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_FloatVector, + Type: schema.DataType_FloatVector, FieldName: c.name, } - data := make([]float32, 0, len(c.values)*c.dim) + data := make([]float32, 0, len(c.values)* c.dim) for _, vector := range c.values { data = append(data, vector...) @@ -164,11 +166,12 @@ func (c *ColumnFloatVector) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - + Data: &schema.VectorField_FloatVector{ FloatVector: &schema.FloatArray{ Data: data, }, + }, }, } @@ -177,9 +180,184 @@ func (c *ColumnFloatVector) FieldData() *schema.FieldData { // NewColumnFloatVector auto generated constructor func NewColumnFloatVector(name string, dim int, values [][]float32) *ColumnFloatVector { - return &ColumnFloatVector{ + return &ColumnFloatVector { + name: name, + dim: dim, + values: values, + } +} + +// ColumnFloat16Vector generated columns type for Float16Vector +type ColumnFloat16Vector struct { + ColumnBase + name string + dim int + values [][]byte +} + +// Name returns column name +func (c *ColumnFloat16Vector) Name() string { + return c.name +} + +// Type returns column FieldType +func (c *ColumnFloat16Vector) Type() FieldType { + return FieldTypeFloat16Vector +} + +// Len returns column data length +func (c * ColumnFloat16Vector) Len() int { + return len(c.values) +} + +// Dim returns vector dimension +func (c *ColumnFloat16Vector) Dim() int { + return c.dim +} + +// Get returns values at index as interface{}. +func (c *ColumnFloat16Vector) Get(idx int) (interface{}, error) { + if idx < 0 || idx >= c.Len() { + return nil, errors.New("index out of range") + } + return c.values[idx], nil +} + +// AppendValue append value into column +func(c *ColumnFloat16Vector) AppendValue(i interface{}) error { + v, ok := i.([]byte) + if !ok { + return fmt.Errorf("invalid type, expected []byte, got %T", i) + } + c.values = append(c.values, v) + + return nil +} + +// Data returns column data +func (c *ColumnFloat16Vector) Data() [][]byte { + return c.values +} + +// FieldData return column data mapped to schema.FieldData +func (c *ColumnFloat16Vector) FieldData() *schema.FieldData { + fd := &schema.FieldData{ + Type: schema.DataType_Float16Vector, + FieldName: c.name, + } + + data := make([]byte, 0, len(c.values)* c.dim) + + for _, vector := range c.values { + data = append(data, vector...) + } + + fd.Field = &schema.FieldData_Vectors{ + Vectors: &schema.VectorField{ + Dim: int64(c.dim), + + Data: &schema.VectorField_Float16Vector{ + Float16Vector: data, + + }, + }, + } + return fd +} + +// NewColumnFloat16Vector auto generated constructor +func NewColumnFloat16Vector(name string, dim int, values [][]byte) *ColumnFloat16Vector { + return &ColumnFloat16Vector { name: name, dim: dim, values: values, } } + +// ColumnBFloat16Vector generated columns type for BFloat16Vector +type ColumnBFloat16Vector struct { + ColumnBase + name string + dim int + values [][]byte +} + +// Name returns column name +func (c *ColumnBFloat16Vector) Name() string { + return c.name +} + +// Type returns column FieldType +func (c *ColumnBFloat16Vector) Type() FieldType { + return FieldTypeBFloat16Vector +} + +// Len returns column data length +func (c * ColumnBFloat16Vector) Len() int { + return len(c.values) +} + +// Dim returns vector dimension +func (c *ColumnBFloat16Vector) Dim() int { + return c.dim +} + +// Get returns values at index as interface{}. +func (c *ColumnBFloat16Vector) Get(idx int) (interface{}, error) { + if idx < 0 || idx >= c.Len() { + return nil, errors.New("index out of range") + } + return c.values[idx], nil +} + +// AppendValue append value into column +func(c *ColumnBFloat16Vector) AppendValue(i interface{}) error { + v, ok := i.([]byte) + if !ok { + return fmt.Errorf("invalid type, expected []byte, got %T", i) + } + c.values = append(c.values, v) + + return nil +} + +// Data returns column data +func (c *ColumnBFloat16Vector) Data() [][]byte { + return c.values +} + +// FieldData return column data mapped to schema.FieldData +func (c *ColumnBFloat16Vector) FieldData() *schema.FieldData { + fd := &schema.FieldData{ + Type: schema.DataType_BFloat16Vector, + FieldName: c.name, + } + + data := make([]byte, 0, len(c.values)* c.dim) + + for _, vector := range c.values { + data = append(data, vector...) + } + + fd.Field = &schema.FieldData_Vectors{ + Vectors: &schema.VectorField{ + Dim: int64(c.dim), + + Data: &schema.VectorField_Bfloat16Vector{ + Bfloat16Vector: data, + + }, + }, + } + return fd +} + +// NewColumnBFloat16Vector auto generated constructor +func NewColumnBFloat16Vector(name string, dim int, values [][]byte) *ColumnBFloat16Vector { + return &ColumnBFloat16Vector { + name: name, + dim: dim, + values: values, + } +} + diff --git a/entity/columns_vector_gen_test.go b/entity/columns_vector_gen_test.go index d959e2c5..c5753adf 100755 --- a/entity/columns_vector_gen_test.go +++ b/entity/columns_vector_gen_test.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generated +// This file is generated by go generated -package entity +package entity import ( "fmt" @@ -19,16 +19,17 @@ func TestColumnBinaryVector(t *testing.T) { columnLen := 12 + rand.Intn(10) dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] - v := make([][]byte, 0, columnLen) + v := make([][]byte,0, columnLen) dlen := dim dlen /= 8 - + + for i := 0; i < columnLen; i++ { entry := make([]byte, dlen) v = append(v, entry) } column := NewColumnBinaryVector(columnName, dim, v) - + t.Run("test meta", func(t *testing.T) { ft := FieldTypeBinaryVector assert.Equal(t, "BinaryVector", ft.Name()) @@ -43,13 +44,13 @@ func TestColumnBinaryVector(t *testing.T) { assert.Equal(t, FieldTypeBinaryVector, column.Type()) assert.Equal(t, columnLen, column.Len()) assert.Equal(t, dim, column.Dim()) - assert.Equal(t, v, column.Data()) - + assert.Equal(t ,v, column.Data()) + var ev []byte err := column.AppendValue(ev) assert.Equal(t, columnLen+1, column.Len()) assert.Nil(t, err) - + err = column.AppendValue(struct{}{}) assert.Equal(t, columnLen+1, column.Len()) assert.NotNil(t, err) @@ -70,7 +71,7 @@ func TestColumnBinaryVector(t *testing.T) { Type: schema.DataType_BinaryVector, FieldName: columnName, } - _, err := FieldDataVector(fd) + _, err := FieldDataVector(fd) assert.Error(t, err) }) @@ -82,15 +83,17 @@ func TestColumnFloatVector(t *testing.T) { columnLen := 12 + rand.Intn(10) dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] - v := make([][]float32, 0, columnLen) + v := make([][]float32,0, columnLen) dlen := dim - + + + for i := 0; i < columnLen; i++ { entry := make([]float32, dlen) v = append(v, entry) } column := NewColumnFloatVector(columnName, dim, v) - + t.Run("test meta", func(t *testing.T) { ft := FieldTypeFloatVector assert.Equal(t, "FloatVector", ft.Name()) @@ -105,13 +108,13 @@ func TestColumnFloatVector(t *testing.T) { assert.Equal(t, FieldTypeFloatVector, column.Type()) assert.Equal(t, columnLen, column.Len()) assert.Equal(t, dim, column.Dim()) - assert.Equal(t, v, column.Data()) - + assert.Equal(t ,v, column.Data()) + var ev []float32 err := column.AppendValue(ev) assert.Equal(t, columnLen+1, column.Len()) assert.Nil(t, err) - + err = column.AppendValue(struct{}{}) assert.Equal(t, columnLen+1, column.Len()) assert.NotNil(t, err) @@ -132,8 +135,137 @@ func TestColumnFloatVector(t *testing.T) { Type: schema.DataType_FloatVector, FieldName: columnName, } - _, err := FieldDataVector(fd) + _, err := FieldDataVector(fd) assert.Error(t, err) }) } + +func TestColumnFloat16Vector(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + columnName := fmt.Sprintf("column_Float16Vector_%d", rand.Int()) + columnLen := 12 + rand.Intn(10) + dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] + + v := make([][]byte,0, columnLen) + dlen := dim + + dlen *= 2 + + for i := 0; i < columnLen; i++ { + entry := make([]byte, dlen) + v = append(v, entry) + } + column := NewColumnFloat16Vector(columnName, dim, v) + + t.Run("test meta", func(t *testing.T) { + ft := FieldTypeFloat16Vector + assert.Equal(t, "Float16Vector", ft.Name()) + assert.Equal(t, "[]byte", ft.String()) + pbName, pbType := ft.PbFieldType() + assert.Equal(t, "[]byte", pbName) + assert.Equal(t, "", pbType) + }) + + t.Run("test column attribute", func(t *testing.T) { + assert.Equal(t, columnName, column.Name()) + assert.Equal(t, FieldTypeFloat16Vector, column.Type()) + assert.Equal(t, columnLen, column.Len()) + assert.Equal(t, dim, column.Dim()) + assert.Equal(t ,v, column.Data()) + + var ev []byte + err := column.AppendValue(ev) + assert.Equal(t, columnLen+1, column.Len()) + assert.Nil(t, err) + + err = column.AppendValue(struct{}{}) + assert.Equal(t, columnLen+1, column.Len()) + assert.NotNil(t, err) + }) + + t.Run("test column field data", func(t *testing.T) { + fd := column.FieldData() + assert.NotNil(t, fd) + assert.Equal(t, fd.GetFieldName(), columnName) + + c, err := FieldDataVector(fd) + assert.NotNil(t, c) + assert.NoError(t, err) + }) + + t.Run("test column field data error", func(t *testing.T) { + fd := &schema.FieldData{ + Type: schema.DataType_Float16Vector, + FieldName: columnName, + } + _, err := FieldDataVector(fd) + assert.Error(t, err) + }) + +} + +func TestColumnBFloat16Vector(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + columnName := fmt.Sprintf("column_BFloat16Vector_%d", rand.Int()) + columnLen := 12 + rand.Intn(10) + dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] + + v := make([][]byte,0, columnLen) + dlen := dim + + dlen *= 2 + + for i := 0; i < columnLen; i++ { + entry := make([]byte, dlen) + v = append(v, entry) + } + column := NewColumnBFloat16Vector(columnName, dim, v) + + t.Run("test meta", func(t *testing.T) { + ft := FieldTypeBFloat16Vector + assert.Equal(t, "BFloat16Vector", ft.Name()) + assert.Equal(t, "[]byte", ft.String()) + pbName, pbType := ft.PbFieldType() + assert.Equal(t, "[]byte", pbName) + assert.Equal(t, "", pbType) + }) + + t.Run("test column attribute", func(t *testing.T) { + assert.Equal(t, columnName, column.Name()) + assert.Equal(t, FieldTypeBFloat16Vector, column.Type()) + assert.Equal(t, columnLen, column.Len()) + assert.Equal(t, dim, column.Dim()) + assert.Equal(t ,v, column.Data()) + + var ev []byte + err := column.AppendValue(ev) + assert.Equal(t, columnLen+1, column.Len()) + assert.Nil(t, err) + + err = column.AppendValue(struct{}{}) + assert.Equal(t, columnLen+1, column.Len()) + assert.NotNil(t, err) + }) + + t.Run("test column field data", func(t *testing.T) { + fd := column.FieldData() + assert.NotNil(t, fd) + assert.Equal(t, fd.GetFieldName(), columnName) + + c, err := FieldDataVector(fd) + assert.NotNil(t, c) + assert.NoError(t, err) + }) + + t.Run("test column field data error", func(t *testing.T) { + fd := &schema.FieldData{ + Type: schema.DataType_BFloat16Vector, + FieldName: columnName, + } + _, err := FieldDataVector(fd) + assert.Error(t, err) + }) + +} + diff --git a/entity/gen/gen.go b/entity/gen/gen.go index 1faef617..b7af8965 100644 --- a/entity/gen/gen.go +++ b/entity/gen/gen.go @@ -204,17 +204,19 @@ func (c *Column{{.TypeName}}) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - {{if eq .TypeName "BinaryVector" }} - Data: &schema.VectorField_BinaryVector{ - BinaryVector: data, - }, - {{else}} - Data: &schema.VectorField_FloatVector{ + {{if eq .TypeName "FloatVector" }} + Data: &schema.VectorField_{{.TypeName}}{ FloatVector: &schema.FloatArray{ Data: data, }, - }, + {{else if eq .TypeName "BFloat16Vector"}} + Data: &schema.VectorField_Bfloat16Vector{ + Bfloat16Vector: data, + {{else}} + Data: &schema.VectorField_{{.TypeName}}{ + {{.TypeName}}: data, {{end}} + }, }, } return fd @@ -377,6 +379,7 @@ func TestColumn{{.TypeName}}(t *testing.T) { v := make([]{{.TypeDef}},0, columnLen) dlen := dim {{if eq .TypeName "BinaryVector" }}dlen /= 8{{end}} + {{if or (eq .TypeName "BFloat16Vector") (eq .TypeName "Float16Vector") }}dlen *= 2{{end}} for i := 0; i < columnLen; i++ { entry := make({{.TypeDef}}, dlen) @@ -447,6 +450,8 @@ func main() { vectorFieldTypes := []entity.FieldType{ entity.FieldTypeBinaryVector, entity.FieldTypeFloatVector, + entity.FieldTypeFloat16Vector, + entity.FieldTypeBFloat16Vector, } pf := func(ft entity.FieldType) interface{} { diff --git a/entity/schema.go b/entity/schema.go index ee7179ab..868561e6 100644 --- a/entity/schema.go +++ b/entity/schema.go @@ -366,6 +366,10 @@ func (t FieldType) Name() string { return "BinaryVector" case FieldTypeFloatVector: return "FloatVector" + case FieldTypeFloat16Vector: + return "Float16Vector" + case FieldTypeBFloat16Vector: + return "BFloat16Vector" default: return "undefined" } @@ -400,6 +404,10 @@ func (t FieldType) String() string { return "[]byte" case FieldTypeFloatVector: return "[]float32" + case FieldTypeFloat16Vector: + return "[]byte" + case FieldTypeBFloat16Vector: + return "[]byte" default: return "undefined" } @@ -432,6 +440,10 @@ func (t FieldType) PbFieldType() (string, string) { return "[]byte", "" case FieldTypeFloatVector: return "[]float32", "" + case FieldTypeFloat16Vector: + return "[]byte", "" + case FieldTypeBFloat16Vector: + return "[]byte", "" default: return "undefined", "" } @@ -467,4 +479,8 @@ const ( FieldTypeBinaryVector FieldType = 100 // FieldTypeFloatVector field type float vector FieldTypeFloatVector FieldType = 101 + // FieldTypeBinaryVector field type float16 vector + FieldTypeFloat16Vector FieldType = 102 + // FieldTypeBinaryVector field type bf16 vector + FieldTypeBFloat16Vector FieldType = 103 ) diff --git a/entity/schema_test.go b/entity/schema_test.go index 07d07de6..811a46ba 100644 --- a/entity/schema_test.go +++ b/entity/schema_test.go @@ -26,6 +26,8 @@ func TestFieldSchema(t *testing.T) { NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"), NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true), NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128), + NewField().WithName("fp16_field").WithDataType(FieldTypeFloat16Vector).WithDim(128), + NewField().WithName("bf16_field").WithDataType(FieldTypeBFloat16Vector).WithDim(128), /* NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), @@ -117,6 +119,45 @@ func (s *SchemaSuite) TestBasic() { } } +func (s *SchemaSuite) TestFp16Vector() { + cases := []struct { + tag string + input *Schema + pkName string + }{ + { + "test_collection", + NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(true). + WithField(NewField().WithName("fp16_field").WithDataType(FieldTypeFloat16Vector).WithDim(128)). + WithField(NewField().WithName("bf16_field").WithDataType(FieldTypeBFloat16Vector).WithDim(128)), + "", + }, + } + + for _, c := range cases { + s.Run(c.tag, func() { + sch := c.input + p := sch.ProtoMessage() + s.Equal(sch.CollectionName, p.GetName()) + s.Equal(sch.AutoID, p.GetAutoID()) + s.Equal(sch.Description, p.GetDescription()) + s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField()) + s.Equal(len(sch.Fields), len(p.GetFields())) + + nsch := &Schema{} + nsch = nsch.ReadProto(p) + + s.Equal(sch.CollectionName, nsch.CollectionName) + s.Equal(sch.AutoID, nsch.AutoID) + s.Equal(sch.Description, nsch.Description) + s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField) + s.Equal(len(sch.Fields), len(nsch.Fields)) + s.Equal(c.pkName, sch.PKFieldName()) + s.Equal(c.pkName, nsch.PKFieldName()) + }) + } +} + func TestSchema(t *testing.T) { suite.Run(t, new(SchemaSuite)) } diff --git a/examples/bfloat16/bfloat16.go b/examples/bfloat16/bfloat16.go new file mode 100644 index 00000000..5095b08b --- /dev/null +++ b/examples/bfloat16/bfloat16.go @@ -0,0 +1,155 @@ +package main + +import ( + "context" + "encoding/binary" + "fmt" + "log" + "math/rand" + "time" + "unsafe" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +const ( + milvusAddr = `localhost:19530` + nEntities, dim = 3000, 128 + collectionName = "query_example" + + msgFmt = "==== %s ====\n" + idCol, randomCol, embeddingCol = "ID", "random", "embeddings" +) + +func toBFloat16(f float32) []byte { + bs := make([]byte, 2) + u32 := *(*uint32)(unsafe.Pointer(&f)) + binary.LittleEndian.PutUint16(bs, uint16(u32>>16)) + return bs +} + +func main() { + ctx := context.Background() + + log.Printf(msgFmt, "start connecting to Milvus") + c, err := client.NewClient(ctx, client.Config{ + Address: milvusAddr, + }) + if err != nil { + log.Fatalf("failed to connect to milvus, err: %v", err) + } + defer c.Close() + + // delete collection if exists + has, err := c.HasCollection(ctx, collectionName) + if err != nil { + log.Fatalf("failed to check collection exists, err: %v", err) + } + if has { + c.DropCollection(ctx, collectionName) + } + + // define schema + log.Printf(msgFmt, fmt.Sprintf("create collection, `%s`", collectionName)) + schema := entity.NewSchema().WithName(collectionName). + WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(false)). + WithField(entity.NewField().WithName(randomCol).WithDataType(entity.FieldTypeDouble)). + WithField(entity.NewField().WithName(embeddingCol).WithDataType(entity.FieldTypeBFloat16Vector).WithDim(dim)) + + // create collection with consistency level, which serves as the default search/query consistency level + if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber, client.WithConsistencyLevel(entity.ClBounded)); err != nil { + log.Fatalf("create collection failed, err: %v", err) + } + + log.Printf(msgFmt, "start inserting random entities") + idList, randomList := make([]int64, 0, nEntities), make([]float64, 0, nEntities) + embeddingList := make([][]byte, 0, nEntities) + + rand.Seed(time.Now().UnixNano()) + + // generate data + for i := 0; i < nEntities; i++ { + idList = append(idList, int64(i)) + } + for i := 0; i < nEntities; i++ { + randomList = append(randomList, rand.Float64()) + } + for i := 0; i < nEntities; i++ { + vec := make([]byte, 0, dim*2) + for j := 0; j < dim; j++ { + vec = append(vec, toBFloat16(rand.Float32())...) + } + embeddingList = append(embeddingList, vec) + } + idColData := entity.NewColumnInt64(idCol, idList) + randomColData := entity.NewColumnDouble(randomCol, randomList) + embeddingColData := entity.NewColumnBFloat16Vector(embeddingCol, dim, embeddingList) + + // build index + log.Printf(msgFmt, "start creating index IVF_FLAT") + idx, err := entity.NewIndexIvfFlat(entity.L2, 128) + if err != nil { + log.Fatalf("failed to create ivf flat index, err: %v", err) + } + if err := c.CreateIndex(ctx, collectionName, embeddingCol, idx, false); err != nil { + log.Fatalf("failed to create index, err: %v", err) + } + + // insert data + if _, err := c.Insert(ctx, collectionName, "", idColData, randomColData, embeddingColData); err != nil { + log.Fatalf("failed to insert random data into %s, err: %s", collectionName, err.Error()) + } + + log.Printf(msgFmt, "start loading collection") + err = c.LoadCollection(ctx, collectionName, false) + if err != nil { + log.Fatalf("failed to load collection, err: %v", err) + } + + //query + expr := "ID in [0, 1, 2]" + log.Printf(msgFmt, fmt.Sprintf("query with expr `%s`", expr)) + resultSet, err := c.Query(ctx, collectionName, nil, expr, []string{idCol, randomCol}) + if err != nil { + log.Fatalf("failed to query result, err: %v", err) + } + printResultSet(resultSet, idCol, randomCol) + + // drop collection + log.Printf(msgFmt, fmt.Sprintf("drop collection `%s`", collectionName)) + if err := c.DropCollection(ctx, collectionName); err != nil { + log.Fatalf("failed to drop collection, err: %v", err) + } +} + +func printResultSet(rs client.ResultSet, outputFields ...string) { + for _, fieldName := range outputFields { + column := rs.GetColumn(fieldName) + if column == nil { + log.Printf("column %s not exists in result set\n", fieldName) + } + switch column.Type() { + case entity.FieldTypeInt64: + var result []int64 + for i := 0; i < column.Len(); i++ { + v, err := column.GetAsInt64(i) + if err != nil { + log.Printf("column %s row %d cannot GetAsInt64, %s\n", fieldName, i, err.Error()) + } + result = append(result, v) + } + log.Printf("Column %s: value: %v\n", fieldName, result) + case entity.FieldTypeDouble: + var result []float64 + for i := 0; i < column.Len(); i++ { + v, err := column.GetAsDouble(i) + if err != nil { + log.Printf("column %s row %d cannot GetAsDouble, %s\n", fieldName, i, err.Error()) + } + result = append(result, v) + } + log.Printf("Column %s: value: %v\n", fieldName, result) + } + } +} diff --git a/examples/groupby/main.go b/examples/groupby/main.go deleted file mode 100644 index 307c1cef..00000000 --- a/examples/groupby/main.go +++ /dev/null @@ -1,129 +0,0 @@ -package main - -import ( - "context" - "log" - "math/rand" - "time" - - "github.com/milvus-io/milvus-sdk-go/v2/client" - "github.com/milvus-io/milvus-sdk-go/v2/entity" -) - -const ( - milvusAddr = `localhost:19530` - nEntities, dim = 10000, 128 - collectionName = "hello_group_by" - - idCol, keyCol, embeddingCol = "ID", "key", "embeddings" - topK = 3 -) - -func main() { - ctx := context.Background() - - log.Println("start connecting to Milvus") - c, err := client.NewClient(ctx, client.Config{ - Address: milvusAddr, - }) - if err != nil { - log.Fatalf("failed to connect to milvus, err: %v", err) - } - defer c.Close() - - // delete collection if exists - has, err := c.HasCollection(ctx, collectionName) - if err != nil { - log.Fatalf("failed to check collection exists, err: %v", err) - } - if has { - c.DropCollection(ctx, collectionName) - } - - // create collection - log.Printf("create collection `%s`\n", collectionName) - schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs"). - WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)). - WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64)). - WithField(entity.NewField().WithName(embeddingCol).WithDataType(entity.FieldTypeFloatVector).WithDim(dim)) - - if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber); err != nil { // use default shard number - log.Fatalf("create collection failed, err: %v", err) - } - - var keyList []int64 - var embeddingList [][]float32 - keyList = make([]int64, 0, nEntities) - embeddingList = make([][]float32, 0, nEntities) - for i := 0; i < nEntities; i++ { - keyList = append(keyList, rand.Int63()%512) - } - for i := 0; i < nEntities; i++ { - vec := make([]float32, 0, dim) - for j := 0; j < dim; j++ { - vec = append(vec, rand.Float32()) - } - embeddingList = append(embeddingList, vec) - } - keyColData := entity.NewColumnInt64(keyCol, keyList) - embeddingColData := entity.NewColumnFloatVector(embeddingCol, dim, embeddingList) - - log.Println("start to insert data into collection") - - if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData); err != nil { - log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err) - } - - log.Println("insert data done, start to flush") - - if err := c.Flush(ctx, collectionName, false); err != nil { - log.Fatalf("failed to flush data, err: %v", err) - } - log.Println("flush data done") - - // build index - log.Println("start creating index HNSW") - idx, err := entity.NewIndexHNSW(entity.L2, 16, 256) - if err != nil { - log.Fatalf("failed to create ivf flat index, err: %v", err) - } - if err := c.CreateIndex(ctx, collectionName, embeddingCol, idx, false); err != nil { - log.Fatalf("failed to create index, err: %v", err) - } - - log.Printf("build HNSW index done for collection `%s`\n", collectionName) - log.Printf("start to load collection `%s`\n", collectionName) - - // load collection - if err := c.LoadCollection(ctx, collectionName, false); err != nil { - log.Fatalf("failed to load collection, err: %v", err) - } - - log.Println("load collection done") - - vec2search := []entity.Vector{ - entity.FloatVector(embeddingList[len(embeddingList)-2]), - entity.FloatVector(embeddingList[len(embeddingList)-1]), - } - begin := time.Now() - sp, _ := entity.NewIndexHNSWSearchParam(30) - result, err := c.Search(ctx, collectionName, nil, "", []string{keyCol, embeddingCol}, vec2search, - embeddingCol, entity.L2, topK, sp, client.WithGroupByField(keyCol)) - if err != nil { - log.Fatalf("failed to search collection, err: %v", err) - } - - log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin)) - for _, rs := range result { - log.Printf("GroupByValue: %v\n", rs.GroupByValue) - for i := 0; i < rs.ResultCount; i++ { - id, _ := rs.IDs.GetAsInt64(i) - score := rs.Scores[i] - embedding, _ := rs.Fields.GetColumn(embeddingCol).Get(i) - - log.Printf("ID: %d, score %f, embedding: %v\n", id, score, embedding) - } - } - - c.DropCollection(ctx, collectionName) -} diff --git a/go.mod b/go.mod index b0ee74e3..c540c693 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/go-faker/faker/v4 v4.1.0 github.com/golang/protobuf v1.5.2 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 - github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1 + github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4 github.com/stretchr/testify v1.8.1 github.com/tidwall/gjson v1.14.4 google.golang.org/grpc v1.48.0 diff --git a/go.sum b/go.sum index a2c251a7..fa3e7ab3 100644 --- a/go.sum +++ b/go.sum @@ -157,8 +157,8 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1 h1:oNpMivd94JAMhdSVsFw8t1b+olXz8pbzd5PES21sth8= -github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240109020841-d367b5a59df1/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4 h1:HtNGcUb52ojnl+zDAZMmbHyVaTdBjzuCnnBHpb675TU= +github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= diff --git a/mocks/MilvusServiceServer.go b/mocks/MilvusServiceServer.go index fb41abba..9f562bbf 100644 --- a/mocks/MilvusServiceServer.go +++ b/mocks/MilvusServiceServer.go @@ -2887,61 +2887,6 @@ func (_c *MilvusServiceServer_HasPartition_Call) RunAndReturn(run func(context.C return _c } -// HybridSearch provides a mock function with given fields: _a0, _a1 -func (_m *MilvusServiceServer) HybridSearch(_a0 context.Context, _a1 *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) { - ret := _m.Called(_a0, _a1) - - var r0 *milvuspb.SearchResults - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error)); ok { - return rf(_a0, _a1) - } - if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.HybridSearchRequest) *milvuspb.SearchResults); ok { - r0 = rf(_a0, _a1) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*milvuspb.SearchResults) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.HybridSearchRequest) error); ok { - r1 = rf(_a0, _a1) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MilvusServiceServer_HybridSearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HybridSearch' -type MilvusServiceServer_HybridSearch_Call struct { - *mock.Call -} - -// HybridSearch is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 *milvuspb.HybridSearchRequest -func (_e *MilvusServiceServer_Expecter) HybridSearch(_a0 interface{}, _a1 interface{}) *MilvusServiceServer_HybridSearch_Call { - return &MilvusServiceServer_HybridSearch_Call{Call: _e.mock.On("HybridSearch", _a0, _a1)} -} - -func (_c *MilvusServiceServer_HybridSearch_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.HybridSearchRequest)) *MilvusServiceServer_HybridSearch_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*milvuspb.HybridSearchRequest)) - }) - return _c -} - -func (_c *MilvusServiceServer_HybridSearch_Call) Return(_a0 *milvuspb.SearchResults, _a1 error) *MilvusServiceServer_HybridSearch_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MilvusServiceServer_HybridSearch_Call) RunAndReturn(run func(context.Context, *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error)) *MilvusServiceServer_HybridSearch_Call { - _c.Call.Return(run) - return _c -} - // Import provides a mock function with given fields: _a0, _a1 func (_m *MilvusServiceServer) Import(_a0 context.Context, _a1 *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { ret := _m.Called(_a0, _a1) @@ -4097,6 +4042,61 @@ func (_c *MilvusServiceServer_Search_Call) RunAndReturn(run func(context.Context return _c } +// SearchV2 provides a mock function with given fields: _a0, _a1 +func (_m *MilvusServiceServer) SearchV2(_a0 context.Context, _a1 *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error) { + ret := _m.Called(_a0, _a1) + + var r0 *milvuspb.SearchResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, *milvuspb.SearchRequestV2) *milvuspb.SearchResults); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*milvuspb.SearchResults) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *milvuspb.SearchRequestV2) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MilvusServiceServer_SearchV2_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SearchV2' +type MilvusServiceServer_SearchV2_Call struct { + *mock.Call +} + +// SearchV2 is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 *milvuspb.SearchRequestV2 +func (_e *MilvusServiceServer_Expecter) SearchV2(_a0 interface{}, _a1 interface{}) *MilvusServiceServer_SearchV2_Call { + return &MilvusServiceServer_SearchV2_Call{Call: _e.mock.On("SearchV2", _a0, _a1)} +} + +func (_c *MilvusServiceServer_SearchV2_Call) Run(run func(_a0 context.Context, _a1 *milvuspb.SearchRequestV2)) *MilvusServiceServer_SearchV2_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*milvuspb.SearchRequestV2)) + }) + return _c +} + +func (_c *MilvusServiceServer_SearchV2_Call) Return(_a0 *milvuspb.SearchResults, _a1 error) *MilvusServiceServer_SearchV2_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MilvusServiceServer_SearchV2_Call) RunAndReturn(run func(context.Context, *milvuspb.SearchRequestV2) (*milvuspb.SearchResults, error)) *MilvusServiceServer_SearchV2_Call { + _c.Call.Return(run) + return _c +} + // SelectGrant provides a mock function with given fields: _a0, _a1 func (_m *MilvusServiceServer) SelectGrant(_a0 context.Context, _a1 *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { ret := _m.Called(_a0, _a1)