Skip to content

Commit

Permalink
feat: Add TextDocumentSplitter that splits by word, sentence, passage…
Browse files Browse the repository at this point in the history
… (2.0) (#5870)

* draft split by word, sentence, passage

* naive way to split sentences without nltk

* reno

* add tests

* make input list of docs, review feedback

* add source_id and more validation

* update docstrings

* add split delimiters back to strings

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
  • Loading branch information
julian-risch and dfokina authored Sep 27, 2023
1 parent 6665e8e commit 4413675
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 0 deletions.
3 changes: 3 additions & 0 deletions haystack/preview/components/preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter

__all__ = ["TextDocumentSplitter"]
104 changes: 104 additions & 0 deletions haystack/preview/components/preprocessors/text_document_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from copy import deepcopy
from typing import List, Dict, Any, Literal

from more_itertools import windowed

from haystack.preview import component, Document, default_from_dict, default_to_dict


@component
class TextDocumentSplitter:
"""
Splits a list of text documents into a list of text documents with shorter texts.
This is useful for splitting documents with long texts that otherwise would not fit into the maximum text length of language models.
"""

def __init__(
self, split_by: Literal["word", "sentence", "passage"] = "word", split_length: int = 200, split_overlap: int = 0
):
"""
:param split_by: The unit by which the document should be split. Choose from "word" for splitting by " ",
"sentence" for splitting by ".", or "passage" for splitting by "\n\n".
:param split_length: The maximum number of units in each split.
:param split_overlap: The number of units that each split should overlap.
"""

self.split_by = split_by
if split_by not in ["word", "sentence", "passage"]:
raise ValueError("split_by must be one of 'word', 'sentence' or 'passage'.")
if split_length <= 0:
raise ValueError("split_length must be greater than 0.")
self.split_length = split_length
if split_overlap < 0:
raise ValueError("split_overlap must be greater than or equal to 0.")
self.split_overlap = split_overlap

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Splits the documents by split_by after split_length units with an overlap of split_overlap units.
Returns a list of documents with the split texts.
A metadata field "source_id" is added to each document to keep track of the original document that was split.
:param documents: The documents to split.
:return: A list of documents with the split texts.
"""

if not documents or not isinstance(documents, list) or not isinstance(documents[0], Document):
raise TypeError("TextDocumentSplitter expects a List of Documents as input.")
split_docs = []
for doc in documents:
if doc.text is None:
raise ValueError(
f"TextDocumentSplitter only works with text documents but document.text for document ID {doc.id} is None."
)
units = self._split_into_units(doc.text, self.split_by)
text_splits = self._concatenate_units(units, self.split_length, self.split_overlap)
metadata = deepcopy(doc.metadata)
metadata["source_id"] = doc.id
split_docs += [Document(text=txt, metadata=metadata) for txt in text_splits]
return {"documents": split_docs}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(
self, split_by=self.split_by, split_length=self.split_length, split_overlap=self.split_overlap
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TextDocumentSplitter":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage"]) -> List[str]:
if split_by == "passage":
split_at = "\n\n"
elif split_by == "sentence":
split_at = "."
elif split_by == "word":
split_at = " "
else:
raise NotImplementedError(
"TextDocumentSplitter only supports 'passage', 'sentence' or 'word' split_by options."
)
units = text.split(split_at)
# Add the delimiter back to all units except the last one
for i in range(len(units) - 1):
units[i] += split_at
return units

def _concatenate_units(self, elements: List[str], split_length: int, split_overlap: int) -> List[str]:
"""
Concatenates the elements into parts of split_length units.
"""
text_splits = []
segments = windowed(elements, n=split_length, step=split_length - split_overlap)
for seg in segments:
current_units = [unit for unit in seg if unit is not None]
txt = "".join(current_units)
if len(txt) > 0:
text_splits.append(txt)
return text_splits
4 changes: 4 additions & 0 deletions releasenotes/notes/preprocessor-2-0-9828d930562fa3f5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Add the `TextDocumentSplitter` component for Haystack 2.0 that splits a Document with long text into multiple Documents with shorter texts. Thereby the texts match the maximum length that the language models in Embedders or other components can process.
Empty file.
157 changes: 157 additions & 0 deletions test/preview/components/preprocessors/test_text_document_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pytest

from haystack.preview import Document
from haystack.preview.components.preprocessors import TextDocumentSplitter


class TestTextDocumentSplitter:
@pytest.mark.unit
def test_non_text_document(self):
with pytest.raises(
ValueError, match="TextDocumentSplitter only works with text documents but document.text for document ID"
):
splitter = TextDocumentSplitter()
splitter.run(documents=[Document()])

@pytest.mark.unit
def test_single_doc(self):
with pytest.raises(TypeError, match="TextDocumentSplitter expects a List of Documents as input."):
splitter = TextDocumentSplitter()
splitter.run(documents=Document())

@pytest.mark.unit
def test_empty_list(self):
with pytest.raises(TypeError, match="TextDocumentSplitter expects a List of Documents as input."):
splitter = TextDocumentSplitter()
splitter.run(documents=[])

@pytest.mark.unit
def test_unsupported_split_by(self):
with pytest.raises(ValueError, match="split_by must be one of 'word', 'sentence' or 'passage'."):
TextDocumentSplitter(split_by="unsupported")

@pytest.mark.unit
def test_unsupported_split_length(self):
with pytest.raises(ValueError, match="split_length must be greater than 0."):
TextDocumentSplitter(split_length=0)

@pytest.mark.unit
def test_unsupported_split_overlap(self):
with pytest.raises(ValueError, match="split_overlap must be greater than or equal to 0."):
TextDocumentSplitter(split_overlap=-1)

@pytest.mark.unit
def test_split_by_word(self):
splitter = TextDocumentSplitter(split_by="word", split_length=10)
result = splitter.run(
documents=[
Document(
text="This is a text with some words. There is a second sentence. And there is a third sentence."
)
]
)
assert len(result["documents"]) == 2
assert result["documents"][0].text == "This is a text with some words. There is a "
assert result["documents"][1].text == "second sentence. And there is a third sentence."

@pytest.mark.unit
def test_split_by_word_multiple_input_docs(self):
splitter = TextDocumentSplitter(split_by="word", split_length=10)
result = splitter.run(
documents=[
Document(
text="This is a text with some words. There is a second sentence. And there is a third sentence."
),
Document(
text="This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence."
),
]
)
assert len(result["documents"]) == 5
assert result["documents"][0].text == "This is a text with some words. There is a "
assert result["documents"][1].text == "second sentence. And there is a third sentence."
assert result["documents"][2].text == "This is a different text with some words. There is "
assert result["documents"][3].text == "a second sentence. And there is a third sentence. And "
assert result["documents"][4].text == "there is a fourth sentence."

@pytest.mark.unit
def test_split_by_sentence(self):
splitter = TextDocumentSplitter(split_by="sentence", split_length=1)
result = splitter.run(
documents=[
Document(
text="This is a text with some words. There is a second sentence. And there is a third sentence."
)
]
)
assert len(result["documents"]) == 3
assert result["documents"][0].text == "This is a text with some words."
assert result["documents"][1].text == " There is a second sentence."
assert result["documents"][2].text == " And there is a third sentence."

@pytest.mark.unit
def test_split_by_passage(self):
splitter = TextDocumentSplitter(split_by="passage", split_length=1)
result = splitter.run(
documents=[
Document(
text="This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage."
)
]
)
assert len(result["documents"]) == 3
assert result["documents"][0].text == "This is a text with some words. There is a second sentence.\n\n"
assert result["documents"][1].text == "And there is a third sentence.\n\n"
assert result["documents"][2].text == " And another passage."

@pytest.mark.unit
def test_split_by_word_with_overlap(self):
splitter = TextDocumentSplitter(split_by="word", split_length=10, split_overlap=2)
result = splitter.run(
documents=[
Document(
text="This is a text with some words. There is a second sentence. And there is a third sentence."
)
]
)
assert len(result["documents"]) == 2
assert result["documents"][0].text == "This is a text with some words. There is a "
assert result["documents"][1].text == "is a second sentence. And there is a third sentence."

@pytest.mark.unit
def test_to_dict(self):
splitter = TextDocumentSplitter()
data = splitter.to_dict()
assert data == {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "word", "split_length": 200, "split_overlap": 0},
}

@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
splitter = TextDocumentSplitter(split_by="passage", split_length=100, split_overlap=1)
data = splitter.to_dict()
assert data == {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
}

@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "TextDocumentSplitter",
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1},
}
splitter = TextDocumentSplitter.from_dict(data)
assert splitter.split_by == "passage"
assert splitter.split_length == 100
assert splitter.split_overlap == 1

@pytest.mark.unit
def test_source_id_stored_in_metadata(self):
splitter = TextDocumentSplitter(split_by="word", split_length=10)
doc1 = Document(text="This is a text with some words.")
doc2 = Document(text="This is a different text with some words.")
result = splitter.run(documents=[doc1, doc2])
assert result["documents"][0].metadata["source_id"] == doc1.id
assert result["documents"][1].metadata["source_id"] == doc2.id

0 comments on commit 4413675

Please sign in to comment.