Skip to content

Commit

Permalink
Renamed to HalfVec
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Apr 6, 2024
1 parent f570de7 commit 13a6ff4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
24 changes: 12 additions & 12 deletions src/halfvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@ use half::f16;

/// A half vector.
#[derive(Clone, Debug, PartialEq)]
pub struct HalfVector(pub(crate) Vec<f16>);
pub struct HalfVec(pub(crate) Vec<f16>);

impl From<Vec<f16>> for HalfVector {
impl From<Vec<f16>> for HalfVec {
fn from(v: Vec<f16>) -> Self {
HalfVector(v)
HalfVec(v)
}
}

impl From<HalfVector> for Vec<f16> {
fn from(val: HalfVector) -> Self {
impl From<HalfVec> for Vec<f16> {
fn from(val: HalfVec) -> Self {
val.0
}
}

impl HalfVector {
impl HalfVec {
/// Returns a copy of the vector as a `Vec<f16>`.
pub fn to_vec(&self) -> Vec<f16> {
self.0.clone()
Expand All @@ -30,7 +30,7 @@ impl HalfVector {
#[cfg(any(feature = "postgres"))]
pub(crate) fn from_sql(
buf: &[u8],
) -> Result<HalfVector, Box<dyn std::error::Error + Sync + Send>> {
) -> Result<HalfVec, Box<dyn std::error::Error + Sync + Send>> {
let dim = u16::from_be_bytes(buf[0..2].try_into()?) as usize;
let unused = u16::from_be_bytes(buf[2..4].try_into()?);
if unused != 0 {
Expand All @@ -43,18 +43,18 @@ impl HalfVector {
vec.push(f16::from_be_bytes(buf[s..s + 2].try_into()?));
}

Ok(HalfVector(vec))
Ok(HalfVec(vec))
}
}

#[cfg(test)]
mod tests {
use crate::HalfVector;
use crate::HalfVec;
use half::f16;

#[test]
fn test_into() {
let vec = HalfVector::from(vec![
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
Expand All @@ -68,7 +68,7 @@ mod tests {

#[test]
fn test_to_vec() {
let vec = HalfVector::from(vec![
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
Expand All @@ -81,7 +81,7 @@ mod tests {

#[test]
fn test_as_slice() {
let vec = HalfVector::from(vec![
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use vector::Vector;
mod halfvec;

#[cfg(feature = "halfvec")]
pub use halfvec::HalfVector;
pub use halfvec::HalfVec;

#[cfg(feature = "postgres")]
mod postgres_ext;
Expand Down
28 changes: 14 additions & 14 deletions src/postgres_ext/halfvec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ use postgres::types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
use std::convert::TryInto;
use std::error::Error;

use crate::HalfVector;
use crate::HalfVec;

impl<'a> FromSql<'a> for HalfVector {
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<HalfVector, Box<dyn Error + Sync + Send>> {
HalfVector::from_sql(raw)
impl<'a> FromSql<'a> for HalfVec {
fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<HalfVec, Box<dyn Error + Sync + Send>> {
HalfVec::from_sql(raw)
}

fn accepts(ty: &Type) -> bool {
ty.name() == "halfvec"
}
}

impl ToSql for HalfVector {
impl ToSql for HalfVec {
fn to_sql(&self, _ty: &Type, w: &mut BytesMut) -> Result<IsNull, Box<dyn Error + Sync + Send>> {
let dim = self.0.len();
w.put_u16(dim.try_into()?);
Expand All @@ -37,7 +37,7 @@ impl ToSql for HalfVector {

#[cfg(test)]
mod tests {
use crate::HalfVector;
use crate::HalfVec;
use half::f16;
use postgres::binary_copy::BinaryCopyInWriter;
use postgres::types::{Kind, Type};
Expand All @@ -59,12 +59,12 @@ mod tests {
&[],
)?;

let vec = HalfVector::from(vec![
let vec = HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
]);
let vec2 = HalfVector::from(vec![
let vec2 = HalfVec::from(vec![
f16::from_f32(4.0),
f16::from_f32(5.0),
f16::from_f32(6.0),
Expand All @@ -74,7 +74,7 @@ mod tests {
&[&vec, &vec2],
)?;

let query_vec = HalfVector::from(vec![
let query_vec = HalfVec::from(vec![
f16::from_f32(3.0),
f16::from_f32(1.0),
f16::from_f32(2.0),
Expand All @@ -83,14 +83,14 @@ mod tests {
"SELECT embedding FROM postgres_halfvec_items ORDER BY embedding <-> $1 LIMIT 1",
&[&query_vec],
)?;
let res_vec: HalfVector = row.get(0);
let res_vec: HalfVec = row.get(0);
assert_eq!(vec, res_vec);
assert_eq!(
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)],
res_vec.to_vec()
);

let empty_vec = HalfVector::from(vec![]);
let empty_vec = HalfVec::from(vec![]);
let empty_res = client.execute(
"INSERT INTO postgres_halfvec_items (embedding) VALUES ($1)",
&[&empty_vec],
Expand All @@ -105,7 +105,7 @@ mod tests {
"SELECT embedding FROM postgres_halfvec_items WHERE embedding IS NULL LIMIT 1",
&[],
)?;
let null_res: Option<HalfVector> = null_row.get(0);
let null_res: Option<HalfVec> = null_row.get(0);
assert!(null_res.is_none());

// ensures binary format is correct
Expand All @@ -122,14 +122,14 @@ mod tests {
.copy_in("COPY postgres_halfvec_items (embedding) FROM STDIN WITH (FORMAT BINARY)")?;
let mut writer = BinaryCopyInWriter::new(writer, &[halfvec_type]);
writer
.write(&[&HalfVector::from(vec![
.write(&[&HalfVec::from(vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
])])
.unwrap();
writer
.write(&[&HalfVector::from(vec![
.write(&[&HalfVec::from(vec![
f16::from_f32(4.0),
f16::from_f32(5.0),
f16::from_f32(6.0),
Expand Down

0 comments on commit 13a6ff4

Please sign in to comment.