forked from deepset-ai/haystack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add FileExtensionClassifier to previews (deepset-ai#5514)
* Add FileExtensionClassifier preview component * Add release note * PR feedback
- Loading branch information
Showing
6 changed files
with
177 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber | ||
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber | ||
from haystack.preview.components.file_converters import TextFileToDocument | ||
from haystack.preview.components.classifiers import FileExtensionClassifier |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier |
82 changes: 82 additions & 0 deletions
82
haystack/preview/components/classifiers/file_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import logging | ||
import mimetypes | ||
from collections import defaultdict | ||
from pathlib import Path | ||
from typing import List, Union, Optional | ||
|
||
from haystack.preview import component | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@component | ||
class FileExtensionClassifier: | ||
""" | ||
A component that classifies files based on their MIME types read from their file extensions. This component | ||
does not read the file contents, but rather uses the file extension to determine the MIME type of the file. | ||
The FileExtensionClassifier takes a list of file paths and groups them by their MIME types. | ||
The list of MIME types to consider is provided during the initialization of the component. | ||
This component is particularly useful when working with a large number of files, and you | ||
want to categorize them based on their MIME types. | ||
""" | ||
|
||
def __init__(self, mime_types: List[str]): | ||
""" | ||
Initialize the FileExtensionClassifier. | ||
:param mime_types: A list of file mime types to consider when classifying | ||
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]). | ||
""" | ||
if not mime_types: | ||
raise ValueError("The list of mime types cannot be empty.") | ||
|
||
all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types) | ||
if not all_known_mime_types: | ||
raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}") | ||
|
||
# save the init parameters for serialization | ||
self.init_parameters = {"mime_types": mime_types} | ||
|
||
component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types}) | ||
self.mime_types = mime_types | ||
|
||
def run(self, paths: List[Union[str, Path]]): | ||
""" | ||
Run the FileExtensionClassifier. | ||
This method takes the input data, iterates through the provided file paths, checks the file | ||
mime type of each file, and groups the file paths by their mime types. | ||
:param paths: The input data containing the file paths to classify. | ||
:return: The output data containing the classified file paths. | ||
""" | ||
mime_types = defaultdict(list) | ||
for path in paths: | ||
if isinstance(path, str): | ||
path = Path(path) | ||
mime_type = self.get_mime_type(path) | ||
if mime_type in self.mime_types: | ||
mime_types[mime_type].append(path) | ||
else: | ||
mime_types["unclassified"].append(path) | ||
|
||
return mime_types | ||
|
||
def get_mime_type(self, path: Path) -> Optional[str]: | ||
""" | ||
Get the MIME type of the provided file path. | ||
:param path: The file path to get the MIME type for. | ||
:return: The MIME type of the provided file path, or None if the MIME type cannot be determined. | ||
""" | ||
return mimetypes.guess_type(path.as_posix())[0] | ||
|
||
def is_valid_mime_type_format(self, mime_type: str) -> bool: | ||
""" | ||
Check if the provided MIME type is in valid format | ||
:param mime_type: The MIME type to check. | ||
:return: True if the provided MIME type is a valid MIME type format, False otherwise. | ||
""" | ||
return mime_type in mimetypes.types_map.values() |
4 changes: 4 additions & 0 deletions
4
releasenotes/notes/add-file-extension-classifier-preview-40f31c27bbd7cff9.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- | ||
features: | ||
- | | ||
Adds FileExtensionClassifier to preview components. |
Empty file.
89 changes: 89 additions & 0 deletions
89
test/preview/components/classifiers/test_file_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import sys | ||
|
||
import pytest | ||
|
||
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier | ||
from test.preview.components.base import BaseTestComponent | ||
from test.conftest import preview_samples_path | ||
|
||
|
||
@pytest.mark.skipif( | ||
sys.platform in ["win32", "cygwin"], | ||
reason="Can't run on Windows Github CI, need access to registry to get mime types", | ||
) | ||
class TestFileExtensionClassifier(BaseTestComponent): | ||
@pytest.mark.unit | ||
def test_save_load(self, tmp_path): | ||
self.assert_can_be_saved_and_loaded_in_pipeline( | ||
FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path | ||
) | ||
|
||
@pytest.mark.unit | ||
def test_run(self, preview_samples_path): | ||
""" | ||
Test if the component runs correctly in the simplest happy path. | ||
""" | ||
file_paths = [ | ||
preview_samples_path / "txt" / "doc_1.txt", | ||
preview_samples_path / "txt" / "doc_2.txt", | ||
preview_samples_path / "audio" / "the context for this answer is here.wav", | ||
preview_samples_path / "images" / "apple.jpg", | ||
] | ||
|
||
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) | ||
output = classifier.run(paths=file_paths) | ||
assert output | ||
assert len(output["text/plain"]) == 2 | ||
assert len(output["audio/x-wav"]) == 1 | ||
assert len(output["image/jpeg"]) == 1 | ||
assert not output["unclassified"] | ||
|
||
@pytest.mark.unit | ||
def test_no_files(self): | ||
""" | ||
Test that the component runs correctly when no files are provided. | ||
""" | ||
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]) | ||
output = classifier.run(paths=[]) | ||
assert not output | ||
|
||
@pytest.mark.unit | ||
def test_unlisted_extensions(self, preview_samples_path): | ||
""" | ||
Test that the component correctly handles files with non specified mime types. | ||
""" | ||
file_paths = [ | ||
preview_samples_path / "txt" / "doc_1.txt", | ||
preview_samples_path / "audio" / "ignored.mp3", | ||
preview_samples_path / "audio" / "this is the content of the document.wav", | ||
] | ||
classifier = FileExtensionClassifier(mime_types=["text/plain"]) | ||
output = classifier.run(paths=file_paths) | ||
assert len(output["text/plain"]) == 1 | ||
assert "mp3" not in output | ||
assert len(output["unclassified"]) == 2 | ||
assert str(output["unclassified"][0]).endswith("ignored.mp3") | ||
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav") | ||
|
||
@pytest.mark.unit | ||
def test_no_extension(self, preview_samples_path): | ||
""" | ||
Test that the component ignores files with no extension. | ||
""" | ||
file_paths = [ | ||
preview_samples_path / "txt" / "doc_1.txt", | ||
preview_samples_path / "txt" / "doc_2", | ||
preview_samples_path / "txt" / "doc_2.txt", | ||
] | ||
classifier = FileExtensionClassifier(mime_types=["text/plain"]) | ||
output = classifier.run(paths=file_paths) | ||
assert len(output["text/plain"]) == 2 | ||
assert len(output["unclassified"]) == 1 | ||
|
||
@pytest.mark.unit | ||
def test_unknown_mime_type(self): | ||
""" | ||
Test that the component handles files with unknown mime types. | ||
""" | ||
with pytest.raises(ValueError, match="The list of mime types"): | ||
FileExtensionClassifier(mime_types=["type_invalid"]) |