Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace BinaryHeap for TopN #2186

Merged
merged 6 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub use self::multi_collector::{FruitHandle, MultiCollector, MultiFruit};
mod top_collector;

mod top_score_collector;
pub use self::top_score_collector::TopDocs;
pub use self::top_score_collector::{TopDocs, TopNComputer};

mod custom_score_top_collector;
pub use self::custom_score_top_collector::{CustomScorer, CustomSegmentScorer};
Expand Down
59 changes: 20 additions & 39 deletions src/collector/top_collector.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::marker::PhantomData;

use super::top_score_collector::TopNComputer;
use crate::{DocAddress, DocId, SegmentOrdinal, SegmentReader};

/// Contains a feature (field, score, etc.) of a document along with the document address.
Expand All @@ -20,6 +20,14 @@ pub(crate) struct ComparableDoc<T, D> {
pub feature: T,
pub doc: D,
}
impl<T: std::fmt::Debug, D: std::fmt::Debug> std::fmt::Debug for ComparableDoc<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComparableDoc")
.field("feature", &self.feature)
.field("doc", &self.doc)
.finish()
}
}

impl<T: PartialOrd, D: PartialOrd> PartialOrd for ComparableDoc<T, D> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Expand Down Expand Up @@ -91,18 +99,13 @@ where T: PartialOrd + Clone
if self.limit == 0 {
return Ok(Vec::new());
}
let mut top_collector = BinaryHeap::new();
let mut top_collector = TopNComputer::new(self.limit + self.offset);
for child_fruit in children {
for (feature, doc) in child_fruit {
if top_collector.len() < (self.limit + self.offset) {
top_collector.push(ComparableDoc { feature, doc });
} else if let Some(mut head) = top_collector.peek_mut() {
if head.feature < feature {
*head = ComparableDoc { feature, doc };
}
}
top_collector.push(ComparableDoc { feature, doc });
}
}

Ok(top_collector
.into_sorted_vec()
.into_iter()
Expand All @@ -111,7 +114,7 @@ where T: PartialOrd + Clone
.collect())
}

pub(crate) fn for_segment<F: PartialOrd>(
pub(crate) fn for_segment<F: PartialOrd + Clone>(
&self,
segment_id: SegmentOrdinal,
_: &SegmentReader,
Expand All @@ -136,20 +139,18 @@ where T: PartialOrd + Clone
/// The Top Collector keeps track of the K documents
/// sorted by type `T`.
///
/// The implementation is based on a `BinaryHeap`.
/// The implementation is based on a repeatedly truncating on the median after K * 2 documents
/// The theoretical complexity for collecting the top `K` out of `n` documents
/// is `O(n log K)`.
/// is `O(n + K)`.
pub(crate) struct TopSegmentCollector<T> {
limit: usize,
heap: BinaryHeap<ComparableDoc<T, DocId>>,
topn_computer: TopNComputer<T, DocId>,
segment_ord: u32,
}

impl<T: PartialOrd> TopSegmentCollector<T> {
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
fn new(segment_ord: SegmentOrdinal, limit: usize) -> TopSegmentCollector<T> {
TopSegmentCollector {
limit,
heap: BinaryHeap::with_capacity(limit),
topn_computer: TopNComputer::new(limit),
segment_ord,
}
}
Expand All @@ -158,7 +159,7 @@ impl<T: PartialOrd> TopSegmentCollector<T> {
impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
pub fn harvest(self) -> Vec<(T, DocAddress)> {
let segment_ord = self.segment_ord;
self.heap
self.topn_computer
.into_sorted_vec()
.into_iter()
.map(|comparable_doc| {
Expand All @@ -173,33 +174,13 @@ impl<T: PartialOrd + Clone> TopSegmentCollector<T> {
.collect()
}

/// Return true if more documents have been collected than the limit.
#[inline]
pub(crate) fn at_capacity(&self) -> bool {
self.heap.len() >= self.limit
}

/// Collects a document scored by the given feature
///
/// It collects documents until it has reached the max capacity. Once it reaches capacity, it
/// will compare the lowest scoring item with the given one and keep whichever is greater.
#[inline]
pub fn collect(&mut self, doc: DocId, feature: T) {
if self.at_capacity() {
// It's ok to unwrap as long as a limit of 0 is forbidden.
if let Some(limit_feature) = self.heap.peek().map(|head| head.feature.clone()) {
if limit_feature < feature {
if let Some(mut head) = self.heap.peek_mut() {
head.feature = feature;
head.doc = doc;
}
}
}
} else {
// we have not reached capacity yet, so we can just push the
// element.
self.heap.push(ComparableDoc { feature, doc });
}
self.topn_computer.push(ComparableDoc { feature, doc });
}
}

Expand Down
Loading
Loading