Skip to content

Commit

Permalink
Add count, min, max, and sum aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
guilload committed Jan 16, 2023
1 parent 6ca9a47 commit f2dad19
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 71 deletions.
62 changes: 57 additions & 5 deletions src/aggregation/agg_req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ use serde::{Deserialize, Serialize};

pub use super::bucket::RangeAggregation;
use super::bucket::{HistogramAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
};
use super::VecWithNames;

/// The top-level aggregation request structure, which contains [`Aggregation`] and their user
Expand Down Expand Up @@ -237,27 +240,76 @@ impl BucketAggregationType {
/// called multi-value numeric metrics aggregation.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum MetricAggregation {
/// Calculates the average.
/// Computes the average.
#[serde(rename = "avg")]
Average(AverageAggregation),
/// Counts the number of extracted values.
#[serde(rename = "value_count")]
Count(CountAggregation),
/// Finds the maximum value.
#[serde(rename = "max")]
Max(MaxAggregation),
/// Finds the minimum value.
#[serde(rename = "min")]
Min(MinAggregation),
/// Calculates stats sum, average, min, max, standard_deviation on a field.
#[serde(rename = "stats")]
Stats(StatsAggregation),
/// Computes the sum.
#[serde(rename = "sum")]
Sum(SumAggregation),
}

impl MetricAggregation {
fn get_fast_field_names(&self, fast_field_names: &mut HashSet<String>) {
match self {
MetricAggregation::Average(avg) => fast_field_names.insert(avg.field.to_string()),
MetricAggregation::Stats(stats) => fast_field_names.insert(stats.field.to_string()),
let fast_field_name = match self {
MetricAggregation::Average(avg) => avg.field_name(),
MetricAggregation::Count(count) => count.field_name(),
MetricAggregation::Max(max) => max.field_name(),
MetricAggregation::Min(min) => min.field_name(),
MetricAggregation::Stats(stats) => stats.field_name(),
MetricAggregation::Sum(sum) => sum.field_name(),
};
fast_field_names.insert(fast_field_name.to_string());
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_metric_aggregations_deser() {
let agg_req_json = r#"{
"price_avg": { "avg": { "field": "price" } },
"price_count": { "value_count": { "field": "price" } },
"price_max": { "max": { "field": "price" } },
"price_min": { "min": { "field": "price" } },
"price_stats": { "stats": { "field": "price" } },
"price_sum": { "sum": { "field": "price" } }
}"#;
let agg_req: Aggregations = serde_json::from_str(agg_req_json).unwrap();

assert!(
matches!(agg_req.get("price_avg").unwrap(), Aggregation::Metric(MetricAggregation::Average(avg)) if avg.field == "price")
);
assert!(
matches!(agg_req.get("price_count").unwrap(), Aggregation::Metric(MetricAggregation::Count(count)) if count.field == "price")
);
assert!(
matches!(agg_req.get("price_max").unwrap(), Aggregation::Metric(MetricAggregation::Max(max)) if max.field == "price")
);
assert!(
matches!(agg_req.get("price_min").unwrap(), Aggregation::Metric(MetricAggregation::Min(min)) if min.field == "price")
);
assert!(
matches!(agg_req.get("price_stats").unwrap(), Aggregation::Metric(MetricAggregation::Stats(stats)) if stats.field == "price")
);
assert!(
matches!(agg_req.get("price_sum").unwrap(), Aggregation::Metric(MetricAggregation::Sum(sum)) if sum.field == "price")
);
}

#[test]
fn serialize_to_json_test() {
let agg_req1: Aggregations = vec![(
Expand Down
11 changes: 9 additions & 2 deletions src/aggregation/agg_req_with_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ use fastfield_codecs::Column;

use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
use super::bucket::{HistogramAggregation, RangeAggregation, TermsAggregation};
use super::metric::{AverageAggregation, StatsAggregation};
use super::metric::{
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
SumAggregation,
};
use super::segment_agg_result::BucketCount;
use super::VecWithNames;
use crate::fastfield::{type_and_cardinality, MultiValuedFastFieldReader};
Expand Down Expand Up @@ -134,7 +137,11 @@ impl MetricAggregationWithAccessor {
) -> crate::Result<MetricAggregationWithAccessor> {
match &metric {
MetricAggregation::Average(AverageAggregation { field: field_name })
| MetricAggregation::Stats(StatsAggregation { field: field_name }) => {
| MetricAggregation::Count(CountAggregation { field: field_name })
| MetricAggregation::Max(MaxAggregation { field: field_name })
| MetricAggregation::Min(MinAggregation { field: field_name })
| MetricAggregation::Stats(StatsAggregation { field: field_name })
| MetricAggregation::Sum(SumAggregation { field: field_name }) => {
let (accessor, field_type) =
get_ff_reader_and_validate(reader, field_name, Cardinality::SingleValue)?;

Expand Down
40 changes: 32 additions & 8 deletions src/aggregation/agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl AggregationResults {
} else {
// Validation is be done during request parsing, so we can't reach this state.
Err(TantivyError::InternalError(format!(
"Can't find aggregation {:?} in sub_aggregations",
"Can't find aggregation {:?} in sub-aggregations",
name
)))
}
Expand Down Expand Up @@ -70,27 +70,51 @@ impl AggregationResult {
pub enum MetricResult {
/// Average metric result.
Average(SingleMetricResult),
/// Count metric result.
Count(SingleMetricResult),
/// Max metric result.
Max(SingleMetricResult),
/// Min metric result.
Min(SingleMetricResult),
/// Stats metric result.
Stats(Stats),
/// Sum metric result.
Sum(SingleMetricResult),
}

impl MetricResult {
fn get_value(&self, agg_property: &str) -> crate::Result<Option<f64>> {
match self {
MetricResult::Average(avg) => Ok(avg.value),
MetricResult::Count(count) => Ok(count.value),
MetricResult::Max(max) => Ok(max.value),
MetricResult::Min(min) => Ok(min.value),
MetricResult::Stats(stats) => stats.get_value(agg_property),
MetricResult::Sum(sum) => Ok(sum.value),
}
}
}
impl From<IntermediateMetricResult> for MetricResult {
fn from(metric: IntermediateMetricResult) -> Self {
match metric {
IntermediateMetricResult::Average(avg_data) => {
MetricResult::Average(avg_data.finalize().into())
IntermediateMetricResult::Average(intermediate_avg) => {
MetricResult::Average(intermediate_avg.finalize().into())
}
IntermediateMetricResult::Count(intermediate_count) => {
MetricResult::Count(intermediate_count.finalize().into())
}
IntermediateMetricResult::Max(intermediate_max) => {
MetricResult::Max(intermediate_max.finalize().into())
}
IntermediateMetricResult::Min(intermediate_min) => {
MetricResult::Min(intermediate_min.finalize().into())
}
IntermediateMetricResult::Stats(intermediate_stats) => {
MetricResult::Stats(intermediate_stats.finalize())
}
IntermediateMetricResult::Sum(intermediate_sum) => {
MetricResult::Sum(intermediate_sum.finalize().into())
}
}
}
}
Expand All @@ -100,13 +124,13 @@ impl From<IntermediateMetricResult> for MetricResult {
#[serde(untagged)]
pub enum BucketResult {
/// This is the range entry for a bucket, which contains a key, count, from, to, and optionally
/// sub_aggregations.
/// sub-aggregations.
Range {
/// The range buckets sorted by range.
buckets: BucketEntries<RangeBucketEntry>,
},
/// This is the histogram entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
/// sub-aggregations.
Histogram {
/// The buckets.
///
Expand Down Expand Up @@ -151,7 +175,7 @@ pub enum BucketEntries<T> {
}

/// This is the default entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
/// sub-aggregations.
///
/// # JSON Format
/// ```json
Expand Down Expand Up @@ -201,7 +225,7 @@ impl GetDocCount for BucketEntry {
}

/// This is the range entry for a bucket, which contains a key, count, and optionally
/// sub_aggregations.
/// sub-aggregations.
///
/// # JSON Format
/// ```json
Expand Down Expand Up @@ -237,7 +261,7 @@ pub struct RangeBucketEntry {
/// Number of documents in the bucket.
pub doc_count: u64,
#[serde(flatten)]
/// sub-aggregations in this bucket.
/// Sub-aggregations in this bucket.
pub sub_aggregation: AggregationResults,
/// The from range of the bucket. Equals `f64::MIN` when `None`.
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down
62 changes: 53 additions & 9 deletions src/aggregation/intermediate_agg_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ use super::bucket::{
cut_off_buckets, get_agg_name_and_property, intermediate_histogram_buckets_to_final_buckets,
GetDocCount, Order, OrderTarget, SegmentHistogramBucketEntry, TermsAggregation,
};
use super::metric::{IntermediateAverage, IntermediateStats};
use super::metric::{
IntermediateAverage, IntermediateCount, IntermediateMax, IntermediateMin, IntermediateStats,
IntermediateSum,
};
use super::segment_agg_result::SegmentMetricResultCollector;
use super::{format_date, Key, SerializedKey, VecWithNames};
use crate::aggregation::agg_result::{AggregationResults, BucketEntries, BucketEntry};
Expand Down Expand Up @@ -204,22 +207,42 @@ pub enum IntermediateAggregationResult {
/// Holds the intermediate data for metric results
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum IntermediateMetricResult {
/// Intermediate average result
/// Intermediate average result.
Average(IntermediateAverage),
/// Intermediate stats result
/// Intermediate count result.
Count(IntermediateCount),
/// Intermediate max result.
Max(IntermediateMax),
/// Intermediate min result.
Min(IntermediateMin),
/// Intermediate stats result.
Stats(IntermediateStats),
/// Intermediate sum result.
Sum(IntermediateSum),
}

impl From<SegmentMetricResultCollector> for IntermediateMetricResult {
fn from(tree: SegmentMetricResultCollector) -> Self {
match tree {
SegmentMetricResultCollector::Stats(collector) => match collector.collecting_for {
super::metric::SegmentStatsType::Average => IntermediateMetricResult::Average(
IntermediateAverage::from_collector(collector),
),
super::metric::SegmentStatsType::Count => {
IntermediateMetricResult::Count(IntermediateCount::from_collector(collector))
}
super::metric::SegmentStatsType::Max => {
IntermediateMetricResult::Max(IntermediateMax::from_collector(collector))
}
super::metric::SegmentStatsType::Min => {
IntermediateMetricResult::Min(IntermediateMin::from_collector(collector))
}
super::metric::SegmentStatsType::Stats => {
IntermediateMetricResult::Stats(collector.stats)
}
super::metric::SegmentStatsType::Avg => IntermediateMetricResult::Average(
IntermediateAverage::from_collector(collector),
),
super::metric::SegmentStatsType::Sum => {
IntermediateMetricResult::Sum(IntermediateSum::from_collector(collector))
}
},
}
}
Expand All @@ -231,25 +254,46 @@ impl IntermediateMetricResult {
MetricAggregation::Average(_) => {
IntermediateMetricResult::Average(IntermediateAverage::default())
}
MetricAggregation::Count(_) => {
IntermediateMetricResult::Count(IntermediateCount::default())
}
MetricAggregation::Max(_) => IntermediateMetricResult::Max(IntermediateMax::default()),
MetricAggregation::Min(_) => IntermediateMetricResult::Min(IntermediateMin::default()),
MetricAggregation::Stats(_) => {
IntermediateMetricResult::Stats(IntermediateStats::default())
}
MetricAggregation::Sum(_) => IntermediateMetricResult::Sum(IntermediateSum::default()),
}
}
fn merge_fruits(&mut self, other: IntermediateMetricResult) {
match (self, other) {
(
IntermediateMetricResult::Average(avg_data_left),
IntermediateMetricResult::Average(avg_data_right),
IntermediateMetricResult::Average(avg_left),
IntermediateMetricResult::Average(avg_right),
) => {
avg_left.merge_fruits(avg_right);
}
(
IntermediateMetricResult::Count(count_left),
IntermediateMetricResult::Count(count_right),
) => {
avg_data_left.merge_fruits(avg_data_right);
count_left.merge_fruits(count_right);
}
(IntermediateMetricResult::Max(max_left), IntermediateMetricResult::Max(max_right)) => {
max_left.merge_fruits(max_right);
}
(IntermediateMetricResult::Min(min_left), IntermediateMetricResult::Min(min_right)) => {
min_left.merge_fruits(min_right);
}
(
IntermediateMetricResult::Stats(stats_left),
IntermediateMetricResult::Stats(stats_right),
) => {
stats_left.merge_fruits(stats_right);
}
(IntermediateMetricResult::Sum(sum_left), IntermediateMetricResult::Sum(sum_right)) => {
sum_left.merge_fruits(sum_right);
}
_ => {
panic!("incompatible fruit types in tree");
}
Expand Down
Loading

0 comments on commit f2dad19

Please sign in to comment.