From cc1620c649eff00d99924d47c3ee50a2aa8ea505 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Thu, 12 Oct 2023 04:57:17 +0800 Subject: [PATCH 1/4] [bug fix] Fix async actions are left in neural_sparse query (#438) * add serialization and deserialization Signed-off-by: zhichao-aws * hash, equals. + UT Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * add test Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws (cherry picked from commit 51e6c00770d27fb4eabc20c38bdeff23c5c45997) --- .../query/NeuralSparseQueryBuilder.java | 26 +++++++++++- .../query/NeuralSparseQueryBuilderTests.java | 42 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 3e181c73f..a0b54d4e5 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -84,6 +84,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.queryText = in.readString(); this.modelId = in.readString(); + this.maxTokenScore = in.readOptionalFloat(); + if (in.readBoolean()) { + Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); + this.queryTokensSupplier = () -> queryTokens; + } } @Override @@ -91,6 +96,13 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); out.writeString(modelId); + out.writeOptionalFloat(maxTokenScore); + if (queryTokensSupplier != null && queryTokensSupplier.get() != null) { + out.writeBoolean(true); + out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); + } else { + out.writeBoolean(false); + } } @Override @@ -257,15 +269,25 @@ private static void validateQueryTokens(Map queryTokens) { protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; + if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false; + if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) - .append(modelId, obj.modelId); + .append(modelId, obj.modelId) + .append(maxTokenScore, obj.maxTokenScore); + if (queryTokensSupplier != null) { + equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); + } return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode(); + HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore); + if (queryTokensSupplier != null) { + builder.append(queryTokensSupplier.get()); + } + return builder.toHashCode(); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7ff6ca0cb..2ba8981f4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -26,6 +26,7 @@ import lombok.SneakyThrows; import org.opensearch.client.Client; +import org.opensearch.common.SetOnce; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -262,6 +263,23 @@ public void testStreams() { NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); assertEquals(original, copy); + + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); + original.queryTokensSupplier(queryTokensSetOnce::get); + + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + copy = new NeuralSparseQueryBuilder(filterStreamInput); + assertEquals(original, copy); } public void testHashAndEquals() { @@ -275,6 +293,8 @@ public void testHashAndEquals() { float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; + Map queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f); + Map queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f); NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) @@ -329,6 +349,22 @@ public void testHashAndEquals() { .boost(boost1) .queryName(queryName2); + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens1); + + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens2); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -352,6 +388,12 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); } @SneakyThrows From 5830b8e6b77ba7f6f197b432c629b739d3e822c3 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 30 Oct 2023 11:59:13 +0800 Subject: [PATCH 2/4] rm max_token_score Signed-off-by: zhichao-aws --- .../neuralsearch/query/NeuralSparseQueryBuilder.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index a0b54d4e5..2b66d023f 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -84,7 +84,6 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { this.fieldName = in.readString(); this.queryText = in.readString(); this.modelId = in.readString(); - this.maxTokenScore = in.readOptionalFloat(); if (in.readBoolean()) { Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); this.queryTokensSupplier = () -> queryTokens; @@ -96,7 +95,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); out.writeString(modelId); - out.writeOptionalFloat(maxTokenScore); if (queryTokensSupplier != null && queryTokensSupplier.get() != null) { out.writeBoolean(true); out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); @@ -273,8 +271,7 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) - .append(modelId, obj.modelId) - .append(maxTokenScore, obj.maxTokenScore); + .append(modelId, obj.modelId); if (queryTokensSupplier != null) { equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); } @@ -283,7 +280,7 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { @Override protected int doHashCode() { - HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore); + HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId); if (queryTokensSupplier != null) { builder.append(queryTokensSupplier.get()); } From b1b0f63986e7280403cf70d56cc418d045d8ad12 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 30 Oct 2023 13:23:55 +0800 Subject: [PATCH 3/4] add changelog Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b92edd850..d5cba89c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438)) ### Infrastructure ### Documentation ### Maintenance From 48b1a4046abfe0b73b0031314702fc6e29367635 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Tue, 14 Nov 2023 10:56:57 +0800 Subject: [PATCH 4/4] tidy Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilder.java | 13 +++++++------ .../query/NeuralSparseQueryBuilderTests.java | 6 +++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 2b66d023f..86859a054 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; import lombok.AllArgsConstructor; @@ -95,7 +96,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); out.writeString(modelId); - if (queryTokensSupplier != null && queryTokensSupplier.get() != null) { + if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { out.writeBoolean(true); out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); } else { @@ -266,13 +267,13 @@ private static void validateQueryTokens(Map queryTokens) { @Override protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; - if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false; - if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false; + if (Objects.isNull(obj) || getClass() != obj.getClass()) return false; + if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false; + if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) .append(modelId, obj.modelId); - if (queryTokensSupplier != null) { + if (!Objects.isNull(queryTokensSupplier)) { equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); } return equalsBuilder.isEquals(); @@ -281,7 +282,7 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { @Override protected int doHashCode() { HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId); - if (queryTokensSupplier != null) { + if (!Objects.isNull(queryTokensSupplier)) { builder.append(queryTokensSupplier.get()); } return builder.toHashCode(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 2ba8981f4..f3fa3264d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -268,11 +268,11 @@ public void testStreams() { queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); original.queryTokensSupplier(queryTokensSetOnce::get); - streamOutput = new BytesStreamOutput(); - original.writeTo(streamOutput); + BytesStreamOutput streamOutput2 = new BytesStreamOutput(); + original.writeTo(streamOutput2); filterStreamInput = new NamedWriteableAwareStreamInput( - streamOutput.bytes().streamInput(), + streamOutput2.bytes().streamInput(), new NamedWriteableRegistry( List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) )