Skip to content

Commit

Permalink
General refactoring (#2269)
Browse files Browse the repository at this point in the history
+ Refactored SearchHnswDenseVectors and SearchInvertedDenseVectors,
  cleaning up code paths and main error handling.
+ Added test cases, cleaned up existing test cases.
+ Other minor cleanup.
  • Loading branch information
lintool committed Nov 25, 2023
1 parent 95a546c commit 3435e8a
Show file tree
Hide file tree
Showing 33 changed files with 762 additions and 182 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Key:
+ MF = "multifield" baseline (Lucene analyzer)
+ U1 = uniCOIL (noexp)
+ S1 = SPLADE-distill CoCodenser-medium
+ S2 = SPLADE++ (CoCondenser-EnsembleDistil)
+ S2 = SPLADE++ CoCondenser-EnsembleDistil

| Corpus | F1 | F2 | MF | U1 | S1 | S2 |
|-------------------------|:-----------------------------------------------------------------------------:|:--------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------:|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.dl19-passage.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-fw-40.topics.dl19-passage.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding fw -fw.q 40 -hits 1000 &
-topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.dl19-passage.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-lexlsh-600.topics.dl19-passage.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 &
-topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.dl20.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-fw-40.topics.dl20.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding fw -fw.q 40 -hits 1000 &
-topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.dl20.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-lexlsh-600.topics.dl20.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 &
-topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
4 changes: 2 additions & 2 deletions docs/regressions/regressions-msmarco-doc-docTTTTTquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ For default parameters (`k1=0.9`, `b=0.4`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-doc-docTTTTTquery/ \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-doc-docTTTTTquery.bm25-default.txt \
-format msmarco \
-bm25 -hits 100
Expand All @@ -159,7 +159,7 @@ For tuned parameters (`k1=4.68`, `b=0.87`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-doc-docTTTTTquery/ \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-doc-docTTTTTquery.bm25-tuned.txt \
-format msmarco \
-bm25 -bm25.k1 4.68 -bm25.b 0.87 -hits 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ The MaxP passage retrieval functionality is available in `SearchCollection`.
To generate an MS MARCO submission with the BM25 default parameters, corresponding to "BM25 (default)" above:

```bash
$ sh target/appassembler/bin/SearchCollection -topicReader TsvString \
$ sh target/appassembler/bin/SearchCollection -topicreader TsvString \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-index indexes/lucene-index.msmarco-doc-segmented-docTTTTTquery/ \
-output runs/run.msmarco-doc-segmented-docTTTTTquery.bm25-default.txt -format msmarco \
Expand All @@ -131,7 +131,7 @@ Note that the above command uses `-format msmarco` to directly generate a run in
To generate an MS MARCO submission with the BM25 tuned parameters, corresponding to "BM25 (tuned)" above:

```bash
$ sh target/appassembler/bin/SearchCollection -topicReader TsvString \
$ sh target/appassembler/bin/SearchCollection -topicreader TsvString \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-index indexes/lucene-index.msmarco-doc-segmented-docTTTTTquery/ \
-output runs/run.msmarco-doc-segmented-docTTTTTquery.bm25-tuned.txt -format msmarco \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ The following command generates a comparable run:
target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-doc-segmented-unicoil/ \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.unicoil.tsv.gz \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-doc-segmented-unicoil.msmarco-doc.dev.txt \
-format msmarco \
-impact -pretokenized -hits 10000 -selectMaxPassage -selectMaxPassage.delimiter "#" -selectMaxPassage.hits 100
Expand Down
4 changes: 2 additions & 2 deletions docs/regressions/regressions-msmarco-doc-segmented.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ The MaxP passage retrieval functionality is available in `SearchCollection`.
To generate an MS MARCO submission with the BM25 default parameters, corresponding to "BM25 (default)" above:

```bash
$ target/appassembler/bin/SearchCollection -topicReader TsvString \
$ target/appassembler/bin/SearchCollection -topicreader TsvString \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-index indexes/lucene-index.msmarco-doc-segmented/ \
-output runs/run.msmarco-doc-segmented.bm25-default.txt -format msmarco \
Expand All @@ -132,7 +132,7 @@ Note that the above command uses `-format msmarco` to directly generate a run in
To generate an MS MARCO submission with the BM25 tuned parameters, corresponding to "BM25 (tuned)" above:

```bash
$ target/appassembler/bin/SearchCollection -topicReader TsvString \
$ target/appassembler/bin/SearchCollection -topicreader TsvString \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-index indexes/lucene-index.msmarco-doc-segmented/ \
-output runs/run.msmarco-doc-segmented.bm25-tuned.txt -format msmarco \
Expand Down
4 changes: 2 additions & 2 deletions docs/regressions/regressions-msmarco-doc.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ For default parameters (`k1=0.9`, `b=0.4`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-doc/ \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-doc.bm25-default.txt \
-format msmarco \
-bm25 -hits 100
Expand All @@ -194,7 +194,7 @@ For tuned parameters (`k1=4.46`, `b=0.82`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-doc/ \
-topics tools/topics-and-qrels/topics.msmarco-doc.dev.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-doc.bm25-tuned.txt \
-format msmarco \
-bm25 -bm25.k1 4.46 -bm25.b 0.82 -hits 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-fw-40.topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding fw -fw.q 40 -hits 1000 &
-topicField vector -threads 16 -encoding fw -fw.q 40 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ target/appassembler/bin/SearchInvertedDenseVectors \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.gz \
-topicReader JsonIntVector \
-output runs/run.msmarco-passage-cos-dpr-distil.cos-dpr-distil-lexlsh-600.topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.txt \
-topicField vector -encoding lexlsh -lexlsh.b 600 -hits 1000 &
-topicField vector -threads 16 -encoding lexlsh -lexlsh.b 600 -hits 1000 &
```

Evaluation can be performed using `trec_eval`:
Expand Down
4 changes: 2 additions & 2 deletions docs/regressions/regressions-msmarco-passage-docTTTTTquery.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ For parameters `k1=0.82`, `b=0.68`:
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-passage-docTTTTTquery/ \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-passage-docTTTTTquery.1 \
-format msmarco \
-bm25 -bm25.k1 0.82 -bm25.b 0.68
Expand All @@ -185,7 +185,7 @@ For parameters `k1=2.18`, `b=0.86`:
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-passage-docTTTTTquery/ \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-passage-docTTTTTquery.2 \
-format msmarco \
-bm25 -bm25.k1 2.18 -bm25.b 0.86
Expand Down
4 changes: 2 additions & 2 deletions docs/regressions/regressions-msmarco-passage.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ For default parameters (`k1=0.9`, `b=0.4`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-passage/ \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-passage.bm25.default.tsv \
-format msmarco \
-bm25
Expand All @@ -162,7 +162,7 @@ For tuned parameters (`k1=0.82`, `b=0.68`):
$ sh target/appassembler/bin/SearchCollection \
-index indexes/lucene-index.msmarco-passage/ \
-topics tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt \
-topicReader TsvInt \
-topicreader TsvInt \
-output runs/run.msmarco-passage.bm25.tuned.tsv \
-format msmarco \
-bm25 -bm25.k1 0.82 -bm25.b 0.68
Expand Down
105 changes: 71 additions & 34 deletions src/main/java/io/anserini/search/SearchHnswDenseVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
Expand Down Expand Up @@ -113,6 +116,9 @@ public static class Args {
@Option(name ="-encoder", metaVar = "[encoder]", usage = "Dense encoder to use.")
public String encoder = null;

@Option(name = "-options", usage = "Print information about options.")
public Boolean options = false;

// ---------------------------------------------
// Simple built-in support for passage retrieval
// ---------------------------------------------
Expand Down Expand Up @@ -146,32 +152,36 @@ public static class Args {
private final IndexSearcher searcher;
private final VectorQueryGenerator generator;
private final DenseEncoder queryEncoder;
private final SortedMap<K, String> queries = new TreeMap<>();
private final ConcurrentSkipListMap<K, String> results = new ConcurrentSkipListMap<>();

public SearchHnswDenseVectors(Args args) throws IOException {
this.args = args;
Path indexPath = Paths.get(args.index);

if (!Files.exists(indexPath) || !Files.isDirectory(indexPath) || !Files.isReadable(indexPath)) {
throw new IllegalArgumentException(String.format("Index path '%s' does not exist or is not a directory.", args.index));
}

LOG.info("============ Initializing HNSW Searcher ============");
LOG.info("Index: " + indexPath);
LOG.info("Index: " + args.index);
LOG.info("Topics: " + Arrays.toString(args.topics));
LOG.info("Query generator: " + args.queryGenerator);
LOG.info("Encoder: " + args.encoder);
LOG.info("Threads: " + args.threads);

this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
// We might not be able to successfully create a reader for a variety of reasons, anything from path doesn't exist
// to corrupt index. Gather all possible exceptions together as an unchecked exception to make initialization and
// error reporting clearer.
try {
this.reader = DirectoryReader.open(FSDirectory.open(Paths.get(args.index)));
} catch (IOException e) {
throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid index.", args.index));
}

this.searcher = new IndexSearcher(this.reader);

try {
this.generator = (VectorQueryGenerator) Class
.forName(String.format("io.anserini.search.query.%s", args.queryGenerator))
.getConstructor().newInstance();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException("Unable to load QueryGenerator: " + args.queryGenerator);
throw new IllegalArgumentException(String.format("Unable to load QueryGenerator \"%s\".", args.queryGenerator));
}

if (args.encoder != null) {
Expand All @@ -180,51 +190,65 @@ public SearchHnswDenseVectors(Args args) throws IOException {
.forName(String.format("io.anserini.encoder.dense.%sEncoder", args.encoder))
.getConstructor().newInstance();
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException("Unable to load encoder: " + args.encoder);
throw new IllegalArgumentException(String.format("Unable to load Encoder \"%s\".", args.encoder));
}
} else {
queryEncoder = null;
}

}

@Override
public void close() throws IOException {
reader.close();
}

@SuppressWarnings("unchecked")
@Override
public void run() {
// Same as above: we might not be able to successfully read topics for a variety of reasons. Gather all possible
// exceptions together as an unchecked exception to make initialization and error reporting clearer.
SortedMap<K, Map<String, String>> topics = new TreeMap<>();
for (String file : args.topics) {
Path topicsFilePath = Paths.get(file);
for (String singleTopicsFile : args.topics) {
Path topicsFilePath = Paths.get(singleTopicsFile);
if (!Files.exists(topicsFilePath) || !Files.isRegularFile(topicsFilePath) || !Files.isReadable(topicsFilePath)) {
throw new IllegalArgumentException("Topics file : " + topicsFilePath + " does not exist or is not a (readable) file.");
throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid topics file.", topicsFilePath));
}
try {
@SuppressWarnings("unchecked")
TopicReader<K> tr = (TopicReader<K>) Class
.forName(String.format("io.anserini.search.topicreader.%sTopicReader", args.topicReader))
.getConstructor(Path.class).newInstance(topicsFilePath);

topics.putAll(tr.read());
} catch (Exception e) {
e.printStackTrace();
throw new IllegalArgumentException("Unable to load topic reader: " + args.topicReader);
throw new IllegalArgumentException(String.format("Unable to load topic reader \"%s\".", args.topicReader));
}
}

// Now iterate through all the topics to pick out the right field with proper exception handling.
try {
for (Map.Entry<K, Map<String, String>> entry : topics.entrySet()) {
K qid = entry.getKey();
String query = entry.getValue().get(args.topicField);
assert query != null;

this.queries.put(qid, query);
}
} catch (AssertionError|Exception e) {
throw new IllegalArgumentException(String.format("Unable to read topic field \"%s\".", args.topicField));
}
}

@Override
public void close() throws IOException {
reader.close();
}

@SuppressWarnings("unchecked")
@Override
public void run() {
LOG.info("============ Launching Search Threads ============");
final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(args.threads);
final AtomicInteger cnt = new AtomicInteger();

final long start = System.nanoTime();
for (Map.Entry<K, Map<String, String>> entry : topics.entrySet()) {
for (Map.Entry<K, String> entry : queries.entrySet()) {
K qid = entry.getKey();

// This is the per-query execution, in parallel.
executor.execute(() -> {
String queryString = entry.getValue().get(args.topicField);
String queryString = entry.getValue();
ScoredDocuments docs;

try {
Expand Down Expand Up @@ -259,9 +283,9 @@ public void run() {
}
final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS);

LOG.info(topics.size() + " queries processed in " +
LOG.info(queries.size() + " queries processed in " +
DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") +
String.format(" = ~%.2f q/s", topics.size()/(durationMillis/1000.0)));
String.format(" = ~%.2f q/s", queries.size()/(durationMillis/1000.0)));

// Now we write the results to a run file.
try {
Expand Down Expand Up @@ -300,9 +324,22 @@ public static void main(String[] args) throws Exception {
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
System.err.println(e.getMessage());
parser.printUsage(System.err);
System.err.println("Example: SearchHnswDenseVectors" + parser.printExample(OptionHandlerFilter.REQUIRED));
if (searchArgs.options) {
System.err.printf("Options for %s:\n\n", SearchHnswDenseVectors.class.getSimpleName());
parser.printUsage(System.err);

List<String> required = new ArrayList<>();
parser.getOptions().forEach((option) -> {
if (option.option.required()) {
required.add(option.option.toString());
}
});

System.err.printf("\nRequired options are %s\n", required);
} else {
System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", e.getMessage());
}

return;
}

Expand All @@ -315,7 +352,7 @@ public static void main(String[] args) throws Exception {
searcher.run();
searcher.close();
} catch (IllegalArgumentException e) {
System.err.println(e.getMessage());
System.err.printf("Error: %s\n", e.getMessage());
return;
}

Expand Down
Loading

0 comments on commit 3435e8a

Please sign in to comment.