diff --git a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java index 9e3f9425d352b..2ea40c396f782 100644 --- a/server/src/main/java/org/opensearch/index/query/QueryShardContext.java +++ b/server/src/main/java/org/opensearch/index/query/QueryShardContext.java @@ -79,6 +79,8 @@ import org.opensearch.script.ScriptFactory; import org.opensearch.script.ScriptService; import org.opensearch.search.aggregations.AggregatorFactory; +import org.opensearch.search.aggregations.metrics.MaxAggregatorFactory; +import org.opensearch.search.aggregations.metrics.MinAggregatorFactory; import org.opensearch.search.aggregations.metrics.SumAggregatorFactory; import org.opensearch.search.aggregations.support.AggregationUsageService; import org.opensearch.search.aggregations.support.ValuesSourceRegistry; @@ -585,7 +587,7 @@ private Map>> getStarTreePredicates(QueryBuilder qu } public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType compositeIndexFieldInfo, AggregatorFactory aggregatorFactory) { - String field = null; + String field; Map> supportedMetrics = compositeIndexFieldInfo.getMetrics() .stream() .collect(Collectors.toMap(Metric::getField, Metric::getMetrics)); @@ -595,14 +597,16 @@ public boolean validateStarTreeMetricSuport(CompositeDataCubeFieldType composite return false; } - // TODO: increment supported aggregation type if (aggregatorFactory instanceof SumAggregatorFactory) { field = ((SumAggregatorFactory) aggregatorFactory).getField(); - if (supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM)) { - return true; - } + return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.SUM); + } else if (aggregatorFactory instanceof MaxAggregatorFactory) { + field = ((MaxAggregatorFactory) aggregatorFactory).getField(); + return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MAX); + } else if (aggregatorFactory instanceof MinAggregatorFactory) { + field = ((MinAggregatorFactory) aggregatorFactory).getField(); + return supportedMetrics.containsKey(field) && supportedMetrics.get(field).contains(MetricStat.MIN); } - return false; } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java index 8108b8a726856..bb4301c34d089 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java @@ -34,12 +34,16 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.Bits; import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.DoubleArray; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils; import org.opensearch.index.fielddata.NumericDoubleValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.DocValueFormat; @@ -51,6 +55,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.startree.OriginalOrStarTreeQuery; +import org.opensearch.search.startree.StarTreeQuery; import java.io.IOException; import java.util.Arrays; @@ -120,6 +126,16 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc throw new CollectionTerminatedException(); } } + + if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) { + StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery(); + return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree()); + } + return getDefaultLeafCollector(ctx, sub); + } + + private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); final NumericDoubleValues values = MultiValueMode.MAX.select(allValues); @@ -143,6 +159,34 @@ public void collect(int doc, long bucket) throws IOException { }; } + private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) + throws IOException { + StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree); + String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName(); + String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "max"); + SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName); + + final BigArrays bigArrays = context.bigArrays(); + final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); + return new LeafBucketCollectorBase(sub, allValues) { + + @Override + public void collect(int doc, long bucket) throws IOException { + if (bucket >= maxes.size()) { + long from = maxes.size(); + maxes = bigArrays.grow(maxes, bucket + 1); + maxes.fill(from, maxes.size(), Double.NEGATIVE_INFINITY); + } + if (values.advanceExact(doc)) { + final double value = Double.longBitsToDouble(values.nextValue()); + double max = maxes.get(bucket); + max = Math.max(max, value); + maxes.set(bucket, max); + } + } + }; + } + @Override public double metric(long owningBucketOrd) { if (valuesSource == null || owningBucketOrd >= maxes.size()) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java index 4fe936c8b7797..0d537745126d3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class MaxAggregatorFactory extends ValuesSourceAggregatorFactory { +public class MaxAggregatorFactory extends ValuesSourceAggregatorFactory { static void registerAggregators(ValuesSourceRegistry.Builder builder) { builder.register( diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java index 946057e42ac88..8303e5e025b61 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java @@ -34,12 +34,16 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.Bits; import org.opensearch.common.lease.Releasables; import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.DoubleArray; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues; +import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils; import org.opensearch.index.fielddata.NumericDoubleValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; import org.opensearch.search.DocValueFormat; @@ -51,6 +55,8 @@ import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.aggregations.support.ValuesSourceConfig; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.startree.OriginalOrStarTreeQuery; +import org.opensearch.search.startree.StarTreeQuery; import java.io.IOException; import java.util.Map; @@ -119,6 +125,16 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc throw new CollectionTerminatedException(); } } + + if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) { + StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery(); + return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree()); + } + return getDefaultLeafCollector(ctx, sub); + } + + private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) + throws IOException { final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); final NumericDoubleValues values = MultiValueMode.MIN.select(allValues); @@ -138,10 +154,38 @@ public void collect(int doc, long bucket) throws IOException { mins.set(bucket, min); } } + }; + } + private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) + throws IOException { + StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree); + String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName(); + String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "min"); + SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName); + + final BigArrays bigArrays = context.bigArrays(); + final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); + return new LeafBucketCollectorBase(sub, allValues) { + + @Override + public void collect(int doc, long bucket) throws IOException { + if (bucket >= mins.size()) { + long from = mins.size(); + mins = bigArrays.grow(mins, bucket + 1); + mins.fill(from, mins.size(), Double.POSITIVE_INFINITY); + } + if (values.advanceExact(doc)) { + final double value = Double.longBitsToDouble(values.nextValue()); + double min = mins.get(bucket); + min = Math.min(min, value); + mins.set(bucket, min); + } + } }; } + @Override public double metric(long owningBucketOrd) { if (valuesSource == null || owningBucketOrd >= mins.size()) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java index 58fbe5edefd12..2b7c8a4cc8c9c 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregatorFactory.java @@ -52,7 +52,7 @@ * * @opensearch.internal */ -class MinAggregatorFactory extends ValuesSourceAggregatorFactory { +public class MinAggregatorFactory extends ValuesSourceAggregatorFactory { static void registerAggregators(ValuesSourceRegistry.Builder builder) { builder.register( diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java index 3415a10679b3b..b4af1d4768d11 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java @@ -92,6 +92,10 @@ public ScoreMode scoreMode() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + if (valuesSource == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + if (context.query() instanceof OriginalOrStarTreeQuery && ((OriginalOrStarTreeQuery) context.query()).isStarTreeUsed()) { StarTreeQuery starTreeQuery = ((OriginalOrStarTreeQuery) context.query()).getStarTreeQuery(); return getStarTreeLeafCollector(ctx, sub, starTreeQuery.getStarTree()); @@ -100,9 +104,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc } private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - if (valuesSource == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; - } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); final CompensatedSum kahanSummation = new CompensatedSum(0, 0); @@ -134,15 +135,14 @@ public void collect(int doc, long bucket) throws IOException { private LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { - final BigArrays bigArrays = context.bigArrays(); - final CompensatedSum kahanSummation = new CompensatedSum(0, 0); - StarTreeValues starTreeValues = getStarTreeValues(ctx, starTree); String fieldName = ((ValuesSource.Numeric.FieldData) valuesSource).getIndexFieldName(); String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, "sum"); - SortedNumericDocValues values = (SortedNumericDocValues) starTreeValues.getMetricDocValuesIteratorMap().get(metricName); + final BigArrays bigArrays = context.bigArrays(); + final CompensatedSum kahanSummation = new CompensatedSum(0, 0); + return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long bucket) throws IOException { diff --git a/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java b/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java index bc8ef51bfb537..c79eb0f8709ae 100644 --- a/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java +++ b/server/src/main/java/org/opensearch/search/startree/OriginalOrStarTreeQuery.java @@ -45,12 +45,12 @@ public void visit(QueryVisitor queryVisitor) { @Override public boolean equals(Object o) { - return false; + return true; } @Override public int hashCode() { - return 0; + return originalQuery.hashCode(); } @Override diff --git a/server/src/test/java/org/opensearch/index/codec/composite99/datacube/startree/StarTreeDocValuesFormatTests.java b/server/src/test/java/org/opensearch/index/codec/composite99/datacube/startree/StarTreeDocValuesFormatTests.java index b0667f6d44fe4..f06e369a72a17 100644 --- a/server/src/test/java/org/opensearch/index/codec/composite99/datacube/startree/StarTreeDocValuesFormatTests.java +++ b/server/src/test/java/org/opensearch/index/codec/composite99/datacube/startree/StarTreeDocValuesFormatTests.java @@ -72,8 +72,9 @@ */ @LuceneTestCase.SuppressSysoutChecks(bugUrl = "we log a lot on purpose") public class StarTreeDocValuesFormatTests extends BaseDocValuesFormatTestCase { - MapperService mapperService = null; + StarTreeFieldConfiguration.StarTreeBuildMode buildMode; + MapperService mapperService; public StarTreeDocValuesFormatTests(StarTreeFieldConfiguration.StarTreeBuildMode buildMode) { this.buildMode = buildMode; @@ -105,13 +106,14 @@ public void teardown() throws IOException { @Override protected Codec getCodec() { final Logger testLogger = LogManager.getLogger(StarTreeDocValuesFormatTests.class); - + Codec codec; try { - createMapperService(getExpandedMapping()); + mapperService = createMapperService(getExpandedMapping()); + codec = new Composite99Codec(Lucene99Codec.Mode.BEST_SPEED, mapperService, testLogger); } catch (IOException e) { throw new RuntimeException(e); } - Codec codec = new Composite99Codec(Lucene99Codec.Mode.BEST_SPEED, mapperService, testLogger); + return codec; } @@ -195,7 +197,7 @@ public void testStarTreeDocValues() throws IOException { directory.close(); } - private XContentBuilder getExpandedMapping() throws IOException { + public static XContentBuilder getExpandedMapping() throws IOException { return topMapping(b -> { b.startObject("composite"); b.startObject("startree"); @@ -215,6 +217,8 @@ private XContentBuilder getExpandedMapping() throws IOException { b.field("name", "field"); b.startArray("stats"); b.value("sum"); + b.value("max"); + b.value("min"); b.value("count"); b.endArray(); b.endObject(); @@ -236,13 +240,14 @@ private XContentBuilder getExpandedMapping() throws IOException { }); } - private XContentBuilder topMapping(CheckedConsumer buildFields) throws IOException { + private static XContentBuilder topMapping(CheckedConsumer buildFields) throws IOException { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("_doc"); buildFields.accept(builder); return builder.endObject().endObject(); } - private void createMapperService(XContentBuilder builder) throws IOException { + public static MapperService createMapperService(XContentBuilder builder) throws IOException { + MapperService mapperService = null; IndexMetadata indexMetadata = IndexMetadata.builder("test") .settings( Settings.builder() @@ -261,5 +266,6 @@ private void createMapperService(XContentBuilder builder) throws IOException { "test" ); mapperService.merge(indexMetadata, MapperService.MergeReason.INDEX_TEMPLATE); + return mapperService; } } diff --git a/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java new file mode 100644 index 0000000000000..fdbf5e4de4297 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/aggregations/startree/MetricAggregatorTests.java @@ -0,0 +1,174 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.aggregations.startree; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SegmentReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.common.lucene.Lucene; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.index.codec.composite.CompositeIndexFieldInfo; +import org.opensearch.index.codec.composite.CompositeIndexReader; +import org.opensearch.index.codec.composite.composite99.Composite99Codec; +import org.opensearch.index.codec.composite99.datacube.startree.StarTreeDocValuesFormatTests; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregatorTestCase; +import org.junit.After; +import org.junit.Before; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder; +import org.opensearch.search.aggregations.metrics.InternalAvg; +import org.opensearch.search.aggregations.metrics.InternalMax; +import org.opensearch.search.aggregations.metrics.InternalMin; +import org.opensearch.search.aggregations.metrics.InternalSum; +import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder; +import org.opensearch.search.aggregations.metrics.MinAggregationBuilder; +import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; +import org.opensearch.search.startree.OriginalOrStarTreeQuery; +import org.opensearch.search.startree.StarTreeQuery; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.function.Predicate; + +import static org.opensearch.search.aggregations.AggregationBuilders.avg; +import static org.opensearch.search.aggregations.AggregationBuilders.max; +import static org.opensearch.search.aggregations.AggregationBuilders.min; +import static org.opensearch.search.aggregations.AggregationBuilders.sum; +import static org.opensearch.test.InternalAggregationTestCase.DEFAULT_MAX_BUCKETS; + +public class MetricAggregatorTests extends AggregatorTestCase { + + private static final String FIELD_NAME = "field"; + private static final NumberFieldMapper.NumberType DEFAULT_FIELD_TYPE = NumberFieldMapper.NumberType.LONG; + private static final MappedFieldType DEFAULT_MAPPED_FIELD = new NumberFieldMapper.NumberFieldType(FIELD_NAME, DEFAULT_FIELD_TYPE); + + @Before + public void setup() { + FeatureFlags.initializeFeatureFlags(Settings.builder().put(FeatureFlags.STAR_TREE_INDEX, true).build()); + } + + @After + public void teardown() throws IOException { + FeatureFlags.initializeFeatureFlags(Settings.EMPTY); + } + + protected Codec getCodec() { + final Logger testLogger = LogManager.getLogger(MetricAggregatorTests.class); + + MapperService mapperService; + try { + mapperService = StarTreeDocValuesFormatTests.createMapperService(StarTreeDocValuesFormatTests.getExpandedMapping()); + } catch (IOException e) { + throw new RuntimeException(e); + } + return new Composite99Codec(Lucene99Codec.Mode.BEST_SPEED, mapperService, testLogger); + } + + public void testStarTreeDocValues() throws IOException { + Directory directory = newDirectory(); + IndexWriterConfig conf = newIndexWriterConfig(null); + conf.setCodec(getCodec()); + conf.setMergePolicy(newLogMergePolicy()); + RandomIndexWriter iw = new RandomIndexWriter(random(), directory, conf); + Document doc = new Document(); + doc.add(new SortedNumericDocValuesField("sndv", 1)); + doc.add(new SortedNumericDocValuesField("dv", 1)); + doc.add(new SortedNumericDocValuesField("field", 1)); + iw.addDocument(doc); + doc = new Document(); + doc.add(new SortedNumericDocValuesField("sndv", 1)); + doc.add(new SortedNumericDocValuesField("dv", 2)); + doc.add(new SortedNumericDocValuesField("field", 3)); + iw.addDocument(doc); + doc = new Document(); + iw.forceMerge(1); + doc.add(new SortedNumericDocValuesField("sndv", 2)); + doc.add(new SortedNumericDocValuesField("dv", 4)); + doc.add(new SortedNumericDocValuesField("field", 6)); + iw.addDocument(doc); + doc = new Document(); + doc.add(new SortedNumericDocValuesField("sndv", 3)); + doc.add(new SortedNumericDocValuesField("dv", 6)); + doc.add(new SortedNumericDocValuesField("field", 9)); + iw.addDocument(doc); + iw.forceMerge(1); + iw.close(); + + DirectoryReader ir = DirectoryReader.open(directory); + initValuesSourceRegistry(); + assertEquals(ir.leaves().size(), 1); + LeafReaderContext context = ir.leaves().get(0); + + + SegmentReader reader = Lucene.segmentReader(context.reader()); + IndexSearcher indexSearcher = newSearcher(reader, true, true); + CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader) reader.getDocValuesReader(); + List compositeIndexFields = starTreeDocValuesReader.getCompositeIndexFields(); + + CompositeIndexFieldInfo starTree = compositeIndexFields.get(0); + + SumAggregationBuilder sumAggregationBuilder = sum("_name").field(FIELD_NAME); + MaxAggregationBuilder maxAggregationBuilder = max("_name").field(FIELD_NAME); + MinAggregationBuilder minAggregationBuilder = min("_name").field(FIELD_NAME); + + // match-all query + Query defaultQeury = new MatchAllDocsQuery(); + StarTreeQuery starTreeQuery = new StarTreeQuery(starTree, null); // no predicates + testCase(indexSearcher, defaultQeury, starTreeQuery, sumAggregationBuilder, verifyAggregation(InternalSum::getValue), 19); + testCase(indexSearcher, defaultQeury, starTreeQuery, maxAggregationBuilder, verifyAggregation(InternalMax::getValue), 9); + testCase(indexSearcher, defaultQeury, starTreeQuery, minAggregationBuilder, verifyAggregation(InternalMin::getValue), 1); + + // numeric-terms query + defaultQeury = new TermQuery(new Term("sndv", "1")); + Map>> compositePredicateMap = new HashMap<>(); + compositePredicateMap.put("sndv", List.of(dimVal -> dimVal == 1)); + starTreeQuery = new StarTreeQuery(starTree, compositePredicateMap); + testCase(indexSearcher, defaultQeury, starTreeQuery, sumAggregationBuilder, verifyAggregation(InternalSum::getValue), 4); + testCase(indexSearcher, defaultQeury, starTreeQuery, maxAggregationBuilder, verifyAggregation(InternalMax::getValue), 3); + testCase(indexSearcher, defaultQeury, starTreeQuery, minAggregationBuilder, verifyAggregation(InternalMin::getValue), 1); + + ir.close(); + directory.close(); + } + + private void testCase(IndexSearcher searcher, + Query defaultQuery, StarTreeQuery starTreeQuery, T builder, BiConsumer verify, long expectedValue) throws IOException { + OriginalOrStarTreeQuery originalOrStarTreeQuery = new OriginalOrStarTreeQuery(starTreeQuery, defaultQuery); + V aggregation = searchAndReduceStarTree(createIndexSettings(), searcher, originalOrStarTreeQuery, builder, DEFAULT_MAX_BUCKETS, false, DEFAULT_MAPPED_FIELD); + verify.accept(aggregation, expectedValue); + } + + + BiConsumer verifyAggregation(Function valueExtractor) { + return (aggregation, expectedValue) -> assertEquals(expectedValue, valueExtractor.apply(aggregation), 0f); + } +} diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 4abd7fbea9cff..9d816e864c50b 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -37,6 +37,7 @@ import org.apache.lucene.document.Field; import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.document.LatLonDocValuesField; +import org.apache.lucene.document.LongField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StoredField; @@ -141,6 +142,7 @@ import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.lookup.SearchLookup; +import org.opensearch.search.startree.StarTreeQuery; import org.opensearch.test.InternalAggregationTestCase; import org.opensearch.test.OpenSearchTestCase; import org.junit.After; @@ -650,6 +652,74 @@ protected A searchAndReduc doAssertReducedMultiBucketConsumer(internalAgg, reduceBucketConsumer); return internalAgg; } + protected A searchAndReduceStarTree( + IndexSettings indexSettings, + IndexSearcher searcher, + Query query, + AggregationBuilder builder, + int maxBucket, + boolean hasNested, + MappedFieldType... fieldTypes + ) throws IOException { + final IndexReaderContext ctx = searcher.getTopReaderContext(); + final PipelineTree pipelines = builder.buildPipelineTree(); + List aggs = new ArrayList<>(); + if (hasNested) { + query = Queries.filtered(query, Queries.newNonNestedFilter()); + } + + MultiBucketConsumer bucketConsumer = new MultiBucketConsumer( + maxBucket, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + C root = createAggregator(query, builder, searcher, bucketConsumer, fieldTypes); + + if (randomBoolean() && searcher.getIndexReader().leaves().size() > 0) { + assertTrue(ctx instanceof LeafReaderContext); + final LeafReaderContext compCTX = (LeafReaderContext) ctx; + final int size = compCTX.leaves().size(); + final ShardSearcher[] subSearchers = new ShardSearcher[size]; + for (int searcherIDX = 0; searcherIDX < subSearchers.length; searcherIDX++) { + final LeafReaderContext leave = compCTX.leaves().get(searcherIDX); + subSearchers[searcherIDX] = new ShardSearcher(leave, compCTX); + } + for (ShardSearcher subSearcher : subSearchers) { + MultiBucketConsumer shardBucketConsumer = new MultiBucketConsumer( + maxBucket, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + C a = createAggregator(query, builder, subSearcher, indexSettings, shardBucketConsumer, fieldTypes); + a.preCollection(); + Weight weight = subSearcher.createWeight(query, ScoreMode.COMPLETE, 1f); + + assertTrue(weight.getQuery() instanceof StarTreeQuery); + subSearcher.search(weight, a); + a.postCollection(); + aggs.add(a.buildTopLevel()); + } + } else { + root.preCollection(); + searcher.search(query, root); + root.postCollection(); + aggs.add(root.buildTopLevel()); + } + + MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumer( + maxBucket, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction( + root.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + pipelines + ); + + @SuppressWarnings("unchecked") + A internalAgg = (A) aggs.get(0).reduce(aggs, context); + doAssertReducedMultiBucketConsumer(internalAgg, reduceBucketConsumer); + return internalAgg; + } protected void doAssertReducedMultiBucketConsumer(Aggregation agg, MultiBucketConsumerService.MultiBucketConsumer bucketConsumer) { InternalAggregationTestCase.assertMultiBucketConsumer(agg, bucketConsumer);