Skip to content

Commit

Permalink
Add RM3 reranker to SimpleSearcher (#519)
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor0118 authored and lintool committed Dec 16, 2018
1 parent d9e4bad commit a152359
Showing 1 changed file with 41 additions and 5 deletions.
46 changes: 41 additions & 5 deletions src/main/java/io/anserini/search/SimpleSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
package io.anserini.search;

import io.anserini.index.generator.LuceneDocumentGenerator;
import io.anserini.rerank.RerankerCascade;
import io.anserini.rerank.RerankerContext;
import io.anserini.rerank.ScoredDocuments;
import io.anserini.rerank.lib.Rm3Reranker;
import io.anserini.rerank.lib.ScoreTiesAdjusterReranker;
import io.anserini.search.query.BagOfWordsQueryGenerator;
import io.anserini.search.similarity.F2ExpSimilarity;
import io.anserini.search.similarity.F2LogSimilarity;
import io.anserini.util.AnalyzerUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.analysis.Analyzer;
Expand All @@ -40,12 +46,16 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;

import static io.anserini.index.generator.LuceneDocumentGenerator.FIELD_BODY;

public class SimpleSearcher implements Closeable {
private static final Logger LOG = LogManager.getLogger(SimpleSearcher.class);
private final IndexReader reader;
private Similarity similarity;
private Analyzer analyzer;
private RerankerCascade cascade;

protected class Result {
public String docid;
Expand All @@ -71,6 +81,26 @@ public SimpleSearcher(String indexDir) throws IOException {
this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
this.similarity = new LMDirichletSimilarity(1000.0f);
this.analyzer = new EnglishAnalyzer();
setNormalReranker();
}

public void setRM3Reranker() {
setRM3Reranker(20, 50, 0.6f, false);
}

public void setRM3Reranker(int fbTerms, int fbDocs, float originalQueryWeight) {
setRM3Reranker(fbTerms, fbDocs, originalQueryWeight, false);
}

public void setNormalReranker() {
cascade = new RerankerCascade();
cascade.add(new ScoreTiesAdjusterReranker());
}

public void setRM3Reranker(int fbTerms, int fbDocs, float originalQueryWeight, boolean rm3_outputQuery) {
cascade = new RerankerCascade();
cascade.add(new Rm3Reranker(this.analyzer, FIELD_BODY, fbTerms, fbDocs, originalQueryWeight, rm3_outputQuery));
cascade.add(new ScoreTiesAdjusterReranker());
}

public void setLMDirichletSimilarity(float mu) {
Expand Down Expand Up @@ -116,15 +146,21 @@ public Result[] search(String q, int k) throws IOException {
Query query = new BagOfWordsQueryGenerator().buildQuery(LuceneDocumentGenerator.FIELD_BODY, analyzer, q);

TopDocs rs = searcher.search(query, k);
ScoreDoc[] hits = rs.scoreDocs;

Result[] results = new Result[hits.length];
for (int i = 0; i < hits.length; i++) {
Document doc = searcher.doc(hits[i].doc);
List<String> queryTokens = AnalyzerUtils.tokenize(analyzer, q);
SearchArgs searchArgs = new SearchArgs();
searchArgs.arbitraryScoreTieBreak = false;
searchArgs.hits = k;
RerankerContext context = new RerankerContext<>(searcher, null, query, null, q, queryTokens, null, searchArgs);
ScoredDocuments hits = cascade.run(ScoredDocuments.fromTopDocs(rs, searcher), context);

Result[] results = new Result[hits.ids.length];
for (int i = 0; i < hits.ids.length; i++) {
Document doc = hits.documents[i];
String docid = doc.getField(LuceneDocumentGenerator.FIELD_ID).stringValue();
IndexableField field = doc.getField(LuceneDocumentGenerator.FIELD_RAW);
String content = field == null ? null : field.stringValue();
results[i] = new Result(docid, hits[i].doc, hits[i].score, content);
results[i] = new Result(docid, hits.ids[i], hits.scores[i], content);
}

return results;
Expand Down

0 comments on commit a152359

Please sign in to comment.