Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LTR refactoring for modularization #636

Merged
merged 8 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pyserini: Learning-To-Rank Baseline for MS MARCO Passage
# Pyserini: Learning-To-Rank Reranking Baseline for MS MARCO Passage

This guide contains instructions for running learning-to-rank baseline on the [MS MARCO *passage* reranking task](https://microsoft.github.io/msmarco/).
Learning-to-rank serves as a second stage reranker after BM25 retrieval.
Expand Down Expand Up @@ -52,15 +52,17 @@ tar -xzvf runs/msmarco-passage-ltr-mrr-v1.tar.gz -C runs
Next we can run our inference script to get our reranking result.

```bash
python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py \
python -m pyserini.ltr.search_msmarco_passage \
--input runs/run.msmarco-passage.bm25tuned.txt \
--input-format tsv \
--model runs/msmarco-passage-ltr-mrr-v1 \
--index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 \
--output runs/run.ltr.msmarco-passage.tsv
```

Here, our model is trained to maximize MRR@10.
Here, our model is trained to maximize MRR@10.

Note that we can also train other models from scratch follow [training guide](experiments-ltr-msmarco-passage-training.md), and replace `--model` argument with your trained model dir.

Inference speed will vary, on orca, it takes ~0.25s/query.

Expand Down Expand Up @@ -99,20 +101,6 @@ Average precision or AP (also called mean average precision, MAP) and recall@100
AP captures aspects of both precision and recall in a single metric, and is the most common metric used by information retrieval researchers.
On the other hand, recall@1000 provides the upper bound effectiveness of downstream reranking modules (i.e., rerankers are useless if there isn't a relevant document in the results).

## Training the Model From Scratch

```bash
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/
gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz
```
First download the file which has training triples and uncompress it.

```bash
python scripts/ltr_msmarco-passage/train_ltr_model.py \
--index indexes/index-msmarco-passage-ltr-20210519-e25e33f
```
The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--ltr_model_path` parameter for `predict_passage.py`.

## Building the Index From Scratch

Equivalently, we can preprocess collection and queries with our scripts:
Expand Down
62 changes: 62 additions & 0 deletions docs/experiments-ltr-msmarco-passage-training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Pyserini: Train Learning-To-Rank Reranking Models for MS MARCO Passage

## Data Preprocessing

Please first follow the [Pyserini BM25 retrieval guide](experiments-msmarco-passage.md) to obtain our reranking candidate.

```bash
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/
gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz
```
Then, download the file which has training triples and uncompress it.

Next, we're going to use `collections/msmarco-ltr-passage/` as the working directory to download pre processed data.

```bash
mkdir collections/msmarco-ltr-passage/

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.eval.small.tsv \
--output collections/msmarco-ltr-passage/queries.eval.small.json

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.dev.small.tsv \
--output collections/msmarco-ltr-passage/queries.dev.small.json

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.train.tsv \
--output collections/msmarco-ltr-passage/queries.train.json
```

The above scripts convert queries to json objects with `text`, `text_unlemm`, `raw`, and `text_bert_tok` fields.
The first two scripts take ~1 min and the third one is a bit longer (~1.5h).

```bash
python -c "from pyserini.search import SimpleSearcher; SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')"
```

We run the above commands to obtain pre-built index in cache.

Note you can also build index from scratch follow [this guide](./experiments-ltr-msmarco-passage-reranking.md#L104).

```bash
wget https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz -P collections/msmarco-ltr-passage/
tar -xzvf collections/msmarco-ltr-passage/ibm_model.tar.gz -C collections/msmarco-ltr-passage/
```
Download pretrained IBM models:

## Training the Model From Scratch
```bash
python scripts/ltr_msmarco-passage/train_ltr_model.py \
--index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3
```
The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--model` parameter for [reranking](experiments-ltr-msmarco-passage-reranking.md#L58).

Number of negative samples used in training can be changed by `--neg-sample`, by default is 10.

## Change the Optmization Goal of Your Trained Model
The script trains a model which optimizes MRR@10 by default.

You can change the `mrr_at_10` of [this function](../scripts/ltr_msmarco-passage/train_ltr_model.py#L621) and [here](../scripts/ltr_msmarco-passage/train_ltr_model.py#L358) to `recall_at_20` to train a model which optimizes recall@20.

You can also self defined a function format like [this](../scripts/ltr_msmarco-passage/train_ltr_model.py#L300) and change corresponding places mentioned above to have different optimization goal.
7 changes: 2 additions & 5 deletions integrations/sparse/test_ltr_msmarco_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,8 @@ def test_reranking(self):
os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
#queries process
os.system('python scripts/ltr_msmarco-passage/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json')
if(os.getcwd().endswith('sparse')):
os.system(f'python ../../scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
else:
os.system(f'python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
os.system(f'python -m pyserini.ltr.search_msmarco_passage --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
a,b = result.find('#####################\nMRR @10:'), result.find('\nQueriesRanked: 6980\n#####################\n')
mrr = result[a+31:b]
self.assertAlmostEqual(float(mrr),0.24709612498294367, delta=0.000001)
Expand Down
18 changes: 18 additions & 0 deletions pyserini/ltr/search_msmarco_passage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ._search_msmarco_passage import MsmarcoPassageLtrSearcher
__all__ = ['MsmarcoPassageLtrSearcher']
238 changes: 238 additions & 0 deletions pyserini/ltr/search_msmarco_passage/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')

import argparse
import json
import multiprocessing
import os
import pickle
import time

import numpy as np
import pandas as pd
from tqdm import tqdm
from pyserini.ltr.search_msmarco_passage._search_msmarco_passage import MsmarcoPassageLtrSearcher
from pyserini.ltr import *

"""
Running prediction on candidates
"""
def dev_data_loader(file, format, top=100):
if format == 'tsv':
dev = pd.read_csv(file, sep="\t",
names=['qid', 'pid', 'rank'],
dtype={'qid': 'S','pid': 'S', 'rank':'i',})
elif format == 'trec':
dev = pd.read_csv(file, sep="\s+",
names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'],
usecols=['qid', 'pid', 'rank'],
dtype={'qid': 'S','pid': 'S', 'rank':'i',})
else:
raise Exception('unknown parameters')
assert dev['qid'].dtype == np.object
assert dev['pid'].dtype == np.object
assert dev['rank'].dtype == np.int32
dev = dev[dev['rank']<=top]
dev_qrel = pd.read_csv('tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt', sep=" ",
names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'],
dtype={'qid': 'S','pid': 'S', 'rel':'i'})
assert dev['qid'].dtype == np.object
assert dev['pid'].dtype == np.object
assert dev['rank'].dtype == np.int32
dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid'])

print(dev.shape)
print(dev.index.get_level_values('qid').drop_duplicates().shape)
print(dev.groupby('qid').count().mean())
print(dev.head(10))
print(dev.info())

dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel']

recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000]
recall_curve = {k: [] for k in recall_point}
for qid, group in tqdm(dev.groupby('qid')):
group = group.reset_index()
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
total_rel = dev_rel_num.loc[qid]
query_recall = [0 for k in recall_point]
for t in group.sort_values('rank').itertuples():
if t.rel > 0:
for i, p in enumerate(recall_point):
if t.rank <= p:
query_recall[i] += 1
for i, p in enumerate(recall_point):
if total_rel > 0:
recall_curve[p].append(query_recall[i] / total_rel)
else:
recall_curve[p].append(0.)

for k, v in recall_curve.items():
avg = np.mean(v)
print(f'recall@{k}:{avg}')

return dev, dev_qrel


def query_loader():
queries = {}
with open(f'{args.queries}/queries.dev.small.json') as f:
for line in f:
query = json.loads(line)
qid = query.pop('id')
query['analyzed'] = query['analyzed'].split(" ")
query['text'] = query['text_unlemm'].split(" ")
query['text_unlemm'] = query['text_unlemm'].split(" ")
query['text_bert_tok'] = query['text_bert_tok'].split(" ")
queries[qid] = query
return queries


def eval_mrr(dev_data):
score_tie_counter = 0
score_tie_query = set()
MRR = []
for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM

for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
if t.rel > 0:
MRR.append(1.0 / rank)
break
elif rank == 10 or rank == len(group):
MRR.append(0.)
break

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)
mrr_10 = np.mean(MRR).item()
print(f'MRR@10:{mrr_10} with {len(MRR)} queries')
return {'score_tie': score_tie, 'mrr_10': mrr_10}


def eval_recall(dev_qrel, dev_data):
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel']

score_tie_counter = 0
score_tie_query = set()

recall_point = [10,20,50,100,200,250,300,333,400,500,1000]
recall_curve = {k: [] for k in recall_point}
for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM
total_rel = dev_rel_num.loc[qid]
query_recall = [0 for k in recall_point]
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
if t.rel > 0:
for i, p in enumerate(recall_point):
if rank <= p:
query_recall[i] += 1
for i, p in enumerate(recall_point):
if total_rel > 0:
recall_curve[p].append(query_recall[i] / total_rel)
else:
recall_curve[p].append(0.)

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)
res = {'score_tie': score_tie}

for k, v in recall_curve.items():
avg = np.mean(v)
print(f'recall@{k}:{avg}')
res[f'recall@{k}'] = avg

return res


def output(file, dev_data):
score_tie_counter = 0
score_tie_query = set()
output_file = open(file,'w')

for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM

for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
output_file.write(f"{qid}\t{t.pid}\t{rank}\n")

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)

if __name__ == "__main__":
os.environ["ANSERINI_CLASSPATH"] = "./pyserini/resources/jars"
parser = argparse.ArgumentParser(description='Learning to rank reranking')
parser.add_argument('--input', required=True)
parser.add_argument('--reranking-top', type=int, default=1000)
parser.add_argument('--input-format', required=True)
parser.add_argument('--model', required=True)
parser.add_argument('--index', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--ibm-model',default='./collections/msmarco-ltr-passage/ibm_model/')
parser.add_argument('--queries',default='./collections/msmarco-ltr-passage/')

args = parser.parse_args()
searcher = MsmarcoPassageLtrSearcher(args.model, args.ibm_model, args.index)
searcher.add_fe()
print("load dev")
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.reranking_top)
print("load queries")
queries = query_loader()

batch_info = searcher.search(dev, queries)
del dev, queries

eval_res = eval_mrr(batch_info)
eval_recall(dev_qrel, batch_info)
output(args.output, batch_info)
print('Done!')


Loading