diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c8da2e2..ca609b1 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,6 +1,8 @@ name: Unittest on: + workflow_dispatch: + push: paths: - 'src/**' @@ -9,8 +11,7 @@ on: - 'build.rs' - '.github/**' - 'docker-compose.yml' - - # Triggers the workflow on push or pull request events but only for the master branch + pull_request: paths: - 'src/**' @@ -20,34 +21,28 @@ on: - '.github/**' - 'docker-compose.yml' - jobs: - # This workflow contains a single job called "build" - build: - name: Unittest AMD64 Ubuntu ${{ matrix.ubuntu }} + # This workflow contains a single job called "build_and_test" + build_and_test: + name: Unittest AMD64 Ubuntu # The type of runner that the job will run on runs-on: ubuntu-latest timeout-minutes: 30 - strategy: - fail-fast: false - matrix: - ubuntu: [18.04] - env: - UBUNTU: ${{ matrix.ubuntu }} steps: # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: 'true' - + + # do not put docker volume in the source code directory - name: Setup Milvus - run: sudo docker-compose up -d && sleep 15s - + run: DOCKER_VOLUME_DIRECTORY=../ docker compose up -d && sleep 15s + - name: Setup protoc uses: arduino/setup-protoc@v1.1.2 with: repo-token: ${{ secrets.GITHUB_TOKEN }} - + # Runs a single command using the runners shell - name: Run Unittest run: RUST_BACKTRACE=1 cargo test diff --git a/Cargo.toml b/Cargo.toml index 838d2c2..dff9852 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ strum = "0.24" strum_macros = "0.24" base64 = "0.21.0" dashmap = "5.5.3" +# fp16/bf16 support +half = "2.4.1" [build-dependencies] tonic-build = { version = "0.8.2", default-features = false, features = [ @@ -30,3 +32,4 @@ tonic-build = { version = "0.8.2", default-features = false, features = [ [dev-dependencies] rand = "0.8.5" +futures = "0.3" diff --git a/README.md b/README.md index 62529bb..21424b8 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Pre-requisites: ### How to test Many tests require the Milvus server, the project provide a docker-compose file to setup a Milvus cluster: ``` -docker-compose -f ./docker-compose.yml up -d +docker compose -f ./docker-compose.yml up -d ``` You may need to wait for seconds until the system ready @@ -57,4 +57,4 @@ cargo test Enable the full backtrace for debugging: ``` RUST_BACKTRACE=1 cargo test -``` \ No newline at end of file +``` diff --git a/examples/collection.rs b/examples/collection.rs index 97cbe74..8f8965b 100644 --- a/examples/collection.rs +++ b/examples/collection.rs @@ -9,7 +9,8 @@ use std::collections::HashMap; use rand::prelude::*; -const DEFAULT_VEC_FIELD: &str = "embed"; +const FP32_VEC_FIELD: &str = "float32_vector_field"; + const DIM: i64 = 256; #[tokio::main] @@ -26,8 +27,8 @@ async fn main() -> Result<(), Error> { true, )) .add_field(FieldSchema::new_float_vector( - DEFAULT_VEC_FIELD, - "feature field", + FP32_VEC_FIELD, + "fp32 feature field", DIM, )) .build()?; @@ -48,8 +49,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result< let embed = rng.gen(); embed_data.push(embed); } - let embed_column = - FieldColumn::new(collection.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(collection.get_field(FP32_VEC_FIELD).unwrap(), embed_data)?; client .insert(collection.name(), vec![embed_column], None) @@ -62,7 +62,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result< HashMap::from([("nlist".to_owned(), "32".to_owned())]), ); client - .create_index(collection.name(), DEFAULT_VEC_FIELD, index_params) + .create_index(collection.name(), FP32_VEC_FIELD, index_params) .await?; client .load_collection(collection.name(), Some(LoadOptions::default())) diff --git a/examples/fp16_and_bf16.rs b/examples/fp16_and_bf16.rs new file mode 100644 index 0000000..bea9c08 --- /dev/null +++ b/examples/fp16_and_bf16.rs @@ -0,0 +1,119 @@ +use milvus::index::{IndexParams, IndexType}; +use milvus::options::LoadOptions; +use milvus::query::SearchOptions; +use milvus::schema::{CollectionSchema, CollectionSchemaBuilder}; +use milvus::{ + client::Client, data::FieldColumn, error::Error, schema::FieldSchema, +}; + +use half::prelude::*; +use rand::prelude::*; +use std::collections::HashMap; + +const FP16_VEC_FIELD: &str = "float16_vector_field"; +const BF16_VEC_FIELD: &str = "bfloat16_vector_field"; + +const DIM: i64 = 64; + +#[tokio::main] +async fn main() -> Result<(), Error> { + const URL: &str = "http://localhost:19530"; + + let client = Client::new(URL).await?; + + let schema = + CollectionSchemaBuilder::new("milvus_fp16", "fp16/bf16 example for milvus rust SDK") + .add_field(FieldSchema::new_primary_int64( + "id", + "primary key field", + true, + )) + .add_field(FieldSchema::new_float16_vector( + FP16_VEC_FIELD, + "fp16 feature field", + DIM, + )) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bf16 feature field", + DIM, + )) + .build()?; + client.create_collection(schema.clone(), None).await?; + + if let Err(err) = fp16_insert_and_query(&client, &schema).await { + println!("failed to run hello milvus: {:?}", err); + } + client.drop_collection(schema.name()).await?; + + Ok(()) +} + +fn gen_random_f32_vector(n: i64) -> Vec { + let mut data = Vec::::with_capacity(n as usize); + let mut rng = rand::thread_rng(); + for _ in 0..n { + data.push(rng.gen()); + } + data +} + +async fn fp16_insert_and_query( + client: &Client, + collection: &CollectionSchema, +) -> Result<(), Error> { + let mut embed_data = Vec::::new(); + for _ in 1..=DIM * 1000 { + let mut rng = rand::thread_rng(); + let embed = rng.gen(); + embed_data.push(embed); + } + + // fp16 or bf16 vector accept Vec, Vec or Vec/Vec as input + let bf16_column = FieldColumn::new( + collection.get_field(BF16_VEC_FIELD).unwrap(), + Vec::::from_f32_slice(embed_data.as_slice()), + )?; + let fp16_column = FieldColumn::new(collection.get_field(FP16_VEC_FIELD).unwrap(), embed_data)?; + + let result = client + .insert(collection.name(), vec![fp16_column, bf16_column], None) + .await?; + println!("insert cnt: {}", result.insert_cnt); + client.flush(collection.name()).await?; + + let create_index_fut = [FP16_VEC_FIELD, BF16_VEC_FIELD].map(|field_name| { + let index_params = IndexParams::new( + field_name.to_string() + "_index", + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ); + client.create_index(collection.name(), field_name, index_params) + }); + futures::future::try_join_all(create_index_fut).await?; + client.flush(collection.name()).await?; + client + .load_collection(collection.name(), Some(LoadOptions::default())) + .await?; + + // search + let q1 = Vec::::from_f32_slice(&gen_random_f32_vector(DIM)); + let q2 = Vec::::from_f32_slice(&gen_random_f32_vector(DIM)); + let option = SearchOptions::with_limit(3) + .metric_type(milvus::index::MetricType::L2) + .output_fields(vec!["id".to_owned(), FP16_VEC_FIELD.to_owned()]); + let result = client + .search( + collection.name(), + vec![q1.into(), q2.into()], + FP16_VEC_FIELD, + &option, + ) + .await?; + + println!("{:?}", result[0]); + println!("result num: {}, {}", result[0].size, result[1].size); + + Ok(()) +} diff --git a/src/collection.rs b/src/collection.rs index d70748f..c8f7306 100644 --- a/src/collection.rs +++ b/src/collection.rs @@ -660,6 +660,7 @@ pub type ParamValue = serde_json::Value; pub use serde_json::json as ParamValue; // search result for a single vector +#[derive(Debug)] pub struct SearchResult<'a> { pub size: i64, pub id: Vec>, diff --git a/src/data.rs b/src/data.rs index 7258bf6..cc90fd3 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,13 +1,14 @@ -use std::borrow::Cow; - +use crate::error::Result; use crate::{ proto::schema::{ self, field_data::Field, scalar_field::Data as ScalarData, vector_field::Data as VectorData, DataType, ScalarField, VectorField, }, schema::FieldSchema, - value::{Value, ValueVec}, + value::{TryIntoValueVecWithSchema, Value, ValueVec}, }; +use half::prelude::*; +use std::borrow::Cow; pub trait HasDataType { fn data_type() -> DataType; @@ -34,12 +35,17 @@ impl_has_data_type! { String, DataType::String, Cow<'_, str>, DataType::String, Vec, DataType::FloatVector, + Vec, DataType::Float16Vector, + Vec, DataType::BFloat16Vector, Vec, DataType::BinaryVector, Cow<'_, [f32]>, DataType::FloatVector, + Cow<'_, [f16]>, DataType::Float16Vector, + Cow<'_, [bf16]>, DataType::BFloat16Vector, Cow<'_, [u8]>, DataType::BinaryVector } -#[derive(Debug, Clone)] +/// FieldColumn represents a column of data. +#[derive(Debug, Clone, PartialEq)] pub struct FieldColumn { pub name: String, pub dtype: DataType, @@ -72,15 +78,35 @@ impl From for FieldColumn { } impl FieldColumn { - pub fn new>(schm: &FieldSchema, v: V) -> FieldColumn { - FieldColumn { + /// Create a new FieldColumn from a FieldSchema and a value vector. + /// Returns an error if the value vector does not match the schema. + /// + /// # Example + /// ``` + /// use milvus::data::FieldColumn; + /// use milvus::schema::FieldSchema; + /// use milvus::proto::schema::DataType; + /// + /// let schema = FieldSchema::new_int32("int32_schema", ""); + /// let column = FieldColumn::new(&schema, vec![1, 2, 3]).unwrap(); + /// assert_eq!(column.dtype, DataType::Int32); + /// assert_eq!(column.len(), 3); + /// + /// let schema = FieldSchema::new_float16_vector("float16_vector_schema", "", 8); + /// let column = FieldColumn::new(&schema, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, + /// 0.8]).unwrap(); + /// assert_eq!(column.dtype, DataType::Float16Vector); + /// ``` + pub fn new(schm: &FieldSchema, v: V) -> Result { + let value: ValueVec = v.try_into_value_vec(schm)?; + Ok(FieldColumn { name: schm.name.clone(), dtype: schm.dtype, - value: v.into(), + value, dim: schm.dim, max_length: schm.max_length, is_dynamic: false, - } + }) } pub fn get(&self, idx: usize) -> Option> { @@ -107,6 +133,14 @@ impl FieldColumn { let dim = (self.dim / 8) as usize; Value::Binary(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) } + ValueVec::Float16(v) => { + let dim = self.dim as usize; + Value::Float16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) + } + ValueVec::BFloat16(v) => { + let dim = self.dim as usize; + Value::BFloat16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim])) + } ValueVec::String(v) => Value::String(Cow::Borrowed(v.get(idx)?.as_ref())), ValueVec::Json(v) => Value::Json(Cow::Borrowed(v.get(idx)?.as_ref())), ValueVec::Array(v) => Value::Array(Cow::Borrowed(v.get(idx)?)), @@ -126,6 +160,14 @@ impl FieldColumn { (ValueVec::String(vec), Value::String(i)) => vec.push(i.to_string()), (ValueVec::Binary(vec), Value::Binary(i)) => vec.extend_from_slice(i.as_ref()), (ValueVec::Float(vec), Value::FloatArray(i)) => vec.extend_from_slice(i.as_ref()), + (ValueVec::Float16(vec), Value::Float16Array(i)) => vec.extend_from_slice(i.as_ref()), + (ValueVec::Float16(vec), Value::FloatArray(i)) => { + vec.extend(Vec::::from_f32_slice(i.as_ref())) + } + (ValueVec::BFloat16(vec), Value::BFloat16Array(i)) => vec.extend_from_slice(i.as_ref()), + (ValueVec::BFloat16(vec), Value::FloatArray(i)) => { + vec.extend(Vec::::from_f32_slice(i.as_ref())) + } _ => panic!("column type mismatch"), } } @@ -135,6 +177,11 @@ impl FieldColumn { self.value.len() / self.dim as usize } + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + pub fn copy_with_metadata(&self) -> Self { Self { dim: self.dim, @@ -151,6 +198,8 @@ impl FieldColumn { ValueVec::String(_) => ValueVec::String(Vec::new()), ValueVec::Json(_) => ValueVec::Json(Vec::new()), ValueVec::Binary(_) => ValueVec::Binary(Vec::new()), + ValueVec::Float16(_) => ValueVec::Float16(Vec::new()), + ValueVec::BFloat16(_) => ValueVec::BFloat16(Vec::new()), ValueVec::Array(_) => ValueVec::Array(Vec::new()), }, is_dynamic: self.is_dynamic, @@ -176,6 +225,7 @@ impl From for schema::FieldData { data: Some(ScalarData::LongData(schema::LongArray { data: v })), }), ValueVec::Float(v) => match this.dtype { + // both scalar and vector fields accept 1-d float array DataType::Float => Field::Scalars(ScalarField { data: Some(ScalarData::FloatData(schema::FloatArray { data: v })), }), @@ -204,6 +254,21 @@ impl From for schema::FieldData { data: Some(VectorData::BinaryVector(v)), dim: this.dim, }), + // milvus-proto assumes that float16 and bfloat16 are stored as little-endian bytes + ValueVec::BFloat16(v) => { + let v: Vec = v.into_iter().flat_map(|x| x.to_le_bytes()).collect(); + Field::Vectors(VectorField { + data: Some(VectorData::Bfloat16Vector(v)), + dim: this.dim, + }) + } + ValueVec::Float16(v) => { + let v: Vec = v.into_iter().flat_map(|x| x.to_le_bytes()).collect(); + Field::Vectors(VectorField { + data: Some(VectorData::Float16Vector(v)), + dim: this.dim, + }) + } }), is_dynamic: false, } @@ -236,7 +301,10 @@ impl_from_field! { Vec[Field::Scalars(ScalarField {data: Some(ScalarData::LongData(schema::LongArray { data }))}) => Some(data)], Vec[Field::Scalars(ScalarField {data: Some(ScalarData::StringData(schema::StringArray { data }))}) => Some(data)], Vec[Field::Scalars(ScalarField {data: Some(ScalarData::DoubleData(schema::DoubleArray { data }))}) => Some(data)], - Vec[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)] + Vec[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)], + // milvus-proto assumes that float16 and bfloat16 are stored as little-endian bytes + Vec[Field::Vectors(VectorField {data: Some(VectorData::Float16Vector(data)), ..}) => Some(data.chunks(2).map(|x|f16::from_le_bytes([x[0], x[1]])).collect())], + Vec[Field::Vectors(VectorField {data: Some(VectorData::Bfloat16Vector(data)), ..}) => Some(data.chunks(2).map(|x|bf16::from_le_bytes([x[0], x[1]])).collect())] } impl FromField for Vec { @@ -265,3 +333,109 @@ fn get_dim_max_length(field: &Field) -> (Option, Option) { (Some(dim), None) // no idea how to get max_length } + +#[cfg(test)] +mod test { + use crate::data::*; + #[test] + fn test_data_type() { + assert_eq!(bool::data_type(), DataType::Bool); + assert_eq!(i8::data_type(), DataType::Int8); + assert_eq!(i16::data_type(), DataType::Int16); + assert_eq!(i32::data_type(), DataType::Int32); + assert_eq!(i64::data_type(), DataType::Int64); + assert_eq!(f32::data_type(), DataType::Float); + assert_eq!(f64::data_type(), DataType::Double); + assert_eq!(String::data_type(), DataType::String); + assert_eq!(std::borrow::Cow::::data_type(), DataType::String); + assert_eq!(Vec::::data_type(), DataType::FloatVector); + assert_eq!(Vec::::data_type(), DataType::Float16Vector); + assert_eq!(Vec::::data_type(), DataType::BFloat16Vector); + assert_eq!(Vec::::data_type(), DataType::BinaryVector); + assert_eq!( + std::borrow::Cow::<[f32]>::data_type(), + DataType::FloatVector + ); + assert_eq!( + std::borrow::Cow::<[f16]>::data_type(), + DataType::Float16Vector + ); + assert_eq!( + std::borrow::Cow::<[bf16]>::data_type(), + DataType::BFloat16Vector + ); + assert_eq!( + std::borrow::Cow::<[u8]>::data_type(), + DataType::BinaryVector + ); + } + + #[test] + fn test_field_column() { + let field = FieldColumn { + name: "test".to_string(), + dtype: DataType::Int32, + value: ValueVec::Int(vec![1, 2, 3]), + dim: 1, + max_length: 0, + is_dynamic: false, + }; + + let field_data: schema::FieldData = field.clone().into(); + let field_column: FieldColumn = field_data.into(); + assert_eq!(field, field_column); + } + + #[test] + fn test_field_column_from_schema() { + let field_schema = FieldSchema::new_int32("int32_schema", ""); + let field_column = FieldColumn::new(&field_schema, vec![1, 2, 3]).unwrap(); + assert_eq!(field_column.dtype, DataType::Int32); + assert_eq!(field_column.dim, 1); + assert_eq!(field_column.len(), 3); + + let field_schema = FieldSchema::new_float("float_schema", ""); + let field_column_res = FieldColumn::new(&field_schema, Vec::::new()); + assert!(field_column_res.is_err()); + let field_column = FieldColumn::new(&field_schema, Vec::::new()).unwrap(); + assert_eq!(field_column.dtype, DataType::Float); + assert_eq!(field_column.dim, 1); + assert_eq!(field_column.len(), 0); + + let test_cases: [(fn(&str, &str, i64) -> FieldSchema, DataType); 3] = [ + (FieldSchema::new_bfloat16_vector, DataType::BFloat16Vector), + (FieldSchema::new_float16_vector, DataType::Float16Vector), + (FieldSchema::new_float_vector, DataType::FloatVector), + ]; + for (new_fn, dtype) in test_cases { + let field_schema = new_fn("feat", "", 8); + + let field_column_res = FieldColumn::new(&field_schema, vec![0.1, 0.2, 0.3]); + assert!(field_column_res.is_err()); + + let field_column_res = FieldColumn::new(&field_schema, Vec::::new()); + assert!(field_column_res.is_ok()); + let field_column_res = FieldColumn::new(&field_schema, Vec::::new()); + assert!(field_column_res.is_ok()); + let field_column_res = + FieldColumn::new(&field_schema, vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]); + assert!(field_column_res.is_ok()); + let field_column_res = FieldColumn::new( + &field_schema, + vec![ + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, + ], + ); + let mut field_column = field_column_res.unwrap(); + assert_eq!(field_column.dtype, dtype); + assert_eq!(field_column.dim, 8); + assert_eq!(field_column.len(), 2); + let value = field_column.get(0).unwrap(); + assert_eq!(value.data_type(), dtype); + field_column.push(vec![1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4].into()); + assert_eq!(field_column.len(), 3); + let value = field_column.get(2).unwrap(); + assert_eq!(value.data_type(), dtype); + } + } +} diff --git a/src/mutate.rs b/src/mutate.rs index 17680f6..0322dee 100644 --- a/src/mutate.rs +++ b/src/mutate.rs @@ -1,19 +1,15 @@ -use prost::bytes::{BufMut, BytesMut}; use crate::error::Result; use crate::{ client::Client, - collection, data::FieldColumn, error::Error, proto::{ self, common::{MsgBase, MsgType}, milvus::{InsertRequest, UpsertRequest}, - schema::{scalar_field::Data, DataType}, + schema::DataType, }, - schema::FieldData, - utils::status_to_result, value::ValueVec, }; diff --git a/src/partition.rs b/src/partition.rs index 0e13909..34af03a 100644 --- a/src/partition.rs +++ b/src/partition.rs @@ -4,7 +4,6 @@ use crate::error::*; use crate::{ client::Client, proto::{ - self, common::{MsgBase, MsgType}, }, utils::status_to_result, diff --git a/src/query.rs b/src/query.rs index 044e4c6..78f6cc7 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,8 +1,7 @@ -use std::borrow::Borrow; -use std::collections::HashMap; - +use half::{bf16, f16}; use prost::bytes::BytesMut; use prost::Message; +use std::collections::HashMap; use crate::client::{Client, ConsistencyLevel}; use crate::collection::{ParamValue, SearchResult}; @@ -372,6 +371,8 @@ fn get_place_holder_value(vectors: Vec) -> Result { match vectors[0] { Value::FloatArray(_) => place_holder.r#type = PlaceholderType::FloatVector as _, + Value::Float16Array(_) => place_holder.r#type = PlaceholderType::Float16Vector as _, + Value::BFloat16Array(_) => place_holder.r#type = PlaceholderType::BFloat16Vector as _, Value::Binary(_) => place_holder.r#type = PlaceholderType::BinaryVector as _, _ => { return Err(SuperError::from(crate::collection::Error::IllegalType( @@ -381,14 +382,26 @@ fn get_place_holder_value(vectors: Vec) -> Result { } }; + macro_rules! place_holder_push_bytes { + ($d:expr, $t:ty) => { + let mut bytes = Vec::::with_capacity($d.len() * size_of::<$t>()); + for f in $d.iter() { + // milvus-proto assumes that float16 and bfloat16 are stored as little-endian bytes + bytes.extend_from_slice(&f.to_le_bytes()); + } + place_holder.values.push(bytes) + }; + } for v in &vectors { match (v, &vectors[0]) { (Value::FloatArray(d), Value::FloatArray(_)) => { - let mut bytes = Vec::::with_capacity(d.len() * 4); - for f in d.iter() { - bytes.extend_from_slice(&f.to_le_bytes()); - } - place_holder.values.push(bytes) + place_holder_push_bytes!(d, f32); + } + (Value::Float16Array(d), Value::Float16Array(_)) => { + place_holder_push_bytes!(d, f16); + } + (Value::BFloat16Array(d), Value::BFloat16Array(_)) => { + place_holder_push_bytes!(d, bf16); } (Value::Binary(d), Value::Binary(_)) => place_holder.values.push(d.to_vec()), _ => { diff --git a/src/schema.rs b/src/schema.rs index 73460b2..2e3eb7a 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -28,6 +28,16 @@ use crate::proto::{ pub use crate::proto::schema::FieldData; +pub fn is_dense_vector_type(dtype: DataType) -> bool { + matches!( + dtype, + DataType::BinaryVector + | DataType::FloatVector + | DataType::Float16Vector + | DataType::BFloat16Vector + ) +} + pub trait Schema { // fn name(&self) -> &str; // fn description(&self) -> &str; @@ -153,7 +163,7 @@ pub struct FieldSchema { pub is_primary: bool, pub auto_id: bool, pub chunk_size: usize, - pub dim: i64, // only for BinaryVector and FloatVector + pub dim: i64, // only for BinaryVector, FloatVector, Float16Vector and BFloat16Vector pub max_length: i32, // only for VarChar } @@ -392,12 +402,49 @@ impl FieldSchema { max_length: 0, } } + + pub fn new_float16_vector(name: &str, description: &str, dim: i64) -> Self { + if dim <= 0 { + panic!("dim should be positive"); + } + + Self { + name: name.to_owned(), + description: description.to_owned(), + dtype: DataType::Float16Vector, + chunk_size: dim as usize, + dim, + is_primary: false, + auto_id: false, + max_length: 0, + } + } + + pub fn new_bfloat16_vector(name: &str, description: &str, dim: i64) -> Self { + if dim <= 0 { + panic!("dim should be positive"); + } + + Self { + name: name.to_owned(), + description: description.to_owned(), + dtype: DataType::BFloat16Vector, + chunk_size: dim as usize, + dim, + is_primary: false, + auto_id: false, + max_length: 0, + } + } } impl From for schema::FieldSchema { fn from(fld: FieldSchema) -> schema::FieldSchema { let params = match fld.dtype { - DataType::BinaryVector | DataType::FloatVector => vec![KeyValuePair { + DataType::BinaryVector + | DataType::FloatVector + | DataType::Float16Vector + | DataType::BFloat16Vector => vec![KeyValuePair { key: "dim".to_string(), value: fld.dim.to_string(), }], @@ -465,16 +512,12 @@ impl CollectionSchema { pub fn is_valid_vector_field(&self, field_name: &str) -> Result<()> { for f in &self.fields { if f.name == field_name { - if f.dtype == DataType::BinaryVector || f.dtype == DataType::FloatVector { - return Ok(()); - } else { - return Err(error::Error::from(Error::NotVectorField( - field_name.to_owned(), - ))); - } + return is_dense_vector_type(f.dtype) + .then_some(()) + .ok_or(Error::NotVectorField(field_name.to_owned()).into()); } } - return Err(error::Error::from(Error::NoSuchKey(field_name.to_owned()))); + Err(error::Error::from(Error::NoSuchKey(field_name.to_owned()))) } } diff --git a/src/value.rs b/src/value.rs index f2c72e1..0ad9cd2 100644 --- a/src/value.rs +++ b/src/value.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use crate::proto::{ self, schema::{ @@ -7,7 +5,13 @@ use crate::proto::{ DataType, }, }; +use crate::schema::FieldSchema; +use half::prelude::*; +use std::borrow::Cow; +use strum_macros::Display; +/// Value represents a scalar value or a vector value. +#[derive(Debug, Clone, Display, PartialEq)] pub enum Value<'a> { None, Bool(bool), @@ -18,6 +22,10 @@ pub enum Value<'a> { Float(f32), Double(f64), FloatArray(Cow<'a, [f32]>), + /// schema.proto uses bytes to represent binary_vector, + /// float16_vector and bfloat16_vector. + Float16Array(Cow<'a, [f16]>), + BFloat16Array(Cow<'a, [bf16]>), Binary(Cow<'a, [u8]>), String(Cow<'a, str>), Json(Cow<'a, [u8]>), @@ -58,6 +66,8 @@ impl Value<'_> { Value::String(_) => DataType::String, Value::Json(_) => DataType::Json, Value::FloatArray(_) => DataType::FloatVector, + Value::Float16Array(_) => DataType::Float16Vector, + Value::BFloat16Array(_) => DataType::BFloat16Vector, Value::Binary(_) => DataType::BinaryVector, Value::Array(_) => DataType::Array, } @@ -100,14 +110,89 @@ impl From> for Value<'static> { } } -#[derive(Debug, Clone)] +impl<'a> From<&'a [f16]> for Value<'a> { + fn from(v: &'a [f16]) -> Self { + Self::Float16Array(Cow::Borrowed(v)) + } +} + +impl From> for Value<'static> { + fn from(v: Vec) -> Self { + Self::Float16Array(Cow::Owned(v)) + } +} + +impl<'a> From<&'a [bf16]> for Value<'a> { + fn from(v: &'a [bf16]) -> Self { + Self::BFloat16Array(Cow::Borrowed(v)) + } +} + +impl From> for Value<'static> { + fn from(v: Vec) -> Self { + Self::BFloat16Array(Cow::Owned(v)) + } +} + +macro_rules! impl_try_from_for_value_column { + ( $($o: ident,$t: ty ),+ ) => {$( + impl TryFrom> for $t { + type Error = crate::error::Error; + fn try_from(value: Value<'_>) -> Result { + match value { + Value::$o(v) => Ok(v), + _ => Err(crate::error::Error::Conversion), + } + } + } + )*}; +} + +impl_try_from_for_value_column! { + Bool,bool, + Int8,i8, + Int16,i16, + Int32,i32, + Long,i64, + Float,f32, + Double,f64 +} + +macro_rules! impl_try_from_for_value_column { + ( $($o: ident,$t: ty ),+ ) => {$( + impl TryFrom> for $t { + type Error = crate::error::Error; + fn try_from(value: Value<'_>) -> Result { + match value { + Value::$o(v) => Ok(v.into_owned()), + _ => Err(crate::error::Error::Conversion), + } + } + } + )*}; +} + +impl_try_from_for_value_column! { + FloatArray,Vec, + Float16Array,Vec, + BFloat16Array,Vec +} + +/// ValueVec represents a column of data. +/// Both scalar_field value and vector_field value are represented by 1-d array. +#[derive(Debug, Clone, Display, PartialEq)] pub enum ValueVec { None, Bool(Vec), Int(Vec), Long(Vec), + /// float or float vector Float(Vec), Double(Vec), + /// float16 vector + Float16(Vec), + /// bfloat16 vector + BFloat16(Vec), Binary(Vec), String(Vec), Json(Vec>), @@ -130,10 +215,110 @@ impl_from_for_value_vec! { Vec, Long, Vec, String, Vec, Binary, + Vec, BFloat16, + Vec, Float16, Vec, Float, Vec, Double } +trait IntoFpVec: Into { + fn to_fp32_vec(self) -> Vec; + fn to_fp16_vec(self) -> Vec; + fn to_bf16_vec(self) -> Vec; +} + +impl IntoFpVec for Vec { + fn to_fp32_vec(self) -> Vec { + self + } + + fn to_fp16_vec(self) -> Vec { + Vec::::from_f32_slice(&self) + } + + fn to_bf16_vec(self) -> Vec { + Vec::::from_f32_slice(&self) + } +} + +impl IntoFpVec for Vec { + fn to_fp32_vec(self) -> Vec { + self.iter().map(|x| *x as f32).collect() + } + + fn to_fp16_vec(self) -> Vec { + Vec::::from_f64_slice(&self) + } + + fn to_bf16_vec(self) -> Vec { + Vec::::from_f64_slice(&self) + } +} + +pub trait TryIntoValueVecWithSchema: Into { + fn try_into_value_vec(self, schema: &FieldSchema) -> crate::error::Result; +} + +impl TryIntoValueVecWithSchema for V { + fn try_into_value_vec(self, schema: &FieldSchema) -> crate::error::Result { + debug_assert!(schema.dim > 0); + let v: ValueVec = match schema.dtype { + DataType::FloatVector => { + let v: Vec = self.to_fp32_vec(); + if v.len() % schema.dim as usize != 0 { + return Err(crate::error::Error::Conversion); + } + v.into() + } + DataType::Float16Vector => { + let v: Vec = self.to_fp16_vec(); + if v.len() % schema.dim as usize != 0 { + return Err(crate::error::Error::Conversion); + } + v.into() + } + DataType::BFloat16Vector => { + let v: Vec = self.to_bf16_vec(); + if v.len() % schema.dim as usize != 0 { + return Err(crate::error::Error::Conversion); + } + v.into() + } + _ => self.into(), + }; + if v.check_dtype(schema.dtype) { + Ok(v) + } else { + Err(crate::error::Error::Conversion) + } + } +} + +macro_rules! impl_try_into_value_vec_with_datatype { + ( $($t: ty),+ ) => {$( + impl TryIntoValueVecWithSchema for $t { + fn try_into_value_vec(self, schema: &FieldSchema) -> crate::error::Result { + let v: ValueVec = self.into(); + if v.check_dtype(schema.dtype) { + Ok(v) + } else { + Err(crate::error::Error::Conversion) + } + } + } + )*}; +} + +impl_try_into_value_vec_with_datatype!( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec +); + macro_rules! impl_try_from_for_value_vec { ( $($o: ident, $t: ty ),+ ) => {$( impl TryFrom for $t { @@ -154,6 +339,8 @@ impl_try_from_for_value_vec! { Long, Vec, String, Vec, Binary, Vec, + BFloat16, Vec, + Float16, Vec, Float, Vec, Double, Vec } @@ -187,27 +374,49 @@ impl ValueVec { DataType::Array => Self::Array(Vec::new()), DataType::BinaryVector => Self::Binary(Vec::new()), DataType::FloatVector => Self::Float(Vec::new()), - DataType::Float16Vector => Self::Binary(Vec::new()), - DataType::BFloat16Vector => Self::Binary(Vec::new()), + DataType::Float16Vector => Self::Float16(Vec::new()), + DataType::BFloat16Vector => Self::BFloat16(Vec::new()), } } + pub fn float16_from_f32_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f32_slice(v.as_slice()); + Self::Float16(v) + } + + pub fn float16_from_f64_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f64_slice(v.as_slice()); + Self::Float16(v) + } + + pub fn bfloat16_from_f32_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f32_slice(v.as_slice()); + Self::BFloat16(v) + } + + pub fn bfloat16_from_f64_slice(v: &Vec) -> Self { + let v: Vec = Vec::::from_f64_slice(v.as_slice()); + Self::BFloat16(v) + } + pub fn check_dtype(&self, dtype: DataType) -> bool { - match (self, dtype) { + matches!( + (self, dtype), (ValueVec::Binary(..), DataType::BinaryVector) - | (ValueVec::Float(..), DataType::FloatVector) - | (ValueVec::Float(..), DataType::Float) - | (ValueVec::Int(..), DataType::Int8) - | (ValueVec::Int(..), DataType::Int16) - | (ValueVec::Int(..), DataType::Int32) - | (ValueVec::Long(..), DataType::Int64) - | (ValueVec::Bool(..), DataType::Bool) - | (ValueVec::String(..), DataType::String) - | (ValueVec::String(..), DataType::VarChar) - | (ValueVec::None, _) - | (ValueVec::Double(..), DataType::Double) => true, - _ => false, - } + | (ValueVec::BFloat16(..), DataType::BFloat16Vector) + | (ValueVec::Float16(..), DataType::Float16Vector) + | (ValueVec::Float(..), DataType::FloatVector) + | (ValueVec::Float(..), DataType::Float) + | (ValueVec::Int(..), DataType::Int8) + | (ValueVec::Int(..), DataType::Int16) + | (ValueVec::Int(..), DataType::Int32) + | (ValueVec::Long(..), DataType::Int64) + | (ValueVec::Bool(..), DataType::Bool) + | (ValueVec::String(..), DataType::String) + | (ValueVec::String(..), DataType::VarChar) + | (ValueVec::None, _) + | (ValueVec::Double(..), DataType::Double) + ) } #[inline] @@ -221,6 +430,8 @@ impl ValueVec { ValueVec::Bool(v) => v.len(), ValueVec::Int(v) => v.len(), ValueVec::Long(v) => v.len(), + ValueVec::BFloat16(v) => v.len(), + ValueVec::Float16(v) => v.len(), ValueVec::Float(v) => v.len(), ValueVec::Double(v) => v.len(), ValueVec::Binary(v) => v.len(), @@ -236,6 +447,8 @@ impl ValueVec { ValueVec::Bool(v) => v.clear(), ValueVec::Int(v) => v.clear(), ValueVec::Long(v) => v.clear(), + ValueVec::BFloat16(v) => v.clear(), + ValueVec::Float16(v) => v.clear(), ValueVec::Float(v) => v.clear(), ValueVec::Double(v) => v.clear(), ValueVec::Binary(v) => v.clear(), @@ -268,8 +481,20 @@ impl From for ValueVec { Some(x) => match x { VectorData::FloatVector(v) => Self::Float(v.data), VectorData::BinaryVector(v) => Self::Binary(v), - VectorData::Bfloat16Vector(v) => Self::Binary(v), - VectorData::Float16Vector(v) => Self::Binary(v), + VectorData::Bfloat16Vector(v) => { + let v: Vec = v + .chunks_exact(2) + .map(|x| bf16::from_le_bytes([x[0], x[1]])) + .collect(); + Self::BFloat16(v) + } + VectorData::Float16Vector(v) => { + let v: Vec = v + .chunks_exact(2) + .map(|x| f16::from_le_bytes([x[0], x[1]])) + .collect(); + Self::Float16(v) + } }, None => Self::None, }, @@ -277,36 +502,13 @@ impl From for ValueVec { } } -macro_rules! impl_try_from_for_value_column { - ( $($o: ident,$t: ty ),+ ) => {$( - impl TryFrom> for $t { - type Error = crate::error::Error; - fn try_from(value: Value<'_>) -> Result { - match value { - Value::$o(v) => Ok(v), - _ => Err(crate::error::Error::Conversion), - } - } - } - )*}; -} - -impl_try_from_for_value_column! { - Bool,bool, - Int8,i8, - Int16,i16, - Int32,i32, - Long,i64, - Float,f32, - Double,f64 -} - #[cfg(test)] mod test { use crate::{ error::Error, value::{Value, ValueVec}, }; + use half::prelude::*; #[test] fn test_try_from_for_value_column() { @@ -345,6 +547,24 @@ mod test { let r: Result = double.try_into(); assert!(r.is_ok()); assert_eq!(22104f64, r.unwrap()); + + let float_array: Value = vec![1.1, 2.2, 3.3].into(); + let r: Result, Error> = float_array.try_into(); + assert_eq!(vec![1.1, 2.2, 3.3], r.unwrap()); + + let float16_array: Value = Vec::::from_f32_slice(&[1.1, 2.2, 3.3, -1.1]).into(); + let r: Result, Error> = float16_array.try_into(); + assert_eq!( + Vec::::from_f32_slice(&[1.1, 2.2, 3.3, -1.1]), + r.unwrap() + ); + + let bfloat16_array: Value = Vec::::from_f32_slice(&[1.1, 2.2, 3.3, 8.21, 0.]).into(); + let r: Result, Error> = bfloat16_array.try_into(); + assert_eq!( + Vec::::from_f32_slice(&[1.1, 2.2, 3.3, 8.21, 0.]), + r.unwrap() + ); } #[test] @@ -365,6 +585,17 @@ mod test { assert!(b.is_ok()); assert_eq!(v, b.unwrap()); + let v: Vec = Vec::from_f32_slice(&[-1.1, -2.2, 3.3, -4.4, -5.5]); + let b = ValueVec::BFloat16(v.clone()); + let b: Result, Error> = b.try_into(); + assert!(b.is_ok()); + + let v: Vec = Vec::from_f32_slice(&[1.1, 2.2, 3.3, 4.4, 5.5]); + let b = ValueVec::Float16(v.clone()); + let b: Result, Error> = b.try_into(); + assert!(b.is_ok()); + assert_eq!(v, b.unwrap()); + let v: Vec = vec![11., 7., 754., 68., 34.]; let b = ValueVec::Float(v.clone()); let b: Result, Error> = b.try_into(); diff --git a/tests/client.rs b/tests/client.rs index 8720e4a..9aaf0b8 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -75,6 +75,11 @@ async fn create_has_drop_collection() -> Result<()> { let schema = schema .add_field(FieldSchema::new_int64("i64_field", "")) .add_field(FieldSchema::new_bool("bool_field", "")) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bfloat", + 128, + )) .set_primary_key("i64_field")? .enable_auto_id()? .build()?; @@ -83,7 +88,7 @@ async fn create_has_drop_collection() -> Result<()> { client.drop_collection(NAME).await?; } - let collection = client + client .create_collection( schema, Some(CreateCollectionOptions::with_consistency_level( diff --git a/tests/collection.rs b/tests/collection.rs index afb58f7..046dddf 100644 --- a/tests/collection.rs +++ b/tests/collection.rs @@ -23,10 +23,9 @@ use milvus::mutate::InsertOptions; use milvus::options::LoadOptions; use milvus::query::{QueryOptions, SearchOptions}; use std::collections::HashMap; - mod common; use common::*; - +use half::prelude::*; use milvus::value::ValueVec; #[tokio::test] @@ -42,8 +41,8 @@ async fn collection_upsert() -> Result<()> { let (client, schema) = create_test_collection(false).await?; let pk_data = gen_random_int64_vector(2000); let vec_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let pk_col = FieldColumn::new(schema.get_field("id").unwrap(), pk_data); - let vec_col = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), vec_data); + let pk_col = FieldColumn::new(schema.get_field("id").unwrap(), pk_data)?; + let vec_col = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), vec_data)?; client .upsert(schema.name(), vec![pk_col, vec_col], None) .await?; @@ -77,7 +76,7 @@ async fn collection_basic() -> Result<()> { let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) @@ -108,13 +107,114 @@ async fn collection_basic() -> Result<()> { Ok(()) } +#[tokio::test] +async fn collection_fp16_bf16_vec() -> Result<()> { + let (client, schema) = create_test_fp16_bf16_collection(true).await?; + let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); + + let fp16_vector_column = FieldColumn::new( + schema.get_field(FP16_VEC_FIELD).unwrap(), + embed_data.clone(), + )?; + let bf16_vector_column = + FieldColumn::new(schema.get_field(BF16_VEC_FIELD).unwrap(), embed_data)?; + + client + .insert( + schema.name(), + vec![fp16_vector_column, bf16_vector_column], + None, + ) + .await?; + client.flush(schema.name()).await?; + for (index_name, field_name) in [ + (FP16_VEC_INDEX_NAME, FP16_VEC_FIELD), + (BF16_VEC_INDEX_NAME, BF16_VEC_FIELD), + ] { + let index_params = IndexParams::new( + index_name.to_owned(), + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ); + client + .create_index(schema.name(), field_name, index_params) + .await?; + } + client + .load_collection(schema.name(), Some(LoadOptions::default())) + .await?; + + let options = QueryOptions::default(); + let result = client.query(schema.name(), "id > 0", &options).await?; + + println!( + "result num: {}", + result.first().map(|c| c.len()).unwrap_or(0), + ); + // create index + let create_index_futs = [ + (FP16_VEC_FIELD, FP16_VEC_INDEX_NAME), + (BF16_VEC_FIELD, BF16_VEC_INDEX_NAME), + ] + .map(|(field_name, index_name)| { + client.create_index( + schema.name(), + field_name, + IndexParams::new( + index_name.to_owned(), + IndexType::IvfFlat, + milvus::index::MetricType::L2, + HashMap::from([("nlist".to_owned(), "32".to_owned())]), + ), + ) + }); + futures::future::join_all(create_index_futs).await; + client.flush(schema.name()).await?; + client + .load_collection(schema.name(), Some(LoadOptions::default())) + .await?; + + // search + let mut option = SearchOptions::with_limit(10) + .metric_type(MetricType::L2) + .output_fields(vec!["id".to_owned(), FP16_VEC_FIELD.to_owned()]); + option = option.add_param("nprobe", ParamValue!(16)); + let query_vec = gen_random_f32_vector(DEFAULT_DIM); + + let result = client + .search( + schema.name(), + vec![Vec::::from_f32_slice(query_vec.as_slice()).into()], + FP16_VEC_FIELD, + &option, + ) + .await?; + assert_eq!(result.len(), 1); + assert_eq!(result[0].size, 10); + assert_eq!(result[0].field.len(), 2); + + let result = client + .search( + schema.name(), + vec![Vec::::from_f32_slice(query_vec.as_slice()).into()], + BF16_VEC_FIELD, + &option, + ) + .await?; + assert_eq!(result[0].size, 10); + + client.drop_collection(schema.name()).await?; + Ok(()) +} + #[tokio::test] async fn collection_index() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let feature = gen_random_f32_vector(DEFAULT_DIM * 2000); - let feature_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), feature); + let feature_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), feature)?; client .insert(schema.name(), vec![feature_column], None) @@ -149,7 +249,7 @@ async fn collection_search() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) @@ -195,7 +295,7 @@ async fn collection_range_search() -> Result<()> { let (client, schema) = create_test_collection(true).await?; let embed_data = gen_random_f32_vector(DEFAULT_DIM * 2000); - let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data); + let embed_column = FieldColumn::new(schema.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data)?; client .insert(schema.name(), vec![embed_column], None) diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 2e17886..c0954e1 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -5,14 +5,18 @@ use milvus::schema::{CollectionSchema, CollectionSchemaBuilder, FieldSchema}; use rand::Rng; pub const DEFAULT_DIM: i64 = 128; -pub const DEFAULT_VEC_FIELD: &str = "feature"; -pub const DEFAULT_INDEX_NAME: &str = "feature_index"; +pub const DEFAULT_VEC_FIELD: &str = "float32_vec_field"; +pub const FP16_VEC_FIELD: &str = "float16_vec_field"; +pub const BF16_VEC_FIELD: &str = "bfloat16_vec_field"; +pub const DEFAULT_INDEX_NAME: &str = "float32_vec_field_index"; +pub const FP16_VEC_INDEX_NAME: &str = "float16_vec_field_index"; +pub const BF16_VEC_INDEX_NAME: &str = "bfloat16_vec_field_index"; + pub const URL: &str = "http://localhost:19530"; pub async fn create_test_collection(autoid: bool) -> Result<(Client, CollectionSchema)> { let collection_name = gen_random_name(); let collection_name = format!("{}_{}", "test_collection", collection_name); - let client = Client::new(URL).await?; let schema = CollectionSchemaBuilder::new(&collection_name, "") .add_field(FieldSchema::new_primary_int64("id", "", autoid)) .add_field(FieldSchema::new_float_vector( @@ -21,8 +25,34 @@ pub async fn create_test_collection(autoid: bool) -> Result<(Client, CollectionS DEFAULT_DIM, )) .build()?; - if client.has_collection(&collection_name).await? { - client.drop_collection(&collection_name).await?; + create_test_collection_with_schema(schema).await +} + +pub async fn create_test_fp16_bf16_collection(autoid: bool) -> Result<(Client, CollectionSchema)> { + let collection_name = gen_random_name(); + let collection_name = format!("{}_{}", "test_collection", collection_name); + let schema = CollectionSchemaBuilder::new(&collection_name, "") + .add_field(FieldSchema::new_primary_int64("id", "", autoid)) + .add_field(FieldSchema::new_float16_vector( + FP16_VEC_FIELD, + "fp16 vector field", + DEFAULT_DIM, + )) + .add_field(FieldSchema::new_bfloat16_vector( + BF16_VEC_FIELD, + "bf16 vector field", + DEFAULT_DIM, + )) + .build()?; + create_test_collection_with_schema(schema).await +} + +async fn create_test_collection_with_schema( + schema: CollectionSchema, +) -> Result<(Client, CollectionSchema)> { + let client = Client::new(URL).await?; + if client.has_collection(schema.name()).await? { + client.drop_collection(schema.name()).await?; } client .create_collection(