Skip to content

Commit

Permalink
abort aggregation when too many buckets are created
Browse files Browse the repository at this point in the history
Validation happens on different phases depending on the aggregation
Term: During segment collection
Histogram: At the end when converting in intermediate buckets (we preallocate empty buckets for the range) Revisit after #1370
Range: When validating the request

update CHANGELOG
  • Loading branch information
PSeitz committed May 12, 2022
1 parent 351290e commit cfa4d1b
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Unreleased
- Add [histogram](https://github.com/quickwit-oss/tantivy/pull/1306) aggregation (@PSeitz)
- Add support for fastfield on text fields (@PSeitz)
- Add terms aggregation (@PSeitz)
- API Change: `SegmentCollector.collect` changed to return a `Result`.

Tantivy 0.17
================================
Expand Down
2 changes: 1 addition & 1 deletion examples/custom_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ struct StatsSegmentCollector {
impl SegmentCollector for StatsSegmentCollector {
type Fruit = Option<Stats>;

fn collect(&mut self, doc: u32, _score: Score) -> crate::Result<()> {
fn collect(&mut self, doc: u32, _score: Score) -> tantivy::Result<()> {
let value = self.fast_field_reader.get(doc) as f64;
self.stats.count += 1;
self.stats.sum += value;
Expand Down
5 changes: 5 additions & 0 deletions src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ pub struct BucketAggregationWithAccessor {
pub(crate) field_type: Type,
pub(crate) bucket_agg: BucketAggregationType,
pub(crate) sub_aggregation: AggregationsWithAccessor,
pub(crate) bucket_count: Rc<AtomicU32>,
}

impl BucketAggregationWithAccessor {
fn try_from_bucket(
bucket: &BucketAggregationType,
sub_aggregation: &Aggregations,
reader: &SegmentReader,
bucket_count: Rc<AtomicU32>,
) -> crate::Result<BucketAggregationWithAccessor> {
let mut inverted_index = None;
let (accessor, field_type) = match &bucket {
Expand Down Expand Up @@ -97,6 +99,7 @@ impl BucketAggregationWithAccessor {
sub_aggregation: get_aggs_with_accessor_and_validate(&sub_aggregation, reader)?,
bucket_agg: bucket.clone(),
inverted_index,
bucket_count,
})
}
}
Expand Down Expand Up @@ -137,6 +140,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
aggs: &Aggregations,
reader: &SegmentReader,
) -> crate::Result<AggregationsWithAccessor> {
let bucket_count: Rc<AtomicU32> = Default::default();
let mut metrics = vec![];
let mut buckets = vec![];
for (key, agg) in aggs.iter() {
Expand All @@ -147,6 +151,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
&bucket.bucket_agg,
&bucket.sub_aggregation,
reader,
Rc::clone(&bucket_count),
)?,
)),
Aggregation::Metric(metric) => metrics.push((
Expand Down
9 changes: 8 additions & 1 deletion src/aggregation/bucket/histogram/histogram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use crate::aggregation::f64_from_fastfield_u64;
use crate::aggregation::intermediate_agg_result::{
IntermediateAggregationResults, IntermediateBucketResult, IntermediateHistogramBucketEntry,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::fastfield::{DynamicFastFieldReader, FastFieldReader};
use crate::schema::Type;
use crate::{DocId, TantivyError};
Expand Down Expand Up @@ -250,6 +252,11 @@ impl SegmentHistogramCollector {
);
};

agg_with_accessor
.bucket_count
.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
validate_bucket_count(&agg_with_accessor.bucket_count)?;

Ok(IntermediateBucketResult::Histogram { buckets })
}

Expand Down
26 changes: 20 additions & 6 deletions src/aggregation/bucket/range.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::fmt::Debug;
use std::ops::Range;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;

use fnv::FnvHashMap;
use serde::{Deserialize, Serialize};

use crate::aggregation::agg_req_with_accessor::{
Expand All @@ -9,8 +12,10 @@ use crate::aggregation::agg_req_with_accessor::{
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateRangeBucketEntry, IntermediateRangeBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key};
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::aggregation::{f64_from_fastfield_u64, f64_to_fastfield_u64, Key, SerializedKey};
use crate::fastfield::FastFieldReader;
use crate::schema::Type;
use crate::{DocId, TantivyError};
Expand Down Expand Up @@ -153,7 +158,7 @@ impl SegmentRangeCollector {
) -> crate::Result<IntermediateBucketResult> {
let field_type = self.field_type;

let buckets = self
let buckets: FnvHashMap<SerializedKey, IntermediateRangeBucketEntry> = self
.buckets
.into_iter()
.map(move |range_bucket| {
Expand All @@ -174,12 +179,13 @@ impl SegmentRangeCollector {
pub(crate) fn from_req_and_validate(
req: &RangeAggregation,
sub_aggregation: &AggregationsWithAccessor,
bucket_count: &Rc<AtomicU32>,
field_type: Type,
) -> crate::Result<Self> {
// The range input on the request is f64.
// We need to convert to u64 ranges, because we read the values as u64.
// The mapping from the conversion is monotonic so ordering is preserved.
let buckets = extend_validate_ranges(&req.ranges, &field_type)?
let buckets: Vec<_> = extend_validate_ranges(&req.ranges, &field_type)?
.iter()
.map(|range| {
let to = if range.end == u64::MAX {
Expand Down Expand Up @@ -212,6 +218,9 @@ impl SegmentRangeCollector {
})
.collect::<crate::Result<_>>()?;

bucket_count.fetch_add(buckets.len() as u32, std::sync::atomic::Ordering::Relaxed);
validate_bucket_count(bucket_count)?;

Ok(SegmentRangeCollector {
buckets,
field_type,
Expand Down Expand Up @@ -403,8 +412,13 @@ mod tests {
ranges,
};

SegmentRangeCollector::from_req_and_validate(&req, &Default::default(), field_type)
.expect("unexpected error")
SegmentRangeCollector::from_req_and_validate(
&req,
&Default::default(),
&Default::default(),
field_type,
)
.expect("unexpected error")
}

#[test]
Expand Down
57 changes: 45 additions & 12 deletions src/aggregation/bucket/term_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use crate::aggregation::agg_req_with_accessor::{
use crate::aggregation::intermediate_agg_result::{
IntermediateBucketResult, IntermediateTermBucketEntry, IntermediateTermBucketResult,
};
use crate::aggregation::segment_agg_result::SegmentAggregationResultsCollector;
use crate::aggregation::segment_agg_result::{
validate_bucket_count, SegmentAggregationResultsCollector,
};
use crate::error::DataCorruption;
use crate::fastfield::MultiValuedFastFieldReader;
use crate::schema::Type;
Expand Down Expand Up @@ -244,19 +246,23 @@ impl TermBuckets {
&mut self,
term_ids: &[u64],
doc: DocId,
bucket_with_accessor: &AggregationsWithAccessor,
bucket_with_accessor: &BucketAggregationWithAccessor,
blueprint: &Option<SegmentAggregationResultsCollector>,
) -> crate::Result<()> {
for &term_id in term_ids {
let entry = self
.entries
.entry(term_id as u32)
.or_insert_with(|| TermBucketEntry::from_blueprint(blueprint));
let entry = self.entries.entry(term_id as u32).or_insert_with(|| {
bucket_with_accessor
.bucket_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);

TermBucketEntry::from_blueprint(blueprint)
});
entry.doc_count += 1;
if let Some(sub_aggregations) = entry.sub_aggregations.as_mut() {
sub_aggregations.collect(doc, bucket_with_accessor)?;
sub_aggregations.collect(doc, &bucket_with_accessor.sub_aggregation)?;
}
}
validate_bucket_count(&bucket_with_accessor.bucket_count)?;
Ok(())
}

Expand Down Expand Up @@ -441,25 +447,25 @@ impl SegmentTermCollector {
self.term_buckets.increment_bucket(
&vals1,
docs[0],
&bucket_with_accessor.sub_aggregation,
bucket_with_accessor,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals2,
docs[1],
&bucket_with_accessor.sub_aggregation,
bucket_with_accessor,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals3,
docs[2],
&bucket_with_accessor.sub_aggregation,
bucket_with_accessor,
&self.blueprint,
)?;
self.term_buckets.increment_bucket(
&vals4,
docs[3],
&bucket_with_accessor.sub_aggregation,
bucket_with_accessor,
&self.blueprint,
)?;
}
Expand All @@ -469,7 +475,7 @@ impl SegmentTermCollector {
self.term_buckets.increment_bucket(
&vals1,
doc,
&bucket_with_accessor.sub_aggregation,
bucket_with_accessor,
&self.blueprint,
)?;
}
Expand Down Expand Up @@ -1175,6 +1181,33 @@ mod tests {
Ok(())
}

#[test]
fn terms_aggregation_term_bucket_limit() -> crate::Result<()> {
let terms: Vec<String> = (0..100_000).map(|el| el.to_string()).collect();
let terms_per_segment = vec![terms.iter().map(|el| el.as_str()).collect()];

let index = get_test_index_from_terms(true, &terms_per_segment)?;

let agg_req: Aggregations = vec![(
"my_texts".to_string(),
Aggregation::Bucket(BucketAggregation {
bucket_agg: BucketAggregationType::Terms(TermsAggregation {
field: "string_id".to_string(),
min_doc_count: Some(0),
..Default::default()
}),
sub_aggregation: Default::default(),
}),
)]
.into_iter()
.collect();

let res = exec_request_with_query(agg_req, &index, None);
assert!(res.is_err());

Ok(())
}

#[test]
fn test_json_format() -> crate::Result<()> {
let agg_req: Aggregations = vec![(
Expand Down
15 changes: 10 additions & 5 deletions src/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,9 @@ mod tests {
let mut schema_builder = Schema::builder();
let text_fieldtype = crate::schema::TextOptions::default()
.set_indexing_options(
TextFieldIndexing::default().set_index_option(IndexRecordOption::WithFreqs),
TextFieldIndexing::default()
.set_index_option(IndexRecordOption::Basic)
.set_fieldnorms(false),
)
.set_fast()
.set_stored();
Expand All @@ -435,7 +437,8 @@ mod tests {
);
let index = Index::create_in_ram(schema_builder.build());
{
let mut index_writer = index.writer_for_tests()?;
// let mut index_writer = index.writer_for_tests()?;
let mut index_writer = index.writer_with_num_threads(1, 30_000_000)?;
for values in segment_and_values {
for (i, term) in values {
let i = *i;
Expand All @@ -457,9 +460,11 @@ mod tests {
let segment_ids = index
.searchable_segment_ids()
.expect("Searchable segments failed.");
let mut index_writer = index.writer_for_tests()?;
index_writer.merge(&segment_ids).wait()?;
index_writer.wait_merging_threads()?;
if segment_ids.len() > 1 {
let mut index_writer = index.writer_for_tests()?;
index_writer.merge(&segment_ids).wait()?;
index_writer.wait_merging_threads()?;
}
}

Ok(index)
Expand Down
14 changes: 13 additions & 1 deletion src/aggregation/segment_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
//! merging.

use std::fmt::Debug;
use std::rc::Rc;
use std::sync::atomic::AtomicU32;

use super::agg_req::MetricAggregation;
use super::agg_req_with_accessor::{
Expand All @@ -16,7 +18,7 @@ use super::metric::{
};
use super::VecWithNames;
use crate::aggregation::agg_req::BucketAggregationType;
use crate::DocId;
use crate::{DocId, TantivyError};

pub(crate) const DOC_BLOCK_SIZE: usize = 64;
pub(crate) type DocBlock = [DocId; DOC_BLOCK_SIZE];
Expand Down Expand Up @@ -236,6 +238,7 @@ impl SegmentBucketResultCollector {
Ok(Self::Range(SegmentRangeCollector::from_req_and_validate(
range_req,
&req.sub_aggregation,
&req.bucket_count,
req.field_type,
)?))
}
Expand Down Expand Up @@ -273,3 +276,12 @@ impl SegmentBucketResultCollector {
Ok(())
}
}

pub(crate) fn validate_bucket_count(bucket_count: &Rc<AtomicU32>) -> crate::Result<()> {
if bucket_count.load(std::sync::atomic::Ordering::Relaxed) > 65000 {
return Err(TantivyError::InvalidArgument(
"Aborting aggregation because too many buckets were created".to_string(),
));
}
Ok(())
}
6 changes: 4 additions & 2 deletions src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,14 @@ pub trait Collector: Sync + Send {
if let Some(alive_bitset) = reader.alive_bitset() {
weight.for_each(reader, &mut |doc, score| {
if alive_bitset.is_alive(doc) {
segment_collector.collect(doc, score).unwrap(); // TODO
segment_collector.collect(doc, score)?;
}
Ok(())
})?;
} else {
weight.for_each(reader, &mut |doc, score| {
segment_collector.collect(doc, score).unwrap(); // TODO
segment_collector.collect(doc, score)?;
Ok(())
})?;
}
Ok(segment_collector.harvest())
Expand Down
6 changes: 3 additions & 3 deletions src/query/boolean_query/boolean_weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,17 @@ impl Weight for BooleanWeight {
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>,
) -> crate::Result<()> {
let scorer = self.complex_scorer::<SumWithCoordsCombiner>(reader, 1.0)?;
match scorer {
SpecializedScorer::TermUnion(term_scorers) => {
let mut union_scorer =
Union::<TermScorer, SumWithCoordsCombiner>::from(term_scorers);
for_each_scorer(&mut union_scorer, callback);
for_each_scorer(&mut union_scorer, callback)?;
}
SpecializedScorer::Other(mut scorer) => {
for_each_scorer(scorer.as_mut(), callback);
for_each_scorer(scorer.as_mut(), callback)?;
}
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/query/term_query/term_weight.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ impl Weight for TermWeight {
fn for_each(
&self,
reader: &SegmentReader,
callback: &mut dyn FnMut(DocId, Score),
callback: &mut dyn FnMut(DocId, Score) -> crate::Result<()>,
) -> crate::Result<()> {
let mut scorer = self.specialized_scorer(reader, 1.0)?;
for_each_scorer(&mut scorer, callback);
for_each_scorer(&mut scorer, callback)?;
Ok(())
}

Expand Down
Loading

0 comments on commit cfa4d1b

Please sign in to comment.