diff --git a/Cargo.toml b/Cargo.toml index 838d2c2..3df4260 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ strum = "0.24" strum_macros = "0.24" base64 = "0.21.0" dashmap = "5.5.3" +half = "2.4.1" [build-dependencies] tonic-build = { version = "0.8.2", default-features = false, features = [ @@ -30,3 +31,4 @@ tonic-build = { version = "0.8.2", default-features = false, features = [ [dev-dependencies] rand = "0.8.5" +futures = "0.3" diff --git a/examples/collection.rs b/examples/collection.rs index 97cbe74..8f8965b 100644 --- a/examples/collection.rs +++ b/examples/collection.rs @@ -9,7 +9,8 @@ use std::collections::HashMap; use rand::prelude::*; -const DEFAULT_VEC_FIELD: &str = "embed"; +const FP32_VEC_FIELD: &str = "float32_vector_field"; + const DIM: i64 = 256; #[tokio::main] @@ -26,8 +27,8 @@ async fn main() -> Result<(), Error> { true, )) .add_field(FieldSchema::new_float_vector( - DEFAULT_VEC_FIELD, - "feature field", + FP32_VEC_FIELD, + "fp32 feature field", DIM, )) .build()?; @@ -48,8 +49,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result< let embed = rng.gen(); embed_data.push(embed); } - let embed_column = - FieldColumn::new(collection.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(collection.get_field(FP32_VEC_FIELD).unwrap(), embed_data)?; client .insert(collection.name(), vec![embed_column], None) @@ -62,7 +62,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result< HashMap::from([("nlist".to_owned(), "32".to_owned())]), ); client - .create_index(collection.name(), DEFAULT_VEC_FIELD, index_params) + .create_index(collection.name(), FP32_VEC_FIELD, index_params) .await?; client .load_collection(collection.name(), Some(LoadOptions::default())) diff --git a/examples/fp16_and_bf16.rs b/examples/fp16_and_bf16.rs new file mode 100644 index 0000000..bea9c08 --- /dev/null +++ b/examples/fp16_and_bf16.rs @@ -0,0 +1,119 @@ +use milvus::index::{IndexParams, IndexType}; +use milvus::options::LoadOptions; +use milvus::query::SearchOptions; +use milvus::schema::{CollectionSchema, CollectionSchemaBuilder}; +use milvus::{ + client::Client, data::FieldColumn, error::Error, schema::FieldSchema, +}; + +use half::prelude::*; +use rand::prelude::*; +use std::collections::HashMap; + +const FP16_VEC_FIELD: &str = "float16_vector_field"; +const BF16_VEC_FIELD: &str = "bfloat16_vector_field"; + +const DIM: i64 = 64; + +#[tokio::main] +async fn main() -> Result<(), Error> { + const URL: &str = "http://localhost:19530"; + + let client = Client::new(URL).await?; + + let schema = + CollectionSchemaBuilder::new("milvus_fp16", "fp16/bf16 example for milvus rust SDK") + .add_field(FieldSchema::new_primary_int64( + "id", + "primary key field", + true, + )) + .add_field(FieldSchema::new_float16_vector( + FP16_VEC_FIELD, + "fp16 feature field", + DIM, + )) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bf16 feature field", + DIM, + )) + .build()?; + client.create_collection(schema.clone(), None).await?; + + if let Err(err) = fp16_insert_and_query(&client, &schema).await { + println!("failed to run hello milvus: {:?}", err); + } + client.drop_collection(schema.name()).await?; + + Ok(()) +} + +fn gen_random_f32_vector(n: i64) -> Vec { + let mut data = Vec::::with_capacity(n as usize); + let mut rng = rand::thread_rng(); + for _ in 0..n { + data.push(rng.gen()); + } + data +} + +async fn fp16_insert_and_query( + client: &Client, + collection: &CollectionSchema, +) -> Result<(), Error> { + let mut embed_data = Vec::::new(); + for _ in 1..=DIM * 1000 { + let mut rng = rand::thread_rng(); + let embed = rng.gen(); + embed_data.push(embed); + } + + // fp16 or bf16 vector accept Vec, Vec or Vec/Vec as input + let bf16_column = FieldColumn::new( + collection.get_field(BF16_VEC_FIELD).unwrap(), + Vec::::from_f32_slice(embed_data.as_slice()), + )?; + let fp16_column = FieldColumn::new(collection.get_field(FP16_VEC_FIELD).unwrap(), embed_data)?; + + let result = client + .insert(collection.name(), vec![fp16_column, bf16_column], None) + .await?; + println!("insert cnt: {}", result.insert_cnt); + client.flush(collection.name()).await?; + + let create_index_fut = [FP16_VEC_FIELD, BF16_VEC_FIELD].map(|field_name| { + let index_params = IndexParams::new( + field_name.to_string() + "_index", + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ); + client.create_index(collection.name(), field_name, index_params) + }); + futures::future::try_join_all(create_index_fut).await?; + client.flush(collection.name()).await?; + client + .load_collection(collection.name(), Some(LoadOptions::default())) + .await?; + + // search + let q1 = Vec::::from_f32_slice(&gen_random_f32_vector(DIM)); + let q2 = Vec::::from_f32_slice(&gen_random_f32_vector(DIM)); + let option = SearchOptions::with_limit(3) + .metric_type(milvus::index::MetricType::L2) + .output_fields(vec!["id".to_owned(), FP16_VEC_FIELD.to_owned()]); + let result = client + .search( + collection.name(), + vec![q1.into(), q2.into()], + FP16_VEC_FIELD, + &option, + ) + .await?; + + println!("{:?}", result[0]); + println!("result num: {}, {}", result[0].size, result[1].size); + + Ok(()) +} diff --git a/src/collection.rs b/src/collection.rs index d70748f..c8f7306 100644 --- a/src/collection.rs +++ b/src/collection.rs @@ -660,6 +660,7 @@ pub type ParamValue = serde_json::Value; pub use serde_json::json as ParamValue; // search result for a single vector +#[derive(Debug)] pub struct SearchResult<'a> { pub size: i64, pub id: Vec>, diff --git a/src/data.rs b/src/data.rs index 7258bf6..efb0c23 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,13 +1,14 @@ -use std::borrow::Cow; - +use crate::error::Result; use crate::{ proto::schema::{ self, field_data::Field, scalar_field::Data as ScalarData, vector_field::Data as VectorData, DataType, ScalarField, VectorField, }, schema::FieldSchema, - value::{Value, ValueVec}, + value::{TryIntoValueVecWithDataType, Value, ValueVec}, }; +use half::prelude::*; +use std::borrow::Cow; pub trait HasDataType { fn data_type() -> DataType; @@ -34,8 +35,12 @@ impl_has_data_type! { String, DataType::String, Cow<'_, str>, DataType::String, Vec, DataType::FloatVector, + Vec, DataType::Float16Vector, + Vec, DataType::BFloat16Vector, Vec, DataType::BinaryVector, Cow<'_, [f32]>, DataType::FloatVector, + Cow<'_, [f16]>, DataType::Float16Vector, + Cow<'_, [bf16]>, DataType::BFloat16Vector, Cow<'_, [u8]>, DataType::BinaryVector } @@ -72,15 +77,16 @@ impl From for FieldColumn { } impl FieldColumn { - pub fn new>(schm: &FieldSchema, v: V) -> FieldColumn { - FieldColumn { + pub fn new(schm: &FieldSchema, v: V) -> Result { + let value: ValueVec = v.try_into_value_vec(schm.dtype)?; + Ok(FieldColumn { name: schm.name.clone(), dtype: schm.dtype, - value: v.into(), + value, dim: schm.dim, max_length: schm.max_length, is_dynamic: false, - } + }) } pub fn get(&self, idx: usize) -> Option> { @@ -107,6 +113,14 @@ impl FieldColumn { let dim = (self.dim / 8) as usize; Value::Binary(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) } + ValueVec::Float16(v) => { + let dim = self.dim as usize; + Value::Float16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) + } + ValueVec::BFloat16(v) => { + let dim = self.dim as usize; + Value::BFloat16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) + } ValueVec::String(v) => Value::String(Cow::Borrowed(v.get(idx)?.as_ref())), ValueVec::Json(v) => Value::Json(Cow::Borrowed(v.get(idx)?.as_ref())), ValueVec::Array(v) => Value::Array(Cow::Borrowed(v.get(idx)?)), @@ -126,6 +140,8 @@ impl FieldColumn { (ValueVec::String(vec), Value::String(i)) => vec.push(i.to_string()), (ValueVec::Binary(vec), Value::Binary(i)) => vec.extend_from_slice(i.as_ref()), (ValueVec::Float(vec), Value::FloatArray(i)) => vec.extend_from_slice(i.as_ref()), + (ValueVec::Float16(vec), Value::Float16Array(i)) => vec.extend_from_slice(i.as_ref()), + (ValueVec::BFloat16(vec), Value::BFloat16Array(i)) => vec.extend_from_slice(i.as_ref()), _ => panic!("column type mismatch"), } } @@ -135,6 +151,11 @@ impl FieldColumn { self.value.len() / self.dim as usize } + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn copy_with_metadata(&self) -> Self { Self { dim: self.dim, @@ -151,6 +172,8 @@ impl FieldColumn { ValueVec::String(_) => ValueVec::String(Vec::new()), ValueVec::Json(_) => ValueVec::Json(Vec::new()), ValueVec::Binary(_) => ValueVec::Binary(Vec::new()), + ValueVec::Float16(_) => ValueVec::Float16(Vec::new()), + ValueVec::BFloat16(_) => ValueVec::BFloat16(Vec::new()), ValueVec::Array(_) => ValueVec::Array(Vec::new()), }, is_dynamic: self.is_dynamic, @@ -176,6 +199,7 @@ impl From for schema::FieldData { data: Some(ScalarData::LongData(schema::LongArray { data: v })), }), ValueVec::Float(v) => match this.dtype { + // both scalar and vector fields accept 1-d float array DataType::Float => Field::Scalars(ScalarField { data: Some(ScalarData::FloatData(schema::FloatArray { data: v })), }), @@ -204,6 +228,20 @@ impl From for schema::FieldData { data: Some(VectorData::BinaryVector(v)), dim: this.dim, }), + ValueVec::BFloat16(v) => { + let v: Vec = v.into_iter().flat_map(|x| x.to_le_bytes()).collect(); + Field::Vectors(VectorField { + data: Some(VectorData::Bfloat16Vector(v)), + dim: this.dim, + }) + } + ValueVec::Float16(v) => { + let v: Vec = v.into_iter().flat_map(|x| x.to_le_bytes()).collect(); + Field::Vectors(VectorField { + data: Some(VectorData::Float16Vector(v)), + dim: this.dim, + }) + } }), is_dynamic: false, } @@ -236,7 +274,9 @@ impl_from_field! { Vec[Field::Scalars(ScalarField {data: Some(ScalarData::LongData(schema::LongArray { data }))}) => Some(data)], Vec[Field::Scalars(ScalarField {data: Some(ScalarData::StringData(schema::StringArray { data }))}) => Some(data)], Vec[Field::Scalars(ScalarField {data: Some(ScalarData::DoubleData(schema::DoubleArray { data }))}) => Some(data)], - Vec[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)] + Vec[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)], + Vec[Field::Vectors(VectorField {data: Some(VectorData::Float16Vector(data)), ..}) => Some(data.chunks(2).map(|x|f16::from_le_bytes([x[0], x[1]])).collect())], + Vec[Field::Vectors(VectorField {data: Some(VectorData::Bfloat16Vector(data)), ..}) => Some(data.chunks(2).map(|x|bf16::from_le_bytes([x[0], x[1]])).collect())] } impl FromField for Vec { diff --git a/src/mutate.rs b/src/mutate.rs index 17680f6..0322dee 100644 --- a/src/mutate.rs +++ b/src/mutate.rs @@ -1,19 +1,15 @@ -use prost::bytes::{BufMut, BytesMut}; use crate::error::Result; use crate::{ client::Client, - collection, data::FieldColumn, error::Error, proto::{ self, common::{MsgBase, MsgType}, milvus::{InsertRequest, UpsertRequest}, - schema::{scalar_field::Data, DataType}, + schema::DataType, }, - schema::FieldData, - utils::status_to_result, value::ValueVec, }; diff --git a/src/partition.rs b/src/partition.rs index 0e13909..34af03a 100644 --- a/src/partition.rs +++ b/src/partition.rs @@ -4,7 +4,6 @@ use crate::error::*; use crate::{ client::Client, proto::{ - self, common::{MsgBase, MsgType}, }, utils::status_to_result, diff --git a/src/query.rs b/src/query.rs index 044e4c6..c13b882 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,8 +1,7 @@ -use std::borrow::Borrow; -use std::collections::HashMap; - +use half::{bf16, f16}; use prost::bytes::BytesMut; use prost::Message; +use std::collections::HashMap; use crate::client::{Client, ConsistencyLevel}; use crate::collection::{ParamValue, SearchResult}; @@ -372,6 +371,8 @@ fn get_place_holder_value(vectors: Vec) -> Result { match vectors[0] { Value::FloatArray(_) => place_holder.r#type = PlaceholderType::FloatVector as _, + Value::Float16Array(_) => place_holder.r#type = PlaceholderType::Float16Vector as _, + Value::BFloat16Array(_) => place_holder.r#type = PlaceholderType::BFloat16Vector as _, Value::Binary(_) => place_holder.r#type = PlaceholderType::BinaryVector as _, _ => { return Err(SuperError::from(crate::collection::Error::IllegalType( @@ -381,14 +382,25 @@ fn get_place_holder_value(vectors: Vec) -> Result { } }; + macro_rules! place_holder_push_bytes { + ($d:expr, $t:ty) => { + let mut bytes = Vec::::with_capacity($d.len() * size_of::<$t>()); + for f in $d.iter() { + bytes.extend_from_slice(&f.to_le_bytes()); + } + place_holder.values.push(bytes) + }; + } for v in &vectors { match (v, &vectors[0]) { (Value::FloatArray(d), Value::FloatArray(_)) => { - let mut bytes = Vec::::with_capacity(d.len() * 4); - for f in d.iter() { - bytes.extend_from_slice(&f.to_le_bytes()); - } - place_holder.values.push(bytes) + place_holder_push_bytes!(d, f32); + } + (Value::Float16Array(d), Value::Float16Array(_)) => { + place_holder_push_bytes!(d, f16); + } + (Value::BFloat16Array(d), Value::BFloat16Array(_)) => { + place_holder_push_bytes!(d, bf16); } (Value::Binary(d), Value::Binary(_)) => place_holder.values.push(d.to_vec()), _ => { diff --git a/src/schema.rs b/src/schema.rs index 73460b2..b866b4b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -153,7 +153,7 @@ pub struct FieldSchema { pub is_primary: bool, pub auto_id: bool, pub chunk_size: usize, - pub dim: i64, // only for BinaryVector and FloatVector + pub dim: i64, // only for BinaryVector, FloatVector, Float16Vector and BFloat16Vector pub max_length: i32, // only for VarChar } @@ -392,12 +392,49 @@ impl FieldSchema { max_length: 0, } } + + pub fn new_float16_vector(name: &str, description: &str, dim: i64) -> Self { + if dim <= 0 { + panic!("dim should be positive"); + } + + Self { + name: name.to_owned(), + description: description.to_owned(), + dtype: DataType::Float16Vector, + chunk_size: dim as usize, + dim, + is_primary: false, + auto_id: false, + max_length: 0, + } + } + + pub fn new_bfloat16_vector(name: &str, description: &str, dim: i64) -> Self { + if dim <= 0 { + panic!("dim should be positive"); + } + + Self { + name: name.to_owned(), + description: description.to_owned(), + dtype: DataType::BFloat16Vector, + chunk_size: dim as usize, + dim, + is_primary: false, + auto_id: false, + max_length: 0, + } + } } impl From for schema::FieldSchema { fn from(fld: FieldSchema) -> schema::FieldSchema { let params = match fld.dtype { - DataType::BinaryVector | DataType::FloatVector => vec![KeyValuePair { + DataType::BinaryVector + | DataType::FloatVector + | DataType::Float16Vector + | DataType::BFloat16Vector => vec![KeyValuePair { key: "dim".to_string(), value: fld.dim.to_string(), }], @@ -465,16 +502,16 @@ impl CollectionSchema { pub fn is_valid_vector_field(&self, field_name: &str) -> Result<()> { for f in &self.fields { if f.name == field_name { - if f.dtype == DataType::BinaryVector || f.dtype == DataType::FloatVector { - return Ok(()); - } else { - return Err(error::Error::from(Error::NotVectorField( - field_name.to_owned(), - ))); - } + return match f.dtype { + DataType::BinaryVector + | DataType::FloatVector + | DataType::Float16Vector + | DataType::BFloat16Vector => Ok(()), + _ => Err(Error::NotVectorField(field_name.to_owned()).into()), + }; } } - return Err(error::Error::from(Error::NoSuchKey(field_name.to_owned()))); + Err(error::Error::from(Error::NoSuchKey(field_name.to_owned()))) } } diff --git a/src/value.rs b/src/value.rs index f2c72e1..5e4b53b 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use crate::proto::{ self, schema::{ @@ -7,7 +5,12 @@ use crate::proto::{ DataType, }, }; +use half::prelude::*; +use std::borrow::Cow; +use strum_macros::Display; +/// Value represents a scalar value or a vector value. +#[derive(Debug, Clone, Display)] pub enum Value<'a> { None, Bool(bool), @@ -18,6 +21,10 @@ pub enum Value<'a> { Float(f32), Double(f64), FloatArray(Cow<'a, [f32]>), + /// schema.proto uses bytes to represent binary_vector, + /// float16_vector and bfloat16_vector. + Float16Array(Cow<'a, [f16]>), + BFloat16Array(Cow<'a, [bf16]>), Binary(Cow<'a, [u8]>), String(Cow<'a, str>), Json(Cow<'a, [u8]>), @@ -58,6 +65,8 @@ impl Value<'_> { Value::String(_) => DataType::String, Value::Json(_) => DataType::Json, Value::FloatArray(_) => DataType::FloatVector, + Value::Float16Array(_) => DataType::Float16Vector, + Value::BFloat16Array(_) => DataType::BFloat16Vector, Value::Binary(_) => DataType::BinaryVector, Value::Array(_) => DataType::Array, } @@ -100,14 +109,89 @@ impl From> for Value<'static> { } } -#[derive(Debug, Clone)] +impl<'a> From<&'a [f16]> for Value<'a> { + fn from(v: &'a [f16]) -> Self { + Self::Float16Array(Cow::Borrowed(v)) + } +} + +impl From> for Value<'static> { + fn from(v: Vec) -> Self { + Self::Float16Array(Cow::Owned(v)) + } +} + +impl<'a> From<&'a [bf16]> for Value<'a> { + fn from(v: &'a [bf16]) -> Self { + Self::BFloat16Array(Cow::Borrowed(v)) + } +} + +impl From> for Value<'static> { + fn from(v: Vec) -> Self { + Self::BFloat16Array(Cow::Owned(v)) + } +} + +macro_rules! impl_try_from_for_value_column { + ( $($o: ident,$t: ty ),+ ) => {$( + impl TryFrom> for $t { + type Error = crate::error::Error; + fn try_from(value: Value<'_>) -> Result { + match value { + Value::$o(v) => Ok(v), + _ => Err(crate::error::Error::Conversion), + } + } + } + )*}; +} + +impl_try_from_for_value_column! { + Bool,bool, + Int8,i8, + Int16,i16, + Int32,i32, + Long,i64, + Float,f32, + Double,f64 +} + +macro_rules! impl_try_from_for_value_column { + ( $($o: ident,$t: ty ),+ ) => {$( + impl TryFrom> for $t { + type Error = crate::error::Error; + fn try_from(value: Value<'_>) -> Result { + match value { + Value::$o(v) => Ok(v.into_owned()), + _ => Err(crate::error::Error::Conversion), + } + } + } + )*}; +} + +impl_try_from_for_value_column! { + FloatArray,Vec, + Float16Array,Vec, + BFloat16Array,Vec +} + +/// ValueVec represents a column of data. +/// Both scalar_field value and vector_field value are represented by 1-d array. +#[derive(Debug, Clone, Display)] pub enum ValueVec { None, Bool(Vec), Int(Vec), Long(Vec), + /// float or float vector Float(Vec), Double(Vec), + /// float16 vector + Float16(Vec), + /// bfloat16 vector + BFloat16(Vec), Binary(Vec), String(Vec), Json(Vec>), @@ -130,10 +214,83 @@ impl_from_for_value_vec! { Vec, Long, Vec, String, Vec, Binary, + Vec, BFloat16, + Vec, Float16, Vec, Float, Vec, Double } +pub trait TryIntoValueVecWithDataType: Into { + fn try_into_value_vec(self, dtype: DataType) -> crate::error::Result; +} + +impl TryIntoValueVecWithDataType for Vec { + fn try_into_value_vec(self, dtype: DataType) -> crate::error::Result { + let v: ValueVec = match dtype { + DataType::Float16Vector => { + let v: Vec = Vec::::from_f32_slice(&self); + v.into() + } + DataType::BFloat16Vector => { + let v: Vec = Vec::::from_f32_slice(&self); + v.into() + } + _ => self.into(), + }; + if v.check_dtype(dtype) { + Ok(v) + } else { + Err(crate::error::Error::Conversion) + } + } +} + +impl TryIntoValueVecWithDataType for Vec { + fn try_into_value_vec(self, dtype: DataType) -> crate::error::Result { + let v: ValueVec = match dtype { + DataType::Float16Vector => { + let v: Vec = Vec::::from_f64_slice(&self); + v.into() + } + DataType::BFloat16Vector => { + let v: Vec = Vec::::from_f64_slice(&self); + v.into() + } + _ => self.into(), + }; + if v.check_dtype(dtype) { + Ok(v) + } else { + Err(crate::error::Error::Conversion) + } + } +} + +macro_rules! impl_try_into_value_vec_with_datatype { + ( $($t: ty),+ ) => {$( + impl TryIntoValueVecWithDataType for $t { + fn try_into_value_vec(self, dtype: DataType) -> crate::error::Result { + let v: ValueVec = self.into(); + if v.check_dtype(dtype) { + Ok(v) + } else { + Err(crate::error::Error::Conversion) + } + } + } + )*}; +} + +impl_try_into_value_vec_with_datatype!( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec +); + macro_rules! impl_try_from_for_value_vec { ( $($o: ident, $t: ty ),+ ) => {$( impl TryFrom for $t { @@ -154,6 +311,8 @@ impl_try_from_for_value_vec! { Long, Vec, String, Vec, Binary, Vec, + BFloat16, Vec, + Float16, Vec, Float, Vec, Double, Vec } @@ -187,27 +346,49 @@ impl ValueVec { DataType::Array => Self::Array(Vec::new()), DataType::BinaryVector => Self::Binary(Vec::new()), DataType::FloatVector => Self::Float(Vec::new()), - DataType::Float16Vector => Self::Binary(Vec::new()), - DataType::BFloat16Vector => Self::Binary(Vec::new()), + DataType::Float16Vector => Self::Float16(Vec::new()), + DataType::BFloat16Vector => Self::BFloat16(Vec::new()), } } + pub fn float16_from_f32_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f32_slice(v.as_slice()); + Self::Float16(v) + } + + pub fn float16_from_f64_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f64_slice(v.as_slice()); + Self::Float16(v) + } + + pub fn bfloat16_from_f32_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f32_slice(v.as_slice()); + Self::BFloat16(v) + } + + pub fn bfloat16_from_f64_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f64_slice(v.as_slice()); + Self::BFloat16(v) + } + pub fn check_dtype(&self, dtype: DataType) -> bool { - match (self, dtype) { + matches!( + (self, dtype), (ValueVec::Binary(..), DataType::BinaryVector) - | (ValueVec::Float(..), DataType::FloatVector) - | (ValueVec::Float(..), DataType::Float) - | (ValueVec::Int(..), DataType::Int8) - | (ValueVec::Int(..), DataType::Int16) - | (ValueVec::Int(..), DataType::Int32) - | (ValueVec::Long(..), DataType::Int64) - | (ValueVec::Bool(..), DataType::Bool) - | (ValueVec::String(..), DataType::String) - | (ValueVec::String(..), DataType::VarChar) - | (ValueVec::None, _) - | (ValueVec::Double(..), DataType::Double) => true, - _ => false, - } + | (ValueVec::BFloat16(..), DataType::BFloat16Vector) + | (ValueVec::Float16(..), DataType::Float16Vector) + | (ValueVec::Float(..), DataType::FloatVector) + | (ValueVec::Float(..), DataType::Float) + | (ValueVec::Int(..), DataType::Int8) + | (ValueVec::Int(..), DataType::Int16) + | (ValueVec::Int(..), DataType::Int32) + | (ValueVec::Long(..), DataType::Int64) + | (ValueVec::Bool(..), DataType::Bool) + | (ValueVec::String(..), DataType::String) + | (ValueVec::String(..), DataType::VarChar) + | (ValueVec::None, _) + | (ValueVec::Double(..), DataType::Double) + ) } #[inline] @@ -221,6 +402,8 @@ impl ValueVec { ValueVec::Bool(v) => v.len(), ValueVec::Int(v) => v.len(), ValueVec::Long(v) => v.len(), + ValueVec::BFloat16(v) => v.len(), + ValueVec::Float16(v) => v.len(), ValueVec::Float(v) => v.len(), ValueVec::Double(v) => v.len(), ValueVec::Binary(v) => v.len(), @@ -236,6 +419,8 @@ impl ValueVec { ValueVec::Bool(v) => v.clear(), ValueVec::Int(v) => v.clear(), ValueVec::Long(v) => v.clear(), + ValueVec::BFloat16(v) => v.clear(), + ValueVec::Float16(v) => v.clear(), ValueVec::Float(v) => v.clear(), ValueVec::Double(v) => v.clear(), ValueVec::Binary(v) => v.clear(), @@ -268,8 +453,20 @@ impl From for ValueVec { Some(x) => match x { VectorData::FloatVector(v) => Self::Float(v.data), VectorData::BinaryVector(v) => Self::Binary(v), - VectorData::Bfloat16Vector(v) => Self::Binary(v), - VectorData::Float16Vector(v) => Self::Binary(v), + VectorData::Bfloat16Vector(v) => { + let v: Vec = v + .chunks_exact(2) + .map(|x| bf16::from_le_bytes([x[0], x[1]])) + .collect(); + Self::BFloat16(v) + } + VectorData::Float16Vector(v) => { + let v: Vec = v + .chunks_exact(2) + .map(|x| f16::from_le_bytes([x[0], x[1]])) + .collect(); + Self::Float16(v) + } }, None => Self::None, }, @@ -277,36 +474,13 @@ impl From for ValueVec { } } -macro_rules! impl_try_from_for_value_column { - ( $($o: ident,$t: ty ),+ ) => {$( - impl TryFrom> for $t { - type Error = crate::error::Error; - fn try_from(value: Value<'_>) -> Result { - match value { - Value::$o(v) => Ok(v), - _ => Err(crate::error::Error::Conversion), - } - } - } - )*}; -} - -impl_try_from_for_value_column! { - Bool,bool, - Int8,i8, - Int16,i16, - Int32,i32, - Long,i64, - Float,f32, - Double,f64 -} - #[cfg(test)] mod test { use crate::{ error::Error, value::{Value, ValueVec}, }; + use half::prelude::*; #[test] fn test_try_from_for_value_column() { @@ -345,6 +519,24 @@ mod test { let r: Result = double.try_into(); assert!(r.is_ok()); assert_eq!(22104f64, r.unwrap()); + + let float_array: Value = vec![1.1, 2.2, 3.3].into(); + let r: Result, Error> = float_array.try_into(); + assert_eq!(vec![1.1, 2.2, 3.3], r.unwrap()); + + let float16_array: Value = Vec::::from_f32_slice(&[1.1, 2.2, 3.3, -1.1]).into(); + let r: Result, Error> = float16_array.try_into(); + assert_eq!( + Vec::::from_f32_slice(&[1.1, 2.2, 3.3, -1.1]), + r.unwrap() + ); + + let bfloat16_array: Value = Vec::::from_f32_slice(&[1.1, 2.2, 3.3, 8.21, 0.]).into(); + let r: Result, Error> = bfloat16_array.try_into(); + assert_eq!( + Vec::::from_f32_slice(&[1.1, 2.2, 3.3, 8.21, 0.]), + r.unwrap() + ); } #[test] @@ -365,6 +557,17 @@ mod test { assert!(b.is_ok()); assert_eq!(v, b.unwrap()); + let v: Vec = Vec::from_f32_slice(&[-1.1, -2.2, 3.3, -4.4, -5.5]); + let b = ValueVec::BFloat16(v.clone()); + let b: Result, Error> = b.try_into(); + assert!(b.is_ok()); + + let v: Vec = Vec::from_f32_slice(&[1.1, 2.2, 3.3, 4.4, 5.5]); + let b = ValueVec::Float16(v.clone()); + let b: Result, Error> = b.try_into(); + assert!(b.is_ok()); + assert_eq!(v, b.unwrap()); + let v: Vec = vec![11., 7., 754., 68., 34.]; let b = ValueVec::Float(v.clone()); let b: Result, Error> = b.try_into(); diff --git a/tests/client.rs b/tests/client.rs index 8720e4a..9c5aefb 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -75,6 +75,11 @@ async fn create_has_drop_collection() -> Result<()> { let schema = schema .add_field(FieldSchema::new_int64("i64_field", "")) .add_field(FieldSchema::new_bool("bool_field", "")) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bfloat", + 128, + )) .set_primary_key("i64_field")? .enable_auto_id()? .build()?; diff --git a/tests/collection.rs b/tests/collection.rs index afb58f7..046dddf 100644 --- a/tests/collection.rs +++ b/tests/collection.rs @@ -23,10 +23,9 @@ use milvus::mutate::InsertOptions; use milvus::options::LoadOptions; use milvus::query::{QueryOptions, SearchOptions}; use std::collections::HashMap; - mod common; use common::*; - +use half::prelude::*; use milvus::value::ValueVec; #[tokio::test] @@ -42,8 +41,8 @@ async fn collection_upsert() -> Result<()> { let (client, schema) = create_test_collection(false).await?; let pk_data = gen_random_int64_vector(2000); let vec_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let pk_col = FieldColumn::new(schema.get_field("id").unwrap(), pk_data); - let vec_col = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), vec_data); + let pk_col = FieldColumn::new(schema.get_field("id").unwrap(), pk_data)?; + let vec_col = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), vec_data)?; client .upsert(schema.name(), vec![pk_col, vec_col], None) .await?; @@ -77,7 +76,7 @@ async fn collection_basic() -> Result<()> { let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) @@ -108,13 +107,114 @@ async fn collection_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn collection_fp16_bf16_vec() -> Result<()> { + let (client, schema) = create_test_fp16_bf16_collection(true).await?; + let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); + + let fp16_vector_column = FieldColumn::new( + schema.get_field(FP16_VEC_FIELD).unwrap(), + embed_data.clone(), + )?; + let bf16_vector_column = + FieldColumn::new(schema.get_field(BF16_VEC_FIELD).unwrap(), embed_data)?; + + client + .insert( + schema.name(), + vec![fp16_vector_column, bf16_vector_column], + None, + ) + .await?; + client.flush(schema.name()).await?; + for (index_name, field_name) in [ + (FP16_VEC_INDEX_NAME, FP16_VEC_FIELD), + (BF16_VEC_INDEX_NAME, BF16_VEC_FIELD), + ] { + let index_params = IndexParams::new( + index_name.to_owned(), + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ); + client + .create_index(schema.name(), field_name, index_params) + .await?; + } + client + .load_collection(schema.name(), Some(LoadOptions::default())) + .await?; + + let options = QueryOptions::default(); + let result = client.query(schema.name(), "id > 0", &options).await?; + + println!( + "result num: {}", + result.first().map(|c| c.len()).unwrap_or(0), + ); + // create index + let create_index_futs = [ + (FP16_VEC_FIELD, FP16_VEC_INDEX_NAME), + (BF16_VEC_FIELD, BF16_VEC_INDEX_NAME), + ] + .map(|(field_name, index_name)| { + client.create_index( + schema.name(), + field_name, + IndexParams::new( + index_name.to_owned(), + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ), + ) + }); + futures::future::join_all(create_index_futs).await; + client.flush(schema.name()).await?; + client + .load_collection(schema.name(), Some(LoadOptions::default())) + .await?; + + // search + let mut option = SearchOptions::with_limit(10) + .metric_type(MetricType::L2) + .output_fields(vec!["id".to_owned(), FP16_VEC_FIELD.to_owned()]); + option = option.add_param("nprobe", ParamValue!(16)); + let query_vec = gen_random_f32_vector(DEFAULT_DIM); + + let result = client + .search( + schema.name(), + vec![Vec::::from_f32_slice(query_vec.as_slice()).into()], + FP16_VEC_FIELD, + &option, + ) + .await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].size, 10); + assert_eq!(result[0].field.len(), 2); + + let result = client + .search( + schema.name(), + vec![Vec::::from_f32_slice(query_vec.as_slice()).into()], + BF16_VEC_FIELD, + &option, + ) + .await?; + assert_eq!(result[0].size, 10); + + client.drop_collection(schema.name()).await?; + Ok(()) +} + #[tokio::test] async fn collection_index() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let feature = gen_random_f32_vector(DEFAULT_DIM * 2000); - let feature_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), feature); + let feature_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), feature)?; client .insert(schema.name(), vec![feature_column], None) @@ -149,7 +249,7 @@ async fn collection_search() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) @@ -195,7 +295,7 @@ async fn collection_range_search() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 2e17886..c0954e1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,14 +5,18 @@ use milvus::schema::{CollectionSchema, CollectionSchemaBuilder, FieldSchema}; use rand::Rng; pub const DEFAULT_DIM: i64 = 128; -pub const DEFAULT_VEC_FIELD: &str = "feature"; -pub const DEFAULT_INDEX_NAME: &str = "feature_index"; +pub const DEFAULT_VEC_FIELD: &str = "float32_vec_field"; +pub const FP16_VEC_FIELD: &str = "float16_vec_field"; +pub const BF16_VEC_FIELD: &str = "bfloat16_vec_field"; +pub const DEFAULT_INDEX_NAME: &str = "float32_vec_field_index"; +pub const FP16_VEC_INDEX_NAME: &str = "float16_vec_field_index"; +pub const BF16_VEC_INDEX_NAME: &str = "bfloat16_vec_field_index"; + pub const URL: &str = "http://localhost:19530"; pub async fn create_test_collection(autoid: bool) -> Result<(Client, CollectionSchema)> { let collection_name = gen_random_name(); let collection_name = format!("{}_{}", "test_collection", collection_name); - let client = Client::new(URL).await?; let schema = CollectionSchemaBuilder::new(&collection_name, "") .add_field(FieldSchema::new_primary_int64("id", "", autoid)) .add_field(FieldSchema::new_float_vector( @@ -21,8 +25,34 @@ pub async fn create_test_collection(autoid: bool) -> Result<(Client, CollectionS DEFAULT_DIM, )) .build()?; - if client.has_collection(&collection_name).await? { - client.drop_collection(&collection_name).await?; + create_test_collection_with_schema(schema).await +} + +pub async fn create_test_fp16_bf16_collection(autoid: bool) -> Result<(Client, CollectionSchema)> { + let collection_name = gen_random_name(); + let collection_name = format!("{}_{}", "test_collection", collection_name); + let schema = CollectionSchemaBuilder::new(&collection_name, "") + .add_field(FieldSchema::new_primary_int64("id", "", autoid)) + .add_field(FieldSchema::new_float16_vector( + FP16_VEC_FIELD, + "fp16 vector field", + DEFAULT_DIM, + )) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bf16 vector field", + DEFAULT_DIM, + )) + .build()?; + create_test_collection_with_schema(schema).await +} + +async fn create_test_collection_with_schema( + schema: CollectionSchema, +) -> Result<(Client, CollectionSchema)> { + let client = Client::new(URL).await?; + if client.has_collection(schema.name()).await? { + client.drop_collection(schema.name()).await?; } client .create_collection(