diff --git a/src/main/java/io/anserini/search/SimpleImpactSearcher.java b/src/main/java/io/anserini/search/SimpleImpactSearcher.java index 9f353ccfd7..78ccfd2439 100644 --- a/src/main/java/io/anserini/search/SimpleImpactSearcher.java +++ b/src/main/java/io/anserini/search/SimpleImpactSearcher.java @@ -17,10 +17,15 @@ package io.anserini.search; import io.anserini.index.Constants; +import org.apache.lucene.analysis.Analyzer; +import io.anserini.analysis.AnalyzerUtils; +import io.anserini.index.IndexCollection; import io.anserini.index.IndexReaderUtils; import io.anserini.rerank.RerankerCascade; import io.anserini.rerank.RerankerContext; import io.anserini.rerank.ScoredDocuments; +import io.anserini.rerank.lib.Rm3Reranker; +import io.anserini.rerank.lib.RocchioReranker; import io.anserini.rerank.lib.ScoreTiesAdjusterReranker; import io.anserini.search.query.BagOfWordsQueryGenerator; import io.anserini.search.query.QueryEncoder; @@ -38,6 +43,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.store.FSDirectory; +import java.util.ArrayList; +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; import ai.onnxruntime.OrtException; @@ -54,7 +61,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; - +import java.util.stream.Collectors; /** * Class that exposes basic search functionality, designed specifically to provide the bridge between Java and Python * via pyjnius. Note that methods are named according to Python conventions (e.g., snake case instead of camel case). @@ -67,10 +74,13 @@ public class SimpleImpactSearcher implements Closeable { protected IndexReader reader; protected Similarity similarity; protected BagOfWordsQueryGenerator generator; + protected Analyzer analyzer; protected RerankerCascade cascade; protected IndexSearcher searcher = null; protected boolean backwardsCompatibilityLucene8; private QueryEncoder queryEncoder = null; + protected boolean useRM3; + protected boolean useRocchio; /** * This class is meant to serve as the bridge between Anserini and Pyserini. * Note that we are adopting Python naming conventions here on purpose. @@ -103,6 +113,42 @@ protected SimpleImpactSearcher() { * @throws IOException if errors encountered during initialization */ public SimpleImpactSearcher(String indexDir) throws IOException { + this(indexDir, IndexCollection.DEFAULT_ANALYZER); + } + + /** + * Creates a {@code SimpleImpactSearcher}. + * + * @param indexDir index directory + * @param queryEncoder query encoder + * @throws IOException if errors encountered during initialization + */ + public SimpleImpactSearcher(String indexDir, String queryEncoder) throws IOException { + this(indexDir, IndexCollection.DEFAULT_ANALYZER); + this.set_onnx_query_encoder(queryEncoder); + } + + /** + * Creates a {@code SimpleImpactSearcher}. + * + * @param indexDir index directory + * @param queryEncoder query encoder + * @param analyzer Analyzer + * @throws IOException if errors encountered during initialization + */ + public SimpleImpactSearcher(String indexDir, String queryEncoder, Analyzer analyzer) throws IOException { + this(indexDir, analyzer); + this.set_onnx_query_encoder(queryEncoder); + } + + /** + * Creates a {@code SimpleImpactSearcher}. + * + * @param indexDir index directory + * @param analyzer Analyzer + * @throws IOException if errors encountered during initialization + */ + public SimpleImpactSearcher(String indexDir, Analyzer analyzer) throws IOException { Path indexPath = Paths.get(indexDir); if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) { @@ -118,7 +164,10 @@ public SimpleImpactSearcher(String indexDir) throws IOException { // Default to using ImpactSimilarity. this.similarity = new ImpactSimilarity(); + this.analyzer = analyzer; this.generator = new BagOfWordsQueryGenerator(); + this.useRM3 = false; + this.useRocchio = false; cascade = new RerankerCascade(); cascade.add(new ScoreTiesAdjusterReranker()); } @@ -129,7 +178,7 @@ public SimpleImpactSearcher(String indexDir) throws IOException { * @param encoder the query encoder */ public void set_onnx_query_encoder(String encoder) { - if (empty_encoder()) { + if (emptyEncoder()) { try { this.queryEncoder = (QueryEncoder) Class.forName("io.anserini.search.query." + encoder + "QueryEncoder") .getConstructor().newInstance(); @@ -139,10 +188,197 @@ public void set_onnx_query_encoder(String encoder) { } } - private boolean empty_encoder(){ + private boolean emptyEncoder(){ return this.queryEncoder == null; } + /** + * Sets the analyzer used. + * + * @param analyzer analyzer to use + */ + public void set_analyzer(Analyzer analyzer) { + this.analyzer = analyzer; + } + + /** + * Returns the analyzer used. + * + * @return analyzed used + */ + public Analyzer get_analyzer(){ + return this.analyzer; + } + + /** + * Determines if RM3 query expansion is enabled. + * + * @return true if RM query expansion is enabled; false otherwise. + */ + public boolean use_rm3() { + return useRM3; + } + + /** + * Disables RM3 query expansion. + */ + public void unset_rm3() { + this.useRM3 = false; + cascade = new RerankerCascade(); + cascade.add(new ScoreTiesAdjusterReranker()); + } + + /** + * Enables RM3 query expansion with default parameters. + */ + public void set_rm3() { + SearchCollection.Args defaults = new SearchCollection.Args(); + set_rm3(Integer.parseInt(defaults.rm3_fbTerms[0]), Integer.parseInt(defaults.rm3_fbDocs[0]), + Float.parseFloat(defaults.rm3_originalQueryWeight[0])); + } + + /** + * Enables RM3 query expansion with default parameters. + * + * @param collectionClass class for on-the-fly document parsing if index does not contain docvectors + */ + public void set_rm3(String collectionClass) { + SearchCollection.Args defaults = new SearchCollection.Args(); + set_rm3(collectionClass, Integer.parseInt(defaults.rm3_fbTerms[0]), Integer.parseInt(defaults.rm3_fbDocs[0]), + Float.parseFloat(defaults.rm3_originalQueryWeight[0])); + } + + /** + * Enables RM3 query expansion with specified parameters. + * + * @param fbTerms number of expansion terms + * @param fbDocs number of expansion documents + * @param originalQueryWeight weight to assign to the original query + */ + public void set_rm3(int fbTerms, int fbDocs, float originalQueryWeight) { + set_rm3(null, fbTerms, fbDocs, originalQueryWeight, false, true); + } + + /** + * Enables RM3 query expansion with specified parameters. + * + * @param collectionClass class for on-the-fly document parsing if index does not contain docvectors + * @param fbTerms number of expansion terms + * @param fbDocs number of expansion documents + * @param originalQueryWeight weight to assign to the original query + */ + public void set_rm3(String collectionClass, int fbTerms, int fbDocs, float originalQueryWeight) { + set_rm3(collectionClass, fbTerms, fbDocs, originalQueryWeight, false, true); + } + + /** + * Enables RM3 query expansion with specified parameters. + * + * @param collectionClass class for on-the-fly document parsing if index does not contain docvectors + * @param fbTerms number of expansion terms + * @param fbDocs number of expansion documents + * @param originalQueryWeight weight to assign to the original query + * @param outputQuery flag to print original and expanded queries + * @param filterTerms whether to filter terms to be English only + */ + public void set_rm3(String collectionClass, int fbTerms, int fbDocs, float originalQueryWeight, boolean outputQuery, boolean filterTerms) { + Class clazz = null; + try { + if (collectionClass != null) { + clazz = Class.forName("io.anserini.collection." + collectionClass); + } + } catch (ClassNotFoundException e) { + LOG.error("collectionClass: " + collectionClass + " not found!"); + } + + useRM3 = true; + cascade = new RerankerCascade("rm3"); + cascade.add(new Rm3Reranker(this.analyzer, clazz, Constants.CONTENTS, + fbTerms, fbDocs, originalQueryWeight, outputQuery, filterTerms)); + cascade.add(new ScoreTiesAdjusterReranker()); + } + + /** + * Determines if Rocchio query expansion is enabled. + * + * @return true if Rocchio query expansion is enabled; false otherwise. + */ + public boolean use_rocchio() { + return useRocchio; + } + + /** + * Disables Rocchio query expansion. + */ + public void unset_rocchio() { + this.useRocchio = false; + cascade = new RerankerCascade(); + cascade.add(new ScoreTiesAdjusterReranker()); + } + + /** + * Enables Rocchio query expansion with default parameters. + */ + public void set_rocchio() { + SearchCollection.Args defaults = new SearchCollection.Args(); + set_rocchio(null, Integer.parseInt(defaults.rocchio_topFbTerms[0]), Integer.parseInt(defaults.rocchio_topFbDocs[0]), + Integer.parseInt(defaults.rocchio_bottomFbTerms[0]), Integer.parseInt(defaults.rocchio_bottomFbDocs[0]), + Float.parseFloat(defaults.rocchio_alpha[0]), Float.parseFloat(defaults.rocchio_beta[0]), + Float.parseFloat(defaults.rocchio_gamma[0]), false, false); + } + + /** + * Enables Rocchio query expansion with default parameters. + * + * @param collectionClass class for on-the-fly document parsing if index does not contain docvectors + */ + public void set_rocchio(String collectionClass) { + SearchCollection.Args defaults = new SearchCollection.Args(); + set_rocchio(collectionClass, Integer.parseInt(defaults.rocchio_topFbTerms[0]), Integer.parseInt(defaults.rocchio_topFbDocs[0]), + Integer.parseInt(defaults.rocchio_bottomFbTerms[0]), Integer.parseInt(defaults.rocchio_bottomFbDocs[0]), + Float.parseFloat(defaults.rocchio_alpha[0]), Float.parseFloat(defaults.rocchio_beta[0]), + Float.parseFloat(defaults.rocchio_gamma[0]), false, false); + } + + /** + * Enables Rocchio query expansion with specified parameters. + * + * @param collectionClass class for on-the-fly document parsing if index does not contain docvectors + * @param topFbTerms number of relevant expansion terms + * @param topFbDocs number of relevant expansion documents + * @param bottomFbTerms number of nonrelevant expansion terms + * @param bottomFbDocs number of nonrelevant expansion documents + * @param alpha weight to assign to the original query + * @param beta weight to assign to the relevant document vectors + * @param gamma weight to assign to the nonrelevant document vectors + * @param outputQuery flag to print original and expanded queries + * @param useNegative flag to use negative feedback + */ + public void set_rocchio(String collectionClass, int topFbTerms, int topFbDocs, int bottomFbTerms, int bottomFbDocs, float alpha, float beta, float gamma, boolean outputQuery, boolean useNegative) { + Class clazz = null; + try { + if (collectionClass != null) { + clazz = Class.forName("io.anserini.collection." + collectionClass); + } + } catch (ClassNotFoundException e) { + LOG.error("collectionClass: " + collectionClass + " not found!"); + } + + useRocchio = true; + cascade = new RerankerCascade("rocchio"); + cascade.add(new RocchioReranker(this.analyzer, clazz, Constants.CONTENTS, + topFbTerms, topFbDocs, bottomFbTerms, bottomFbDocs, alpha, beta, gamma, outputQuery, useNegative)); + cascade.add(new ScoreTiesAdjusterReranker()); + } + + /** + * Returns the {@link Similarity} (i.e., scoring function) currently being used. + * + * @return the {@link Similarity} currently being used + */ + public Similarity get_similarity() { + return similarity; + } /** * Returns the number of documents in the index. @@ -159,6 +395,20 @@ public int get_total_num_docs() { return searcher.getIndexReader().maxDoc(); } + /** + * helper function to change Map to Map + * + * @param Map map needs to be transform + * @return a map in the form of Map + */ + private Map intToFloat(Map input) { + Map transformed = new HashMap<>(); + for (Map.Entry entry : input.entrySet()) { + transformed.put(entry.getKey(), entry.getValue().floatValue()); + } + return transformed; + } + /** * Closes this searcher. */ @@ -174,13 +424,73 @@ public void close() throws IOException { /** * Searches in batch using multiple threads. * - * @param queries list of queries + * @param encoded_queries list of queries + * @param qids list of unique query ids + * @param k number of hits + * @param threads number of threads + * @return a map of query id to search results + */ + public Map batch_search(List> encoded_queries, + List qids, + int k, + int threads) { + // Create the IndexSearcher here, if needed. We do it here because if we leave the creation to the search + // method, we might end up with a race condition as multiple threads try to concurrently create the IndexSearcher. + if (searcher == null) { + searcher = new IndexSearcher(reader); + searcher.setSimilarity(similarity); + } + + ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(threads); + ConcurrentHashMap results = new ConcurrentHashMap<>(); + + int queryCnt = encoded_queries.size(); + for (int q = 0; q < queryCnt; ++q) { + Map query = encoded_queries.get(q); + String qid = qids.get(q); + executor.execute(() -> { + try { + results.put(qid, search(query, k)); + } catch (IOException e) { + throw new CompletionException(e); + } catch (OrtException e) { + throw new CompletionException(e); + } + }); + } + + executor.shutdown(); + + try { + // Wait for existing tasks to terminate + while (!executor.awaitTermination(1, TimeUnit.MINUTES)) { + // Opportunity to perform status logging, but no-op here because logging interferes with Python tqdm + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + executor.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + + if (queryCnt != executor.getCompletedTaskCount()) { + throw new RuntimeException("queryCount = " + queryCnt + + " is not equal to completedTaskCount = " + executor.getCompletedTaskCount()); + } + + return results; + } + + /** + * Searches in batch using multiple threads. + * + * @param queries list of String queries * @param qids list of unique query ids * @param k number of hits * @param threads number of threads * @return a map of query id to search results */ - public Map batch_search(List> queries, + public Map batch_search_queries(List queries, List qids, int k, int threads) { @@ -196,13 +506,15 @@ public Map batch_search(List> queries, int queryCnt = queries.size(); for (int q = 0; q < queryCnt; ++q) { - Map query = queries.get(q); + String query = queries.get(q); String qid = qids.get(q); executor.execute(() -> { try { results.put(qid, search(query, k)); } catch (IOException e) { throw new CompletionException(e); + } catch (OrtException e) { + throw new CompletionException(e); } }); } @@ -236,8 +548,39 @@ public Map batch_search(List> queries, * @throws OrtException if errors encountered during encoding * @return encoded query */ - public Map encode_with_onnx(String queryString) throws OrtException { - Map encodedQ = this.queryEncoder.getTokenWeightMap(queryString); + public Map encodeWithOnnx(String queryString) throws OrtException { + // if no query encoder, assume its encoded query split by whitespace + if (this.queryEncoder == null){ + Analyzer whiteSpaceAnalyzer = new WhitespaceAnalyzer(); + List queryTokens = AnalyzerUtils.analyze(analyzer, queryString); + Map queryTokensFreq = queryTokens.stream().collect(Collectors.toMap( + e->e, (a)->1, Integer::sum)); + return queryTokensFreq; + } + + Map encodedQ = this.queryEncoder.getEncodedQueryMap(queryString); + return encodedQ; + } + + /** + * Encodes the weight map using the onnx encoder + * + * @param queryWeight query weight map + * @throws OrtException if errors encountered during encoding + * @return encoded query + */ + public String encodeWithOnnx(Map queryWeight) throws OrtException { + String encodedQ = ""; + List encodedQuery = new ArrayList<>(); + for (Map.Entry entry : queryWeight.entrySet()) { + String token = entry.getKey(); + Integer tokenWeight = entry.getValue(); + for (int i = 0; i < tokenWeight; ++i) { + encodedQuery.add(token); + } + } + encodedQ = String.join(" ", encodedQuery); + return encodedQ; } @@ -245,29 +588,65 @@ public Map encode_with_onnx(String queryString) throws OrtExcepti /** * Searches the collection, returning 10 hits by default. * - * @param q query + * @param encoded_q query * @return array of search results * @throws IOException if error encountered during search + * @throws OrtException if error encountered during search */ - public Result[] search(Map q) throws IOException { + public Result[] search(Map encoded_q) throws IOException, OrtException { + return search(encoded_q, 10); + } + + /** + * Searches the collection, returning 10 hits by default. + * + * @param q raw string query + * @return array of search results + * @throws IOException if error encountered during search + * @throws OrtException if error encountered during search + */ + public Result[] search(String q) throws IOException, OrtException { return search(q, 10); } /** * Searches the collection. * - * @param q query + * @param encoded_q query + * @param k number of hits + * @return array of search results + * @throws IOException if error encountered during search + * @throws OrtException if error encountered during search + */ + public Result[] search(Map encoded_q, int k) throws IOException, OrtException { + Map float_encoded_q = intToFloat(encoded_q); + Query query = generator.buildQuery(Constants.CONTENTS, float_encoded_q); + String encodedQuery = encodeWithOnnx(encoded_q); + return _search(query, encodedQuery, k); + } + + /** + * Searches the collection. + * + * @param q string query * @param k number of hits * @return array of search results * @throws IOException if error encountered during search + * @throws OrtException if error encountered during search */ - public Result[] search(Map q, int k) throws IOException { - Query query = generator.buildQuery(Constants.CONTENTS, q); - return _search(query, k); + public Result[] search(String q, int k) throws IOException, OrtException { + // make encoded query from raw query + Map encoded_q = encodeWithOnnx(q); + + // transform map type for query generator + Map float_encoded_q = intToFloat(encoded_q); + Query query = generator.buildQuery(Constants.CONTENTS, float_encoded_q); + String encodedQuery = encodeWithOnnx(encoded_q); + return _search(query, encodedQuery, k); } // internal implementation - protected Result[] _search(Query query, int k) throws IOException { + protected Result[] _search(Query query, String encodedQuery, int k) throws IOException, OrtException { // Create an IndexSearch only once. Note that the object is thread safe. if (searcher == null) { searcher = new IndexSearcher(reader); @@ -278,6 +657,10 @@ protected Result[] _search(Query query, int k) throws IOException { searchArgs.arbitraryScoreTieBreak = this.backwardsCompatibilityLucene8; searchArgs.hits = k; + // encoded query can be tokenized using whitespace analyzer + Analyzer whiteSpaceAnalyzer = new WhitespaceAnalyzer(); + List queryTokens = AnalyzerUtils.analyze(analyzer, encodedQuery); + TopDocs rs; RerankerContext context; if (this.backwardsCompatibilityLucene8) { @@ -286,7 +669,7 @@ protected Result[] _search(Query query, int k) throws IOException { rs = searcher.search(query, k, BREAK_SCORE_TIES_BY_DOCID, true); } context = new RerankerContext<>(searcher, null, query, null, - null, null, null, searchArgs); + encodedQuery, queryTokens, null, searchArgs); ScoredDocuments hits = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context); diff --git a/src/main/java/io/anserini/search/query/QueryEncoder.java b/src/main/java/io/anserini/search/query/QueryEncoder.java index 8b91c60aae..0f65b6bd6b 100644 --- a/src/main/java/io/anserini/search/query/QueryEncoder.java +++ b/src/main/java/io/anserini/search/query/QueryEncoder.java @@ -56,7 +56,7 @@ protected static long[] convertTokensToIds(BertFullTokenizer tokenizer, List tokenWeightMap) { + public String generateEncodedQuery(Map tokenWeightMap) { /* * This function generates the encoded query. */ @@ -72,7 +72,23 @@ protected String generateEncodedQuery(Map tokenWeightMap) { return String.join(" ", encodedQuery); } - static Map getTokenWeightMap(long[] indexes, float[] computedWeights, DefaultVocabulary vocab) { + public Map getEncodedQueryMap(Map tokenWeightMap) throws OrtException { + Map encodedQuery = new HashMap<>(); + for (Map.Entry entry : tokenWeightMap.entrySet()) { + String token = entry.getKey(); + Float tokenWeight = entry.getValue(); + int weightQuanted = Math.round(tokenWeight / weightRange * quantRange); + encodedQuery.put(token, weightQuanted); + } + return encodedQuery; + } + + public Map getEncodedQueryMap(String query) throws OrtException { + Map tokenWeightMap = getTokenWeightMap(query); + return getEncodedQueryMap(tokenWeightMap); + } + + static protected Map getTokenWeightMap(long[] indexes, float[] computedWeights, DefaultVocabulary vocab) { /* * This function returns a map of token to its weight. */ @@ -87,6 +103,5 @@ static Map getTokenWeightMap(long[] indexes, float[] computedWeig return tokenWeightMap; } - public abstract Map getTokenWeightMap(String query) throws OrtException; - + protected abstract Map getTokenWeightMap(String query) throws OrtException; } \ No newline at end of file diff --git a/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java b/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java index d4afd70bff..7f29a2f1c9 100644 --- a/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/SpladePlusPlusEnsembleDistilQueryEncoder.java @@ -47,7 +47,7 @@ public String encode(String query) throws OrtException { } @Override - public Map getTokenWeightMap(String query) throws OrtException { + protected Map getTokenWeightMap(String query) throws OrtException { List queryTokens = new ArrayList<>(); queryTokens.add("[CLS]"); queryTokens.addAll(tokenizer.tokenize(query)); diff --git a/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java b/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java index d8d8ba48aa..c5baa440fe 100644 --- a/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/SpladePlusPlusSelfDistilQueryEncoder.java @@ -46,7 +46,7 @@ public String encode(String query) throws OrtException { } @Override - public Map getTokenWeightMap(String query) throws OrtException { + protected Map getTokenWeightMap(String query) throws OrtException { List queryTokens = new ArrayList<>(); queryTokens.add("[CLS]"); queryTokens.addAll(tokenizer.tokenize(query)); diff --git a/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java b/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java index 95bf44d874..80ca87d153 100644 --- a/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java +++ b/src/main/java/io/anserini/search/query/UniCoilQueryEncoder.java @@ -87,7 +87,7 @@ private Map getTokenWeightMap(List tokens, float[] comput } @Override - public Map getTokenWeightMap(String query) throws OrtException { + protected Map getTokenWeightMap(String query) throws OrtException { List queryTokens = new ArrayList<>(); queryTokens.add("[CLS]"); queryTokens.addAll(tokenizer.tokenize(query)); diff --git a/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene8Test.java b/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene8Test.java index 3a48453eb1..206ea90bf8 100644 --- a/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene8Test.java +++ b/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene8Test.java @@ -33,8 +33,8 @@ public void testSearch1() throws Exception { SimpleImpactSearcher.Result[] hits; - Map query = new HashMap<>(); - query.put("##ing", 1.0f); + Map query = new HashMap<>(); + query.put("##ing", 1); hits = searcher.search(query, 10); assertEquals(1, hits.length); @@ -42,7 +42,7 @@ public void testSearch1() throws Exception { assertEquals(2, (int) hits[0].score); query = new HashMap<>(); - query.put("test", 1.0f); + query.put("test", 1); hits = searcher.search(query, 10); assertEquals(1, hits.length); assertEquals("2000000", hits[0].docid); diff --git a/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene9Test.java b/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene9Test.java index c30d26eaad..de4dd66c53 100644 --- a/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene9Test.java +++ b/src/test/java/io/anserini/search/SimpleImpactSearcherPrebuiltLucene9Test.java @@ -33,8 +33,8 @@ public void testSearch1() throws Exception { SimpleImpactSearcher.Result[] hits; - Map query = new HashMap<>(); - query.put("##ing", 1.0f); + Map query = new HashMap<>(); + query.put("##ing", 1); hits = searcher.search(query, 10); assertEquals(1, hits.length); @@ -42,7 +42,7 @@ public void testSearch1() throws Exception { assertEquals(2, (int) hits[0].score); query = new HashMap<>(); - query.put("test", 1.0f); + query.put("test", 1); hits = searcher.search(query, 10); assertEquals(1, hits.length); assertEquals("2000000", hits[0].docid); diff --git a/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java index 86379d3f7b..3348273593 100644 --- a/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java +++ b/src/test/java/io/anserini/search/SimpleImpactSearcherTest.java @@ -23,15 +23,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import io.anserini.search.SimpleImpactSearcher.Result; + public class SimpleImpactSearcherTest extends IndexerTestBase { - private static Map EXPECTED_ENCODED_QUERY = new HashMap<>(); + private static Map EXPECTED_ENCODED_QUERY = new HashMap<>(); static { - EXPECTED_ENCODED_QUERY.put("here", 3.05345f); - EXPECTED_ENCODED_QUERY.put("a", 0.59636426f); - EXPECTED_ENCODED_QUERY.put("test", 2.9012794f); + EXPECTED_ENCODED_QUERY.put("here", 156); + EXPECTED_ENCODED_QUERY.put("a", 31); + EXPECTED_ENCODED_QUERY.put("test", 149); } @Test @@ -109,14 +111,22 @@ public void testGetRaw() throws Exception { public void testSearch1() throws Exception { SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); - Map testQuery = new HashMap<>(); - testQuery.put("test", 1.2f); + Map testQuery = new HashMap<>(); + testQuery.put("test", 1); SimpleImpactSearcher.Result[] hits = searcher.search(testQuery, 10); + SimpleImpactSearcher.Result[] hits_string = searcher.search("test", 10); + assertEquals(hits_string.length, hits.length); + assertEquals(hits_string[0].docid, hits[0].docid); + assertEquals(hits_string[0].lucene_docid, hits[0].lucene_docid); + assertEquals(hits_string[0].score, hits[0].score, 10e-6); + assertEquals(hits_string[0].contents, hits[0].contents); + assertEquals(hits_string[0].raw, hits[0].raw); + assertEquals(1, hits.length); assertEquals("doc3", hits[0].docid); assertEquals(2, hits[0].lucene_docid); - assertEquals(1.2f, hits[0].score, 10e-6); + assertEquals(1.0f, hits[0].score, 10e-6); assertEquals("here is a test", hits[0].contents); assertEquals("{\"contents\": \"here is a test\"}", hits[0].raw); @@ -135,8 +145,8 @@ public void testSearch1() throws Exception { public void testSearch2() throws Exception { SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); - Map testQuery = new HashMap<>(); - testQuery.put("text", 1.2f); + Map testQuery = new HashMap<>(); + testQuery.put("text", 1); SimpleImpactSearcher.Result[] results; @@ -144,7 +154,7 @@ public void testSearch2() throws Exception { assertEquals(1, results.length); assertEquals("doc1", results[0].docid); assertEquals(0, results[0].lucene_docid); - assertEquals(2.4f, results[0].score, 10e-6); + assertEquals(2.0f, results[0].score, 10e-6); assertEquals("here is some text here is some more text. city.", results[0].contents); assertEquals("{\"contents\": \"here is some text here is some more text. city.\"}", results[0].raw); @@ -154,17 +164,17 @@ public void testSearch2() throws Exception { assertEquals(0, results[0].lucene_docid); assertEquals("doc2", results[1].docid); assertEquals(1, results[1].lucene_docid); - assertEquals(2.4f, results[0].score, 10e-6); - assertEquals(1.2f, results[1].score, 10e-6); + assertEquals(2.0f, results[0].score, 10e-6); + assertEquals(1.0f, results[1].score, 10e-6); - Map testQuery2 = new HashMap<>(); - testQuery2.put("test", 0.125f); + Map testQuery2 = new HashMap<>(); + testQuery2.put("test", 1); results = searcher.search(testQuery2); assertEquals(1, results.length); assertEquals("doc3", results[0].docid); assertEquals(2, results[0].lucene_docid); - assertEquals(0.125f, results[0].score, 10e-6); + assertEquals(1.0f, results[0].score, 10e-6); searcher.close(); } @@ -172,13 +182,13 @@ public void testSearch2() throws Exception { @Test public void testBatchSearch() throws Exception { SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); - Map testQuery1 = new HashMap<>(); - testQuery1.put("tests", 0.1f); - testQuery1.put("test", 0.1f); - Map testQuery2 = new HashMap<>(); - testQuery2.put("more", 1.5f); + Map testQuery1 = new HashMap<>(); + testQuery1.put("tests", 1); + testQuery1.put("test", 1); + Map testQuery2 = new HashMap<>(); + testQuery2.put("more", 3); - List> queries = new ArrayList<>(); + List> queries = new ArrayList<>(); queries.add(testQuery1); queries.add(testQuery2); @@ -205,14 +215,127 @@ public void testTotalNumDocuments() throws Exception { assertEquals(3 ,searcher.get_total_num_docs()); } + @Test + public void testOnnxEncodedQuery() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + Map testQuery1 = new HashMap<>(); + testQuery1.put("text", 2); + String encodedQuery = searcher.encodeWithOnnx(testQuery1); + assertEquals("text text" ,encodedQuery); + } + @Test public void testOnnxEncoder() throws Exception{ SimpleImpactSearcher searcher = new SimpleImpactSearcher(); searcher.set_onnx_query_encoder("SpladePlusPlusEnsembleDistil"); - Map encoded_query = searcher.encode_with_onnx("here is a test"); + Map encoded_query = searcher.encodeWithOnnx("here is a test"); assertEquals(encoded_query.get("here"), EXPECTED_ENCODED_QUERY.get("here"), 2e-4); assertEquals(encoded_query.get("a"), EXPECTED_ENCODED_QUERY.get("a"), 2e-4); assertEquals(encoded_query.get("test"), EXPECTED_ENCODED_QUERY.get("test"), 2e-4); } + + @Test + public void testSearch3() throws Exception { + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + searcher.set_rm3(); + assertTrue(searcher.use_rm3()); + + Result[] results; + + Map testQuery1 = new HashMap<>(); + testQuery1.put("text", 1); + + results = searcher.search(testQuery1, 1); + assertEquals(1, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(1.0f, results[0].score, 10e-5); + + Map testQuery2 = new HashMap<>(); + testQuery2.put("test", 1); + + results = searcher.search(testQuery2); + assertEquals(1, results.length); + assertEquals("doc3", results[0].docid); + assertEquals(2, results[0].lucene_docid); + assertEquals(0.5f, results[0].score, 10e-5); + + Map testQuery3 = new HashMap<>(); + testQuery3.put("more", 1); + + results = searcher.search(testQuery3); + + assertEquals(2, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(0.5f, results[0].score, 10e-5); + assertEquals("doc2", results[1].docid); + assertEquals(1, results[1].lucene_docid); + assertEquals(0.5f, results[1].score, 10e-5); + + searcher.unset_rm3(); + assertFalse(searcher.use_rm3()); + + results = searcher.search(testQuery1, 1); + assertEquals(1, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(2.0f, results[0].score, 10e-5); + + searcher.close(); + } + + @Test + public void testSearch4() throws Exception { + // This adds Rocchio on top of "testSearch1" + SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString()); + searcher.set_rocchio(); + assertTrue(searcher.use_rocchio()); + + Result[] results; + + Map testQuery1 = new HashMap<>(); + testQuery1.put("text", 1); + + results = searcher.search(testQuery1, 1); + assertEquals(1, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(2.0f, results[0].score, 10e-5); + + Map testQuery2 = new HashMap<>(); + testQuery2.put("test", 1); + + results = searcher.search(testQuery2); + assertEquals(1, results.length); + assertEquals("doc3", results[0].docid); + assertEquals(2, results[0].lucene_docid); + assertEquals(1.0f, results[0].score, 10e-5); + + Map testQuery3 = new HashMap<>(); + testQuery3.put("more", 1); + + results = searcher.search(testQuery3); + assertEquals(2, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(1.0f, results[0].score, 10e-5); + assertEquals("doc2", results[1].docid); + assertEquals(1, results[1].lucene_docid); + assertEquals(1.0f, results[1].score, 10e-5); + + + searcher.unset_rocchio(); + assertFalse(searcher.use_rocchio()); + + results = searcher.search(testQuery1, 1); + assertEquals(1, results.length); + assertEquals("doc1", results[0].docid); + assertEquals(0, results[0].lucene_docid); + assertEquals(2.0f, results[0].score, 10e-5); + + searcher.close(); + } + }