Skip to content

Commit

Permalink
HSEARCH-5133 Implement Lucene sum aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
fax4ever committed Jul 5, 2024
1 parent bc40bf5 commit 6113c0a
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl;

import java.io.IOException;

import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.LongMultiValues;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.LongMultiValuesSource;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;

public class SumCollector implements Collector {

private final LongMultiValuesSource valueSource;
private long sum = 0L;

public SumCollector(LongMultiValuesSource valueSource) {
this.valueSource = valueSource;
}

public long sum() {
return sum;
}

@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return new SumLeafCollector( valueSource.getValues( context ) );
}

@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}

public class SumLeafCollector implements LeafCollector {
private final LongMultiValues values;

public SumLeafCollector(LongMultiValues values) {
this.values = values;
}

@Override
public void collect(int doc) throws IOException {
if ( values.advanceExact( doc ) ) {
while ( values.hasNextValue() ) {
long value = values.nextValue();
sum += value;
}
}
}

@Override
public void setScorer(Scorable scorer) {
// no-op by default
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl;

import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorExecutionContext;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;

public class SumCollectorFactory implements CollectorFactory<SumCollector, Long, SumCollectorManager> {

private final JoiningLongMultiValuesSource source;
private final CollectorKey<SumCollector, Long> key;

public SumCollectorFactory(JoiningLongMultiValuesSource source, CollectorKey<SumCollector, Long> key) {
this.source = source;
this.key = key;
}

@Override
public SumCollectorManager createCollectorManager(CollectorExecutionContext context) {
return new SumCollectorManager( source );
}

@Override
public CollectorKey<SumCollector, Long> getCollectorKey() {
return key;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl;

import java.io.IOException;
import java.util.Collection;

import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;

import org.apache.lucene.search.CollectorManager;

public class SumCollectorManager implements CollectorManager<SumCollector, Long> {

private final JoiningLongMultiValuesSource source;

public SumCollectorManager(JoiningLongMultiValuesSource source) {
this.source = source;
}

@Override
public SumCollector newCollector() {
return new SumCollector( source );
}

@Override
public Long reduce(Collection<SumCollector> collectors) throws IOException {
long result = 0L;
for ( SumCollector sumCollector : collectors ) {
result += sumCollector.sum();
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static JoiningLongMultiValuesSource fromIntField(String field, NestedDocs
return fromField( field, nested );
}

private static JoiningLongMultiValuesSource fromField(String field, NestedDocsProvider nested) {
public static JoiningLongMultiValuesSource fromField(String field, NestedDocsProvider nested) {
return new FieldLongMultiValuesSource( field, nested );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,35 @@
*/
package org.hibernate.search.backend.lucene.search.aggregation.impl;

import java.util.Set;

import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.join.impl.NestedDocsProvider;
import org.hibernate.search.backend.lucene.search.extraction.impl.ExtractionRequirements;
import org.hibernate.search.backend.lucene.search.predicate.impl.PredicateRequestContext;
import org.hibernate.search.backend.lucene.search.query.impl.LuceneSearchQueryIndexScope;
import org.hibernate.search.engine.backend.session.spi.BackendSessionContext;
import org.hibernate.search.engine.search.common.NamedValues;
import org.hibernate.search.engine.search.query.spi.QueryParameters;

import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;

public final class AggregationRequestContext {

private final LuceneSearchQueryIndexScope<?> queryIndexScope;
private final BackendSessionContext sessionContext;
private final Set<String> routingKeys;
private final ExtractionRequirements.Builder extractionRequirementsBuilder;
private final QueryParameters parameters;

public AggregationRequestContext(ExtractionRequirements.Builder extractionRequirementsBuilder,
public AggregationRequestContext(LuceneSearchQueryIndexScope<?> queryIndexScope, BackendSessionContext sessionContext,
Set<String> routingKeys, ExtractionRequirements.Builder extractionRequirementsBuilder,
QueryParameters parameters) {
this.queryIndexScope = queryIndexScope;
this.sessionContext = sessionContext;
this.routingKeys = routingKeys;
this.extractionRequirementsBuilder = extractionRequirementsBuilder;
this.parameters = parameters;
}
Expand All @@ -31,4 +45,13 @@ public <C extends Collector, T, CM extends CollectorManager<C, T>> void requireC
public NamedValues queryParameters() {
return parameters;
}

public PredicateRequestContext toPredicateRequestContext(String absolutePath) {
return PredicateRequestContext.withSession( queryIndexScope, sessionContext, routingKeys, parameters )
.withNestedPath( absolutePath );
}

public NestedDocsProvider createNestedDocsProvider(String nestedDocumentPath, Query nestedFilter) {
return new NestedDocsProvider( nestedDocumentPath, nestedFilter );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ public LuceneSearchQuery<H> build() {
if ( aggregations != null ) {
aggregationExtractors = new LinkedHashMap<>();
AggregationRequestContext aggregationRequestContext =
new AggregationRequestContext( extractionRequirementsBuilder, parameters );
new AggregationRequestContext( scope, sessionContext, routingKeys, extractionRequirementsBuilder,
parameters );
for ( Map.Entry<AggregationKey<?>, LuceneSearchAggregation<?>> entry : aggregations.entrySet() ) {
aggregationExtractors.put( entry.getKey(), entry.getValue().request( aggregationRequestContext ) );
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.hibernate.search.backend.lucene.logging.impl.Log;
import org.hibernate.search.backend.lucene.lowlevel.join.impl.NestedDocsProvider;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationRequestContext;
import org.hibernate.search.backend.lucene.search.aggregation.impl.LuceneSearchAggregation;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext;
Expand Down Expand Up @@ -41,6 +42,15 @@ protected NestedDocsProvider createNestedDocsProvider(AggregationExtractContext
return nestedDocsProvider;
}

protected NestedDocsProvider createNestedDocsProvider(AggregationRequestContext context) {
NestedDocsProvider nestedDocsProvider = null;
if ( nestedDocumentPath != null ) {
nestedDocsProvider = context.createNestedDocsProvider( nestedDocumentPath,
toNestedFilterQuery( context.toPredicateRequestContext( nestedDocumentPath ) ) );
}
return nestedDocsProvider;
}

private Query toNestedFilterQuery(PredicateRequestContext filterContext) {
return nestedFilter == null ? null : nestedFilter.toQuery( filterContext );
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Red Hat Inc. and Hibernate Authors
*/
package org.hibernate.search.backend.lucene.types.aggregation.impl;

import java.util.Set;

import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.SumCollector;
import org.hibernate.search.backend.lucene.lowlevel.aggregation.collector.impl.SumCollectorFactory;
import org.hibernate.search.backend.lucene.lowlevel.collector.impl.CollectorKey;
import org.hibernate.search.backend.lucene.lowlevel.docvalues.impl.JoiningLongMultiValuesSource;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationExtractContext;
import org.hibernate.search.backend.lucene.search.aggregation.impl.AggregationRequestContext;
import org.hibernate.search.backend.lucene.search.common.impl.AbstractLuceneCodecAwareSearchQueryElementFactory;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexScope;
import org.hibernate.search.backend.lucene.search.common.impl.LuceneSearchIndexValueFieldContext;
import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec;
import org.hibernate.search.backend.lucene.types.lowlevel.impl.LuceneNumericDomain;
import org.hibernate.search.engine.backend.types.converter.spi.ProjectionConverter;
import org.hibernate.search.engine.search.aggregation.spi.FieldMetricAggregationBuilder;
import org.hibernate.search.engine.search.common.ValueConvert;

/**
* @param <F> The type of field values.
* @param <K> The type of returned value. It can be {@code F}, {@link Double}
* or a different type if value converters are used.
*/
public class LuceneMetricNumericFieldAggregation<F, E extends Number, K> extends AbstractLuceneNestableAggregation<K> {

private final Set<String> indexNames;
private final String absoluteFieldPath;
private final AbstractLuceneNumericFieldCodec<F, E> codec;
private final LuceneNumericDomain<E> numericDomain;
private final ProjectionConverter<F, ? extends K> fromFieldValueConverter;
private final String operation;
private final CollectorKey<SumCollector, Long> collectorKey;

LuceneMetricNumericFieldAggregation(Builder<F, E, K> builder) {
super( builder );
this.indexNames = builder.scope.hibernateSearchIndexNames();
this.absoluteFieldPath = builder.field.absolutePath();
this.codec = builder.codec;
this.numericDomain = codec.getDomain();
this.fromFieldValueConverter = builder.fromFieldValueConverter;
this.operation = builder.operation;
this.collectorKey = CollectorKey.create();
}

@Override
public Extractor<K> request(AggregationRequestContext context) {
JoiningLongMultiValuesSource source = JoiningLongMultiValuesSource.fromField(
absoluteFieldPath, createNestedDocsProvider( context )
);
if ( "sum".equals( operation ) ) {
context.requireCollector( new SumCollectorFactory( source, collectorKey ) );
}
return new LuceneNumericMetricFieldAggregationExtraction();
}

@Override
public Set<String> indexNames() {
return indexNames;
}

private class LuceneNumericMetricFieldAggregationExtraction implements Extractor<K> {

@Override
public K extract(AggregationExtractContext context) {
Long collector = context.getFacets( collectorKey );
E e = numericDomain.sortedDocValueToTerm( collector );
F decode = codec.decode( e );
return fromFieldValueConverter.fromDocumentValue( decode, context.fromDocumentValueConvertContext() );
}
}

public static class Factory<F>
extends AbstractLuceneCodecAwareSearchQueryElementFactory<FieldMetricAggregationBuilder.TypeSelector,
F,
AbstractLuceneNumericFieldCodec<F, ?>> {

private final String operation;

public Factory(AbstractLuceneNumericFieldCodec<F, ?> codec, String operation) {
super( codec );
this.operation = operation;
}

@Override
public FieldMetricAggregationBuilder.TypeSelector create(LuceneSearchIndexScope<?> scope,
LuceneSearchIndexValueFieldContext<F> field) {
return new TypeSelector<>( codec, scope, field, operation );
}
}

private static class TypeSelector<F> implements FieldMetricAggregationBuilder.TypeSelector {
private final AbstractLuceneNumericFieldCodec<F, ?> codec;
private final LuceneSearchIndexScope<?> scope;
private final LuceneSearchIndexValueFieldContext<F> field;
private final String operation;

private TypeSelector(AbstractLuceneNumericFieldCodec<F, ?> codec,
LuceneSearchIndexScope<?> scope, LuceneSearchIndexValueFieldContext<F> field,
String operation) {
this.codec = codec;
this.scope = scope;
this.field = field;
this.operation = operation;
}

@Override
public <T> Builder<F, ?, T> type(Class<T> expectedType, ValueConvert convert) {
return new Builder<>( codec, scope, field,
field.type().projectionConverter( convert ).withConvertedType( expectedType, field ),
operation
);
}
}

private static class Builder<F, E extends Number, K> extends AbstractBuilder<K>
implements FieldMetricAggregationBuilder<K> {

private final AbstractLuceneNumericFieldCodec<F, E> codec;
private final ProjectionConverter<F, ? extends K> fromFieldValueConverter;
private final String operation;

public Builder(AbstractLuceneNumericFieldCodec<F, E> codec, LuceneSearchIndexScope<?> scope,
LuceneSearchIndexValueFieldContext<F> field,
ProjectionConverter<F, ? extends K> fromFieldValueConverter,
String operation) {
super( scope, field );
this.codec = codec;
this.fromFieldValueConverter = fromFieldValueConverter;
this.operation = operation;
}

@Override
public LuceneMetricNumericFieldAggregation<F, E, K> build() {
return new LuceneMetricNumericFieldAggregation<>( this );
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import org.hibernate.search.backend.lucene.search.predicate.impl.LucenePredicateTypeKeys;
import org.hibernate.search.backend.lucene.search.projection.impl.LuceneFieldProjection;
import org.hibernate.search.backend.lucene.types.aggregation.impl.LuceneMetricNumericFieldAggregation;
import org.hibernate.search.backend.lucene.types.aggregation.impl.LuceneNumericRangeAggregation;
import org.hibernate.search.backend.lucene.types.aggregation.impl.LuceneNumericTermsAggregation;
import org.hibernate.search.backend.lucene.types.codec.impl.AbstractLuceneNumericFieldCodec;
Expand Down Expand Up @@ -85,6 +86,8 @@ public LuceneIndexValueFieldType<F> toIndexFieldType() {
builder.aggregable( true );
builder.queryElementFactory( AggregationTypeKeys.TERMS, new LuceneNumericTermsAggregation.Factory<>( codec ) );
builder.queryElementFactory( AggregationTypeKeys.RANGE, new LuceneNumericRangeAggregation.Factory<>( codec ) );
builder.queryElementFactory( AggregationTypeKeys.SUM,
new LuceneMetricNumericFieldAggregation.Factory<>( codec, "sum" ) );
}

return builder.build();
Expand Down
Loading

0 comments on commit 6113c0a

Please sign in to comment.