Skip to content

Commit

Permalink
Improved test
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Oct 1, 2023
1 parent 11badfd commit 744a575
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions src/diesel_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,31 +73,33 @@ mod tests {
table! {
use diesel::sql_types::*;

items (id) {
diesel_items (id) {
id -> Int4,
embedding -> Nullable<crate::sql_types::Vector>,
}
}

#[derive(PartialEq, Queryable)]
#[diesel(table_name = items)]
use diesel_items as items;

#[derive(Queryable)]
#[diesel(table_name = diesel_items)]
struct Item {
pub id: i32,
pub embedding: Option<crate::Vector>,
pub embedding: Option<Vector>,
}

#[derive(Insertable)]
#[diesel(table_name = items)]
#[diesel(table_name = diesel_items)]
struct NewItem {
pub embedding: Option<crate::Vector>,
pub embedding: Option<Vector>,
}

#[test]
fn it_works() -> Result<(), diesel::result::Error> {
let mut conn = PgConnection::establish("postgres://localhost/pgvector_rust_test").unwrap();
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn)?;
diesel::sql_query("DROP TABLE IF EXISTS items").execute(&mut conn)?;
diesel::sql_query("CREATE TABLE items (id serial PRIMARY KEY, embedding vector(3))")
diesel::sql_query("DROP TABLE IF EXISTS diesel_items").execute(&mut conn)?;
diesel::sql_query("CREATE TABLE diesel_items (id serial PRIMARY KEY, embedding vector(3))")
.execute(&mut conn)?;

let new_items = vec![
Expand Down Expand Up @@ -126,7 +128,11 @@ mod tests {
.load::<Item>(&mut conn)?;
assert_eq!(
vec![1, 3, 2, 4],
neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>()
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);
assert_eq!(
Some(Vector::from(vec![1.0, 1.0, 1.0])),
neighbors.first().unwrap().embedding
);

let neighbors = items::table
Expand All @@ -135,7 +141,7 @@ mod tests {
.load::<Item>(&mut conn)?;
assert_eq!(
vec![2, 3, 1, 4],
neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>()
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);

let neighbors = items::table
Expand All @@ -144,7 +150,7 @@ mod tests {
.load::<Item>(&mut conn)?;
assert_eq!(
vec![1, 2, 3, 4],
neighbors.into_iter().map(|v| v.id).collect::<Vec<i32>>()
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
);

let distances = items::table
Expand Down

0 comments on commit 744a575

Please sign in to comment.