Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Polars generic dataset (take 3) #170

Merged
merged 12 commits into from
Aug 30, 2023
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Upcoming Release
## Major features and improvements
* Added `polars.GenericDataSet`, a `GenericDataSet` backed by [polars](https://www.pola.rs/), a lightning fast dataframe package built entirely using Rust.
astrojuanlu marked this conversation as resolved.
Show resolved Hide resolved

## Bug fixes and other changes
## Community contributions

Expand Down
4 changes: 3 additions & 1 deletion kedro-datasets/kedro_datasets/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901
CSVDataSet: Any
GenericDataSet: Any

__getattr__, __dir__, __all__ = lazy.attach(
__name__, submod_attrs={"csv_dataset": ["CSVDataSet"]}
__name__,
submod_attrs={"csv_dataset": ["CSVDataSet"], "generic_dataset": ["GenericDataSet"]},
)
205 changes: 205 additions & 0 deletions kedro-datasets/kedro_datasets/polars/generic_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""``GenericDataSet`` loads/saves data from/to a data file using an underlying
filesystem (e.g.: local, S3, GCS). It uses polars to handle the
type of read/write target.
"""
from copy import deepcopy
from io import BytesIO
from pathlib import PurePosixPath
from typing import Any, Dict

import fsspec
import polars as pl
from kedro.io.core import (
AbstractVersionedDataSet,
DataSetError,
Version,
get_filepath_str,
get_protocol_and_path,
)


# pylint: disable=too-many-instance-attributes
class GenericDataSet(AbstractVersionedDataSet[pl.DataFrame, pl.DataFrame]):
"""`polars.GenericDataSet` loads/saves data from/to a data file using an underlying
filesystem (e.g.: local, S3, GCS). It uses polars to dynamically select the
appropriate type of read/write target on a best effort basis.
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:
.. code-block:: yaml
cars:
type: polars.GenericDataSet
file_format: parquet
filepath: s3://data/01_raw/company/cars.parquet
load_args:
low_memory: True
save_args:
compression: "snappy"

Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-code-api>`_:
::
>>> from kedro_datasets.polars import GenericDataSet
>>> import polars as pl
>>>
>>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5],
>>> 'col3': [5, 6]})
>>>
>>> data_set = GenericDataSet(filepath="test.parquet", file_format='parquet')
>>> data_set.save(data)
>>> reloaded = data_set.load()
>>> assert data.frame_equal(reloaded)
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
filepath: str,
file_format: str,
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
):
"""Creates a new instance of ``GenericDataSet`` pointing to a concrete data file
on a specific filesystem. The appropriate polars load/save methods are
dynamically identified by string matching on a best effort basis.
Args:
filepath: Filepath in POSIX format to a file prefixed with a protocol like
`s3://`.
If prefix is not provided, `file` protocol (local filesystem)
will be used.
The prefix should be any protocol supported by ``fsspec``.
Key assumption: The first argument of either load/save method points to
a filepath/buffer/io type location. There are some read/write targets
such as 'clipboard' or 'records' that will fail since they do not take a
filepath like argument.
file_format: String which is used to match the appropriate load/save method
on a best effort basis. For example if 'csv' is passed the
`polars.read_csv` and
`polars.DataFrame.write_csv` methods will be identified. An error will
be raised unless
at least one matching `read_{file_format}` or `write_{file_format}`.
load_args: polars options for loading files.
Here you can find all available arguments:
https://pola-rs.github.io/polars/py-polars/html/reference/io.html
All defaults are preserved.
save_args: Polars options for saving files.
Here you can find all available arguments:
https://pola-rs.github.io/polars/py-polars/html/reference/io.html
All defaults are preserved.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
Raises:
DataSetError: Will be raised if at least less than one appropriate
read or write methods are identified.
"""

self._file_format = file_format.lower()

_fs_args = deepcopy(fs_args) or {}
_fs_open_args_load = _fs_args.pop("open_args_load", {})
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

protocol, path = get_protocol_and_path(filepath)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)

self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)

self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save

def _load(self) -> pl.DataFrame: # pylint: disable= inconsistent-return-statements
load_path = get_filepath_str(self._get_load_path(), self._protocol)
load_method = getattr(pl, f"read_{self._file_format}", None)

if not load_method:
raise DataSetError(
f"Unable to retrieve 'polars.read_{self._file_format}' method, please"
" ensure that your "
"'file_format' parameter has been defined correctly as per the Polars"
" API"
" https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return load_method(fs_file, **self._load_args)

def _save(self, data: pl.DataFrame) -> None:
save_path = get_filepath_str(self._get_save_path(), self._protocol)
save_method = getattr(data, f"write_{self._file_format}", None)

if not save_method:
raise DataSetError(
f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' "
"method, please "
"ensure that your 'file_format' parameter has been defined correctly as"
" per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)
buf = BytesIO()
save_method(buf, **self._save_args)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
fs_file.write(buf.getvalue())
self._invalidate_cache()

def _exists(self) -> bool:
try:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
except DataSetError:
return False

return self._fs.exists(load_path)

def _describe(self) -> Dict[str, Any]:
return {
"file_format": self._file_format,
"filepath": self._filepath,
"protocol": self._protocol,
"load_args": self._load_args,
"save_args": self._save_args,
"version": self._version,
}

def _release(self) -> None:
super()._release()
self._invalidate_cache()

def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)
12 changes: 9 additions & 3 deletions kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SPARK = "pyspark>=2.2, <3.4"
HDFS = "hdfs>=2.5.8, <3.0"
S3FS = "s3fs>=0.3.0, <0.5"
POLARS = "polars~=0.17.0"
POLARS = "polars~=0.18.0"
DELTA = "delta-spark~=1.2.1"


Expand Down Expand Up @@ -52,7 +52,13 @@ def _collect_requirements(requires):
"plotly.PlotlyDataSet": [PANDAS, "plotly>=4.8.0, <6.0"],
"plotly.JSONDataSet": ["plotly>=4.8.0, <6.0"],
}
polars_require = {"polars.CSVDataSet": [POLARS]}
polars_require = {
"polars.CSVDataSet": [POLARS],
"polars.GenericDataSet":
[
POLARS, "pyarrow>=4.0", "xlsx2csv>=0.8.0", "deltalake >= 0.6.2"
],
}
redis_require = {"redis.PickleDataSet": ["redis~=4.1"]}
snowflake_require = {
"snowflake.SnowparkTableDataSet": [
Expand Down Expand Up @@ -179,7 +185,7 @@ def _collect_requirements(requires):
"pandas~=1.3 # 1.3 for read_xml/to_xml",
"Pillow~=9.0",
"plotly>=4.8.0, <6.0",
"polars~=0.15.13",
"polars[xlsx2csv, deltalake]~=0.18.0",
"pre-commit>=2.9.2, <3.0", # The hook `mypy` requires pre-commit version 2.9.2.
"pyarrow>=1.0; python_version < '3.11'",
"pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors
Expand Down
Loading