Skip to content

Commit

Permalink
update bm25Vectorizer (#323)
Browse files Browse the repository at this point in the history
* fix bug [random vocabulary order] in vectorizer
* make norm as an optional parameter
* add unit test for vectorizer
  • Loading branch information
MXueguang authored Jan 17, 2021
1 parent 889a976 commit 98f3236
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 17 deletions.
51 changes: 36 additions & 15 deletions pyserini/vectorizer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
#

import math
from typing import List
from typing import List, Optional
from sklearn.preprocessing import normalize

from scipy.sparse import csr_matrix
from sklearn.preprocessing import normalize

from pyserini import index, search
from pyserini.analysis import Analyzer, get_lucene_analyzer
from tqdm import tqdm


class Vectorizer:
Expand All @@ -42,12 +44,15 @@ def __init__(self, lucene_index_path: str, min_df: int = 1, verbose: bool = Fals
self.index_reader = index.IndexReader(lucene_index_path)
self.searcher = search.SimpleSearcher(lucene_index_path)
self.num_docs: int = self.searcher.num_docs
self.stats = self.index_reader.stats()
self.analyzer = Analyzer(get_lucene_analyzer())

# build vocabulary
self.vocabulary_ = set()
for term in self.index_reader.terms():
if term.df > self.min_df:
self.vocabulary_.add(term.term)
self.vocabulary_ = sorted(self.vocabulary_)

# build term to index mapping
self.term_to_index = {}
Expand All @@ -58,6 +63,17 @@ def __init__(self, lucene_index_path: str, min_df: int = 1, verbose: bool = Fals
if self.verbose:
print(f'Found {self.vocabulary_size} terms with min_df={self.min_df}')

def get_query_vector(self, query: str):
matrix_row, matrix_col, matrix_data = [], [], []
tokens = self.analyzer.analyze(query)
for term in tokens:
if term in self.vocabulary_:
matrix_row.append(0)
matrix_col.append(self.term_to_index[term])
matrix_data.append(1)
vectors = csr_matrix((matrix_data, (matrix_row, matrix_col)), shape=(1, self.vocabulary_size))
return vectors


class TfidfVectorizer(Vectorizer):
"""Wrapper class for tf-idf vectorizer implemented on top of Pyserini.
Expand All @@ -79,26 +95,25 @@ def __init__(self, lucene_index_path: str, min_df: int = 1, verbose: bool = Fals
for term in self.index_reader.terms():
self.idf_[term.term] = math.log(self.num_docs / term.df)

def get_vectors(self, docids: List[str]):
def get_vectors(self, docids: List[str], norm: Optional[str] = 'l2'):
"""Get the tf-idf vectors given a list of docids
Parameters
----------
norm : str
Normalize the sparse matrix
docids : List[str]
The piece of text to analyze.
Returns
-------
csr_matrix
L2 normalized sparse matrix representation of tf-idf vectors
Sparse matrix representation of tf-idf vectors
"""
matrix_row, matrix_col, matrix_data = [], [], []
num_docs = len(docids)

for index, doc_id in enumerate(docids):
if index % 1000 == 0 and num_docs > 1000 and self.verbose:
print(f'Vectorizing: {index}/{len(docids)}')

for index, doc_id in enumerate(tqdm(docids)):
# Term Frequency
tf = self.index_reader.get_document_vector(doc_id)
if tf is None:
Expand All @@ -115,7 +130,10 @@ def get_vectors(self, docids: List[str]):
matrix_data.append(tfidf)

vectors = csr_matrix((matrix_data, (matrix_row, matrix_col)), shape=(num_docs, self.vocabulary_size))
return normalize(vectors, norm='l2')

if norm:
return normalize(vectors, norm=norm)
return vectors


class BM25Vectorizer(Vectorizer):
Expand All @@ -134,25 +152,25 @@ class BM25Vectorizer(Vectorizer):
def __init__(self, lucene_index_path: str, min_df: int = 1, verbose: bool = False):
super().__init__(lucene_index_path, min_df, verbose)

def get_vectors(self, docids: List[str]):
def get_vectors(self, docids: List[str], norm: Optional[str] = 'l2'):
"""Get the BM25 vectors given a list of docids
Parameters
----------
norm : str
Normalize the sparse matrix
docids : List[str]
The piece of text to analyze.
Returns
-------
csr_matrix
L2 normalized sparse matrix representation of BM25 vectors
Sparse matrix representation of BM25 vectors
"""
matrix_row, matrix_col, matrix_data = [], [], []
num_docs = len(docids)

for index, doc_id in enumerate(docids):
if index % 1000 == 0 and num_docs > 1000 and self.verbose:
print(f'Vectorizing: {index}/{len(docids)}')
for index, doc_id in enumerate(tqdm(docids)):

# Term Frequency
tf = self.index_reader.get_document_vector(doc_id)
Expand All @@ -170,4 +188,7 @@ def get_vectors(self, docids: List[str]):
matrix_data.append(bm25_weight)

vectors = csr_matrix((matrix_data, (matrix_row, matrix_col)), shape=(num_docs, self.vocabulary_size))
return normalize(vectors, norm='l2')

if norm:
return normalize(vectors, norm=norm)
return vectors
23 changes: 21 additions & 2 deletions tests/test_index_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setUp(self):
self.searcher = search.SimpleSearcher(self.index_path)
self.index_reader = index.IndexReader(self.index_path)

def test_tfidf_vectorizer(self):
def test_tfidf_vectorizer_train(self):
vectorizer = TfidfVectorizer(self.index_path, min_df=5)
train_docs = ['CACM-0239', 'CACM-0440', 'CACM-3168', 'CACM-3169']
train_labels = [1, 1, 0, 0]
Expand All @@ -62,7 +62,7 @@ def test_tfidf_vectorizer(self):
self.assertAlmostEqual(0.51837413, pred[1][0], places=8)
self.assertAlmostEqual(0.48162587, pred[1][1], places=8)

def test_bm25_vectorizer(self):
def test_bm25_vectorizer_train(self):
vectorizer = BM25Vectorizer(self.index_path, min_df=5)
train_docs = ['CACM-0239', 'CACM-0440', 'CACM-3168', 'CACM-3169']
train_labels = [1, 1, 0, 0]
Expand All @@ -77,6 +77,25 @@ def test_bm25_vectorizer(self):
self.assertAlmostEqual(0.48288416, pred[1][0], places=8)
self.assertAlmostEqual(0.51711584, pred[1][1], places=8)

def test_tfidf_vectorizer(self):
vectorizer = TfidfVectorizer(self.index_path, min_df=5)
result = vectorizer.get_vectors(['CACM-0239', 'CACM-0440'], norm=None)
self.assertAlmostEqual(result[0, 190], 2.907369334264736, places=8)
self.assertAlmostEqual(result[1, 391], 0.07516490235060004, places=8)

def test_bm25_vectorizer(self):
vectorizer = BM25Vectorizer(self.index_path, min_df=5)
result = vectorizer.get_vectors(['CACM-0239', 'CACM-0440'], norm=None)
self.assertAlmostEqual(result[0, 190], 1.7513844966888428, places=8)
self.assertAlmostEqual(result[1, 391], 0.03765463829040527, places=8)

def test_vectorizer_query(self):
vectorizer = BM25Vectorizer(self.index_path, min_df=5)
result = vectorizer.get_query_vector('this is a query to test query vector')
self.assertEqual(result[0, 2703], 2)
self.assertEqual(result[0, 3078], 1)
self.assertEqual(result[0, 3204], 1)

def test_terms_count(self):
# We're going to iterate through the index and make sure we have the correct number of terms.
self.assertEqual(sum(1 for x in self.index_reader.terms()), 14363)
Expand Down

0 comments on commit 98f3236

Please sign in to comment.