Skip to content

Commit

Permalink
Add negative feedback option to Rocchio (#1218)
Browse files Browse the repository at this point in the history
* add negative rocchio reranker in pyserini

* add test for negative feedback in Rocchio
  • Loading branch information
yuki617 authored Jun 22, 2022
1 parent abd2790 commit 6914743
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
6 changes: 5 additions & 1 deletion pyserini/search/lucene/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def define_search_args(parser):

parser.add_argument('--rm3', action='store_true', help="Use RM3")
parser.add_argument('--rocchio', action='store_true', help="Use Rocchio")
parser.add_argument('--rocchio-use-negative', action='store_true', help="Use nonrelevant labels in Rocchio")
parser.add_argument('--qld', action='store_true', help="Use QLD")

parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en')
Expand Down Expand Up @@ -182,7 +183,10 @@ def define_search_args(parser):

if args.rocchio:
search_rankers.append('rocchio')
searcher.set_rocchio()
if args.rocchio_use_negative:
searcher.set_rocchio(gamma=0.15, use_negative=True)
else:
searcher.set_rocchio()

fields = dict()
if args.fields:
Expand Down
6 changes: 4 additions & 2 deletions pyserini/search/lucene/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def is_using_rm3(self) -> bool:
"""Check if RM3 query expansion is being performed."""
return self.object.useRM3()

def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10, rocchio_alpha=1, rocchio_beta=0.75, rocchio_gamma=0, rocchio_output_query=False):
def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10, alpha=1, beta=0.75, gamma=0, output_query=False, use_negative=False):
"""Configure Rocchio query expansion.
Parameters
Expand All @@ -262,9 +262,11 @@ def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, botto
Rocchio parameter for weight to assign to the nonrelevant document vector.
rocchio_output_query : bool
Print the original and expanded queries as debug output.
rocchio_use_negative : bool
Rocchio parameter to use negative labels.
"""
if self.object.reader.getTermVectors(0):
self.object.setRocchio(top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, rocchio_alpha, rocchio_beta, rocchio_gamma, rocchio_output_query)
self.object.setRocchio(top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, alpha, beta, gamma, output_query, use_negative)
else:
raise TypeError("Rocchio is not supported for indexes without document vectors.")

Expand Down
13 changes: 12 additions & 1 deletion tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_rocchio(self):
self.assertAlmostEqual(hits[9].score, 4.21740, places=5)

self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10,
bottom_fb_docs=8, rocchio_alpha=0.4, rocchio_beta=0.5, rocchio_gamma=0.1)
bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1, output_query=False, use_negative=True)
self.assertTrue(self.searcher.is_using_rocchio())

hits = self.searcher.search('information retrieval')
Expand All @@ -272,6 +272,17 @@ def test_rocchio(self):
self.assertEqual(hits[9].docid, 'CACM-1032')
self.assertAlmostEqual(hits[9].score, 2.57510, places=5)

self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10,
bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1, output_query=False, use_negative=False)
self.assertTrue(self.searcher.is_using_rocchio())

hits = self.searcher.search('information retrieval')

self.assertEqual(hits[0].docid, 'CACM-3134')
self.assertAlmostEqual(hits[0].score, 4.03900, places=5)
self.assertEqual(hits[9].docid, 'CACM-1032')
self.assertAlmostEqual(hits[9].score, 2.91550, places=5)

with self.assertRaises(TypeError):
self.no_vec_searcher.set_rocchio()

Expand Down

0 comments on commit 6914743

Please sign in to comment.