Skip to content

Commit

Permalink
Improved bit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent e79487c commit 67972f0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 36 deletions.
16 changes: 8 additions & 8 deletions src/diesel_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,24 @@ mod tests {
diesel::sql_query("CREATE EXTENSION IF NOT EXISTS vector").execute(&mut conn)?;
diesel::sql_query("DROP TABLE IF EXISTS diesel_bit_items").execute(&mut conn)?;
diesel::sql_query(
"CREATE TABLE diesel_bit_items (id serial PRIMARY KEY, embedding bit(10))",
"CREATE TABLE diesel_bit_items (id serial PRIMARY KEY, embedding bit(9))",
)
.execute(&mut conn)?;

let new_items = vec![
NewItem {
embedding: Some(Bit::new(&[
false, false, false, false, false, false, false, false, false, true,
false, false, false, false, false, false, false, false, true,
])),
},
NewItem {
embedding: Some(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
false, true, false, true, false, false, false, false, true,
])),
},
NewItem {
embedding: Some(Bit::new(&[
true, true, true, false, false, false, false, false, false, true,
false, true, true, true, false, false, false, false, true,
])),
},
NewItem { embedding: None },
Expand All @@ -95,7 +95,7 @@ mod tests {

let neighbors = items::table
.order(items::embedding.hamming_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
false, true, false, true, false, false, false, false, true,
])))
.limit(5)
.load::<Item>(&mut conn)?;
Expand All @@ -105,14 +105,14 @@ mod tests {
);
assert_eq!(
Some(Bit::new(&[
true, false, true, false, false, false, false, false, false, true
false, true, false, true, false, false, false, false, true
])),
neighbors.first().unwrap().embedding
);

let neighbors = items::table
.order(items::embedding.jaccard_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
false, true, false, true, false, false, false, false, true,
])))
.limit(5)
.load::<Item>(&mut conn)?;
Expand All @@ -123,7 +123,7 @@ mod tests {

let distances = items::table
.select(items::embedding.hamming_distance(Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
false, true, false, true, false, false, false, false, true,
])))
.order(items::id)
.load::<Option<f64>>(&mut conn)?;
Expand Down
24 changes: 9 additions & 15 deletions src/postgres_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,32 +49,26 @@ mod tests {
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
client.execute("DROP TABLE IF EXISTS postgres_bit_items", &[])?;
client.execute(
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(10))",
"CREATE TABLE postgres_bit_items (id bigserial PRIMARY KEY, embedding bit(9))",
&[],
)?;

let vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let vec2 = Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
]);
let vec = Bit::new(&[false, true, false, true, false, false, false, false, true]);
let vec2 = Bit::new(&[false, false, true, false, false, false, false, false, true]);
client.execute(
"INSERT INTO postgres_bit_items (embedding) VALUES ($1), ($2), (NULL)",
&[&vec, &vec2],
)?;

let query_vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let query_vec = Bit::new(&[false, true, false, true, false, false, false, false, true]);
let row = client.query_one(
"SELECT embedding FROM postgres_bit_items ORDER BY embedding <~> $1 LIMIT 1",
&[&query_vec],
)?;
let res_vec: Bit = row.get(0);
assert_eq!(vec, res_vec);
assert_eq!(10, res_vec.len());
assert_eq!(&[0b10100000, 0b01000000], res_vec.as_bytes());
assert_eq!(9, res_vec.len());
assert_eq!(&[0b01010000, 0b10000000], res_vec.as_bytes());

let null_row = client.query_one(
"SELECT embedding FROM postgres_bit_items WHERE embedding IS NULL LIMIT 1",
Expand All @@ -89,18 +83,18 @@ mod tests {
&[],
)?;
let text_res: String = text_row.get(0);
assert_eq!("1010000001", text_res);
assert_eq!("010100001", text_res);

// copy
let bit_type = Type::BIT;
let writer = client
.copy_in("COPY postgres_bit_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
let mut writer = BinaryCopyInWriter::new(writer, &[bit_type]);
writer.write(&[&Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
false, true, false, true, false, false, false, false, true,
])])?;
writer.write(&[&Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
false, false, true, false, false, false, false, false, true,
])])?;
writer.finish()?;

Expand Down
20 changes: 7 additions & 13 deletions src/sqlx_ext/bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,27 @@ mod tests {
sqlx::query("DROP TABLE IF EXISTS sqlx_bit_items")
.execute(&pool)
.await?;
sqlx::query("CREATE TABLE sqlx_bit_items (id bigserial PRIMARY KEY, embedding bit(10))")
sqlx::query("CREATE TABLE sqlx_bit_items (id bigserial PRIMARY KEY, embedding bit(9))")
.execute(&pool)
.await?;

let vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let vec2 = Bit::new(&[
false, true, false, false, false, false, false, false, false, true,
]);
let vec = Bit::new(&[false, true, false, true, false, false, false, false, true]);
let vec2 = Bit::new(&[false, false, true, false, false, false, false, false, true]);
sqlx::query("INSERT INTO sqlx_bit_items (embedding) VALUES ($1), ($2), (NULL)")
.bind(&vec)
.bind(&vec2)
.execute(&pool)
.await?;

let query_vec = Bit::new(&[
true, false, true, false, false, false, false, false, false, true,
]);
let query_vec = Bit::new(&[false, true, false, true, false, false, false, false, true]);
let row =
sqlx::query("SELECT embedding FROM sqlx_bit_items ORDER BY embedding <~> $1 LIMIT 1")
.bind(query_vec)
.fetch_one(&pool)
.await?;
let res_vec: Bit = row.try_get("embedding").unwrap();
assert_eq!(vec, res_vec);
assert_eq!(&[0b10100000, 0b01000000], res_vec.as_bytes());
assert_eq!(&[0b01010000, 0b10000000], res_vec.as_bytes());

let null_row =
sqlx::query("SELECT embedding FROM sqlx_bit_items WHERE embedding IS NULL LIMIT 1")
Expand All @@ -94,9 +88,9 @@ mod tests {
.fetch_one(&pool)
.await?;
let text_res: String = text_row.try_get("embedding").unwrap();
assert_eq!("1010000001", text_res);
assert_eq!("010100001", text_res);

sqlx::query("ALTER TABLE sqlx_bit_items ADD COLUMN factors bit(10)[]")
sqlx::query("ALTER TABLE sqlx_bit_items ADD COLUMN factors bit(9)[]")
.execute(&pool)
.await?;

Expand Down

0 comments on commit 67972f0

Please sign in to comment.