Skip to content

Commit

Permalink
Added support for halfvec, bit, and sparsevec types to Rust-Postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 1, 2024
1 parent 221c36c commit 4d7a9d8
Show file tree
Hide file tree
Showing 12 changed files with 628 additions and 3 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ jobs:
dev-files: true
- run: |
cd /tmp
git clone --branch v0.6.0 https://github.com/pgvector/pgvector.git
git clone --branch v0.7.0 https://github.com/pgvector/pgvector.git
cd pgvector
make
sudo make install
- run: psql -d pgvector_rust_test -c "CREATE EXTENSION vector"

# test features individually
- run: cargo test --features postgres
- run: cargo test --features sqlx
- run: cargo test --features diesel
- run: cargo test --features serde
- run: cargo test --features postgres,halfvec
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.3.3 (unreleased)

- Added support for `halfvec`, `bit`, and `sparsevec` types to Rust-Postgres

## 0.3.2 (2023-10-30)

- Fixed error with Diesel without `with-deprecated` feature
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ postgres-types = { version = "0.2", default-features = false, optional = true }
diesel = { version = "2", default-features = false, features = ["postgres"], optional = true }
sqlx = { version = ">= 0.5, < 0.8", default-features = false, features = ["postgres"], optional = true }
serde = { version = "1", features = ["derive"], optional = true }
half = { version = "2", default-features = false, optional = true }

[dev-dependencies]
postgres = { version = "0.19", default-features = false }
Expand All @@ -28,3 +29,4 @@ serde_json = "1"

[features]
postgres = ["dep:postgres-types", "dep:bytes"]
halfvec = ["dep:half"]
46 changes: 46 additions & 0 deletions src/bit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/// A bit string.
#[derive(Clone, Debug, PartialEq)]
pub struct Bit<'a> {
pub(crate) len: usize,
pub(crate) data: &'a [u8],
}

impl<'a> Bit<'a> {
/// Creates a bit string for a slice of bytes.
pub fn from_bytes(data: &'a [u8]) -> Bit {
Bit {
len: data.len().checked_mul(8).unwrap(),
data,
}
}

/// Returns the number of bits in the vector.
pub fn len(&self) -> usize {
self.len
}

/// Returns the vector as a slice of bytes.
pub fn as_bytes(&self) -> &'a [u8] {
self.data
}

#[cfg(any(feature = "postgres"))]
pub(crate) fn from_sql(buf: &[u8]) -> Result<Bit, Box<dyn std::error::Error + Sync + Send>> {
let len = i32::from_be_bytes(buf[0..4].try_into()?) as usize;
let data = &buf[4..4 + len / 8];

Ok(Bit { len, data })
}
}

#[cfg(test)]
mod tests {
use crate::Bit;

#[test]
fn test_as_bytes() {
let vec = Bit::from_bytes(&[0b00000000, 0b11111111]);
assert_eq!(16, vec.len());
assert_eq!(&[0b00000000, 0b11111111], vec.as_bytes());
}
}
94 changes: 94 additions & 0 deletions src/halfvec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use half::f16;

/// A half vector.
#[derive(Clone, Debug, PartialEq)]
pub struct HalfVec(pub(crate) Vec<f16>);

impl From<Vec<f16>> for HalfVec {
fn from(v: Vec<f16>) -> Self {
HalfVec(v)
}
}

impl From<HalfVec> for Vec<f16> {
fn from(val: HalfVec) -> Self {
val.0
}
}

impl HalfVec {
/// Returns a copy of the half vector as a `Vec<f16>`.
pub fn to_vec(&self) -> Vec<f16> {
self.0.clone()
}

/// Returns the half vector as a slice.
pub fn as_slice(&self) -> &[f16] {
self.0.as_slice()
}

#[cfg(any(feature = "postgres"))]
pub(crate) fn from_sql(
buf: &[u8],
) -> Result<HalfVec, Box<dyn std::error::Error + Sync + Send>> {
let dim = u16::from_be_bytes(buf[0..2].try_into()?) as usize;
let unused = u16::from_be_bytes(buf[2..4].try_into()?);
if unused != 0 {
return Err("expected unused to be 0".into());
}

let mut vec = Vec::with_capacity(dim);
for i in 0..dim {
let s = 4 + 2 * i;
vec.push(f16::from_be_bytes(buf[s..s + 2].try_into()?));
}

Ok(HalfVec(vec))
}
}

#[cfg(test)]
mod tests {
use crate::HalfVec;
use half::f16;

#[test]
fn test_into() {
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
]);
let f16_vec: Vec<f16> = vec.into();
assert_eq!(
f16_vec,
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
);
}

#[test]
fn test_to_vec() {
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
]);
assert_eq!(
vec.to_vec(),
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
);
}

#[test]
fn test_as_slice() {
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
]);
assert_eq!(
vec.as_slice(),
&[f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
);
}
}
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,20 @@
#[macro_use]
extern crate diesel;

mod bit;
mod sparsevec;
mod vector;

pub use bit::Bit;
pub use sparsevec::SparseVec;
pub use vector::Vector;

#[cfg(feature = "halfvec")]
mod halfvec;

#[cfg(feature = "halfvec")]
pub use halfvec::HalfVec;

#[cfg(feature = "postgres")]
mod postgres_ext;

Expand Down
103 changes: 103 additions & 0 deletions src/postgres_ext/bit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use bytes::{BufMut, BytesMut};
use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::convert::TryInto;
use std::error::Error;

use crate::Bit;

impl<'a> FromSql<'a> for Bit<'a> {
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Bit<'a>, Box<dyn Error + Sync + Send>> {
Bit::from_sql(raw)
}

fn accepts(ty: &Type) -> bool {
ty.name() == "bit"
}
}

impl<'a> ToSql for Bit<'a> {
fn to_sql(&self, _ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
let len = self.len;
w.put_i32(len.try_into()?);

for v in self.data {
w.put_u8(*v);
}

Ok(IsNull::No)
}

fn accepts(ty: &Type) -> bool {
ty.name() == "bit"
}

to_sql_checked!();
}

#[cfg(test)]
mod tests {
use crate::Bit;
use postgres::binary_copy::BinaryCopyInWriter;
use postgres::types::Type;
use postgres::{Client, NoTls};

#[test]
fn it_works() -> Result<(), postgres::Error> {
let user = std::env::var("USER").unwrap();
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_rust_test")
.user(user.as_str())
.connect(NoTls)?;

client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS postgres_bit_items", &[])?;
client.execute(
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(8))",
&[],
)?;

let vec = Bit::from_bytes(&[0b10101010]);
let vec2 = Bit::from_bytes(&[0b01010101]);
client.execute(
"INSERT INTO postgres_bit_items (embedding) VALUES ($1), ($2), (NULL)",
&[&vec, &vec2],
)?;

let query_vec = Bit::from_bytes(&[0b10101010]);
let row = client.query_one(
"SELECT embedding FROM postgres_bit_items ORDER BY embedding <~> $1 LIMIT 1",
&[&query_vec],
)?;
let res_vec: Bit = row.get(0);
assert_eq!(vec, res_vec);
assert_eq!(8, res_vec.len());
assert_eq!(&[0b10101010], res_vec.as_bytes());

let null_row = client.query_one(
"SELECT embedding FROM postgres_bit_items WHERE embedding IS NULL LIMIT 1",
&[],
)?;
let null_res: Option<Bit> = null_row.get(0);
assert!(null_res.is_none());

// ensures binary format is correct
let text_row = client.query_one(
"SELECT embedding::text FROM postgres_bit_items ORDER BY id LIMIT 1",
&[],
)?;
let text_res: String = text_row.get(0);
assert_eq!("10101010", text_res);

// copy
let bit_type = Type::BIT;
let writer = client
.copy_in("COPY postgres_bit_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
let mut writer = BinaryCopyInWriter::new(writer, &[bit_type]);
writer.write(&[&Bit::from_bytes(&[0b10101010])])?;
writer.write(&[&Bit::from_bytes(&[0b01010101])])?;
writer.finish()?;

Ok(())
}
}
Loading

0 comments on commit 4d7a9d8

Please sign in to comment.