diff --git a/.github/workflows/test_aggregations.yml b/.github/workflows/test_aggregations.yml new file mode 100644 index 000000000..3bd5a66c0 --- /dev/null +++ b/.github/workflows/test_aggregations.yml @@ -0,0 +1,71 @@ +name: Run Additional Tests for Neural Search +on: + schedule: + - cron: '0 0 * * *' # every night + push: + branches: + - "*" + - "feature/**" + pull_request: + branches: + - "*" + - "feature/**" + +jobs: + Get-CI-Image-Tag: + uses: opensearch-project/opensearch-build/.github/workflows/get-ci-image-tag.yml@main + with: + product: opensearch + + Check-neural-search-linux: + needs: Get-CI-Image-Tag + strategy: + matrix: + java: [11, 17, 21] + os: [ubuntu-latest] + + name: Integ Tests Linux + runs-on: ${{ matrix.os }} + container: + # using the same image which is used by opensearch-build team to build the OpenSearch Distribution + # this image tag is subject to change as more dependencies and updates will arrive over time + image: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-version-linux }} + # need to switch to root so that github actions can install runner binary on container without permission issues. + options: --user root + + + steps: + - name: Checkout neural-search + uses: actions/checkout@v1 + + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + - name: Run tests + run: | + chown -R 1000:1000 `pwd` + su `id -un 1000` -c "./gradlew ':integTest' -Dtest_aggs=true --tests \"org.opensearch.neuralsearch.query.aggregation.*IT\"" + + Check-neural-search-windows: + strategy: + matrix: + java: [11, 17, 21] + os: [windows-latest] + + name: Integ Tests Windows + runs-on: ${{ matrix.os }} + + steps: + - name: Checkout neural-search + uses: actions/checkout@v1 + + - name: Setup Java ${{ matrix.java }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.java }} + + - name: Run tests + run: | + ./gradlew ':integTest' -Dtest_aggs=true --tests "org.opensearch.neuralsearch.query.aggregation.*IT" diff --git a/CHANGELOG.md b/CHANGELOG.md index 2db64156a..03e1b0cac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578)) - Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615)) - Bug test for other BWC tests +- Add max_token_score field placeholder in NeuralSparseQueryBuilder to fix the rolling-upgrade from 2.x nodes bwc tests. ([#696](https://github.com/opensearch-project/neural-search/pull/696)) ### Infrastructure +- Adding integration tests for scenario of hybrid query with aggregations ([#632](https://github.com/opensearch-project/neural-search/pull/632)) ### Documentation ### Maintenance ### Refactoring @@ -19,7 +21,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.13...2.x) ### Features ### Enhancements +- BWC tests for text chunking processor ([#661](https://github.com/opensearch-project/neural-search/pull/661)) - Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670)) +- Allowing query by raw tokens in neural_sparse query ([#693](https://github.com/opensearch-project/neural-search/pull/693)) ### Bug Fixes - Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663)) ### Infrastructure diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 62187e543..58562826d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,10 +31,13 @@ To send us a pull request, please: 1. Fork the repository. 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. -3. Ensure local tests pass. -4. Commit to your fork using clear commit messages. -5. Send us a pull request, answering any default questions in the pull request interface. -6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. +3. Include tests that check your new feature or bug fix. Ideally, we're looking for unit, integration, and BWC tests, but that depends on how big and critical your change is. +If you're adding an integration test and it is using local ML models, please make sure that the number of model deployments is limited, and you're using the smallest possible model. +Each model deployment consumes resources, and having too many models may cause unexpected test failures. +4. Ensure local tests pass. +5. Commit to your fork using clear commit messages. +6. Send us a pull request, answering any default questions in the pull request interface. +7. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index acbb00883..47ae31be6 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -181,6 +181,11 @@ Additionally, to run integration tests on multi nodes with security enabled, run ./gradlew :integTest -Dsecurity.enabled=true -PnumNodes=3 ``` +Some integration tests are skipped by default, mainly to save time and resources. A special parameter is required to include those tests in the executed test suite. For example, the following command enables additional tests for aggregations when they are bundled with hybrid queries +``` +./gradlew :integTest -PnumNodes=3 -Dtest_aggs=true +``` + Integration tests can be run with remote cluster. For that run the following command and replace host/port/cluster name values with ones for the target cluster: ``` diff --git a/build.gradle b/build.gradle index 8f5c67c63..be234cfe6 100644 --- a/build.gradle +++ b/build.gradle @@ -308,6 +308,12 @@ task integTest(type: RestIntegTestTask) { description = "Run tests against a cluster" testClassesDirs = sourceSets.test.output.classesDirs classpath = sourceSets.test.runtimeClasspath + boolean runCompleteAggsTestSuite = Boolean.parseBoolean(System.getProperty('test_aggs', "false")) + if (!runCompleteAggsTestSuite) { + filter { + excludeTestsMatching "org.opensearch.neuralsearch.query.aggregation.*IT" + } + } } tasks.named("check").configure { dependsOn(integTest) } diff --git a/qa/restart-upgrade/build.gradle b/qa/restart-upgrade/build.gradle index 1a6d0a104..8fca43f3a 100644 --- a/qa/restart-upgrade/build.gradle +++ b/qa/restart-upgrade/build.gradle @@ -65,7 +65,7 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { @@ -83,6 +83,13 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -107,7 +114,7 @@ task testAgainstNewCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.is_old_cluster', 'false' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { @@ -125,6 +132,13 @@ task testAgainstNewCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java index c2d2657f4..bdbba92e8 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRestartUpgradeRestTestCase.java @@ -11,11 +11,11 @@ import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.CLIENT_TIMEOUT_VALUE; -import static org.opensearch.neuralsearch.TestUtils.RESTART_UPGRADE_OLD_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.BWC_VERSION; -import static org.opensearch.neuralsearch.TestUtils.generateModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.CLIENT_TIMEOUT_VALUE; +import static org.opensearch.neuralsearch.util.TestUtils.RESTART_UPGRADE_OLD_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.BWC_VERSION; +import static org.opensearch.neuralsearch.util.TestUtils.generateModelId; import org.opensearch.test.rest.OpenSearchRestTestCase; public abstract class AbstractRestartUpgradeRestTestCase extends BaseNeuralSearchIT { @@ -99,4 +99,11 @@ protected void createPipelineForSparseEncodingProcessor(final String modelId, fi ); createPipelineProcessor(requestBody, pipelineName, modelId); } + + protected void createPipelineForTextChunkingProcessor(String pipelineName) throws Exception { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/PipelineForTextChunkingProcessorConfiguration.json").toURI()) + ); + createPipelineProcessor(requestBody, pipelineName, ""); + } } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 48735182a..f5289fe79 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -11,12 +11,12 @@ import java.util.List; import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.getModelId; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index e6749d778..1d9dde2c6 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class MultiModalSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java index 876b2b0d7..02edb486c 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -4,12 +4,12 @@ */ package org.opensearch.neuralsearch.bwc; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; import org.opensearch.common.settings.Settings; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java index 22bd4a281..8ec54711a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java @@ -10,10 +10,10 @@ import java.util.Map; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import org.opensearch.neuralsearch.util.TestUtils; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class NeuralSparseSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java index ec5938cd9..27ca7f42d 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.getModelId; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class SemanticSearchIT extends AbstractRestartUpgradeRestTestCase { diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java new file mode 100644 index 000000000..bab2d78d5 --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.index.query.MatchAllQueryBuilder; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; + +public class TextChunkingProcessorIT extends AbstractRestartUpgradeRestTestCase { + + private static final String PIPELINE_NAME = "pipeline-text-chunking"; + private static final String INPUT_FIELD = "body"; + private static final String OUTPUT_FIELD = "body_chunk"; + private static final String TEST_INDEX_SETTING_PATH = "processor/ChunkingIndexSettings.json"; + private static final String TEST_INGEST_TEXT = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + List expectedPassages = List.of( + "This is an example document to be chunked. The document ", + "contains a single paragraph, two sentences and 24 tokens by ", + "standard tokenizer in OpenSearch." + ); + + // Test rolling-upgrade text chunking processor + // Create Text Chunking Processor, Ingestion Pipeline and add document + // Validate process, pipeline and document count in restart-upgrade scenario + public void testTextChunkingProcessor_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + String indexName = getIndexNameForTest(); + if (isRunningAgainstOldCluster()) { + createPipelineForTextChunkingProcessor(PIPELINE_NAME); + createChunkingIndex(indexName); + addDocument(indexName, "0", INPUT_FIELD, TEST_INGEST_TEXT, null, null); + validateTestIndex(indexName, OUTPUT_FIELD, 1, expectedPassages); + } else { + try { + addDocument(indexName, "1", INPUT_FIELD, TEST_INGEST_TEXT, null, null); + validateTestIndex(indexName, OUTPUT_FIELD, 2, expectedPassages); + } finally { + wipeOfTestResources(indexName, PIPELINE_NAME, null, null); + } + } + } + + private void createChunkingIndex(String indexName) throws Exception { + URL documentURLPath = classLoader.getResource(TEST_INDEX_SETTING_PATH); + Objects.requireNonNull(documentURLPath); + String indexSetting = Files.readString(Path.of(documentURLPath.toURI())); + createIndexWithConfiguration(indexName, indexSetting, PIPELINE_NAME); + } + + private void validateTestIndex(String indexName, String fieldName, int documentCount, Object expected) { + int docCount = getDocCount(indexName); + assertEquals(documentCount, docCount); + MatchAllQueryBuilder query = new MatchAllQueryBuilder(); + Map searchResults = search(indexName, query, 10); + assertNotNull(searchResults); + Map document = getFirstInnerHit(searchResults); + assertNotNull(document); + Object documentSource = document.get("_source"); + assert (documentSource instanceof Map); + @SuppressWarnings("unchecked") + Map documentSourceMap = (Map) documentSource; + assert (documentSourceMap).containsKey(fieldName); + Object ingestOutputs = documentSourceMap.get(fieldName); + assertEquals(expected, ingestOutputs); + } +} diff --git a/qa/restart-upgrade/src/test/resources/processor/ChunkingIndexSettings.json b/qa/restart-upgrade/src/test/resources/processor/ChunkingIndexSettings.json new file mode 100644 index 000000000..956ffc585 --- /dev/null +++ b/qa/restart-upgrade/src/test/resources/processor/ChunkingIndexSettings.json @@ -0,0 +1,17 @@ +{ + "settings":{ + "default_pipeline": "%s", + "number_of_shards": 3, + "number_of_replicas": 1 + }, + "mappings": { + "properties": { + "body": { + "type": "text" + }, + "body_chunk": { + "type": "text" + } + } + } +} diff --git a/qa/restart-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json b/qa/restart-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json new file mode 100644 index 000000000..6c727b3b4 --- /dev/null +++ b/qa/restart-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json @@ -0,0 +1,18 @@ +{ + "description": "An example fixed token length chunker pipeline with standard tokenizer", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10, + "tokenizer": "standard" + } + } + } + } + ] +} diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 591e83d58..eedea2d2d 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -83,6 +83,13 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -126,6 +133,13 @@ task testAgainstOneThirdUpgradedCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -150,7 +164,7 @@ task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { @@ -168,6 +182,13 @@ task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' @@ -210,6 +231,13 @@ task testRollingUpgrade(type: StandaloneRestIntegTestTask) { } } + // Excluding the text chunking processor test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10") || ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.TextChunkingProcessorIT.*" + } + } + nonInputProperties.systemProperty('tests.rest.cluster', "${-> testClusters."${baseName}".allHttpSocketURI.join(",")}") nonInputProperties.systemProperty('tests.clustername', "${-> testClusters."${baseName}".getName()}") systemProperty 'tests.security.manager', 'false' diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java index 16ed2d229..fffc878e8 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/AbstractRollingUpgradeTestCase.java @@ -11,14 +11,14 @@ import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.OLD_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.MIXED_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.UPGRADED_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.ROLLING_UPGRADE_FIRST_ROUND; -import static org.opensearch.neuralsearch.TestUtils.BWCSUITE_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.BWC_VERSION; -import static org.opensearch.neuralsearch.TestUtils.generateModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.OLD_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.MIXED_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.UPGRADED_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.ROLLING_UPGRADE_FIRST_ROUND; +import static org.opensearch.neuralsearch.util.TestUtils.BWCSUITE_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.BWC_VERSION; +import static org.opensearch.neuralsearch.util.TestUtils.generateModelId; import org.opensearch.test.rest.OpenSearchRestTestCase; public abstract class AbstractRollingUpgradeTestCase extends BaseNeuralSearchIT { @@ -130,4 +130,11 @@ protected void createPipelineForSparseEncodingProcessor(String modelId, String p ); createPipelineProcessor(requestBody, pipelineName, modelId); } + + protected void createPipelineForTextChunkingProcessor(String pipelineName) throws Exception { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/PipelineForTextChunkingProcessorConfiguration.json").toURI()) + ); + createPipelineProcessor(requestBody, pipelineName, ""); + } } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java index 292540820..903ffc9be 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/HybridSearchIT.java @@ -10,12 +10,12 @@ import java.util.List; import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java index b91ec1322..e10ddd17e 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/MultiModalSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_IMAGE_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class MultiModalSearchIT extends AbstractRollingUpgradeTestCase { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java index 281c78821..c9897447e 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.bwc; import org.opensearch.common.settings.Settings; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; @@ -13,9 +13,9 @@ import java.nio.file.Path; import java.util.List; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; public class NeuralQueryEnricherProcessorIT extends AbstractRollingUpgradeTestCase { // add prefix to avoid conflicts with other IT class, since we don't wipe resources after first round diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java index 70513686b..0801ea201 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralSparseSearchIT.java @@ -10,11 +10,11 @@ import java.util.Map; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import org.opensearch.neuralsearch.util.TestUtils; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class NeuralSparseSearchIT extends AbstractRollingUpgradeTestCase { @@ -36,7 +36,7 @@ public class NeuralSparseSearchIT extends AbstractRollingUpgradeTestCase { // Test rolling-upgrade test sparse embedding processor // Create Sparse Encoding Processor, Ingestion Pipeline and add document - // Validate process , pipeline and document count in restart-upgrade scenario + // Validate process , pipeline and document count in rolling-upgrade scenario public void testSparseEncodingProcessor_E2EFlow() throws Exception { waitForClusterHealthGreen(NODES_BWC_CLUSTER); switch (getClusterType()) { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java index 51e548474..b9f7b15a9 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/SemanticSearchIT.java @@ -7,9 +7,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Map; -import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.TestUtils.getModelId; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; +import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class SemanticSearchIT extends AbstractRollingUpgradeTestCase { diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java new file mode 100644 index 000000000..f2451c480 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/TextChunkingProcessorIT.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.index.query.MatchAllQueryBuilder; +import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; + +public class TextChunkingProcessorIT extends AbstractRollingUpgradeTestCase { + + private static final String PIPELINE_NAME = "pipeline-text-chunking"; + private static final String INPUT_FIELD = "body"; + private static final String OUTPUT_FIELD = "body_chunk"; + private static final String TEST_INDEX_SETTING_PATH = "processor/ChunkingIndexSettings.json"; + private static final int NUM_DOCS_PER_ROUND = 1; + private static final String TEST_INGEST_TEXT = + "This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch."; + + List expectedPassages = List.of( + "This is an example document to be chunked. The document ", + "contains a single paragraph, two sentences and 24 tokens by ", + "standard tokenizer in OpenSearch." + ); + + // Test rolling-upgrade text chunking processor + // Create Text Chunking Processor, Ingestion Pipeline and add document + // Validate process, pipeline and document count in rolling-upgrade scenario + public void testTextChunkingProcessor_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + String indexName = getIndexNameForTest(); + switch (getClusterType()) { + case OLD: + createPipelineForTextChunkingProcessor(PIPELINE_NAME); + createChunkingIndex(indexName); + addDocument(indexName, "0", INPUT_FIELD, TEST_INGEST_TEXT, null, null); + break; + case MIXED: + int totalDocsCountMixed; + if (isFirstMixedRound()) { + totalDocsCountMixed = NUM_DOCS_PER_ROUND; + validateTestIndex(indexName, OUTPUT_FIELD, totalDocsCountMixed, expectedPassages); + addDocument(indexName, "1", INPUT_FIELD, TEST_INGEST_TEXT, null, null); + } else { + totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; + validateTestIndex(indexName, OUTPUT_FIELD, totalDocsCountMixed, expectedPassages); + } + break; + case UPGRADED: + try { + int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; + addDocument(indexName, "2", INPUT_FIELD, TEST_INGEST_TEXT, null, null); + validateTestIndex(indexName, OUTPUT_FIELD, totalDocsCountUpgraded, expectedPassages); + } finally { + wipeOfTestResources(indexName, PIPELINE_NAME, null, null); + } + break; + default: + throw new IllegalStateException("Unexpected value: " + getClusterType()); + } + } + + private void createChunkingIndex(String indexName) throws Exception { + URL documentURLPath = classLoader.getResource(TEST_INDEX_SETTING_PATH); + Objects.requireNonNull(documentURLPath); + String indexSetting = Files.readString(Path.of(documentURLPath.toURI())); + createIndexWithConfiguration(indexName, indexSetting, PIPELINE_NAME); + } + + private void validateTestIndex(String indexName, String fieldName, int documentCount, Object expected) { + int docCount = getDocCount(indexName); + assertEquals(documentCount, docCount); + MatchAllQueryBuilder query = new MatchAllQueryBuilder(); + Map searchResults = search(indexName, query, 10); + assertNotNull(searchResults); + Map document = getFirstInnerHit(searchResults); + assertNotNull(document); + Object documentSource = document.get("_source"); + assert (documentSource instanceof Map); + @SuppressWarnings("unchecked") + Map documentSourceMap = (Map) documentSource; + assert (documentSourceMap).containsKey(fieldName); + Object ingestOutputs = documentSourceMap.get(fieldName); + assertEquals(expected, ingestOutputs); + } +} diff --git a/qa/rolling-upgrade/src/test/resources/processor/ChunkingIndexSettings.json b/qa/rolling-upgrade/src/test/resources/processor/ChunkingIndexSettings.json new file mode 100644 index 000000000..956ffc585 --- /dev/null +++ b/qa/rolling-upgrade/src/test/resources/processor/ChunkingIndexSettings.json @@ -0,0 +1,17 @@ +{ + "settings":{ + "default_pipeline": "%s", + "number_of_shards": 3, + "number_of_replicas": 1 + }, + "mappings": { + "properties": { + "body": { + "type": "text" + }, + "body_chunk": { + "type": "text" + } + } + } +} diff --git a/qa/rolling-upgrade/src/test/resources/processor/PipelineForSparseEncodingProcessorConfiguration.json b/qa/rolling-upgrade/src/test/resources/processor/PipelineForSparseEncodingProcessorConfiguration.json index d9a358c24..fe885a0a2 100644 --- a/qa/rolling-upgrade/src/test/resources/processor/PipelineForSparseEncodingProcessorConfiguration.json +++ b/qa/rolling-upgrade/src/test/resources/processor/PipelineForSparseEncodingProcessorConfiguration.json @@ -1,13 +1,13 @@ { - "description": "An sparse encoding ingest pipeline", - "processors": [ - { - "sparse_encoding": { - "model_id": "%s", - "field_map": { - "passage_text": "passage_embedding" - } + "description": "An sparse encoding ingest pipeline", + "processors": [ + { + "sparse_encoding": { + "model_id": "%s", + "field_map": { + "passage_text": "passage_embedding" } } - ] - } + } + ] +} diff --git a/qa/rolling-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json b/qa/rolling-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json new file mode 100644 index 000000000..6c727b3b4 --- /dev/null +++ b/qa/rolling-upgrade/src/test/resources/processor/PipelineForTextChunkingProcessorConfiguration.json @@ -0,0 +1,18 @@ +{ + "description": "An example fixed token length chunker pipeline with standard tokenizer", + "processors" : [ + { + "text_chunking": { + "field_map": { + "body": "body_chunk" + }, + "algorithm": { + "fixed_token_length": { + "token_limit": 10, + "tokenizer": "standard" + } + } + } + } + ] +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 58e912788..6c3b06967 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -46,7 +47,7 @@ import lombok.extern.log4j.Log4j2; /** - * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model + * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model * or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed * to Lucene FeatureQuery wrapped by Lucene BooleanQuery. */ @@ -62,20 +63,26 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder> queryTokensSupplier; private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; + public static void initialize(MLCommonsClientAccessor mlClient) { + NeuralSparseQueryBuilder.ML_CLIENT = mlClient; + } + /** * Constructor from stream input * @@ -91,24 +98,36 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { } else { this.modelId = in.readString(); } + this.maxTokenScore = in.readOptionalFloat(); if (in.readBoolean()) { Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); this.queryTokensSupplier = () -> queryTokens; } + // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API + // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead + if (StringUtils.EMPTY.equals(this.queryText)) { + this.queryText = null; + } + if (StringUtils.EMPTY.equals(this.modelId)) { + this.modelId = null; + } } @Override protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(fieldName); - out.writeString(queryText); + out.writeString(this.fieldName); + // to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API + // after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead + out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY)); if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { out.writeOptionalString(this.modelId); } else { - out.writeString(this.modelId); + out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY)); } - if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { + out.writeOptionalFloat(maxTokenScore); + if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) { out.writeBoolean(true); - out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); + out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); } else { out.writeBoolean(false); } @@ -118,10 +137,16 @@ protected void doWriteTo(StreamOutput out) throws IOException { protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); - xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + if (Objects.nonNull(queryText)) { + xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + } if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } + if (Objects.nonNull(maxTokenScore)) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); + if (Objects.nonNull(queryTokensSupplier) && Objects.nonNull(queryTokensSupplier.get())) { + xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokensSupplier.get()); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -131,9 +156,20 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws * The expected parsing form looks like: * "SAMPLE_FIELD": { * "query_text": "string", - * "model_id": "string" + * "model_id": "string", + * "max_token_score": float (optional) * } * + * or + * "SAMPLE_FIELD": { + * "query_tokens": { + * "token_a": float, + * "token_b": float, + * ... + * } + * } + * + * * @param parser XContentParser * @return NeuralQueryBuilder * @throws IOException can be thrown by parser @@ -161,16 +197,40 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw } requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); - requireValue( - sparseEncodingQueryBuilder.queryText(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) - ); - if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + if (Objects.isNull(sparseEncodingQueryBuilder.queryTokensSupplier())) { requireValue( - sparseEncodingQueryBuilder.modelId(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + sparseEncodingQueryBuilder.queryText(), + String.format( + Locale.ROOT, + "either %s field or %s field must be provided for [%s] query", + QUERY_TEXT_FIELD.getPreferredName(), + QUERY_TOKENS_FIELD.getPreferredName(), + NAME + ) + ); + if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format( + Locale.ROOT, + "using %s, %s field must be provided for [%s] query", + QUERY_TEXT_FIELD.getPreferredName(), + MODEL_ID_FIELD.getPreferredName(), + NAME + ) + ); + } + } + + if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.queryText())) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "%s field can not be empty", QUERY_TEXT_FIELD.getPreferredName()) ); } + if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.modelId())) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName())); + } + return sparseEncodingQueryBuilder; } @@ -189,12 +249,17 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui sparseEncodingQueryBuilder.queryText(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { sparseEncodingQueryBuilder.modelId(parser.text()); + } else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue()); } else { throw new ParsingException( parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName) ); } + } else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + Map queryTokens = parser.map(HashMap::new, XContentParser::floatValue); + sparseEncodingQueryBuilder.queryTokensSupplier(() -> queryTokens); } else { throw new ParsingException( parser.getTokenLocation(), @@ -227,6 +292,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws return new NeuralSparseQueryBuilder().fieldName(fieldName) .queryText(queryText) .modelId(modelId) + .maxTokenScore(maxTokenScore) .queryTokensSupplier(queryTokensSetOnce::get); } @@ -281,12 +347,13 @@ private static void validateQueryTokens(Map queryTokens) { protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; if (Objects.isNull(obj) || getClass() != obj.getClass()) return false; - if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false; - if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; + if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false; + if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) - .append(modelId, obj.modelId); - if (!Objects.isNull(queryTokensSupplier)) { + .append(modelId, obj.modelId) + .append(maxTokenScore, obj.maxTokenScore); + if (Objects.nonNull(queryTokensSupplier)) { equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); } return equalsBuilder.isEquals(); @@ -294,8 +361,8 @@ protected boolean doEquals(NeuralSparseQueryBuilder obj) { @Override protected int doHashCode() { - HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId); - if (!Objects.isNull(queryTokensSupplier)) { + HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore); + if (queryTokensSupplier != null) { builder.append(queryTokensSupplier.get()); } return builder.toHashCode(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java index 0f4c49f27..baee337ce 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java @@ -4,9 +4,9 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.util.Collections; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index f6f59958f..05eb6829c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.ArrayList; @@ -164,7 +164,7 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu } @SneakyThrows - public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful() { + public void testQueryMatches_whenMultipleShards_thenSuccessful() { String modelId = null; try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); @@ -223,55 +223,37 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful // verify that all ids are unique assertEquals(Set.copyOf(ids).size(), ids.size()); - } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, null, modelId, SEARCH_PIPELINE); - } - } - - @SneakyThrows - public void testResultProcessor_whenMultipleShardsAndNoMatches_thenSuccessful() { - try { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); - createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + // verify case when there are partial match + HybridQueryBuilder hybridQueryBuilderPartialMatch = new HybridQueryBuilder(); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); - Map searchResponseAsMap = search( + Map searchResponseAsMapPartialMatch = search( TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, + hybridQueryBuilderPartialMatch, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertQueryResults(searchResponseAsMap, 0, true); - } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, null, null, SEARCH_PIPELINE); - } - } - - @SneakyThrows - public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessful() { - try { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); - createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + assertQueryResults(searchResponseAsMapPartialMatch, 4, true, Range.between(0.33f, 1.0f)); - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); - hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + // verify case when query doesn't have a match + HybridQueryBuilder hybridQueryBuilderNoMatches = new HybridQueryBuilder(); + hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); + hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); - Map searchResponseAsMap = search( + Map searchResponseAsMapNoMatches = search( TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, + hybridQueryBuilderNoMatches, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertQueryResults(searchResponseAsMap, 4, true, Range.between(0.33f, 1.0f)); + assertQueryResults(searchResponseAsMapNoMatches, 0, true); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, null, modelId, SEARCH_PIPELINE); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index dd185e227..7c443a825 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -39,7 +39,7 @@ import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 2f880ce74..5d88ffed9 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -21,7 +21,7 @@ import org.opensearch.action.OriginalIndices; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.index.shard.ShardId; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index b3478984c..e1360474c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -6,11 +6,11 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; -import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.util.TestUtils.assertWeightedScores; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.Arrays; @@ -24,9 +24,9 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import com.google.common.primitives.Floats; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 4f76c666e..da9b34f22 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -4,7 +4,7 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import java.util.List; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 175ea08fe..ff1a2001c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -4,10 +4,10 @@ */ package org.opensearch.neuralsearch.processor; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.assertHybridSearchResults; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.Arrays; @@ -20,9 +20,9 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import com.google.common.primitives.Floats; import lombok.SneakyThrows; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java index dd517aa17..d85865bb5 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextChunkingProcessorIT.java @@ -24,7 +24,6 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; public class TextChunkingProcessorIT extends BaseNeuralSearchIT { private static final String INDEX_NAME = "text_chunking_test_index"; @@ -197,20 +196,7 @@ private void createPipelineProcessor(String pipelineName) throws Exception { URL pipelineURLPath = classLoader.getResource(PIPELINE_CONFIGS_BY_NAME.get(pipelineName)); Objects.requireNonNull(pipelineURLPath); String requestBody = Files.readString(Path.of(pipelineURLPath.toURI())); - Response pipelineCreateResponse = makeRequest( - client(), - "PUT", - "/_ingest/pipeline/" + pipelineName, - null, - toHttpEntity(String.format(LOCALE, requestBody)), - ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) - ); - Map node = XContentHelper.convertToMap( - XContentType.JSON.xContent(), - EntityUtils.toString(pipelineCreateResponse.getEntity()), - false - ); - assertEquals("true", node.get("acknowledged").toString()); + createPipelineProcessor(requestBody, pipelineName, ""); } private void createTextChunkingIndex(String indexName, String pipelineName) throws Exception { @@ -222,13 +208,13 @@ private void createTextChunkingIndex(String indexName, String pipelineName) thro private void ingestDocument(String documentPath) throws Exception { URL documentURLPath = classLoader.getResource(documentPath); Objects.requireNonNull(documentURLPath); - String ingestDocument = Files.readString(Path.of(documentURLPath.toURI())); + String document = Files.readString(Path.of(documentURLPath.toURI())); Response response = makeRequest( client(), "POST", INDEX_NAME + "/_doc?refresh", null, - toHttpEntity(ingestDocument), + toHttpEntity(document), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); Map map = XContentHelper.convertToMap( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java index 43d629c71..cf1473293 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java @@ -34,29 +34,19 @@ public void setUp() throws Exception { updateClusterSettings(); } - public void testEmbeddingProcessor_whenIngestingDocumentWithSourceMatchingTextMapping_thenSuccessful() throws Exception { + public void testEmbeddingProcessor_whenIngestingDocumentWithOrWithoutSourceMatchingMapping_thenSuccessful() throws Exception { String modelId = null; try { modelId = uploadModel(); loadModel(modelId); createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); createTextImageEmbeddingIndex(); + // verify doc with mapping ingestDocumentWithTextMappedToEmbeddingField(); assertEquals(1, getDocCount(INDEX_NAME)); - } finally { - wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); - } - } - - public void testEmbeddingProcessor_whenIngestingDocumentWithSourceWithoutMatchingInMapping_thenSuccessful() throws Exception { - String modelId = null; - try { - modelId = uploadModel(); - loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); - createTextImageEmbeddingIndex(); + // verify doc without mapping ingestDocumentWithoutMappedFields(); - assertEquals(1, getDocCount(INDEX_NAME)); + assertEquals(2, getDocCount(INDEX_NAME)); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java index 33620d149..fcb946d84 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/MLOpenSearchRerankProcessorIT.java @@ -12,8 +12,6 @@ import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.message.BasicHeader; -import org.junit.After; -import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; @@ -26,7 +24,7 @@ import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; @Log4j2 public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { @@ -36,33 +34,19 @@ public class MLOpenSearchRerankProcessorIT extends BaseNeuralSearchIT { private final static String TEXT_REP_1 = "Jacques loves fish. Fish make Jacques happy"; private final static String TEXT_REP_2 = "Fish like to eat plankton"; private final static String INDEX_CONFIG = "{\"mappings\": {\"properties\": {\"text_representation\": {\"type\": \"text\"}}}}"; - private String modelId; - - @After - @SneakyThrows - public void tearDown() { - super.tearDown(); - /* this is required to minimize chance of model not being deployed due to open memory CB, - * this happens in case we leave model from previous test case. We use new model for every test, and old model - * can be undeployed and deleted to free resources after each test case execution. - */ - deleteModel(modelId); - deleteSearchPipeline(PIPELINE_NAME); - deleteIndex(INDEX_NAME); - } - - @Before - @SneakyThrows - public void setup() { - modelId = uploadTextSimilarityModel(); - loadModel(modelId); - } @SneakyThrows public void testCrossEncoderRerankProcessor() { - createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); - setupIndex(); - runQueries(); + String modelId = null; + try { + modelId = uploadTextSimilarityModel(); + loadModel(modelId); + createSearchPipelineViaConfig(modelId, PIPELINE_NAME, "processor/RerankMLOpenSearchPipelineConfiguration.json"); + setupIndex(); + runQueries(); + } finally { + wipeOfTestResources(INDEX_NAME, null, modelId, PIPELINE_NAME); + } } private String uploadTextSimilarityModel() throws Exception { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java index 4647ebf5f..2a2fc7f34 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -27,25 +27,21 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.stream.IntStream; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; +import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery; /** * Integration tests for base scenarios when aggregations are combined with hybrid query */ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { - private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = - "test-neural-aggs-pipeline-multi-doc-index-multiple-shards"; - private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-hybrid-aggs-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-aggs-multi-doc-index-single-shard"; private static final String TEST_QUERY_TEXT3 = "hello"; private static final String TEST_QUERY_TEXT4 = "everyone"; private static final String TEST_QUERY_TEXT5 = "welcome"; @@ -53,7 +49,7 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { private static final String TEST_DOC_TEXT2 = "Hi to this place"; private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-aggregation-pipeline"; private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; private static final String TEST_DOC_TEXT5 = "People keep telling me orange but I still prefer pink"; private static final String TEST_DOC_TEXT6 = "She traveled because it cost the same as therapy and was a lot more enjoyable"; @@ -786,29 +782,6 @@ private Map executeQueryAndGetAggsResults( return searchResponseAsMap; } - private void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { - assertEquals(expected, getHitCount(searchResponseAsMap)); - - List> hits1NestedList = getNestedHits(searchResponseAsMap); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - - // verify that scores are in desc order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); - - Map total = getTotalHits(searchResponseAsMap); - assertNotNull(total.get("value")); - assertEquals(expected, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - } - private HybridQueryBuilder createHybridQueryBuilder(boolean isComplex) { if (isComplex) { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java index 7beb02dcc..8ff552698 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryBuilderTests.java @@ -12,7 +12,7 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 12fa534dd..be6942232 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -7,11 +7,11 @@ import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.opensearch.index.query.QueryBuilders.matchQuery; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; import java.io.IOException; import java.util.ArrayList; @@ -566,28 +566,17 @@ public void testRequestCache_whenMultipleShardsQueryReturnResults_thenSuccessful @SneakyThrows public void testWrappedQueryWithFilter_whenIndexAliasHasFilterAndIndexWithNestedFields_thenSuccess() { - String modelId = null; String alias = "alias_with_filter"; try { initializeIndexIfNotExist(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); - modelId = prepareModel(); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); // create alias for index QueryBuilder aliasFilter = QueryBuilders.boolQuery() .mustNot(QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); createIndexAlias(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, alias, aliasFilter); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - "", - modelId, - 5, - null, - null - ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); Map searchResponseAsMap = search( alias, @@ -608,34 +597,23 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilterAndIndexWithNested assertEquals(RELATION_EQUAL_TO, total.get("relation")); } finally { deleteIndexAlias(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, alias); - wipeOfTestResources(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, null, modelId, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME, null, null, SEARCH_PIPELINE); } } @SneakyThrows public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { - String modelId = null; String alias = "alias_with_filter"; try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); - modelId = prepareModel(); createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); // create alias for index QueryBuilder aliasFilter = QueryBuilders.boolQuery() .mustNot(QueryBuilders.matchQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); createIndexAlias(TEST_MULTI_DOC_INDEX_NAME, alias, aliasFilter); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - "", - modelId, - 5, - null, - null - ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); Map searchResponseAsMap = search( alias, @@ -656,7 +634,7 @@ public void testWrappedQueryWithFilter_whenIndexAliasHasFilters_thenSuccess() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } finally { deleteIndexAlias(TEST_MULTI_DOC_INDEX_NAME, alias); - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, modelId, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); } } @@ -892,7 +870,15 @@ private void addDocsToIndex(final String testMultiDocIndexName) { Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT2) ); - assertEquals(3, getDocCount(testMultiDocIndexName)); + addKnnDoc( + testMultiDocIndexName, + "4", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(4, getDocCount(testMultiDocIndexName)); } private List> getNestedHits(Map searchResponseAsMap) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java new file mode 100644 index 000000000..7d33d07fe --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryPostFilterIT.java @@ -0,0 +1,577 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.BeforeClass; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.MatchNoneQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery; + +public class HybridQueryPostFilterIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-hybrid-post-filter-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = + "test-hybrid-post-filter-multi-doc-index-single-shard"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-post-filter-pipeline"; + private static final String INTEGER_FIELD_1_STOCK = "stock"; + private static final String TEXT_FIELD_1_NAME = "name"; + private static final String KEYWORD_FIELD_2_CATEGORY = "category"; + private static final String TEXT_FIELD_VALUE_1_DUNES = "Dunes part 1"; + private static final String TEXT_FIELD_VALUE_2_DUNES = "Dunes part 2"; + private static final String TEXT_FIELD_VALUE_3_MI_1 = "Mission Impossible 1"; + private static final String TEXT_FIELD_VALUE_4_MI_2 = "Mission Impossible 2"; + private static final String TEXT_FIELD_VALUE_5_TERMINAL = "The Terminal"; + private static final String TEXT_FIELD_VALUE_6_AVENGERS = "Avengers"; + private static final int INTEGER_FIELD_STOCK_1_25 = 25; + private static final int INTEGER_FIELD_STOCK_2_22 = 22; + private static final int INTEGER_FIELD_STOCK_3_256 = 256; + private static final int INTEGER_FIELD_STOCK_4_25 = 25; + private static final int INTEGER_FIELD_STOCK_5_20 = 20; + private static final String KEYWORD_FIELD_CATEGORY_1_DRAMA = "Drama"; + private static final String KEYWORD_FIELD_CATEGORY_2_ACTION = "Action"; + private static final String KEYWORD_FIELD_CATEGORY_3_SCI_FI = "Sci-fi"; + private static final String STOCK_AVG_AGGREGATION_NAME = "avg_stock_size"; + private static final int SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER = 1; + private static final int SHARDS_COUNT_IN_MULTI_NODE_CLUSTER = 3; + private static final int LTE_OF_RANGE_IN_HYBRID_QUERY = 400; + private static final int GTE_OF_RANGE_IN_HYBRID_QUERY = 200; + private static final int LTE_OF_RANGE_IN_POST_FILTER_QUERY = 400; + private static final int GTE_OF_RANGE_IN_POST_FILTER_QUERY = 230; + + @BeforeClass + @SneakyThrows + public static void setUpCluster() { + // we need new instance because we're calling non-static methods from static method. + // main purpose is to minimize network calls, initialization is only needed once + HybridQueryPostFilterIT instance = new HybridQueryPostFilterIT(); + instance.initClient(); + instance.updateClusterSettings(); + } + + @SneakyThrows + public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithSingleShard_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_SINGLE_NODE_CLUSTER); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchEnabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testPostFilterOnIndexWithMultipleShards_whenConcurrentSearchDisabled_thenSuccessful() { + try { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + prepareResourcesBeforeTestExecution(SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + testPostFilterRangeQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterBoolQuery(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + testPostFilterMatchAllAndMatchNoneQueries(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "post_filter":{ + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + } + }*/ + @SneakyThrows + private void testPostFilterRangeQuery(String indexName) { + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + QueryBuilder postFilterQuery = createQueryBuilderWithRangeQuery( + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY + ); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 1, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + } + + /*{ + "query": { + "hybrid":{ + "queries":[ + { + "match":{ + "name": "mission" + } + }, + { + "term":{ + "name":{ + "value":"part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + + }, + "aggs": { + "avg_stock_size": { + "avg": { "field": "stock" } + } + }, + "post_filter":{ + "bool":{ + "should":[ + { + "range": { + "stock": { + "gte": 230, + "lte": 400 + } + } + }, + { + "match":{ + "name":"impossible" + } + } + + ] + } + } + }*/ + @SneakyThrows + private void testPostFilterBoolQuery(String indexName) { + // Case 1 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Range and a + // Match Query). + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + QueryBuilder postFilterQuery = createQueryBuilderWithBoolShouldQuery( + "impossible", + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY + ); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + // Case 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query), aggregation (Average stock price + // `avg_stock_price`) and a post filter query (Range Query and a Match Query). + AggregationBuilder aggsBuilder = createAvgAggregation(); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + List.of(aggsBuilder), + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 2, 1, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, STOCK_AVG_AGGREGATION_NAME); + assertEquals(1, aggValue.size()); + // Case 3 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Bool Query with + // a must clause(Range Query and a Match Query)). + postFilterQuery = createQueryBuilderWithBoolMustQuery( + "terminal", + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY + ); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + // Case 4 A Query with a combination of hybrid query (Match Query, Range Query) and a post filter query (Bool Query with a should + // clause(Range Query and a Match Query)). + hybridQueryBuilder = createHybridQueryBuilderScenarioWithMatchAndRangeQuery("hero", 5000, 1000); + postFilterQuery = createQueryBuilderWithBoolShouldQuery( + "impossible", + LTE_OF_RANGE_IN_POST_FILTER_QUERY, + GTE_OF_RANGE_IN_POST_FILTER_QUERY + ); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + } + + /*{ + "query": { + "hybrid": { + "queries": [ + { + "match": { + "name": "mission" + } + }, + { + "term": { + "name": { + "value": "part" + } + } + }, + { + "range": { + "stock": { + "gte": 200, + "lte": 400 + } + } + } + ] + } + }, + "post_filter": { + "match_all": {} + } + }*/ + @SneakyThrows + private void testPostFilterMatchAllAndMatchNoneQueries(String indexName) { + // CASE 1 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Match ALL + // Query). + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + QueryBuilder postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(true); + + Map searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 4, 3, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + + // CASE 2 A Query with a combination of hybrid query (Match Query, Term Query, Range Query) and a post filter query (Match NONE + // Query). + postFilterQuery = createPostFilterQueryBuilderWithMatchAllOrNoneQuery(false); + searchResponseAsMap = search( + indexName, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + null, + postFilterQuery + ); + assertHybridQueryResults(searchResponseAsMap, 0, 0, GTE_OF_RANGE_IN_POST_FILTER_QUERY, LTE_OF_RANGE_IN_POST_FILTER_QUERY); + } + + private void assertHybridQueryResults( + Map searchResponseAsMap, + int resultsExpected, + int postFilterResultsValidationExpected, + int lte, + int gte + ) { + assertHitResultsFromQuery(resultsExpected, searchResponseAsMap); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + + List docIndexes = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + int docIndex = (int) source.get(INTEGER_FIELD_1_STOCK); + docIndexes.add(docIndex); + } + assertEquals(postFilterResultsValidationExpected, docIndexes.stream().filter(docIndex -> docIndex < lte || docIndex > gte).count()); + } + + @SneakyThrows + void prepareResourcesBeforeTestExecution(int numShards) { + if (numShards == 1) { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, numShards); + } else { + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, numShards); + } + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + } + + @SneakyThrows + private void initializeIndexIfNotExists(String indexName, int numShards) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(), + numShards + ), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_2_DUNES), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_1_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_1_DUNES), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_2_22), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_3_MI_1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_3_256), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_4_MI_2), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_4_25), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_2_ACTION), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_5_TERMINAL), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_1_DRAMA), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEXT_FIELD_1_NAME), + Collections.singletonList(TEXT_FIELD_VALUE_6_AVENGERS), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1_STOCK), + List.of(INTEGER_FIELD_STOCK_5_20), + List.of(KEYWORD_FIELD_2_CATEGORY), + List.of(KEYWORD_FIELD_CATEGORY_3_SCI_FI), + List.of(), + List.of() + ); + } + } + + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(termQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private HybridQueryBuilder createHybridQueryBuilderScenarioWithMatchAndRangeQuery(String text, int lte, int gte) { + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder).add(rangeQueryBuilder); + return hybridQueryBuilder; + } + + private QueryBuilder createQueryBuilderWithRangeQuery(int lte, int gte) { + return QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + } + + private QueryBuilder createQueryBuilderWithBoolShouldQuery(String query, int lte, int gte) { + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, query); + return QueryBuilders.boolQuery().should(rangeQuery).should(matchQuery); + } + + private QueryBuilder createQueryBuilderWithBoolMustQuery(String query, int lte, int gte) { + QueryBuilder rangeQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_1_STOCK).gte(gte).lte(lte); + QueryBuilder matchQuery = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, query); + return QueryBuilders.boolQuery().must(rangeQuery).must(matchQuery); + } + + private QueryBuilder createPostFilterQueryBuilderWithMatchAllOrNoneQuery(boolean isMatchAll) { + return isMatchAll ? QueryBuilders.matchAllQuery() : new MatchNoneQueryBuilder(); + } + + private AggregationBuilder createAvgAggregation() { + return AggregationBuilders.avg(STOCK_AVG_AGGREGATION_NAME).field(INTEGER_FIELD_1_STOCK); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index dd63abbea..1fa7e94c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -12,7 +12,7 @@ import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index 6f1e5f27e..2e4c766aa 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -4,11 +4,11 @@ */ package org.opensearch.neuralsearch.query; -import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; -import static org.opensearch.neuralsearch.TestUtils.TEST_DIMENSION; -import static org.opensearch.neuralsearch.TestUtils.TEST_SPACE_TYPE; -import static org.opensearch.neuralsearch.TestUtils.createRandomVector; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; import java.util.Collections; import java.util.List; @@ -27,7 +27,6 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; private static final String TEST_MULTI_VECTOR_FIELD_INDEX_NAME = "test-neural-multi-vector-field-index"; - private static final String TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME = "test-neural-text-and-vector-field-index"; private static final String TEST_NESTED_INDEX_NAME = "test-neural-nested-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world"; @@ -45,67 +44,40 @@ public void setUp() throws Exception { } /** - * Tests basic query: + * Tests basic query with boost parameter: * { * "query": { * "neural": { * "text_knn": { * "query_text": "Hello world", * "model_id": "dcsdcasd", - * "k": 1 + * "k": 1, + * "boost": 2.0 * } * } * } * } - */ - @SneakyThrows - public void testBasicQuery() { - String modelId = null; - try { - initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); - modelId = prepareModel(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - "", - modelId, - 1, - null, - null - ); - Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); - - assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); - } finally { - wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); - } - } - - /** - * Tests basic query with boost parameter: + * and query with image query part * { * "query": { * "neural": { * "text_knn": { * "query_text": "Hello world", + * "query_image": "base64_1234567890", * "model_id": "dcsdcasd", - * "k": 1, - * "boost": 2.0 + * "k": 1 * } * } * } * } */ @SneakyThrows - public void testBoostQuery() { + public void testQueryWithBoostAndImageQuery() { String modelId = null; try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); modelId = prepareModel(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + NeuralQueryBuilder neuralQueryBuilderTextQuery = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, "", @@ -116,13 +88,33 @@ public void testBoostQuery() { ); final float boost = 2.0f; - neuralQueryBuilder.boost(boost); - Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + neuralQueryBuilderTextQuery.boost(boost); + Map searchResponseAsMapTextQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderTextQuery, 1); + Map firstInnerHitTextQuery = getFirstInnerHit(searchResponseAsMapTextQuery); - assertEquals("1", firstInnerHit.get("_id")); + assertEquals("1", firstInnerHitTextQuery.get("_id")); float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); + assertEquals(expectedScore, objectToFloat(firstInnerHitTextQuery.get("_score")), DELTA_FOR_SCORE_ASSERTION); + + NeuralQueryBuilder neuralQueryBuilderMultimodalQuery = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + TEST_IMAGE_TEXT, + modelId, + 1, + null, + null + ); + Map searchResponseAsMapMultimodalQuery = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilderMultimodalQuery, 1); + Map firstInnerHitMultimodalQuery = getFirstInnerHit(searchResponseAsMapMultimodalQuery); + + assertEquals("1", firstInnerHitMultimodalQuery.get("_id")); + float expectedScoreMultimodalQuery = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals( + expectedScoreMultimodalQuery, + objectToFloat(firstInnerHitMultimodalQuery.get("_score")), + DELTA_FOR_SCORE_ASSERTION + ); } finally { wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); } @@ -200,6 +192,27 @@ public void testRescoreQuery() { * } * } * } + * and bool should with BM25 and neural query: + * { + * "query": { + * "bool" : { + * "should": [ + * "neural": { + * "field_1": { + * "query_text": "Hello world", + * "model_id": "dcsdcasd", + * "k": 1 + * }, + * }, + * "match": { + * "field_2": { + * "query": "Hello world" + * } + * } + * ] + * } + * } + * } */ @SneakyThrows public void testBooleanQuery_withMultipleNeuralQueries() { @@ -207,8 +220,8 @@ public void testBooleanQuery_withMultipleNeuralQueries() { try { initializeIndexIfNotExist(TEST_MULTI_VECTOR_FIELD_INDEX_NAME); modelId = prepareModel(); - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - + // verify two neural queries wrapped into bool + BoolQueryBuilder boolQueryBuilderTwoNeuralQueries = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, @@ -228,50 +241,21 @@ public void testBooleanQuery_withMultipleNeuralQueries() { null ); - boolQueryBuilder.should(neuralQueryBuilder1).should(neuralQueryBuilder2); + boolQueryBuilderTwoNeuralQueries.should(neuralQueryBuilder1).should(neuralQueryBuilder2); - Map searchResponseAsMap = search(TEST_MULTI_VECTOR_FIELD_INDEX_NAME, boolQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + Map searchResponseAsMapTwoNeuralQueries = search( + TEST_MULTI_VECTOR_FIELD_INDEX_NAME, + boolQueryBuilderTwoNeuralQueries, + 1 + ); + Map firstInnerHitTwoNeuralQueries = getFirstInnerHit(searchResponseAsMapTwoNeuralQueries); - assertEquals("1", firstInnerHit.get("_id")); + assertEquals("1", firstInnerHitTwoNeuralQueries.get("_id")); float expectedScore = 2 * computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); - } finally { - wipeOfTestResources(TEST_MULTI_VECTOR_FIELD_INDEX_NAME, null, modelId, null); - } - } - - /** - * Tests bool should with BM25 and neural query: - * { - * "query": { - * "bool" : { - * "should": [ - * "neural": { - * "field_1": { - * "query_text": "Hello world", - * "model_id": "dcsdcasd", - * "k": 1 - * }, - * }, - * "match": { - * "field_2": { - * "query": "Hello world" - * } - * } - * ] - * } - * } - * } - */ - @SneakyThrows - public void testBooleanQuery_withNeuralAndBM25Queries() { - String modelId = null; - try { - initializeIndexIfNotExist(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME); - modelId = prepareModel(); - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + assertEquals(expectedScore, objectToFloat(firstInnerHitTwoNeuralQueries.get("_score")), DELTA_FOR_SCORE_ASSERTION); + // verify bool with one neural and one bm25 query + BoolQueryBuilder boolQueryBuilderMixOfNeuralAndBM25 = new BoolQueryBuilder(); NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, @@ -284,16 +268,20 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); - boolQueryBuilder.should(neuralQueryBuilder).should(matchQueryBuilder); + boolQueryBuilderMixOfNeuralAndBM25.should(neuralQueryBuilder).should(matchQueryBuilder); - Map searchResponseAsMap = search(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME, boolQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + Map searchResponseAsMapMixOfNeuralAndBM25 = search( + TEST_MULTI_VECTOR_FIELD_INDEX_NAME, + boolQueryBuilderMixOfNeuralAndBM25, + 1 + ); + Map firstInnerHitMixOfNeuralAndBM25 = getFirstInnerHit(searchResponseAsMapMixOfNeuralAndBM25); - assertEquals("1", firstInnerHit.get("_id")); + assertEquals("1", firstInnerHitMixOfNeuralAndBM25.get("_id")); float minExpectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); - assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); + assertTrue(minExpectedScore < objectToFloat(firstInnerHitMixOfNeuralAndBM25.get("_score"))); } finally { - wipeOfTestResources(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME, null, modelId, null); + wipeOfTestResources(TEST_MULTI_VECTOR_FIELD_INDEX_NAME, null, modelId, null); } } @@ -389,47 +377,6 @@ public void testFilterQuery() { } } - /** - * Tests basic query for multimodal: - * { - * "query": { - * "neural": { - * "text_knn": { - * "query_text": "Hello world", - * "query_image": "base64_1234567890", - * "model_id": "dcsdcasd", - * "k": 1 - * } - * } - * } - * } - */ - @SneakyThrows - public void testMultimodalQuery() { - String modelId = null; - try { - initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); - modelId = prepareModel(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( - TEST_KNN_VECTOR_FIELD_NAME_1, - TEST_QUERY_TEXT, - TEST_IMAGE_TEXT, - modelId, - 1, - null, - null - ); - Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); - Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); - - assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); - assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA_FOR_SCORE_ASSERTION); - } finally { - wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); - } - } - @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -458,7 +405,9 @@ private void initializeIndexIfNotExist(String indexName) { TEST_MULTI_VECTOR_FIELD_INDEX_NAME, "1", List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), - List.of(Floats.asList(testVector).toArray(), Floats.asList(testVector).toArray()) + List.of(Floats.asList(testVector).toArray(), Floats.asList(testVector).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_QUERY_TEXT) ); assertEquals(1, getDocCount(TEST_MULTI_VECTOR_FIELD_INDEX_NAME)); } @@ -477,22 +426,6 @@ private void initializeIndexIfNotExist(String indexName) { assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); } - if (TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME)) { - prepareKnnIndex( - TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME, - Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)) - ); - addKnnDoc( - TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME, - "1", - Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_QUERY_TEXT) - ); - assertEquals(1, getDocCount(TEST_TEXT_AND_VECTOR_FIELD_INDEX_NAME)); - } - if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { prepareKnnIndex( TEST_MULTI_DOC_INDEX_NAME, diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 89bcd57d7..4d2fe540d 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -9,10 +9,12 @@ import static org.mockito.Mockito.mock; import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; -import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.util.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME; import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD; +import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD; import java.io.IOException; import java.util.List; @@ -22,6 +24,10 @@ import java.util.function.BiConsumer; import java.util.function.Supplier; +import org.apache.commons.lang.StringUtils; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.junit.Before; import org.opensearch.Version; import org.opensearch.client.Client; @@ -37,9 +43,11 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; @@ -54,6 +62,7 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase { private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final float BOOST = 1.8f; private static final String QUERY_NAME = "queryName"; + private static final Float MAX_TOKEN_SCORE = 123f; private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); @Before @@ -88,6 +97,32 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_tokens": { + "token_a": float_score_a, + "token_b": float_score_b + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS_SUPPLIER.get()) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TOKENS_SUPPLIER.get(), sparseEncodingQueryBuilder.queryTokensSupplier().get()); + } + @SneakyThrows public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { /* @@ -121,6 +156,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName()); } + @SneakyThrows + public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "max_token_score": 123.0 + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); + assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely"); + } + @SneakyThrows public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { /* @@ -243,12 +304,56 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenBuildWithEmptyQuery_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), StringUtils.EMPTY) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithEmptyModelId_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), StringUtils.EMPTY) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); + } + @SuppressWarnings("unchecked") @SneakyThrows public void testToXContent() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) .modelId(MODEL_ID) - .queryText(QUERY_TEXT); + .queryText(QUERY_TEXT) + .maxTokenScore(MAX_TOKEN_SCORE) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); XContentBuilder builder = XContentFactory.jsonBuilder(); builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -273,18 +378,32 @@ public void testToXContent() { assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); + assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0); + Map parsedQueryTokens = (Map) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName()); + assertEquals(QUERY_TOKENS_SUPPLIER.get().keySet(), parsedQueryTokens.keySet()); + for (Map.Entry entry : QUERY_TOKENS_SUPPLIER.get().entrySet()) { + assertEquals(entry.getValue(), parsedQueryTokens.get(entry.getKey()).floatValue(), 0); + } + } + + public void testStreams_whenCurrentVersion_thenSuccess() { + setUpClusterService(Version.CURRENT); + testStreams(); + testStreamsWithQueryTokensOnly(); } public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() { setUpClusterService(Version.V_2_12_0); testStreams(); + testStreamsWithQueryTokensOnly(); } @SneakyThrows - public void testStreams() { + private void testStreams() { NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); + original.maxTokenScore(MAX_TOKEN_SCORE); original.modelId(MODEL_ID); original.boost(BOOST); original.queryName(QUERY_NAME); @@ -306,11 +425,11 @@ public void testStreams() { queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); original.queryTokensSupplier(queryTokensSetOnce::get); - BytesStreamOutput streamOutput2 = new BytesStreamOutput(); - original.writeTo(streamOutput2); + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); filterStreamInput = new NamedWriteableAwareStreamInput( - streamOutput2.bytes().streamInput(), + streamOutput.bytes().streamInput(), new NamedWriteableRegistry( List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) ) @@ -320,6 +439,26 @@ public void testStreams() { assertEquals(original, copy); } + @SneakyThrows + private void testStreamsWithQueryTokensOnly() { + NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); + original.fieldName(FIELD_NAME); + original.queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); + assertEquals(original, copy); + } + public void testHashAndEquals() { String fieldName1 = "field 1"; String fieldName2 = "field 2"; @@ -327,6 +466,8 @@ public void testHashAndEquals() { String queryText2 = "query text 2"; String modelId1 = "model-1"; String modelId2 = "model-2"; + float maxTokenScore1 = 1.1f; + float maxTokenScore2 = 2.2f; float boost1 = 1.8f; float boost2 = 3.8f; String queryName1 = "query-1"; @@ -337,6 +478,7 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); @@ -344,18 +486,21 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) - .modelId(modelId1); + .modelId(modelId1) + .maxTokenScore(maxTokenScore1); // Identical to sparseEncodingQueryBuilder_baseline except diff field name NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); @@ -363,6 +508,7 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText2) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); @@ -370,6 +516,7 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId2) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1); @@ -377,6 +524,7 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost2) .queryName(queryName1); @@ -384,13 +532,23 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName2); + // Identical to sparseEncodingQueryBuilder_baseline except diff max token score + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore2) + .boost(boost1) + .queryName(queryName1); + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1) .queryTokensSupplier(() -> queryTokens1); @@ -399,10 +557,23 @@ public void testHashAndEquals() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) .modelId(modelId1) + .maxTokenScore(maxTokenScore1) .boost(boost1) .queryName(queryName1) .queryTokensSupplier(() -> queryTokens2); + // Identical to sparseEncodingQueryBuilder_baseline except null query text + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except null model id + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .boost(boost1) + .queryName(queryName1); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -427,11 +598,20 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode()); + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode()); assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens); assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullQueryText); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullQueryText.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullModelId); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullModelId.hashCode()); } @SneakyThrows @@ -486,4 +666,23 @@ private void setUpClusterService(Version version) { ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); NeuralSearchClusterUtil.instance().initialize(clusterService); } + + @SneakyThrows + public void testDoToQuery_successfulDoToQuery() { + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) + .maxTokenScore(MAX_TOKEN_SCORE) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class); + MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class); + doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName(); + doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any()); + + BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder(); + targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD); + targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD); + + assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java index 0caca4f43..4790169e5 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java @@ -5,8 +5,8 @@ package org.opensearch.neuralsearch.query; import org.opensearch.neuralsearch.BaseNeuralSearchIT; -import static org.opensearch.neuralsearch.TestUtils.objectToFloat; - +import static org.opensearch.neuralsearch.util.TestUtils.objectToFloat; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomTokenWeightMap; import java.util.List; import java.util.Map; @@ -15,7 +15,7 @@ import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.util.TestUtils; import lombok.SneakyThrows; @@ -42,13 +42,14 @@ public void setUp() throws Exception { } /** - * Tests basic query: + * Tests basic query with boost: * { * "query": { * "neural_sparse": { * "text_sparse": { * "query_text": "Hello world a b", - * "model_id": "dcsdcasd" + * "model_id": "dcsdcasd", + * "boost": 2 * } * } * } @@ -62,12 +63,13 @@ public void testBasicQueryUsingQueryText() { modelId = prepareSparseEncodingModel(); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) .queryText(TEST_QUERY_TEXT) - .modelId(modelId); + .modelId(modelId) + .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } finally { wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); @@ -75,13 +77,18 @@ public void testBasicQueryUsingQueryText() { } /** - * Tests basic query: + * Tests basic query with boost: * { * "query": { * "neural_sparse": { * "text_sparse": { - * "query_text": "Hello world a b", - * "model_id": "dcsdcasd", + * "query_tokens": { + * "hello": float, + * "world": float, + * "a": float, + * "b": float, + * "c": float + * }, * "boost": 2 * } * } @@ -89,23 +96,21 @@ public void testBasicQueryUsingQueryText() { * } */ @SneakyThrows - public void testBoostQuery() { - String modelId = null; + public void testBasicQueryUsingQueryTokens() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); - modelId = prepareSparseEncodingModel(); + Map queryTokens = createRandomTokenWeightMap(TEST_TOKENS); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) - .modelId(modelId) + .queryTokensSupplier(() -> queryTokens) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); assertEquals("1", firstInnerHit.get("_id")); - float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + float expectedScore = 2 * computeExpectedScore(testRankFeaturesDoc, sparseEncodingQueryBuilder.queryTokensSupplier().get()); assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); } finally { - wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, modelId, null); + wipeOfTestResources(TEST_BASIC_INDEX_NAME, null, null, null); } } @@ -213,11 +218,8 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() { * "model_id": "dcsdcasd" * } * }, - * "neural_sparse": { - * "field2": { - * "query_text": "Hello world a b", - * "model_id": "dcsdcasd" - * } + * "match": { + * "field2": "Hello world a b", * } * ] * } diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java new file mode 100644 index 000000000..521606fda --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BaseAggregationsWithHybridQueryIT.java @@ -0,0 +1,244 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.aggregation; + +import lombok.SneakyThrows; +import org.junit.BeforeClass; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; + +public class BaseAggregationsWithHybridQueryIT extends BaseNeuralSearchIT { + protected static final String TEST_DOC_TEXT1 = "Hello world"; + protected static final String TEST_DOC_TEXT2 = "Hi to this place"; + protected static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + protected static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + protected static final String TEST_DOC_TEXT5 = "People keep telling me orange but I still prefer pink"; + protected static final String TEST_DOC_TEXT6 = "She traveled because it cost the same as therapy and was a lot more enjoyable"; + protected static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + protected static final String TEST_QUERY_TEXT3 = "hello"; + protected static final String TEST_QUERY_TEXT4 = "cost"; + protected static final String TEST_QUERY_TEXT5 = "welcome"; + protected static final String NESTED_TYPE_FIELD_USER = "user"; + protected static final String NESTED_FIELD_FIRSTNAME = "firstname"; + protected static final String NESTED_FIELD_LASTNAME = "lastname"; + protected static final String NESTED_FIELD_FIRSTNAME_JOHN = "john"; + protected static final String NESTED_FIELD_LASTNAME_BLACK = "black"; + protected static final String NESTED_FIELD_FIRSTNAME_FRODO = "frodo"; + protected static final String NESTED_FIELD_LASTNAME_BAGGINS = "baggins"; + protected static final String NESTED_FIELD_FIRSTNAME_MOHAMMED = "mohammed"; + protected static final String NESTED_FIELD_LASTNAME_EZAB = "ezab"; + protected static final String NESTED_FIELD_FIRSTNAME_SUN = "sun"; + protected static final String NESTED_FIELD_LASTNAME_WUKONG = "wukong"; + protected static final String NESTED_FIELD_FIRSTNAME_VASILISA = "vasilisa"; + protected static final String NESTED_FIELD_LASTNAME_WISE = "the wise"; + protected static final String INTEGER_FIELD_DOCINDEX = "doc_index"; + protected static final int INTEGER_FIELD_DOCINDEX_1234 = 1234; + protected static final int INTEGER_FIELD_DOCINDEX_2345 = 2345; + protected static final int INTEGER_FIELD_DOCINDEX_3456 = 3456; + protected static final int INTEGER_FIELD_DOCINDEX_4567 = 4567; + protected static final String KEYWORD_FIELD_DOCKEYWORD = "doc_keyword"; + protected static final String KEYWORD_FIELD_DOCKEYWORD_WORKABLE = "workable"; + protected static final String KEYWORD_FIELD_DOCKEYWORD_ANGRY = "angry"; + protected static final String KEYWORD_FIELD_DOCKEYWORD_LIKABLE = "likeable"; + protected static final String KEYWORD_FIELD_DOCKEYWORD_ENTIRE = "entire"; + protected static final String DATE_FIELD = "doc_date"; + protected static final String DATE_FIELD_01031995 = "01/03/1995"; + protected static final String DATE_FIELD_05022015 = "05/02/2015"; + protected static final String DATE_FIELD_07232007 = "07/23/2007"; + protected static final String DATE_FIELD_08212012 = "08/21/2012"; + protected static final String INTEGER_FIELD_PRICE = "doc_price"; + protected static final int INTEGER_FIELD_PRICE_130 = 130; + protected static final int INTEGER_FIELD_PRICE_100 = 100; + protected static final int INTEGER_FIELD_PRICE_200 = 200; + protected static final int INTEGER_FIELD_PRICE_25 = 25; + protected static final int INTEGER_FIELD_PRICE_30 = 30; + protected static final int INTEGER_FIELD_PRICE_350 = 350; + protected static final String BUCKET_AGG_DOC_COUNT_FIELD = "doc_count"; + protected static final String BUCKETS_AGGREGATION_NAME_1 = "date_buckets_1"; + protected static final String BUCKETS_AGGREGATION_NAME_2 = "date_buckets_2"; + protected static final String BUCKETS_AGGREGATION_NAME_3 = "date_buckets_3"; + protected static final String BUCKETS_AGGREGATION_NAME_4 = "date_buckets_4"; + protected static final String KEY = "key"; + protected static final String BUCKET_AGG_KEY_AS_STRING = "key_as_string"; + protected static final String SUM_AGGREGATION_NAME = "sum_aggs"; + protected static final String SUM_AGGREGATION_NAME_2 = "sum_aggs_2"; + protected static final String AVG_AGGREGATION_NAME = "avg_field"; + protected static final String GENERIC_AGGREGATION_NAME = "my_aggregation"; + protected static final String DATE_AGGREGATION_NAME = "date_aggregation"; + protected static final String CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH = "search.concurrent_segment_search.enabled"; + + @BeforeClass + @SneakyThrows + public static void setUpCluster() { + // we need new instance because we're calling non-static methods from static method. + // main purpose is to minimize network calls, initialization is only needed once + BaseAggregationsWithHybridQueryIT instance = new BaseAggregationsWithHybridQueryIT(); + instance.initClient(); + instance.updateClusterSettings(); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + protected void prepareResources(String indexName, String pipelineName) { + initializeIndexIfNotExist(indexName); + createSearchPipelineWithResultsPostProcessor(pipelineName); + } + + @SneakyThrows + protected void initializeIndexIfNotExist(String indexName) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(NESTED_TYPE_FIELD_USER, NESTED_FIELD_FIRSTNAME, NESTED_FIELD_LASTNAME), + List.of(INTEGER_FIELD_DOCINDEX), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(DATE_FIELD), + 3 + ), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(NESTED_TYPE_FIELD_USER), + List.of(Map.of(NESTED_FIELD_FIRSTNAME, NESTED_FIELD_FIRSTNAME_JOHN, NESTED_FIELD_LASTNAME, NESTED_FIELD_LASTNAME_BLACK)), + List.of(INTEGER_FIELD_DOCINDEX, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_DOCINDEX_1234, INTEGER_FIELD_PRICE_130), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(KEYWORD_FIELD_DOCKEYWORD_WORKABLE), + List.of(DATE_FIELD), + List.of(DATE_FIELD_01031995) + ); + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(NESTED_TYPE_FIELD_USER), + List.of(Map.of(NESTED_FIELD_FIRSTNAME, NESTED_FIELD_FIRSTNAME_FRODO, NESTED_FIELD_LASTNAME, NESTED_FIELD_LASTNAME_BAGGINS)), + List.of(INTEGER_FIELD_DOCINDEX, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_DOCINDEX_2345, INTEGER_FIELD_PRICE_100), + List.of(), + List.of(), + List.of(DATE_FIELD), + List.of(DATE_FIELD_05022015) + ); + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2), + List.of(NESTED_TYPE_FIELD_USER), + List.of(Map.of(NESTED_FIELD_FIRSTNAME, NESTED_FIELD_FIRSTNAME_MOHAMMED, NESTED_FIELD_LASTNAME, NESTED_FIELD_LASTNAME_EZAB)), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_200), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(KEYWORD_FIELD_DOCKEYWORD_ANGRY), + List.of(DATE_FIELD), + List.of(DATE_FIELD_07232007) + ); + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4), + List.of(NESTED_TYPE_FIELD_USER), + List.of(Map.of(NESTED_FIELD_FIRSTNAME, NESTED_FIELD_FIRSTNAME_SUN, NESTED_FIELD_LASTNAME, NESTED_FIELD_LASTNAME_WUKONG)), + List.of(INTEGER_FIELD_DOCINDEX, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_DOCINDEX_3456, INTEGER_FIELD_PRICE_25), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(KEYWORD_FIELD_DOCKEYWORD_LIKABLE), + List.of(DATE_FIELD), + List.of(DATE_FIELD_05022015) + ); + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5), + List.of(), + List.of(), + List.of(INTEGER_FIELD_DOCINDEX, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_DOCINDEX_3456, INTEGER_FIELD_PRICE_30), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(KEYWORD_FIELD_DOCKEYWORD_ENTIRE), + List.of(DATE_FIELD), + List.of(DATE_FIELD_08212012) + ); + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT6), + List.of(NESTED_TYPE_FIELD_USER), + List.of(Map.of(NESTED_FIELD_FIRSTNAME, NESTED_FIELD_FIRSTNAME_VASILISA, NESTED_FIELD_LASTNAME, NESTED_FIELD_LASTNAME_WISE)), + List.of(INTEGER_FIELD_DOCINDEX, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_DOCINDEX_4567, INTEGER_FIELD_PRICE_350), + List.of(KEYWORD_FIELD_DOCKEYWORD), + List.of(KEYWORD_FIELD_DOCKEYWORD_ENTIRE), + List.of(DATE_FIELD), + List.of(DATE_FIELD_08212012) + ); + } + } + + protected void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { + assertEquals(expected, getHitCount(searchResponseAsMap)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expected, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java new file mode 100644 index 000000000..ce8854eed --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/BucketAggregationsWithHybridQueryIT.java @@ -0,0 +1,817 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.aggregation; + +import lombok.SneakyThrows; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MinBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.StatsBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; + +/** + * Integration tests for bucket type aggregations when they are bundled with hybrid query + * Below is list of aggregations that are present in this test: + * - Adjacency matrix + * - Diversified sampler + * - Date histogram + * - Nested + * - Filter + * - Global + * - Sampler + * - Histogram + * - Significant terms + * - Terms + * + * Following aggs are tested by other integ tests: + * - Date range + * + * Below aggregations are not part of any test: + * - Filters + * - Geodistance + * - Geohash grid + * - Geohex grid + * - Geotile grid + * - IP range + * - Missing + * - Multi-terms + * - Range + * - Reverse nested + * - Significant text + */ +public class BucketAggregationsWithHybridQueryIT extends BaseAggregationsWithHybridQueryIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-aggs-bucket-multi-doc-index-multiple-shards"; + private static final String SEARCH_PIPELINE = "search-pipeline-bucket-aggs"; + + @SneakyThrows + public void testBucketAndNestedAggs_whenAdjacencyMatrix_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testAdjacencyMatrixAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenAdjacencyMatrix_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testAdjacencyMatrixAggs(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenDiversifiedSampler_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDiversifiedSampler(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDiversifiedSampler_thenFail() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + + testDiversifiedSampler(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenAvgNestedIntoFilter_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testAvgNestedIntoFilter(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenAvgNestedIntoFilter_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testAvgNestedIntoFilter(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenSumNestedIntoFilters_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testSumNestedIntoFilters(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenSumNestedIntoFilters_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testSumNestedIntoFilters(); + } + + @SneakyThrows + public void testBucketAggs_whenGlobalAggUsedWithQuery_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testGlobalAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenGlobalAggUsedWithQuery_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testGlobalAggs(); + } + + @SneakyThrows + public void testBucketAggs_whenHistogramAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testHistogramAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenHistogramAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testHistogramAggs(); + } + + @SneakyThrows + public void testBucketAggs_whenNestedAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testNestedAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenNestedAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testNestedAggs(); + } + + @SneakyThrows + public void testBucketAggs_whenSamplerAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testSampler(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenSamplerAgg_thenFail() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + + testSampler(); + } + + @SneakyThrows + public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs(); + } + + @SneakyThrows + public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketStatsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketStatsAggs(); + } + + @SneakyThrows + public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketScriptAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketScriptedAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketScriptedAggs(); + } + + @SneakyThrows + public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketScriptedAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenTermsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testTermsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenTermsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testTermsAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenSignificantTermsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testSignificantTermsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenSignificantTermsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testSignificantTermsAggs(); + } + + private void testAvgNestedIntoFilter() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.filter( + GENERIC_AGGREGATION_NAME, + QueryBuilders.rangeQuery(INTEGER_FIELD_DOCINDEX).lte(3000) + ).subAggregation(AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + double avgValue = getAggregationValue(getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME), AVG_AGGREGATION_NAME); + assertEquals(1789.5, avgValue, DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testSumNestedIntoFilters() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.filters( + GENERIC_AGGREGATION_NAME, + QueryBuilders.rangeQuery(INTEGER_FIELD_DOCINDEX).lte(3000), + QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE) + ).otherBucket(true).subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(3, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(2, firstBucket.size()); + assertEquals(2, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(3579.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + + Map secondBucket = buckets.get(1); + assertEquals(2, secondBucket.size()); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + + Map thirdBucket = buckets.get(2); + assertEquals(2, thirdBucket.size()); + assertEquals(1, thirdBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(3456.0, getAggregationValue(thirdBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testGlobalAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + AggregationBuilder aggsBuilder = AggregationBuilders.global(GENERIC_AGGREGATION_NAME) + .subAggregation(AggregationBuilders.sum(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggsBuilder), + hybridQueryBuilderNeuralThenTerm, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + double avgValue = getAggregationValue(getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME), AVG_AGGREGATION_NAME); + assertEquals(15058.0, avgValue, DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testHistogramAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.histogram(GENERIC_AGGREGATION_NAME) + .field(INTEGER_FIELD_PRICE) + .interval(100); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(2, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(2, firstBucket.size()); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, (Double) firstBucket.get(KEY), DELTA_FOR_SCORE_ASSERTION); + + Map secondBucket = buckets.get(1); + assertEquals(2, secondBucket.size()); + assertEquals(2, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(100.0, (Double) secondBucket.get(KEY), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testNestedAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.nested(GENERIC_AGGREGATION_NAME, NESTED_TYPE_FIELD_USER) + .subAggregation( + AggregationBuilders.terms(BUCKETS_AGGREGATION_NAME_1) + .field(String.join(".", NESTED_TYPE_FIELD_USER, NESTED_FIELD_FIRSTNAME)) + ); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map nestedAgg = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(nestedAgg); + + assertEquals(3, nestedAgg.get(BUCKET_AGG_DOC_COUNT_FIELD)); + List> buckets = getAggregationBuckets(nestedAgg, BUCKETS_AGGREGATION_NAME_1); + + assertNotNull(buckets); + assertEquals(3, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(2, firstBucket.size()); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(NESTED_FIELD_FIRSTNAME_FRODO, firstBucket.get(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(2, secondBucket.size()); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(NESTED_FIELD_FIRSTNAME_JOHN, secondBucket.get(KEY)); + + Map thirdBucket = buckets.get(2); + assertEquals(2, thirdBucket.size()); + assertEquals(1, thirdBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(NESTED_FIELD_FIRSTNAME_SUN, thirdBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDiversifiedSampler() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.diversifiedSampler(GENERIC_AGGREGATION_NAME) + .field(KEYWORD_FIELD_DOCKEYWORD) + .shardSize(2) + .subAggregation(AggregationBuilders.terms(BUCKETS_AGGREGATION_NAME_1).field(KEYWORD_FIELD_DOCKEYWORD)); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(2, aggValue.size()); + assertEquals(3, aggValue.get(BUCKET_AGG_DOC_COUNT_FIELD)); + Map nestedAggs = getAggregationValues(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertNotNull(nestedAggs); + assertEquals(0, nestedAggs.get("doc_count_error_upper_bound")); + List> buckets = getAggregationBuckets(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertEquals(2, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("likeable", firstBucket.get(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("workable", secondBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testAdjacencyMatrixAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.adjacencyMatrix( + GENERIC_AGGREGATION_NAME, + Map.of( + "grpA", + QueryBuilders.matchQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE), + "grpB", + QueryBuilders.matchQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_ANGRY), + "grpC", + QueryBuilders.matchQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_LIKABLE) + ) + ); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(2, buckets.size()); + Map grpA = buckets.get(0); + assertEquals(1, grpA.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("grpA", grpA.get(KEY)); + Map grpC = buckets.get(1); + assertEquals(1, grpC.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("grpC", grpC.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDateBucketedSumsPipelinedToBucketMinMaxSumAvgAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggDateHisto = AggregationBuilders.dateHistogram(GENERIC_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + + BucketMetricsPipelineAggregationBuilder aggAvgBucket = PipelineAggregatorBuilders + .avgBucket(BUCKETS_AGGREGATION_NAME_1, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggSumBucket = PipelineAggregatorBuilders + .sumBucket(BUCKETS_AGGREGATION_NAME_2, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMinBucket = PipelineAggregatorBuilders + .minBucket(BUCKETS_AGGREGATION_NAME_3, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMaxBucket = PipelineAggregatorBuilders + .maxBucket(BUCKETS_AGGREGATION_NAME_4, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggDateHisto, aggAvgBucket, aggSumBucket, aggMinBucket, aggMaxBucket), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + assertResultsOfPipelineSumtoDateHistogramAggs(searchResponseAsMap); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void assertResultsOfPipelineSumtoDateHistogramAggs(Map searchResponseAsMap) { + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + double aggValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_1); + assertEquals(3517.5, aggValue, DELTA_FOR_SCORE_ASSERTION); + + double sumValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_2); + assertEquals(7035.0, sumValue, DELTA_FOR_SCORE_ASSERTION); + + double minValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_3); + assertEquals(1234.0, minValue, DELTA_FOR_SCORE_ASSERTION); + + double maxValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_4); + assertEquals(5801.0, maxValue, DELTA_FOR_SCORE_ASSERTION); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } + + private void testDateBucketedSumsPipelinedToBucketStatsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggDateHisto = AggregationBuilders.dateHistogram(GENERIC_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + + StatsBucketPipelineAggregationBuilder aggStatsBucket = PipelineAggregatorBuilders.statsBucket( + BUCKETS_AGGREGATION_NAME_1, + GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME + ); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggDateHisto, aggStatsBucket), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map statsAggs = getAggregationValues(aggregations, BUCKETS_AGGREGATION_NAME_1); + + assertNotNull(statsAggs); + + assertEquals(3517.5, (Double) statsAggs.get("avg"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(7035.0, (Double) statsAggs.get("sum"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, (Double) statsAggs.get("min"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(5801.0, (Double) statsAggs.get("max"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(2, (int) statsAggs.get("count")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDateBucketedSumsPipelinedToBucketScriptedAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggBuilder = AggregationBuilders.dateHistogram(DATE_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregations( + new AggregatorFactories.Builder().addAggregator( + AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX) + ) + .addAggregator( + AggregationBuilders.filter( + GENERIC_AGGREGATION_NAME, + QueryBuilders.boolQuery() + .should( + QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE)) + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_ANGRY)) + ) + .should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(KEYWORD_FIELD_DOCKEYWORD))) + ).subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME_2).field(INTEGER_FIELD_PRICE)) + ) + .addPipelineAggregator( + PipelineAggregatorBuilders.bucketScript( + BUCKETS_AGGREGATION_NAME_1, + Map.of("docNum", GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME_2, "totalNum", SUM_AGGREGATION_NAME), + new Script("params.docNum / params.totalNum") + ) + ) + ); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + // first bucket have all the aggs values + Map firstBucket = buckets.get(0); + assertEquals(6, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.1053, getAggregationValue(firstBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map inBucketAggValues = getAggregationValues(firstBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inBucketAggValues); + assertEquals(1, inBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(130.0, getAggregationValue(inBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + + // second bucket is empty + Map secondBucket = buckets.get(1); + assertEquals(5, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertFalse(secondBucket.containsKey(BUCKETS_AGGREGATION_NAME_1)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map inSecondBucketAggValues = getAggregationValues(secondBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inSecondBucketAggValues); + assertEquals(0, inSecondBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(inSecondBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + + // last bucket has values + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(6, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0172, getAggregationValue(lastBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + + Map inLastBucketAggValues = getAggregationValues(lastBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inLastBucketAggValues); + assertEquals(1, inLastBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(100.0, getAggregationValue(inLastBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testSampler() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.sampler(GENERIC_AGGREGATION_NAME) + .shardSize(2) + .subAggregation(AggregationBuilders.terms(BUCKETS_AGGREGATION_NAME_1).field(KEYWORD_FIELD_DOCKEYWORD)); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(2, aggValue.size()); + assertEquals(3, aggValue.get(BUCKET_AGG_DOC_COUNT_FIELD)); + Map nestedAggs = getAggregationValues(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertNotNull(nestedAggs); + assertEquals(0, nestedAggs.get("doc_count_error_upper_bound")); + List> buckets = getAggregationBuckets(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertEquals(2, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("likeable", firstBucket.get(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("workable", secondBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testTermsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.terms(GENERIC_AGGREGATION_NAME).field(KEYWORD_FIELD_DOCKEYWORD); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + List> buckets = ((Map) getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME)).get( + "buckets" + ); + assertNotNull(buckets); + assertEquals(2, buckets.size()); + Map firstBucket = buckets.get(0); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(KEYWORD_FIELD_DOCKEYWORD_LIKABLE, firstBucket.get(KEY)); + Map secondBucket = buckets.get(1); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(KEYWORD_FIELD_DOCKEYWORD_WORKABLE, secondBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testSignificantTermsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.significantTerms(GENERIC_AGGREGATION_NAME).field(KEYWORD_FIELD_DOCKEYWORD); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + + Map significantTermsAggregations = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + + assertNotNull(significantTermsAggregations); + assertEquals(3, (int) getAggregationValues(significantTermsAggregations, BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(11, (int) getAggregationValues(significantTermsAggregations, "bg_count")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private Map executeQueryAndGetAggsResults(final Object aggsBuilder, String indexName) { + return executeQueryAndGetAggsResults(List.of(aggsBuilder), indexName); + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName, + int expectedHits + ) { + initializeIndexIfNotExist(indexName); + + Map searchResponseAsMap = search( + indexName, + queryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilders + ); + + assertHitResultsFromQuery(expectedHits, searchResponseAsMap); + return searchResponseAsMap; + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName + ) { + return executeQueryAndGetAggsResults(aggsBuilders, queryBuilder, indexName, 3); + } + + private Map executeQueryAndGetAggsResults(final List aggsBuilders, String indexName) { + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + return executeQueryAndGetAggsResults(aggsBuilders, hybridQueryBuilderNeuralThenTerm, indexName); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java new file mode 100644 index 000000000..36c853984 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/MetricAggregationsWithHybridQueryIT.java @@ -0,0 +1,495 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.aggregation; + +import lombok.SneakyThrows; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; + +/** + * Integration tests for metric type aggregations when they are bundled with hybrid query + * Below is list of metric aggregations that are present in this test: + * - Average + * - Cardinality + * - Extended stats + * - Top hits + * - Percentile ranks + * - Scripted metric + * - Sum + * - Value count + * + * Following metric aggs are tested by other integ tests + * - Maximum + * + * + * Below aggregations are not part of any test + * - Geobounds + * - Matrix stats + * - Minimum + * - Percentile + * - Stats + */ +public class MetricAggregationsWithHybridQueryIT extends BaseAggregationsWithHybridQueryIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-aggs-metric-multi-doc-index-multiple-shards"; + private static final String SEARCH_PIPELINE = "search-pipeline-metric-aggs"; + + /** + * Tests complex query with multiple nested sub-queries: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * }, + * "aggs": { + * "max_index": { + * "max": { + * "field": "doc_index" + * } + * } + * } + * } + */ + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenAvgAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testAvgAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenCardinalityAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testCardinalityAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenCardinalityAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testCardinalityAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenExtendedStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testExtendedStatsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenExtendedStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testExtendedStatsAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenTopHitsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testTopHitsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenTopHitsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testTopHitsAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenPercentileRank_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testPercentileRankAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenPercentileRank_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testPercentileRankAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenPercentile_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testPercentileAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenPercentile_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testPercentileAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenScriptedMetrics_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testScriptedMetricsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenScriptedMetrics_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testScriptedMetricsAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenSumAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testSumAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenSumAgg_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testSumAggs(); + } + + @SneakyThrows + public void testMetricAggs_whenValueCount_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testValueCountAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenValueCount_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testValueCountAggs(); + } + + private void testAvgAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.avg(AVG_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(AVG_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, AVG_AGGREGATION_NAME); + assertEquals(maxAggsValue, 2345.0, DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testCardinalityAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.cardinality(GENERIC_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + int aggsValue = getAggregationValue(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(aggsValue, 3); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testExtendedStatsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.extendedStats(GENERIC_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + Map extendedStatsValues = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(extendedStatsValues); + + assertEquals((double) extendedStatsValues.get("max"), 3456.0, DELTA_FOR_SCORE_ASSERTION); + assertEquals((int) extendedStatsValues.get("count"), 3); + assertEquals((double) extendedStatsValues.get("sum"), 7035.0, DELTA_FOR_SCORE_ASSERTION); + assertEquals((double) extendedStatsValues.get("avg"), 2345.0, DELTA_FOR_SCORE_ASSERTION); + assertEquals((double) extendedStatsValues.get("variance"), 822880.666, DELTA_FOR_SCORE_ASSERTION); + assertEquals((double) extendedStatsValues.get("std_deviation"), 907.127, DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testTopHitsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.topHits(GENERIC_AGGREGATION_NAME).size(4); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + Map aggsValues = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(aggsValues); + assertHitResultsFromQuery(3, aggsValues); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testScriptedMetricsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + // compute sum of all int fields that are not blank + AggregationBuilder aggsBuilder = AggregationBuilders.scriptedMetric(GENERIC_AGGREGATION_NAME) + .initScript(new Script("state.price = []")) + .mapScript( + new Script( + "state.price.add(doc[\"" + + INTEGER_FIELD_DOCINDEX + + "\"].size() == 0 ? 0 : doc." + + INTEGER_FIELD_DOCINDEX + + ".value)" + ) + ) + .combineScript(new Script("state.price.stream().mapToInt(Integer::intValue).sum()")) + .reduceScript(new Script("states.stream().mapToInt(Integer::intValue).sum()")); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + int aggsValue = getAggregationValue(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(7035, aggsValue); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testPercentileAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.percentiles(GENERIC_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + assertHitResultsFromQuery(3, searchResponseAsMap); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + Map> aggsValues = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(aggsValues); + + Map values = aggsValues.get("values"); + assertNotNull(values); + assertEquals(7, values.size()); + assertEquals(1234.0, values.get("1.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, values.get("5.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, values.get("25.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(2345.0, values.get("50.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(3456.0, values.get("75.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(3456.0, values.get("95.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(3456.0, values.get("99.0"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testPercentileRankAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.percentileRanks(GENERIC_AGGREGATION_NAME, new double[] { 2000, 3000 }) + .field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + assertHitResultsFromQuery(3, searchResponseAsMap); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + Map> aggsValues = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(aggsValues); + Map values = aggsValues.get("values"); + assertNotNull(values); + assertEquals(33.333, values.get("2000.0"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(66.666, values.get("3000.0"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testSumAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(SUM_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, SUM_AGGREGATION_NAME); + assertEquals(7035.0, maxAggsValue, DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testValueCountAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + AggregationBuilder aggsBuilder = AggregationBuilders.count(GENERIC_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggsBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + assertHitResultsFromQuery(3, searchResponseAsMap); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + assertTrue(aggregations.containsKey(GENERIC_AGGREGATION_NAME)); + assertEquals(3, (int) getAggregationValue(aggregations, GENERIC_AGGREGATION_NAME)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testSumAggsAndRangePostFilter() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder3); + + QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(INTEGER_FIELD_DOCINDEX).gte(3000).lte(5000); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + List.of(aggsBuilder), + rangeFilterQuery + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(SUM_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, SUM_AGGREGATION_NAME); + assertEquals(11602.0, maxAggsValue, DELTA_FOR_SCORE_ASSERTION); + + assertHitResultsFromQuery(2, searchResponseAsMap); + + // assert post-filter + List> hitsNestedList = getNestedHits(searchResponseAsMap); + + List docIndexes = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + assertNotNull(oneHit.get("_source")); + Map source = (Map) oneHit.get("_source"); + int docIndex = (int) source.get(INTEGER_FIELD_DOCINDEX); + docIndexes.add(docIndex); + } + assertEquals(0, docIndexes.stream().filter(docIndex -> docIndex < 3000 || docIndex > 5000).count()); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private Map executeQueryAndGetAggsResults(final Object aggsBuilder, String indexName) { + return executeQueryAndGetAggsResults(List.of(aggsBuilder), indexName); + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName, + int expectedHits + ) { + initializeIndexIfNotExist(indexName); + + Map searchResponseAsMap = search( + indexName, + queryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilders + ); + + assertHitResultsFromQuery(expectedHits, searchResponseAsMap); + return searchResponseAsMap; + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName + ) { + return executeQueryAndGetAggsResults(aggsBuilders, queryBuilder, indexName, 3); + } + + private Map executeQueryAndGetAggsResults(final List aggsBuilders, String indexName) { + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + return executeQueryAndGetAggsResults(aggsBuilders, hybridQueryBuilderNeuralThenTerm, indexName); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java new file mode 100644 index 000000000..168dce1e0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/aggregation/PipelineAggregationsWithHybridQueryIT.java @@ -0,0 +1,408 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query.aggregation; + +import lombok.SneakyThrows; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.pipeline.StatsBucketPipelineAggregationBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.SortOrder; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; + +/** + * Integration tests for pipeline type aggregations when they are bundled with hybrid query + * Below is list of aggregations that are present in this test: + * - bucket_sort + * - cumulative_sum + * + * Following metric aggs are tested by other integ tests: + * - min_bucket + * - max_bucket + * - sum_bucket + * - avg_bucket + * + * Below aggregations are not part of any test: + * - derivative + * - moving_avg + * - serial_diff + */ +public class PipelineAggregationsWithHybridQueryIT extends BaseAggregationsWithHybridQueryIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-aggs-pipeline-multi-doc-index-multiple-shards"; + private static final String SEARCH_PIPELINE = "search-pipeline-pipeline-aggs"; + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketStatsAggs(); + } + + @SneakyThrows + public void testPipelineSiblingAggs_whenDateBucketedSumsPipelinedToBucketStatsAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketStatsAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketScriptedAggs(); + } + + @SneakyThrows + public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketScriptedAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketScriptedAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToBucketSortAggs(); + } + + @SneakyThrows + public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToBucketSortAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToBucketSortAggs(); + } + + @SneakyThrows + public void testWithConcurrentSegmentSearch_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, true); + testDateBucketedSumsPipelinedToCumulativeSumAggs(); + } + + @SneakyThrows + public void testPipelineParentAggs_whenDateBucketedSumsPipelinedToCumulativeSumAggs_thenSuccessful() { + updateClusterSettings(CLUSTER_SETTING_CONCURRENT_SEGMENT_SEARCH, false); + testDateBucketedSumsPipelinedToCumulativeSumAggs(); + } + + private void testDateBucketedSumsPipelinedToBucketStatsAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggDateHisto = AggregationBuilders.dateHistogram(GENERIC_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX)); + + StatsBucketPipelineAggregationBuilder aggStatsBucket = PipelineAggregatorBuilders.statsBucket( + BUCKETS_AGGREGATION_NAME_1, + GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME + ); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggDateHisto, aggStatsBucket), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map statsAggs = getAggregationValues(aggregations, BUCKETS_AGGREGATION_NAME_1); + + assertNotNull(statsAggs); + + assertEquals(3517.5, (Double) statsAggs.get("avg"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(7035.0, (Double) statsAggs.get("sum"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, (Double) statsAggs.get("min"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(5801.0, (Double) statsAggs.get("max"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(2, (int) statsAggs.get("count")); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDateBucketedSumsPipelinedToBucketScriptedAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggBuilder = AggregationBuilders.dateHistogram(DATE_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregations( + new AggregatorFactories.Builder().addAggregator( + AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX) + ) + .addAggregator( + AggregationBuilders.filter( + GENERIC_AGGREGATION_NAME, + QueryBuilders.boolQuery() + .should( + QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE)) + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_ANGRY)) + ) + .should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(KEYWORD_FIELD_DOCKEYWORD))) + ).subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME_2).field(INTEGER_FIELD_PRICE)) + ) + .addPipelineAggregator( + PipelineAggregatorBuilders.bucketScript( + BUCKETS_AGGREGATION_NAME_1, + Map.of("docNum", GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME_2, "totalNum", SUM_AGGREGATION_NAME), + new Script("params.docNum / params.totalNum") + ) + ) + ); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + aggBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + // first bucket have all the aggs values + Map firstBucket = buckets.get(0); + assertEquals(6, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.1053, getAggregationValue(firstBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map inBucketAggValues = getAggregationValues(firstBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inBucketAggValues); + assertEquals(1, inBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(130.0, getAggregationValue(inBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + + // second bucket is empty + Map secondBucket = buckets.get(1); + assertEquals(5, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertFalse(secondBucket.containsKey(BUCKETS_AGGREGATION_NAME_1)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map inSecondBucketAggValues = getAggregationValues(secondBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inSecondBucketAggValues); + assertEquals(0, inSecondBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(inSecondBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + + // last bucket has values + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(6, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0172, getAggregationValue(lastBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + + Map inLastBucketAggValues = getAggregationValues(lastBucket, GENERIC_AGGREGATION_NAME); + assertNotNull(inLastBucketAggValues); + assertEquals(1, inLastBucketAggValues.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(100.0, getAggregationValue(inLastBucketAggValues, SUM_AGGREGATION_NAME_2), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDateBucketedSumsPipelinedToBucketSortAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggBuilder = AggregationBuilders.dateHistogram(DATE_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregations( + new AggregatorFactories.Builder().addAggregator( + AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX) + ) + .addPipelineAggregator( + PipelineAggregatorBuilders.bucketSort( + BUCKETS_AGGREGATION_NAME_1, + List.of(new FieldSortBuilder(SUM_AGGREGATION_NAME).order(SortOrder.DESC)) + ).size(5) + ) + ); + + QueryBuilder queryBuilder = QueryBuilders.boolQuery() + .should( + QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE)) + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_ANGRY)) + ) + .should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(KEYWORD_FIELD_DOCKEYWORD))); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggBuilder), + queryBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + + assertNotNull(buckets); + assertEquals(3, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/2015", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(2345.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + // second bucket is empty + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1995", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + // last bucket has values + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2007", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testDateBucketedSumsPipelinedToCumulativeSumAggs() throws IOException { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggBuilder = AggregationBuilders.dateHistogram(DATE_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD) + .subAggregations( + new AggregatorFactories.Builder().addAggregator( + AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_DOCINDEX) + ).addPipelineAggregator(PipelineAggregatorBuilders.cumulativeSum(BUCKETS_AGGREGATION_NAME_1, SUM_AGGREGATION_NAME)) + ); + + QueryBuilder queryBuilder = QueryBuilders.boolQuery() + .should( + QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_WORKABLE)) + .should(QueryBuilders.termQuery(KEYWORD_FIELD_DOCKEYWORD, KEYWORD_FIELD_DOCKEYWORD_ANGRY)) + ) + .should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(KEYWORD_FIELD_DOCKEYWORD))); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggBuilder), + queryBuilder, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(5, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, getAggregationValue(firstBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(5, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1234.0, getAggregationValue(secondBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + // last bucket is empty + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(5, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(2345.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertEquals(3579.0, getAggregationValue(lastBucket, BUCKETS_AGGREGATION_NAME_1), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private Map executeQueryAndGetAggsResults(final Object aggsBuilder, String indexName) { + return executeQueryAndGetAggsResults(List.of(aggsBuilder), indexName); + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName, + int expectedHits + ) { + initializeIndexIfNotExist(indexName); + + Map searchResponseAsMap = search( + indexName, + queryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilders + ); + + assertHitResultsFromQuery(expectedHits, searchResponseAsMap); + return searchResponseAsMap; + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName + ) { + return executeQueryAndGetAggsResults(aggsBuilders, queryBuilder, indexName, 3); + } + + private Map executeQueryAndGetAggsResults(final List aggsBuilders, String indexName) { + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + return executeQueryAndGetAggsResults(aggsBuilders, hybridQueryBuilderNeuralThenTerm, indexName); + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 32ba16696..baecf2932 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -56,14 +56,14 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import com.google.common.collect.ImmutableList; -import static org.opensearch.neuralsearch.TestUtils.MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_USER_AGENT; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_NORMALIZATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.DEFAULT_COMBINATION_METHOD; -import static org.opensearch.neuralsearch.TestUtils.PARAM_NAME_WEIGHTS; -import static org.opensearch.neuralsearch.TestUtils.MAX_RETRY; -import static org.opensearch.neuralsearch.TestUtils.MAX_TIME_OUT_INTERVAL; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_TASK_RESULT_QUERY_TIME_IN_SECOND; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_RETRY; +import static org.opensearch.neuralsearch.util.TestUtils.MAX_TIME_OUT_INTERVAL; import lombok.AllArgsConstructor; import lombok.Getter; @@ -82,6 +82,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" ); + private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -114,6 +115,7 @@ protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.jvm_heap_memory_threshold", 95); updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); } @@ -633,7 +635,10 @@ protected void addKnnDoc( request.setJsonEntity(builder.toString()); Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + assertTrue( + request.getEndpoint() + ": failed", + SUCCESS_STATUSES.contains(RestStatus.fromCode(response.getStatusLine().getStatusCode())) + ); } @SneakyThrows diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java b/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java index 133f42daf..a43f77917 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/OpenSearchSecureRestTestCase.java @@ -7,11 +7,11 @@ import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_PER_ROUTE; import static org.opensearch.client.RestClientBuilder.DEFAULT_MAX_CONN_TOTAL; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; -import static org.opensearch.neuralsearch.TestUtils.NEURAL_SEARCH_BWC_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.OPENDISTRO_SECURITY; -import static org.opensearch.neuralsearch.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.SECURITY_AUDITLOG_PREFIX; -import static org.opensearch.neuralsearch.TestUtils.SKIP_DELETE_MODEL_INDEX; +import static org.opensearch.neuralsearch.util.TestUtils.NEURAL_SEARCH_BWC_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.OPENDISTRO_SECURITY; +import static org.opensearch.neuralsearch.util.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.SECURITY_AUDITLOG_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.SKIP_DELETE_MODEL_INDEX; import java.io.IOException; import java.util.Collections; diff --git a/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java similarity index 100% rename from src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java similarity index 100% rename from src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java similarity index 90% rename from src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java rename to src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index a6f4a3e0f..0534f85bf 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -2,13 +2,15 @@ * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch; +package org.opensearch.neuralsearch.util; import com.carrotsearch.randomizedtesting.RandomizedTest; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; import static org.opensearch.test.OpenSearchTestCase.randomFloat; import java.util.ArrayList; @@ -62,6 +64,7 @@ public class TestUtils { public static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; public static final String SPARSE_ENCODING_PROCESSOR = "sparse_encoding"; + public static final String TEXT_CHUNKING_PROCESSOR = "text_chunking"; public static final int MAX_TIME_OUT_INTERVAL = 3000; public static final int MAX_RETRY = 5; @@ -299,6 +302,29 @@ public static void assertFetchResultScores(FetchSearchResult fetchSearchResult, assertEquals(0.001f, minScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); } + public static void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { + assertEquals(expected, getHitCount(searchResponseAsMap)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expected, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + private static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); @@ -314,6 +340,13 @@ private static Optional getMaxScore(Map searchResponseAsM return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } + @SuppressWarnings("unchecked") + private static int getHitCount(final Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + List hitsList = (List) hitsMap.get("hits"); + return hitsList.size(); + } + public static String getModelId(Map pipeline, String processor) { assertNotNull(pipeline); ArrayList> processors = (ArrayList>) pipeline.get("processors"); @@ -326,5 +359,4 @@ public static String getModelId(Map pipeline, String processor) public static String generateModelId() { return "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8); } - }