From fa4eb6dc840355962e3d9719d237789a23002acc Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Mon, 5 Feb 2024 15:36:59 +0800 Subject: [PATCH] Adding float16 and bfloat16 support Signed-off-by: Ted Xu --- client/collection.go | 5 +- entity/columns.go | 101 +++++++++++++++ entity/columns_scalar_gen.go | 69 +++++----- entity/columns_scalar_gen_test.go | 79 +++++++----- entity/columns_vector_gen.go | 204 ++++++++++++++++++++++++++++-- entity/columns_vector_gen_test.go | 164 +++++++++++++++++++++--- entity/gen/gen.go | 19 ++- entity/schema.go | 16 +++ entity/schema_test.go | 41 ++++++ examples/bfloat16/bfloat16.go | 155 +++++++++++++++++++++++ 10 files changed, 747 insertions(+), 106 deletions(-) create mode 100644 examples/bfloat16/bfloat16.go diff --git a/client/collection.go b/client/collection.go index 748d7015..8f53d13f 100644 --- a/client/collection.go +++ b/client/collection.go @@ -218,7 +218,10 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error { if field.IsPartitionKey { hasPartitionKey = true } - if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector { + if field.DataType == entity.FieldTypeFloatVector || + field.DataType == entity.FieldTypeBinaryVector || + field.DataType == entity.FieldTypeBFloat16Vector || + field.DataType == entity.FieldTypeFloat16Vector { vectors++ } } diff --git a/entity/columns.go b/entity/columns.go index af35632c..cc3e4b90 100644 --- a/entity/columns.go +++ b/entity/columns.go @@ -88,6 +88,40 @@ func (fv FloatVector) Serialize() []byte { return data } +// FloatVector float32 vector wrapper. +type Float16Vector []byte + +// Dim returns vector dimension. +func (fv Float16Vector) Dim() int { + return len(fv) / 2 +} + +// FieldType returns coresponding field type. +func (fv Float16Vector) FieldType() FieldType { + return FieldTypeFloat16Vector +} + +func (fv Float16Vector) Serialize() []byte { + return fv +} + +// FloatVector float32 vector wrapper. +type BFloat16Vector []byte + +// Dim returns vector dimension. +func (fv BFloat16Vector) Dim() int { + return len(fv) / 2 +} + +// FieldType returns coresponding field type. +func (fv BFloat16Vector) FieldType() FieldType { + return FieldTypeBFloat16Vector +} + +func (fv BFloat16Vector) Serialize() []byte { + return fv +} + // BinaryVector []byte vector wrapper type BinaryVector []byte @@ -306,6 +340,43 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) { } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_Float16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Float16Vector + dim := int(vectors.GetDim()) + if end < 0 { + end = int(len(data) / dim) + } + vector := make([][]byte, 0, end-begin) + for i := begin; i < end; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil + + case schema.DataType_BFloat16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Bfloat16Vector + dim := int(vectors.GetDim()) + if end < 0 { + end = int(len(data) / dim) + } + vector := make([][]byte, 0, end-begin) // shall not have remanunt + for i := begin; i < end; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil default: return nil, fmt.Errorf("unsupported data type %s", fd.GetType()) } @@ -447,6 +518,36 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) { vector = append(vector, v) } return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_Float16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Float16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Float16Vector + dim := int(vectors.GetDim()) + vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt + for i := 0; i < len(data)/dim; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil + case schema.DataType_BFloat16Vector: + vectors := fd.GetVectors() + x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector) + if !ok { + return nil, errFieldDataTypeNotMatch + } + data := x.Bfloat16Vector + dim := int(vectors.GetDim()) + vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt + for i := 0; i < len(data)/dim; i++ { + v := make([]byte, dim) + copy(v, data[i*dim:(i+1)*dim]) + vector = append(vector, v) + } + return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil default: return nil, errors.New("unsupported data type") } diff --git a/entity/columns_scalar_gen.go b/entity/columns_scalar_gen.go index 67d91f61..ac65a1b2 100755 --- a/entity/columns_scalar_gen.go +++ b/entity/columns_scalar_gen.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generate +// This file is generated by go generate -package entity +package entity import ( "errors" @@ -44,7 +44,7 @@ func (c *ColumnBool) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnBool) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Bool, + Type: schema.DataType_Bool, FieldName: c.name, } data := make([]bool, 0, c.Len()) @@ -74,7 +74,7 @@ func (c *ColumnBool) ValueByIdx(idx int) (bool, error) { } // AppendValue append value into column -func (c *ColumnBool) AppendValue(i interface{}) error { +func(c *ColumnBool) AppendValue(i interface{}) error { v, ok := i.(bool) if !ok { return fmt.Errorf("invalid type, expected bool, got %T", i) @@ -91,8 +91,8 @@ func (c *ColumnBool) Data() []bool { // NewColumnBool auto generated constructor func NewColumnBool(name string, values []bool) *ColumnBool { - return &ColumnBool{ - name: name, + return &ColumnBool { + name: name, values: values, } } @@ -131,7 +131,7 @@ func (c *ColumnInt8) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt8) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int8, + Type: schema.DataType_Int8, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -161,7 +161,7 @@ func (c *ColumnInt8) ValueByIdx(idx int) (int8, error) { } // AppendValue append value into column -func (c *ColumnInt8) AppendValue(i interface{}) error { +func(c *ColumnInt8) AppendValue(i interface{}) error { v, ok := i.(int8) if !ok { return fmt.Errorf("invalid type, expected int8, got %T", i) @@ -178,8 +178,8 @@ func (c *ColumnInt8) Data() []int8 { // NewColumnInt8 auto generated constructor func NewColumnInt8(name string, values []int8) *ColumnInt8 { - return &ColumnInt8{ - name: name, + return &ColumnInt8 { + name: name, values: values, } } @@ -218,7 +218,7 @@ func (c *ColumnInt16) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt16) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int16, + Type: schema.DataType_Int16, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -248,7 +248,7 @@ func (c *ColumnInt16) ValueByIdx(idx int) (int16, error) { } // AppendValue append value into column -func (c *ColumnInt16) AppendValue(i interface{}) error { +func(c *ColumnInt16) AppendValue(i interface{}) error { v, ok := i.(int16) if !ok { return fmt.Errorf("invalid type, expected int16, got %T", i) @@ -265,8 +265,8 @@ func (c *ColumnInt16) Data() []int16 { // NewColumnInt16 auto generated constructor func NewColumnInt16(name string, values []int16) *ColumnInt16 { - return &ColumnInt16{ - name: name, + return &ColumnInt16 { + name: name, values: values, } } @@ -305,7 +305,7 @@ func (c *ColumnInt32) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt32) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int32, + Type: schema.DataType_Int32, FieldName: c.name, } data := make([]int32, 0, c.Len()) @@ -335,7 +335,7 @@ func (c *ColumnInt32) ValueByIdx(idx int) (int32, error) { } // AppendValue append value into column -func (c *ColumnInt32) AppendValue(i interface{}) error { +func(c *ColumnInt32) AppendValue(i interface{}) error { v, ok := i.(int32) if !ok { return fmt.Errorf("invalid type, expected int32, got %T", i) @@ -352,8 +352,8 @@ func (c *ColumnInt32) Data() []int32 { // NewColumnInt32 auto generated constructor func NewColumnInt32(name string, values []int32) *ColumnInt32 { - return &ColumnInt32{ - name: name, + return &ColumnInt32 { + name: name, values: values, } } @@ -392,7 +392,7 @@ func (c *ColumnInt64) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnInt64) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Int64, + Type: schema.DataType_Int64, FieldName: c.name, } data := make([]int64, 0, c.Len()) @@ -422,7 +422,7 @@ func (c *ColumnInt64) ValueByIdx(idx int) (int64, error) { } // AppendValue append value into column -func (c *ColumnInt64) AppendValue(i interface{}) error { +func(c *ColumnInt64) AppendValue(i interface{}) error { v, ok := i.(int64) if !ok { return fmt.Errorf("invalid type, expected int64, got %T", i) @@ -439,8 +439,8 @@ func (c *ColumnInt64) Data() []int64 { // NewColumnInt64 auto generated constructor func NewColumnInt64(name string, values []int64) *ColumnInt64 { - return &ColumnInt64{ - name: name, + return &ColumnInt64 { + name: name, values: values, } } @@ -479,7 +479,7 @@ func (c *ColumnFloat) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnFloat) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Float, + Type: schema.DataType_Float, FieldName: c.name, } data := make([]float32, 0, c.Len()) @@ -509,7 +509,7 @@ func (c *ColumnFloat) ValueByIdx(idx int) (float32, error) { } // AppendValue append value into column -func (c *ColumnFloat) AppendValue(i interface{}) error { +func(c *ColumnFloat) AppendValue(i interface{}) error { v, ok := i.(float32) if !ok { return fmt.Errorf("invalid type, expected float32, got %T", i) @@ -526,8 +526,8 @@ func (c *ColumnFloat) Data() []float32 { // NewColumnFloat auto generated constructor func NewColumnFloat(name string, values []float32) *ColumnFloat { - return &ColumnFloat{ - name: name, + return &ColumnFloat { + name: name, values: values, } } @@ -566,7 +566,7 @@ func (c *ColumnDouble) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnDouble) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_Double, + Type: schema.DataType_Double, FieldName: c.name, } data := make([]float64, 0, c.Len()) @@ -596,7 +596,7 @@ func (c *ColumnDouble) ValueByIdx(idx int) (float64, error) { } // AppendValue append value into column -func (c *ColumnDouble) AppendValue(i interface{}) error { +func(c *ColumnDouble) AppendValue(i interface{}) error { v, ok := i.(float64) if !ok { return fmt.Errorf("invalid type, expected float64, got %T", i) @@ -613,8 +613,8 @@ func (c *ColumnDouble) Data() []float64 { // NewColumnDouble auto generated constructor func NewColumnDouble(name string, values []float64) *ColumnDouble { - return &ColumnDouble{ - name: name, + return &ColumnDouble { + name: name, values: values, } } @@ -653,7 +653,7 @@ func (c *ColumnString) Get(idx int) (interface{}, error) { // FieldData return column data mapped to schema.FieldData func (c *ColumnString) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_String, + Type: schema.DataType_String, FieldName: c.name, } data := make([]string, 0, c.Len()) @@ -683,7 +683,7 @@ func (c *ColumnString) ValueByIdx(idx int) (string, error) { } // AppendValue append value into column -func (c *ColumnString) AppendValue(i interface{}) error { +func(c *ColumnString) AppendValue(i interface{}) error { v, ok := i.(string) if !ok { return fmt.Errorf("invalid type, expected string, got %T", i) @@ -700,8 +700,9 @@ func (c *ColumnString) Data() []string { // NewColumnString auto generated constructor func NewColumnString(name string, values []string) *ColumnString { - return &ColumnString{ - name: name, + return &ColumnString { + name: name, values: values, } } + diff --git a/entity/columns_scalar_gen_test.go b/entity/columns_scalar_gen_test.go index 5366e405..70e1408c 100755 --- a/entity/columns_scalar_gen_test.go +++ b/entity/columns_scalar_gen_test.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generated +// This file is generated by go generated -package entity +package entity import ( "fmt" @@ -9,8 +9,8 @@ import ( "testing" "time" - schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/stretchr/testify/assert" + schema "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) func TestColumnBool(t *testing.T) { @@ -60,7 +60,7 @@ func TestFieldDataBoolColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Bool_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Bool, + Type: schema.DataType_Bool, FieldName: name, } @@ -77,7 +77,7 @@ func TestFieldDataBoolColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeBool, column.Type()) @@ -92,12 +92,13 @@ func TestFieldDataBoolColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -111,7 +112,7 @@ func TestFieldDataBoolColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeBool, column.Type()) @@ -165,7 +166,7 @@ func TestFieldDataInt8Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int8_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int8, + Type: schema.DataType_Int8, FieldName: name, } @@ -182,7 +183,7 @@ func TestFieldDataInt8Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt8, column.Type()) @@ -197,12 +198,13 @@ func TestFieldDataInt8Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -216,7 +218,7 @@ func TestFieldDataInt8Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt8, column.Type()) @@ -270,7 +272,7 @@ func TestFieldDataInt16Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int16_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int16, + Type: schema.DataType_Int16, FieldName: name, } @@ -287,7 +289,7 @@ func TestFieldDataInt16Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt16, column.Type()) @@ -302,12 +304,13 @@ func TestFieldDataInt16Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -321,7 +324,7 @@ func TestFieldDataInt16Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt16, column.Type()) @@ -375,7 +378,7 @@ func TestFieldDataInt32Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int32_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int32, + Type: schema.DataType_Int32, FieldName: name, } @@ -392,7 +395,7 @@ func TestFieldDataInt32Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt32, column.Type()) @@ -407,12 +410,13 @@ func TestFieldDataInt32Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -426,7 +430,7 @@ func TestFieldDataInt32Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt32, column.Type()) @@ -480,7 +484,7 @@ func TestFieldDataInt64Column(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Int64_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Int64, + Type: schema.DataType_Int64, FieldName: name, } @@ -497,7 +501,7 @@ func TestFieldDataInt64Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt64, column.Type()) @@ -512,12 +516,13 @@ func TestFieldDataInt64Column(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -531,7 +536,7 @@ func TestFieldDataInt64Column(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeInt64, column.Type()) @@ -585,7 +590,7 @@ func TestFieldDataFloatColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Float_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Float, + Type: schema.DataType_Float, FieldName: name, } @@ -602,7 +607,7 @@ func TestFieldDataFloatColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeFloat, column.Type()) @@ -617,12 +622,13 @@ func TestFieldDataFloatColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -636,7 +642,7 @@ func TestFieldDataFloatColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeFloat, column.Type()) @@ -690,7 +696,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_Double_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_Double, + Type: schema.DataType_Double, FieldName: name, } @@ -707,7 +713,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeDouble, column.Type()) @@ -722,12 +728,13 @@ func TestFieldDataDoubleColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -741,7 +748,7 @@ func TestFieldDataDoubleColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeDouble, column.Type()) @@ -795,7 +802,7 @@ func TestFieldDataStringColumn(t *testing.T) { len := rand.Intn(10) + 8 name := fmt.Sprintf("fd_String_%d", rand.Int()) fd := &schema.FieldData{ - Type: schema.DataType_String, + Type: schema.DataType_String, FieldName: name, } @@ -812,7 +819,7 @@ func TestFieldDataStringColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, len) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeString, column.Type()) @@ -827,12 +834,13 @@ func TestFieldDataStringColumn(t *testing.T) { assert.NotNil(t, err) }) + t.Run("nil data", func(t *testing.T) { fd.Field = nil _, err := FieldDataColumn(fd, 0, len) assert.NotNil(t, err) }) - + t.Run("get all data", func(t *testing.T) { fd.Field = &schema.FieldData_Scalars{ Scalars: &schema.ScalarField{ @@ -846,9 +854,10 @@ func TestFieldDataStringColumn(t *testing.T) { column, err := FieldDataColumn(fd, 0, -1) assert.Nil(t, err) assert.NotNil(t, column) - + assert.Equal(t, name, column.Name()) assert.Equal(t, len, column.Len()) assert.Equal(t, FieldTypeString, column.Type()) }) } + diff --git a/entity/columns_vector_gen.go b/entity/columns_vector_gen.go index 87a874c7..f70e826b 100755 --- a/entity/columns_vector_gen.go +++ b/entity/columns_vector_gen.go @@ -10,6 +10,7 @@ import ( "github.com/cockroachdb/errors" ) + // ColumnBinaryVector generated columns type for BinaryVector type ColumnBinaryVector struct { ColumnBase @@ -29,7 +30,7 @@ func (c *ColumnBinaryVector) Type() FieldType { } // Len returns column data length -func (c *ColumnBinaryVector) Len() int { +func (c * ColumnBinaryVector) Len() int { return len(c.values) } @@ -47,7 +48,7 @@ func (c *ColumnBinaryVector) Get(idx int) (interface{}, error) { } // AppendValue append value into column -func (c *ColumnBinaryVector) AppendValue(i interface{}) error { +func(c *ColumnBinaryVector) AppendValue(i interface{}) error { v, ok := i.([]byte) if !ok { return fmt.Errorf("invalid type, expected []byte, got %T", i) @@ -65,11 +66,11 @@ func (c *ColumnBinaryVector) Data() [][]byte { // FieldData return column data mapped to schema.FieldData func (c *ColumnBinaryVector) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_BinaryVector, + Type: schema.DataType_BinaryVector, FieldName: c.name, } - data := make([]byte, 0, len(c.values)*c.dim) + data := make([]byte, 0, len(c.values)* c.dim) for _, vector := range c.values { data = append(data, vector...) @@ -78,9 +79,10 @@ func (c *ColumnBinaryVector) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - - Data: &schema.VectorField_BinaryVector{ + + Data: &schema.VectorField_BinaryVector{ BinaryVector: data, + }, }, } @@ -89,7 +91,7 @@ func (c *ColumnBinaryVector) FieldData() *schema.FieldData { // NewColumnBinaryVector auto generated constructor func NewColumnBinaryVector(name string, dim int, values [][]byte) *ColumnBinaryVector { - return &ColumnBinaryVector{ + return &ColumnBinaryVector { name: name, dim: dim, values: values, @@ -115,7 +117,7 @@ func (c *ColumnFloatVector) Type() FieldType { } // Len returns column data length -func (c *ColumnFloatVector) Len() int { +func (c * ColumnFloatVector) Len() int { return len(c.values) } @@ -133,7 +135,7 @@ func (c *ColumnFloatVector) Get(idx int) (interface{}, error) { } // AppendValue append value into column -func (c *ColumnFloatVector) AppendValue(i interface{}) error { +func(c *ColumnFloatVector) AppendValue(i interface{}) error { v, ok := i.([]float32) if !ok { return fmt.Errorf("invalid type, expected []float32, got %T", i) @@ -151,11 +153,11 @@ func (c *ColumnFloatVector) Data() [][]float32 { // FieldData return column data mapped to schema.FieldData func (c *ColumnFloatVector) FieldData() *schema.FieldData { fd := &schema.FieldData{ - Type: schema.DataType_FloatVector, + Type: schema.DataType_FloatVector, FieldName: c.name, } - data := make([]float32, 0, len(c.values)*c.dim) + data := make([]float32, 0, len(c.values)* c.dim) for _, vector := range c.values { data = append(data, vector...) @@ -164,11 +166,12 @@ func (c *ColumnFloatVector) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - + Data: &schema.VectorField_FloatVector{ FloatVector: &schema.FloatArray{ Data: data, }, + }, }, } @@ -177,9 +180,184 @@ func (c *ColumnFloatVector) FieldData() *schema.FieldData { // NewColumnFloatVector auto generated constructor func NewColumnFloatVector(name string, dim int, values [][]float32) *ColumnFloatVector { - return &ColumnFloatVector{ + return &ColumnFloatVector { + name: name, + dim: dim, + values: values, + } +} + +// ColumnFloat16Vector generated columns type for Float16Vector +type ColumnFloat16Vector struct { + ColumnBase + name string + dim int + values [][]byte +} + +// Name returns column name +func (c *ColumnFloat16Vector) Name() string { + return c.name +} + +// Type returns column FieldType +func (c *ColumnFloat16Vector) Type() FieldType { + return FieldTypeFloat16Vector +} + +// Len returns column data length +func (c * ColumnFloat16Vector) Len() int { + return len(c.values) +} + +// Dim returns vector dimension +func (c *ColumnFloat16Vector) Dim() int { + return c.dim +} + +// Get returns values at index as interface{}. +func (c *ColumnFloat16Vector) Get(idx int) (interface{}, error) { + if idx < 0 || idx >= c.Len() { + return nil, errors.New("index out of range") + } + return c.values[idx], nil +} + +// AppendValue append value into column +func(c *ColumnFloat16Vector) AppendValue(i interface{}) error { + v, ok := i.([]byte) + if !ok { + return fmt.Errorf("invalid type, expected []byte, got %T", i) + } + c.values = append(c.values, v) + + return nil +} + +// Data returns column data +func (c *ColumnFloat16Vector) Data() [][]byte { + return c.values +} + +// FieldData return column data mapped to schema.FieldData +func (c *ColumnFloat16Vector) FieldData() *schema.FieldData { + fd := &schema.FieldData{ + Type: schema.DataType_Float16Vector, + FieldName: c.name, + } + + data := make([]byte, 0, len(c.values)* c.dim) + + for _, vector := range c.values { + data = append(data, vector...) + } + + fd.Field = &schema.FieldData_Vectors{ + Vectors: &schema.VectorField{ + Dim: int64(c.dim), + + Data: &schema.VectorField_Float16Vector{ + Float16Vector: data, + + }, + }, + } + return fd +} + +// NewColumnFloat16Vector auto generated constructor +func NewColumnFloat16Vector(name string, dim int, values [][]byte) *ColumnFloat16Vector { + return &ColumnFloat16Vector { name: name, dim: dim, values: values, } } + +// ColumnBFloat16Vector generated columns type for BFloat16Vector +type ColumnBFloat16Vector struct { + ColumnBase + name string + dim int + values [][]byte +} + +// Name returns column name +func (c *ColumnBFloat16Vector) Name() string { + return c.name +} + +// Type returns column FieldType +func (c *ColumnBFloat16Vector) Type() FieldType { + return FieldTypeBFloat16Vector +} + +// Len returns column data length +func (c * ColumnBFloat16Vector) Len() int { + return len(c.values) +} + +// Dim returns vector dimension +func (c *ColumnBFloat16Vector) Dim() int { + return c.dim +} + +// Get returns values at index as interface{}. +func (c *ColumnBFloat16Vector) Get(idx int) (interface{}, error) { + if idx < 0 || idx >= c.Len() { + return nil, errors.New("index out of range") + } + return c.values[idx], nil +} + +// AppendValue append value into column +func(c *ColumnBFloat16Vector) AppendValue(i interface{}) error { + v, ok := i.([]byte) + if !ok { + return fmt.Errorf("invalid type, expected []byte, got %T", i) + } + c.values = append(c.values, v) + + return nil +} + +// Data returns column data +func (c *ColumnBFloat16Vector) Data() [][]byte { + return c.values +} + +// FieldData return column data mapped to schema.FieldData +func (c *ColumnBFloat16Vector) FieldData() *schema.FieldData { + fd := &schema.FieldData{ + Type: schema.DataType_BFloat16Vector, + FieldName: c.name, + } + + data := make([]byte, 0, len(c.values)* c.dim) + + for _, vector := range c.values { + data = append(data, vector...) + } + + fd.Field = &schema.FieldData_Vectors{ + Vectors: &schema.VectorField{ + Dim: int64(c.dim), + + Data: &schema.VectorField_Bfloat16Vector{ + Bfloat16Vector: data, + + }, + }, + } + return fd +} + +// NewColumnBFloat16Vector auto generated constructor +func NewColumnBFloat16Vector(name string, dim int, values [][]byte) *ColumnBFloat16Vector { + return &ColumnBFloat16Vector { + name: name, + dim: dim, + values: values, + } +} + diff --git a/entity/columns_vector_gen_test.go b/entity/columns_vector_gen_test.go index d959e2c5..c5753adf 100755 --- a/entity/columns_vector_gen_test.go +++ b/entity/columns_vector_gen_test.go @@ -1,7 +1,7 @@ // Code generated by go generate; DO NOT EDIT -// This file is generated by go generated +// This file is generated by go generated -package entity +package entity import ( "fmt" @@ -19,16 +19,17 @@ func TestColumnBinaryVector(t *testing.T) { columnLen := 12 + rand.Intn(10) dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] - v := make([][]byte, 0, columnLen) + v := make([][]byte,0, columnLen) dlen := dim dlen /= 8 - + + for i := 0; i < columnLen; i++ { entry := make([]byte, dlen) v = append(v, entry) } column := NewColumnBinaryVector(columnName, dim, v) - + t.Run("test meta", func(t *testing.T) { ft := FieldTypeBinaryVector assert.Equal(t, "BinaryVector", ft.Name()) @@ -43,13 +44,13 @@ func TestColumnBinaryVector(t *testing.T) { assert.Equal(t, FieldTypeBinaryVector, column.Type()) assert.Equal(t, columnLen, column.Len()) assert.Equal(t, dim, column.Dim()) - assert.Equal(t, v, column.Data()) - + assert.Equal(t ,v, column.Data()) + var ev []byte err := column.AppendValue(ev) assert.Equal(t, columnLen+1, column.Len()) assert.Nil(t, err) - + err = column.AppendValue(struct{}{}) assert.Equal(t, columnLen+1, column.Len()) assert.NotNil(t, err) @@ -70,7 +71,7 @@ func TestColumnBinaryVector(t *testing.T) { Type: schema.DataType_BinaryVector, FieldName: columnName, } - _, err := FieldDataVector(fd) + _, err := FieldDataVector(fd) assert.Error(t, err) }) @@ -82,15 +83,17 @@ func TestColumnFloatVector(t *testing.T) { columnLen := 12 + rand.Intn(10) dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] - v := make([][]float32, 0, columnLen) + v := make([][]float32,0, columnLen) dlen := dim - + + + for i := 0; i < columnLen; i++ { entry := make([]float32, dlen) v = append(v, entry) } column := NewColumnFloatVector(columnName, dim, v) - + t.Run("test meta", func(t *testing.T) { ft := FieldTypeFloatVector assert.Equal(t, "FloatVector", ft.Name()) @@ -105,13 +108,13 @@ func TestColumnFloatVector(t *testing.T) { assert.Equal(t, FieldTypeFloatVector, column.Type()) assert.Equal(t, columnLen, column.Len()) assert.Equal(t, dim, column.Dim()) - assert.Equal(t, v, column.Data()) - + assert.Equal(t ,v, column.Data()) + var ev []float32 err := column.AppendValue(ev) assert.Equal(t, columnLen+1, column.Len()) assert.Nil(t, err) - + err = column.AppendValue(struct{}{}) assert.Equal(t, columnLen+1, column.Len()) assert.NotNil(t, err) @@ -132,8 +135,137 @@ func TestColumnFloatVector(t *testing.T) { Type: schema.DataType_FloatVector, FieldName: columnName, } - _, err := FieldDataVector(fd) + _, err := FieldDataVector(fd) assert.Error(t, err) }) } + +func TestColumnFloat16Vector(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + columnName := fmt.Sprintf("column_Float16Vector_%d", rand.Int()) + columnLen := 12 + rand.Intn(10) + dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] + + v := make([][]byte,0, columnLen) + dlen := dim + + dlen *= 2 + + for i := 0; i < columnLen; i++ { + entry := make([]byte, dlen) + v = append(v, entry) + } + column := NewColumnFloat16Vector(columnName, dim, v) + + t.Run("test meta", func(t *testing.T) { + ft := FieldTypeFloat16Vector + assert.Equal(t, "Float16Vector", ft.Name()) + assert.Equal(t, "[]byte", ft.String()) + pbName, pbType := ft.PbFieldType() + assert.Equal(t, "[]byte", pbName) + assert.Equal(t, "", pbType) + }) + + t.Run("test column attribute", func(t *testing.T) { + assert.Equal(t, columnName, column.Name()) + assert.Equal(t, FieldTypeFloat16Vector, column.Type()) + assert.Equal(t, columnLen, column.Len()) + assert.Equal(t, dim, column.Dim()) + assert.Equal(t ,v, column.Data()) + + var ev []byte + err := column.AppendValue(ev) + assert.Equal(t, columnLen+1, column.Len()) + assert.Nil(t, err) + + err = column.AppendValue(struct{}{}) + assert.Equal(t, columnLen+1, column.Len()) + assert.NotNil(t, err) + }) + + t.Run("test column field data", func(t *testing.T) { + fd := column.FieldData() + assert.NotNil(t, fd) + assert.Equal(t, fd.GetFieldName(), columnName) + + c, err := FieldDataVector(fd) + assert.NotNil(t, c) + assert.NoError(t, err) + }) + + t.Run("test column field data error", func(t *testing.T) { + fd := &schema.FieldData{ + Type: schema.DataType_Float16Vector, + FieldName: columnName, + } + _, err := FieldDataVector(fd) + assert.Error(t, err) + }) + +} + +func TestColumnBFloat16Vector(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + columnName := fmt.Sprintf("column_BFloat16Vector_%d", rand.Int()) + columnLen := 12 + rand.Intn(10) + dim := ([]int{64, 128, 256, 512})[rand.Intn(4)] + + v := make([][]byte,0, columnLen) + dlen := dim + + dlen *= 2 + + for i := 0; i < columnLen; i++ { + entry := make([]byte, dlen) + v = append(v, entry) + } + column := NewColumnBFloat16Vector(columnName, dim, v) + + t.Run("test meta", func(t *testing.T) { + ft := FieldTypeBFloat16Vector + assert.Equal(t, "BFloat16Vector", ft.Name()) + assert.Equal(t, "[]byte", ft.String()) + pbName, pbType := ft.PbFieldType() + assert.Equal(t, "[]byte", pbName) + assert.Equal(t, "", pbType) + }) + + t.Run("test column attribute", func(t *testing.T) { + assert.Equal(t, columnName, column.Name()) + assert.Equal(t, FieldTypeBFloat16Vector, column.Type()) + assert.Equal(t, columnLen, column.Len()) + assert.Equal(t, dim, column.Dim()) + assert.Equal(t ,v, column.Data()) + + var ev []byte + err := column.AppendValue(ev) + assert.Equal(t, columnLen+1, column.Len()) + assert.Nil(t, err) + + err = column.AppendValue(struct{}{}) + assert.Equal(t, columnLen+1, column.Len()) + assert.NotNil(t, err) + }) + + t.Run("test column field data", func(t *testing.T) { + fd := column.FieldData() + assert.NotNil(t, fd) + assert.Equal(t, fd.GetFieldName(), columnName) + + c, err := FieldDataVector(fd) + assert.NotNil(t, c) + assert.NoError(t, err) + }) + + t.Run("test column field data error", func(t *testing.T) { + fd := &schema.FieldData{ + Type: schema.DataType_BFloat16Vector, + FieldName: columnName, + } + _, err := FieldDataVector(fd) + assert.Error(t, err) + }) + +} + diff --git a/entity/gen/gen.go b/entity/gen/gen.go index 1faef617..b7af8965 100644 --- a/entity/gen/gen.go +++ b/entity/gen/gen.go @@ -204,17 +204,19 @@ func (c *Column{{.TypeName}}) FieldData() *schema.FieldData { fd.Field = &schema.FieldData_Vectors{ Vectors: &schema.VectorField{ Dim: int64(c.dim), - {{if eq .TypeName "BinaryVector" }} - Data: &schema.VectorField_BinaryVector{ - BinaryVector: data, - }, - {{else}} - Data: &schema.VectorField_FloatVector{ + {{if eq .TypeName "FloatVector" }} + Data: &schema.VectorField_{{.TypeName}}{ FloatVector: &schema.FloatArray{ Data: data, }, - }, + {{else if eq .TypeName "BFloat16Vector"}} + Data: &schema.VectorField_Bfloat16Vector{ + Bfloat16Vector: data, + {{else}} + Data: &schema.VectorField_{{.TypeName}}{ + {{.TypeName}}: data, {{end}} + }, }, } return fd @@ -377,6 +379,7 @@ func TestColumn{{.TypeName}}(t *testing.T) { v := make([]{{.TypeDef}},0, columnLen) dlen := dim {{if eq .TypeName "BinaryVector" }}dlen /= 8{{end}} + {{if or (eq .TypeName "BFloat16Vector") (eq .TypeName "Float16Vector") }}dlen *= 2{{end}} for i := 0; i < columnLen; i++ { entry := make({{.TypeDef}}, dlen) @@ -447,6 +450,8 @@ func main() { vectorFieldTypes := []entity.FieldType{ entity.FieldTypeBinaryVector, entity.FieldTypeFloatVector, + entity.FieldTypeFloat16Vector, + entity.FieldTypeBFloat16Vector, } pf := func(ft entity.FieldType) interface{} { diff --git a/entity/schema.go b/entity/schema.go index ee7179ab..868561e6 100644 --- a/entity/schema.go +++ b/entity/schema.go @@ -366,6 +366,10 @@ func (t FieldType) Name() string { return "BinaryVector" case FieldTypeFloatVector: return "FloatVector" + case FieldTypeFloat16Vector: + return "Float16Vector" + case FieldTypeBFloat16Vector: + return "BFloat16Vector" default: return "undefined" } @@ -400,6 +404,10 @@ func (t FieldType) String() string { return "[]byte" case FieldTypeFloatVector: return "[]float32" + case FieldTypeFloat16Vector: + return "[]byte" + case FieldTypeBFloat16Vector: + return "[]byte" default: return "undefined" } @@ -432,6 +440,10 @@ func (t FieldType) PbFieldType() (string, string) { return "[]byte", "" case FieldTypeFloatVector: return "[]float32", "" + case FieldTypeFloat16Vector: + return "[]byte", "" + case FieldTypeBFloat16Vector: + return "[]byte", "" default: return "undefined", "" } @@ -467,4 +479,8 @@ const ( FieldTypeBinaryVector FieldType = 100 // FieldTypeFloatVector field type float vector FieldTypeFloatVector FieldType = 101 + // FieldTypeBinaryVector field type float16 vector + FieldTypeFloat16Vector FieldType = 102 + // FieldTypeBinaryVector field type bf16 vector + FieldTypeBFloat16Vector FieldType = 103 ) diff --git a/entity/schema_test.go b/entity/schema_test.go index 07d07de6..811a46ba 100644 --- a/entity/schema_test.go +++ b/entity/schema_test.go @@ -26,6 +26,8 @@ func TestFieldSchema(t *testing.T) { NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"), NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true), NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128), + NewField().WithName("fp16_field").WithDataType(FieldTypeFloat16Vector).WithDim(128), + NewField().WithName("bf16_field").WithDataType(FieldTypeBFloat16Vector).WithDim(128), /* NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), @@ -117,6 +119,45 @@ func (s *SchemaSuite) TestBasic() { } } +func (s *SchemaSuite) TestFp16Vector() { + cases := []struct { + tag string + input *Schema + pkName string + }{ + { + "test_collection", + NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(true). + WithField(NewField().WithName("fp16_field").WithDataType(FieldTypeFloat16Vector).WithDim(128)). + WithField(NewField().WithName("bf16_field").WithDataType(FieldTypeBFloat16Vector).WithDim(128)), + "", + }, + } + + for _, c := range cases { + s.Run(c.tag, func() { + sch := c.input + p := sch.ProtoMessage() + s.Equal(sch.CollectionName, p.GetName()) + s.Equal(sch.AutoID, p.GetAutoID()) + s.Equal(sch.Description, p.GetDescription()) + s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField()) + s.Equal(len(sch.Fields), len(p.GetFields())) + + nsch := &Schema{} + nsch = nsch.ReadProto(p) + + s.Equal(sch.CollectionName, nsch.CollectionName) + s.Equal(sch.AutoID, nsch.AutoID) + s.Equal(sch.Description, nsch.Description) + s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField) + s.Equal(len(sch.Fields), len(nsch.Fields)) + s.Equal(c.pkName, sch.PKFieldName()) + s.Equal(c.pkName, nsch.PKFieldName()) + }) + } +} + func TestSchema(t *testing.T) { suite.Run(t, new(SchemaSuite)) } diff --git a/examples/bfloat16/bfloat16.go b/examples/bfloat16/bfloat16.go new file mode 100644 index 00000000..5095b08b --- /dev/null +++ b/examples/bfloat16/bfloat16.go @@ -0,0 +1,155 @@ +package main + +import ( + "context" + "encoding/binary" + "fmt" + "log" + "math/rand" + "time" + "unsafe" + + "github.com/milvus-io/milvus-sdk-go/v2/client" + "github.com/milvus-io/milvus-sdk-go/v2/entity" +) + +const ( + milvusAddr = `localhost:19530` + nEntities, dim = 3000, 128 + collectionName = "query_example" + + msgFmt = "==== %s ====\n" + idCol, randomCol, embeddingCol = "ID", "random", "embeddings" +) + +func toBFloat16(f float32) []byte { + bs := make([]byte, 2) + u32 := *(*uint32)(unsafe.Pointer(&f)) + binary.LittleEndian.PutUint16(bs, uint16(u32>>16)) + return bs +} + +func main() { + ctx := context.Background() + + log.Printf(msgFmt, "start connecting to Milvus") + c, err := client.NewClient(ctx, client.Config{ + Address: milvusAddr, + }) + if err != nil { + log.Fatalf("failed to connect to milvus, err: %v", err) + } + defer c.Close() + + // delete collection if exists + has, err := c.HasCollection(ctx, collectionName) + if err != nil { + log.Fatalf("failed to check collection exists, err: %v", err) + } + if has { + c.DropCollection(ctx, collectionName) + } + + // define schema + log.Printf(msgFmt, fmt.Sprintf("create collection, `%s`", collectionName)) + schema := entity.NewSchema().WithName(collectionName). + WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(false)). + WithField(entity.NewField().WithName(randomCol).WithDataType(entity.FieldTypeDouble)). + WithField(entity.NewField().WithName(embeddingCol).WithDataType(entity.FieldTypeBFloat16Vector).WithDim(dim)) + + // create collection with consistency level, which serves as the default search/query consistency level + if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber, client.WithConsistencyLevel(entity.ClBounded)); err != nil { + log.Fatalf("create collection failed, err: %v", err) + } + + log.Printf(msgFmt, "start inserting random entities") + idList, randomList := make([]int64, 0, nEntities), make([]float64, 0, nEntities) + embeddingList := make([][]byte, 0, nEntities) + + rand.Seed(time.Now().UnixNano()) + + // generate data + for i := 0; i < nEntities; i++ { + idList = append(idList, int64(i)) + } + for i := 0; i < nEntities; i++ { + randomList = append(randomList, rand.Float64()) + } + for i := 0; i < nEntities; i++ { + vec := make([]byte, 0, dim*2) + for j := 0; j < dim; j++ { + vec = append(vec, toBFloat16(rand.Float32())...) + } + embeddingList = append(embeddingList, vec) + } + idColData := entity.NewColumnInt64(idCol, idList) + randomColData := entity.NewColumnDouble(randomCol, randomList) + embeddingColData := entity.NewColumnBFloat16Vector(embeddingCol, dim, embeddingList) + + // build index + log.Printf(msgFmt, "start creating index IVF_FLAT") + idx, err := entity.NewIndexIvfFlat(entity.L2, 128) + if err != nil { + log.Fatalf("failed to create ivf flat index, err: %v", err) + } + if err := c.CreateIndex(ctx, collectionName, embeddingCol, idx, false); err != nil { + log.Fatalf("failed to create index, err: %v", err) + } + + // insert data + if _, err := c.Insert(ctx, collectionName, "", idColData, randomColData, embeddingColData); err != nil { + log.Fatalf("failed to insert random data into %s, err: %s", collectionName, err.Error()) + } + + log.Printf(msgFmt, "start loading collection") + err = c.LoadCollection(ctx, collectionName, false) + if err != nil { + log.Fatalf("failed to load collection, err: %v", err) + } + + //query + expr := "ID in [0, 1, 2]" + log.Printf(msgFmt, fmt.Sprintf("query with expr `%s`", expr)) + resultSet, err := c.Query(ctx, collectionName, nil, expr, []string{idCol, randomCol}) + if err != nil { + log.Fatalf("failed to query result, err: %v", err) + } + printResultSet(resultSet, idCol, randomCol) + + // drop collection + log.Printf(msgFmt, fmt.Sprintf("drop collection `%s`", collectionName)) + if err := c.DropCollection(ctx, collectionName); err != nil { + log.Fatalf("failed to drop collection, err: %v", err) + } +} + +func printResultSet(rs client.ResultSet, outputFields ...string) { + for _, fieldName := range outputFields { + column := rs.GetColumn(fieldName) + if column == nil { + log.Printf("column %s not exists in result set\n", fieldName) + } + switch column.Type() { + case entity.FieldTypeInt64: + var result []int64 + for i := 0; i < column.Len(); i++ { + v, err := column.GetAsInt64(i) + if err != nil { + log.Printf("column %s row %d cannot GetAsInt64, %s\n", fieldName, i, err.Error()) + } + result = append(result, v) + } + log.Printf("Column %s: value: %v\n", fieldName, result) + case entity.FieldTypeDouble: + var result []float64 + for i := 0; i < column.Len(); i++ { + v, err := column.GetAsDouble(i) + if err != nil { + log.Printf("column %s row %d cannot GetAsDouble, %s\n", fieldName, i, err.Error()) + } + result = append(result, v) + } + log.Printf("Column %s: value: %v\n", fieldName, result) + } + } +}