From 1ebe6dda261218b5a1481dcfeaeb47cfc83d0292 Mon Sep 17 00:00:00 2001 From: Arthur Chen <36494787+ArthurChen189@users.noreply.github.com> Date: Thu, 28 Dec 2023 19:13:55 -0500 Subject: [PATCH] Add reproduction script for "End-to-End Retrieval with Learned Dense and Sparse Representations Using Lucene" (#2317) --- .../java/io/anserini/index/IndexInfo.java | 56 ++++------- .../io/anserini/search/HnswDenseSearcher.java | 24 ++++- .../e2e_sparse_dense_lucene/reproduction.py | 92 +++++++++++++++++++ .../e2e_sparse_dense_lucene/pre-encoded.yaml | 78 ++++++++++++++++ .../io/anserini/index/PrebuiltIndexTest.java | 2 +- 5 files changed, 211 insertions(+), 41 deletions(-) create mode 100644 src/main/python/e2e_sparse_dense_lucene/reproduction.py create mode 100644 src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml diff --git a/src/main/java/io/anserini/index/IndexInfo.java b/src/main/java/io/anserini/index/IndexInfo.java index 767493ef8b..bf3567be55 100644 --- a/src/main/java/io/anserini/index/IndexInfo.java +++ b/src/main/java/io/anserini/index/IndexInfo.java @@ -20,65 +20,43 @@ public enum IndexInfo { MSMARCO_V1_PASSAGE("msmarco-v1-passage", "Lucene index of the MS MARCO V1 passage corpus. (Lucene 9)", "lucene-index.msmarco-v1-passage.20221004.252b5e.tar.gz", - "lucene-index.msmarco-v1-passage.20221004.252b5e.README.md", new String[] { "https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.msmarco-v1-passage.20221004.252b5e.tar.gz" }, - "c697b18c9a0686ca760583e615dbe450", "2170758938", "352316036", "8841823", - "2660824", false), + "c697b18c9a0686ca760583e615dbe450"), CACM("cacm", "Lucene index of the CACM corpus. (Lucene 9)", "lucene-index.cacm.tar.gz", new String[] { "https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.20221005.252b5e.tar.gz" }, - "cfe14d543c6a27f4d742fb2d0099b8e0", - "2347197", - "320968", - "3204", - "14363"); + "cfe14d543c6a27f4d742fb2d0099b8e0"), + + MSMARCO_V1_PASSAGE_COS_DPR_DISTIL("msmarco-v1-passage-cos-dpr-distil", + "Lucene index of the MS MARCO V1 passage corpus encoded by cos-DPR Distil. (Lucene 9)", + "lucene-hnsw.msmarco-v1-passage-cos-dpr-distil.20231124.9d3427.tar.gz", + new String[] { + "https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-hnsw.msmarco-v1-passage-cos-dpr-distil.20231124.9d3427.tar.gz" }, + "7aa825e292a411abbe1585fb4d9f20ee"), + + MSMARCO_V1_PASSAGE_SPLADE_PP_ED("msmarco-v1-passage-splade-pp-ed", + "Lucene impact index of the MS MARCO passage corpus encoded by SPLADE++ CoCondenser-EnsembleDistil. (Lucene 9)", + "lucene-index.msmarco-v1-passage-splade-pp-ed.20230524.a59610.tar.gz", + new String[] { + "https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.msmarco-v1-passage-splade-pp-ed.20230524.a59610.tar.gz" }, + "4b3c969033cbd017306df42ce134c395"); public final String indexName; public final String description; public final String filename; - public final String readme; public final String[] urls; public final String md5; - public final String size; - public final String totalTerms; - public final String totalDocs; - public final String totalUniqueTerms; - public final boolean downloaded; - - // constructor with all 11 fields - IndexInfo(String indexName, String description, String filename, String readme, String[] urls, String md5, - String size, String totalTerms, String totalDocs, String totalUniqueTerms, boolean downloaded) { - this.indexName = indexName; - this.description = description; - this.filename = filename; - this.readme = readme; - this.urls = urls; - this.md5 = md5; - this.size = size; - this.totalTerms = totalTerms; - this.totalDocs = totalDocs; - this.totalUniqueTerms = totalUniqueTerms; - this.downloaded = downloaded; - } - // constructor with 9 fields - IndexInfo(String indexName, String description, String filename, String[] urls, String md5, String size, - String totalTerms, String totalDocs, String totalUniqueTerms) { + IndexInfo(String indexName, String description, String filename, String[] urls, String md5) { this.indexName = indexName; this.description = description; this.filename = filename; - this.readme = ""; this.urls = urls; this.md5 = md5; - this.size = size; - this.totalTerms = totalTerms; - this.totalDocs = totalDocs; - this.totalUniqueTerms = totalUniqueTerms; - this.downloaded = false; } public static boolean contains(String indexName) { diff --git a/src/main/java/io/anserini/search/HnswDenseSearcher.java b/src/main/java/io/anserini/search/HnswDenseSearcher.java index c3b29a207e..822a8724a4 100644 --- a/src/main/java/io/anserini/search/HnswDenseSearcher.java +++ b/src/main/java/io/anserini/search/HnswDenseSearcher.java @@ -20,6 +20,8 @@ import io.anserini.encoder.dense.DenseEncoder; import io.anserini.index.Constants; import io.anserini.search.query.VectorQueryGenerator; +import io.anserini.util.PrebuiltIndexHandler; + import org.apache.commons.lang3.time.DurationFormatUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -36,6 +38,8 @@ import javax.annotation.Nullable; import java.io.Closeable; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.List; import java.util.SortedMap; @@ -79,8 +83,26 @@ public HnswDenseSearcher(Args args) { // We might not be able to successfully create a reader for a variety of reasons, anything from path doesn't exist // to corrupt index. Gather all possible exceptions together as an unchecked exception to make initialization and // error reporting clearer. + Path indexPath = Path.of(args.index); + PrebuiltIndexHandler indexHandler = new PrebuiltIndexHandler(args.index); + if (!Files.exists(indexPath)) { + // it doesn't exist locally, we try to download it from remote + try { + indexHandler.initialize(); + indexHandler.download(); + indexPath = Path.of(indexHandler.decompressIndex()); + } catch (IOException e) { + throw new RuntimeException("MD5 checksum does not match!"); + } catch (Exception e) { + throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid index.", args.index)); + } + } else { + // if it exists locally, we use it + indexPath = Paths.get(args.index); + } + try { - this.reader = DirectoryReader.open(FSDirectory.open(Paths.get(args.index))); + this.reader = DirectoryReader.open(FSDirectory.open(indexPath)); } catch (IOException e) { throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid index.", args.index)); } diff --git a/src/main/python/e2e_sparse_dense_lucene/reproduction.py b/src/main/python/e2e_sparse_dense_lucene/reproduction.py new file mode 100644 index 0000000000..06900390df --- /dev/null +++ b/src/main/python/e2e_sparse_dense_lucene/reproduction.py @@ -0,0 +1,92 @@ +# Anserini: A toolkit for reproducible information retrieval research built on Lucene +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml +from typing import Union, Dict, List, Optional, Any +import os +import subprocess +TOPIC_NAMES = ['msmarco-passage-dev-subset', 'dl19-passage', 'dl20-passage'] +EVAL_CMD_MAP = { + 'map': '-m map -c -l 2', # AP + 'ndcg_cut_10': '-m ndcg_cut.10 -c', # nDCG@10 + 'recall_1000_msmarco': '-c -m recall.1000', # R@1000 for MS MARCO + 'recall_1000': '-m recall.1000 -c -l 2', # R@1000 + 'recip_rank': '-c -M 10 -m recip_rank' # RR@10 +} +TOPIC_EVAL_MAP = { + 'msmarco-passage-dev-subset': ['recip_rank', 'recall_1000_msmarco'], + 'dl19-passage': ['map', 'ndcg_cut_10', 'recall_1000'], + 'dl20-passage': ['map', 'ndcg_cut_10', 'recall_1000'] +} + + +def get_output_run_file_name(topic: str, name: str): + return f'runs/{topic}_{name}.txt' + + +def get_search_command(model_name: str, cmd_template: str, topics: List[str]): + outputs = [get_output_run_file_name( + topic_name, model_name) for topic_name in TOPIC_NAMES] + + for topic, output in zip(topics, outputs): + cmd = cmd_template.format(topic=topic, output=output) + yield cmd + + +def get_eval_command(param: str, qrel: str, run_file: str, cmd_template: str): + cmd = cmd_template.format( + param=param, qrel=qrel, output=run_file) + yield cmd + + +def main(config): + # print all search commands + for model_name, model_config in config['collections'].items(): + print("running model: ", model_name) + # # search + # for cmd in get_search_command(model_name, model_config['search_command'], model_config['topics']): + # p = subprocess.Popen( + # cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # stdout, stderr = p.communicate() + # if stderr: + # print(stderr.decode('utf-8')) + + # eval + expected_results = model_config['results'] + run_files = [get_output_run_file_name( + topic_name, model_name) for topic_name in TOPIC_NAMES] + eval_cmd = model_config['eval_command'] + metric_precision = model_config['metric_precision'] + + for run_file, topic_name, qrel in zip(run_files, TOPIC_NAMES, model_config['qrels']): + for metric in TOPIC_EVAL_MAP[topic_name]: + for cmd in get_eval_command(EVAL_CMD_MAP[metric], qrel, run_file, eval_cmd): + p = subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = p.communicate() + stdout = [out.strip() + for out in stdout.decode('utf-8').split('\t')] + actual_result = round(float(stdout[-1]), metric_precision) + expected_result = expected_results[topic_name][metric] + assert actual_result == expected_result, f'{model_name} {topic_name} {metric} {actual_result} != {expected_result}, expected: {expected_results[topic_name]}' + print( + f"{topic_name} {metric} {actual_result} == {expected_result}") + print(f"{model_name} passed!") + print("="*50) + + +if __name__ == '__main__': + with open('src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + main(config) diff --git a/src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml b/src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml new file mode 100644 index 0000000000..3ad58970c9 --- /dev/null +++ b/src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml @@ -0,0 +1,78 @@ +--- +collections: + bm25: + name: bm25 + search_command: target/appassembler/bin/SearchCollection -index msmarco-v1-passage -topicReader TsvInt -topics {topic} -output {output} -bm25 -parallelism 12 + topics: + - tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/topics.dl19-passage.txt + - tools/topics-and-qrels/topics.dl20.txt + qrels: + - tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/qrels.dl19-passage.txt + - tools/topics-and-qrels/qrels.dl20-passage.txt + + eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output} + results: + msmarco-passage-dev-subset: + recip_rank: 0.184 + recall_1000_msmarco: 0.853 + dl19-passage: + map: 0.301 + ndcg_cut_10: 0.506 + recall_1000: 0.750 + dl20-passage: + map: 0.286 + ndcg_cut_10: 0.480 + recall_1000: 0.786 + metric_precision: 3 + cosdpr-distil: + name: cosdpr-distil + search_command: target/appassembler/bin/SearchHnswDenseVectors -index msmarco-v1-passage-cos-dpr-distil -topicReader TsvInt -topics {topic} -output {output} -generator VectorQueryGenerator -topicField title -threads 12 -hits 1000 -efSearch 1000 -encoder CosDprDistil + topics: + - tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/topics.dl19-passage.txt + - tools/topics-and-qrels/topics.dl20.txt + qrels: + - tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/qrels.dl19-passage.txt + - tools/topics-and-qrels/qrels.dl20-passage.txt + eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output} + results: + msmarco-passage-dev-subset: + recip_rank: 0.389 + recall_1000_msmarco: 0.975 + dl19-passage: + map: 0.466 + ndcg_cut_10: 0.725 + recall_1000: 0.822 + dl20-passage: + map: 0.487 + ndcg_cut_10: 0.703 + recall_1000: 0.852 + metric_precision: 3 + splade-pp-ed: + name: splade-pp-ed + search_command: target/appassembler/bin/SearchCollection -index msmarco-v1-passage-splade-pp-ed -topicReader TsvInt -topics {topic} -output {output} -impact -pretokenized -parallelism 12 -encoder SpladePlusPlusEnsembleDistil + topics: + - tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/topics.dl19-passage.txt + - tools/topics-and-qrels/topics.dl20.txt + qrels: + - tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt + - tools/topics-and-qrels/qrels.dl19-passage.txt + - tools/topics-and-qrels/qrels.dl20-passage.txt + eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output} + results: + msmarco-passage-dev-subset: + recip_rank: 0.383 + recall_1000_msmarco: 0.983 + dl19-passage: + map: 0.505 + ndcg_cut_10: 0.731 + recall_1000: 0.873 + dl20-passage: + map: 0.500 + ndcg_cut_10: 0.720 + recall_1000: 0.900 + metric_precision: 3 \ No newline at end of file diff --git a/src/test/java/io/anserini/index/PrebuiltIndexTest.java b/src/test/java/io/anserini/index/PrebuiltIndexTest.java index c00aed4f23..478ea81837 100644 --- a/src/test/java/io/anserini/index/PrebuiltIndexTest.java +++ b/src/test/java/io/anserini/index/PrebuiltIndexTest.java @@ -56,6 +56,6 @@ public void testUrls() { // test number of prebuilt-indexes @Test public void testNumPrebuiltIndexes() { - assert IndexInfo.values().length == 2; + assert IndexInfo.values().length == 4; } }