Skip to content

Commit

Permalink
Added support for halfvec type to Diesel
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 5, 2024
1 parent 20a2da1 commit 1ab2b5d
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ jobs:
- run: cargo test --features serde
- run: cargo test --features postgres,halfvec
- run: cargo test --features sqlx,halfvec
- run: cargo test --features diesel,halfvec
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- Added support for `halfvec`, `bit`, and `sparsevec` types to Rust-Postgres
- Added support for `halfvec` type to SQLx
- Added support for `halfvec` type to Diesel
- Added `l1_distance` function for Diesel

## 0.3.2 (2023-10-30)
Expand Down
180 changes: 180 additions & 0 deletions src/diesel_ext/halfvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use diesel::deserialize::{self, FromSql};
use diesel::pg::{Pg, PgValue};
use diesel::query_builder::QueryId;
use diesel::serialize::{self, IsNull, Output, ToSql};
use diesel::sql_types::SqlType;
use std::convert::TryFrom;
use std::io::Write;

use crate::HalfVec;

#[derive(SqlType, QueryId)]
#[diesel(postgres_type(name = "halfvec"))]
pub struct HalfVecType;

impl ToSql<HalfVecType, Pg> for HalfVec {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
let dim = self.0.len();
out.write_all(&u16::try_from(dim)?.to_be_bytes())?;
out.write_all(&0_u16.to_be_bytes())?;

for v in &self.0 {
out.write_all(&v.to_be_bytes())?;
}

Ok(IsNull::No)
}
}

impl FromSql<HalfVecType, Pg> for HalfVec {
fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
HalfVec::from_sql(value.as_bytes())
}
}

#[cfg(test)]
mod tests {
use crate::{HalfVec, VectorExpressionMethods};
use diesel::pg::PgConnection;
use diesel::{Connection, QueryDsl, RunQueryDsl};
use half::f16;

table! {
use diesel::sql_types::*;

diesel_half_items (id) {
id -> Int4,
embedding -> Nullable<crate::sql_types::HalfVec>,
}
}

use diesel_half_items as items;

#[derive(Queryable)]
#[diesel(table_name = items)]
struct Item {
pub id: i32,
pub embedding: Option<HalfVec>,
}

#[derive(Insertable)]
#[diesel(table_name = items)]
struct NewItem {
pub embedding: Option<HalfVec>,
}

#[test]
fn it_works() -> Result<(), diesel::result::Error> {
let mut conn = PgConnection::establish("postgres://localhost/pgvector_rust_test").unwrap();
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn)?;
diesel::sql_query("DROP TABLE IF EXISTS diesel_half_items").execute(&mut conn)?;
diesel::sql_query(
"CREATE TABLE diesel_half_items (id serial PRIMARY KEY, embedding halfvec(3))",
)
.execute(&mut conn)?;

let new_items = vec![
NewItem {
embedding: Some(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])),
},
NewItem {
embedding: Some(HalfVec::from(vec![
f16::from_f32(2.0),
f16::from_f32(2.0),
f16::from_f32(2.0),
])),
},
NewItem {
embedding: Some(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(2.0),
])),
},
NewItem { embedding: None },
];

diesel::insert_into(items::table)
.values(&new_items)
.get_results::<Item>(&mut conn)?;

let all = items::table.load::<Item>(&mut conn)?;
assert_eq!(4, all.len());

let neighbors = items::table
.order(items::embedding.l2_distance(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
vec![1, 3, 2, 4],
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);
assert_eq!(
Some(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0)
])),
neighbors.first().unwrap().embedding
);

let neighbors = items::table
.order(items::embedding.max_inner_product(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
vec![2, 3, 1, 4],
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);

let neighbors = items::table
.order(items::embedding.cosine_distance(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
vec![1, 2, 3, 4],
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);

let neighbors = items::table
.order(items::embedding.l1_distance(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])))
.limit(5)
.load::<Item>(&mut conn)?;
assert_eq!(
vec![1, 3, 2, 4],
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);

let distances = items::table
.select(items::embedding.max_inner_product(HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(1.0),
f16::from_f32(1.0),
])))
.order(items::id)
.load::<Option<f64>>(&mut conn)?;
assert_eq!(vec![Some(-3.0), Some(-6.0), Some(-4.0), None], distances);

Ok(())
}
}
3 changes: 3 additions & 0 deletions src/diesel_ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
pub(crate) mod vector;

#[cfg(feature = "halfvec")]
pub(crate) mod halfvec;
10 changes: 5 additions & 5 deletions src/diesel_ext/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use diesel::expression::{AsExpression, Expression};
use diesel::pg::{Pg, PgValue};
use diesel::query_builder::QueryId;
use diesel::serialize::{self, IsNull, Output, ToSql};
use diesel::sql_types::{Double, Nullable, SqlType};
use diesel::sql_types::{Double, SqlType};
use std::convert::TryFrom;
use std::io::Write;

Expand Down Expand Up @@ -42,31 +42,31 @@ pub trait VectorExpressionMethods: Expression + Sized {
fn l2_distance<T>(self, other: T) -> L2Distance<Self, T::Expression>
where
Self::SqlType: SqlType,
T: AsExpression<Nullable<VectorType>>,
T: AsExpression<Self::SqlType>,
{
L2Distance::new(self, other.as_expression())
}

fn max_inner_product<T>(self, other: T) -> MaxInnerProduct<Self, T::Expression>
where
Self::SqlType: SqlType,
T: AsExpression<Nullable<VectorType>>,
T: AsExpression<Self::SqlType>,
{
MaxInnerProduct::new(self, other.as_expression())
}

fn cosine_distance<T>(self, other: T) -> CosineDistance<Self, T::Expression>
where
Self::SqlType: SqlType,
T: AsExpression<Nullable<VectorType>>,
T: AsExpression<Self::SqlType>,
{
CosineDistance::new(self, other.as_expression())
}

fn l1_distance<T>(self, other: T) -> L1Distance<Self, T::Expression>
where
Self::SqlType: SqlType,
T: AsExpression<Nullable<VectorType>>,
T: AsExpression<Self::SqlType>,
{
L1Distance::new(self, other.as_expression())
}
Expand Down
10 changes: 9 additions & 1 deletion src/halfvec.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use half::f16;

#[cfg(feature = "diesel")]
use crate::diesel_ext::halfvec::HalfVecType;

#[cfg(feature = "diesel")]
use diesel::{deserialize::FromSqlRow, expression::AsExpression};

/// A half vector.
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "diesel", derive(FromSqlRow, AsExpression))]
#[cfg_attr(feature = "diesel", diesel(sql_type = HalfVecType))]
pub struct HalfVec(pub(crate) Vec<f16>);

impl From<Vec<f16>> for HalfVec {
Expand All @@ -27,7 +35,7 @@ impl HalfVec {
self.0.as_slice()
}

#[cfg(any(feature = "postgres", feature = "sqlx"))]
#[cfg(any(feature = "postgres", feature = "sqlx", feature = "diesel"))]
pub(crate) fn from_sql(
buf: &[u8],
) -> Result<HalfVec, Box<dyn std::error::Error + Sync + Send>> {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod diesel_ext;

#[cfg(feature = "diesel")]
pub mod sql_types {
pub use super::diesel_ext::halfvec::HalfVecType as HalfVec;
pub use super::diesel_ext::vector::VectorType as Vector;
}

Expand Down

0 comments on commit 1ab2b5d

Please sign in to comment.