Skip to content

Commit

Permalink
feat!: support fp16/bf16 vector
Browse files Browse the repository at this point in the history
BREAKING CHANGES:

- `FieldColumn::new() -> FieldColumn` -> `FieldColumn::new() -> Result<FieldColumn>`
    - add type checking so it may return Error

issue: milvus-io/milvus#37448

Signed-off-by: Yinzuo Jiang <jiangyinzuo@foxmail.com>
Signed-off-by: Yinzuo Jiang <yinzuo.jiang@zilliz.com>
  • Loading branch information
jiangyinzuo committed Dec 4, 2024
1 parent a1e2b02 commit 0638758
Show file tree
Hide file tree
Showing 13 changed files with 640 additions and 96 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ strum = "0.24"
strum_macros = "0.24"
base64 = "0.21.0"
dashmap = "5.5.3"
half = "2.4.1"

[build-dependencies]
tonic-build = { version = "0.8.2", default-features = false, features = [
Expand All @@ -30,3 +31,4 @@ tonic-build = { version = "0.8.2", default-features = false, features = [

[dev-dependencies]
rand = "0.8.5"
futures = "0.3"
12 changes: 6 additions & 6 deletions examples/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use std::collections::HashMap;

use rand::prelude::*;

const DEFAULT_VEC_FIELD: &str = "embed";
const FP32_VEC_FIELD: &str = "float32_vector_field";

const DIM: i64 = 256;

#[tokio::main]
Expand All @@ -26,8 +27,8 @@ async fn main() -> Result<(), Error> {
true,
))
.add_field(FieldSchema::new_float_vector(
DEFAULT_VEC_FIELD,
"feature field",
FP32_VEC_FIELD,
"fp32 feature field",
DIM,
))
.build()?;
Expand All @@ -48,8 +49,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result<
let embed = rng.gen();
embed_data.push(embed);
}
let embed_column =
FieldColumn::new(collection.get_field(DEFAULT_VEC_FIELD).unwrap(), embed_data);
let embed_column = FieldColumn::new(collection.get_field(FP32_VEC_FIELD).unwrap(), embed_data)?;

client
.insert(collection.name(), vec![embed_column], None)
Expand All @@ -62,7 +62,7 @@ async fn hello_milvus(client: &Client, collection: &CollectionSchema) -> Result<
HashMap::from([("nlist".to_owned(), "32".to_owned())]),
);
client
.create_index(collection.name(), DEFAULT_VEC_FIELD, index_params)
.create_index(collection.name(), FP32_VEC_FIELD, index_params)
.await?;
client
.load_collection(collection.name(), Some(LoadOptions::default()))
Expand Down
119 changes: 119 additions & 0 deletions examples/fp16_and_bf16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use milvus::index::{IndexParams, IndexType};
use milvus::options::LoadOptions;
use milvus::query::SearchOptions;
use milvus::schema::{CollectionSchema, CollectionSchemaBuilder};
use milvus::{
client::Client, data::FieldColumn, error::Error, schema::FieldSchema,
};

use half::prelude::*;
use rand::prelude::*;
use std::collections::HashMap;

const FP16_VEC_FIELD: &str = "float16_vector_field";
const BF16_VEC_FIELD: &str = "bfloat16_vector_field";

const DIM: i64 = 64;

#[tokio::main]
async fn main() -> Result<(), Error> {
const URL: &str = "http://localhost:19530";

let client = Client::new(URL).await?;

let schema =
CollectionSchemaBuilder::new("milvus_fp16", "fp16/bf16 example for milvus rust SDK")
.add_field(FieldSchema::new_primary_int64(
"id",
"primary key field",
true,
))
.add_field(FieldSchema::new_float16_vector(
FP16_VEC_FIELD,
"fp16 feature field",
DIM,
))
.add_field(FieldSchema::new_bfloat16_vector(
BF16_VEC_FIELD,
"bf16 feature field",
DIM,
))
.build()?;
client.create_collection(schema.clone(), None).await?;

if let Err(err) = fp16_insert_and_query(&client, &schema).await {
println!("failed to run hello milvus: {:?}", err);
}
client.drop_collection(schema.name()).await?;

Ok(())
}

fn gen_random_f32_vector(n: i64) -> Vec<f32> {
let mut data = Vec::<f32>::with_capacity(n as usize);
let mut rng = rand::thread_rng();
for _ in 0..n {
data.push(rng.gen());
}
data
}

async fn fp16_insert_and_query(
client: &Client,
collection: &CollectionSchema,
) -> Result<(), Error> {
let mut embed_data = Vec::<f32>::new();
for _ in 1..=DIM * 1000 {
let mut rng = rand::thread_rng();
let embed = rng.gen();
embed_data.push(embed);
}

// fp16 or bf16 vector accept Vec<f32>, Vec<f64> or Vec<f16>/Vec<bf16> as input
let bf16_column = FieldColumn::new(
collection.get_field(BF16_VEC_FIELD).unwrap(),
Vec::<bf16>::from_f32_slice(embed_data.as_slice()),
)?;
let fp16_column = FieldColumn::new(collection.get_field(FP16_VEC_FIELD).unwrap(), embed_data)?;

let result = client
.insert(collection.name(), vec![fp16_column, bf16_column], None)
.await?;
println!("insert cnt: {}", result.insert_cnt);
client.flush(collection.name()).await?;

let create_index_fut = [FP16_VEC_FIELD, BF16_VEC_FIELD].map(|field_name| {
let index_params = IndexParams::new(
field_name.to_string() + "_index",
IndexType::IvfFlat,
milvus::index::MetricType::L2,
HashMap::from([("nlist".to_owned(), "32".to_owned())]),
);
client.create_index(collection.name(), field_name, index_params)
});
futures::future::try_join_all(create_index_fut).await?;
client.flush(collection.name()).await?;
client
.load_collection(collection.name(), Some(LoadOptions::default()))
.await?;

// search
let q1 = Vec::<f16>::from_f32_slice(&gen_random_f32_vector(DIM));
let q2 = Vec::<f16>::from_f32_slice(&gen_random_f32_vector(DIM));
let option = SearchOptions::with_limit(3)
.metric_type(milvus::index::MetricType::L2)
.output_fields(vec!["id".to_owned(), FP16_VEC_FIELD.to_owned()]);
let result = client
.search(
collection.name(),
vec![q1.into(), q2.into()],
FP16_VEC_FIELD,
&option,
)
.await?;

println!("{:?}", result[0]);
println!("result num: {}, {}", result[0].size, result[1].size);

Ok(())
}
1 change: 1 addition & 0 deletions src/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ pub type ParamValue = serde_json::Value;
pub use serde_json::json as ParamValue;

// search result for a single vector
#[derive(Debug)]
pub struct SearchResult<'a> {
pub size: i64,
pub id: Vec<Value<'a>>,
Expand Down
56 changes: 48 additions & 8 deletions src/data.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::borrow::Cow;

use crate::error::Result;
use crate::{
proto::schema::{
self, field_data::Field, scalar_field::Data as ScalarData,
vector_field::Data as VectorData, DataType, ScalarField, VectorField,
},
schema::FieldSchema,
value::{Value, ValueVec},
value::{TryIntoValueVecWithDataType, Value, ValueVec},
};
use half::prelude::*;
use std::borrow::Cow;

pub trait HasDataType {
fn data_type() -> DataType;
Expand All @@ -34,8 +35,12 @@ impl_has_data_type! {
String, DataType::String,
Cow<'_, str>, DataType::String,
Vec<f32>, DataType::FloatVector,
Vec<f16>, DataType::Float16Vector,
Vec<bf16>, DataType::BFloat16Vector,
Vec<u8>, DataType::BinaryVector,
Cow<'_, [f32]>, DataType::FloatVector,
Cow<'_, [f16]>, DataType::Float16Vector,
Cow<'_, [bf16]>, DataType::BFloat16Vector,
Cow<'_, [u8]>, DataType::BinaryVector
}

Expand Down Expand Up @@ -72,15 +77,16 @@ impl From<schema::FieldData> for FieldColumn {
}

impl FieldColumn {
pub fn new<V: Into<ValueVec>>(schm: &FieldSchema, v: V) -> FieldColumn {
FieldColumn {
pub fn new<V: TryIntoValueVecWithDataType>(schm: &FieldSchema, v: V) -> Result<FieldColumn> {
let value: ValueVec = v.try_into_value_vec(schm.dtype)?;
Ok(FieldColumn {
name: schm.name.clone(),
dtype: schm.dtype,
value: v.into(),
value,
dim: schm.dim,
max_length: schm.max_length,
is_dynamic: false,
}
})
}

pub fn get(&self, idx: usize) -> Option<Value<'_>> {
Expand All @@ -107,6 +113,14 @@ impl FieldColumn {
let dim = (self.dim / 8) as usize;
Value::Binary(Cow::Borrowed(&v[idx * dim..idx * dim + dim]))
}
ValueVec::Float16(v) => {
let dim = self.dim as usize;
Value::Float16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim]))
}
ValueVec::BFloat16(v) => {
let dim = self.dim as usize;
Value::BFloat16Array(Cow::Borrowed(&v[idx * dim..idx * dim + dim]))
}
ValueVec::String(v) => Value::String(Cow::Borrowed(v.get(idx)?.as_ref())),
ValueVec::Json(v) => Value::Json(Cow::Borrowed(v.get(idx)?.as_ref())),
ValueVec::Array(v) => Value::Array(Cow::Borrowed(v.get(idx)?)),
Expand All @@ -126,6 +140,8 @@ impl FieldColumn {
(ValueVec::String(vec), Value::String(i)) => vec.push(i.to_string()),
(ValueVec::Binary(vec), Value::Binary(i)) => vec.extend_from_slice(i.as_ref()),
(ValueVec::Float(vec), Value::FloatArray(i)) => vec.extend_from_slice(i.as_ref()),
(ValueVec::Float16(vec), Value::Float16Array(i)) => vec.extend_from_slice(i.as_ref()),
(ValueVec::BFloat16(vec), Value::BFloat16Array(i)) => vec.extend_from_slice(i.as_ref()),
_ => panic!("column type mismatch"),
}
}
Expand All @@ -135,6 +151,11 @@ impl FieldColumn {
self.value.len() / self.dim as usize
}

#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn copy_with_metadata(&self) -> Self {
Self {
dim: self.dim,
Expand All @@ -151,6 +172,8 @@ impl FieldColumn {
ValueVec::String(_) => ValueVec::String(Vec::new()),
ValueVec::Json(_) => ValueVec::Json(Vec::new()),
ValueVec::Binary(_) => ValueVec::Binary(Vec::new()),
ValueVec::Float16(_) => ValueVec::Float16(Vec::new()),
ValueVec::BFloat16(_) => ValueVec::BFloat16(Vec::new()),
ValueVec::Array(_) => ValueVec::Array(Vec::new()),
},
is_dynamic: self.is_dynamic,
Expand All @@ -176,6 +199,7 @@ impl From<FieldColumn> for schema::FieldData {
data: Some(ScalarData::LongData(schema::LongArray { data: v })),
}),
ValueVec::Float(v) => match this.dtype {
// both scalar and vector fields accept 1-d float array
DataType::Float => Field::Scalars(ScalarField {
data: Some(ScalarData::FloatData(schema::FloatArray { data: v })),
}),
Expand Down Expand Up @@ -204,6 +228,20 @@ impl From<FieldColumn> for schema::FieldData {
data: Some(VectorData::BinaryVector(v)),
dim: this.dim,
}),
ValueVec::BFloat16(v) => {
let v: Vec<u8> = v.into_iter().flat_map(|x| x.to_le_bytes()).collect();
Field::Vectors(VectorField {
data: Some(VectorData::Bfloat16Vector(v)),
dim: this.dim,
})
}
ValueVec::Float16(v) => {
let v: Vec<u8> = v.into_iter().flat_map(|x| x.to_le_bytes()).collect();
Field::Vectors(VectorField {
data: Some(VectorData::Float16Vector(v)),
dim: this.dim,
})
}
}),
is_dynamic: false,
}
Expand Down Expand Up @@ -236,7 +274,9 @@ impl_from_field! {
Vec<i64>[Field::Scalars(ScalarField {data: Some(ScalarData::LongData(schema::LongArray { data }))}) => Some(data)],
Vec<String>[Field::Scalars(ScalarField {data: Some(ScalarData::StringData(schema::StringArray { data }))}) => Some(data)],
Vec<f64>[Field::Scalars(ScalarField {data: Some(ScalarData::DoubleData(schema::DoubleArray { data }))}) => Some(data)],
Vec<u8>[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)]
Vec<u8>[Field::Vectors(VectorField {data: Some(VectorData::BinaryVector(data)), ..}) => Some(data)],
Vec<f16>[Field::Vectors(VectorField {data: Some(VectorData::Float16Vector(data)), ..}) => Some(data.chunks(2).map(|x|f16::from_le_bytes([x[0], x[1]])).collect())],
Vec<bf16>[Field::Vectors(VectorField {data: Some(VectorData::Bfloat16Vector(data)), ..}) => Some(data.chunks(2).map(|x|bf16::from_le_bytes([x[0], x[1]])).collect())]
}

impl FromField for Vec<f32> {
Expand Down
6 changes: 1 addition & 5 deletions src/mutate.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
use prost::bytes::{BufMut, BytesMut};

use crate::error::Result;
use crate::{
client::Client,
collection,
data::FieldColumn,
error::Error,
proto::{
self,
common::{MsgBase, MsgType},
milvus::{InsertRequest, UpsertRequest},
schema::{scalar_field::Data, DataType},
schema::DataType,
},
schema::FieldData,
utils::status_to_result,
value::ValueVec,
};

Expand Down
1 change: 0 additions & 1 deletion src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::error::*;
use crate::{
client::Client,
proto::{
self,
common::{MsgBase, MsgType},
},
utils::status_to_result,
Expand Down
28 changes: 20 additions & 8 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::borrow::Borrow;
use std::collections::HashMap;

use half::{bf16, f16};
use prost::bytes::BytesMut;
use prost::Message;
use std::collections::HashMap;

use crate::client::{Client, ConsistencyLevel};
use crate::collection::{ParamValue, SearchResult};
Expand Down Expand Up @@ -372,6 +371,8 @@ fn get_place_holder_value(vectors: Vec<Value>) -> Result<PlaceholderValue> {

match vectors[0] {
Value::FloatArray(_) => place_holder.r#type = PlaceholderType::FloatVector as _,
Value::Float16Array(_) => place_holder.r#type = PlaceholderType::Float16Vector as _,
Value::BFloat16Array(_) => place_holder.r#type = PlaceholderType::BFloat16Vector as _,
Value::Binary(_) => place_holder.r#type = PlaceholderType::BinaryVector as _,
_ => {
return Err(SuperError::from(crate::collection::Error::IllegalType(
Expand All @@ -381,14 +382,25 @@ fn get_place_holder_value(vectors: Vec<Value>) -> Result<PlaceholderValue> {
}
};

macro_rules! place_holder_push_bytes {
($d:expr, $t:ty) => {
let mut bytes = Vec::<u8>::with_capacity($d.len() * size_of::<$t>());
for f in $d.iter() {
bytes.extend_from_slice(&f.to_le_bytes());
}
place_holder.values.push(bytes)
};
}
for v in &vectors {
match (v, &vectors[0]) {
(Value::FloatArray(d), Value::FloatArray(_)) => {
let mut bytes = Vec::<u8>::with_capacity(d.len() * 4);
for f in d.iter() {
bytes.extend_from_slice(&f.to_le_bytes());
}
place_holder.values.push(bytes)
place_holder_push_bytes!(d, f32);
}
(Value::Float16Array(d), Value::Float16Array(_)) => {
place_holder_push_bytes!(d, f16);
}
(Value::BFloat16Array(d), Value::BFloat16Array(_)) => {
place_holder_push_bytes!(d, bf16);
}
(Value::Binary(d), Value::Binary(_)) => place_holder.values.push(d.to_vec()),
_ => {
Expand Down
Loading

0 comments on commit 0638758

Please sign in to comment.