Skip to content

Commit

Permalink
Adding float16 and bfloat16 support
Browse files Browse the repository at this point in the history
Signed-off-by: Ted Xu <ted.xu@zilliz.com>
  • Loading branch information
tedxu committed Feb 5, 2024
1 parent 1e03ea4 commit fa4eb6d
Show file tree
Hide file tree
Showing 10 changed files with 747 additions and 106 deletions.
5 changes: 4 additions & 1 deletion client/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ func (c *GrpcClient) validateSchema(sch *entity.Schema) error {
if field.IsPartitionKey {
hasPartitionKey = true
}
if field.DataType == entity.FieldTypeFloatVector || field.DataType == entity.FieldTypeBinaryVector {
if field.DataType == entity.FieldTypeFloatVector ||
field.DataType == entity.FieldTypeBinaryVector ||
field.DataType == entity.FieldTypeBFloat16Vector ||
field.DataType == entity.FieldTypeFloat16Vector {
vectors++
}
}
Expand Down
101 changes: 101 additions & 0 deletions entity/columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,40 @@ func (fv FloatVector) Serialize() []byte {
return data
}

// FloatVector float32 vector wrapper.
type Float16Vector []byte

// Dim returns vector dimension.
func (fv Float16Vector) Dim() int {
return len(fv) / 2
}

// FieldType returns coresponding field type.
func (fv Float16Vector) FieldType() FieldType {
return FieldTypeFloat16Vector
}

func (fv Float16Vector) Serialize() []byte {
return fv
}

// FloatVector float32 vector wrapper.
type BFloat16Vector []byte

// Dim returns vector dimension.
func (fv BFloat16Vector) Dim() int {
return len(fv) / 2
}

// FieldType returns coresponding field type.
func (fv BFloat16Vector) FieldType() FieldType {
return FieldTypeBFloat16Vector
}

func (fv BFloat16Vector) Serialize() []byte {
return fv
}

// BinaryVector []byte vector wrapper
type BinaryVector []byte

Expand Down Expand Up @@ -306,6 +340,43 @@ func FieldDataColumn(fd *schema.FieldData, begin, end int) (Column, error) {
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil

case schema.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin)
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil

case schema.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
if end < 0 {
end = int(len(data) / dim)
}
vector := make([][]byte, 0, end-begin) // shall not have remanunt
for i := begin; i < end; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, fmt.Errorf("unsupported data type %s", fd.GetType())
}
Expand Down Expand Up @@ -447,6 +518,36 @@ func FieldDataVector(fd *schema.FieldData) (Column, error) {
vector = append(vector, v)
}
return NewColumnBinaryVector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_Float16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Float16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Float16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnFloat16Vector(fd.GetFieldName(), dim, vector), nil
case schema.DataType_BFloat16Vector:
vectors := fd.GetVectors()
x, ok := vectors.GetData().(*schema.VectorField_Bfloat16Vector)
if !ok {
return nil, errFieldDataTypeNotMatch
}
data := x.Bfloat16Vector
dim := int(vectors.GetDim())
vector := make([][]byte, 0, len(data)/dim) // shall not have remanunt
for i := 0; i < len(data)/dim; i++ {
v := make([]byte, dim)
copy(v, data[i*dim:(i+1)*dim])
vector = append(vector, v)
}
return NewColumnBFloat16Vector(fd.GetFieldName(), dim, vector), nil
default:
return nil, errors.New("unsupported data type")
}
Expand Down
69 changes: 35 additions & 34 deletions entity/columns_scalar_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit fa4eb6d

Please sign in to comment.