-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for halfvec, bit, and sparsevec types to Rust-Postgres
- Loading branch information
Showing
12 changed files
with
628 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/// A bit string. | ||
#[derive(Clone, Debug, PartialEq)] | ||
pub struct Bit<'a> { | ||
pub(crate) len: usize, | ||
pub(crate) data: &'a [u8], | ||
} | ||
|
||
impl<'a> Bit<'a> { | ||
/// Creates a bit string for a slice of bytes. | ||
pub fn from_bytes(data: &'a [u8]) -> Bit { | ||
Bit { | ||
len: data.len().checked_mul(8).unwrap(), | ||
data, | ||
} | ||
} | ||
|
||
/// Returns the number of bits in the vector. | ||
pub fn len(&self) -> usize { | ||
self.len | ||
} | ||
|
||
/// Returns the vector as a slice of bytes. | ||
pub fn as_bytes(&self) -> &'a [u8] { | ||
self.data | ||
} | ||
|
||
#[cfg(any(feature = "postgres"))] | ||
pub(crate) fn from_sql(buf: &[u8]) -> Result<Bit, Box<dyn std::error::Error + Sync + Send>> { | ||
let len = i32::from_be_bytes(buf[0..4].try_into()?) as usize; | ||
let data = &buf[4..4 + len / 8]; | ||
|
||
Ok(Bit { len, data }) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::Bit; | ||
|
||
#[test] | ||
fn test_as_bytes() { | ||
let vec = Bit::from_bytes(&[0b00000000, 0b11111111]); | ||
assert_eq!(16, vec.len()); | ||
assert_eq!(&[0b00000000, 0b11111111], vec.as_bytes()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
use half::f16; | ||
|
||
/// A half vector. | ||
#[derive(Clone, Debug, PartialEq)] | ||
pub struct HalfVec(pub(crate) Vec<f16>); | ||
|
||
impl From<Vec<f16>> for HalfVec { | ||
fn from(v: Vec<f16>) -> Self { | ||
HalfVec(v) | ||
} | ||
} | ||
|
||
impl From<HalfVec> for Vec<f16> { | ||
fn from(val: HalfVec) -> Self { | ||
val.0 | ||
} | ||
} | ||
|
||
impl HalfVec { | ||
/// Returns a copy of the half vector as a `Vec<f16>`. | ||
pub fn to_vec(&self) -> Vec<f16> { | ||
self.0.clone() | ||
} | ||
|
||
/// Returns the half vector as a slice. | ||
pub fn as_slice(&self) -> &[f16] { | ||
self.0.as_slice() | ||
} | ||
|
||
#[cfg(any(feature = "postgres"))] | ||
pub(crate) fn from_sql( | ||
buf: &[u8], | ||
) -> Result<HalfVec, Box<dyn std::error::Error + Sync + Send>> { | ||
let dim = u16::from_be_bytes(buf[0..2].try_into()?) as usize; | ||
let unused = u16::from_be_bytes(buf[2..4].try_into()?); | ||
if unused != 0 { | ||
return Err("expected unused to be 0".into()); | ||
} | ||
|
||
let mut vec = Vec::with_capacity(dim); | ||
for i in 0..dim { | ||
let s = 4 + 2 * i; | ||
vec.push(f16::from_be_bytes(buf[s..s + 2].try_into()?)); | ||
} | ||
|
||
Ok(HalfVec(vec)) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::HalfVec; | ||
use half::f16; | ||
|
||
#[test] | ||
fn test_into() { | ||
let vec = HalfVec::from(vec![ | ||
f16::from_f32(1.0), | ||
f16::from_f32(2.0), | ||
f16::from_f32(3.0), | ||
]); | ||
let f16_vec: Vec<f16> = vec.into(); | ||
assert_eq!( | ||
f16_vec, | ||
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)] | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_to_vec() { | ||
let vec = HalfVec::from(vec![ | ||
f16::from_f32(1.0), | ||
f16::from_f32(2.0), | ||
f16::from_f32(3.0), | ||
]); | ||
assert_eq!( | ||
vec.to_vec(), | ||
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)] | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_as_slice() { | ||
let vec = HalfVec::from(vec![ | ||
f16::from_f32(1.0), | ||
f16::from_f32(2.0), | ||
f16::from_f32(3.0), | ||
]); | ||
assert_eq!( | ||
vec.as_slice(), | ||
&[f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)] | ||
); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
use bytes::{BufMut, BytesMut}; | ||
use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; | ||
use std::convert::TryInto; | ||
use std::error::Error; | ||
|
||
use crate::Bit; | ||
|
||
impl<'a> FromSql<'a> for Bit<'a> { | ||
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Bit<'a>, Box<dyn Error + Sync + Send>> { | ||
Bit::from_sql(raw) | ||
} | ||
|
||
fn accepts(ty: &Type) -> bool { | ||
ty.name() == "bit" | ||
} | ||
} | ||
|
||
impl<'a> ToSql for Bit<'a> { | ||
fn to_sql(&self, _ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> { | ||
let len = self.len; | ||
w.put_i32(len.try_into()?); | ||
|
||
for v in self.data { | ||
w.put_u8(*v); | ||
} | ||
|
||
Ok(IsNull::No) | ||
} | ||
|
||
fn accepts(ty: &Type) -> bool { | ||
ty.name() == "bit" | ||
} | ||
|
||
to_sql_checked!(); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::Bit; | ||
use postgres::binary_copy::BinaryCopyInWriter; | ||
use postgres::types::Type; | ||
use postgres::{Client, NoTls}; | ||
|
||
#[test] | ||
fn it_works() -> Result<(), postgres::Error> { | ||
let user = std::env::var("USER").unwrap(); | ||
let mut client = Client::configure() | ||
.host("localhost") | ||
.dbname("pgvector_rust_test") | ||
.user(user.as_str()) | ||
.connect(NoTls)?; | ||
|
||
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?; | ||
client.execute("DROP TABLE IF EXISTS postgres_bit_items", &[])?; | ||
client.execute( | ||
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(8))", | ||
&[], | ||
)?; | ||
|
||
let vec = Bit::from_bytes(&[0b10101010]); | ||
let vec2 = Bit::from_bytes(&[0b01010101]); | ||
client.execute( | ||
"INSERT INTO postgres_bit_items (embedding) VALUES ($1), ($2), (NULL)", | ||
&[&vec, &vec2], | ||
)?; | ||
|
||
let query_vec = Bit::from_bytes(&[0b10101010]); | ||
let row = client.query_one( | ||
"SELECT embedding FROM postgres_bit_items ORDER BY embedding <~> $1 LIMIT 1", | ||
&[&query_vec], | ||
)?; | ||
let res_vec: Bit = row.get(0); | ||
assert_eq!(vec, res_vec); | ||
assert_eq!(8, res_vec.len()); | ||
assert_eq!(&[0b10101010], res_vec.as_bytes()); | ||
|
||
let null_row = client.query_one( | ||
"SELECT embedding FROM postgres_bit_items WHERE embedding IS NULL LIMIT 1", | ||
&[], | ||
)?; | ||
let null_res: Option<Bit> = null_row.get(0); | ||
assert!(null_res.is_none()); | ||
|
||
// ensures binary format is correct | ||
let text_row = client.query_one( | ||
"SELECT embedding::text FROM postgres_bit_items ORDER BY id LIMIT 1", | ||
&[], | ||
)?; | ||
let text_res: String = text_row.get(0); | ||
assert_eq!("10101010", text_res); | ||
|
||
// copy | ||
let bit_type = Type::BIT; | ||
let writer = client | ||
.copy_in("COPY postgres_bit_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?; | ||
let mut writer = BinaryCopyInWriter::new(writer, &[bit_type]); | ||
writer.write(&[&Bit::from_bytes(&[0b10101010])])?; | ||
writer.write(&[&Bit::from_bytes(&[0b01010101])])?; | ||
writer.finish()?; | ||
|
||
Ok(()) | ||
} | ||
} |
Oops, something went wrong.