From c9cdcc148cd176becfc1456c9f27ab90aa4bfcf5 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 11 Mar 2024 09:15:07 -0700 Subject: [PATCH 1/3] Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator (#624) * Adding two phase iterator Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/query/HybridQuery.java | 9 +- .../query/HybridQueryBuilder.java | 5 +- .../neuralsearch/query/HybridQueryScorer.java | 228 +++++++++++++++- .../neuralsearch/query/HybridQueryWeight.java | 86 +++++- .../HybridScoreBlockBoundaryPropagator.java | 99 +++++++ .../search/HybridTopScoreDocCollector.java | 45 +++- .../query/HybridAggregationProcessor.java | 82 ++++++ .../search/query/HybridCollectorManager.java | 253 ++++++++++++++++++ .../query/HybridQueryPhaseSearcher.java | 195 ++------------ .../query/HybridQueryScorerTests.java | 99 ++++++- ...bridScoreBlockBoundaryPropagatorTests.java | 113 ++++++++ .../HybridTopScoreDocCollectorTests.java | 109 ++++++++ .../HybridAggregationProcessorTests.java | 233 ++++++++++++++++ .../query/HybridCollectorManagerTests.java | 201 ++++++++++++++ .../query/HybridQueryPhaseSearcherTests.java | 69 ++--- 16 files changed, 1573 insertions(+), 254 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 528231b07..8dcdc721b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java index 8846f6977..01d271cdd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQuery.java @@ -12,7 +12,6 @@ import java.util.List; import java.util.Objects; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; @@ -77,12 +76,12 @@ public String toString(String field) { /** * Re-writes queries into primitive queries. Callers are expected to call rewrite multiple times if necessary, * until the rewritten query is the same as the original query. - * @param reader + * @param indexSearcher * @return * @throws IOException */ @Override - public Query rewrite(IndexReader reader) throws IOException { + public Query rewrite(IndexSearcher indexSearcher) throws IOException { if (subQueries.isEmpty()) { return new MatchNoDocsQuery("empty HybridQuery"); } @@ -90,7 +89,7 @@ public Query rewrite(IndexReader reader) throws IOException { boolean actuallyRewritten = false; List rewrittenSubQueries = new ArrayList<>(); for (Query subQuery : subQueries) { - Query rewrittenSub = subQuery.rewrite(reader); + Query rewrittenSub = subQuery.rewrite(indexSearcher); /* we keep rewrite sub-query unless it's not equal to itself, it may take multiple levels of recursive calls queries need to be rewritten from high-level clauses into lower-level clauses because low-level clauses perform better. For hybrid query we need to track progress of re-write for all sub-queries */ @@ -102,7 +101,7 @@ public Query rewrite(IndexReader reader) throws IOException { return new HybridQuery(rewrittenSubQueries); } - return super.rewrite(reader); + return super.rewrite(indexSearcher); } /** diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index aa4242c2e..60d9fd639 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -27,7 +27,6 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardException; -import org.opensearch.index.query.Rewriteable; import org.opensearch.index.query.QueryBuilderVisitor; import lombok.Getter; @@ -54,7 +53,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder queries private Collection toQueries(Collection queryBuilders, QueryShardContext context) throws QueryShardException { List queries = queryBuilders.stream().map(qb -> { try { - return Rewriteable.rewrite(qb, context).toQuery(context); + return qb.rewrite(context).toQuery(context); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java index 5abfd0b5e..f31d0abd9 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryScorer.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -18,10 +19,13 @@ import org.apache.lucene.search.DisjunctionDISIApproximation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import lombok.Getter; +import org.apache.lucene.util.PriorityQueue; /** * Class abstracts functionality of Scorer for hybrid query. When iterating over documents in increasing @@ -40,12 +44,56 @@ public final class HybridQueryScorer extends Scorer { private final Map> queryToIndex; - public HybridQueryScorer(Weight weight, List subScorers) throws IOException { + private final DocIdSetIterator approximation; + private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator; + private final TwoPhase twoPhase; + + public HybridQueryScorer(final Weight weight, final List subScorers) throws IOException { + this(weight, subScorers, ScoreMode.TOP_SCORES); + } + + HybridQueryScorer(final Weight weight, final List subScorers, final ScoreMode scoreMode) throws IOException { super(weight); this.subScorers = Collections.unmodifiableList(subScorers); subScores = new float[subScorers.size()]; this.queryToIndex = mapQueryToIndex(); this.subScorersPQ = initializeSubScorersPQ(); + boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES; + + this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ); + if (scoreMode == ScoreMode.TOP_SCORES) { + this.disjunctionBlockPropagator = new HybridScoreBlockBoundaryPropagator(subScorers); + } else { + this.disjunctionBlockPropagator = null; + } + + boolean hasApproximation = false; + float sumMatchCost = 0; + long sumApproxCost = 0; + // Compute matchCost as the average over the matchCost of the subScorers. + // This is weighted by the cost, which is an expected number of matching documents. + for (DisiWrapper w : subScorersPQ) { + long costWeight = (w.cost <= 1) ? 1 : w.cost; + sumApproxCost += costWeight; + if (w.twoPhaseView != null) { + hasApproximation = true; + sumMatchCost += w.matchCost * costWeight; + } + } + if (!hasApproximation) { // no sub scorer supports approximations + twoPhase = null; + } else { + final float matchCost = sumMatchCost / sumApproxCost; + twoPhase = new TwoPhase(approximation, matchCost, subScorersPQ, needsScores); + } + } + + @Override + public int advanceShallow(int target) throws IOException { + if (disjunctionBlockPropagator != null) { + return disjunctionBlockPropagator.advanceShallow(target); + } + return super.advanceShallow(target); } /** @@ -55,7 +103,10 @@ public HybridQueryScorer(Weight weight, List subScorers) throws IOExcept */ @Override public float score() throws IOException { - DisiWrapper topList = subScorersPQ.topList(); + return score(getSubMatches()); + } + + private float score(DisiWrapper topList) throws IOException { float totalScore = 0.0f; for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) { // check if this doc has match in the subQuery. If not, add score as 0.0 and continue @@ -67,13 +118,30 @@ public float score() throws IOException { return totalScore; } + DisiWrapper getSubMatches() throws IOException { + if (twoPhase == null) { + return subScorersPQ.topList(); + } else { + return twoPhase.getSubMatches(); + } + } + /** * Return a DocIdSetIterator over matching documents. * @return DocIdSetIterator object */ @Override public DocIdSetIterator iterator() { - return new DisjunctionDISIApproximation(this.subScorersPQ); + if (twoPhase != null) { + return TwoPhaseIterator.asDocIdSetIterator(twoPhase); + } else { + return approximation; + } + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return twoPhase; } /** @@ -93,12 +161,28 @@ public float getMaxScore(int upTo) throws IOException { }).max(Float::compare).orElse(0.0f); } + @Override + public void setMinCompetitiveScore(float minScore) throws IOException { + if (disjunctionBlockPropagator != null) { + disjunctionBlockPropagator.setMinCompetitiveScore(minScore); + } + + for (Scorer scorer : subScorers) { + if (Objects.nonNull(scorer)) { + scorer.setMinCompetitiveScore(minScore); + } + } + } + /** * Returns the doc ID that is currently being scored. * @return document id */ @Override public int docID() { + if (subScorersPQ.size() == 0) { + return DocIdSetIterator.NO_MORE_DOCS; + } return subScorersPQ.top().doc; } @@ -169,4 +253,142 @@ private DisiPriorityQueue initializeSubScorersPQ() { } return subScorersPQ; } + + @Override + public Collection getChildren() throws IOException { + ArrayList children = new ArrayList<>(); + for (DisiWrapper scorer = getSubMatches(); scorer != null; scorer = scorer.next) { + children.add(new ChildScorable(scorer.scorer, "SHOULD")); + } + return children; + } + + /** + * Object returned by {@link Scorer#twoPhaseIterator()} to provide an approximation of a {@link DocIdSetIterator}. + * After calling {@link DocIdSetIterator#nextDoc()} or {@link DocIdSetIterator#advance(int)} on the iterator + * returned by approximation(), you need to check {@link TwoPhaseIterator#matches()} to confirm if the retrieved + * document ID is a match. Implementation inspired by identical class for + * DisjunctionScorer + */ + static class TwoPhase extends TwoPhaseIterator { + private final float matchCost; + // list of verified matches on the current doc + DisiWrapper verifiedMatches; + // priority queue of approximations on the current doc that have not been verified yet + final PriorityQueue unverifiedMatches; + DisiPriorityQueue subScorers; + boolean needsScores; + + private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) { + super(approximation); + this.matchCost = matchCost; + this.subScorers = subScorers; + unverifiedMatches = new PriorityQueue<>(subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; + this.needsScores = needsScores; + } + + DisiWrapper getSubMatches() throws IOException { + for (DisiWrapper wrapper : unverifiedMatches) { + if (wrapper.twoPhaseView.matches()) { + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; + } + } + unverifiedMatches.clear(); + return verifiedMatches; + } + + @Override + public boolean matches() throws IOException { + verifiedMatches = null; + unverifiedMatches.clear(); + + for (DisiWrapper wrapper = subScorers.topList(); wrapper != null;) { + DisiWrapper next = wrapper.next; + + if (Objects.isNull(wrapper.twoPhaseView)) { + // implicitly verified, move it to verifiedMatches + wrapper.next = verifiedMatches; + verifiedMatches = wrapper; + + if (!needsScores) { + // we can stop here + return true; + } + } else { + unverifiedMatches.add(wrapper); + } + wrapper = next; + } + + if (Objects.nonNull(verifiedMatches)) { + return true; + } + + // verify subs that have an two-phase iterator + // least-costly ones first + while (unverifiedMatches.size() > 0) { + DisiWrapper wrapper = unverifiedMatches.pop(); + if (wrapper.twoPhaseView.matches()) { + wrapper.next = null; + verifiedMatches = wrapper; + return true; + } + } + return false; + } + + @Override + public float matchCost() { + return matchCost; + } + } + + /** + * A DocIdSetIterator which is a disjunction of the approximations of the provided iterators and supports + * sub iterators that return empty results + */ + static class HybridSubqueriesDISIApproximation extends DocIdSetIterator { + final DocIdSetIterator docIdSetIterator; + final DisiPriorityQueue subIterators; + + public HybridSubqueriesDISIApproximation(final DisiPriorityQueue subIterators) { + docIdSetIterator = new DisjunctionDISIApproximation(subIterators); + this.subIterators = subIterators; + } + + @Override + public long cost() { + return docIdSetIterator.cost(); + } + + @Override + public int docID() { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.docID(); + } + + @Override + public int nextDoc() throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.nextDoc(); + } + + @Override + public int advance(final int target) throws IOException { + if (subIterators.size() == 0) { + return NO_MORE_DOCS; + } + return docIdSetIterator.advance(target); + } + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index 69ee5015f..facb79694 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -5,10 +5,12 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.stream.Collectors; +import lombok.RequiredArgsConstructor; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -16,14 +18,16 @@ import org.apache.lucene.search.MatchesUtils; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; +import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; + /** * Calculates query weights and build query scorers for hybrid query. */ public final class HybridQueryWeight extends Weight { - private final HybridQuery queries; // The Weights for our subqueries, in 1-1 correspondence private final List weights; @@ -34,7 +38,6 @@ public final class HybridQueryWeight extends Weight { */ public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(hybridQuery); - this.queries = hybridQuery; weights = hybridQuery.getSubQueries().stream().map(q -> { try { return searcher.createWeight(q, scoreMode, boost); @@ -65,6 +68,20 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { return MatchesUtils.fromSubMatches(mis); } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + List scorerSuppliers = new ArrayList<>(); + for (Weight w : weights) { + ScorerSupplier ss = w.scorerSupplier(context); + scorerSuppliers.add(ss); + } + + if (scorerSuppliers.isEmpty()) { + return null; + } + return new HybridScorerSupplier(scorerSuppliers, this, scoreMode); + } + /** * Create the scorer used to score our associated Query * @@ -75,19 +92,12 @@ public Matches matches(LeafReaderContext context, int doc) throws IOException { */ @Override public Scorer scorer(LeafReaderContext context) throws IOException { - List scorers = weights.stream().map(w -> { - try { - return w.scorer(context); - } catch (IOException e) { - throw new RuntimeException(e); - } - }).collect(Collectors.toList()); - // if there are no matches in any of the scorers (sub-queries) we need to return - // scorer as null to avoid problems with disi result iterators - if (scorers.stream().allMatch(Objects::isNull)) { + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { return null; } - return new HybridQueryScorer(this, scorers); + supplier.setTopLevelScoringClause(); + return supplier.get(Long.MAX_VALUE); } /** @@ -98,6 +108,10 @@ public Scorer scorer(LeafReaderContext context) throws IOException { */ @Override public boolean isCacheable(LeafReaderContext ctx) { + if (weights.size() > MAX_NUMBER_OF_SUB_QUERIES) { + // this situation should never happen, but in case it do such query will not be cached + return false; + } return weights.stream().allMatch(w -> w.isCacheable(ctx)); } @@ -113,4 +127,50 @@ public boolean isCacheable(LeafReaderContext ctx) { public Explanation explain(LeafReaderContext context, int doc) throws IOException { throw new UnsupportedOperationException("Explain is not supported"); } + + @RequiredArgsConstructor + static class HybridScorerSupplier extends ScorerSupplier { + private long cost = -1; + private final List scorerSuppliers; + private final Weight weight; + private final ScoreMode scoreMode; + + @Override + public Scorer get(long leadCost) throws IOException { + List tScorers = new ArrayList<>(); + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + tScorers.add(ss.get(leadCost)); + } else { + tScorers.add(null); + } + } + return new HybridQueryScorer(weight, tScorers, scoreMode); + } + + @Override + public long cost() { + if (cost == -1) { + long cost = 0; + for (ScorerSupplier ss : scorerSuppliers) { + if (Objects.nonNull(ss)) { + cost += ss.cost(); + } + } + this.cost = cost; + } + return cost; + } + + @Override + public void setTopLevelScoringClause() throws IOException { + for (ScorerSupplier ss : scorerSuppliers) { + // sub scorers need to be able to skip too as calls to setMinCompetitiveScore get + // propagated + if (Objects.nonNull(ss)) { + ss.setTopLevelScoringClause(); + } + } + } + }; } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java new file mode 100644 index 000000000..6b47a098d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagator.java @@ -0,0 +1,99 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Objects; + +/** + * This class functions as a utility for propagating block boundaries within disjunctions. + * In disjunctions, where a match occurs if any subclause matches, a common approach might involve returning + * the minimum block boundary across all clauses. However, this method can introduce performance challenges, + * particularly when dealing with high minimum competitive scores and clauses with low scores that no longer + * significantly contribute to the iteration process. Therefore, this class computes block boundaries solely for clauses + * with a maximum score equal to or exceeding the minimum competitive score, or for the clause with the maximum + * score if such a clause is absent. + */ +public class HybridScoreBlockBoundaryPropagator { + + private static final Comparator MAX_SCORE_COMPARATOR = Comparator.comparing((Scorer s) -> { + try { + return s.getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } catch (IOException e) { + throw new RuntimeException(e); + } + }).thenComparing(s -> s.iterator().cost()); + + private final Scorer[] scorers; + private final float[] maxScores; + private int leadIndex = 0; + + HybridScoreBlockBoundaryPropagator(final Collection scorers) throws IOException { + this.scorers = scorers.stream().filter(Objects::nonNull).toArray(Scorer[]::new); + for (Scorer scorer : this.scorers) { + scorer.advanceShallow(0); + } + Arrays.sort(this.scorers, MAX_SCORE_COMPARATOR); + + maxScores = new float[this.scorers.length]; + for (int i = 0; i < this.scorers.length; ++i) { + maxScores[i] = this.scorers[i].getMaxScore(DocIdSetIterator.NO_MORE_DOCS); + } + } + + /** See {@link Scorer#advanceShallow(int)}. */ + int advanceShallow(int target) throws IOException { + // For scorers that are below the lead index, just propagate. + for (int i = 0; i < leadIndex; ++i) { + Scorer s = scorers[i]; + if (s.docID() < target) { + s.advanceShallow(target); + } + } + + // For scorers above the lead index, we take the minimum + // boundary. + Scorer leadScorer = scorers[leadIndex]; + int upTo = leadScorer.advanceShallow(Math.max(leadScorer.docID(), target)); + + for (int i = leadIndex + 1; i < scorers.length; ++i) { + Scorer scorer = scorers[i]; + if (scorer.docID() <= target) { + upTo = Math.min(scorer.advanceShallow(target), upTo); + } + } + + // If the maximum scoring clauses are beyond `target`, then we use their + // docID as a boundary. It helps not consider them when computing the + // maximum score and get a lower score upper bound. + for (int i = scorers.length - 1; i > leadIndex; --i) { + Scorer scorer = scorers[i]; + if (scorer.docID() > target) { + upTo = Math.min(upTo, scorer.docID() - 1); + } else { + break; + } + } + return upTo; + } + + /** + * Set the minimum competitive score to filter out clauses that score less than this threshold. + * + * @see Scorer#setMinCompetitiveScore + */ + void setMinCompetitiveScore(float minScore) throws IOException { + // Update the lead index if necessary + while (leadIndex < maxScores.length - 1 && minScore > maxScores[leadIndex]) { + leadIndex++; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..4418841f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -19,7 +20,6 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TopScoreDocCollector; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.PriorityQueue; import org.opensearch.neuralsearch.query.HybridQueryScorer; @@ -47,20 +47,55 @@ public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThreshol } @Override - public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + public LeafCollector getLeafCollector(LeafReaderContext context) { docBase = context.docBase; - return new TopScoreDocCollector.ScorerLeafCollector() { + return new LeafCollector() { HybridQueryScorer compoundQueryScorer; @Override public void setScorer(Scorable scorer) throws IOException { - super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + log.debug("passed scorer is of type HybridQueryScorer, saving it for collecting documents and scores"); + compoundQueryScorer = (HybridQueryScorer) scorer; + } else { + compoundQueryScorer = getHybridQueryScorer(scorer); + if (Objects.isNull(compoundQueryScorer)) { + log.error( + String.format(Locale.ROOT, "cannot find scorer of type HybridQueryScorer in a hierarchy of scorer %s", scorer) + ); + } + } + } + + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (Objects.nonNull(hybridQueryScorer)) { + log.debug( + String.format( + Locale.ROOT, + "found hybrid query scorer, it's child of scorer %s", + childScorable.child.getClass().getSimpleName() + ) + ); + return hybridQueryScorer; + } + } + return null; } @Override public void collect(int doc) throws IOException { + if (Objects.isNull(compoundQueryScorer)) { + throw new IllegalArgumentException("scorers are null for all sub-queries in hybrid query"); + } float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java new file mode 100644 index 000000000..4e9070748 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.AllArgsConstructor; +import org.apache.lucene.search.CollectorManager; +import org.opensearch.search.aggregations.AggregationInitializationException; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QueryPhaseExecutionException; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.io.IOException; +import java.util.List; + +import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; + +/** + * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom + * collector manager for hybris query (pre phase) and reducing results (post phase) + */ +@AllArgsConstructor +public class HybridAggregationProcessor implements AggregationProcessor { + + private final AggregationProcessor delegateAggsProcessor; + + @Override + public void preProcess(SearchContext context) { + delegateAggsProcessor.preProcess(context); + + if (isHybridQuery(context.query(), context)) { + // adding collector manager for hybrid query + CollectorManager collectorManager; + try { + collectorManager = HybridCollectorManager.createHybridCollectorManager(context); + } catch (IOException exception) { + throw new AggregationInitializationException("could not initialize hybrid aggregation processor", exception); + } + context.queryCollectorManagers().put(HybridCollectorManager.class, collectorManager); + } + } + + @Override + public void postProcess(SearchContext context) { + if (isHybridQuery(context.query(), context)) { + // for case when concurrent search is not enabled (default as of 2.12 release) reduce for collector + // managers is not called + // (https://github.com/opensearch-project/OpenSearch/blob/2.12/server/src/main/java/org/opensearch/search/query/QueryPhase.java#L333-L373) + // and we have to call it manually. This is required as we format final + // result of hybrid query in {@link HybridTopScoreCollector#reduce} + // when concurrent search is enabled then reduce method is called as part of the search {@see + // ConcurrentQueryPhaseSearcher#searchWithCollectorManager} + // corresponding call in Lucene + // https://github.com/apache/lucene/blob/branch_9_10/lucene/core/src/java/org/apache/lucene/search/IndexSearcher.java#L700 + if (!context.shouldUseConcurrentSearch()) { + reduceCollectorResults(context); + } + updateQueryResult(context.queryResult(), context); + } + + delegateAggsProcessor.postProcess(context); + } + + private void reduceCollectorResults(SearchContext context) { + CollectorManager collectorManager = context.queryCollectorManagers().get(HybridCollectorManager.class); + try { + collectorManager.reduce(List.of()).reduce(context.queryResult()); + } catch (IOException e) { + throw new QueryPhaseExecutionException(context.shardTarget(), "failed to execute hybrid query aggregation processor", e); + } + } + + private void updateQueryResult(final QuerySearchResult queryResult, final SearchContext searchContext) { + boolean isSingleShard = searchContext.numberOfShards() == 1; + if (isSingleShard) { + searchContext.size(queryResult.queryResult().topDocs().topDocs.scoreDocs.length); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java new file mode 100644 index 000000000..a5de898ab --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java @@ -0,0 +1,253 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.MultiCollectorWrapper; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; +import org.opensearch.search.sort.SortAndFormats; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; + +/** + * Collector manager based on HybridTopScoreDocCollector that allows users to parallelize counting the number of hits. + * In most cases it will be wrapped in MultiCollectorManager. + */ +@RequiredArgsConstructor +public abstract class HybridCollectorManager implements CollectorManager { + + private final int numHits; + private final HitsThresholdChecker hitsThresholdChecker; + private final boolean isSingleShard; + private final int trackTotalHitsUpTo; + private final SortAndFormats sortAndFormats; + + /** + * Create new instance of HybridCollectorManager depending on the concurrent search beeing enabled or disabled. + * @param searchContext + * @return + * @throws IOException + */ + public static CollectorManager createHybridCollectorManager(final SearchContext searchContext) throws IOException { + final IndexReader reader = searchContext.searcher().getIndexReader(); + final int totalNumDocs = Math.max(0, reader.numDocs()); + boolean isSingleShard = searchContext.numberOfShards() == 1; + int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); + int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); + + return searchContext.shouldUseConcurrentSearch() + ? new HybridCollectorConcurrentSearchManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ) + : new HybridCollectorNonConcurrentManager( + numDocs, + new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), + isSingleShard, + trackTotalHitsUpTo, + searchContext.sort() + ); + } + + @Override + public Collector newCollector() { + Collector hybridcollector = new HybridTopScoreDocCollector(numHits, hitsThresholdChecker); + return hybridcollector; + } + + /** + * Reduce the results from hybrid scores collector into a format specific for hybrid search query: + * - start + * - sub-query-delimiter + * - scores + * - stop + * Ignore other collectors if they are present in the context + * @param collectors collection of collectors after they has been executed and collected documents and scores + * @return search results that can be reduced be the caller + */ + @Override + public ReduceableSearchResult reduce(Collection collectors) { + final List hybridTopScoreDocCollectors = new ArrayList<>(); + // check if collector for hybrid query scores is part of this search context. It can be wrapped into MultiCollectorWrapper + // in case multiple collector managers are registered. We use hybrid scores collector to format scores into + // format specific for hybrid search query: start, sub-query-delimiter, scores, stop + for (final Collector collector : collectors) { + if (collector instanceof MultiCollectorWrapper) { + for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) { + if (sub instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) sub); + } + } + } else if (collector instanceof HybridTopScoreDocCollector) { + hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector) collector); + } + } + + if (!hybridTopScoreDocCollectors.isEmpty()) { + HybridTopScoreDocCollector hybridTopScoreDocCollector = hybridTopScoreDocCollectors.stream() + .findFirst() + .orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query")); + List topDocs = hybridTopScoreDocCollector.topDocs(); + TopDocs newTopDocs = getNewTopDocs(getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard), topDocs); + float maxScore = getMaxScore(topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); }; + } + throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors"); + } + + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(int trackTotalHitsUpTo, final List topDocs, final boolean isSingleShard) { + final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED + ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO + : TotalHits.Relation.EQUAL_TO; + if (topDocs == null || topDocs.isEmpty()) { + return new TotalHits(0, relation); + } + + List scoreDocs = topDocs.stream() + .map(topdDoc -> topdDoc.scoreDocs) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + Set uniqueDocIds = new HashSet<>(); + for (ScoreDoc[] scoreDocsArray : scoreDocs) { + uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList())); + } + long maxTotalHits = uniqueDocIds.size(); + + return new TotalHits(maxTotalHits, relation); + } + + private float getMaxScore(final List topDocs) { + if (topDocs.isEmpty()) { + return 0.0f; + } else { + return topDocs.stream() + .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .get(); + } + } + + private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { + return sortAndFormats == null ? null : sortAndFormats.formats; + } + + /** + * Implementation of the HybridCollector that reuses instance of collector on each even call. This allows caller to + * use saved state of collector + */ + static class HybridCollectorNonConcurrentManager extends HybridCollectorManager { + private final Collector scoreCollector; + + public HybridCollectorNonConcurrentManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null"); + } + + @Override + public Collector newCollector() { + return scoreCollector; + } + + @Override + public ReduceableSearchResult reduce(Collection collectors) { + assert collectors.isEmpty() : "reduce on HybridCollectorNonConcurrentManager called with non-empty collectors"; + return super.reduce(List.of(scoreCollector)); + } + } + + /** + * Implementation of the HybridCollector that doesn't save collector's state and return new instance of every + * call of newCollector + */ + static class HybridCollectorConcurrentSearchManager extends HybridCollectorManager { + + public HybridCollectorConcurrentSearchManager( + int numHits, + HitsThresholdChecker hitsThresholdChecker, + boolean isSingleShard, + int trackTotalHitsUpTo, + SortAndFormats sortAndFormats + ) { + super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..6461c698e 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -4,46 +4,26 @@ */ package org.opensearch.neuralsearch.search.query; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; -import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; -import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; - import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.LinkedList; import java.util.List; -import java.util.Objects; -import org.apache.lucene.index.IndexReader; +import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHitCountCollector; -import org.apache.lucene.search.TotalHits; -import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.HitsThresholdChecker; -import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; -import org.opensearch.search.DocValueFormat; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; import org.opensearch.search.query.QueryPhase; import org.opensearch.search.query.QueryPhaseSearcherWrapper; -import org.opensearch.search.query.QuerySearchResult; -import org.opensearch.search.query.TopDocsCollectorContext; -import org.opensearch.search.rescore.RescoreContext; -import org.opensearch.search.sort.SortAndFormats; - -import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; @@ -66,15 +46,17 @@ public boolean searchWith( final boolean hasFilterCollector, final boolean hasTimeout ) throws IOException { - if (isHybridQuery(query, searchContext)) { + if (!isHybridQuery(query, searchContext)) { + validateQuery(searchContext, query); + return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); + } else { Query hybridQuery = extractHybridQuery(searchContext, query); - return searchWithCollector(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); + return super.searchWith(searchContext, searcher, hybridQuery, collectors, hasFilterCollector, hasTimeout); } - validateQuery(searchContext, query); - return super.searchWith(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout); } - private boolean isHybridQuery(final Query query, final SearchContext searchContext) { + @VisibleForTesting + static boolean isHybridQuery(final Query query, final SearchContext searchContext) { if (query instanceof HybridQuery) { return true; } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { @@ -103,7 +85,7 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte // we have already checked if query in instance of Boolean in higher level else if condition return ((BooleanQuery) query).clauses() .stream() - .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) .allMatch(clause -> { return clause.getOccur() == BooleanClause.Occur.FILTER && clause.getQuery() instanceof FieldExistsQuery @@ -113,16 +95,17 @@ private boolean isHybridQuery(final Query query, final SearchContext searchConte return false; } - private boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } - private boolean isWrappedHybridQuery(final Query query) { + private static boolean isWrappedHybridQuery(final Query query) { return query instanceof BooleanQuery && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); } - private Query extractHybridQuery(final SearchContext searchContext, final Query query) { + @VisibleForTesting + protected Query extractHybridQuery(final SearchContext searchContext, final Query query) { if (hasNestedFieldOrNestedDocs(query, searchContext) && isWrappedHybridQuery(query) && ((BooleanQuery) query).clauses().size() > 0) { @@ -180,152 +163,14 @@ private void validateNestedBooleanQuery(final Query query, final int level) { } } - @VisibleForTesting - protected boolean searchWithCollector( - final SearchContext searchContext, - final ContextIndexSearcher searcher, - final Query query, - final LinkedList collectors, - final boolean hasFilterCollector, - final boolean hasTimeout - ) throws IOException { - log.debug("searching with custom doc collector, shard {}", searchContext.shardTarget().getShardId()); - - final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector); - collectors.addFirst(topDocsFactory); - if (searchContext.size() == 0) { - final TotalHitCountCollector collector = new TotalHitCountCollector(); - searcher.search(query, collector); - return false; - } - final IndexReader reader = searchContext.searcher().getIndexReader(); - int totalNumDocs = Math.max(0, reader.numDocs()); - int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs); - final boolean shouldRescore = !searchContext.rescore().isEmpty(); - if (shouldRescore) { - for (RescoreContext rescoreContext : searchContext.rescore()) { - numDocs = Math.max(numDocs, rescoreContext.getWindowSize()); - } - } - - final QuerySearchResult queryResult = searchContext.queryResult(); - - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( - numDocs, - new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) - ); - - searcher.search(query, collector); - - if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { - queryResult.terminatedEarly(false); - } - - setTopDocsInQueryResult(queryResult, collector, searchContext); - - return shouldRescore; - } - - private void setTopDocsInQueryResult( - final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, - final SearchContext searchContext - ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); - } - - private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { - ScoreDoc[] scoreDocs = new ScoreDoc[0]; - if (Objects.nonNull(topDocs)) { - // for a single shard case we need to do score processing at coordinator level. - // this is workaround for current core behaviour, for single shard fetch phase is executed - // right after query phase and processors are called after actual fetch is done - // find any valid doc Id, or set it to -1 if there is not a single match - int delimiterDocId = topDocs.stream() - .filter(Objects::nonNull) - .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) - .map(topDoc -> topDoc.scoreDocs) - .filter(scoreDoc -> scoreDoc.length > 0) - .map(scoreDoc -> scoreDoc[0].doc) - .findFirst() - .orElse(-1); - if (delimiterDocId == -1) { - return new TopDocs(totalHits, scoreDocs); - } - // format scores using following template: - // doc_id | magic_number_1 - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_2 - // ... - // doc_id | magic_number_1 - List result = new ArrayList<>(); - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - for (TopDocs topDoc : topDocs) { - if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - continue; - } - result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); - result.addAll(Arrays.asList(topDoc.scoreDocs)); - } - result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); - scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } - return new TopDocs(totalHits, scoreDocs); - } - - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { - int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); - final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED - ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO - : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.isEmpty()) { - return new TotalHits(0, relation); - } - long maxTotalHits = topDocs.get(0).totalHits.value; - int totalSize = 0; - for (TopDocs topDoc : topDocs) { - maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); - if (isSingleShard) { - totalSize += topDoc.totalHits.value + 1; - } - } - // add 1 qty per each sub-query and + 2 for start and stop delimiters - totalSize += 2; - if (isSingleShard) { - // for single shard we need to update total size as this is how many docs are fetched in Fetch phase - searchContext.size(totalSize); - } - - return new TotalHits(maxTotalHits, relation); - } - - private float getMaxScore(final List topDocs) { - if (topDocs.isEmpty()) { - return 0.0f; - } else { - return topDocs.stream() - .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) - .map(scoreDoc -> scoreDoc.score) - .max(Float::compare) - .get(); - } - } - - private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) { - return sortAndFormats == null ? null : sortAndFormats.formats; - } - private int getMaxDepthLimit(final SearchContext searchContext) { Settings indexSettings = searchContext.getQueryShardContext().getIndexSettings().getSettings(); return MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings).intValue(); } + + @Override + public AggregationProcessor aggregationProcessor(SearchContext searchContext) { + AggregationProcessor coreAggProcessor = super.aggregationProcessor(searchContext); + return new HybridAggregationProcessor(coreAggProcessor); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java index 1a2e3f26e..a0a4c8ca3 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryScorerTests.java @@ -21,6 +21,7 @@ import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.tests.util.TestUtil; @@ -219,13 +220,101 @@ public void testMaxScoreFailures_whenScorerThrowsException_thenFail() { when(scorer.iterator()).thenReturn(iterator(docs)); when(scorer.getMaxScore(anyInt())).thenThrow(new IOException("Test exception")); - HybridQueryScorer hybridQueryScorerWithAllNonNullSubScorers = new HybridQueryScorer(weight, Arrays.asList(scorer)); + IOException runtimeException = expectThrows(IOException.class, () -> new HybridQueryScorer(weight, Arrays.asList(scorer))); + assertTrue(runtimeException.getMessage().contains("Test exception")); + } + + @SneakyThrows + public void testApproximationIterator_whenSubScorerSupportsApproximation_thenSuccessful() { + final int maxDoc = TestUtil.nextInt(random(), 10, 1_000); + final int numDocs = TestUtil.nextInt(random(), 1, maxDoc / 2); + final Set uniqueDocs = new HashSet<>(); + while (uniqueDocs.size() < numDocs) { + uniqueDocs.add(random().nextInt(maxDoc)); + } + final int[] docs = new int[numDocs]; + int i = 0; + for (int doc : uniqueDocs) { + docs[i++] = doc; + } + Arrays.sort(docs); + final float[] scores1 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores1[i] = random().nextFloat(); + } + final float[] scores2 = new float[numDocs]; + for (i = 0; i < numDocs; ++i) { + scores2[i] = random().nextFloat(); + } + + Weight weight = mock(Weight.class); - RuntimeException runtimeException = expectThrows( - RuntimeException.class, - () -> hybridQueryScorerWithAllNonNullSubScorers.getMaxScore(Integer.MAX_VALUE) + HybridQueryScorer queryScorer = new HybridQueryScorer( + weight, + Arrays.asList( + scorerWithTwoPhaseIterator(docs, scores1, fakeWeight(new MatchAllDocsQuery()), maxDoc), + scorerWithTwoPhaseIterator(docs, scores2, fakeWeight(new MatchNoDocsQuery()), maxDoc) + ) ); - assertTrue(runtimeException.getMessage().contains("Test exception")); + + int doc = -1; + int idx = 0; + while (doc != DocIdSetIterator.NO_MORE_DOCS) { + doc = queryScorer.iterator().nextDoc(); + if (idx == docs.length) { + assertEquals(DocIdSetIterator.NO_MORE_DOCS, doc); + } else { + assertEquals(docs[idx], doc); + assertEquals(scores1[idx] + scores2[idx], queryScorer.score(), 0.001f); + } + idx++; + } + } + + protected static Scorer scorerWithTwoPhaseIterator(final int[] docs, final float[] scores, Weight weight, int maxDoc) { + final DocIdSetIterator iterator = DocIdSetIterator.all(maxDoc); + return new Scorer(weight) { + + int lastScoredDoc = -1; + + public DocIdSetIterator iterator() { + return TwoPhaseIterator.asDocIdSetIterator(twoPhaseIterator()); + } + + @Override + public int docID() { + return iterator.docID(); + } + + @Override + public float score() { + assertNotEquals("score() called twice on doc " + docID(), lastScoredDoc, docID()); + lastScoredDoc = docID(); + final int idx = Arrays.binarySearch(docs, docID()); + return scores[idx]; + } + + @Override + public float getMaxScore(int upTo) { + return Float.MAX_VALUE; + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return new TwoPhaseIterator(iterator) { + + @Override + public boolean matches() { + return Arrays.binarySearch(docs, iterator.docID()) >= 0; + } + + @Override + public float matchCost() { + return 10; + } + }; + } + }; } private Pair generateDocuments(int maxDocId) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java new file mode 100644 index 000000000..5bf0948ea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridScoreBlockBoundaryPropagatorTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class HybridScoreBlockBoundaryPropagatorTests extends OpenSearchQueryTestCase { + + public void testAdvanceShallow_whenMinCompetitiveScoreSet_thenSuccessful() throws IOException { + Scorer scorer1 = new MockScorer(10, 0.6f); + Scorer scorer2 = new MockScorer(40, 1.5f); + Scorer scorer3 = new MockScorer(30, 2f); + Scorer scorer4 = new MockScorer(120, 4f); + + List scorers = Arrays.asList(scorer1, scorer2, scorer3, scorer4); + Collections.shuffle(scorers, random()); + HybridScoreBlockBoundaryPropagator propagator = new HybridScoreBlockBoundaryPropagator(scorers); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.1f); + assertEquals(10, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(0.8f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.4f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(1.9f); + assertEquals(30, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(2.5f); + assertEquals(120, propagator.advanceShallow(0)); + + propagator.setMinCompetitiveScore(7f); + assertEquals(120, propagator.advanceShallow(0)); + } + + private static class MockWeight extends Weight { + + MockWeight() { + super(new MatchNoDocsQuery()); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return null; + } + + @Override + public Scorer scorer(LeafReaderContext context) { + return null; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + } + + private static class MockScorer extends Scorer { + + final int boundary; + final float maxScore; + + MockScorer(int boundary, float maxScore) throws IOException { + super(new MockWeight()); + this.boundary = boundary; + this.maxScore = maxScore; + } + + @Override + public int docID() { + return 0; + } + + @Override + public float score() { + throw new UnsupportedOperationException(); + } + + @Override + public DocIdSetIterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public void setMinCompetitiveScore(float minCompetitiveScore) {} + + @Override + public float getMaxScore(int upTo) throws IOException { + return maxScore; + } + + @Override + public int advanceShallow(int target) { + assert target <= boundary; + return boundary; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java index b67a1ee05..ad5a955c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollectorTests.java @@ -11,6 +11,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -27,12 +28,15 @@ import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.apache.lucene.util.PriorityQueue; import org.opensearch.index.mapper.TextFieldMapper; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.QueryShardContext; @@ -399,4 +403,109 @@ public void testTrackTotalHits_whenTotalHitsSetIntegerMaxValue_thenSuccessful() reader.close(); directory.close(); } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsChildScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + Scorer scorer = mock(Scorer.class); + Collection childrenCollectors = List.of(new Scorable.ChildScorable(hybridQueryScorer, "MUST")); + when(scorer.getChildren()).thenReturn(childrenCollectors); + leafCollector.setScorer(scorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } + + @SneakyThrows + public void testCompoundScorer_whenHybridScorerIsTopLevelScorer_thenSuccessful() { + HybridTopScoreDocCollector hybridTopScoreDocCollector = new HybridTopScoreDocCollector( + NUM_DOCS, + new HitsThresholdChecker(Integer.MAX_VALUE) + ); + + final Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_1, FIELD_1_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_2, FIELD_2_VALUE, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, DOC_ID_3, FIELD_3_VALUE, ft)); + w.commit(); + + DirectoryReader reader = DirectoryReader.open(w); + + LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); + LeafCollector leafCollector = hybridTopScoreDocCollector.getLeafCollector(leafReaderContext); + + assertNotNull(leafCollector); + + Weight weight = mock(Weight.class); + Weight subQueryWeight = mock(Weight.class); + Scorer subQueryScorer = mock(Scorer.class); + when(subQueryScorer.getWeight()).thenReturn(subQueryWeight); + DocIdSetIterator iterator = mock(DocIdSetIterator.class); + when(subQueryScorer.iterator()).thenReturn(iterator); + + HybridQueryScorer hybridQueryScorer = new HybridQueryScorer(weight, Arrays.asList(subQueryScorer)); + + leafCollector.setScorer(hybridQueryScorer); + int nextDoc = hybridQueryScorer.iterator().nextDoc(); + leafCollector.collect(nextDoc); + + assertNotNull(hybridTopScoreDocCollector.getCompoundScores()); + PriorityQueue[] compoundScoresPQ = hybridTopScoreDocCollector.getCompoundScores(); + assertEquals(1, compoundScoresPQ.length); + PriorityQueue scoreDoc = compoundScoresPQ[0]; + assertNotNull(scoreDoc); + assertNotNull(scoreDoc.top()); + + w.close(); + reader.close(); + directory.close(); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java new file mode 100644 index 000000000..f44e762f0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessorTests.java @@ -0,0 +1,233 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import lombok.SneakyThrows; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +public class HybridAggregationProcessorTests extends OpenSearchQueryTestCase { + + static final String TEXT_FIELD_NAME = "field"; + static final String TERM_QUERY_TEXT = "keyword"; + + @SneakyThrows + public void testAggregationProcessorDelegate_whenPreAndPostAreCalled_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + hybridAggregationProcessor.preProcess(searchContext); + verify(mockAggsProcessorDelegate).preProcess(any()); + + hybridAggregationProcessor.postProcess(searchContext); + verify(mockAggsProcessorDelegate).postProcess(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verify(hybridCollectorManagerSpy).reduce(any()); + } + + @SneakyThrows + public void testCollectorManager_whenHybridQueryAndConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertEquals(1, classCollectorManagerMap.size()); + assertTrue(classCollectorManagerMap.containsKey(HybridCollectorManager.class)); + CollectorManager hybridCollectorManager = classCollectorManagerMap.get( + HybridCollectorManager.class + ); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + // set captor on collector manager to track if reduce has been called + CollectorManager hybridCollectorManagerSpy = spy(hybridCollectorManager); + classCollectorManagerMap.put(HybridCollectorManager.class, hybridCollectorManagerSpy); + + hybridAggregationProcessor.postProcess(searchContext); + + verifyNoInteractions(hybridCollectorManagerSpy); + } + + @SneakyThrows + public void testCollectorManager_whenNotHybridQueryAndNotConcurrentSearch_thenSuccessful() { + AggregationProcessor mockAggsProcessorDelegate = mock(AggregationProcessor.class); + HybridAggregationProcessor hybridAggregationProcessor = new HybridAggregationProcessor(mockAggsProcessorDelegate); + + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query termQuery = termSubQuery.toQuery(mockQueryShardContext); + + when(searchContext.query()).thenReturn(termQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + hybridAggregationProcessor.preProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + + // setup query result for post processing + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridAggregationProcessor.postProcess(searchContext); + + assertTrue(classCollectorManagerMap.isEmpty()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java new file mode 100644 index 000000000..65d6f3d8a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -0,0 +1,201 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.search.query; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import lombok.SneakyThrows; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.analysis.MockAnalyzer; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.BoostingQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryWeight; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; +import org.opensearch.search.internal.ContextIndexSearcher; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.search.query.ReduceableSearchResult; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; + +public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { + + private static final String TEXT_FIELD_NAME = "field"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String QUERY1 = "hello"; + private static final float DELTA_FOR_ASSERTION = 0.001f; + + @SneakyThrows + public void testNewCollector_whenNotConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorNonConcurrentManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertSame(collector, secondCollector); + } + + @SneakyThrows + public void testNewCollector_whenConcurrentSearch_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1); + HybridQuery hybridQuery = new HybridQuery(List.of(termSubQuery.toQuery(mockQueryShardContext))); + + when(searchContext.query()).thenReturn(hybridQuery); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(true); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + assertNotNull(hybridCollectorManager); + assertTrue(hybridCollectorManager instanceof HybridCollectorManager.HybridCollectorConcurrentSearchManager); + + Collector collector = hybridCollectorManager.newCollector(); + assertNotNull(collector); + assertTrue(collector instanceof HybridTopScoreDocCollector); + + Collector secondCollector = hybridCollectorManager.newCollector(); + assertNotSame(collector, secondCollector); + } + + @SneakyThrows + public void testReduce_whenMatchedDocs_thenSuccessful() { + SearchContext searchContext = mock(SearchContext.class); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery hybridQueryWithTerm = new HybridQuery( + List.of(QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY1).toQuery(mockQueryShardContext)) + ); + when(searchContext.query()).thenReturn(hybridQueryWithTerm); + ContextIndexSearcher indexSearcher = mock(ContextIndexSearcher.class); + IndexReader indexReader = mock(IndexReader.class); + when(indexReader.numDocs()).thenReturn(3); + when(indexSearcher.getIndexReader()).thenReturn(indexReader); + when(searchContext.searcher()).thenReturn(indexSearcher); + when(searchContext.size()).thenReturn(1); + + Map, CollectorManager> classCollectorManagerMap = new HashMap<>(); + when(searchContext.queryCollectorManagers()).thenReturn(classCollectorManagerMap); + when(searchContext.shouldUseConcurrentSearch()).thenReturn(false); + + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + final IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + int docId1 = RandomizedTest.randomInt(); + int docId2 = RandomizedTest.randomInt(); + int docId3 = RandomizedTest.randomInt(); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId1, TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId2, TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, docId3, TEST_DOC_TEXT3, ft)); + w.flush(); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + CollectorManager hybridCollectorManager = HybridCollectorManager.createHybridCollectorManager(searchContext); + HybridTopScoreDocCollector collector = (HybridTopScoreDocCollector) hybridCollectorManager.newCollector(); + + Weight weight = new HybridQueryWeight(hybridQueryWithTerm, searcher, ScoreMode.TOP_SCORES, BoostingQueryBuilder.DEFAULT_BOOST); + collector.setWeight(weight); + LeafReaderContext leafReaderContext = searcher.getIndexReader().leaves().get(0); + LeafCollector leafCollector = collector.getLeafCollector(leafReaderContext); + BulkScorer scorer = weight.bulkScorer(leafReaderContext); + scorer.score(leafCollector, leafReaderContext.reader().getLiveDocs()); + leafCollector.finish(); + + Object results = hybridCollectorManager.reduce(List.of()); + + assertNotNull(results); + ReduceableSearchResult reduceableSearchResult = ((ReduceableSearchResult) results); + QuerySearchResult querySearchResult = new QuerySearchResult(); + reduceableSearchResult.reduce(querySearchResult); + TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); + + assertNotNull(topDocsAndMaxScore); + assertEquals(1, topDocsAndMaxScore.topDocs.totalHits.value); + assertEquals(TotalHits.Relation.EQUAL_TO, topDocsAndMaxScore.topDocs.totalHits.relation); + float maxScore = topDocsAndMaxScore.maxScore; + assertTrue(maxScore > 0); + ScoreDoc[] scoreDocs = topDocsAndMaxScore.topDocs.scoreDocs; + assertEquals(4, scoreDocs.length); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[0].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_DELIMITER, scoreDocs[1].score, DELTA_FOR_ASSERTION); + assertEquals(maxScore, scoreDocs[2].score, DELTA_FOR_ASSERTION); + assertEquals(MAGIC_NUMBER_START_STOP, scoreDocs[3].score, DELTA_FOR_ASSERTION); + + w.close(); + reader.close(); + directory.close(); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..2aebbb5d8 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.UUID; +import java.util.stream.Collectors; import org.apache.lucene.document.FieldType; import org.apache.lucene.document.TextField; @@ -61,6 +63,7 @@ import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.query.QueryCollectorContext; @@ -159,7 +162,7 @@ public void testQueryType_whenQueryIsHybrid_thenCallHybridDocCollector() { releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); } @SneakyThrows @@ -226,7 +229,7 @@ public void testQueryType_whenQueryIsNotHybrid_thenDoNotCallHybridDocCollector() releaseResources(directory, w, reader); - verify(hybridQueryPhaseSearcher, never()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + verify(hybridQueryPhaseSearcher, never()).extractHybridQuery(any(), any()); } @SneakyThrows @@ -305,17 +308,8 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { assertEquals(1, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(4, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(1, compoundTopDocs.size()); - TopDocs subQueryTopDocs = compoundTopDocs.get(0); - assertEquals(1, subQueryTopDocs.totalHits.value); - assertNotNull(subQueryTopDocs.scoreDocs); - assertEquals(1, subQueryTopDocs.scoreDocs.length); - ScoreDoc scoreDoc = subQueryTopDocs.scoreDocs[0]; + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; assertNotNull(scoreDoc); int actualDocId = Integer.parseInt(reader.document(scoreDoc.doc).getField("id").stringValue()); assertEquals(docId1, actualDocId); @@ -403,24 +397,10 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes assertEquals(4, topDocs.totalHits.value); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertEquals(10, scoreDocs.length); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(3, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); - - TopDocs subQueryTopDocs3 = compoundTopDocs.get(2); - List expectedIds3 = List.of(docId1, docId2, docId3, docId4); - assertQueryResults(subQueryTopDocs3, expectedIds3, reader); + assertEquals(4, scoreDocs.length); + List expectedIds = List.of(0, 1, 2, 3); + List actualDocIds = Arrays.stream(scoreDocs).map(sd -> sd.doc).collect(Collectors.toList()); + assertEquals(expectedIds, actualDocIds); releaseResources(directory, w, reader); } @@ -726,20 +706,10 @@ public void testWrappedHybridQuery_whenHybridWrappedIntoBoolBecauseOfNested_then assertTrue(topDocs.totalHits.value > 0); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assertNotNull(scoreDocs); - assertTrue(scoreDocs.length > 0); - assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); - assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); - List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); - assertNotNull(compoundTopDocs); - assertEquals(2, compoundTopDocs.size()); - - TopDocs subQueryTopDocs1 = compoundTopDocs.get(0); - List expectedIds1 = List.of(docId1); - assertQueryResults(subQueryTopDocs1, expectedIds1, reader); - - TopDocs subQueryTopDocs2 = compoundTopDocs.get(1); - List expectedIds2 = List.of(); - assertQueryResults(subQueryTopDocs2, expectedIds2, reader); + assertEquals(1, scoreDocs.length); + ScoreDoc scoreDoc = scoreDocs[0]; + assertTrue(scoreDoc.score > 0); + assertEquals(0, scoreDoc.doc); releaseResources(directory, w, reader); } @@ -831,6 +801,15 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = new HybridQueryPhaseSearcher(); + SearchContext searchContext = mock(SearchContext.class); + AggregationProcessor aggregationProcessor = hybridQueryPhaseSearcher.aggregationProcessor(searchContext); + assertNotNull(aggregationProcessor); + assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); From f04c058fc5ab193342c583cf820cd6cb72be42ea Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 12 Mar 2024 10:04:39 -0700 Subject: [PATCH 2/3] Adding aggregations in hybrid query (#630) * Adding aggregations in hybrid query Signed-off-by: Martin Gaievski --- CHANGELOG.md | 4 +- .../processor/combination/ScoreCombiner.java | 18 +- .../query/HybridAggregationProcessor.java | 2 +- .../query/HybridQueryPhaseSearcher.java | 48 +- .../neuralsearch/util/HybridQueryUtil.java | 71 +++ .../processor/NormalizationProcessorIT.java | 8 +- .../ScoreCombinationTechniqueTests.java | 2 +- .../query/HybridQueryAggregationsIT.java | 597 ++++++++++++++++++ .../query/HybridQueryPhaseSearcherTests.java | 80 ++- .../util/AggregationsTestUtils.java | 43 ++ .../util/HybridQueryUtilTests.java | 100 +++ .../neuralsearch/BaseNeuralSearchIT.java | 138 +++- 12 files changed, 1044 insertions(+), 67 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dcdc721b..120640aa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,19 +8,17 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements ### Bug Fixes - Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438)) -- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490)) -- Hybrid query and nested type fields ([#498](https://github.com/opensearch-project/neural-search/pull/498)) - Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578)) - Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615)) ### Infrastructure ### Documentation ### Maintenance -- Added support for jdk-21 ([#500](https://github.com/opensearch-project/neural-search/pull/500))) ### Refactoring ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.12...2.x) ### Features ### Enhancements +- Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630)) ### Bug Fixes - Fix runtime exceptions in hybrid query for case when sub-query scorer return TwoPhase iterator that is incompatible with DISI iterator ([#624](https://github.com/opensearch-project/neural-search/pull/624)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index c9e0551e2..278d2fdfc 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import org.apache.lucene.search.ScoreDoc; @@ -131,13 +132,18 @@ private void updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits)); } + /** + * Get max hits as number of unique doc ids from results of all sub-queries + * @param topDocsPerSubQuery list of topDocs objects for one shard + * @return number of unique doc ids + */ protected int getMaxHits(final List topDocsPerSubQuery) { - int maxHits = 0; - for (TopDocs topDocs : topDocsPerSubQuery) { - int hits = topDocs.scoreDocs.length; - maxHits = Math.max(maxHits, hits); - } - return maxHits; + Set docIds = topDocsPerSubQuery.stream() + .filter(topDocs -> Objects.nonNull(topDocs.scoreDocs)) + .flatMap(topDocs -> Arrays.stream(topDocs.scoreDocs)) + .map(scoreDoc -> scoreDoc.doc) + .collect(Collectors.toSet()); + return docIds.size(); } private TotalHits getTotalHits(final List topDocsPerSubQuery, int maxHits) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java index 4e9070748..7f36b09be 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridAggregationProcessor.java @@ -16,7 +16,7 @@ import java.io.IOException; import java.util.List; -import static org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher.isHybridQuery; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; /** * Defines logic for pre- and post-phases of document scores collection. Responsible for registering custom diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 6461c698e..4d8b429df 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -11,11 +11,9 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.FieldExistsQuery; import org.apache.lucene.search.Query; import org.opensearch.common.settings.Settings; import org.opensearch.index.mapper.MapperService; -import org.opensearch.index.mapper.SeqNoFieldMapper; import org.opensearch.index.search.NestedHelper; import org.opensearch.neuralsearch.query.HybridQuery; import org.opensearch.search.aggregations.AggregationProcessor; @@ -27,6 +25,8 @@ import lombok.extern.log4j.Log4j2; +import static org.opensearch.neuralsearch.util.HybridQueryUtil.isHybridQuery; + /** * Custom search implementation to be used at {@link QueryPhase} for Hybrid Query search. For queries other than Hybrid the * upstream standard implementation of searcher is called. @@ -34,10 +34,6 @@ @Log4j2 public class HybridQueryPhaseSearcher extends QueryPhaseSearcherWrapper { - public HybridQueryPhaseSearcher() { - super(); - } - public boolean searchWith( final SearchContext searchContext, final ContextIndexSearcher searcher, @@ -55,46 +51,6 @@ public boolean searchWith( } } - @VisibleForTesting - static boolean isHybridQuery(final Query query, final SearchContext searchContext) { - if (query instanceof HybridQuery) { - return true; - } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { - /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code - https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. - main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks - hybrid query for indexes with nested field types. - in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for - this search request. - below is sample structure of such query: - - Boolean { - should: { - hybrid: { - sub_query1 {} - sub_query2 {} - } - } - filter: { - exists: { - field: "_primary_term" - } - } - } - TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ - // we have already checked if query in instance of Boolean in higher level else if condition - return ((BooleanQuery) query).clauses() - .stream() - .filter(clause -> !(clause.getQuery() instanceof HybridQuery)) - .allMatch(clause -> { - return clause.getOccur() == BooleanClause.Occur.FILTER - && clause.getQuery() instanceof FieldExistsQuery - && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); - }); - } - return false; - } - private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); } diff --git a/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java new file mode 100644 index 000000000..689cbedca --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/HybridQueryUtil.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.SeqNoFieldMapper; +import org.opensearch.index.search.NestedHelper; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.search.internal.SearchContext; + +/** + * Utility class for anything related to hybrid query + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class HybridQueryUtil { + + public static boolean isHybridQuery(final Query query, final SearchContext searchContext) { + if (query instanceof HybridQuery) { + return true; + } else if (isWrappedHybridQuery(query) && hasNestedFieldOrNestedDocs(query, searchContext)) { + /* Checking if this is a hybrid query that is wrapped into a Bool query by core Opensearch code + https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/DefaultSearchContext.java#L367-L370. + main reason for that is performance optimization, at time of writing we are ok with loosing on performance if that's unblocks + hybrid query for indexes with nested field types. + in such case we consider query a valid hybrid query. Later in the code we will extract it and execute as a main query for + this search request. + below is sample structure of such query: + + Boolean { + should: { + hybrid: { + sub_query1 {} + sub_query2 {} + } + } + filter: { + exists: { + field: "_primary_term" + } + } + } + TODO Need to add logic for passing hybrid sub-queries through the same logic in core to ensure there is no latency regression */ + // we have already checked if query in instance of Boolean in higher level else if condition + return ((BooleanQuery) query).clauses() + .stream() + .filter(clause -> clause.getQuery() instanceof HybridQuery == false) + .allMatch(clause -> { + return clause.getOccur() == BooleanClause.Occur.FILTER + && clause.getQuery() instanceof FieldExistsQuery + && SeqNoFieldMapper.PRIMARY_TERM_NAME.equals(((FieldExistsQuery) clause.getQuery()).getField()); + }); + } + return false; + } + + private static boolean hasNestedFieldOrNestedDocs(final Query query, final SearchContext searchContext) { + return searchContext.mapperService().hasNested() && new NestedHelper(searchContext.mapperService()).mightMatchNestedDocs(query); + } + + private static boolean isWrappedHybridQuery(final Query query) { + return query instanceof BooleanQuery + && ((BooleanQuery) query).clauses().stream().anyMatch(clauseQuery -> clauseQuery.getQuery() instanceof HybridQuery); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index b1f0de9d3..e4f2c77ae 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -52,6 +52,8 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT { private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final float[] testVector4 = createRandomVector(TEST_DIMENSION); + private final float[] testVector5 = createRandomVector(TEST_DIMENSION); + private final float[] testVector6 = createRandomVector(TEST_DIMENSION); @Before public void setUp() throws Exception { @@ -318,7 +320,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME, "5", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector5).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT4) ); @@ -365,7 +367,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, "5", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector5).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT4) ); @@ -373,7 +375,7 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, "6", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), - Collections.singletonList(Floats.asList(testVector4).toArray()), + Collections.singletonList(Floats.asList(testVector6).toArray()), Collections.singletonList(TEST_TEXT_FIELD_NAME_1), Collections.singletonList(TEST_DOC_TEXT5) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index d2c1ddb4f..4f76c666e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -63,7 +63,7 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); - assertEquals(3, queryTopDocs.get(0).getScoreDocs().size()); + assertEquals(5, queryTopDocs.get(0).getScoreDocs().size()); assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(0).score, DELTA_FOR_SCORE_ASSERTION); assertEquals(1, queryTopDocs.get(0).getScoreDocs().get(0).doc); assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(1).score, DELTA_FOR_SCORE_ASSERTION); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java new file mode 100644 index 000000000..e51a4562d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryAggregationsIT.java @@ -0,0 +1,597 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.PipelineAggregatorBuilders; +import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval; +import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MinBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.IntStream; + +import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; + +/** + * Integration tests for base scenarios when aggregations are combined with hybrid query + */ +public class HybridQueryAggregationsIT extends BaseNeuralSearchIT { + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = + "test-neural-aggs-pipeline-multi-doc-index-multiple-shards"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-neural-aggs-multi-doc-index-single-shard"; + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; + private static final String TEST_DOC_TEXT5 = "People keep telling me orange but I still prefer pink"; + private static final String TEST_DOC_TEXT6 = "She traveled because it cost the same as therapy and was a lot more enjoyable"; + private static final String INTEGER_FIELD_1 = "doc_index"; + private static final int INTEGER_FIELD_1_VALUE = 1234; + private static final int INTEGER_FIELD_2_VALUE = 2345; + private static final int INTEGER_FIELD_3_VALUE = 3456; + private static final int INTEGER_FIELD_4_VALUE = 4567; + private static final String KEYWORD_FIELD_1 = "doc_keyword"; + private static final String KEYWORD_FIELD_1_VALUE = "workable"; + private static final String KEYWORD_FIELD_2_VALUE = "angry"; + private static final String KEYWORD_FIELD_3_VALUE = "likeable"; + private static final String KEYWORD_FIELD_4_VALUE = "entire"; + private static final String DATE_FIELD_1 = "doc_date"; + private static final String DATE_FIELD_1_VALUE = "01/03/1995"; + private static final String DATE_FIELD_2_VALUE = "05/02/2015"; + private static final String DATE_FIELD_3_VALUE = "07/23/2007"; + private static final String DATE_FIELD_4_VALUE = "08/21/2012"; + private static final String INTEGER_FIELD_PRICE = "doc_price"; + private static final int INTEGER_FIELD_PRICE_1_VALUE = 130; + private static final int INTEGER_FIELD_PRICE_2_VALUE = 100; + private static final int INTEGER_FIELD_PRICE_3_VALUE = 200; + private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; + private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; + private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; + private static final String BUCKET_AGG_DOC_COUNT_FIELD = "doc_count"; + private static final String KEY = "key"; + private static final String BUCKET_AGG_KEY_AS_STRING = "key_as_string"; + private static final String SUM_AGGREGATION_NAME = "sum_aggs"; + private static final String MAX_AGGREGATION_NAME = "max_aggs"; + private static final String DATE_AGGREGATION_NAME = "date_aggregation"; + private static final String GENERIC_AGGREGATION_NAME = "my_aggregation"; + private static final String BUCKETS_AGGREGATION_NAME_1 = "date_buckets_1"; + private static final String BUCKETS_AGGREGATION_NAME_2 = "date_buckets_2"; + private static final String BUCKETS_AGGREGATION_NAME_3 = "date_buckets_3"; + private static final String BUCKETS_AGGREGATION_NAME_4 = "date_buckets_4"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + @SneakyThrows + public void testPipelineAggs_whenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testAvgSumMinMaxAggs(); + } + + @SneakyThrows + public void testPipelineAggs_whenConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testAvgSumMinMaxAggs(); + } + + @SneakyThrows + public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testMaxAggsOnSingleShardCluster(); + } + + @SneakyThrows + public void testMetricAggsOnSingleShard_whenMaxAggsAndConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testMaxAggsOnSingleShardCluster(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenConcurrentSearchDisabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", false); + testDateRange(); + } + + @SneakyThrows + public void testBucketAndNestedAggs_whenConcurrentSearchEnabled_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + testDateRange(); + } + + @SneakyThrows + public void testAggregationNotSupportedConcurrentSearch_whenUseSamplerAgg_thenSuccessful() { + updateClusterSettings("search.concurrent_segment_search.enabled", true); + + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.sampler(GENERIC_AGGREGATION_NAME) + .shardSize(2) + .subAggregation(AggregationBuilders.terms(BUCKETS_AGGREGATION_NAME_1).field(KEYWORD_FIELD_1)); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggsBuilder), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + Map aggValue = getAggregationValues(aggregations, GENERIC_AGGREGATION_NAME); + assertEquals(2, aggValue.size()); + assertEquals(3, aggValue.get(BUCKET_AGG_DOC_COUNT_FIELD)); + Map nestedAggs = getAggregationValues(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertNotNull(nestedAggs); + assertEquals(0, nestedAggs.get("doc_count_error_upper_bound")); + List> buckets = getAggregationBuckets(aggValue, BUCKETS_AGGREGATION_NAME_1); + assertEquals(2, buckets.size()); + + Map firstBucket = buckets.get(0); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("likeable", firstBucket.get(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(1, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("workable", secondBucket.get(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void testAvgSumMinMaxAggs() { + try { + prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.dateHistogram(GENERIC_AGGREGATION_NAME) + .calendarInterval(DateHistogramInterval.YEAR) + .field(DATE_FIELD_1) + .subAggregation(AggregationBuilders.sum(SUM_AGGREGATION_NAME).field(INTEGER_FIELD_1)); + + BucketMetricsPipelineAggregationBuilder aggAvgBucket = PipelineAggregatorBuilders + .avgBucket(BUCKETS_AGGREGATION_NAME_1, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggSumBucket = PipelineAggregatorBuilders + .sumBucket(BUCKETS_AGGREGATION_NAME_2, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMinBucket = PipelineAggregatorBuilders + .minBucket(BUCKETS_AGGREGATION_NAME_3, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + BucketMetricsPipelineAggregationBuilder aggMaxBucket = PipelineAggregatorBuilders + .maxBucket(BUCKETS_AGGREGATION_NAME_4, GENERIC_AGGREGATION_NAME + ">" + SUM_AGGREGATION_NAME); + + Map searchResponseAsMapAnngsBoolQuery = executeQueryAndGetAggsResults( + List.of(aggsBuilder, aggAvgBucket, aggSumBucket, aggMinBucket, aggMaxBucket), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + assertResultsOfPipelineSumtoDateHistogramAggs(searchResponseAsMapAnngsBoolQuery); + + // test only aggregation without query (handled as match_all query) + Map searchResponseAsMapAggsNoQuery = executeQueryAndGetAggsResults( + List.of(aggsBuilder, aggAvgBucket), + null, + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 6 + ); + + assertResultsOfPipelineSumtoDateHistogramAggsForMatchAllQuery(searchResponseAsMapAggsNoQuery); + + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + private void testMaxAggsOnSingleShardCluster() throws Exception { + try { + prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + AggregationBuilder aggsBuilder = AggregationBuilders.max(MAX_AGGREGATION_NAME).field(INTEGER_FIELD_1); + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + List.of(aggsBuilder) + ); + + assertHitResultsFromQuery(2, searchResponseAsMap); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + assertTrue(aggregations.containsKey(MAX_AGGREGATION_NAME)); + double maxAggsValue = getAggregationValue(aggregations, MAX_AGGREGATION_NAME); + assertTrue(maxAggsValue >= 0); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, null, null, SEARCH_PIPELINE); + } + } + + private void testDateRange() throws IOException { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + // try { + // prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE); + + AggregationBuilder aggsBuilder = AggregationBuilders.dateRange(DATE_AGGREGATION_NAME) + .field(DATE_FIELD_1) + .format("MM-yyyy") + .addRange("01-2014", "02-2024"); + + Map searchResponseAsMap = executeQueryAndGetAggsResults( + List.of(aggsBuilder), + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + 3 + ); + + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + List> buckets = getAggregationBuckets(aggregations, DATE_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(1, buckets.size()); + + Map bucket = buckets.get(0); + + assertEquals(6, bucket.size()); + assertEquals("01-2014", bucket.get("from_as_string")); + assertEquals(2, bucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals("02-2024", bucket.get("to_as_string")); + assertTrue(bucket.containsKey("from")); + assertTrue(bucket.containsKey("to")); + assertTrue(bucket.containsKey(KEY)); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void initializeIndexIfNotExist(String indexName) throws IOException { + if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(DATE_FIELD_1), 3), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_1_VALUE, INTEGER_FIELD_PRICE_1_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_1_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_1_VALUE) + ); + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_2_VALUE, INTEGER_FIELD_PRICE_2_VALUE), + List.of(), + List.of(), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_2_VALUE) + ); + addKnnDoc( + indexName, + "3", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2), + List.of(), + List.of(), + List.of(INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_PRICE_3_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_2_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_3_VALUE) + ); + addKnnDoc( + indexName, + "4", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT4), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_4_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_3_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_2_VALUE) + ); + addKnnDoc( + indexName, + "5", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT5), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_3_VALUE, INTEGER_FIELD_PRICE_5_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_4_VALUE) + ); + addKnnDoc( + indexName, + "6", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT6), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1, INTEGER_FIELD_PRICE), + List.of(INTEGER_FIELD_4_VALUE, INTEGER_FIELD_PRICE_6_VALUE), + List.of(KEYWORD_FIELD_1), + List.of(KEYWORD_FIELD_4_VALUE), + List.of(DATE_FIELD_1), + List.of(DATE_FIELD_4_VALUE) + ); + } + } + + @SneakyThrows + private void initializeIndexWithOneShardIfNotExists(String indexName) { + if (!indexExists(indexName)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_1), List.of(KEYWORD_FIELD_1), List.of(), 1), + "" + ); + + addKnnDoc( + indexName, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_1_VALUE), + List.of(), + List.of(), + List.of(), + List.of() + ); + + addKnnDoc( + indexName, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_2_VALUE), + List.of(), + List.of(), + List.of(), + List.of() + ); + } + } + + @SneakyThrows + void prepareResources(String indexName, String pipelineName) { + initializeIndexIfNotExist(indexName); + createSearchPipelineWithResultsPostProcessor(pipelineName); + } + + @SneakyThrows + void prepareResourcesForSingleShardIndex(String indexName, String pipelineName) { + initializeIndexWithOneShardIfNotExists(indexName); + createSearchPipelineWithResultsPostProcessor(pipelineName); + } + + private void assertResultsOfPipelineSumtoDateHistogramAggs(Map searchResponseAsMap) { + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + double aggValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_1); + assertEquals(3517.5, aggValue, DELTA_FOR_SCORE_ASSERTION); + + double sumValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_2); + assertEquals(7035.0, sumValue, DELTA_FOR_SCORE_ASSERTION); + + double minValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_3); + assertEquals(1234.0, minValue, DELTA_FOR_SCORE_ASSERTION); + + double maxValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_4); + assertEquals(5801.0, maxValue, DELTA_FOR_SCORE_ASSERTION); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } + + private void assertResultsOfPipelineSumtoDateHistogramAggsForMatchAllQuery(Map searchResponseAsMap) { + Map aggregations = getAggregations(searchResponseAsMap); + assertNotNull(aggregations); + + double aggValue = getAggregationValue(aggregations, BUCKETS_AGGREGATION_NAME_1); + assertEquals(3764.5, aggValue, DELTA_FOR_SCORE_ASSERTION); + + List> buckets = getAggregationBuckets(aggregations, GENERIC_AGGREGATION_NAME); + assertNotNull(buckets); + assertEquals(21, buckets.size()); + + // check content of few buckets + Map firstBucket = buckets.get(0); + assertEquals(4, firstBucket.size()); + assertEquals("01/01/1995", firstBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(1, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(1234.0, getAggregationValue(firstBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(firstBucket.containsKey(KEY)); + + Map secondBucket = buckets.get(1); + assertEquals(4, secondBucket.size()); + assertEquals("01/01/1996", secondBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(0, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(0.0, getAggregationValue(secondBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(secondBucket.containsKey(KEY)); + + Map lastBucket = buckets.get(buckets.size() - 1); + assertEquals(4, lastBucket.size()); + assertEquals("01/01/2015", lastBucket.get(BUCKET_AGG_KEY_AS_STRING)); + assertEquals(2, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD)); + assertEquals(5801.0, getAggregationValue(lastBucket, SUM_AGGREGATION_NAME), DELTA_FOR_SCORE_ASSERTION); + assertTrue(lastBucket.containsKey(KEY)); + } + + private Map executeQueryAndGetAggsResults(final List aggsBuilders, String indexName, int expectedHitsNumber) { + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + return executeQueryAndGetAggsResults(aggsBuilders, hybridQueryBuilderNeuralThenTerm, indexName, expectedHitsNumber); + } + + private Map executeQueryAndGetAggsResults( + final List aggsBuilders, + QueryBuilder queryBuilder, + String indexName, + int expectedHits + ) { + Map searchResponseAsMap = search( + indexName, + queryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilders + ); + + assertHitResultsFromQuery(expectedHits, searchResponseAsMap); + return searchResponseAsMap; + } + + private void assertHitResultsFromQuery(int expected, Map searchResponseAsMap) { + assertEquals(expected, getHitCount(searchResponseAsMap)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(expected, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index 2aebbb5d8..055301832 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -66,12 +66,12 @@ import org.opensearch.search.aggregations.AggregationProcessor; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QuerySearchResult; import com.carrotsearch.randomizedtesting.RandomizedTest; import lombok.SneakyThrows; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -810,6 +810,82 @@ public void testAggsProcessor_whenGettingAggsProcessor_thenSuccess() { assertTrue(aggregationProcessor instanceof HybridAggregationProcessor); } + @SneakyThrows + public void testAggregations_whenMetricAggregation_thenSuccessful() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWith(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); diff --git a/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java b/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java new file mode 100644 index 000000000..fbb53a918 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/AggregationsTestUtils.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import java.util.List; +import java.util.Map; + +/** + * Util class for routines associated with aggregations testing + */ +public class AggregationsTestUtils { + + public static List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + public static Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + public static Map getAggregations(final Map searchResponseAsMap) { + Map aggsMap = (Map) searchResponseAsMap.get("aggregations"); + return aggsMap; + } + + public static T getAggregationValue(final Map aggsMap, final String aggName) { + Map aggValues = (Map) aggsMap.get(aggName); + return (T) aggValues.get("value"); + } + + public static T getAggregationBuckets(final Map aggsMap, final String aggName) { + Map aggValues = (Map) aggsMap.get(aggName); + return (T) aggValues.get("buckets"); + } + + public static T getAggregationValues(final Map aggsMap, final String aggName) { + return (T) aggsMap.get(aggName); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java new file mode 100644 index 000000000..be9dbc2cc --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/HybridQueryUtilTests.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import lombok.SneakyThrows; +import org.apache.lucene.search.Query; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.index.mapper.TextFieldMapper; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.query.HybridQuery; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.search.internal.SearchContext; + +import java.util.List; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HybridQueryUtilTests extends OpenSearchQueryTestCase { + + private static final String TERM_QUERY_TEXT = "keyword"; + private static final String RANGE_FIELD = "date _range"; + private static final String FROM_TEXT = "123"; + private static final String TO_TEXT = "456"; + private static final String TEXT_FIELD_NAME = "field"; + + @SneakyThrows + public void testIsHybridQueryCheck_whenQueryIsHybridQueryInstance_thenSuccess() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQuery query = new HybridQuery( + List.of( + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext), + QueryBuilders.rangeQuery(RANGE_FIELD) + .from(FROM_TEXT) + .to(TO_TEXT) + .rewrite(mockQueryShardContext) + .rewrite(mockQueryShardContext) + .toQuery(mockQueryShardContext), + QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT).toQuery(mockQueryShardContext) + ) + ); + SearchContext searchContext = mock(SearchContext.class); + + assertTrue(HybridQueryUtil.isHybridQuery(query, searchContext)); + } + + @SneakyThrows + public void testIsHybridQueryCheck_whenHybridWrappedIntoBoolAndNoNested_thenSuccess() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)); + hybridQueryBuilder.add( + QueryBuilders.rangeQuery(RANGE_FIELD).from(FROM_TEXT).to(TO_TEXT).rewrite(mockQueryShardContext).rewrite(mockQueryShardContext) + ); + + Query booleanQuery = QueryBuilders.boolQuery() + .should(hybridQueryBuilder) + .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) + .toQuery(mockQueryShardContext); + SearchContext searchContext = mock(SearchContext.class); + when(searchContext.mapperService()).thenReturn(mapperService); + + assertFalse(HybridQueryUtil.isHybridQuery(booleanQuery, searchContext)); + } + + @SneakyThrows + public void testIsHybridQueryCheck_whenNoHybridQuery_thenSuccess() { + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Query booleanQuery = QueryBuilders.boolQuery() + .should(QueryBuilders.termQuery(TEXT_FIELD_NAME, TERM_QUERY_TEXT)) + .should( + QueryBuilders.rangeQuery(RANGE_FIELD) + .from(FROM_TEXT) + .to(TO_TEXT) + .rewrite(mockQueryShardContext) + .rewrite(mockQueryShardContext) + ) + .toQuery(mockQueryShardContext); + SearchContext searchContext = mock(SearchContext.class); + when(searchContext.mapperService()).thenReturn(mapperService); + + assertFalse(HybridQueryUtil.isHybridQuery(booleanQuery, searchContext)); + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index ffbbed2bc..622327fa7 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -413,14 +413,54 @@ protected Map search( final int resultSize, final Map requestParams ) { - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); - queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + return search(index, queryBuilder, rescorer, resultSize, requestParams, null); + } + + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + List aggs + ) { + return search(index, queryBuilder, rescorer, resultSize, requestParams, aggs, null); + } + + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + List aggs, + QueryBuilder postFilterBuilder + ) { + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + + if (queryBuilder != null) { + builder.field("query"); + queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } if (rescorer != null) { builder.startObject("rescore").startObject("query").field("query_weight", 0.0f).field("rescore_query"); rescorer.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject().endObject(); } + if (Objects.nonNull(aggs)) { + builder.startObject("aggs"); + for (Object agg : aggs) { + builder.value(agg); + } + builder.endObject(); + } + if (Objects.nonNull(postFilterBuilder)) { + builder.field("post_filter"); + postFilterBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + } builder.endObject(); @@ -463,6 +503,35 @@ protected void addKnnDoc( addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts, + List nestedFieldNames, + List> nestedFields + ) { + addKnnDoc( + index, + docId, + vectorFieldNames, + vectors, + textFieldNames, + texts, + nestedFieldNames, + nestedFields, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList() + ); + } + /** * Add a set of knn vectors and text to an index * @@ -484,7 +553,13 @@ protected void addKnnDoc( final List textFieldNames, final List texts, final List nestedFieldNames, - final List> nestedFields + final List> nestedFields, + final List integerFieldNames, + final List integerFieldValues, + final List keywordFieldNames, + final List keywordFieldValues, + final List dateFieldNames, + final List dateFieldValues ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -505,6 +580,18 @@ protected void addKnnDoc( } builder.endObject(); } + + for (int i = 0; i < integerFieldNames.size(); i++) { + builder.field(integerFieldNames.get(i), integerFieldValues.get(i)); + } + + for (int i = 0; i < keywordFieldNames.size(); i++) { + builder.field(keywordFieldNames.get(i), keywordFieldValues.get(i)); + } + + for (int i = 0; i < dateFieldNames.size(); i++) { + builder.field(dateFieldNames.get(i), dateFieldValues.get(i)); + } builder.endObject(); request.setJsonEntity(builder.toString()); @@ -667,6 +754,25 @@ protected String buildIndexConfiguration( final List knnFieldConfigs, final List nestedFields, final int numberOfShards + ) { + return buildIndexConfiguration( + knnFieldConfigs, + nestedFields, + Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), + numberOfShards + ); + } + + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final List intFields, + final List keywordFields, + final List dateFields, + final int numberOfShards ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() @@ -688,9 +794,31 @@ protected String buildIndexConfiguration( .endObject() .endObject(); } + // treat the list in a manner that first element is always the type name and all others are keywords + if (!nestedFields.isEmpty()) { + String nestedFieldName = nestedFields.get(0); + xContentBuilder.startObject(nestedFieldName).field("type", "nested"); + if (nestedFields.size() > 1) { + xContentBuilder.startObject("properties"); + for (int i = 1; i < nestedFields.size(); i++) { + String innerNestedTypeField = nestedFields.get(i); + xContentBuilder.startObject(innerNestedTypeField).field("type", "keyword").endObject(); + } + xContentBuilder.endObject(); + } + xContentBuilder.endObject(); + } + + for (String intField : intFields) { + xContentBuilder.startObject(intField).field("type", "integer").endObject(); + } + + for (String keywordField : keywordFields) { + xContentBuilder.startObject(keywordField).field("type", "keyword").endObject(); + } - for (String nestedField : nestedFields) { - xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); + for (String dateField : dateFields) { + xContentBuilder.startObject(dateField).field("type", "date").field("format", "MM/dd/yyyy").endObject(); } xContentBuilder.endObject().endObject().endObject(); From 759a971404c31dd302527b0c3f00ce347a12c48c Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 13 Mar 2024 17:20:23 -0500 Subject: [PATCH 3/3] Fix Failing build due to slf4j from k-NN (#634) Signed-off-by: Naveen Tatikonda --- build.gradle | 8 +++++--- qa/build.gradle | 4 +++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/build.gradle b/build.gradle index 496edad4f..1bf27837c 100644 --- a/build.gradle +++ b/build.gradle @@ -250,7 +250,9 @@ dependencies { zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" secureIntegTestPluginArchive group: 'org.opensearch.plugin', name:'opensearch-security', version: "${opensearch_build}" - compileOnly fileTree(dir: knnJarDirectory, include: '*.jar') + compileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar") + compileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' + compileOnly group: 'commons-lang', name: 'commons-lang', version: '2.6' api group: 'org.opensearch', name:'opensearch-ml-client', version: "${opensearch_build}" testFixturesImplementation "org.opensearch.test:framework:${opensearch_version}" implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' @@ -265,8 +267,8 @@ dependencies { runtimeOnly group: 'org.json', name: 'json', version: '20231013' testFixturesImplementation "org.opensearch:common-utils:${version}" testFixturesImplementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.14.0' - testFixturesCompileOnly group: 'com.google.guava', name: 'guava', version:'32.0.1-jre' - testFixturesCompileOnly fileTree(dir: knnJarDirectory, include: '*.jar') + testFixturesCompileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' + testFixturesCompileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar") } // In order to add the jar to the classpath, we need to unzip the diff --git a/qa/build.gradle b/qa/build.gradle index 511dcc442..6b7667e62 100644 --- a/qa/build.gradle +++ b/qa/build.gradle @@ -31,7 +31,9 @@ dependencies { api "org.opensearch:opensearch:${opensearch_version}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-knn', version: "${opensearch_build}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-ml-plugin', version: "${opensearch_build}" - compileOnly fileTree(dir: knnJarDirectory, include: '*.jar') + compileOnly fileTree(dir: knnJarDirectory, include: "opensearch-knn-${opensearch_build}.jar") + compileOnly group: 'com.google.guava', name: 'guava', version:'32.1.3-jre' + compileOnly group: 'commons-lang', name: 'commons-lang', version: '2.6' api "org.apache.logging.log4j:log4j-api:${versions.log4j}" api "org.apache.logging.log4j:log4j-core:${versions.log4j}" api "junit:junit:${versions.junit}"