Skip to content

Commit

Permalink
Add "filter" to neural query (#932) (#934)
Browse files Browse the repository at this point in the history
* 🩹 add filter to neural query



* CHANGELOG.md



* 💚 spotless fix



* 💚 spotless fix, jdk8 check



* 💚 spotless fix



* Update CHANGELOG.md




---------




(cherry picked from commit bc78613)

Signed-off-by: Lorenzo Caenazzo <lorenzo.caenazzo@optionfactory.net>
Signed-off-by: Grogdunn <frzlollo@gmail.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Andriy Redko <drreta@gmail.com>
  • Loading branch information
3 people authored Apr 10, 2024
1 parent d233863 commit c4d8a2e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Support weight function in function score query ([#880](https://github.com/opensearch-project/opensearch-java/pull/880))
- Fix pattern replace by making flag and replacement optional as on api ([#895](https://github.com/opensearch-project/opensearch-java/pull/895))
- Client with Java 8 runtime and Apache HttpClient 5 Transport fails with java.lang.NoSuchMethodError: java.nio.ByteBuffer.flip()Ljava/nio/ByteBuffer ([#920](https://github.com/opensearch-project/opensearch-java/pull/920))
- Add missed field "filter" to NeuralQuery model class

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public class NeuralQuery extends QueryBase implements QueryVariant {
private final int k;
@Nullable
private final String modelId;
@Nullable
private final Query filter;

private NeuralQuery(NeuralQuery.Builder builder) {
super(builder);
Expand All @@ -35,6 +37,7 @@ private NeuralQuery(NeuralQuery.Builder builder) {
this.queryText = ApiTypeHelper.requireNonNull(builder.queryText, this, "queryText");
this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k");
this.modelId = builder.modelId;
this.filter = builder.filter;
}

public static NeuralQuery of(Function<NeuralQuery.Builder, ObjectBuilder<NeuralQuery>> fn) {
Expand Down Expand Up @@ -93,6 +96,16 @@ public final String modelId() {
return this.modelId;
}

/**
* Optional - A query to filter the results of the query.
*
* @return The filter query.
*/
@Nullable
public final Query filter() {
return this.filter;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);
Expand All @@ -107,11 +120,16 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {

generator.write("k", this.k);

if (this.filter != null) {
generator.writeKey("filter");
this.filter.serialize(generator, mapper);
}

generator.writeEnd();
}

public Builder toBuilder() {
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId);
return new Builder().field(field).queryText(queryText).k(k).modelId(modelId).filter(filter);
}

/**
Expand All @@ -123,6 +141,8 @@ public static class Builder extends QueryBase.AbstractBuilder<NeuralQuery.Builde
private Integer k;
@Nullable
private String modelId;
@Nullable
private Query filter;

/**
* Required - The target field.
Expand Down Expand Up @@ -169,6 +189,17 @@ public NeuralQuery.Builder k(@Nullable Integer k) {
return this;
}

/**
* Optional - A query to filter the results of the knn query.
*
* @param filter The filter query.
* @return This builder.
*/
public NeuralQuery.Builder filter(@Nullable Query filter) {
this.filter = filter;
return this;
}

@Override
protected NeuralQuery.Builder self() {
return this;
Expand Down Expand Up @@ -198,6 +229,7 @@ protected static void setupNeuralQueryDeserializer(ObjectDeserializer<NeuralQuer
op.add(NeuralQuery.Builder::queryText, JsonpDeserializer.stringDeserializer(), "query_text");
op.add(NeuralQuery.Builder::modelId, JsonpDeserializer.stringDeserializer(), "model_id");
op.add(NeuralQuery.Builder::k, JsonpDeserializer.integerDeserializer(), "k");
op.add(NeuralQuery.Builder::filter, Query._DESERIALIZER, "filter");

op.setKey(NeuralQuery.Builder::field, JsonpDeserializer.stringDeserializer());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
public class NeuralQueryTest extends ModelTestCase {
@Test
public void toBuilder() {
NeuralQuery origin = new NeuralQuery.Builder().field("field").queryText("queryText").k(1).build();
NeuralQuery origin = new NeuralQuery.Builder().field("field")
.queryText("queryText")
.k(1)
.filter(IdsQuery.of(builder -> builder.values("Some_ID")).toQuery())
.build();
NeuralQuery copied = origin.toBuilder().build();

assertEquals(toJson(copied), toJson(origin));
Expand Down

0 comments on commit c4d8a2e

Please sign in to comment.