diff --git a/pyserini/search/lucene/__main__.py b/pyserini/search/lucene/__main__.py index d34f89517..e4b669fe8 100644 --- a/pyserini/search/lucene/__main__.py +++ b/pyserini/search/lucene/__main__.py @@ -88,6 +88,7 @@ def define_search_args(parser): parser.add_argument('--rm3', action='store_true', help="Use RM3") parser.add_argument('--rocchio', action='store_true', help="Use Rocchio") + parser.add_argument('--rocchio-use-negative', action='store_true', help="Use nonrelevant labels in Rocchio") parser.add_argument('--qld', action='store_true', help="Use QLD") parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en') @@ -182,7 +183,10 @@ def define_search_args(parser): if args.rocchio: search_rankers.append('rocchio') - searcher.set_rocchio() + if args.rocchio_use_negative: + searcher.set_rocchio(gamma=0.15, use_negative=True) + else: + searcher.set_rocchio() fields = dict() if args.fields: diff --git a/pyserini/search/lucene/_searcher.py b/pyserini/search/lucene/_searcher.py index 579356e2b..029bbc165 100644 --- a/pyserini/search/lucene/_searcher.py +++ b/pyserini/search/lucene/_searcher.py @@ -241,7 +241,7 @@ def is_using_rm3(self) -> bool: """Check if RM3 query expansion is being performed.""" return self.object.useRM3() - def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10, rocchio_alpha=1, rocchio_beta=0.75, rocchio_gamma=0, rocchio_output_query=False): + def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, bottom_fb_docs=10, alpha=1, beta=0.75, gamma=0, output_query=False, use_negative=False): """Configure Rocchio query expansion. Parameters @@ -262,9 +262,11 @@ def set_rocchio(self, top_fb_terms=10, top_fb_docs=10, bottom_fb_terms=10, botto Rocchio parameter for weight to assign to the nonrelevant document vector. rocchio_output_query : bool Print the original and expanded queries as debug output. + rocchio_use_negative : bool + Rocchio parameter to use negative labels. """ if self.object.reader.getTermVectors(0): - self.object.setRocchio(top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, rocchio_alpha, rocchio_beta, rocchio_gamma, rocchio_output_query) + self.object.setRocchio(top_fb_terms, top_fb_docs, bottom_fb_terms, bottom_fb_docs, alpha, beta, gamma, output_query, use_negative) else: raise TypeError("Rocchio is not supported for indexes without document vectors.") diff --git a/tests/test_search.py b/tests/test_search.py index 96a90a421..4049a466c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -262,7 +262,7 @@ def test_rocchio(self): self.assertAlmostEqual(hits[9].score, 4.21740, places=5) self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10, - bottom_fb_docs=8, rocchio_alpha=0.4, rocchio_beta=0.5, rocchio_gamma=0.1) + bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1, output_query=False, use_negative=True) self.assertTrue(self.searcher.is_using_rocchio()) hits = self.searcher.search('information retrieval') @@ -272,6 +272,17 @@ def test_rocchio(self): self.assertEqual(hits[9].docid, 'CACM-1032') self.assertAlmostEqual(hits[9].score, 2.57510, places=5) + self.searcher.set_rocchio(top_fb_terms=10, top_fb_docs=8, bottom_fb_terms=10, + bottom_fb_docs=8, alpha=0.4, beta=0.5, gamma=0.1, output_query=False, use_negative=False) + self.assertTrue(self.searcher.is_using_rocchio()) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 4.03900, places=5) + self.assertEqual(hits[9].docid, 'CACM-1032') + self.assertAlmostEqual(hits[9].score, 2.91550, places=5) + with self.assertRaises(TypeError): self.no_vec_searcher.set_rocchio()