Skip to content

Commit

Permalink
[Backport 2.x][Fix] add non-null check for queryBuilder in NeuralQuer…
Browse files Browse the repository at this point in the history
…yEnricherProcessor (#619)

* [Fix] add non-null check for queryBuilder in NeuralQueryEnricherProcessor (#615)

* fix: add non-null check in NeuralQueryEnricherProcessor

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* add change log

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* optimize assert in ut

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

---------

Signed-off-by: zhichao-aws <zhichaog@amazon.com>
(cherry picked from commit b97dbe8)

* update pr number

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

* fix test compile

Signed-off-by: zhichao-aws <zhichaog@amazon.com>

---------

Signed-off-by: zhichao-aws <zhichaog@amazon.com>
  • Loading branch information
zhichao-aws authored Mar 4, 2024
1 parent c422065 commit 90ec7cc
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
Fix typo for sparse encoding processor factory([#600](https://github.com/opensearch-project/neural-search/pull/600))
Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#619](https://github.com/opensearch-project/neural-search/pull/619))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ private NeuralQueryEnricherProcessor(
@Override
public SearchRequest processRequest(SearchRequest searchRequest) {
QueryBuilder queryBuilder = searchRequest.source().query();
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
/* Use null check for the case where users are using empty query body. i.e. GET /index_name/_search */
if (queryBuilder != null) {
queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldDefaultIdMap));
}
return searchRequest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
import java.util.Collections;
import java.util.Map;

import org.apache.http.util.EntityUtils;
import org.junit.Before;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
Expand Down Expand Up @@ -55,6 +61,27 @@ public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() {
}
}

@SneakyThrows
public void testNeuralQueryEnricherProcessor_whenGetEmptyQueryBody_thenSuccess() {
String modelId = null;
try {
initializeIndexIfNotExist(index);
modelId = prepareModel();
createSearchRequestProcessor(modelId, search_pipeline);
createPipelineProcessor(modelId, ingest_pipeline, ProcessorType.TEXT_EMBEDDING);
updateIndexSettings(index, Settings.builder().put("index.search.default_pipeline", search_pipeline));
Request request = new Request("POST", "/" + index + "/_search");
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));
String responseBody = EntityUtils.toString(response.getEntity());
Map<String, Object> responseInMap = XContentHelper.convertToMap(XContentType.JSON.xContent(), responseBody, false);
assertFalse(responseInMap.isEmpty());
assertEquals(3, ((Map) responseInMap.get("hits")).size());
} finally {
wipeOfTestResources(index, ingest_pipeline, modelId, search_pipeline);
}
}

@SneakyThrows
public void testNeuralQueryEnricherProcessor_whenHybridQueryBuilderAndNoModelIdPassed_thenSuccess() {
String modelId = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ public void testProcessRequest_whenVisitingQueryBuilder_thenSuccess() throws Exc
assertEquals(processSearchRequest, searchRequest);
}

public void testProcessRequest_whenVisitingEmptyQueryBody_thenSuccess() throws Exception {
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder());
assertNull(searchRequest.source().query());
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
SearchRequest processSearchRequest = processor.processRequest(searchRequest);
// should do nothing
assertNull(processSearchRequest.source().query());
}

public void testType() throws Exception {
NeuralQueryEnricherProcessor.Factory factory = new NeuralQueryEnricherProcessor.Factory();
NeuralQueryEnricherProcessor processor = createTestProcessor(factory);
Expand Down

0 comments on commit 90ec7cc

Please sign in to comment.