diff --git a/CHANGELOG.md b/CHANGELOG.md index 332a809294..78aea6681c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased 2.x] ### Added +- Added `minScore` and `maxDistance` to `KnnQuery` ([#1166](https://github.com/opensearch-project/opensearch-java/pull/1166)) ### Dependencies diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java index 596752f47c..59f03cc1ab 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQuery.java @@ -23,7 +23,12 @@ public class KnnQuery extends QueryBase implements QueryVariant { private final String field; private final float[] vector; - private final int k; + @Nullable + private final Integer k; + @Nullable + private final Float minScore; + @Nullable + private final Float maxDistance; @Nullable private final Query filter; @@ -32,7 +37,9 @@ private KnnQuery(Builder builder) { this.field = ApiTypeHelper.requireNonNull(builder.field, this, "field"); this.vector = ApiTypeHelper.requireNonNull(builder.vector, this, "vector"); - this.k = ApiTypeHelper.requireNonNull(builder.k, this, "k"); + this.k = builder.k; + this.minScore = builder.minScore; + this.maxDistance = builder.maxDistance; this.filter = builder.filter; } @@ -66,13 +73,29 @@ public final float[] vector() { } /** - * Required - The number of neighbors the search of each graph will return. + * Optional - The number of neighbors the search of each graph will return. * @return The number of neighbors to return. */ - public final int k() { + public final Integer k() { return this.k; } + /** + * Optional - The minimum score allowed for the returned search results. + * @return The minimum score allowed for the returned search results. + */ + private final Float minScore() { + return this.minScore; + } + + /** + * Optional - The maximum distance allowed between the vector and each of the returned search results. + * @return The maximum distance allowed between the vector and each ofthe returned search results. + */ + private final Float maxDistance() { + return this.maxDistance; + } + /** * Optional - A query to filter the results of the query. * @return The filter query. @@ -97,7 +120,17 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } generator.writeEnd(); - generator.write("k", this.k); + if (this.k != null) { + generator.write("k", this.k); + } + + if (this.minScore != null) { + generator.write("min_score", this.minScore); + } + + if (this.maxDistance != null) { + generator.write("max_distance", this.maxDistance); + } if (this.filter != null) { generator.writeKey("filter"); @@ -108,7 +141,7 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { } public Builder toBuilder() { - return toBuilder(new Builder()).field(field).vector(vector).k(k).filter(filter); + return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter); } /** @@ -122,6 +155,10 @@ public static class Builder extends QueryBase.AbstractBuilder implement @Nullable private Integer k; @Nullable + private Float minScore; + @Nullable + private Float maxDistance; + @Nullable private Query filter; /** @@ -156,6 +193,28 @@ public Builder k(@Nullable Integer k) { return this; } + /** + * Optional - The minimum score allowed for the returned search results. + * + * @param minScore The minimum score allowed for the returned search results. + * @return This builder. + */ + public Builder minScore(@Nullable Float minScore) { + this.minScore = minScore; + return this; + } + + /** + * Optional - The maximum distance allowed between the vector and each of the returned search results. + * + * @param maxDistance The maximum distance allowed between the vector and each ofthe returned search results. + * @return This builder. + */ + public Builder maxDistance(@Nullable Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + /** * Optional - A query to filter the results of the knn query. * @@ -201,6 +260,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer op) b.vector(vector); }, JsonpDeserializer.arrayDeserializer(JsonpDeserializer.floatDeserializer()), "vector"); op.add(Builder::k, JsonpDeserializer.integerDeserializer(), "k"); + op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score"); + op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance"); op.add(Builder::filter, Query._DESERIALIZER, "filter"); op.setKey(Builder::field, JsonpDeserializer.stringDeserializer()); diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java index a8a3fd779b..941f5224d7 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/_types/query_dsl/KnnQueryTest.java @@ -14,7 +14,7 @@ public class KnnQueryTest extends ModelTestCase { @Test public void toBuilder() { - KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).build(); + KnnQuery origin = new KnnQuery.Builder().field("field").vector(new float[] { 1.0f }).k(1).minScore(0.0f).maxDistance(1.0f).build(); KnnQuery copied = origin.toBuilder().build(); assertEquals(toJson(copied), toJson(origin)); diff --git a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java index f647ae56f0..2251ff4c4a 100644 --- a/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java +++ b/java-client/src/test/java/org/opensearch/client/opensearch/model/VariantsTest.java @@ -282,7 +282,7 @@ public void testHybridQuery() { assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); - assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k()); } @Test @@ -304,6 +304,6 @@ public void testHybridQueryFromJson() { assertEquals(100, searchRequest.query().hybrid().queries().get(1).neural().k()); assertEquals("passage_embedding", searchRequest.query().hybrid().queries().get(2).knn().field()); assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().vector().length); - assertEquals(2, searchRequest.query().hybrid().queries().get(2).knn().k()); + assertEquals(Integer.valueOf(2), searchRequest.query().hybrid().queries().get(2).knn().k()); } }