diff --git a/examples/aggregation.rs b/examples/aggregation.rs index 82cc0fccd3..ae11dc5a5a 100644 --- a/examples/aggregation.rs +++ b/examples/aggregation.rs @@ -117,7 +117,7 @@ fn main() -> tantivy::Result<()> { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/agg_req_with_accessor.rs b/src/aggregation/agg_req_with_accessor.rs index 10597b3de0..491faf2137 100644 --- a/src/aggregation/agg_req_with_accessor.rs +++ b/src/aggregation/agg_req_with_accessor.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation}; use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation}; use super::metric::{AverageAggregation, StatsAggregation}; +use super::segment_agg_result::BucketCount; use super::VecWithNames; use crate::fastfield::{ type_and_cardinality, DynamicFastFieldReader, FastType, MultiValuedFastFieldReader, @@ -62,7 +63,7 @@ pub struct BucketAggregationWithAccessor { pub(crate) field_type: Type, pub(crate) bucket_agg: BucketAggregationType, pub(crate) sub_aggregation: AggregationsWithAccessor, - pub(crate) bucket_count: Rc, + pub(crate) bucket_count: BucketCount, } impl BucketAggregationWithAccessor { @@ -71,6 +72,7 @@ impl BucketAggregationWithAccessor { sub_aggregation: &Aggregations, reader: &SegmentReader, bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { let mut inverted_index = None; let (accessor, field_type) = match &bucket { @@ -96,10 +98,18 @@ impl BucketAggregationWithAccessor { Ok(BucketAggregationWithAccessor { accessor, field_type, - sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?, + sub_aggregation: get_aggs_with_accessor_and_validate( + &sub_aggregation, + reader, + bucket_count.clone(), + max_bucket_count, + )?, bucket_agg: bucket.clone(), inverted_index, - bucket_count, + bucket_count: BucketCount { + bucket_count, + max_bucket_count, + }, }) } } @@ -139,8 +149,9 @@ impl MetricAggregationWithAccessor { pub(crate) fn get_aggs_with_accessor_and_validate( aggs: &Aggregations, reader: &SegmentReader, + bucket_count: Rc, + max_bucket_count: u32, ) -> crate::Result { - let bucket_count: Rc = Default::default(); let mut metrics = vec![]; let mut buckets = vec![]; for (key, agg) in aggs.iter() { @@ -152,6 +163,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate( &bucket.sub_aggregation, reader, Rc::clone(&bucket_count), + max_bucket_count, )?, )), Aggregation::Metric(metric) => metrics.push(( diff --git a/src/aggregation/bucket/histogram/histogram.rs b/src/aggregation/bucket/histogram/histogram.rs index 69111c71fc..70acf0f117 100644 --- a/src/aggregation/bucket/histogram/histogram.rs +++ b/src/aggregation/bucket/histogram/histogram.rs @@ -13,9 +13,7 @@ use crate::aggregation::f64_from_fastfield_u64; use crate::aggregation::intermediate_agg_result::{ IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector; use crate::fastfield::{DynamicFastFieldReader, FastFieldReader}; use crate::schema::Type; use crate::{DocId, TantivyError}; @@ -254,8 +252,8 @@ impl SegmentHistogramCollector { agg_with_accessor .bucket_count - .fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); - validate_bucket_count(&agg_with_accessor.bucket_count)?; + .add_count(buckets.len() as u32); + agg_with_accessor.bucket_count.validate_bucket_count()?; Ok(IntermediateBucketResult::Histogram { buckets }) } diff --git a/src/aggregation/bucket/range.rs b/src/aggregation/bucket/range.rs index 590165c158..7faa500e7c 100644 --- a/src/aggregation/bucket/range.rs +++ b/src/aggregation/bucket/range.rs @@ -1,7 +1,5 @@ use std::fmt::Debug; use std::ops::Range; -use std::rc::Rc; -use std::sync::atomic::AtomicU32; use fnv::FnvHashMap; use serde::{Deserialize, Serialize}; @@ -12,9 +10,7 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey}; use crate::fastfield::FastFieldReader; use crate::schema::Type; @@ -179,7 +175,7 @@ impl SegmentRangeCollector { pub(crate) fn from_req_and_validate( req: &RangeAggregation, sub_aggregation: &AggregationsWithAccessor, - bucket_count: &Rc, + bucket_count: &BucketCount, field_type: Type, ) -> crate::Result { // The range input on the request is f64. @@ -218,8 +214,8 @@ impl SegmentRangeCollector { }) .collect::>()?; - bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed); - validate_bucket_count(bucket_count)?; + bucket_count.add_count(buckets.len() as u32); + bucket_count.validate_bucket_count()?; Ok(SegmentRangeCollector { buckets, @@ -438,7 +434,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); diff --git a/src/aggregation/bucket/term_agg.rs b/src/aggregation/bucket/term_agg.rs index 312522017e..52e120cc96 100644 --- a/src/aggregation/bucket/term_agg.rs +++ b/src/aggregation/bucket/term_agg.rs @@ -11,9 +11,7 @@ use crate::aggregation::agg_req_with_accessor::{ use crate::aggregation::intermediate_agg_result::{ IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult, }; -use crate::aggregation::segment_agg_result::{ - validate_bucket_count, SegmentAggregationResultsCollector, -}; +use crate::aggregation::segment_agg_result::{BucketCount, SegmentAggregationResultsCollector}; use crate::error::DataCorruption; use crate::fastfield::MultiValuedFastFieldReader; use crate::schema::Type; @@ -246,23 +244,23 @@ impl TermBuckets { &mut self, term_ids: &[u64], doc: DocId, - bucket_with_accessor: &BucketAggregationWithAccessor, + sub_aggregation: &AggregationsWithAccessor, + bucket_count: &BucketCount, blueprint: &Option, ) -> crate::Result<()> { for &term_id in term_ids { let entry = self.entries.entry(term_id as u32).or_insert_with(|| { - bucket_with_accessor - .bucket_count - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + bucket_count.add_count(1); TermBucketEntry::from_blueprint(blueprint) }); entry.doc_count += 1; if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() { - sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?; + sub_aggregations.collect(doc, &sub_aggregation)?; } } - validate_bucket_count(&bucket_with_accessor.bucket_count)?; + bucket_count.validate_bucket_count()?; + Ok(()) } @@ -447,25 +445,29 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, docs[0], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals2, docs[1], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals3, docs[2], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; self.term_buckets.increment_bucket( &vals4, docs[3], - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; } @@ -475,7 +477,8 @@ impl SegmentTermCollector { self.term_buckets.increment_bucket( &vals1, doc, - bucket_with_accessor, + &bucket_with_accessor.sub_aggregation, + &bucket_with_accessor.bucket_count, &self.blueprint, )?; } @@ -1326,9 +1329,15 @@ mod bench { let mut collector = get_collector_with_buckets(total_terms); let vals = get_rand_terms(total_terms, num_terms); let aggregations_with_accessor: AggregationsWithAccessor = Default::default(); + let bucket_count: BucketCount = BucketCount { + bucket_count: Default::default(), + max_bucket_count: 1_000_001u32, + }; b.iter(|| { for &val in &vals { - collector.increment_bucket(&[val], 0, &aggregations_with_accessor, &None); + collector + .increment_bucket(&[val], 0, &aggregations_with_accessor, &bucket_count, &None) + .unwrap(); } }) } diff --git a/src/aggregation/collector.rs b/src/aggregation/collector.rs index 69913931cd..cf2848f383 100644 --- a/src/aggregation/collector.rs +++ b/src/aggregation/collector.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use super::agg_req::Aggregations; use super::agg_req_with_accessor::AggregationsWithAccessor; use super::agg_result::AggregationResults; @@ -7,17 +9,25 @@ use crate::aggregation::agg_req_with_accessor::get_aggs_with_accessor_and_valida use crate::collector::{Collector, SegmentCollector}; use crate::SegmentReader; +pub const MAX_BUCKET_COUNT: u32 = 65000; + /// Collector for aggregations. /// /// The collector collects all aggregations by the underlying aggregation request. pub struct AggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl AggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -28,15 +38,21 @@ impl AggregationCollector { /// # Purpose /// AggregationCollector returns `IntermediateAggregationResults` and not the final /// `AggregationResults`, so that results from differenct indices can be merged and then converted -/// into the final `AggregationResults` via the `into()` method. +/// into the final `AggregationResults` via the `into_final_result()` method. pub struct DistributedAggregationCollector { agg: Aggregations, + max_bucket_count: u32, } impl DistributedAggregationCollector { /// Create collector from aggregation request. - pub fn from_aggs(agg: Aggregations) -> Self { - Self { agg } + /// + /// max_bucket_count will default to `MAX_BUCKET_COUNT` (65000) when unset + pub fn from_aggs(agg: Aggregations, max_bucket_count: Option) -> Self { + Self { + agg, + max_bucket_count: max_bucket_count.unwrap_or(MAX_BUCKET_COUNT), + } } } @@ -50,7 +66,11 @@ impl Collector for DistributedAggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -75,7 +95,11 @@ impl Collector for AggregationCollector { _segment_local_id: crate::SegmentOrdinal, reader: &crate::SegmentReader, ) -> crate::Result { - AggregationSegmentCollector::from_agg_req_and_reader(&self.agg, reader) + AggregationSegmentCollector::from_agg_req_and_reader( + &self.agg, + reader, + self.max_bucket_count, + ) } fn requires_scoring(&self) -> bool { @@ -117,8 +141,10 @@ impl AggregationSegmentCollector { pub fn from_agg_req_and_reader( agg: &Aggregations, reader: &SegmentReader, + max_bucket_count: u32, ) -> crate::Result { - let aggs_with_accessor = get_aggs_with_accessor_and_validate(agg, reader)?; + let aggs_with_accessor = + get_aggs_with_accessor_and_validate(agg, reader, Rc::default(), max_bucket_count)?; let result = SegmentAggregationResultsCollector::from_req_and_validate(&aggs_with_accessor)?; Ok(AggregationSegmentCollector { diff --git a/src/aggregation/intermediate_agg_result.rs b/src/aggregation/intermediate_agg_result.rs index cb2f9f416c..20eef59c07 100644 --- a/src/aggregation/intermediate_agg_result.rs +++ b/src/aggregation/intermediate_agg_result.rs @@ -35,7 +35,7 @@ pub struct IntermediateAggregationResults { } impl IntermediateAggregationResults { - /// Convert and intermediate result and its aggregation request to the final result + /// Convert intermediate result and its aggregation request to the final result. pub(crate) fn into_final_bucket_result( self, req: Aggregations, @@ -43,7 +43,7 @@ impl IntermediateAggregationResults { self.into_final_bucket_result_internal(&(req.into())) } - /// Convert and intermediate result and its aggregation request to the final result + /// Convert intermediate result and its aggregation request to the final result. /// /// Internal function, AggregationsInternal is used instead Aggregations, which is optimized /// for internal processing, by splitting metric and buckets into seperate groups. diff --git a/src/aggregation/metric/stats.rs b/src/aggregation/metric/stats.rs index 0498ffbe80..2f704b17d0 100644 --- a/src/aggregation/metric/stats.rs +++ b/src/aggregation/metric/stats.rs @@ -222,7 +222,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -299,7 +299,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); diff --git a/src/aggregation/mod.rs b/src/aggregation/mod.rs index 7fe2d82847..7f6f8378ce 100644 --- a/src/aggregation/mod.rs +++ b/src/aggregation/mod.rs @@ -28,7 +28,7 @@ //! //! ```verbatim //! let agg_req: Aggregations = serde_json::from_str(json_request_string).unwrap(); -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! let searcher = reader.searcher(); //! let agg_res = searcher.search(&term_query, &collector).unwrap_err(); //! let json_response_string: String = &serde_json::to_string(&agg_res)?; @@ -68,7 +68,7 @@ //! .into_iter() //! .collect(); //! -//! let collector = AggregationCollector::from_aggs(agg_req); +//! let collector = AggregationCollector::from_aggs(agg_req, None); //! //! let searcher = reader.searcher(); //! let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap(); @@ -358,7 +358,7 @@ mod tests { index: &Index, query: Option<(&str, &str)>, ) -> crate::Result { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let reader = index.reader()?; let searcher = reader.searcher(); @@ -547,7 +547,7 @@ mod tests { .unwrap(); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); let intermediate_agg_result = searcher.search(&AllQuery, &collector).unwrap(); @@ -555,7 +555,7 @@ mod tests { .into_final_bucket_result(agg_req) .unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&AllQuery, &collector).unwrap() @@ -792,7 +792,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = searcher.search(&term_query, &collector).unwrap(); @@ -982,7 +982,7 @@ mod tests { assert_eq!(field_names, vec!["text".to_string()].into_iter().collect()); let agg_res: AggregationResults = if use_distributed_collector { - let collector = DistributedAggregationCollector::from_aggs(agg_req.clone()); + let collector = DistributedAggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); let res = searcher.search(&term_query, &collector).unwrap(); @@ -991,7 +991,7 @@ mod tests { serde_json::from_str(&serde_json::to_string(&res).unwrap()).unwrap(); res.into_final_bucket_result(agg_req.clone()).unwrap() } else { - let collector = AggregationCollector::from_aggs(agg_req.clone()); + let collector = AggregationCollector::from_aggs(agg_req.clone(), None); let searcher = reader.searcher(); searcher.search(&term_query, &collector).unwrap() @@ -1049,7 +1049,7 @@ mod tests { ); // Test empty result set - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); searcher.search(&query_with_no_hits, &collector).unwrap(); @@ -1114,7 +1114,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); @@ -1227,7 +1227,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1258,7 +1258,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1289,7 +1289,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1328,7 +1328,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1357,7 +1357,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1386,7 +1386,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req); + let collector = AggregationCollector::from_aggs(agg_req, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1422,7 +1422,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1457,7 +1457,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1496,7 +1496,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1526,7 +1526,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = @@ -1582,7 +1582,7 @@ mod tests { .into_iter() .collect(); - let collector = AggregationCollector::from_aggs(agg_req_1); + let collector = AggregationCollector::from_aggs(agg_req_1, None); let searcher = reader.searcher(); let agg_res: AggregationResults = diff --git a/src/aggregation/segment_agg_result.rs b/src/aggregation/segment_agg_result.rs index 57c545f7d6..fe07400897 100644 --- a/src/aggregation/segment_agg_result.rs +++ b/src/aggregation/segment_agg_result.rs @@ -12,6 +12,7 @@ use super::agg_req_with_accessor::{ AggregationsWithAccessor, BucketAggregationWithAccessor, MetricAggregationWithAccessor, }; use super::bucket::{SegmentHistogramCollector, SegmentRangeCollector, SegmentTermCollector}; +use super::collector::MAX_BUCKET_COUNT; use super::intermediate_agg_result::{IntermediateAggregationResults, IntermediateBucketResult}; use super::metric::{ AverageAggregation, SegmentAverageCollector, SegmentStatsCollector, StatsAggregation, @@ -277,11 +278,36 @@ impl SegmentBucketResultCollector { } } -pub(crate) fn validate_bucket_count(bucket_count: &Rc) -> crate::Result<()> { - if bucket_count.load(std::sync::atomic::Ordering::Relaxed) > 65000 { - return Err(TantivyError::InvalidArgument( - "Aborting aggregation because too many buckets were created".to_string(), - )); +#[derive(Clone)] +pub(crate) struct BucketCount { + /// The counter which is shared between the aggregations for one request. + pub(crate) bucket_count: Rc, + pub(crate) max_bucket_count: u32, +} + +impl Default for BucketCount { + fn default() -> Self { + Self { + bucket_count: Default::default(), + max_bucket_count: MAX_BUCKET_COUNT, + } + } +} + +impl BucketCount { + pub(crate) fn validate_bucket_count(&self) -> crate::Result<()> { + if self.get_count() > self.max_bucket_count { + return Err(TantivyError::InvalidArgument( + "Aborting aggregation because too many buckets were created".to_string(), + )); + } + Ok(()) + } + pub(crate) fn add_count(&self, count: u32) { + self.bucket_count + .fetch_add(count as u32, std::sync::atomic::Ordering::Relaxed); + } + pub(crate) fn get_count(&self) -> u32 { + self.bucket_count.load(std::sync::atomic::Ordering::Relaxed) } - Ok(()) }