-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement a FSspec based ArtifactLoader #100
Changes from 3 commits
facbc71
64777cd
36730c0
3157354
3f26bd9
ce7f0dd
35141f4
ed83dd9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,17 @@ | |
import copy | ||
import inspect | ||
import os | ||
import shutil | ||
import weakref | ||
from abc import ABCMeta, abstractmethod | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from tempfile import mkdtemp | ||
from typing import Any, Callable, Generic, Iterable, Iterator, Literal, TypeVar | ||
|
||
from fsspec import core, filesystem, utils | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please import the specific APIs from these submodules, not the whole modules themselves. |
||
from fsspec.implementations.local import LocalFileSystem | ||
|
||
from nnbench.context import Context | ||
|
||
T = TypeVar("T") | ||
|
@@ -110,22 +117,72 @@ class LocalArtifactLoader(ArtifactLoader): | |
Parameters | ||
---------- | ||
path : str | os.PathLike[str] | ||
The file system pathto the artifact. | ||
The file system path to the artifact. | ||
""" | ||
|
||
def __init__(self, path: str | os.PathLike[str]): | ||
def __init__(self, path: str | os.PathLike[str]) -> None: | ||
self._path = path | ||
|
||
def load(self): | ||
def load(self) -> Path: | ||
""" | ||
Returns the path to the artifact on the local file system. | ||
""" | ||
return self._path | ||
return Path(self._path).resolve() | ||
|
||
|
||
class S3ArtifactLoader(ArtifactLoader): | ||
# TODO: Implement this and other common ArtifactLoders here or in a util | ||
pass | ||
class FilePathArtifactLoader(ArtifactLoader): | ||
""" | ||
ArtifactLoader for loading artifacts using fsspec-supported file systems. | ||
|
||
This allows for loading from various file systems like local, S3, GCS, etc., | ||
by using a unified API provided by fsspec. | ||
|
||
Parameters | ||
---------- | ||
path : str | os.PathLike[str] | ||
The path to the artifact, which can include a protocol specifier (like 's3://') for remote access. | ||
to_localdir : str | os.PathLike[str] | None | ||
The local directory to which remote artifacts will be downloaded. If provided, the model data will be persisted. Otherwise, local artifacts are cleaned. | ||
storage_options : dict[str, Any] | None | ||
Storage options for remote storage. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
path: str | os.PathLike[str], | ||
to_localdir: str | os.PathLike[str] | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should have a different name, it currently sounds like a bool. Maybe |
||
storage_options: dict[str, Any] | None = None, | ||
) -> None: | ||
self.source_path = str(path) | ||
if to_localdir: | ||
self.target_path = Path(to_localdir) | ||
delete = False | ||
else: | ||
self.target_path = Path(mkdtemp()) | ||
delete = True | ||
self._finalizer = weakref.finalize(self, self._cleanup, delete=delete) | ||
self.storage_options = storage_options or {} | ||
|
||
def load(self) -> Path: | ||
""" | ||
Loads the artifact, downloading it if necessary, and returns the local path. | ||
|
||
Returns | ||
------- | ||
Path | ||
The path to the artifact on the local filesystem. | ||
""" | ||
fs, _, _ = core.get_fs_token_paths(self.source_path) | ||
if isinstance(fs, LocalFileSystem): | ||
return Path(self.source_path).resolve() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this caching a la "check if it exists, only download if it doesn't"? (Otherwise it's probably currently missing) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More or less. My understanding is, that this class, as a general filesystem class, should also be able to use local filepaths. If it is a remote path, we download and put it either (a) a specified local dir, or (b) a tmp dir. If a) the files are persisted on disk and the filepath can be used (by the user) in other places. If b) the cleanup task deletes the tmp dir and the model to not clutter the filesystem. That is how it is currently implemented. Do you agree with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might want to copy files into your target folder even if they are local? Seems like you could just omit the branch here. |
||
else: | ||
fs = filesystem(utils.get_protocol(self.source_path), **self.storage_options) | ||
fs.get(self.source_path, self.target_path, recursive=True) | ||
return Path(self.target_path).resolve() | ||
|
||
def _cleanup(self, delete: bool) -> None: | ||
if delete: | ||
shutil.rmtree(self.target_path) | ||
|
||
|
||
class Artifact(Generic[T], metaclass=ABCMeta): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from pathlib import Path | ||
|
||
from nnbench.types import FilePathArtifactLoader | ||
|
||
|
||
def test_load_local_file(local_file: Path, tmp_path: Path) -> None: | ||
test_dir = tmp_path / "test_load_dir" | ||
loader = FilePathArtifactLoader(str(local_file), str(test_dir)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should work without the string casts. |
||
loaded_path: Path = loader.load() | ||
assert loaded_path.exists() | ||
assert loaded_path.read_text() == "Test content" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, we should make this optional (i.e. by importing
fsspec
only locally in the artifact load method), and facilitate local loading with builtin methods.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree