Skip to content

Commit

Permalink
Merge branch 'main' into neural_sparse_default_model_id
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-aws committed Mar 14, 2024
2 parents 1f12c1a + 759a971 commit 6946d3a
Show file tree
Hide file tree
Showing 26 changed files with 2,621 additions and 321 deletions.
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,20 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Enhancements
### Bug Fixes
- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438))
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498))
- Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578))
- Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615))
### Infrastructure
### Documentation
### Maintenance
- Added support for jdk-21 ([#500](https://github.com/opensearch-project/neural-search/pull/500)))
### Refactoring

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.12...2.x)
### Features
- Enabled support for applying default modelId in neural sparse query ([#614](https://github.com/opensearch-project/neural-search/pull/614)
### Enhancements
- Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630))
### Bug Fixes
- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624))
### Infrastructure
### Documentation
### Maintenance
Expand Down
8 changes: 5 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ dependencies {
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}"
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
compileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar")
compileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
compileOnly group: 'commons-lang', name: 'commons-lang', version: '2.6'
api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}"
testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}"
implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
Expand All @@ -265,8 +267,8 @@ dependencies {
runtimeOnly group: 'org.json', name: 'json', version: '20231013'
testFixturesImplementation "org.opensearch:common-utils:${version}"
testFixturesImplementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0'
testFixturesCompileOnly group: 'com.google.guava', name: 'guava', version:'32.0.1-jre'
testFixturesCompileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
testFixturesCompileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
testFixturesCompileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar")
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
4 changes: 3 additions & 1 deletion qa/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ dependencies {
api "org.opensearch:opensearch:${opensearch_version}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}"
zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}"
compileOnly fileTree(dir: knnJarDirectory, include: '*.jar')
compileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar")
compileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre'
compileOnly group: 'commons-lang', name: 'commons-lang', version: '2.6'
api "org.apache.logging.log4j:log4j-api:${versions.log4j}"
api "org.apache.logging.log4j:log4j-core:${versions.log4j}"
api "junit:junit:${versions.junit}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.lucene.search.ScoreDoc;
Expand Down Expand Up @@ -131,13 +132,18 @@ private void updateQueryTopDocsWithCombinedScores(
compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits));
}

/**
* Get max hits as number of unique doc ids from results of all sub-queries
* @param topDocsPerSubQuery list of topDocs objects for one shard
* @return number of unique doc ids
*/
protected int getMaxHits(final List<TopDocs> topDocsPerSubQuery) {
int maxHits = 0;
for (TopDocs topDocs : topDocsPerSubQuery) {
int hits = topDocs.scoreDocs.length;
maxHits = Math.max(maxHits, hits);
}
return maxHits;
Set<Integer> docIds = topDocsPerSubQuery.stream()
.filter(topDocs -> Objects.nonNull(topDocs.scoreDocs))
.flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs))
.map(scoreDoc -> scoreDoc.doc)
.collect(Collectors.toSet());
return docIds.size();
}

private TotalHits getTotalHits(final List<TopDocs> topDocsPerSubQuery, int maxHits) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.List;
import java.util.Objects;

import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -77,20 +76,20 @@ public String toString(String field) {
/**
* Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary,
* until the rewritten query is the same as the original query.
* @param reader
* @param indexSearcher
* @return
* @throws IOException
*/
@Override
public Query rewrite(IndexReader reader) throws IOException {
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
if (subQueries.isEmpty()) {
return new MatchNoDocsQuery("empty HybridQuery");
}

boolean actuallyRewritten = false;
List<Query> rewrittenSubQueries = new ArrayList<>();
for (Query subQuery : subQueries) {
Query rewrittenSub = subQuery.rewrite(reader);
Query rewrittenSub = subQuery.rewrite(indexSearcher);
/* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls
queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses
perform better. For hybrid query we need to track progress of re-write for all sub-queries */
Expand All @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException {
return new HybridQuery(rewrittenSubQueries);
}

return super.rewrite(reader);
return super.rewrite(indexSearcher);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.QueryShardException;
import org.opensearch.index.query.Rewriteable;
import org.opensearch.index.query.QueryBuilderVisitor;

import lombok.Getter;
Expand All @@ -54,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder<HybridQueryBu

private String fieldName;

private static final int MAX_NUMBER_OF_SUB_QUERIES = 5;
static final int MAX_NUMBER_OF_SUB_QUERIES = 5;

public HybridQueryBuilder(StreamInput in) throws IOException {
super(in);
Expand Down Expand Up @@ -290,7 +289,7 @@ private void writeQueries(StreamOutput out, List<? extends QueryBuilder> queries
private Collection<Query> toQueries(Collection<QueryBuilder> queryBuilders, QueryShardContext context) throws QueryShardException {
List<Query> queries = queryBuilders.stream().map(qb -> {
try {
return Rewriteable.rewrite(qb, context).toQuery(context);
return qb.rewrite(context).toQuery(context);
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 6946d3a

Please sign in to comment.