Skip to content

Commit

Permalink
Added pagination support to top_hits aggregation by adding from o…
Browse files Browse the repository at this point in the history
…ption.

Closes elastic#6299
  • Loading branch information
martijnvg committed May 26, 2014
1 parent 3f2f1f0 commit a989229
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ This aggregator can't hold any sub-aggregators and therefor can only be used as

==== Options

* `from` - The index from which to include matching hits.
* `size` - The maximum number of top matching hits to return per bucket. By default the top three matching hits are returned.
* `sort` - How the top matching hits should be sorted. By default the hits are sorted by the score of the main query.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public static void registerStreams() {
AggregationStreams.registerStream(STREAM, TYPE.stream());
}

private int from;
private int size;
private Sort sort;
private TopDocs topDocs;
Expand All @@ -62,8 +63,9 @@ public static void registerStreams() {
InternalTopHits() {
}

public InternalTopHits(String name, int size, Sort sort, TopDocs topDocs, InternalSearchHits searchHits) {
public InternalTopHits(String name, int from, int size, Sort sort, TopDocs topDocs, InternalSearchHits searchHits) {
this.name = name;
this.from = from;
this.size = size;
this.sort = sort;
this.topDocs = topDocs;
Expand Down Expand Up @@ -104,7 +106,7 @@ public InternalAggregation reduce(ReduceContext reduceContext) {

try {
int[] tracker = new int[shardHits.length];
TopDocs reducedTopDocs = TopDocs.merge(sort, size, shardDocs);
TopDocs reducedTopDocs = TopDocs.merge(sort, from, size, shardDocs);
InternalSearchHit[] hits = new InternalSearchHit[reducedTopDocs.scoreDocs.length];
for (int i = 0; i < reducedTopDocs.scoreDocs.length; i++) {
ScoreDoc scoreDoc = reducedTopDocs.scoreDocs[i];
Expand All @@ -119,6 +121,7 @@ public InternalAggregation reduce(ReduceContext reduceContext) {
@Override
public void readFrom(StreamInput in) throws IOException {
name = in.readString();
from = in.readVInt();
size = in.readVInt();
topDocs = Lucene.readTopDocs(in);
if (topDocs instanceof TopFieldDocs) {
Expand All @@ -130,6 +133,7 @@ public void readFrom(StreamInput in) throws IOException {
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(name);
out.writeVInt(from);
out.writeVInt(size);
Lucene.writeTopDocs(out, topDocs, 0);
searchHits.writeTo(out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
import org.elasticsearch.search.aggregations.Aggregation;

/**
* Accumulation of the most relevant hits for a bucket this aggregation falls into.
*/
public interface TopHits extends Aggregation {

/**
* @return The top matching hits for the bucket
*/
SearchHits getHits();

}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
searchHitFields.sortValues(fieldDoc.fields);
}
}
return new InternalTopHits(name, topHitsContext.size(), topHitsContext.sort(), topDocs, fetchResult.hits());
return new InternalTopHits(name, topHitsContext.from(), topHitsContext.size(), topHitsContext.sort(), topDocs, fetchResult.hits());
}
}

Expand All @@ -102,10 +102,10 @@ public void collect(int docId, long bucketOrdinal) throws IOException {
TopDocsCollector topDocsCollector = topDocsCollectors.get(bucketOrdinal);
if (topDocsCollector == null) {
Sort sort = topHitsContext.sort();
int size = topHitsContext.size();
int topN = topHitsContext.from() + topHitsContext.size();
topDocsCollectors.put(
bucketOrdinal,
topDocsCollector = sort != null ? TopFieldCollector.create(sort, size, true, topHitsContext.trackScores(), true, false) : TopScoreDocCollector.create(size, false)
topDocsCollector = sort != null ? TopFieldCollector.create(sort, topN, true, topHitsContext.trackScores(), true, false) : TopScoreDocCollector.create(topN, false)
);
topDocsCollector.setNextReader(currentContext);
topDocsCollector.setScorer(currentScorer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,17 @@ public TopHitsBuilder(String name) {
}

/**
* The number of search hits to return. Defaults to <tt>10</tt>.
* The index to start to return hits from. Defaults to <tt>0</tt>.
*/
public TopHitsBuilder setFrom(int from) {
sourceBuilder().from(from);
return this;
}


/**
* The number of search hits to return. Defaults to <tt>10</tt>.
*/
public TopHitsBuilder setSize(int size) {
sourceBuilder().size(size);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class TopHitsContext extends SearchContext {
// the to hits are returned per bucket.
private final static int DEFAULT_SIZE = 3;

private int from;
private int size = DEFAULT_SIZE;
private Sort sort;

Expand Down Expand Up @@ -440,12 +441,13 @@ public SearchContext updateRewriteQuery(Query rewriteQuery) {

@Override
public int from() {
return context.from();
return from;
}

@Override
public SearchContext from(int from) {
throw new UnsupportedOperationException("Not supported");
this.from = from;
return this;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public AggregatorFactory parse(String aggregationName, XContentParser parser, Se
currentFieldName = parser.currentName();
} else if (token.isValue()) {
switch (currentFieldName) {
case "from":
topHitsContext.from(parser.intValue());
break;
case "size":
topHitsContext.size(parser.intValue());
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ public void setupSuiteScopeCluster() throws Exception {
.endObject()));
}

// Use routing to make sure all docs are in the same shard for consistent scoring
builders.add(client().prepareIndex("idx", "field-collapsing", "1").setSource(jsonBuilder()
.startObject()
.field("group", "a")
Expand Down Expand Up @@ -168,6 +167,159 @@ public void testBasics() throws Exception {
}
}

@Test
public void testPagination() throws Exception {
SearchResponse response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)).setSize(2)
)
)
.get();

assertSearchResponse(response);

Terms terms = response.getAggregations().get("terms");
assertThat(terms, notNullValue());
assertThat(terms.getName(), equalTo("terms"));
assertThat(terms.getBuckets().size(), equalTo(5));

Terms.Bucket bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
TopHits topHits = bucket.getAggregations().get("hits");
SearchHits hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(2));
assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(10l));
assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(9l));

response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))
.setSize(2)
.setFrom(2)
)
)
.get();

assertSearchResponse(response);

terms = response.getAggregations().get("terms");
bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
topHits = bucket.getAggregations().get("hits");
hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(2));
assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(8l));
assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(7l));

response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))
.setSize(2)
.setFrom(4)
)
)
.get();

assertSearchResponse(response);

terms = response.getAggregations().get("terms");
bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
topHits = bucket.getAggregations().get("hits");
hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(2));
assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(6l));
assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(5l));

response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))
.setSize(2)
.setFrom(6)
)
)
.get();

assertSearchResponse(response);

terms = response.getAggregations().get("terms");
bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
topHits = bucket.getAggregations().get("hits");
hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(2));
assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(4l));
assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(3l));

response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))
.setSize(2)
.setFrom(8)
)
)
.get();

assertSearchResponse(response);

terms = response.getAggregations().get("terms");
bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
topHits = bucket.getAggregations().get("hits");
hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(2));
assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(2l));
assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(1l));

response = client().prepareSearch("idx").setTypes("type")
.addAggregation(terms("terms")
.executionHint(randomExecutionHint())
.field(TERMS_AGGS_FIELD)
.subAggregation(
topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC))
.setSize(2)
.setFrom(10)
)
)
.get();

assertSearchResponse(response);

terms = response.getAggregations().get("terms");
bucket = terms.getBucketByKey("val0");
assertThat(bucket, notNullValue());
assertThat(bucket.getDocCount(), equalTo(10l));
topHits = bucket.getAggregations().get("hits");
hits = topHits.getHits();
assertThat(hits.totalHits(), equalTo(10l));
assertThat(hits.getHits().length, equalTo(0));
}

@Test
public void testSortByBucket() throws Exception {
SearchResponse response = client().prepareSearch("idx").setTypes("type")
Expand Down

0 comments on commit a989229

Please sign in to comment.