diff --git a/pyproject.toml b/pyproject.toml index b2b67e7..f20b3aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dev = [ "psutil", "pyarrow", "pytest>=7.4.0", + "fsspec" ] docs = [ "black", @@ -85,7 +86,7 @@ strict_optional = false warn_unreachable = true [[tool.mypy.overrides]] -module = ["tabulate", "yaml"] +module = ["tabulate", "yaml", "fsspec"] ignore_missing_imports = true [tool.ruff] diff --git a/src/nnbench/types.py b/src/nnbench/types.py index 0d66885..cad760a 100644 --- a/src/nnbench/types.py +++ b/src/nnbench/types.py @@ -6,12 +6,31 @@ import copy import inspect import os +import shutil +import weakref from abc import ABCMeta, abstractmethod from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Iterable, Iterator, Literal, TypeVar +from pathlib import Path +from tempfile import mkdtemp +from typing import ( + Any, + Callable, + Generic, + Iterable, + Iterator, + Literal, + TypeVar, +) from nnbench.context import Context +try: + import fsspec + + HAS_FSSPEC = True +except ImportError: + HAS_FSSPEC = False + T = TypeVar("T") Variable = tuple[str, type, Any] @@ -110,22 +129,77 @@ class LocalArtifactLoader(ArtifactLoader): Parameters ---------- path : str | os.PathLike[str] - The file system pathto the artifact. + The file system path to the artifact. """ - def __init__(self, path: str | os.PathLike[str]): + def __init__(self, path: str | os.PathLike[str]) -> None: self._path = path - def load(self): + def load(self) -> Path: """ Returns the path to the artifact on the local file system. """ - return self._path + return Path(self._path).resolve() -class S3ArtifactLoader(ArtifactLoader): - # TODO: Implement this and other common ArtifactLoders here or in a util - pass +class FilePathArtifactLoader(ArtifactLoader): + """ + ArtifactLoader for loading artifacts using fsspec-supported file systems. + + This allows for loading from various file systems like local, S3, GCS, etc., + by using a unified API provided by fsspec. + + Parameters + ---------- + path : str | os.PathLike[str] + The path to the artifact, which can include a protocol specifier (like 's3://') for remote access. + destination : str | os.PathLike[str] | None + The local directory to which remote artifacts will be downloaded. If provided, the model data will be persisted. Otherwise, local artifacts are cleaned. + storage_options : dict[str, Any] | None + Storage options for remote storage. + """ + + def __init__( + self, + path: str | os.PathLike[str], + destination: str | os.PathLike[str] | None = None, + storage_options: dict[str, Any] | None = None, + ) -> None: + self.source_path = str(path) + if destination: + self.target_path = str(Path(destination).resolve()) + delete = False + else: + self.target_path = str(Path(mkdtemp()).resolve()) + delete = True + self._finalizer = weakref.finalize(self, self._cleanup, delete=delete) + self.storage_options = storage_options or {} + + def load(self) -> Path: + """ + Loads the artifact and returns the local path. + + Returns + ------- + Path + The path to the artifact on the local filesystem. + + Raises + ------ + ImportError + When fsspec is not installed. + """ + if not HAS_FSSPEC: + raise ImportError( + "class {self.__class__.__name__} requires `fsspec` to be installed. You can install it by running `python -m pip install --upgrade fsspec`" + ) + fs = fsspec.filesystem(fsspec.utils.get_protocol(self.source_path)) + fs.get(self.source_path, self.target_path, recursive=True) + return Path(self.target_path).resolve() + + def _cleanup(self, delete: bool) -> None: + if delete: + shutil.rmtree(self.target_path) class Artifact(Generic[T], metaclass=ABCMeta): diff --git a/tests/conftest.py b/tests/conftest.py index cca77f3..a120e63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,10 @@ def testfolder() -> str: """A test directory for benchmark collection.""" return str(HERE / "benchmarks") + + +@pytest.fixture +def local_file(tmp_path: Path) -> Path: + file_path = tmp_path / "test_file.txt" + file_path.write_text("Test content") + return file_path diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 0000000..4f48aeb --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,11 @@ +from pathlib import Path + +from nnbench.types import FilePathArtifactLoader + + +def test_load_local_file(local_file: Path, tmp_path: Path) -> None: + test_dir = tmp_path / "test_load_dir" + loader = FilePathArtifactLoader(local_file, test_dir) + loaded_path: Path = loader.load() + assert loaded_path.exists() + assert loaded_path.read_text() == "Test content"