Skip to content

Commit

Permalink
feature: implement default model id for neural sparse
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws committed Feb 26, 2024
1 parent e6202af commit 8257486
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ private NeuralQueryEnricherProcessor(
@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
/* Use null check to pass the case where users are using empty query body. i.e. GET /index_name/_search */
if (queryBuilder != null) {
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
}
return searchRequest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
Expand All @@ -32,6 +33,7 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -72,6 +74,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private String queryText;
private String modelId;
private Supplier<Map<String, Float>> queryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

/**
* Constructor from stream input
Expand All @@ -83,7 +86,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
this.modelId = in.readOptionalString();
} else {
this.modelId = in.readString();
}
if (in.readBoolean()) {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
Expand All @@ -94,7 +101,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(modelId);
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
}
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
Expand All @@ -108,7 +119,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
if (modelId != null) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand Down Expand Up @@ -152,11 +165,12 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw
sparseEncodingQueryBuilder.queryText(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME)
);
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
);

if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
);
}
return sparseEncodingQueryBuilder;
}

Expand Down Expand Up @@ -291,4 +305,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
}
}

0 comments on commit 8257486

Please sign in to comment.