Skip to content

Commit

Permalink
Started support for sparsevec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 6, 2024
1 parent 13a6ff4 commit c333770
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ extern crate diesel;
mod vector;
pub use vector::Vector;

mod sparsevec;
pub use sparsevec::SparseVec;

#[cfg(feature = "halfvec")]
mod halfvec;

Expand Down
1 change: 1 addition & 0 deletions src/postgres_ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod vector;
mod sparsevec;

#[cfg(feature = "halfvec")]
mod halfvec;
129 changes: 129 additions & 0 deletions src/postgres_ext/sparsevec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use bytes::{BufMut, BytesMut};
use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::convert::TryInto;
use std::error::Error;

use crate::SparseVec;

impl<'a> FromSql<'a> for SparseVec {
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<SparseVec, Box<dyn Error + Sync + Send>> {
SparseVec::from_sql(raw)
}

fn accepts(ty: &Type) -> bool {
ty.name() == "sparsevec"
}
}

impl ToSql for SparseVec {
fn to_sql(&self, _ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
let dim = self.dim;
let nnz = self.indices.len();
w.put_i32(dim.try_into()?);
w.put_i32(nnz.try_into()?);
w.put_i32(0);

for v in &self.indices {
w.put_i32(*v + 1);
}

for v in &self.values {
w.put_f32(*v);
}

Ok(IsNull::No)
}

fn accepts(ty: &Type) -> bool {
ty.name() == "sparsevec"
}

to_sql_checked!();
}

#[cfg(test)]
mod tests {
use crate::SparseVec;
use postgres::binary_copy::BinaryCopyInWriter;
use postgres::types::{Kind, Type};
use postgres::{Client, NoTls};

#[test]
fn it_works() -> Result<(), postgres::Error> {
let user = std::env::var("USER").unwrap();
let mut client = Client::configure()
.host("localhost")
.dbname("pgvector_rust_test")
.user(user.as_str())
.connect(NoTls)?;

client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS postgres_sparse_items", &[])?;
client.execute(
"CREATE TABLE postgres_sparse_items (id bigserial PRIMARY KEY, embedding sparsevec(5))",
&[],
)?;

let vec = SparseVec::new(5, vec![0, 2, 4], vec![1.0, 2.0, 3.0]);
let vec2 = SparseVec::new(5, vec![0, 2, 4], vec![4.0, 5.0, 6.0]);
client.execute(
"INSERT INTO postgres_sparse_items (embedding) VALUES ($1), ($2), (NULL)",
&[&vec, &vec2],
)?;

let query_vec = SparseVec::new(5, vec![0, 2, 4], vec![3.0, 1.0, 2.0]);
let row = client.query_one(
"SELECT embedding FROM postgres_sparse_items ORDER BY embedding <-> $1 LIMIT 1",
&[&query_vec],
)?;
let res_vec: SparseVec = row.get(0);
assert_eq!(vec, res_vec);

// let empty_vec = SparseVec::from(vec![]);
// let empty_res = client.execute(
// "INSERT INTO postgres_sparse_items (embedding) VALUES ($1)",
// &[&empty_vec],
// );
// assert!(empty_res.is_err());
// assert!(empty_res
// .unwrap_err()
// .to_string()
// .contains("vector must have at least 1 dimension"));

// let null_row = client.query_one(
// "SELECT embedding FROM postgres_sparse_items WHERE embedding IS NULL LIMIT 1",
// &[],
// )?;
// let null_res: Option<SparseVec> = null_row.get(0);
// assert!(null_res.is_none());

// // ensures binary format is correct
// let text_row = client.query_one(
// "SELECT embedding::text FROM postgres_sparse_items ORDER BY id LIMIT 1",
// &[],
// )?;
// let text_res: String = text_row.get(0);
// assert_eq!("[1,2,3]", text_res);

// // copy
// let vector_type = get_type(&mut client, "sparsevec")?;
// let writer =
// client.copy_in("COPY postgres_sparse_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
// let mut writer = BinaryCopyInWriter::new(writer, &[vector_type]);
// writer.write(&[&SparseVec::from(vec![1.0, 2.0, 3.0])]).unwrap();
// writer.write(&[&SparseVec::from(vec![4.0, 5.0, 6.0])]).unwrap();
// writer.finish()?;

Ok(())
}

fn get_type(client: &mut Client, name: &str) -> Result<Type, postgres::Error> {
let row = client.query_one("SELECT pg_type.oid, nspname AS schema FROM pg_type INNER JOIN pg_namespace ON pg_namespace.oid = pg_type.typnamespace WHERE typname = $1", &[&name])?;
Ok(Type::new(
name.into(),
row.get("oid"),
Kind::Simple,
row.get("schema"),
))
}
}
37 changes: 37 additions & 0 deletions src/sparsevec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/// A sparse vector.
#[derive(Clone, Debug, PartialEq)]
pub struct SparseVec {
pub(crate) dim: usize,
pub(crate) indices: Vec<i32>,
pub(crate) values: Vec<f32>,
}

impl SparseVec {
pub fn new(dim: usize, indices: Vec<i32>, values: Vec<f32>) -> SparseVec {
SparseVec { dim, indices, values }
}

#[cfg(any(feature = "postgres"))]
pub(crate) fn from_sql(buf: &[u8]) -> Result<SparseVec, Box<dyn std::error::Error + Sync + Send>> {
let dim = i32::from_be_bytes(buf[0..4].try_into()?) as usize;
let nnz = i32::from_be_bytes(buf[4..8].try_into()?) as usize;
let unused = i32::from_be_bytes(buf[8..12].try_into()?);
if unused != 0 {
return Err("expected unused to be 0".into());
}

let mut indices = Vec::with_capacity(nnz);
for i in 0..nnz {
let s = 12 + 4 * i;
indices.push(i32::from_be_bytes(buf[s..s + 4].try_into()?) - 1);
}

let mut values = Vec::with_capacity(nnz);
for i in 0..nnz {
let s = 12 + 4 * nnz + 4 * i;
values.push(f32::from_be_bytes(buf[s..s + 4].try_into()?));
}

Ok(SparseVec { dim, indices, values })
}
}

0 comments on commit c333770

Please sign in to comment.