Skip to content

Commit

Permalink
refactor: rewrite error-based reranker
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <usamoi@outlook.com>
  • Loading branch information
usamoi committed Aug 28, 2024
1 parent 84601d9 commit c302125
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 61 deletions.
42 changes: 42 additions & 0 deletions crates/rabitq/src/quant/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use base::always_equal::AlwaysEqual;
use base::distance::Distance;
use base::search::RerankerPop;
use std::cmp::Reverse;
use std::collections::BinaryHeap;

pub struct ErrorFlatReranker<T, R> {
rerank: R,
heap: BinaryHeap<(Reverse<Distance>, AlwaysEqual<u32>)>,
cache: BinaryHeap<(Reverse<Distance>, AlwaysEqual<u32>, AlwaysEqual<T>)>,
}

impl<T, R> ErrorFlatReranker<T, R> {
pub fn new(heap: Vec<(Reverse<Distance>, AlwaysEqual<u32>)>, rerank: R) -> Self
where
R: Fn(u32) -> (Distance, T),
{
Self {
rerank,
heap: heap.into(),
cache: BinaryHeap::new(),
}
}
}

impl<T, R> RerankerPop<T> for ErrorFlatReranker<T, R>
where
R: Fn(u32) -> (Distance, T),
{
fn pop(&mut self) -> Option<(Distance, u32, T)> {
while !self.heap.is_empty()
&& self.heap.peek().map(|x| x.0) > self.cache.peek().map(|x| x.0)
{
let (_, AlwaysEqual(u)) = self.heap.pop().unwrap();
let (dis_u, pay_u) = (self.rerank)(u);
self.cache
.push((Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)));
}
let (Reverse(dist), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.cache.pop()?;
Some((dist, u, pay_u))
}
}
58 changes: 0 additions & 58 deletions crates/rabitq/src/quant/error_based.rs

This file was deleted.

2 changes: 1 addition & 1 deletion crates/rabitq/src/quant/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pub mod error_based;
pub mod error;
pub mod quantization;
pub mod quantizer;
4 changes: 2 additions & 2 deletions crates/rabitq/src/quant/quantizer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::error_based::ErrorBasedReranker;
use super::error::ErrorFlatReranker;
use crate::operator::OperatorRabitq;
use base::always_equal::AlwaysEqual;
use base::distance::Distance;
Expand Down Expand Up @@ -198,6 +198,6 @@ impl<O: OperatorRabitq> RabitqQuantizer<O> {
heap: Vec<(Reverse<Distance>, AlwaysEqual<u32>)>,
r: impl Fn(u32) -> (Distance, T) + 'a,
) -> impl RerankerPop<T> + 'a {
ErrorBasedReranker::new(heap, r)
ErrorFlatReranker::new(heap, r)
}
}

0 comments on commit c302125

Please sign in to comment.