Skip to content

Commit

Permalink
add javadocs and a bit of clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
jdconrad committed Feb 2, 2024
1 parent 0434bd5 commit 2bd1882
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* A knn retriever is used to represent a knn search
* with some elements to specify parameters for that knn search.
*/
public final class KnnRetrieverBuilder extends RetrieverBuilder<KnnRetrieverBuilder> {

public static final String NAME = "knn";
Expand Down Expand Up @@ -106,7 +110,7 @@ public KnnRetrieverBuilder(
}

@Override
public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(field, queryVector, queryVectorBuilder, k, numCands, similarity);
if (preFilterQueryBuilders != null) {
knnSearchBuilder.addFilterQueries(preFilterQueryBuilders);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
import java.util.Locale;
import java.util.Objects;

/**
* A retriever represents an API element that returns an ordered list of top
* documents. These can be obtained from a query, from another retriever, etc.
* Internally, a {@link RetrieverBuilder} is just a wrapper for other search
* elements that are extracted into a {@link SearchSourceBuilder}. The advantage
* retrievers have is in the API they appear as a tree-like structure enabling
* easier reasoning about what a search does.
*
* This is the base class for all other retrievers. This class does not support
* serialization and is expected to be fully extracted to a {@link SearchSourceBuilder}
* prior to any transport calls.
*/
public abstract class RetrieverBuilder<RB extends RetrieverBuilder<RB>> {

public static final NodeFeature NODE_FEATURE = new NodeFeature("retrievers");
Expand All @@ -37,13 +49,18 @@ protected static void declareBaseParserFields(
String name,
AbstractObjectParser<? extends RetrieverBuilder<?>, RetrieverParserContext> parser
) {
parser.declareObjectArray(RetrieverBuilder::preFilterQueryBuilders, (p, c) -> {
parser.declareObjectArray((r, v) -> r.preFilterQueryBuilders = v, (p, c) -> {
QueryBuilder preFilterQueryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage);
c.trackSectionUsage(name + ":" + PRE_FILTER_FIELD.getPreferredName());
return preFilterQueryBuilder;
}, PRE_FILTER_FIELD);
}

/**
* This method parsers a top-level retriever within a search and tracks its own depth. Currently, the
* maximum depth allowed is limited to 2 as a compound retriever cannot currently contain another
* compound retriever.
*/
public static RetrieverBuilder<?> parseTopLevelRetrieverBuilder(XContentParser parser, RetrieverParserContext context)
throws IOException {
parser = new FilterXContentParserWrapper(parser) {
Expand Down Expand Up @@ -153,19 +170,16 @@ protected static RetrieverBuilder<?> parseInnerRetrieverBuilder(XContentParser p

protected List<QueryBuilder> preFilterQueryBuilders = new ArrayList<>();

public List<QueryBuilder> preFilterQueryBuilders() {
/**
* Gets the filters for this retriever.
*/
public List<QueryBuilder> getPreFilterQueryBuilders() {
return preFilterQueryBuilders;
}

@SuppressWarnings("unchecked")
public RB preFilterQueryBuilders(List<QueryBuilder> preFilterQueryBuilders) {
this.preFilterQueryBuilders = preFilterQueryBuilders;
return (RB) this;
}

public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
doExtractToSearchSourceBuilder(searchSourceBuilder);
}

public abstract void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder);
/**
* This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}.
* Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}.
*/
public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

import java.util.Set;

/**
* Each retriever is given its own {@link NodeFeature} so new
* retrievers can be added individually with additional functionality.
*/
public class RetrieversFeatures implements FeatureSpecification {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
import java.io.IOException;
import java.util.List;

/**
* A standard retriever is used to represent anything that is a query along
* with some elements to specify parameters for that query.
*/
public final class StandardRetrieverBuilder extends RetrieverBuilder<StandardRetrieverBuilder> {

public static final String NAME = "standard";
Expand All @@ -44,37 +48,37 @@ public final class StandardRetrieverBuilder extends RetrieverBuilder<StandardRet
);

static {
PARSER.declareObject(StandardRetrieverBuilder::queryBuilder, (p, c) -> {
PARSER.declareObject((r, v) -> r.queryBuilder = v, (p, c) -> {
QueryBuilder queryBuilder = AbstractQueryBuilder.parseTopLevelQuery(p, c::trackQueryUsage);
c.trackSectionUsage(NAME + ":" + QUERY_FIELD.getPreferredName());
return queryBuilder;
}, QUERY_FIELD);

PARSER.declareField(StandardRetrieverBuilder::searchAfterBuilder, (p, c) -> {
PARSER.declareField((r, v) -> r.searchAfterBuilder = v, (p, c) -> {
SearchAfterBuilder searchAfterBuilder = SearchAfterBuilder.fromXContent(p);
c.trackSectionUsage(NAME + ":" + SEARCH_AFTER_FIELD.getPreferredName());
return searchAfterBuilder;
}, SEARCH_AFTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);

PARSER.declareField(StandardRetrieverBuilder::terminateAfter, (p, c) -> {
PARSER.declareField((r, v) -> r.terminateAfter = v, (p, c) -> {
int terminateAfter = p.intValue();
c.trackSectionUsage(NAME + ":" + TERMINATE_AFTER_FIELD.getPreferredName());
return terminateAfter;
}, TERMINATE_AFTER_FIELD, ObjectParser.ValueType.INT);

PARSER.declareField(StandardRetrieverBuilder::sortBuilders, (p, c) -> {
PARSER.declareField((r, v) -> r.sortBuilders = v, (p, c) -> {
List<SortBuilder<?>> sortBuilders = SortBuilder.fromXContent(p);
c.trackSectionUsage(NAME + ":" + SORT_FIELD.getPreferredName());
return sortBuilders;
}, SORT_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);

PARSER.declareField(StandardRetrieverBuilder::minScore, (p, c) -> {
PARSER.declareField((r, v) -> r.minScore = v, (p, c) -> {
float minScore = p.floatValue();
c.trackSectionUsage(NAME + ":" + MIN_SCORE_FIELD.getPreferredName());
return minScore;
}, MIN_SCORE_FIELD, ObjectParser.ValueType.FLOAT);

PARSER.declareField(StandardRetrieverBuilder::collapseBuilder, (p, c) -> {
PARSER.declareField((r, v) -> r.collapseBuilder = v, (p, c) -> {
CollapseBuilder collapseBuilder = CollapseBuilder.fromXContent(p);
if (collapseBuilder.getField() != null) {
c.trackSectionUsage(COLLAPSE_FIELD.getPreferredName());
Expand All @@ -99,62 +103,9 @@ public static StandardRetrieverBuilder fromXContent(XContentParser parser, Retri
private Float minScore;
private CollapseBuilder collapseBuilder;

public QueryBuilder queryBuilder() {
return queryBuilder;
}

public StandardRetrieverBuilder queryBuilder(QueryBuilder queryBuilder) {
this.queryBuilder = queryBuilder;
return this;
}

public SearchAfterBuilder searchAfterBuilder() {
return searchAfterBuilder;
}

public StandardRetrieverBuilder searchAfterBuilder(SearchAfterBuilder searchAfterBuilder) {
this.searchAfterBuilder = searchAfterBuilder;
return this;
}

public int terminateAfter() {
return terminateAfter;
}

public StandardRetrieverBuilder terminateAfter(int terminateAfter) {
this.terminateAfter = terminateAfter;
return this;
}

public List<SortBuilder<?>> sortBuilders() {
return sortBuilders;
}

public StandardRetrieverBuilder sortBuilders(List<SortBuilder<?>> sortBuilders) {
this.sortBuilders = sortBuilders;
return this;
}

public Float minScore() {
return minScore;
}

public StandardRetrieverBuilder minScore(Float minScore) {
this.minScore = minScore;
return this;
}

public CollapseBuilder collapseBuilder() {
return collapseBuilder;
}

public StandardRetrieverBuilder collapseBuilder(CollapseBuilder collapseBuilder) {
this.collapseBuilder = collapseBuilder;
return this;
}

public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
if (preFilterQueryBuilders().isEmpty() == false) {
@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
if (preFilterQueryBuilders.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();

for (QueryBuilder preFilterQueryBuilder : preFilterQueryBuilders) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import java.util.Set;

/**
* A set of features specifically for the rrf plugin.
*/
public class RRFFeatureSpecification implements FeatureSpecification {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
import java.util.Collections;
import java.util.List;

/**
* An rrf retriever is used to represent an rrf rank element, but
* as a tree-like structure. This retriever is a compound retriever
* meaning it has a set of child retrievers that each return a set of
* top docs that will then be combined and ranked according to the rrf
* formula.
*/
public final class RRFRetrieverBuilder extends RetrieverBuilder<RRFRetrieverBuilder> {

public static final NodeFeature NODE_FEATURE = new NodeFeature(RRFRankPlugin.NAME + "_retriever");
Expand All @@ -36,15 +43,15 @@ public final class RRFRetrieverBuilder extends RetrieverBuilder<RRFRetrieverBuil
);

static {
PARSER.declareObjectArray((v, l) -> v.retrieverBuilders = l, (p, c) -> {
PARSER.declareObjectArray((r, v) -> r.retrieverBuilders = v, (p, c) -> {
p.nextToken();
String name = p.currentName();
RetrieverBuilder<?> retrieverBuilder = (RetrieverBuilder<?>) p.namedObject(RetrieverBuilder.class, name, c);
p.nextToken();
return retrieverBuilder;
}, RETRIEVERS_FIELD);
PARSER.declareInt((b, v) -> b.windowSize = v, WINDOW_SIZE_FIELD);
PARSER.declareInt((b, v) -> b.rankConstant = v, RANK_CONSTANT_FIELD);
PARSER.declareInt((r, v) -> r.windowSize = v, WINDOW_SIZE_FIELD);
PARSER.declareInt((r, v) -> r.rankConstant = v, RANK_CONSTANT_FIELD);

RetrieverBuilder.declareBaseParserFields(RRFRankPlugin.NAME, PARSER);
}
Expand All @@ -64,13 +71,13 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
private int rankConstant = RRFRankBuilder.DEFAULT_RANK_CONSTANT;

@Override
public void doExtractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder) {
for (RetrieverBuilder<?> retrieverBuilder : retrieverBuilders) {
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.preFilterQueryBuilders().addAll(preFilterQueryBuilders);
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}

retrieverBuilder.doExtractToSearchSourceBuilder(searchSourceBuilder);
retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder);
}

if (searchSourceBuilder.rankBuilder() == null) {
Expand Down

0 comments on commit 2bd1882

Please sign in to comment.