Skip to content

Commit

Permalink
set max bucket size as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
PSeitz committed May 13, 2022
1 parent 11ac451 commit 44ea731
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 73 deletions.
2 changes: 1 addition & 1 deletion examples/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> {
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req_1);
let collector = AggregationCollector::from_aggs(agg_req_1, None);

let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
Expand Down
20 changes: 16 additions & 4 deletions src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::sync::Arc;
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::segment_agg_result::BucketCount;
use super::VecWithNames;
use crate::fastfield::{
type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader,
Expand Down Expand Up @@ -62,7 +63,7 @@ pub struct BucketAggregationWithAccessor {
pub(crate) field_type: Type,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: Rc<AtomicU32>,
pub(crate) bucket_count: BucketCount,
}

impl BucketAggregationWithAccessor {
Expand All @@ -71,6 +72,7 @@ impl BucketAggregationWithAccessor {
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
Expand All @@ -96,10 +98,18 @@ impl BucketAggregationWithAccessor {
Ok(BucketAggregationWithAccessor {
accessor,
field_type,
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
sub_aggregation: get_aggs_with_accessor_and_validate(
&sub_aggregation,
reader,
bucket_count.clone(),
max_bucket_count,
)?,
bucket_agg: bucket.clone(),
inverted_index,
bucket_count,
bucket_count: BucketCount {
bucket_count,
max_bucket_count,
},
})
}
}
Expand Down Expand Up @@ -139,8 +149,9 @@ impl MetricAggregationWithAccessor {
pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
max_bucket_count: u32,
) -> crate::Result<AggregationsWithAccessor> {
let bucket_count: Rc<AtomicU32> = Default::default();
let mut metrics = vec![];
let mut buckets = vec![];
for (key, agg) in aggs.iter() {
Expand All @@ -152,6 +163,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
max_bucket_count,
)?,
)),
Aggregation::Metric(metric) => metrics.push((
Expand Down
8 changes: 3 additions & 5 deletions src/aggregation/bucket/histogram/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
use crate::schema::Type;
use crate::{DocId, TantivyError};
Expand Down Expand Up @@ -254,8 +252,8 @@ impl SegmentHistogramCollector {

agg_with_accessor
.bucket_count
.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
validate_bucket_count(&agg_with_accessor.bucket_count)?;
.add_count(buckets.len() as u32);
agg_with_accessor.bucket_count.validate_bucket_count()?;

Ok(IntermediateBucketResult::Histogram { buckets })
}
Expand Down
14 changes: 5 additions & 9 deletions src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::fmt::Debug;
use std::ops::Range;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;

use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};
Expand All @@ -12,9 +10,7 @@ use crate::aggregation::agg_req_with_accessor::{
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector};
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey};
use crate::fastfield::FastFieldReader;
use crate::schema::Type;
Expand Down Expand Up @@ -179,7 +175,7 @@ impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &Rc<AtomicU32>,
bucket_count: &BucketCount,
field_type: Type,
) -> crate::Result<Self> {
// The range input on the request is f64.
Expand Down Expand Up @@ -218,8 +214,8 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;

bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
validate_bucket_count(bucket_count)?;
bucket_count.add_count(buckets.len() as u32);
bucket_count.validate_bucket_count()?;

Ok(SegmentRangeCollector {
buckets,
Expand Down Expand Up @@ -438,7 +434,7 @@ mod tests {
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req);
let collector = AggregationCollector::from_aggs(agg_req, None);

let reader = index.reader()?;
let searcher = reader.searcher();
Expand Down
39 changes: 24 additions & 15 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ use crate::aggregation::agg_req_with_accessor::{
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector};
use crate::error::DataCorruption;
use crate::fastfield::MultiValuedFastFieldReader;
use crate::schema::Type;
Expand Down Expand Up @@ -246,23 +244,23 @@ impl TermBuckets {
&mut self,
term_ids: &[u64],
doc: DocId,
bucket_with_accessor: &BucketAggregationWithAccessor,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &BucketCount,
blueprint: &Option<SegmentAggregationResultsCollector>,
) -> crate::Result<()> {
for &term_id in term_ids {
let entry = self.entries.entry(term_id as u32).or_insert_with(|| {
bucket_with_accessor
.bucket_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
bucket_count.add_count(1);

TermBucketEntry::from_blueprint(blueprint)
});
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?;
sub_aggregations.collect(doc, &sub_aggregation)?;
}
}
validate_bucket_count(&bucket_with_accessor.bucket_count)?;
bucket_count.validate_bucket_count()?;

Ok(())
}

Expand Down Expand Up @@ -447,25 +445,29 @@ impl SegmentTermCollector {
self.term_buckets.increment_bucket(
&vals1,
docs[0],
bucket_with_accessor,
&bucket_with_accessor.sub_aggregation,
&bucket_with_accessor.bucket_count,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals2,
docs[1],
bucket_with_accessor,
&bucket_with_accessor.sub_aggregation,
&bucket_with_accessor.bucket_count,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals3,
docs[2],
bucket_with_accessor,
&bucket_with_accessor.sub_aggregation,
&bucket_with_accessor.bucket_count,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals4,
docs[3],
bucket_with_accessor,
&bucket_with_accessor.sub_aggregation,
&bucket_with_accessor.bucket_count,
&self.blueprint,
)?;
}
Expand All @@ -475,7 +477,8 @@ impl SegmentTermCollector {
self.term_buckets.increment_bucket(
&vals1,
doc,
bucket_with_accessor,
&bucket_with_accessor.sub_aggregation,
&bucket_with_accessor.bucket_count,
&self.blueprint,
)?;
}
Expand Down Expand Up @@ -1326,9 +1329,15 @@ mod bench {
let mut collector = get_collector_with_buckets(total_terms);
let vals = get_rand_terms(total_terms, num_terms);
let aggregations_with_accessor: AggregationsWithAccessor = Default::default();
let bucket_count: BucketCount = BucketCount {
bucket_count: Default::default(),
max_bucket_count: 1_000_001u32,
};
b.iter(|| {
for &val in &vals {
collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None);
collector
.increment_bucket(&[val], 0, &aggregations_with_accessor, &bucket_count, &None)
.unwrap();
}
})
}
Expand Down
42 changes: 34 additions & 8 deletions src/aggregation/collector.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::rc::Rc;

use super::agg_req::Aggregations;
use super::agg_req_with_accessor::AggregationsWithAccessor;
use super::agg_result::AggregationResults;
Expand All @@ -7,17 +9,25 @@ use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_valida
use crate::collector::{Collector, SegmentCollector};
use crate::SegmentReader;

pub const MAX_BUCKET_COUNT: u32 = 65000;

/// Collector for aggregations.
///
/// The collector collects all aggregations by the underlying aggregation request.
pub struct AggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
}

impl AggregationCollector {
/// Create collector from aggregation request.
pub fn from_aggs(agg: Aggregations) -> Self {
Self { agg }
///
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
}
}

Expand All @@ -28,15 +38,21 @@ impl AggregationCollector {
/// # Purpose
/// AggregationCollector returns `IntermediateAggregationResults` and not the final
/// `AggregationResults`, so that results from differenct indices can be merged and then converted
/// into the final `AggregationResults` via the `into()` method.
/// into the final `AggregationResults` via the `into_final_result()` method.
pub struct DistributedAggregationCollector {
agg: Aggregations,
max_bucket_count: u32,
}

impl DistributedAggregationCollector {
/// Create collector from aggregation request.
pub fn from_aggs(agg: Aggregations) -> Self {
Self { agg }
///
/// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset
pub fn from_aggs(agg: Aggregations, max_bucket_count: Option<u32>) -> Self {
Self {
agg,
max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT),
}
}
}

Expand All @@ -50,7 +66,11 @@ impl Collector for DistributedAggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader)
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
}

fn requires_scoring(&self) -> bool {
Expand All @@ -75,7 +95,11 @@ impl Collector for AggregationCollector {
_segment_local_id: crate::SegmentOrdinal,
reader: &crate::SegmentReader,
) -> crate::Result<Self::Child> {
AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader)
AggregationSegmentCollector::from_agg_req_and_reader(
&self.agg,
reader,
self.max_bucket_count,
)
}

fn requires_scoring(&self) -> bool {
Expand Down Expand Up @@ -117,8 +141,10 @@ impl AggregationSegmentCollector {
pub fn from_agg_req_and_reader(
agg: &Aggregations,
reader: &SegmentReader,
max_bucket_count: u32,
) -> crate::Result<Self> {
let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader)?;
let aggs_with_accessor =
get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?;
let result =
SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?;
Ok(AggregationSegmentCollector {
Expand Down
4 changes: 2 additions & 2 deletions src/aggregation/intermediate_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ pub struct IntermediateAggregationResults {
}

impl IntermediateAggregationResults {
/// Convert and intermediate result and its aggregation request to the final result
/// Convert intermediate result and its aggregation request to the final result.
pub(crate) fn into_final_bucket_result(
self,
req: Aggregations,
) -> crate::Result<AggregationResults> {
self.into_final_bucket_result_internal(&(req.into()))
}

/// Convert and intermediate result and its aggregation request to the final result
/// Convert intermediate result and its aggregation request to the final result.
///
/// Internal function, AggregationsInternal is used instead Aggregations, which is optimized
/// for internal processing, by splitting metric and buckets into seperate groups.
Expand Down
4 changes: 2 additions & 2 deletions src/aggregation/metric/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ mod tests {
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req_1);
let collector = AggregationCollector::from_aggs(agg_req_1, None);

let reader = index.reader()?;
let searcher = reader.searcher();
Expand Down Expand Up @@ -299,7 +299,7 @@ mod tests {
.into_iter()
.collect();

let collector = AggregationCollector::from_aggs(agg_req_1);
let collector = AggregationCollector::from_aggs(agg_req_1, None);

let searcher = reader.searcher();
let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap();
Expand Down
Loading

0 comments on commit 44ea731

Please sign in to comment.