Skip to content

Commit

Permalink
Support field search and batched field search for MS MARCO (#988)
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinzhng authored Feb 11, 2020
1 parent d6f1abf commit eff7755
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 14 deletions.
30 changes: 28 additions & 2 deletions src/main/java/io/anserini/search/SearchMsmarco.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.OptionHandlerFilter;
import org.kohsuke.args4j.ParserProperties;
import org.kohsuke.args4j.spi.MapOptionHandler;

import java.io.File;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -73,6 +76,10 @@ public static class Args {

@Option(name = "-originalQueryWeight", metaVar = "[value]", usage = "RM3 parameter: weight to assign to the original query")
public float originalQueryWeight = 0.5f;

@Option(name = "-fields", metaVar = "[key=value]", handler = MapOptionHandler.class,
usage = "Fields to search with assigned float weight")
public Map<String, String> fields = new HashMap<>();
}

public static void main(String[] args) throws Exception {
Expand All @@ -88,6 +95,11 @@ public static void main(String[] args) throws Exception {
return;
}

Map<String, Float> fields = new HashMap<>();
retrieveArgs.fields.forEach((key, value) -> {
fields.put(key, Float.valueOf(value));
});

long totalStartTime = System.nanoTime();

SimpleSearcher searcher = new SimpleSearcher(retrieveArgs.index);
Expand All @@ -100,6 +112,10 @@ public static void main(String[] args) throws Exception {
+ " and originalQueryWeight=" + retrieveArgs.originalQueryWeight);
}

if (retrieveArgs.fields.size() > 0) {
System.out.println("Performing weighted field search with fields=" + retrieveArgs.fields);
}

PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(retrieveArgs.output), StandardCharsets.US_ASCII));

if (retrieveArgs.threads == 1) {
Expand All @@ -113,7 +129,12 @@ public static void main(String[] args) throws Exception {
String qid = split[0];
String query = split[1];

SimpleSearcher.Result[] hits = searcher.search(query, retrieveArgs.hits);
SimpleSearcher.Result[] hits;
if (retrieveArgs.fields.size() > 0) {
hits = searcher.searchFields(query, fields, retrieveArgs.hits);
} else {
hits = searcher.search(query, retrieveArgs.hits);
}

if (lineNumber % 100 == 0) {
double timePerQuery = (double) (System.nanoTime() - startTime) / (lineNumber + 1) / 1e9;
Expand All @@ -131,7 +152,12 @@ public static void main(String[] args) throws Exception {
List<String> queries = lines.stream().map(x -> x.trim().split("\t")[1]).collect(Collectors.toList());
List<String> qids = lines.stream().map(x -> x.trim().split("\t")[0]).collect(Collectors.toList());

Map<String, SimpleSearcher.Result[]> results = searcher.batchSearch(queries, qids, retrieveArgs.hits, -1, retrieveArgs.threads);
Map<String, SimpleSearcher.Result[]> results;
if (retrieveArgs.fields.size() > 0) {
results = searcher.batchSearchFields(queries, qids, retrieveArgs.hits, retrieveArgs.threads, fields);
} else {
results = searcher.batchSearch(queries, qids, retrieveArgs.hits, retrieveArgs.threads);
}

for (String qid : qids) {
SimpleSearcher.Result[] hits = results.get(qid);
Expand Down
39 changes: 28 additions & 11 deletions src/main/java/io/anserini/search/SimpleSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletionException;
Expand Down Expand Up @@ -188,10 +190,14 @@ public void close() throws IOException {
}

public Map<String, Result[]> batchSearch(List<String> queries, List<String> qids, int k, int threads) {
return batchSearch(queries, qids, k, -1, threads);
return batchSearch(queries, qids, k, -1, threads, new HashMap<>());
}

public Map<String, Result[]> batchSearch(List<String> queries, List<String> qids, int k, long t, int threads) {
public Map<String, Result[]> batchSearchFields(List<String> queries, List<String> qids, int k, int threads, Map<String, Float> fields) {
return batchSearch(queries, qids, k, -1, threads, fields);
}

public Map<String, Result[]> batchSearch(List<String> queries, List<String> qids, int k, long t, int threads, Map<String, Float> fields) {
ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(threads);
ConcurrentHashMap<String, Result[]> results = new ConcurrentHashMap<>();

Expand All @@ -203,7 +209,11 @@ public Map<String, Result[]> batchSearch(List<String> queries, List<String> qids
String qid = qids.get(q);
executor.execute(() -> {
try {
results.put(qid, search(query, k, t));
if (fields.size() > 0) {
results.put(qid, searchFields(query, fields, k, t));
} else {
results.put(qid, search(query, k, t));
}
} catch (IOException e) {
throw new CompletionException(e);
}
Expand Down Expand Up @@ -235,7 +245,7 @@ public Map<String, Result[]> batchSearch(List<String> queries, List<String> qids
throw new RuntimeException("queryCount = " + queryCnt +
" is not equal to completedTaskCount = " + executor.getCompletedTaskCount());
}

return results;
}

Expand Down Expand Up @@ -304,19 +314,26 @@ protected Result[] search(Query query, List<String> queryTokens, String queryStr
return results;
}

public Result[] searchFields(String q, Map<String, Float> fields, int k) throws IOException {
return searchFields(q, fields, k, -1);
}

// searching both the defaults contents fields and another field with weight boost
// this is used for MS MACRO experiments with query expansion.
// TODO: "fields" should probably changed to a map of fields to boosts for extensibility
public Result[] searchFields(String q, String f, float boost, int k) throws IOException {
// this is used for MS MARCO experiments with document expansion.
public Result[] searchFields(String q, Map<String, Float> fields, int k, long t) throws IOException {
IndexSearcher searcher = new IndexSearcher(reader);
searcher.setSimilarity(similarity);

Query queryContents = new BagOfWordsQueryGenerator().buildQuery(LuceneDocumentGenerator.FIELD_BODY, analyzer, q);
Query queryField = new BagOfWordsQueryGenerator().buildQuery(f, analyzer, q);
BooleanQuery query = new BooleanQuery.Builder()
.add(queryContents, BooleanClause.Occur.SHOULD)
.add(new BoostQuery(queryField, boost), BooleanClause.Occur.SHOULD).build();
BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder()
.add(queryContents, BooleanClause.Occur.SHOULD);

for (Map.Entry<String, Float> entry : fields.entrySet()) {
Query queryField = new BagOfWordsQueryGenerator().buildQuery(entry.getKey(), analyzer, q);
queryBuilder.add(new BoostQuery(queryField, entry.getValue()), BooleanClause.Occur.SHOULD);
}

BooleanQuery query = queryBuilder.build();
List<String> queryTokens = AnalyzerUtils.tokenize(analyzer, q);

return search(query, queryTokens, q, k, -1);
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/io/anserini/IndexerTestBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import java.nio.file.Path;

public class IndexerTestBase extends LuceneTestCase {
protected static Path tempDir1;
protected Path tempDir1;

// A very simple example of how to build an index.
private void buildTestIndex() throws IOException {
Expand Down
93 changes: 93 additions & 0 deletions src/test/java/io/anserini/search/SimpleSearcherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
import io.anserini.search.SimpleSearcher.Result;
import org.junit.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SimpleSearcherTest extends IndexerTestBase {

@Test
Expand Down Expand Up @@ -48,6 +53,8 @@ public void test1() throws Exception {
assertEquals("doc3", results[0].docid);
assertEquals(2, results[0].ldocid);
assertEquals(0.5648999810218811f, results[0].score, 10e-6);

searcher.close();
}

@Test
Expand All @@ -69,6 +76,8 @@ public void test2() throws Exception {
assertEquals("here is a test",
searcher.doc("doc3").getField("contents").stringValue());
assertEquals(null, searcher.doc(3));

searcher.close();
}

@Test
Expand All @@ -84,6 +93,90 @@ public void test3() throws Exception {
assertEquals("more texts", searcher.getContents("doc2"));
assertEquals("here is a test", searcher.getContents("doc3"));
assertEquals(null, searcher.getContents("doc42"));

searcher.close();
}

@Test
public void testSearch() throws Exception {
SimpleSearcher searcher = new SimpleSearcher(super.tempDir1.toString());

SimpleSearcher.Result[] hits = searcher.search("test", 10);
assertEquals(1, hits.length);
assertEquals("doc3", hits[0].docid);

searcher.close();
}

@Test
public void testBatchSearch() throws Exception {
SimpleSearcher searcher = new SimpleSearcher(super.tempDir1.toString());

List<String> queries = new ArrayList<>();
queries.add("test");
queries.add("more");

List<String> qids = new ArrayList<>();
qids.add("query_test");
qids.add("query_more");

Map<String, SimpleSearcher.Result[]> hits = searcher.batchSearch(queries, qids, 10, 2);
assertEquals(2, hits.size());

assertEquals(1, hits.get("query_test").length);
assertEquals("doc3", hits.get("query_test")[0].docid);

assertEquals(2, hits.get("query_more").length);
assertEquals("doc2", hits.get("query_more")[0].docid);
assertEquals("doc1", hits.get("query_more")[1].docid);

searcher.close();
}

@Test
public void testFieldedSearch() throws Exception {
SimpleSearcher searcher = new SimpleSearcher(super.tempDir1.toString());

Map<String, Float> fields = new HashMap<>();
fields.put("id", 1.0f);
fields.put("contents", 1.0f);

SimpleSearcher.Result[] hits = searcher.searchFields("doc1", fields, 10);
assertEquals(1, hits.length);
assertEquals("doc1", hits[0].docid);

hits = searcher.searchFields("test", fields, 10);
assertEquals(1, hits.length);
assertEquals("doc3", hits[0].docid);

searcher.close();
}

@Test
public void testFieldedBatchSearch() throws Exception {
SimpleSearcher searcher = new SimpleSearcher(super.tempDir1.toString());

List<String> queries = new ArrayList<>();
queries.add("doc1");
queries.add("test");

List<String> qids = new ArrayList<>();
qids.add("query_id");
qids.add("query_contents");

Map<String, Float> fields = new HashMap<>();
fields.put("id", 1.0f);
fields.put("contents", 1.0f);

Map<String, SimpleSearcher.Result[]> hits = searcher.batchSearchFields(queries, qids, 10, 2, fields);
assertEquals(2, hits.size());

assertEquals(1, hits.get("query_id").length);
assertEquals("doc1", hits.get("query_id")[0].docid);

assertEquals(1, hits.get("query_contents").length);
assertEquals("doc3", hits.get("query_contents")[0].docid);

searcher.close();
}
}

0 comments on commit eff7755

Please sign in to comment.