Skip to content

Commit

Permalink
Improved test
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 30, 2023
1 parent 45d988f commit 41afd8a
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions src/diesel_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ mod tests {
}

#[test]
fn it_works() {
fn it_works() -> Result<(), diesel::result::Error> {
use crate::Vector;
use crate::VectorExpressionMethods;
use diesel::pg::PgConnection;
use diesel::{Connection, QueryDsl, RunQueryDsl};

let mut conn = PgConnection::establish("postgres://localhost/pgvector_rust_test").unwrap();
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn).unwrap();
diesel::sql_query("DROP TABLE IF EXISTS items").execute(&mut conn).unwrap();
diesel::sql_query("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3))").execute(&mut conn).unwrap();
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn)?;
diesel::sql_query("DROP TABLE IF EXISTS items").execute(&mut conn)?;
diesel::sql_query("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3))").execute(&mut conn)?;

let new_items = vec![
Item {
Expand All @@ -108,37 +108,35 @@ mod tests {
},
];

diesel::insert_into(items::table).values(&new_items).get_results::<Item>(&mut conn).unwrap();
diesel::insert_into(items::table).values(&new_items).get_results::<Item>(&mut conn)?;

let all = items::table.load::<Item>(&mut conn).unwrap();
let all = items::table.load::<Item>(&mut conn)?;
assert_eq!(4, all.len());

let neighbors = items::table
.order(items::embedding.l2_distance(Vector::from(vec![1.0, 1.0, 1.0])))
.limit(5)
.load::<Item>(&mut conn)
.unwrap();
.load::<Item>(&mut conn)?;
assert_eq!(vec![1, 3, 2, 4], neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>());

let neighbors = items::table
.order(items::embedding.max_inner_product(Vector::from(vec![1.0, 1.0, 1.0])))
.limit(5)
.load::<Item>(&mut conn)
.unwrap();
.load::<Item>(&mut conn)?;
assert_eq!(vec![2, 3, 1, 4], neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>());

let neighbors = items::table
.order(items::embedding.cosine_distance(Vector::from(vec![1.0, 1.0, 1.0])))
.limit(5)
.load::<Item>(&mut conn)
.unwrap();
.load::<Item>(&mut conn)?;
assert_eq!(vec![1, 2, 3, 4], neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>());

let distances = items::table
.select(items::embedding.max_inner_product(Vector::from(vec![1.0, 1.0, 1.0])))
.order(items::id)
.load::<Option<f64>>(&mut conn)
.unwrap();
.load::<Option<f64>>(&mut conn)?;
assert_eq!(vec![Some(-3.0), Some(-6.0), Some(-4.0), None], distances);

Ok(())
}
}

0 comments on commit 41afd8a

Please sign in to comment.