-
Notifications
You must be signed in to change notification settings - Fork 374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LTR document searcher and docs to reproduce #835
Changes from all commits
71c9440
b040731
d79319c
18852cf
f2fe3af
7eaa0af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Pyserini: Learning-To-Rank Reranking Baseline for MS MARCO Document | ||
|
||
This guide contains instructions for running learning-to-rank baseline on the [MS MARCO *document* reranking task](https://microsoft.github.io/msmarco/). | ||
Learning-to-rank serves as a second stage reranker after BM25 retrieval. | ||
Note, we use sliding window and maxP strategy here. | ||
|
||
## Data Preprocessing | ||
|
||
We're going to use the repository's root directory as the working directory. | ||
|
||
First, we need to download and extract the MS MARCO document dataset: | ||
|
||
```bash | ||
mkdir collections/msmarco-doc | ||
wget https://git.uwaterloo.ca/jimmylin/doc2query-data/raw/master/T5-doc/msmarco-docs.tsv.gz -P collections/msmarco-doc | ||
wget https://git.uwaterloo.ca/jimmylin/doc2query-data/raw/master/T5-doc/msmarco_doc_passage_ids.txt -P collections/msmarco-doc | ||
``` | ||
|
||
We will need to generate collection of passage segments. Here, we use segment size 3 and stride is 1. | ||
```bash | ||
python scripts/ltr_msmarco-document/convert_msmarco_passage_doc_to_anserini.py \ | ||
--original_docs_path collections/msmarco-doc/msmarco-docs.tsv.gz \ | ||
--doc_ids_path collections/msmarco-doc/msmarco_doc_passage_ids.txt \ | ||
--output_docs_path collections/msmarco-doc/msmarco_pass_doc.jsonl | ||
``` | ||
|
||
Let's first get bag-of-words 10000 hits for segments as our LTR reranking candidates. | ||
```bash | ||
python scripts/ltr_msmarco-passage/convert_collection_to_jsonl.py --collection-path collections/msmarco-doc/msmarco_pass_doc.jsonl --output-folder collections/msmarco-doc/msmarco_pass_doc/ | ||
|
||
python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \ | ||
-threads 21 -input collections/msmarco-doc/msmarco_pass_doc \ | ||
-index indexes/lucene-index-msmarco-doc-passage -storePositions -storeDocvectors -storeRaw | ||
|
||
python -m pyserini.search --topics msmarco-doc-dev \ | ||
--index indexes/lucene-index-msmarco-doc-passage \ | ||
--output collections/msmarco-doc/run.msmarco-pass-doc.bm25.txt \ | ||
--bm25 --output-format trec --hits 10000 | ||
``` | ||
|
||
Now, we prepare queries for LTR: | ||
```bash | ||
mkdir collections/msmarco-ltr-document | ||
|
||
python scripts/ltr_msmarco-passage/convert_queries.py \ | ||
--input tools/topics-and-qrels/topics.msmarco-doc.dev.txt \ | ||
--output collections/msmarco-ltr-document/queries.dev.small.json | ||
|
||
``` | ||
|
||
Prepare the LTR index: | ||
|
||
```bash | ||
python scripts/ltr_msmarco-document/convert_passage_doc.py \ | ||
--input collections/msmarco-doc/msmarco_pass_doc.jsonl \ | ||
--output collections/msmarco-ltr-document/ltr_msmarco_pass_doc.jsonl \ | ||
--proc_qty 10 | ||
``` | ||
|
||
The above script will convert the collection and queries to json files with `text_unlemm`, `analyzed`, `text_bert_tok` and `raw` fields. | ||
Next, we need to convert the MS MARCO json collection into Anserini's jsonl files (which have one json object per line): | ||
|
||
```bash | ||
python scripts/ltr_msmarco-passage/convert_collection_to_jsonl.py \ | ||
--collection-path collections/msmarco-ltr-document/ltr_msmarco_pass_doc.jsonl \ | ||
--output-folder collections/msmarco-ltr-document/ltr_msmarco_pass_doc_jsonl | ||
``` | ||
We can now index these docs as a `JsonCollection` using Anserini with pretokenized option: | ||
|
||
```bash | ||
python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \ | ||
-threads 21 -input collections/msmarco-ltr-document/ltr_msmarco_pass_doc_jsonl \ | ||
-index indexes/lucene-index-msmarco-document-ltr -storePositions -storeDocvectors -storeRaw -pretokenized | ||
``` | ||
|
||
Note that pretokenized option let Anserini use whitespace analyzer so that do not break our preprocessed tokenization. | ||
|
||
Download pretrained IBM models: | ||
|
||
```bash | ||
wget https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz -P collections/msmarco-ltr-document/ | ||
tar -xzvf collections/msmarco-ltr-document/ibm_model.tar.gz -C collections/msmarco-ltr-document/ | ||
``` | ||
|
||
Download our pretrained LTR model: | ||
|
||
```bash | ||
wget https://www.dropbox.com/s/ffl2bfw4cd5ngyz/msmarco-passage-ltr-mrr-v1.tar.gz -P runs/ | ||
tar -xzvf runs/msmarco-passage-ltr-mrr-v1.tar.gz -C runs | ||
``` | ||
|
||
Now, we have all things ready and can run inference. The LTR outpus rankings on segments level. We will need to use another script to get doc level results using maxP strategy. | ||
```bash | ||
python -m pyserini.ltr.search_msmarco_document \ | ||
--input collections/msmarco-doc/run.msmarco-pass-doc.bm25.txt \ | ||
--input-format trec \ | ||
--model runs/msmarco-passage-ltr-mrr-v1 \ | ||
--index indexes/lucene-index-msmarco-document-ltr --output runs/run.ltr.doc-pas.trec | ||
|
||
python scripts/ltr_msmarco-document/generate_document_score_withmaxP.py \ | ||
--input runs/run.ltr.doc-pas.trec \ | ||
--output runs/run.ltr.doc_level.tsv | ||
``` | ||
|
||
```bash | ||
python tools/scripts/msmarco/msmarco_doc_eval.py \ | ||
--judgments tools/topics-and-qrels/qrels.msmarco-doc.dev.txt \ | ||
--run runs/run.ltr.doc_level.tsv | ||
|
||
``` | ||
The above evaluation should give your results as below. | ||
```bash | ||
##################### | ||
MRR @100: 0.3105532197278601 | ||
QueriesRanked: 5193 | ||
##################### | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# | ||
# Pyserini: Reproducible IR research with sparse and dense representations | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from ._search_msmarco_document import MsmarcoDocumentLtrSearcher | ||
__all__ = ['MsmarcoDocumentLtrSearcher'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# | ||
# Pyserini: Reproducible IR research with sparse and dense representations | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import sys | ||
|
||
# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). | ||
# Comment these lines out to use a pip-installed one instead. | ||
sys.path.insert(0, './') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this really suggests that this script should be in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For that matter, shouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we reorganize the LTR pipeline according to #611, in order to have something like If the main goes to scripts/, it will become sth They have many duplicate codes since the document exp uses passage-like segments in LTR process. Or I change codes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be The reason is that |
||
|
||
import argparse | ||
import json | ||
import multiprocessing | ||
import os | ||
import pickle | ||
import time | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
from pyserini.ltr.search_msmarco_document._search_msmarco_document import MsmarcoDocumentLtrSearcher | ||
from pyserini.ltr import * | ||
from pyserini.index import IndexReader | ||
|
||
""" | ||
Running prediction on candidates | ||
""" | ||
def dev_data_loader(file, format, index, top): | ||
if format == 'tsv': | ||
dev = pd.read_csv(file, sep="\t", | ||
names=['qid', 'pid', 'rank'], | ||
dtype={'qid': 'S','pid': 'S', 'rank':'i',}) | ||
elif format == 'trec': | ||
dev = pd.read_csv(file, sep="\s+", | ||
names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'], | ||
usecols=['qid', 'pid', 'rank'], | ||
dtype={'qid': 'S','pid': 'S', 'rank':'i',}) | ||
else: | ||
raise Exception('unknown parameters') | ||
assert dev['qid'].dtype == np.object | ||
assert dev['pid'].dtype == np.object | ||
assert dev['rank'].dtype == np.int32 | ||
dev = dev[dev['rank']<=top] | ||
print(dev.shape) | ||
dev_qrel = pd.read_csv('tools/topics-and-qrels/qrels.msmarco-doc.dev.txt', sep="\t", | ||
names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], | ||
dtype={'qid': 'S','pid': 'S', 'rel':'i'}) | ||
assert dev['qid'].dtype == np.object | ||
assert dev['pid'].dtype == np.object | ||
assert dev['rank'].dtype == np.int32 | ||
dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left') | ||
dev['rel'] = dev['rel'].fillna(0).astype(np.int32) | ||
dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) | ||
|
||
print(dev.shape) | ||
print(dev.index.get_level_values('qid').drop_duplicates().shape) | ||
print(dev.groupby('qid').count().mean()) | ||
print(dev.head(10)) | ||
print(dev.info()) | ||
|
||
return dev, dev_qrel | ||
|
||
def query_loader(): | ||
queries = {} | ||
with open(f'{args.queries}/queries.dev.small.json') as f: | ||
for line in f: | ||
query = json.loads(line) | ||
qid = query.pop('id') | ||
query['analyzed'] = query['analyzed'].split(" ") | ||
query['text'] = query['text_unlemm'].split(" ") | ||
query['text_unlemm'] = query['text_unlemm'].split(" ") | ||
query['text_bert_tok'] = query['text_bert_tok'].split(" ") | ||
queries[qid] = query | ||
return queries | ||
|
||
|
||
def eval_mrr(dev_data): | ||
score_tie_counter = 0 | ||
score_tie_query = set() | ||
MRR = [] | ||
for qid, group in tqdm(dev_data.groupby('qid')): | ||
group = group.reset_index() | ||
rank = 0 | ||
prev_score = None | ||
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | ||
# stable sort is also used in LightGBM | ||
|
||
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | ||
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | ||
score_tie_counter += 1 | ||
score_tie_query.add(qid) | ||
prev_score = t.score | ||
rank += 1 | ||
if t.rel > 0: | ||
MRR.append(1.0 / rank) | ||
break | ||
elif rank == 10 or rank == len(group): | ||
MRR.append(0.) | ||
break | ||
|
||
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | ||
print(score_tie) | ||
mrr_10 = np.mean(MRR).item() | ||
print(f'MRR@10:{mrr_10} with {len(MRR)} queries') | ||
return {'score_tie': score_tie, 'mrr_10': mrr_10} | ||
|
||
|
||
def eval_recall(dev_qrel, dev_data): | ||
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] | ||
|
||
score_tie_counter = 0 | ||
score_tie_query = set() | ||
|
||
recall_point = [10,20,50,100,200,250,300,333,400,500,1000] | ||
recall_curve = {k: [] for k in recall_point} | ||
for qid, group in tqdm(dev_data.groupby('qid')): | ||
group = group.reset_index() | ||
rank = 0 | ||
prev_score = None | ||
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | ||
# stable sort is also used in LightGBM | ||
total_rel = dev_rel_num.loc[qid] | ||
query_recall = [0 for k in recall_point] | ||
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | ||
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | ||
score_tie_counter += 1 | ||
score_tie_query.add(qid) | ||
prev_score = t.score | ||
rank += 1 | ||
if t.rel > 0: | ||
for i, p in enumerate(recall_point): | ||
if rank <= p: | ||
query_recall[i] += 1 | ||
for i, p in enumerate(recall_point): | ||
if total_rel > 0: | ||
recall_curve[p].append(query_recall[i] / total_rel) | ||
else: | ||
recall_curve[p].append(0.) | ||
|
||
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | ||
print(score_tie) | ||
res = {'score_tie': score_tie} | ||
|
||
for k, v in recall_curve.items(): | ||
avg = np.mean(v) | ||
print(f'recall@{k}:{avg}') | ||
res[f'recall@{k}'] = avg | ||
|
||
return res | ||
|
||
|
||
def output(file, dev_data,format): | ||
score_tie_counter = 0 | ||
score_tie_query = set() | ||
output_file = open(file,'w') | ||
|
||
for qid, group in tqdm(dev_data.groupby('qid')): | ||
group = group.reset_index() | ||
rank = 0 | ||
prev_score = None | ||
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | ||
# stable sort is also used in LightGBM | ||
|
||
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | ||
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | ||
score_tie_counter += 1 | ||
score_tie_query.add(qid) | ||
prev_score = t.score | ||
rank += 1 | ||
if (format == 'tsv'): | ||
output_file.write(f"{qid}\t{t.pid}\t{rank}\n") | ||
else: | ||
output_file.write(f"{qid}\tq0\t{t.pid}\t{rank}\t{t.score}\tltr\n") | ||
|
||
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | ||
print(score_tie) | ||
|
||
if __name__ == "__main__": | ||
os.environ["ANSERINI_CLASSPATH"] = "./pyserini/resources/jars" | ||
parser = argparse.ArgumentParser(description='Learning to rank reranking') | ||
parser.add_argument('--input', required=True) | ||
parser.add_argument('--reranking-top', type=int, default=10000) | ||
parser.add_argument('--input-format', required=True) | ||
parser.add_argument('--model', required=True) | ||
parser.add_argument('--index', required=True) | ||
parser.add_argument('--output', required=True) | ||
parser.add_argument('--ibm-model',default='./collections/msmarco-ltr-document/ibm_model/') | ||
parser.add_argument('--queries',default='./collections/msmarco-ltr-document/') | ||
parser.add_argument('--output-format',default='trec') | ||
|
||
args = parser.parse_args() | ||
print("load dev") | ||
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.index, args.reranking_top) | ||
print("load queries") | ||
queries = query_loader() | ||
searcher = MsmarcoDocumentLtrSearcher(args.model, args.ibm_model, args.index) | ||
searcher.add_fe() | ||
|
||
batch_info = searcher.search(dev, queries) | ||
del dev, queries | ||
|
||
eval_res = eval_mrr(batch_info) | ||
eval_recall(dev_qrel, batch_info) | ||
output(args.output, batch_info, args.output_format) | ||
print('Done!') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a lot of code duplicate here with: https://github.com/castorini/pyserini/blob/master/pyserini/ltr/search_msmarco_passage/__main__.py