-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mod vector; | ||
mod sparsevec; | ||
|
||
#[cfg(feature = "halfvec")] | ||
mod halfvec; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
use bytes::{BufMut, BytesMut}; | ||
use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type}; | ||
use std::convert::TryInto; | ||
use std::error::Error; | ||
|
||
use crate::SparseVec; | ||
|
||
impl<'a> FromSql<'a> for SparseVec { | ||
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<SparseVec, Box<dyn Error + Sync + Send>> { | ||
SparseVec::from_sql(raw) | ||
} | ||
|
||
fn accepts(ty: &Type) -> bool { | ||
ty.name() == "sparsevec" | ||
} | ||
} | ||
|
||
impl ToSql for SparseVec { | ||
fn to_sql(&self, _ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> { | ||
let dim = self.dim; | ||
let nnz = self.indices.len(); | ||
w.put_i32(dim.try_into()?); | ||
w.put_i32(nnz.try_into()?); | ||
w.put_i32(0); | ||
|
||
for v in &self.indices { | ||
w.put_i32(*v + 1); | ||
} | ||
|
||
for v in &self.values { | ||
w.put_f32(*v); | ||
} | ||
|
||
Ok(IsNull::No) | ||
} | ||
|
||
fn accepts(ty: &Type) -> bool { | ||
ty.name() == "sparsevec" | ||
} | ||
|
||
to_sql_checked!(); | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::SparseVec; | ||
use postgres::binary_copy::BinaryCopyInWriter; | ||
use postgres::types::{Kind, Type}; | ||
use postgres::{Client, NoTls}; | ||
|
||
#[test] | ||
fn it_works() -> Result<(), postgres::Error> { | ||
let user = std::env::var("USER").unwrap(); | ||
let mut client = Client::configure() | ||
.host("localhost") | ||
.dbname("pgvector_rust_test") | ||
.user(user.as_str()) | ||
.connect(NoTls)?; | ||
|
||
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?; | ||
client.execute("DROP TABLE IF EXISTS postgres_sparse_items", &[])?; | ||
client.execute( | ||
"CREATE TABLE postgres_sparse_items (id bigserial PRIMARY KEY, embedding sparsevec(5))", | ||
&[], | ||
)?; | ||
|
||
let vec = SparseVec::new(5, vec![0, 2, 4], vec![1.0, 2.0, 3.0]); | ||
let vec2 = SparseVec::new(5, vec![0, 2, 4], vec![4.0, 5.0, 6.0]); | ||
client.execute( | ||
"INSERT INTO postgres_sparse_items (embedding) VALUES ($1), ($2), (NULL)", | ||
&[&vec, &vec2], | ||
)?; | ||
|
||
let query_vec = SparseVec::new(5, vec![0, 2, 4], vec![3.0, 1.0, 2.0]); | ||
let row = client.query_one( | ||
"SELECT embedding FROM postgres_sparse_items ORDER BY embedding <-> $1 LIMIT 1", | ||
&[&query_vec], | ||
)?; | ||
let res_vec: SparseVec = row.get(0); | ||
assert_eq!(vec, res_vec); | ||
|
||
// let empty_vec = SparseVec::from(vec![]); | ||
// let empty_res = client.execute( | ||
// "INSERT INTO postgres_sparse_items (embedding) VALUES ($1)", | ||
// &[&empty_vec], | ||
// ); | ||
// assert!(empty_res.is_err()); | ||
// assert!(empty_res | ||
// .unwrap_err() | ||
// .to_string() | ||
// .contains("vector must have at least 1 dimension")); | ||
|
||
// let null_row = client.query_one( | ||
// "SELECT embedding FROM postgres_sparse_items WHERE embedding IS NULL LIMIT 1", | ||
// &[], | ||
// )?; | ||
// let null_res: Option<SparseVec> = null_row.get(0); | ||
// assert!(null_res.is_none()); | ||
|
||
// // ensures binary format is correct | ||
// let text_row = client.query_one( | ||
// "SELECT embedding::text FROM postgres_sparse_items ORDER BY id LIMIT 1", | ||
// &[], | ||
// )?; | ||
// let text_res: String = text_row.get(0); | ||
// assert_eq!("[1,2,3]", text_res); | ||
|
||
// // copy | ||
// let vector_type = get_type(&mut client, "sparsevec")?; | ||
// let writer = | ||
// client.copy_in("COPY postgres_sparse_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?; | ||
// let mut writer = BinaryCopyInWriter::new(writer, &[vector_type]); | ||
// writer.write(&[&SparseVec::from(vec![1.0, 2.0, 3.0])]).unwrap(); | ||
// writer.write(&[&SparseVec::from(vec![4.0, 5.0, 6.0])]).unwrap(); | ||
// writer.finish()?; | ||
|
||
Ok(()) | ||
} | ||
|
||
fn get_type(client: &mut Client, name: &str) -> Result<Type, postgres::Error> { | ||
let row = client.query_one("SELECT pg_type.oid, nspname AS schema FROM pg_type INNER JOIN pg_namespace ON pg_namespace.oid = pg_type.typnamespace WHERE typname = $1", &[&name])?; | ||
Ok(Type::new( | ||
name.into(), | ||
row.get("oid"), | ||
Kind::Simple, | ||
row.get("schema"), | ||
)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/// A sparse vector. | ||
#[derive(Clone, Debug, PartialEq)] | ||
pub struct SparseVec { | ||
pub(crate) dim: usize, | ||
pub(crate) indices: Vec<i32>, | ||
pub(crate) values: Vec<f32>, | ||
} | ||
|
||
impl SparseVec { | ||
pub fn new(dim: usize, indices: Vec<i32>, values: Vec<f32>) -> SparseVec { | ||
SparseVec { dim, indices, values } | ||
} | ||
|
||
#[cfg(any(feature = "postgres"))] | ||
pub(crate) fn from_sql(buf: &[u8]) -> Result<SparseVec, Box<dyn std::error::Error + Sync + Send>> { | ||
let dim = i32::from_be_bytes(buf[0..4].try_into()?) as usize; | ||
let nnz = i32::from_be_bytes(buf[4..8].try_into()?) as usize; | ||
let unused = i32::from_be_bytes(buf[8..12].try_into()?); | ||
if unused != 0 { | ||
return Err("expected unused to be 0".into()); | ||
} | ||
|
||
let mut indices = Vec::with_capacity(nnz); | ||
for i in 0..nnz { | ||
let s = 12 + 4 * i; | ||
indices.push(i32::from_be_bytes(buf[s..s + 4].try_into()?) - 1); | ||
} | ||
|
||
let mut values = Vec::with_capacity(nnz); | ||
for i in 0..nnz { | ||
let s = 12 + 4 * nnz + 4 * i; | ||
values.push(f32::from_be_bytes(buf[s..s + 4].try_into()?)); | ||
} | ||
|
||
Ok(SparseVec { dim, indices, values }) | ||
} | ||
} |