Skip to content

Commit

Permalink
Add progress bar to filesystemblob loader, update pytest config for u…
Browse files Browse the repository at this point in the history
…nit tests (#4212)

This PR adds:

* Option to show a tqdm progress bar when using the file system blob loader
* Update pytest run configuration to be stricter
* Adding a new marker that checks that required pkgs exist
  • Loading branch information
eyurtsev authored May 8, 2023
1 parent f4c8502 commit aa11f7c
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 2 deletions.
56 changes: 54 additions & 2 deletions langchain/document_loaders/blob_loaders/file_system.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,38 @@
"""Use to load blobs from the local file system."""
from pathlib import Path
from typing import Iterable, Optional, Sequence, Union
from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar, Union

from langchain.document_loaders.blob_loaders.schema import Blob, BlobLoader

T = TypeVar("T")


def _make_iterator(
length_func: Callable[[], int], show_progress: bool = False
) -> Callable[[Iterable[T]], Iterator[T]]:
"""Create a function that optionally wraps an iterable in tqdm."""
if show_progress:
try:
from tqdm.auto import tqdm
except ImportError:
raise ImportError(
"You must install tqdm to use show_progress=True."
"You can install tqdm with `pip install tqdm`."
)

# Make sure to provide `total` here so that tqdm can show
# a progress bar that takes into account the total number of files.
def _with_tqdm(iterable: Iterable[T]) -> Iterator[T]:
"""Wrap an iterable in a tqdm progress bar."""
return tqdm(iterable, total=length_func())

iterator = _with_tqdm
else:
iterator = iter # type: ignore

return iterator


# PUBLIC API


Expand All @@ -26,6 +55,7 @@ def __init__(
*,
glob: str = "**/[!.]*",
suffixes: Optional[Sequence[str]] = None,
show_progress: bool = False,
) -> None:
"""Initialize with path to directory and how to glob over it.
Expand All @@ -36,6 +66,9 @@ def __init__(
suffixes: Provide to keep only files with these suffixes
Useful when wanting to keep files with different suffixes
Suffixes must include the dot, e.g. ".txt"
show_progress: If true, will show a progress bar as the files are loaded.
This forces an iteration through all matching files
to count them prior to loading them.
Examples:
Expand All @@ -60,14 +93,33 @@ def __init__(
self.path = _path
self.glob = glob
self.suffixes = set(suffixes or [])
self.show_progress = show_progress

def yield_blobs(
self,
) -> Iterable[Blob]:
"""Yield blobs that match the requested pattern."""
iterator = _make_iterator(
length_func=self.count_matching_files, show_progress=self.show_progress
)

for path in iterator(self._yield_paths()):
yield Blob.from_path(path)

def _yield_paths(self) -> Iterable[Path]:
"""Yield paths that match the requested pattern."""
paths = self.path.glob(self.glob)
for path in paths:
if path.is_file():
if self.suffixes and path.suffix not in self.suffixes:
continue
yield Blob.from_path(str(path))
yield path

def count_matching_files(self) -> int:
"""Count files that match the pattern without loading them."""
# Carry out a full iteration to count the files without
# materializing anything expensive in memory.
num = 0
for _ in self._yield_paths():
num += 1
return num
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,17 @@ omit = [
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"requires: mark tests as requiring a specific library"
]
44 changes: 44 additions & 0 deletions tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Configuration for unit tests."""
from importlib import util
from typing import Dict, Sequence

import pytest
from pytest import Config, Function


def pytest_collection_modifyitems(config: Config, items: Sequence[Function]) -> None:
"""Add implementations for handling custom markers.
At the moment, this adds support for a custom `requires` marker.
The `requires` marker is used to denote tests that require one or more packages
to be installed to run. If the package is not installed, the test is skipped.
The `requires` marker syntax is:
.. code-block:: python
@pytest.mark.requires("package1", "package2")
def test_something():
...
"""
# Mapping from the name of a package to whether it is installed or not.
# Used to avoid repeated calls to `util.find_spec`
required_pkgs_info: Dict[str, bool] = {}

for item in items:
requires_marker = item.get_closest_marker("requires")
if requires_marker is not None:
# Iterate through the list of required packages
required_pkgs = requires_marker.args
for pkg in required_pkgs:
# If we haven't yet checked whether the pkg is installed
# let's check it and store the result.
if pkg not in required_pkgs_info:
required_pkgs_info[pkg] = util.find_spec(pkg) is not None

if not required_pkgs_info[pkg]:
# If the package is not installed, we immediately break
# and mark the test as skipped.
item.add_marker(pytest.mark.skip(reason=f"requires pkg: `{pkg}`"))
break
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def test_file_names_exist(
loader = FileSystemBlobLoader(toy_dir, glob=glob, suffixes=suffixes)
blobs = list(loader.yield_blobs())

assert loader.count_matching_files() == len(relative_filenames)

file_names = sorted(str(blob.path) for blob in blobs)

expected_filenames = sorted(
Expand All @@ -99,3 +101,11 @@ def test_file_names_exist(
)

assert file_names == expected_filenames


@pytest.mark.requires("tqdm")
def test_show_progress(toy_dir: str) -> None:
"""Verify that file system loader works with a progress bar."""
loader = FileSystemBlobLoader(toy_dir)
blobs = list(loader.yield_blobs())
assert len(blobs) == loader.count_matching_files()

0 comments on commit aa11f7c

Please sign in to comment.