diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java index 69adb79cb2b84..70135c2d51e73 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java @@ -18,18 +18,22 @@ */ package org.elasticsearch.search.aggregations; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Streamable; import org.elasticsearch.search.aggregations.InternalAggregation.ReduceContext; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static java.util.Collections.emptyMap; @@ -49,6 +53,8 @@ public final class InternalAggregations extends Aggregations implements Streamab } }; + private List topLevelPipelineAggregators = Collections.emptyList(); + private InternalAggregations() { } @@ -60,18 +66,42 @@ public InternalAggregations(List aggregations) { } /** - * Reduces the given list of aggregations + * Constructs a new aggregation providing its {@link InternalAggregation}s and {@link SiblingPipelineAggregator}s + */ + public InternalAggregations(List aggregations, List topLevelPipelineAggregators) { + super(aggregations); + this.topLevelPipelineAggregators = Objects.requireNonNull(topLevelPipelineAggregators); + } + + /** + * Returns the top-level pipeline aggregators. + * Note that top-level pipeline aggregators become normal aggregation once the final reduction has been performed, after which they + * become part of the list of {@link InternalAggregation}s. + */ + List getTopLevelPipelineAggregators() { + return topLevelPipelineAggregators; + } + + /** + * Reduces the given list of aggregations as well as the top-level pipeline aggregators extracted from the first + * {@link InternalAggregations} object found in the list. + * Note that top-level pipeline aggregators are reduced only as part of the final reduction phase, otherwise they are left untouched. */ - public static InternalAggregations reduce(List aggregationsList, ReduceContext context) { - return reduce(aggregationsList, null, context); + public static InternalAggregations reduce(List aggregationsList, + ReduceContext context) { + if (aggregationsList.isEmpty()) { + return null; + } + InternalAggregations first = aggregationsList.get(0); + return reduce(aggregationsList, first.topLevelPipelineAggregators, context); } /** - * Reduces the given list of aggregations as well as the provided sibling pipeline aggregators. - * Note that sibling pipeline aggregators are ignored when non final reduction is performed. + * Reduces the given list of aggregations as well as the provided top-level pipeline aggregators. + * Note that top-level pipeline aggregators are reduced only as part of the final reduction phase, otherwise they are left untouched. */ public static InternalAggregations reduce(List aggregationsList, - List siblingPipelineAggregators, + List topLevelPipelineAggregators, ReduceContext context) { if (aggregationsList.isEmpty()) { return null; @@ -98,15 +128,14 @@ public static InternalAggregations reduce(List aggregation reducedAggregations.add(first.reduce(aggregations, context)); } - if (siblingPipelineAggregators != null) { - if (context.isFinalReduce()) { - for (SiblingPipelineAggregator pipelineAggregator : siblingPipelineAggregators) { - InternalAggregation newAgg = pipelineAggregator.doReduce(new InternalAggregations(reducedAggregations), context); - reducedAggregations.add(newAgg); - } + if (context.isFinalReduce()) { + for (SiblingPipelineAggregator pipelineAggregator : topLevelPipelineAggregators) { + InternalAggregation newAgg = pipelineAggregator.doReduce(new InternalAggregations(reducedAggregations), context); + reducedAggregations.add(newAgg); } + return new InternalAggregations(reducedAggregations); } - return new InternalAggregations(reducedAggregations); + return new InternalAggregations(reducedAggregations, topLevelPipelineAggregators); } public static InternalAggregations readAggregations(StreamInput in) throws IOException { @@ -121,11 +150,20 @@ public void readFrom(StreamInput in) throws IOException { if (aggregations.isEmpty()) { aggregationsAsMap = emptyMap(); } + if (in.getVersion().onOrAfter(Version.V_6_7_0)) { + this.topLevelPipelineAggregators = in.readList( + stream -> (SiblingPipelineAggregator)in.readNamedWriteable(PipelineAggregator.class)); + } else { + this.topLevelPipelineAggregators = Collections.emptyList(); + } } @Override @SuppressWarnings("unchecked") public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteableList((List)aggregations); + if (out.getVersion().onOrAfter(Version.V_6_7_0)) { + out.writeNamedWriteableList(topLevelPipelineAggregators); + } } } diff --git a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java index 43654823914b4..55787dfc53a35 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java +++ b/server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java @@ -36,10 +36,11 @@ import org.elasticsearch.search.suggest.Suggest; import java.io.IOException; +import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; -import static java.util.Collections.emptyList; import static org.elasticsearch.common.lucene.Lucene.readTopDocs; import static org.elasticsearch.common.lucene.Lucene.writeTopDocs; @@ -54,7 +55,7 @@ public final class QuerySearchResult extends SearchPhaseResult { private DocValueFormat[] sortValueFormats; private InternalAggregations aggregations; private boolean hasAggs; - private List pipelineAggregators; + private List pipelineAggregators = Collections.emptyList(); private Suggest suggest; private boolean searchTimedOut; private Boolean terminatedEarly = null; @@ -80,7 +81,6 @@ public QuerySearchResult queryResult() { return this; } - public void searchTimedOut(boolean searchTimedOut) { this.searchTimedOut = searchTimedOut; } @@ -204,7 +204,7 @@ public List pipelineAggregators() { } public void pipelineAggregators(List pipelineAggregators) { - this.pipelineAggregators = pipelineAggregators; + this.pipelineAggregators = Objects.requireNonNull(pipelineAggregators); } public Suggest suggest() { @@ -338,7 +338,7 @@ public void writeToNoId(StreamOutput out) throws IOException { out.writeBoolean(true); aggregations.writeTo(out); } - out.writeNamedWriteableList(pipelineAggregators == null ? emptyList() : pipelineAggregators); + out.writeNamedWriteableList(pipelineAggregators); if (suggest == null) { out.writeBoolean(false); } else { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java index 100c36ea744a5..2165895974e27 100644 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionTests.java @@ -412,7 +412,8 @@ public void testCCSRemoteReduce() throws Exception { OriginalIndices localIndices = local ? new OriginalIndices(new String[]{"index"}, SearchRequest.DEFAULT_INDICES_OPTIONS) : null; int totalClusters = numClusters + (local ? 1 : 0); TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0); - Function reduceContext = finalReduce -> null; + Function reduceContext = + finalReduce -> new InternalAggregation.ReduceContext(null, null, finalReduce); try (MockTransportService service = MockTransportService.createNewService(settings, Version.CURRENT, threadPool, null)) { service.start(); service.acceptIncomingRequests(); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java new file mode 100644 index 0000000000000..3212c18cf278f --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java @@ -0,0 +1,155 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.aggregations.bucket.histogram.InternalDateHistogramTests; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.bucket.terms.StringTermsTests; +import org.elasticsearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValueTests; +import org.elasticsearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; +import org.elasticsearch.search.aggregations.pipeline.SiblingPipelineAggregator; +import org.elasticsearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.VersionUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class InternalAggregationsTests extends ESTestCase { + + private final NamedWriteableRegistry registry = new NamedWriteableRegistry( + new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + + public void testReduceEmptyAggs() { + List aggs = Collections.emptyList(); + InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, randomBoolean()); + assertNull(InternalAggregations.reduce(aggs, Collections.emptyList(), reduceContext)); + } + + public void testNonFinalReduceTopLevelPipelineAggs() throws IOException { + InternalAggregation terms = new StringTerms("name", BucketOrder.key(true), + 10, 1, Collections.emptyList(), Collections.emptyMap(), DocValueFormat.RAW, 25, false, 10, Collections.emptyList(), 0); + List aggs = Collections.singletonList(new InternalAggregations(Collections.singletonList(terms))); + List topLevelPipelineAggs = new ArrayList<>(); + MaxBucketPipelineAggregationBuilder maxBucketPipelineAggregationBuilder = new MaxBucketPipelineAggregationBuilder("test", "test"); + topLevelPipelineAggs.add((SiblingPipelineAggregator)maxBucketPipelineAggregationBuilder.create()); + InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, false); + InternalAggregations reducedAggs = InternalAggregations.reduce(aggs, topLevelPipelineAggs, reduceContext); + assertEquals(1, reducedAggs.getTopLevelPipelineAggregators().size()); + assertEquals(1, reducedAggs.aggregations.size()); + } + + public void testFinalReduceTopLevelPipelineAggs() throws IOException { + InternalAggregation terms = new StringTerms("name", BucketOrder.key(true), + 10, 1, Collections.emptyList(), Collections.emptyMap(), DocValueFormat.RAW, 25, false, 10, Collections.emptyList(), 0); + + MaxBucketPipelineAggregationBuilder maxBucketPipelineAggregationBuilder = new MaxBucketPipelineAggregationBuilder("test", "test"); + SiblingPipelineAggregator siblingPipelineAggregator = (SiblingPipelineAggregator) maxBucketPipelineAggregationBuilder.create(); + InternalAggregation.ReduceContext reduceContext = new InternalAggregation.ReduceContext(null, null, true); + final InternalAggregations reducedAggs; + if (randomBoolean()) { + InternalAggregations aggs = new InternalAggregations(Collections.singletonList(terms), + Collections.singletonList(siblingPipelineAggregator)); + reducedAggs = InternalAggregations.reduce(Collections.singletonList(aggs), reduceContext); + } else { + InternalAggregations aggs = new InternalAggregations(Collections.singletonList(terms)); + List topLevelPipelineAggs = Collections.singletonList(siblingPipelineAggregator); + reducedAggs = InternalAggregations.reduce(Collections.singletonList(aggs), topLevelPipelineAggs, reduceContext); + } + assertEquals(0, reducedAggs.getTopLevelPipelineAggregators().size()); + assertEquals(2, reducedAggs.aggregations.size()); + } + + public void testSerialization() throws Exception { + List aggsList = new ArrayList<>(); + if (randomBoolean()) { + StringTermsTests stringTermsTests = new StringTermsTests(); + stringTermsTests.init(); + stringTermsTests.setUp(); + aggsList.add(stringTermsTests.createTestInstance()); + } + if (randomBoolean()) { + InternalDateHistogramTests dateHistogramTests = new InternalDateHistogramTests(); + dateHistogramTests.setUp(); + aggsList.add(dateHistogramTests.createTestInstance()); + } + if (randomBoolean()) { + InternalSimpleValueTests simpleValueTests = new InternalSimpleValueTests(); + aggsList.add(simpleValueTests.createTestInstance()); + } + List topLevelPipelineAggs = new ArrayList<>(); + if (randomBoolean()) { + if (randomBoolean()) { + topLevelPipelineAggs.add((SiblingPipelineAggregator)new MaxBucketPipelineAggregationBuilder("name1", "bucket1").create()); + } + if (randomBoolean()) { + topLevelPipelineAggs.add((SiblingPipelineAggregator)new AvgBucketPipelineAggregationBuilder("name2", "bucket2").create()); + } + if (randomBoolean()) { + topLevelPipelineAggs.add((SiblingPipelineAggregator)new SumBucketPipelineAggregationBuilder("name3", "bucket3").create()); + } + } + InternalAggregations aggregations = new InternalAggregations(aggsList, topLevelPipelineAggs); + writeToAndReadFrom(aggregations, 0); + } + + private void writeToAndReadFrom(InternalAggregations aggregations, int iteration) throws IOException { + Version version = VersionUtils.randomVersion(random()); + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.setVersion(version); + aggregations.writeTo(out); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(out.bytes().toBytesRef().bytes), registry)) { + in.setVersion(version); + InternalAggregations deserialized = InternalAggregations.readAggregations(in); + assertEquals(aggregations.aggregations, deserialized.aggregations); + if (aggregations.getTopLevelPipelineAggregators() == null) { + assertEquals(0, deserialized.getTopLevelPipelineAggregators().size()); + } else { + if (version.before(Version.V_6_7_0)) { + assertEquals(0, deserialized.getTopLevelPipelineAggregators().size()); + } else { + assertEquals(aggregations.getTopLevelPipelineAggregators().size(), + deserialized.getTopLevelPipelineAggregators().size()); + for (int i = 0; i < aggregations.getTopLevelPipelineAggregators().size(); i++) { + SiblingPipelineAggregator siblingPipelineAggregator1 = aggregations.getTopLevelPipelineAggregators().get(i); + SiblingPipelineAggregator siblingPipelineAggregator2 = deserialized.getTopLevelPipelineAggregators().get(i); + assertArrayEquals(siblingPipelineAggregator1.bucketsPaths(), siblingPipelineAggregator2.bucketsPaths()); + assertEquals(siblingPipelineAggregator1.name(), siblingPipelineAggregator2.name()); + } + } + } + if (iteration < 2) { + //serialize this enough times to make sure that we are able to write again what we read + writeToAndReadFrom(deserialized, iteration + 1); + } + } + } + } +}