diff --git a/pyserini/analysis/_base.py b/pyserini/analysis/_base.py index 7b4d3e391..7ca17c5ec 100644 --- a/pyserini/analysis/_base.py +++ b/pyserini/analysis/_base.py @@ -49,9 +49,10 @@ JAnalyzerUtils = autoclass('io.anserini.analysis.AnalyzerUtils') JDefaultEnglishAnalyzer = autoclass('io.anserini.analysis.DefaultEnglishAnalyzer') JTweetAnalyzer = autoclass('io.anserini.analysis.TweetAnalyzer') +JHuggingFaceTokenizerAnalyzer = autoclass('io.anserini.analysis.HuggingFaceTokenizerAnalyzer') -def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopwords=True) -> JAnalyzer: +def get_lucene_analyzer(language: str='en', stemming: bool=True, stemmer: str='porter', stopwords: bool=True, huggingFaceTokenizer: str=None) -> JAnalyzer: """Create a Lucene ``Analyzer`` with specific settings. Parameters @@ -64,6 +65,8 @@ def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopword Stemmer to use. stopwords : bool Set to filter stopwords. + huggingFaceTokenizer: str + a huggingface model id or path to a tokenizer.json file Returns ------- @@ -112,6 +115,8 @@ def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopword return JTurkishAnalyzer() elif language.lower() == 'tweet': return JTweetAnalyzer() + elif language.lower() == 'hgf_tokenizer': + return JHuggingFaceTokenizerAnalyzer(huggingFaceTokenizer) elif language.lower() == 'en': if stemming: if stopwords: diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 8aa19179a..e4c7ea8df 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -96,6 +96,12 @@ def test_analysis(self): tokens = analyzer.analyze('City buses are running on time.') self.assertEqual(tokens, ['citi', 'buse', 'ar', 'run', 'on', 'time']) + # HuggingFace analyzer, with bert wordpiece tokenizer + analyzer = Analyzer(get_lucene_analyzer(language="hgf_tokenizer", huggingFaceTokenizer="bert-base-uncased")) + self.assertTrue(isinstance(analyzer, Analyzer)) + tokens = analyzer.analyze('This tokenizer generates wordpiece tokens') + self.assertEqual(tokens, ['this', 'token', '##izer', 'generates', 'word', '##piece', 'token', '##s']) + def test_invalid_analyzer_wrapper(self): # Invalid JAnalyzer, make sure we get an exception. with self.assertRaises(TypeError):