From 20a2da999f74303bebbb40d7665a50f801a9168d Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 10 Sep 2024 09:49:26 +0200 Subject: [PATCH 01/11] Port NLTKDocumentSplitter from dC to Haystack --- .../preprocessors/nltk_document_splitter.py | 226 ++++++++++++ haystack/components/preprocessors/utils.py | 221 ++++++++++++ pyproject.toml | 2 + ...plitting-enhancement-6ef6f59bc277662c.yaml | 4 + .../test_nltk_document_splitter.py | 335 ++++++++++++++++++ 5 files changed, 788 insertions(+) create mode 100644 haystack/components/preprocessors/nltk_document_splitter.py create mode 100644 haystack/components/preprocessors/utils.py create mode 100644 releasenotes/notes/nltk-document-splitting-enhancement-6ef6f59bc277662c.yaml create mode 100644 test/components/preprocessors/test_nltk_document_splitter.py diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py new file mode 100644 index 0000000000..985850fefd --- /dev/null +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -0,0 +1,226 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from copy import deepcopy +from typing import Dict, List, Literal, Tuple + +from haystack import Document, component, logging +from haystack.components.preprocessors.document_splitter import DocumentSplitter +from haystack.components.preprocessors.utils import Language, SentenceSplitter +from haystack.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport("Run 'pip install nltk'") as nltk_imports: + import nltk + + +@component +class NLTKDocumentSplitter(DocumentSplitter): + def __init__( + self, + split_by: Literal["word", "sentence", "page", "passage"] = "word", + split_length: int = 200, + split_overlap: int = 0, + split_threshold: int = 0, + respect_sentence_boundary: bool = False, + language: Language = "en", + use_split_rules: bool = True, + extend_abbreviations: bool = True, + ): + """ + Splits your documents using NLTK to respect sentence boundaries. + + Initialize the NLTKDocumentSplitter. + + :param split_by: Select the unit for splitting your documents. Choose from `word` for splitting by spaces (" "), + `sentence` for splitting by NLTK sentence tokenizer, `page` for splitting by the form feed ("\\f") or + `passage` for splitting by double line breaks ("\\n\\n"). + :param split_length: The maximum number of units in each split. + :param split_overlap: The number of overlapping units for each split. + :param split_threshold: The minimum number of units per split. If a split has fewer units + than the threshold, it's attached to the previous split. + :param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word". + If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences. + :param language: Choose the language for the NLTK tokenizer. The default is English ("en"). + :param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`. + :param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list + of curated abbreviations, if available. + This is currently supported for English ("en") and German ("de"). + """ + super(NLTKDocumentSplitter, self).__init__( + split_by=split_by, split_length=split_length, split_overlap=split_overlap, split_threshold=split_threshold + ) + nltk_imports.check() + + if respect_sentence_boundary and split_by != "word": + logger.warning( + "The 'respect_sentence_boundary' option is only supported for `split_by='word'`. " + "The option `respect_sentence_boundary` will be set to `False`." + ) + respect_sentence_boundary = False + self.respect_sentence_boundary = respect_sentence_boundary + self.sentence_splitter = SentenceSplitter( + language=language, + use_split_rules=use_split_rules, + extend_abbreviations=extend_abbreviations, + keep_white_spaces=True, + ) + self.language = language + + def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage", "page"]) -> List[str]: + if split_by == "page": + self.split_at = "\f" + units = text.split(self.split_at) + elif split_by == "passage": + self.split_at = "\n\n" + units = text.split(self.split_at) + elif split_by == "sentence": + # whitespace is preserved while splitting text into sentences when using keep_white_spaces=True + # so split_at is set to an empty string + self.split_at = "" + result = self.sentence_splitter.split_sentences(text) + units = [sentence["sentence"] for sentence in result] + elif split_by == "word": + self.split_at = " " + units = text.split(self.split_at) + else: + raise NotImplementedError( + "DocumentSplitter only supports 'word', 'sentence', 'page' or 'passage' split_by options." + ) + + # Add the delimiter back to all units except the last one + for i in range(len(units) - 1): + units[i] += self.split_at + return units + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]) -> Dict[str, List[Document]]: + """ + Split documents into smaller parts. + + Splits documents by the unit expressed in `split_by`, with a length of `split_length` + and an overlap of `split_overlap`. + + :param documents: The documents to split. + + :returns: A dictionary with the following key: + - `documents`: List of documents with the split texts. Each document includes: + - A metadata field source_id to track the original document. + - A metadata field page_number to track the original page number. + - All other metadata copied from the original document. + + :raises TypeError: if the input is not a list of Documents. + :raises ValueError: if the content of a document is None. + """ + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): + raise TypeError("DocumentSplitter expects a List of Documents as input.") + + split_docs = [] + for doc in documents: + if doc.content is None: + raise ValueError( + f"DocumentSplitter only works with text documents but content for document ID {doc.id} is None." + ) + + if self.respect_sentence_boundary: + units = self._split_into_units(doc.content, "sentence") + text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount( + sentences=units, split_length=self.split_length, split_overlap=self.split_overlap + ) + else: + units = self._split_into_units(doc.content, self.split_by) + text_splits, splits_pages, splits_start_idxs = self._concatenate_units( + elements=units, + split_length=self.split_length, + split_overlap=self.split_overlap, + split_threshold=self.split_threshold, + ) + metadata = deepcopy(doc.meta) + metadata["source_id"] = doc.id + split_docs += self._create_docs_from_splits( + text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata + ) + return {"documents": split_docs} + + @staticmethod + def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int: + """ + Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`. + """ + # If the split_overlap is 0, we don't need to keep any sentences + if split_overlap == 0: + return 0 + + num_sentences_to_keep = 0 + num_words = 0 + for sent in reversed(sentences): + num_words += len(sent.split()) + # If the number of words is larger than the split_length then don't add any more sentences + if num_words > split_length: + break + num_sentences_to_keep += 1 + if num_words > split_overlap: + break + return num_sentences_to_keep + + def _concatenate_sentences_based_on_word_amount( + self, sentences: List[str], split_length: int, split_overlap: int + ) -> Tuple[List[str], List[int], List[int]]: + """ + Groups the sentences into chunks of `split_length` words while respecting sentence boundaries. + """ + # Chunk information + chunk_word_count = 0 + chunk_starting_page_number = 1 + chunk_start_idx = 0 + current_chunk: List[str] = [] + # Output lists + split_start_page_numbers = [] + list_of_splits: List[List[str]] = [] + split_start_indices = [] + + for sentence_idx, sentence in enumerate(sentences): + current_chunk.append(sentence) + chunk_word_count += len(sentence.split()) + next_sentence_word_count = ( + len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0 + ) + + # Number of words in the current chunk plus the next sentence is larger than the split_length + # or we reached the last sentence + if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1: + # Save current chunk and start a new one + list_of_splits.append(current_chunk) + split_start_page_numbers.append(chunk_starting_page_number) + split_start_indices.append(chunk_start_idx) + + # Get the number of sentences that overlap with the next chunk + num_sentences_to_keep = self._number_of_sentences_to_keep( + sentences=current_chunk, split_length=split_length, split_overlap=split_overlap + ) + # Set up information for the new chunk + if num_sentences_to_keep > 0: + # Processed sentences are the ones that are not overlapping with the next chunk + processed_sentences = current_chunk[:-num_sentences_to_keep] + chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences) + chunk_start_idx += len("".join(processed_sentences)) + # Next chunk starts with the sentences that were overlapping with the previous chunk + current_chunk = current_chunk[-num_sentences_to_keep:] + chunk_word_count = sum([len(s.split()) for s in current_chunk]) + else: + # Here processed_sentences is the same as current_chunk since there is no overlap + chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk) + chunk_start_idx += len("".join(current_chunk)) + current_chunk = [] + chunk_word_count = 0 + + # Concatenate the sentences together within each split + text_splits = [] + for split in list_of_splits: + text = "".join(split) + if len(text) > 0: + text_splits.append(text) + + return text_splits, split_start_page_numbers, split_start_indices diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py new file mode 100644 index 0000000000..3537d3b484 --- /dev/null +++ b/haystack/components/preprocessors/utils.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from pathlib import Path +from typing import Any, Dict, List, Literal, Tuple + +from haystack.lazy_imports import LazyImport + +with LazyImport("Run 'pip install nltk'") as nltk_imports: + import nltk + +nltk_imports.check() + +logger = logging.getLogger(__name__) + +Language = Literal[ + "ru", "sl", "es", "sv", "tr", "cs", "da", "nl", "en", "et", "fi", "fr", "de", "el", "it", "no", "pl", "pt", "ml" +] +ISO639_TO_NLTK = { + "ru": "russian", + "sl": "slovene", + "es": "spanish", + "sv": "swedish", + "tr": "turkish", + "cs": "czech", + "da": "danish", + "nl": "dutch", + "en": "english", + "et": "estonian", + "fi": "finnish", + "fr": "french", + "de": "german", + "el": "greek", + "it": "italian", + "no": "norwegian", + "pl": "polish", + "pt": "portuguese", + "ml": "malayalam", +} + + +class CustomPunktLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): + # The following adjustment of PunktSentenceTokenizer is inspired by: + # https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer + # It is needed for preserving whitespace while splitting text into sentences. + _period_context_fmt = r""" + %(SentEndChars)s # a potential sentence ending + \s* # match potential whitespace [ \t\n\x0B\f\r] + (?=(?P + %(NonWord)s # either other punctuation + | + (?P\S+) # or some other token - original version: \s+(?P\S+) + ))""" + + def period_context_re(self) -> re.Pattern: + """ + Compiles and returns a regular expression to find contexts including possible sentence boundaries. + + :returns: A compiled regular expression pattern. + """ + try: + return self._re_period_context # type: ignore + except: # noqa: E722 + self._re_period_context = re.compile( + self._period_context_fmt + % { + "NonWord": self._re_non_word_chars, + # SentEndChars might be followed by closing brackets, so we match them here. + "SentEndChars": self._re_sent_end_chars + r"[\)\]}]*", + }, + re.UNICODE | re.VERBOSE, + ) + return self._re_period_context + + +def load_sentence_tokenizer( + language: Language, keep_white_spaces: bool = False +) -> nltk.tokenize.punkt.PunktSentenceTokenizer: + """ + Utility function to load the nltk sentence tokenizer. + + :param language: The language for the tokenizer. + :param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences. + :returns: nltk sentence tokenizer. + """ + try: + nltk.data.find("tokenizers/punkt_tab") + except LookupError: + try: + nltk.download("punkt_tab") + except FileExistsError as error: + logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error) + + language_name = ISO639_TO_NLTK.get(language) + + if language_name is not None: + sentence_tokenizer = nltk.data.load(f"tokenizers/punkt_tab/{language_name}.pickle") + else: + logger.error( + "PreProcessor couldn't find the default sentence tokenizer model for %s. " + " Using English instead. You may train your own model and use the 'tokenizer_model_folder' parameter.", + language, + ) + sentence_tokenizer = nltk.data.load("tokenizers/punkt_tab/english.pickle") + + if keep_white_spaces: + sentence_tokenizer._lang_vars = CustomPunktLanguageVars() + + return sentence_tokenizer + + +class SentenceSplitter: # pylint: disable=too-few-public-methods + """ + SentenceSplitter splits a text into sentences using the nltk sentence tokenizer + """ + + def __init__( + self, + language: Language = "en", + use_split_rules: bool = True, + extend_abbreviations: bool = True, + keep_white_spaces: bool = False, + ) -> None: + """ + Initializes the SentenceSplitter with the specified language, split rules, and abbreviation handling. + + :param language: The language for the tokenizer. Default is "en". + :param use_split_rules: If True, the additional split rules are used. If False, the rules are not used. + :param extend_abbreviations: If True, the abbreviations used by NLTK's PunktTokenizer are extended by a list + of curated abbreviations if available. If False, the default abbreviations are used. + Currently supported languages are: en, de. + :param keep_white_spaces: If True, the tokenizer will keep white spaces between sentences. + """ + self.language = language + self.sentence_tokenizer = load_sentence_tokenizer(language, keep_white_spaces=keep_white_spaces) + self.use_split_rules = use_split_rules + if extend_abbreviations: + abbreviations = self._read_abbreviations(language) + self.sentence_tokenizer._params.abbrev_types.update(abbreviations) + self.keep_white_spaces = keep_white_spaces + + def split_sentences(self, text: str) -> List[Dict[str, Any]]: + """ + Splits a text into sentences including references to original char positions for each split. + + :param text: The text to split. + :returns: list of sentences with positions. + """ + sentence_spans = list(self.sentence_tokenizer.span_tokenize(text)) + if self.use_split_rules: + sentence_spans = self._apply_split_rules(text, sentence_spans) + + sentences = [{"sentence": text[start:end], "start": start, "end": end} for start, end in sentence_spans] + return sentences + + def _apply_split_rules(self, text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + new_sentence_spans = [] + quote_spans = [match.span() for match in re.finditer(r"\W(\"+|\'+).*?\1", text)] + while sentence_spans: + span = sentence_spans.pop(0) + next_span = sentence_spans[0] if len(sentence_spans) > 0 else None + while next_span and self._needs_join(text, span, next_span, quote_spans): + sentence_spans.pop(0) + span = (span[0], next_span[1]) + next_span = sentence_spans[0] if len(sentence_spans) > 0 else None + start, end = span + new_sentence_spans.append((start, end)) + return new_sentence_spans + + def _needs_join( + self, text: str, span: Tuple[int, int], next_span: Tuple[int, int], quote_spans: List[Tuple[int, int]] + ) -> bool: + """ + Checks if the spans need to be joined as parts of one sentence. + + :param text: The text containing the spans. + :param span: The current sentence span within text. + :param next_span: The next sentence span within text. + :param quote_spans: All quoted spans within text. + :returns: True if the spans needs to be joined. + """ + start, end = span + next_start, next_end = next_span + + # sentence. sentence"\nsentence -> no split (end << quote_end) + # sentence.", sentence -> no split (end < quote_end) + # sentence?", sentence -> no split (end < quote_end) + if any(quote_start < end < quote_end for quote_start, quote_end in quote_spans): + # sentence boundary is inside a quote + return True + + # sentence." sentence -> split (end == quote_end) + # sentence?" sentence -> no split (end == quote_end) + if any(quote_start < end == quote_end and text[quote_end - 2] == "?" for quote_start, quote_end in quote_spans): + # question is cited + return True + + if re.search(r"(^|\n)\s*\d{1,2}\.$", text[start:end]) is not None: + # sentence ends with a numeration + return True + + # next sentence starts with a bracket or we return False + return re.search(r"^\s*[\(\[]", text[next_start:next_end]) + + def _read_abbreviations(self, language: Language) -> List[str]: + """ + Reads the abbreviations for a given language from the abbreviations file. + + :param language: The language to read the abbreviations for. + :returns: List of abbreviations. + """ + abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{language}.txt" + if not abbreviations_file.exists(): + logger.warning("No abbreviations file found for language %s. Using default abbreviations.", language) + return [] + + abbreviations = abbreviations_file.read_text().split("\n") + return abbreviations diff --git a/pyproject.toml b/pyproject.toml index e5c525d2c1..b25425d129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,8 @@ extra-dependencies = [ "python-pptx", # PPTXToDocument "python-docx", # DocxToDocument + "nltk", # NLTKDocumentSplitter + # OpenAPI "jsonref", # OpenAPIServiceConnector, OpenAPIServiceToFunctions "openapi3", diff --git a/releasenotes/notes/nltk-document-splitting-enhancement-6ef6f59bc277662c.yaml b/releasenotes/notes/nltk-document-splitting-enhancement-6ef6f59bc277662c.yaml new file mode 100644 index 0000000000..d97027b30b --- /dev/null +++ b/releasenotes/notes/nltk-document-splitting-enhancement-6ef6f59bc277662c.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Introduced a new NLTK document splitting component, enhancing document preprocessing capabilities. This feature allows for fine-grained control over the splitting of documents into smaller parts based on configurable criteria such as word count, sentence boundaries, and page breaks. It supports multiple languages and offers options for handling sentence boundaries and abbreviations, facilitating better handling of various document types for further processing tasks. diff --git a/test/components/preprocessors/test_nltk_document_splitter.py b/test/components/preprocessors/test_nltk_document_splitter.py new file mode 100644 index 0000000000..94f9ce77df --- /dev/null +++ b/test/components/preprocessors/test_nltk_document_splitter.py @@ -0,0 +1,335 @@ +from typing import List + +import pytest +from haystack import Document +from pytest import LogCaptureFixture + +from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter + + +def test_init_warning_message(caplog: LogCaptureFixture) -> None: + _ = NLTKDocumentSplitter(split_by="page", respect_sentence_boundary=True) + assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text + + +class TestDocumentSplitterSplitIntoUnits: + def test_document_splitter_split_into_units_word(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en" + ) + + text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything." + units = document_splitter._split_into_units(text=text, split_by="word") + + assert units == [ + "Moonlight ", + "shimmered ", + "softly, ", + "wolves ", + "howled ", + "nearby, ", + "night ", + "enveloped ", + "everything.", + ] + + def test_document_splitter_split_into_units_sentence(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="sentence", split_length=2, split_overlap=0, split_threshold=0, language="en" + ) + + text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night." + units = document_splitter._split_into_units(text=text, split_by="sentence") + + assert units == [ + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. ", + "It was a dark night.", + ] + + def test_document_splitter_split_into_units_passage(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="passage", split_length=2, split_overlap=0, split_threshold=0, language="en" + ) + + text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\nIt was a dark night." + units = document_splitter._split_into_units(text=text, split_by="passage") + + assert units == [ + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\n\n", + "It was a dark night.", + ] + + def test_document_splitter_split_into_units_page(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="page", split_length=2, split_overlap=0, split_threshold=0, language="en" + ) + + text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\fIt was a dark night." + units = document_splitter._split_into_units(text=text, split_by="page") + + assert units == [ + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.\f", + "It was a dark night.", + ] + + def test_document_splitter_split_into_units_raise_error(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en" + ) + + text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything." + + with pytest.raises(NotImplementedError): + document_splitter._split_into_units(text=text, split_by="invalid") # type: ignore + + +class TestDocumentSplitterNumberOfSentencesToKeep: + @pytest.mark.parametrize( + "sentences, expected_num_sentences", + [ + (["Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0), + ([" It was a dark night ..."], 0), + ([" The moon was full."], 1), + ], + ) + def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None: + num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep( + sentences=sentences, split_length=5, split_overlap=2 + ) + assert num_sentences == expected_num_sentences + + def test_number_of_sentences_to_keep_split_overlap_zero(self) -> None: + sentences = [ + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything.", + " It was a dark night ...", + " The moon was full.", + ] + num_sentences = NLTKDocumentSplitter._number_of_sentences_to_keep( + sentences=sentences, split_length=5, split_overlap=0 + ) + assert num_sentences == 0 + + +class TestDocumentSplitterRun: + def test_run_type_error(self) -> None: + document_splitter = NLTKDocumentSplitter() + with pytest.raises(TypeError): + document_splitter.run(documents=Document(content="Moonlight shimmered softly.")) # type: ignore + + def test_run_value_error(self) -> None: + document_splitter = NLTKDocumentSplitter() + with pytest.raises(ValueError): + document_splitter.run(documents=[Document(content=None)]) + + def test_run_split_by_sentence_1(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="sentence", + split_length=2, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = ( + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... " + "The moon was full." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 2 + assert ( + documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped " + "everything. It was a dark night ... " + ) + assert documents[1].content == "The moon was full." + + def test_run_split_by_sentence_2(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="sentence", + split_length=1, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=False, + extend_abbreviations=True, + ) + + text = ( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) " + "This is the last test sentence." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 4 + assert ( + documents[0].content + == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + ) + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "This is another test sentence. " + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "(This is a third test sentence.) " + assert documents[2].meta["page_number"] == 1 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "This is the last test sentence." + assert documents[3].meta["page_number"] == 1 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + def test_run_split_by_sentence_3(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="sentence", + split_length=1, + split_overlap=0, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 4 + assert documents[0].content == "Sentence on page 1.\f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Sentence on page 2. \f" + assert documents[1].meta["page_number"] == 2 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 3. \f\f " + assert documents[2].meta["page_number"] == 3 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "Sentence on page 5." + assert documents[3].meta["page_number"] == 5 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + def test_run_split_by_sentence_4(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="sentence", + split_length=2, + split_overlap=1, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + ) + + text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 3 + assert documents[0].content == "Sentence on page 1.\fSentence on page 2. \f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Sentence on page 2. \fSentence on page 3. \f\f " + assert documents[1].meta["page_number"] == 2 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 3. \f\f Sentence on page 5." + assert documents[2].meta["page_number"] == 3 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + +class TestDocumentSplitterRespectSentenceBoundary: + def test_run_split_by_word_respect_sentence_boundary(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="word", + split_length=3, + split_overlap=0, + split_threshold=0, + language="en", + respect_sentence_boundary=True, + ) + + text = ( + "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f" + "The moon was full." + ) + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 3 + assert documents[0].content == "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. " + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "It was a dark night.\f" + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "The moon was full." + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + + def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="word", + split_length=13, + split_overlap=3, + split_threshold=0, + language="en", + respect_sentence_boundary=True, + use_split_rules=False, + extend_abbreviations=False, + ) + text = ( + "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + "This is another test sentence. (This is a third test sentence.) " + "This is the last test sentence." + ) + documents = document_splitter.run([Document(content=text)])["documents"] + assert len(documents) == 3 + assert ( + documents[0].content + == "This is a test sentence with many many words that exceeds the split length and should not be repeated. " + ) + assert "This is a test sentence with many many words" not in documents[1].content + assert "This is a test sentence with many many words" not in documents[2].content + + def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None: + document_splitter = NLTKDocumentSplitter( + split_by="word", + split_length=5, + split_overlap=1, + split_threshold=0, + language="en", + use_split_rules=True, + extend_abbreviations=True, + respect_sentence_boundary=True, + ) + + text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." + documents = document_splitter.run(documents=[Document(content=text)])["documents"] + + assert len(documents) == 4 + assert documents[0].content == "Sentence on page 1.\f" + assert documents[0].meta["page_number"] == 1 + assert documents[0].meta["split_id"] == 0 + assert documents[0].meta["split_idx_start"] == text.index(documents[0].content) + assert documents[1].content == "Sentence on page 1.\fSentence on page 2. \f" + assert documents[1].meta["page_number"] == 1 + assert documents[1].meta["split_id"] == 1 + assert documents[1].meta["split_idx_start"] == text.index(documents[1].content) + assert documents[2].content == "Sentence on page 2. \fSentence on page 3. \f\f " + assert documents[2].meta["page_number"] == 2 + assert documents[2].meta["split_id"] == 2 + assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) + assert documents[3].content == "Sentence on page 3. \f\f Sentence on page 5." + assert documents[3].meta["page_number"] == 3 + assert documents[3].meta["split_id"] == 3 + assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) From 4cbbcdfea5f43dab94ed8400fcbac05d39cb742f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 11:38:23 +0200 Subject: [PATCH 02/11] Improve pydocs --- .../preprocessors/nltk_document_splitter.py | 18 ++++++++++++++++++ haystack/components/preprocessors/utils.py | 7 +++++++ 2 files changed, 25 insertions(+) diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py index 985850fefd..9501d333cd 100644 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -70,6 +70,14 @@ def __init__( self.language = language def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage", "page"]) -> List[str]: + """ + Splits the text into units based on the specified split_by parameter. + + :param text: The text to split. + :param split_by: The unit to split the text by. Choose from "word", "sentence", "passage", or "page". + :returns: A list of units. + """ + if split_by == "page": self.split_at = "\f" units = text.split(self.split_at) @@ -148,6 +156,11 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]: def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int: """ Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`. + + :param sentences: The list of sentences to split. + :param split_length: The maximum number of words in each split. + :param split_overlap: The number of overlapping words in each split. + :returns: The number of sentences to keep in the next chunk. """ # If the split_overlap is 0, we don't need to keep any sentences if split_overlap == 0: @@ -170,6 +183,11 @@ def _concatenate_sentences_based_on_word_amount( ) -> Tuple[List[str], List[int], List[int]]: """ Groups the sentences into chunks of `split_length` words while respecting sentence boundaries. + + :param sentences: The list of sentences to split. + :param split_length: The maximum number of words in each split. + :param split_overlap: The number of overlapping words in each split. + :returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices. """ # Chunk information chunk_word_count = 0 diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index 3537d3b484..0c3f0098fa 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -157,6 +157,13 @@ def split_sentences(self, text: str) -> List[Dict[str, Any]]: return sentences def _apply_split_rules(self, text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Applies additional split rules to the sentence spans. + + :param text: The text to split. + :param sentence_spans: The list of sentence spans to split. + :returns: The list of sentence spans after applying the split rules. + """ new_sentence_spans = [] quote_spans = [match.span() for match in re.finditer(r"\W(\"+|\'+).*?\1", text)] while sentence_spans: From 4e5d7bfdea9e9c8297b78f3de38dd171a36fa15e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 11:41:16 +0200 Subject: [PATCH 03/11] Use haystack logging --- haystack/components/preprocessors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index 0c3f0098fa..e267c38842 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -2,11 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 -import logging import re from pathlib import Path from typing import Any, Dict, List, Literal, Tuple +from haystack import logging from haystack.lazy_imports import LazyImport with LazyImport("Run 'pip install nltk'") as nltk_imports: From 4f7b26aaf77df407c263ef5e1e0516a27e745dfa Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 11:48:23 +0200 Subject: [PATCH 04/11] Add NLTKDocumentSplitter to __init__.py --- haystack/components/preprocessors/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/haystack/components/preprocessors/__init__.py b/haystack/components/preprocessors/__init__.py index d39151f3c7..f7e132077a 100644 --- a/haystack/components/preprocessors/__init__.py +++ b/haystack/components/preprocessors/__init__.py @@ -4,6 +4,7 @@ from .document_cleaner import DocumentCleaner from .document_splitter import DocumentSplitter +from .nltk_document_splitter import NLTKDocumentSplitter from .text_cleaner import TextCleaner -__all__ = ["DocumentSplitter", "DocumentCleaner", "TextCleaner"] +__all__ = ["DocumentSplitter", "DocumentCleaner", "TextCleaner", "NLTKDocumentSplitter"] From d37202eeb3dd5feca070c2b2582aa523be2727d0 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 13:48:35 +0200 Subject: [PATCH 05/11] Use haystack logging, rename test classes --- haystack/components/preprocessors/utils.py | 10 +++++----- .../preprocessors/test_nltk_document_splitter.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index e267c38842..c5166606a9 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -92,17 +92,17 @@ def load_sentence_tokenizer( try: nltk.download("punkt_tab") except FileExistsError as error: - logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: %s", error) + logger.debug("NLTK punkt tokenizer seems to be already downloaded. Error message: {error}", error=error) language_name = ISO639_TO_NLTK.get(language) if language_name is not None: sentence_tokenizer = nltk.data.load(f"tokenizers/punkt_tab/{language_name}.pickle") else: - logger.error( - "PreProcessor couldn't find the default sentence tokenizer model for %s. " + logger.warning( + "PreProcessor couldn't find the default sentence tokenizer model for {language}. " " Using English instead. You may train your own model and use the 'tokenizer_model_folder' parameter.", - language, + language=language, ) sentence_tokenizer = nltk.data.load("tokenizers/punkt_tab/english.pickle") @@ -221,7 +221,7 @@ def _read_abbreviations(self, language: Language) -> List[str]: """ abbreviations_file = Path(__file__).parent.parent / f"data/abbreviations/{language}.txt" if not abbreviations_file.exists(): - logger.warning("No abbreviations file found for language %s. Using default abbreviations.", language) + logger.warning("No abbreviations file found for {language}.Using default abbreviations.", language=language) return [] abbreviations = abbreviations_file.read_text().split("\n") diff --git a/test/components/preprocessors/test_nltk_document_splitter.py b/test/components/preprocessors/test_nltk_document_splitter.py index 94f9ce77df..1140e57a5f 100644 --- a/test/components/preprocessors/test_nltk_document_splitter.py +++ b/test/components/preprocessors/test_nltk_document_splitter.py @@ -12,7 +12,7 @@ def test_init_warning_message(caplog: LogCaptureFixture) -> None: assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text -class TestDocumentSplitterSplitIntoUnits: +class TestNLTKDocumentSplitterSplitIntoUnits: def test_document_splitter_split_into_units_word(self) -> None: document_splitter = NLTKDocumentSplitter( split_by="word", split_length=3, split_overlap=0, split_threshold=0, language="en" @@ -83,7 +83,7 @@ def test_document_splitter_split_into_units_raise_error(self) -> None: document_splitter._split_into_units(text=text, split_by="invalid") # type: ignore -class TestDocumentSplitterNumberOfSentencesToKeep: +class TestNLTKDocumentSplitterNumberOfSentencesToKeep: @pytest.mark.parametrize( "sentences, expected_num_sentences", [ @@ -110,7 +110,7 @@ def test_number_of_sentences_to_keep_split_overlap_zero(self) -> None: assert num_sentences == 0 -class TestDocumentSplitterRun: +class TestNLTKDocumentSplitterRun: def test_run_type_error(self) -> None: document_splitter = NLTKDocumentSplitter() with pytest.raises(TypeError): @@ -245,7 +245,7 @@ def test_run_split_by_sentence_4(self) -> None: assert documents[2].meta["split_idx_start"] == text.index(documents[2].content) -class TestDocumentSplitterRespectSentenceBoundary: +class TestNLTKDocumentSplitterRespectSentenceBoundary: def test_run_split_by_word_respect_sentence_boundary(self) -> None: document_splitter = NLTKDocumentSplitter( split_by="word", From db073922b8bc2cfecc57ad7dbc0fcf16d542d37a Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 14:29:33 +0200 Subject: [PATCH 06/11] Fixing _needs_join return --- haystack/components/preprocessors/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index c5166606a9..443b3a3e5a 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -210,7 +210,7 @@ def _needs_join( return True # next sentence starts with a bracket or we return False - return re.search(r"^\s*[\(\[]", text[next_start:next_end]) + return re.search(r"^\s*[\(\[]", text[next_start:next_end]) is not None def _read_abbreviations(self, language: Language) -> List[str]: """ From a113d56a7e0a25f682b17cd1c4128502f61bed28 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 10 Sep 2024 14:50:44 +0200 Subject: [PATCH 07/11] Linting --- .../components/preprocessors/nltk_document_splitter.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py index 9501d333cd..b11ebd0c71 100644 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -8,13 +8,9 @@ from haystack import Document, component, logging from haystack.components.preprocessors.document_splitter import DocumentSplitter from haystack.components.preprocessors.utils import Language, SentenceSplitter -from haystack.lazy_imports import LazyImport logger = logging.getLogger(__name__) -with LazyImport("Run 'pip install nltk'") as nltk_imports: - import nltk - @component class NLTKDocumentSplitter(DocumentSplitter): @@ -52,7 +48,6 @@ def __init__( super(NLTKDocumentSplitter, self).__init__( split_by=split_by, split_length=split_length, split_overlap=split_overlap, split_threshold=split_threshold ) - nltk_imports.check() if respect_sentence_boundary and split_by != "word": logger.warning( @@ -226,7 +221,7 @@ def _concatenate_sentences_based_on_word_amount( chunk_start_idx += len("".join(processed_sentences)) # Next chunk starts with the sentences that were overlapping with the previous chunk current_chunk = current_chunk[-num_sentences_to_keep:] - chunk_word_count = sum([len(s.split()) for s in current_chunk]) + chunk_word_count = sum(len(s.split()) for s in current_chunk) else: # Here processed_sentences is the same as current_chunk since there is no overlap chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk) From df7ac6ba5849ec79d6f17b226dc1ca7fe93b8dc9 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 11 Sep 2024 16:44:17 +0200 Subject: [PATCH 08/11] PR feedback --- haystack/components/preprocessors/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index 443b3a3e5a..6313ec7674 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -138,7 +138,7 @@ def __init__( self.sentence_tokenizer = load_sentence_tokenizer(language, keep_white_spaces=keep_white_spaces) self.use_split_rules = use_split_rules if extend_abbreviations: - abbreviations = self._read_abbreviations(language) + abbreviations = SentenceSplitter._read_abbreviations(language) self.sentence_tokenizer._params.abbrev_types.update(abbreviations) self.keep_white_spaces = keep_white_spaces @@ -169,7 +169,7 @@ def _apply_split_rules(self, text: str, sentence_spans: List[Tuple[int, int]]) - while sentence_spans: span = sentence_spans.pop(0) next_span = sentence_spans[0] if len(sentence_spans) > 0 else None - while next_span and self._needs_join(text, span, next_span, quote_spans): + while next_span and SentenceSplitter._needs_join(text, span, next_span, quote_spans): sentence_spans.pop(0) span = (span[0], next_span[1]) next_span = sentence_spans[0] if len(sentence_spans) > 0 else None @@ -177,8 +177,9 @@ def _apply_split_rules(self, text: str, sentence_spans: List[Tuple[int, int]]) - new_sentence_spans.append((start, end)) return new_sentence_spans + @staticmethod def _needs_join( - self, text: str, span: Tuple[int, int], next_span: Tuple[int, int], quote_spans: List[Tuple[int, int]] + text: str, span: Tuple[int, int], next_span: Tuple[int, int], quote_spans: List[Tuple[int, int]] ) -> bool: """ Checks if the spans need to be joined as parts of one sentence. @@ -212,7 +213,8 @@ def _needs_join( # next sentence starts with a bracket or we return False return re.search(r"^\s*[\(\[]", text[next_start:next_end]) is not None - def _read_abbreviations(self, language: Language) -> List[str]: + @staticmethod + def _read_abbreviations(language: Language) -> List[str]: """ Reads the abbreviations for a given language from the abbreviations file. From 696b269578a5db2af621a3805531e3a2e161f975 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 09:22:02 +0200 Subject: [PATCH 09/11] More static methods --- haystack/components/preprocessors/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index 6313ec7674..08968853ea 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -151,12 +151,13 @@ def split_sentences(self, text: str) -> List[Dict[str, Any]]: """ sentence_spans = list(self.sentence_tokenizer.span_tokenize(text)) if self.use_split_rules: - sentence_spans = self._apply_split_rules(text, sentence_spans) + sentence_spans = SentenceSplitter._apply_split_rules(text, sentence_spans) sentences = [{"sentence": text[start:end], "start": start, "end": end} for start, end in sentence_spans] return sentences - def _apply_split_rules(self, text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + @staticmethod + def _apply_split_rules(text: str, sentence_spans: List[Tuple[int, int]]) -> List[Tuple[int, int]]: """ Applies additional split rules to the sentence spans. From 5b86408494f8b3dae95a9e51d2295b241f5f6ef4 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 11:09:54 +0200 Subject: [PATCH 10/11] Increase test coverage --- .../test_nltk_document_splitter.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/components/preprocessors/test_nltk_document_splitter.py b/test/components/preprocessors/test_nltk_document_splitter.py index 1140e57a5f..6614c82c3b 100644 --- a/test/components/preprocessors/test_nltk_document_splitter.py +++ b/test/components/preprocessors/test_nltk_document_splitter.py @@ -5,6 +5,7 @@ from pytest import LogCaptureFixture from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter +from haystack.components.preprocessors.utils import SentenceSplitter def test_init_warning_message(caplog: LogCaptureFixture) -> None: @@ -333,3 +334,30 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page assert documents[3].meta["page_number"] == 3 assert documents[3].meta["split_id"] == 3 assert documents[3].meta["split_idx_start"] == text.index(documents[3].content) + + +class TestSentenceSplitter: + def test_apply_split_rules_second_while_loop(self) -> None: + text = "This is a test. (With a parenthetical statement.) And another sentence." + spans = [(0, 15), (16, 50), (51, 74)] + result = SentenceSplitter._apply_split_rules(text, spans) + assert len(result) == 2 + assert result == [(0, 50), (51, 74)] + + def test_apply_split_rules_no_join(self) -> None: + text = "This is a test. This is another test. And a third test." + spans = [(0, 15), (16, 36), (37, 54)] + result = SentenceSplitter._apply_split_rules(text, spans) + assert len(result) == 3 + assert result == [(0, 15), (16, 36), (37, 54)] + + @pytest.mark.parametrize( + "text,span,next_span,quote_spans,expected", + [ + # triggers sentence boundary is inside a quote + ('He said, "Hello World." Then left.', (0, 15), (16, 23), [(9, 23)], True) + ], + ) + def test_needs_join_cases(self, text, span, next_span, quote_spans, expected): + result = SentenceSplitter._needs_join(text, span, next_span, quote_spans) + assert result == expected, f"Expected {expected} for input: {text}, {span}, {next_span}, {quote_spans}" From efd329eac52923655686319d71bee21e51b4261f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Sep 2024 11:13:50 +0200 Subject: [PATCH 11/11] Compile pattern --- haystack/components/preprocessors/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/components/preprocessors/utils.py b/haystack/components/preprocessors/utils.py index 08968853ea..ba4d89585b 100644 --- a/haystack/components/preprocessors/utils.py +++ b/haystack/components/preprocessors/utils.py @@ -41,6 +41,8 @@ "ml": "malayalam", } +QUOTE_SPANS_RE = re.compile(r"\W(\"+|\'+).*?\1") + class CustomPunktLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): # The following adjustment of PunktSentenceTokenizer is inspired by: @@ -166,7 +168,7 @@ def _apply_split_rules(text: str, sentence_spans: List[Tuple[int, int]]) -> List :returns: The list of sentence spans after applying the split rules. """ new_sentence_spans = [] - quote_spans = [match.span() for match in re.finditer(r"\W(\"+|\'+).*?\1", text)] + quote_spans = [match.span() for match in QUOTE_SPANS_RE.finditer(text)] while sentence_spans: span = sentence_spans.pop(0) next_span = sentence_spans[0] if len(sentence_spans) > 0 else None