Skip to content

Commit

Permalink
handle classic retriever on the shard
Browse files Browse the repository at this point in the history
  • Loading branch information
jdconrad committed Nov 13, 2023
1 parent 5b4a214 commit 32b7372
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,9 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc
SearchExecutionContext searchExecutionContext = context.getSearchExecutionContext();
context.from(source.from());
context.size(source.size());
if (source.retrieverbuilder() != null) {
source.retrieverbuilder().doBuildSearchContext(context);
}
Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
QueryBuilder query = source.query();
if (query != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rank.RankBuilder;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.retriever.ClassicRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.searchafter.SearchAfterBuilder;
Expand Down Expand Up @@ -118,7 +119,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R
public static final ParseField SLICE = new ParseField("slice"); // global
public static final ParseField POINT_IN_TIME = new ParseField("pit"); // global
public static final ParseField RUNTIME_MAPPINGS_FIELD = new ParseField("runtime_mappings"); // global
public static final ParseField RETRIEVER = new ParseField("retriever"); // global
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever"); // global

private static final boolean RANK_SUPPORTED = Booleans.parseBoolean(System.getProperty("es.search.rank_supported"), true);

Expand All @@ -136,6 +137,8 @@ public static HighlightBuilder highlight() {
return new HighlightBuilder();
}

private RetrieverBuilder<?> retrieverBuilder;

private List<SubSearchSourceBuilder> subSearchSourceBuilders = new ArrayList<>();

private QueryBuilder postQueryBuilder;
Expand Down Expand Up @@ -207,6 +210,9 @@ public SearchSourceBuilder() {}
* Read from a stream.
*/
public SearchSourceBuilder(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.RETRIEVERS_ADDED)) {
retrieverBuilder = in.readOptionalNamedWriteable(RetrieverBuilder.class);
}
aggregations = in.readOptionalWriteable(AggregatorFactories.Builder::new);
explain = in.readOptionalBoolean();
fetchSourceContext = in.readOptionalWriteable(FetchSourceContext::readFrom);
Expand Down Expand Up @@ -282,6 +288,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException {

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.RETRIEVERS_ADDED)) {
out.writeOptionalNamedWriteable(retrieverBuilder);
}
out.writeOptionalWriteable(aggregations);
out.writeOptionalBoolean(explain);
out.writeOptionalWriteable(fetchSourceContext);
Expand Down Expand Up @@ -372,6 +381,21 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

/**
* Sets the {@link RetrieverBuilder} for this request.
*/
public SearchSourceBuilder retrieverBuilder(RetrieverBuilder<?> retrieverBuilder) {
this.retrieverBuilder = retrieverBuilder;
return this;
}

/**
* Gets the {@link RetrieverBuilder} for this request.
*/
public RetrieverBuilder<?> retrieverbuilder() {
return retrieverBuilder;
}

/**
* Sets the query for this request.
*/
Expand Down Expand Up @@ -1275,7 +1299,6 @@ private SearchSourceBuilder parseXContent(XContentParser parser, boolean checkTr
);
}

RetrieverBuilder<?> retrieverBuilder = null;
SearchUsage searchUsage = new SearchUsage();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
Expand Down Expand Up @@ -1341,7 +1364,7 @@ private SearchSourceBuilder parseXContent(XContentParser parser, boolean checkTr
);
}
} else if (token == XContentParser.Token.START_OBJECT) {
if (RETRIEVER.match(currentFieldName, parser.getDeprecationHandler())) {
if (RETRIEVER_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
retrieverBuilder = RetrieverBuilder.parseTopLevelRetrieverBuilder(
parser,
new RetrieverParserContext(searchUsage::trackSectionUsage, searchUsage::trackQueryUsage)
Expand Down Expand Up @@ -1602,7 +1625,12 @@ private SearchSourceBuilder parseXContent(XContentParser parser, boolean checkTr
}
}
if (retrieverBuilder != null) {
retrieverBuilder.extractToSearchSourceBuilder(this);
if (retrieverBuilder instanceof ClassicRetrieverBuilder) {
retrieverBuilder.validate(this);
} else {
retrieverBuilder.extractToSearchSourceBuilder(this);
retrieverBuilder = null;
}
}
searchUsageConsumer.accept(searchUsage);
return this;
Expand Down Expand Up @@ -1633,6 +1661,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t
builder.field(TERMINATE_AFTER_FIELD.getPreferredName(), terminateAfter);
}

if (retrieverBuilder != null) {
builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
}

if (subSearchSourceBuilders.isEmpty() == false) {
if (subSearchSourceBuilders.size() == 1) {
builder.field(QUERY_FIELD.getPreferredName(), subSearchSourceBuilders.get(0).getQueryBuilder());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,27 @@

package org.elasticsearch.search.retriever;

import org.apache.lucene.search.FieldDoc;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.InnerHitContextBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.search.SearchException;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.builder.SubSearchSourceBuilder;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.collapse.CollapseContext;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.rescore.RescorerBuilder;
import org.elasticsearch.search.searchafter.SearchAfterBuilder;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
Expand All @@ -30,8 +37,11 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public final class ClassicRetrieverBuilder extends RetrieverBuilder<ClassicRetrieverBuilder> {

Expand Down Expand Up @@ -363,6 +373,40 @@ public ClassicRetrieverBuilder collapseBuilder(CollapseBuilder collapseBuilder)
return this;
}

public void doValidate(SearchSourceBuilder searchSourceBuilder) {
if (searchSourceBuilder.query() != null) {
throw new IllegalStateException("[query] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.searchAfter() != null) {
throw new IllegalStateException("[search_after] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER) {
throw new IllegalStateException("[terminate_after] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.sorts() != null) {
throw new IllegalStateException("[sort] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.minScore() != null) {
throw new IllegalStateException("[min_score] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.postFilter() != null) {
throw new IllegalStateException("[post_filter] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.rescores() != null) {
throw new IllegalStateException("[rescore] cannot be declared as a retriever value and as a global value");
}

if (searchSourceBuilder.collapse() != null) {
throw new IllegalStateException("[collapse] cannot be declared as a retriever value and as a global value");
}
}

public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
if (queryBuilder != null) {
searchSourceBuilder.subSearches().add(new SubSearchSourceBuilder(queryBuilder));
Expand Down Expand Up @@ -422,4 +466,80 @@ public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuild
throw new IllegalStateException("[collapse] cannot be declared as a retriever value and as a global value");
}
}

@Override
public SearchContext doBuildSearchContext(SearchContext searchContext) {
SearchShardTarget shardTarget = searchContext.shardTarget();
SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext();
Map<String, InnerHitContextBuilder> innerHitBuilders = new HashMap<>();
if (queryBuilder != null) {
InnerHitContextBuilder.extractInnerHits(queryBuilder, innerHitBuilders);
searchExecutionContext.setAliasFilter(searchContext.request().getAliasFilter().getQueryBuilder());
searchContext.parsedQuery(searchExecutionContext.toQuery(queryBuilder));
}
if (postFilterQueryBuilder != null) {
InnerHitContextBuilder.extractInnerHits(postFilterQueryBuilder, innerHitBuilders);
searchContext.parsedPostFilter(searchExecutionContext.toQuery(postFilterQueryBuilder));
}
if (innerHitBuilders.isEmpty() == false) {
for (Map.Entry<String, InnerHitContextBuilder> entry : innerHitBuilders.entrySet()) {
try {
entry.getValue().build(searchContext, searchContext.innerHits());
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to build inner_hits", e);
}
}
}
if (sortBuilders != null) {
try {
Optional<SortAndFormats> optionalSort = SortBuilder.buildSort(sortBuilders, searchExecutionContext);
if (optionalSort.isPresent()) {
searchContext.sort(optionalSort.get());
}
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create sort elements", e);
}
}
if (minScore != null) {
searchContext.minimumScore(minScore);
}
searchContext.terminateAfter(terminateAfter);
if (rescorerBuilders != null) {
try {
for (RescorerBuilder<?> rescore : rescorerBuilders) {
searchContext.addRescore(rescore.buildContext(searchExecutionContext));
}
} catch (IOException e) {
throw new SearchException(shardTarget, "failed to create RescoreSearchContext", e);
}
}
if (collapseBuilder != null) {
if (searchContext.scrollContext() != null) {
throw new SearchException(shardTarget, "cannot use `collapse` in a scroll context");
}
if (searchContext.rescore() != null && searchContext.rescore().isEmpty() == false) {
throw new SearchException(shardTarget, "cannot use `collapse` in conjunction with `rescore`");
}
final CollapseContext collapseContext = collapseBuilder.build(searchExecutionContext);
searchContext.collapse(collapseContext);
}
if (searchAfterBuilder != null) {
if (searchAfterBuilder == null) {
return null;
}
Object[] searchAfterValues = searchAfterBuilder.getSortValues();
if (searchContext.scrollContext() != null) {
throw new SearchException(shardTarget, "`search_after` cannot be used in a scroll context.");
}
if (searchContext.from() > 0) {
throw new SearchException(shardTarget, "`from` parameter must be set to 0 when `search_after` is used.");
}

String collapseField = collapseBuilder != null ? collapseBuilder.getField() : null;
FieldDoc fieldDoc = SearchAfterBuilder.buildFieldDoc(searchContext.sort(), searchAfterValues, collapseField);
searchContext.searchAfter(fieldDoc);
}

return searchContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
Expand Down Expand Up @@ -256,6 +257,11 @@ public int hashCode() {
);
}

@Override
public void doValidate(SearchSourceBuilder searchSourceBuilder) {
throw new UnsupportedOperationException();
}

@Override
public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
// TODO: add support for multiple knn retrievers per search request
Expand All @@ -272,4 +278,9 @@ public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuild
}
}

@Override
public SearchContext doBuildSearchContext(SearchContext searchContext) {
throw new UnsupportedOperationException();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.xcontent.AbstractObjectParser;
import org.elasticsearch.xcontent.FilterXContentParserWrapper;
import org.elasticsearch.xcontent.NamedObjectNotFoundException;
Expand Down Expand Up @@ -252,6 +253,12 @@ public RB _name(String _name) {
return (RB) this;
}

public final void validate(SearchSourceBuilder searchSourceBuilder) {
doValidate(searchSourceBuilder);
}

public abstract void doValidate(SearchSourceBuilder searchSourceBuilder);

public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
doExtractToSearchSourceBuilder(searchSourceBuilder);

Expand All @@ -265,4 +272,10 @@ public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceB
}

public abstract void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder);

public SearchContext buildSearchContext(SearchContext searchContext) {
return doBuildSearchContext(searchContext);
}

public abstract SearchContext doBuildSearchContext(SearchContext searchContext);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.xcontent.ObjectParser;
Expand Down Expand Up @@ -109,6 +110,11 @@ protected RRFRetrieverBuilder shallowCopyInstance() {
return new RRFRetrieverBuilder(this);
}

@Override
public void doValidate(SearchSourceBuilder searchSourceBuilder) {
throw new UnsupportedOperationException();
}

@Override
public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
for (RetrieverBuilder<?> retrieverBuilder : retrieverBuilders) {
Expand All @@ -121,4 +127,9 @@ public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuild
throw new IllegalStateException("[rank] cannot be declared as a retriever value and as a global value");
}
}

@Override
public SearchContext doBuildSearchContext(SearchContext searchContext) {
throw new UnsupportedOperationException();
}
}

0 comments on commit 32b7372

Please sign in to comment.