Skip to content

Commit

Permalink
LTR refactoring for modularization (#636)
Browse files Browse the repository at this point in the history
Modularization LTR reranking and split ltr doc
  • Loading branch information
stephaniewhoo committed Jun 3, 2021
1 parent 90521b0 commit c7b37d6
Show file tree
Hide file tree
Showing 8 changed files with 565 additions and 474 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Pyserini: Learning-To-Rank Baseline for MS MARCO Passage
# Pyserini: Learning-To-Rank Reranking Baseline for MS MARCO Passage

This guide contains instructions for running learning-to-rank baseline on the [MS MARCO *passage* reranking task](https://microsoft.github.io/msmarco/).
Learning-to-rank serves as a second stage reranker after BM25 retrieval.
Expand Down Expand Up @@ -52,15 +52,17 @@ tar -xzvf runs/msmarco-passage-ltr-mrr-v1.tar.gz -C runs
Next we can run our inference script to get our reranking result.

```bash
python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py \
python -m pyserini.ltr.search_msmarco_passage \
--input runs/run.msmarco-passage.bm25tuned.txt \
--input-format tsv \
--model runs/msmarco-passage-ltr-mrr-v1 \
--index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 \
--output runs/run.ltr.msmarco-passage.tsv
```

Here, our model is trained to maximize MRR@10.
Here, our model is trained to maximize MRR@10.

Note that we can also train other models from scratch follow [training guide](experiments-ltr-msmarco-passage-training.md), and replace `--model` argument with your trained model dir.

Inference speed will vary, on orca, it takes ~0.25s/query.

Expand Down Expand Up @@ -99,20 +101,6 @@ Average precision or AP (also called mean average precision, MAP) and recall@100
AP captures aspects of both precision and recall in a single metric, and is the most common metric used by information retrieval researchers.
On the other hand, recall@1000 provides the upper bound effectiveness of downstream reranking modules (i.e., rerankers are useless if there isn't a relevant document in the results).

## Training the Model From Scratch

```bash
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/
gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz
```
First download the file which has training triples and uncompress it.

```bash
python scripts/ltr_msmarco-passage/train_ltr_model.py \
--index indexes/index-msmarco-passage-ltr-20210519-e25e33f
```
The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--ltr_model_path` parameter for `predict_passage.py`.

## Building the Index From Scratch

Equivalently, we can preprocess collection and queries with our scripts:
Expand Down
62 changes: 62 additions & 0 deletions docs/experiments-ltr-msmarco-passage-training.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Pyserini: Train Learning-To-Rank Reranking Models for MS MARCO Passage

## Data Preprocessing

Please first follow the [Pyserini BM25 retrieval guide](experiments-msmarco-passage.md) to obtain our reranking candidate.

```bash
wget https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz -P collections/msmarco-passage/
gzip -d collections/msmarco-passage/qidpidtriples.train.full.2.tsv.gz
```
Then, download the file which has training triples and uncompress it.

Next, we're going to use `collections/msmarco-ltr-passage/` as the working directory to download pre processed data.

```bash
mkdir collections/msmarco-ltr-passage/

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.eval.small.tsv \
--output collections/msmarco-ltr-passage/queries.eval.small.json

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.dev.small.tsv \
--output collections/msmarco-ltr-passage/queries.dev.small.json

python scripts/ltr_msmarco-passage/convert_queries.py \
--input collections/msmarco-passage/queries.train.tsv \
--output collections/msmarco-ltr-passage/queries.train.json
```

The above scripts convert queries to json objects with `text`, `text_unlemm`, `raw`, and `text_bert_tok` fields.
The first two scripts take ~1 min and the third one is a bit longer (~1.5h).

```bash
python -c "from pyserini.search import SimpleSearcher; SimpleSearcher.from_prebuilt_index('msmarco-passage-ltr')"
```

We run the above commands to obtain pre-built index in cache.

Note you can also build index from scratch follow [this guide](./experiments-ltr-msmarco-passage-reranking.md#L104).

```bash
wget https://www.dropbox.com/s/vlrfcz3vmr4nt0q/ibm_model.tar.gz -P collections/msmarco-ltr-passage/
tar -xzvf collections/msmarco-ltr-passage/ibm_model.tar.gz -C collections/msmarco-ltr-passage/
```
Download pretrained IBM models:

## Training the Model From Scratch
```bash
python scripts/ltr_msmarco-passage/train_ltr_model.py \
--index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3
```
The above scripts will train a model at `runs/` with your running date in the file name. You can use this as the `--model` parameter for [reranking](experiments-ltr-msmarco-passage-reranking.md#L58).

Number of negative samples used in training can be changed by `--neg-sample`, by default is 10.

## Change the Optmization Goal of Your Trained Model
The script trains a model which optimizes MRR@10 by default.

You can change the `mrr_at_10` of [this function](../scripts/ltr_msmarco-passage/train_ltr_model.py#L621) and [here](../scripts/ltr_msmarco-passage/train_ltr_model.py#L358) to `recall_at_20` to train a model which optimizes recall@20.

You can also self defined a function format like [this](../scripts/ltr_msmarco-passage/train_ltr_model.py#L300) and change corresponding places mentioned above to have different optimization goal.
7 changes: 2 additions & 5 deletions integrations/sparse/test_ltr_msmarco_passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,8 @@ def test_reranking(self):
os.system(f'tar -xzvf ltr_test/{ibm_model_tar_name} -C ltr_test')
#queries process
os.system('python scripts/ltr_msmarco-passage/convert_queries.py --input tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt --output ltr_test/queries.dev.small.json')
if(os.getcwd().endswith('sparse')):
os.system(f'python ../../scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
else:
os.system(f'python scripts/ltr_msmarco-passage/rerank_with_ltr_model.py --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
os.system(f'python -m pyserini.ltr.search_msmarco_passage --input ltr_test/{inp} --input-format tsv --model ltr_test/msmarco-passage-ltr-mrr-v1 --index ~/.cache/pyserini/indexes/index-msmarco-passage-ltr-20210519-e25e33f.a5de642c268ac1ed5892c069bdc29ae3 --ibm-model ltr_test/ibm_model/ --queries ltr_test --output ltr_test/{outp}')
result = subprocess.check_output(f'python tools/scripts/msmarco/msmarco_passage_eval.py tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt ltr_test/{outp}', shell=True).decode(sys.stdout.encoding)
a,b = result.find('#####################\nMRR @10:'), result.find('\nQueriesRanked: 6980\n#####################\n')
mrr = result[a+31:b]
self.assertAlmostEqual(float(mrr),0.24709612498294367, delta=0.000001)
Expand Down
18 changes: 18 additions & 0 deletions pyserini/ltr/search_msmarco_passage/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ._search_msmarco_passage import MsmarcoPassageLtrSearcher
__all__ = ['MsmarcoPassageLtrSearcher']
238 changes: 238 additions & 0 deletions pyserini/ltr/search_msmarco_passage/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys

# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one).
# Comment these lines out to use a pip-installed one instead.
sys.path.insert(0, './')

import argparse
import json
import multiprocessing
import os
import pickle
import time

import numpy as np
import pandas as pd
from tqdm import tqdm
from pyserini.ltr.search_msmarco_passage._search_msmarco_passage import MsmarcoPassageLtrSearcher
from pyserini.ltr import *

"""
Running prediction on candidates
"""
def dev_data_loader(file, format, top=100):
if format == 'tsv':
dev = pd.read_csv(file, sep="\t",
names=['qid', 'pid', 'rank'],
dtype={'qid': 'S','pid': 'S', 'rank':'i',})
elif format == 'trec':
dev = pd.read_csv(file, sep="\s+",
names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'],
usecols=['qid', 'pid', 'rank'],
dtype={'qid': 'S','pid': 'S', 'rank':'i',})
else:
raise Exception('unknown parameters')
assert dev['qid'].dtype == np.object
assert dev['pid'].dtype == np.object
assert dev['rank'].dtype == np.int32
dev = dev[dev['rank']<=top]
dev_qrel = pd.read_csv('tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt', sep=" ",
names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'],
dtype={'qid': 'S','pid': 'S', 'rel':'i'})
assert dev['qid'].dtype == np.object
assert dev['pid'].dtype == np.object
assert dev['rank'].dtype == np.int32
dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid'])

print(dev.shape)
print(dev.index.get_level_values('qid').drop_duplicates().shape)
print(dev.groupby('qid').count().mean())
print(dev.head(10))
print(dev.info())

dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel']

recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000]
recall_curve = {k: [] for k in recall_point}
for qid, group in tqdm(dev.groupby('qid')):
group = group.reset_index()
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
total_rel = dev_rel_num.loc[qid]
query_recall = [0 for k in recall_point]
for t in group.sort_values('rank').itertuples():
if t.rel > 0:
for i, p in enumerate(recall_point):
if t.rank <= p:
query_recall[i] += 1
for i, p in enumerate(recall_point):
if total_rel > 0:
recall_curve[p].append(query_recall[i] / total_rel)
else:
recall_curve[p].append(0.)

for k, v in recall_curve.items():
avg = np.mean(v)
print(f'recall@{k}:{avg}')

return dev, dev_qrel


def query_loader():
queries = {}
with open(f'{args.queries}/queries.dev.small.json') as f:
for line in f:
query = json.loads(line)
qid = query.pop('id')
query['analyzed'] = query['analyzed'].split(" ")
query['text'] = query['text_unlemm'].split(" ")
query['text_unlemm'] = query['text_unlemm'].split(" ")
query['text_bert_tok'] = query['text_bert_tok'].split(" ")
queries[qid] = query
return queries


def eval_mrr(dev_data):
score_tie_counter = 0
score_tie_query = set()
MRR = []
for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM

for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
if t.rel > 0:
MRR.append(1.0 / rank)
break
elif rank == 10 or rank == len(group):
MRR.append(0.)
break

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)
mrr_10 = np.mean(MRR).item()
print(f'MRR@10:{mrr_10} with {len(MRR)} queries')
return {'score_tie': score_tie, 'mrr_10': mrr_10}


def eval_recall(dev_qrel, dev_data):
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel']

score_tie_counter = 0
score_tie_query = set()

recall_point = [10,20,50,100,200,250,300,333,400,500,1000]
recall_curve = {k: [] for k in recall_point}
for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM
total_rel = dev_rel_num.loc[qid]
query_recall = [0 for k in recall_point]
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
if t.rel > 0:
for i, p in enumerate(recall_point):
if rank <= p:
query_recall[i] += 1
for i, p in enumerate(recall_point):
if total_rel > 0:
recall_curve[p].append(query_recall[i] / total_rel)
else:
recall_curve[p].append(0.)

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)
res = {'score_tie': score_tie}

for k, v in recall_curve.items():
avg = np.mean(v)
print(f'recall@{k}:{avg}')
res[f'recall@{k}'] = avg

return res


def output(file, dev_data):
score_tie_counter = 0
score_tie_query = set()
output_file = open(file,'w')

for qid, group in tqdm(dev_data.groupby('qid')):
group = group.reset_index()
rank = 0
prev_score = None
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
# stable sort is also used in LightGBM

for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
if prev_score is not None and abs(t.score - prev_score) < 1e-8:
score_tie_counter += 1
score_tie_query.add(qid)
prev_score = t.score
rank += 1
output_file.write(f"{qid}\t{t.pid}\t{rank}\n")

score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
print(score_tie)

if __name__ == "__main__":
os.environ["ANSERINI_CLASSPATH"] = "./pyserini/resources/jars"
parser = argparse.ArgumentParser(description='Learning to rank reranking')
parser.add_argument('--input', required=True)
parser.add_argument('--reranking-top', type=int, default=1000)
parser.add_argument('--input-format', required=True)
parser.add_argument('--model', required=True)
parser.add_argument('--index', required=True)
parser.add_argument('--output', required=True)
parser.add_argument('--ibm-model',default='./collections/msmarco-ltr-passage/ibm_model/')
parser.add_argument('--queries',default='./collections/msmarco-ltr-passage/')

args = parser.parse_args()
searcher = MsmarcoPassageLtrSearcher(args.model, args.ibm_model, args.index)
searcher.add_fe()
print("load dev")
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.reranking_top)
print("load queries")
queries = query_loader()

batch_info = searcher.search(dev, queries)
del dev, queries

eval_res = eval_mrr(batch_info)
eval_recall(dev_qrel, batch_info)
output(args.output, batch_info)
print('Done!')


Loading

0 comments on commit c7b37d6

Please sign in to comment.