-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add TextDocumentSplitter that splits by word, sentence, passage…
… (2.0) (#5870) * draft split by word, sentence, passage * naive way to split sentences without nltk * reno * add tests * make input list of docs, review feedback * add source_id and more validation * update docstrings * add split delimiters back to strings --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
- Loading branch information
1 parent
6665e8e
commit 4413675
Showing
5 changed files
with
268 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from haystack.preview.components.preprocessors.text_document_splitter import TextDocumentSplitter | ||
|
||
__all__ = ["TextDocumentSplitter"] |
104 changes: 104 additions & 0 deletions
104
haystack/preview/components/preprocessors/text_document_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from copy import deepcopy | ||
from typing import List, Dict, Any, Literal | ||
|
||
from more_itertools import windowed | ||
|
||
from haystack.preview import component, Document, default_from_dict, default_to_dict | ||
|
||
|
||
@component | ||
class TextDocumentSplitter: | ||
""" | ||
Splits a list of text documents into a list of text documents with shorter texts. | ||
This is useful for splitting documents with long texts that otherwise would not fit into the maximum text length of language models. | ||
""" | ||
|
||
def __init__( | ||
self, split_by: Literal["word", "sentence", "passage"] = "word", split_length: int = 200, split_overlap: int = 0 | ||
): | ||
""" | ||
:param split_by: The unit by which the document should be split. Choose from "word" for splitting by " ", | ||
"sentence" for splitting by ".", or "passage" for splitting by "\n\n". | ||
:param split_length: The maximum number of units in each split. | ||
:param split_overlap: The number of units that each split should overlap. | ||
""" | ||
|
||
self.split_by = split_by | ||
if split_by not in ["word", "sentence", "passage"]: | ||
raise ValueError("split_by must be one of 'word', 'sentence' or 'passage'.") | ||
if split_length <= 0: | ||
raise ValueError("split_length must be greater than 0.") | ||
self.split_length = split_length | ||
if split_overlap < 0: | ||
raise ValueError("split_overlap must be greater than or equal to 0.") | ||
self.split_overlap = split_overlap | ||
|
||
@component.output_types(documents=List[Document]) | ||
def run(self, documents: List[Document]): | ||
""" | ||
Splits the documents by split_by after split_length units with an overlap of split_overlap units. | ||
Returns a list of documents with the split texts. | ||
A metadata field "source_id" is added to each document to keep track of the original document that was split. | ||
:param documents: The documents to split. | ||
:return: A list of documents with the split texts. | ||
""" | ||
|
||
if not documents or not isinstance(documents, list) or not isinstance(documents[0], Document): | ||
raise TypeError("TextDocumentSplitter expects a List of Documents as input.") | ||
split_docs = [] | ||
for doc in documents: | ||
if doc.text is None: | ||
raise ValueError( | ||
f"TextDocumentSplitter only works with text documents but document.text for document ID {doc.id} is None." | ||
) | ||
units = self._split_into_units(doc.text, self.split_by) | ||
text_splits = self._concatenate_units(units, self.split_length, self.split_overlap) | ||
metadata = deepcopy(doc.metadata) | ||
metadata["source_id"] = doc.id | ||
split_docs += [Document(text=txt, metadata=metadata) for txt in text_splits] | ||
return {"documents": split_docs} | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
""" | ||
return default_to_dict( | ||
self, split_by=self.split_by, split_length=self.split_length, split_overlap=self.split_overlap | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "TextDocumentSplitter": | ||
""" | ||
Deserialize this component from a dictionary. | ||
""" | ||
return default_from_dict(cls, data) | ||
|
||
def _split_into_units(self, text: str, split_by: Literal["word", "sentence", "passage"]) -> List[str]: | ||
if split_by == "passage": | ||
split_at = "\n\n" | ||
elif split_by == "sentence": | ||
split_at = "." | ||
elif split_by == "word": | ||
split_at = " " | ||
else: | ||
raise NotImplementedError( | ||
"TextDocumentSplitter only supports 'passage', 'sentence' or 'word' split_by options." | ||
) | ||
units = text.split(split_at) | ||
# Add the delimiter back to all units except the last one | ||
for i in range(len(units) - 1): | ||
units[i] += split_at | ||
return units | ||
|
||
def _concatenate_units(self, elements: List[str], split_length: int, split_overlap: int) -> List[str]: | ||
""" | ||
Concatenates the elements into parts of split_length units. | ||
""" | ||
text_splits = [] | ||
segments = windowed(elements, n=split_length, step=split_length - split_overlap) | ||
for seg in segments: | ||
current_units = [unit for unit in seg if unit is not None] | ||
txt = "".join(current_units) | ||
if len(txt) > 0: | ||
text_splits.append(txt) | ||
return text_splits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
preview: | ||
- | | ||
Add the `TextDocumentSplitter` component for Haystack 2.0 that splits a Document with long text into multiple Documents with shorter texts. Thereby the texts match the maximum length that the language models in Embedders or other components can process. |
Empty file.
157 changes: 157 additions & 0 deletions
157
test/preview/components/preprocessors/test_text_document_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import pytest | ||
|
||
from haystack.preview import Document | ||
from haystack.preview.components.preprocessors import TextDocumentSplitter | ||
|
||
|
||
class TestTextDocumentSplitter: | ||
@pytest.mark.unit | ||
def test_non_text_document(self): | ||
with pytest.raises( | ||
ValueError, match="TextDocumentSplitter only works with text documents but document.text for document ID" | ||
): | ||
splitter = TextDocumentSplitter() | ||
splitter.run(documents=[Document()]) | ||
|
||
@pytest.mark.unit | ||
def test_single_doc(self): | ||
with pytest.raises(TypeError, match="TextDocumentSplitter expects a List of Documents as input."): | ||
splitter = TextDocumentSplitter() | ||
splitter.run(documents=Document()) | ||
|
||
@pytest.mark.unit | ||
def test_empty_list(self): | ||
with pytest.raises(TypeError, match="TextDocumentSplitter expects a List of Documents as input."): | ||
splitter = TextDocumentSplitter() | ||
splitter.run(documents=[]) | ||
|
||
@pytest.mark.unit | ||
def test_unsupported_split_by(self): | ||
with pytest.raises(ValueError, match="split_by must be one of 'word', 'sentence' or 'passage'."): | ||
TextDocumentSplitter(split_by="unsupported") | ||
|
||
@pytest.mark.unit | ||
def test_unsupported_split_length(self): | ||
with pytest.raises(ValueError, match="split_length must be greater than 0."): | ||
TextDocumentSplitter(split_length=0) | ||
|
||
@pytest.mark.unit | ||
def test_unsupported_split_overlap(self): | ||
with pytest.raises(ValueError, match="split_overlap must be greater than or equal to 0."): | ||
TextDocumentSplitter(split_overlap=-1) | ||
|
||
@pytest.mark.unit | ||
def test_split_by_word(self): | ||
splitter = TextDocumentSplitter(split_by="word", split_length=10) | ||
result = splitter.run( | ||
documents=[ | ||
Document( | ||
text="This is a text with some words. There is a second sentence. And there is a third sentence." | ||
) | ||
] | ||
) | ||
assert len(result["documents"]) == 2 | ||
assert result["documents"][0].text == "This is a text with some words. There is a " | ||
assert result["documents"][1].text == "second sentence. And there is a third sentence." | ||
|
||
@pytest.mark.unit | ||
def test_split_by_word_multiple_input_docs(self): | ||
splitter = TextDocumentSplitter(split_by="word", split_length=10) | ||
result = splitter.run( | ||
documents=[ | ||
Document( | ||
text="This is a text with some words. There is a second sentence. And there is a third sentence." | ||
), | ||
Document( | ||
text="This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence." | ||
), | ||
] | ||
) | ||
assert len(result["documents"]) == 5 | ||
assert result["documents"][0].text == "This is a text with some words. There is a " | ||
assert result["documents"][1].text == "second sentence. And there is a third sentence." | ||
assert result["documents"][2].text == "This is a different text with some words. There is " | ||
assert result["documents"][3].text == "a second sentence. And there is a third sentence. And " | ||
assert result["documents"][4].text == "there is a fourth sentence." | ||
|
||
@pytest.mark.unit | ||
def test_split_by_sentence(self): | ||
splitter = TextDocumentSplitter(split_by="sentence", split_length=1) | ||
result = splitter.run( | ||
documents=[ | ||
Document( | ||
text="This is a text with some words. There is a second sentence. And there is a third sentence." | ||
) | ||
] | ||
) | ||
assert len(result["documents"]) == 3 | ||
assert result["documents"][0].text == "This is a text with some words." | ||
assert result["documents"][1].text == " There is a second sentence." | ||
assert result["documents"][2].text == " And there is a third sentence." | ||
|
||
@pytest.mark.unit | ||
def test_split_by_passage(self): | ||
splitter = TextDocumentSplitter(split_by="passage", split_length=1) | ||
result = splitter.run( | ||
documents=[ | ||
Document( | ||
text="This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage." | ||
) | ||
] | ||
) | ||
assert len(result["documents"]) == 3 | ||
assert result["documents"][0].text == "This is a text with some words. There is a second sentence.\n\n" | ||
assert result["documents"][1].text == "And there is a third sentence.\n\n" | ||
assert result["documents"][2].text == " And another passage." | ||
|
||
@pytest.mark.unit | ||
def test_split_by_word_with_overlap(self): | ||
splitter = TextDocumentSplitter(split_by="word", split_length=10, split_overlap=2) | ||
result = splitter.run( | ||
documents=[ | ||
Document( | ||
text="This is a text with some words. There is a second sentence. And there is a third sentence." | ||
) | ||
] | ||
) | ||
assert len(result["documents"]) == 2 | ||
assert result["documents"][0].text == "This is a text with some words. There is a " | ||
assert result["documents"][1].text == "is a second sentence. And there is a third sentence." | ||
|
||
@pytest.mark.unit | ||
def test_to_dict(self): | ||
splitter = TextDocumentSplitter() | ||
data = splitter.to_dict() | ||
assert data == { | ||
"type": "TextDocumentSplitter", | ||
"init_parameters": {"split_by": "word", "split_length": 200, "split_overlap": 0}, | ||
} | ||
|
||
@pytest.mark.unit | ||
def test_to_dict_with_custom_init_parameters(self): | ||
splitter = TextDocumentSplitter(split_by="passage", split_length=100, split_overlap=1) | ||
data = splitter.to_dict() | ||
assert data == { | ||
"type": "TextDocumentSplitter", | ||
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1}, | ||
} | ||
|
||
@pytest.mark.unit | ||
def test_from_dict(self): | ||
data = { | ||
"type": "TextDocumentSplitter", | ||
"init_parameters": {"split_by": "passage", "split_length": 100, "split_overlap": 1}, | ||
} | ||
splitter = TextDocumentSplitter.from_dict(data) | ||
assert splitter.split_by == "passage" | ||
assert splitter.split_length == 100 | ||
assert splitter.split_overlap == 1 | ||
|
||
@pytest.mark.unit | ||
def test_source_id_stored_in_metadata(self): | ||
splitter = TextDocumentSplitter(split_by="word", split_length=10) | ||
doc1 = Document(text="This is a text with some words.") | ||
doc2 = Document(text="This is a different text with some words.") | ||
result = splitter.run(documents=[doc1, doc2]) | ||
assert result["documents"][0].metadata["source_id"] == doc1.id | ||
assert result["documents"][1].metadata["source_id"] == doc2.id |