Skip to content

Commit

Permalink
feat: Add FileExtensionClassifier to previews (deepset-ai#5514)
Browse files Browse the repository at this point in the history
* Add FileExtensionClassifier preview component

* Add release note

* PR feedback
  • Loading branch information
vblagoje authored and DosticJelena committed Aug 23, 2023
1 parent 099d064 commit 8ff821f
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 0 deletions.
1 change: 1 addition & 0 deletions haystack/preview/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from haystack.preview.components.audio.whisper_local import LocalWhisperTranscriber
from haystack.preview.components.audio.whisper_remote import RemoteWhisperTranscriber
from haystack.preview.components.file_converters import TextFileToDocument
from haystack.preview.components.classifiers import FileExtensionClassifier
1 change: 1 addition & 0 deletions haystack/preview/components/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
82 changes: 82 additions & 0 deletions haystack/preview/components/classifiers/file_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import logging
import mimetypes
from collections import defaultdict
from pathlib import Path
from typing import List, Union, Optional

from haystack.preview import component

logger = logging.getLogger(__name__)


@component
class FileExtensionClassifier:
"""
A component that classifies files based on their MIME types read from their file extensions. This component
does not read the file contents, but rather uses the file extension to determine the MIME type of the file.
The FileExtensionClassifier takes a list of file paths and groups them by their MIME types.
The list of MIME types to consider is provided during the initialization of the component.
This component is particularly useful when working with a large number of files, and you
want to categorize them based on their MIME types.
"""

def __init__(self, mime_types: List[str]):
"""
Initialize the FileExtensionClassifier.
:param mime_types: A list of file mime types to consider when classifying
files (e.g. ["text/plain", "audio/x-wav", "image/jpeg"]).
"""
if not mime_types:
raise ValueError("The list of mime types cannot be empty.")

all_known_mime_types = all(self.is_valid_mime_type_format(mime_type) for mime_type in mime_types)
if not all_known_mime_types:
raise ValueError(f"The list of mime types contains unknown mime types: {mime_types}")

# save the init parameters for serialization
self.init_parameters = {"mime_types": mime_types}

component.set_output_types(self, unclassified=List[Path], **{mime_type: List[Path] for mime_type in mime_types})
self.mime_types = mime_types

def run(self, paths: List[Union[str, Path]]):
"""
Run the FileExtensionClassifier.
This method takes the input data, iterates through the provided file paths, checks the file
mime type of each file, and groups the file paths by their mime types.
:param paths: The input data containing the file paths to classify.
:return: The output data containing the classified file paths.
"""
mime_types = defaultdict(list)
for path in paths:
if isinstance(path, str):
path = Path(path)
mime_type = self.get_mime_type(path)
if mime_type in self.mime_types:
mime_types[mime_type].append(path)
else:
mime_types["unclassified"].append(path)

return mime_types

def get_mime_type(self, path: Path) -> Optional[str]:
"""
Get the MIME type of the provided file path.
:param path: The file path to get the MIME type for.
:return: The MIME type of the provided file path, or None if the MIME type cannot be determined.
"""
return mimetypes.guess_type(path.as_posix())[0]

def is_valid_mime_type_format(self, mime_type: str) -> bool:
"""
Check if the provided MIME type is in valid format
:param mime_type: The MIME type to check.
:return: True if the provided MIME type is a valid MIME type format, False otherwise.
"""
return mime_type in mimetypes.types_map.values()
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Adds FileExtensionClassifier to preview components.
Empty file.
89 changes: 89 additions & 0 deletions test/preview/components/classifiers/test_file_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import sys

import pytest

from haystack.preview.components.classifiers.file_classifier import FileExtensionClassifier
from test.preview.components.base import BaseTestComponent
from test.conftest import preview_samples_path


@pytest.mark.skipif(
sys.platform in ["win32", "cygwin"],
reason="Can't run on Windows Github CI, need access to registry to get mime types",
)
class TestFileExtensionClassifier(BaseTestComponent):
@pytest.mark.unit
def test_save_load(self, tmp_path):
self.assert_can_be_saved_and_loaded_in_pipeline(
FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"]), tmp_path
)

@pytest.mark.unit
def test_run(self, preview_samples_path):
"""
Test if the component runs correctly in the simplest happy path.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2.txt",
preview_samples_path / "audio" / "the context for this answer is here.wav",
preview_samples_path / "images" / "apple.jpg",
]

classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=file_paths)
assert output
assert len(output["text/plain"]) == 2
assert len(output["audio/x-wav"]) == 1
assert len(output["image/jpeg"]) == 1
assert not output["unclassified"]

@pytest.mark.unit
def test_no_files(self):
"""
Test that the component runs correctly when no files are provided.
"""
classifier = FileExtensionClassifier(mime_types=["text/plain", "audio/x-wav", "image/jpeg"])
output = classifier.run(paths=[])
assert not output

@pytest.mark.unit
def test_unlisted_extensions(self, preview_samples_path):
"""
Test that the component correctly handles files with non specified mime types.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "audio" / "ignored.mp3",
preview_samples_path / "audio" / "this is the content of the document.wav",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 1
assert "mp3" not in output
assert len(output["unclassified"]) == 2
assert str(output["unclassified"][0]).endswith("ignored.mp3")
assert str(output["unclassified"][1]).endswith("this is the content of the document.wav")

@pytest.mark.unit
def test_no_extension(self, preview_samples_path):
"""
Test that the component ignores files with no extension.
"""
file_paths = [
preview_samples_path / "txt" / "doc_1.txt",
preview_samples_path / "txt" / "doc_2",
preview_samples_path / "txt" / "doc_2.txt",
]
classifier = FileExtensionClassifier(mime_types=["text/plain"])
output = classifier.run(paths=file_paths)
assert len(output["text/plain"]) == 2
assert len(output["unclassified"]) == 1

@pytest.mark.unit
def test_unknown_mime_type(self):
"""
Test that the component handles files with unknown mime types.
"""
with pytest.raises(ValueError, match="The list of mime types"):
FileExtensionClassifier(mime_types=["type_invalid"])

0 comments on commit 8ff821f

Please sign in to comment.