Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Polars generic dataset (take 3) #170

Merged
merged 12 commits into from
Aug 30, 2023
1 change: 1 addition & 0 deletions kedro-datasets/RELEASE.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* Added automatic inference of file format for `pillow.ImageDataSet` to be passed to `save()`
* Added `polars.GenericDataSet`, a `GenericDataSet` backed by [polars](https://www.pola.rs/), a lightning fast dataframe package built entirely using Rust.
astrojuanlu marked this conversation as resolved.
Show resolved Hide resolved

## Bug fixes and other changes

Expand Down
3 changes: 2 additions & 1 deletion kedro-datasets/kedro_datasets/polars/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""``AbstractDataSet`` implementations that produce pandas DataFrames."""

__all__ = ["CSVDataSet"]
__all__ = ["CSVDataSet", "GenericDataSet"]

from contextlib import suppress

with suppress(ImportError):
from .csv_dataset import CSVDataSet
from .generic_dataset import GenericDataSet
258 changes: 258 additions & 0 deletions kedro-datasets/kedro_datasets/polars/generic_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
"""``GenericDataSet`` loads/saves data from/to a data file using an underlying
filesystem (e.g.: local, S3, GCS). It uses polars to handle the
type of read/write target.
"""
from copy import deepcopy
from io import BytesIO
from pathlib import PurePosixPath
from typing import Any, Dict

import fsspec
import polars as pl
from kedro.io.core import (
AbstractVersionedDataSet,
DataSetError,
Version,
get_filepath_str,
get_protocol_and_path,
)

ACCEPTED_WRITE_MODES = ["overwrite", "ignore"]

ACCEPTED_WRITE_FILE_FORMATS = [
"csv",
"ipc",
"parquet",
"json",
"ndjson",
"avro",
]
# always a superset of ACCEPTED_WRITE_FILE_FORMATS
ACCEPTED_READ_FILE_FORMATS = ACCEPTED_WRITE_FILE_FORMATS + [
"excel",
"delta",
]


# pylint: disable=too-many-instance-attributes
class GenericDataSet(AbstractVersionedDataSet[pl.DataFrame, pl.DataFrame]):
"""`polars.GenericDataSet` loads/saves data from/to a data file using an underlying
filesystem (e.g.: local, S3, GCS). It uses polars to dynamically select the
appropriate type of read/write target on a best effort basis.
Example usage for the
`YAML API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-yaml-api>`_:
.. code-block:: yaml
cars:
type: polars.GenericDataSet
file_format: parquet
filepath: s3://data/01_raw/company/cars.parquet
load_args:
low_memory: True
save_args:
compression: "snappy"

Example usage for the
`Python API <https://kedro.readthedocs.io/en/stable/data/\
data_catalog.html#use-the-data-catalog-with-the-code-api>`_:
::
>>> from kedro_datasets.polars import GenericDataSet
>>> import polars as pl
>>>
>>> data = pl.DataFrame({'col1': [1, 2], 'col2': [4, 5],
>>> 'col3': [5, 6]})
>>>
>>> data_set = GenericDataSet(filepath="test.parquet", file_format='parquet')
>>> data_set.save(data)
>>> reloaded = data_set.load()
>>> assert data.frame_equal(reloaded)
"""

DEFAULT_LOAD_ARGS = {} # type: Dict[str, Any]
DEFAULT_SAVE_ARGS = {} # type: Dict[str, Any]

# pylint: disable=too-many-arguments
def __init__(
self,
filepath: str,
file_format: str,
write_mode: str = "overwrite",
astrojuanlu marked this conversation as resolved.
Show resolved Hide resolved
load_args: Dict[str, Any] = None,
save_args: Dict[str, Any] = None,
version: Version = None,
credentials: Dict[str, Any] = None,
fs_args: Dict[str, Any] = None,
):
"""Creates a new instance of ``GenericDataSet`` pointing to a concrete data file
on a specific filesystem. The appropriate polars load/save methods are
dynamically identified by string matching on a best effort basis.
Args:
filepath: Filepath in POSIX format to a file prefixed with a protocol like
`s3://`.
If prefix is not provided, `file` protocol (local filesystem)
will be used.
The prefix should be any protocol supported by ``fsspec``.
Key assumption: The first argument of either load/save method points to
a filepath/buffer/io type location. There are some read/write targets
such as 'clipboard' or 'records' that will fail since they do not take a
filepath like argument.
file_format: String which is used to match the appropriate load/save method
on a best effort basis. For example if 'csv' is passed the
`polars.read_csv` and
`polars.DataFrame.write_csv` methods will be identified. An error will
be raised unless
at least one matching `read_{file_format}` or `write_{file_format}`.
write_mode: String which determines the behaviour of the dataset,
defaults to "overwrite".
Accepted values are "overwrite" and "ignore", use "ignore" when you want
to read
fileformat that polars does not provide write support for.
load_args: polars options for loading files.
Here you can find all available arguments:
https://pola-rs.github.io/polars/py-polars/html/reference/io.html
All defaults are preserved.
save_args: Polars options for saving files.
Here you can find all available arguments:
https://pola-rs.github.io/polars/py-polars/html/reference/io.html
All defaults are preserved.
version: If specified, should be an instance of
``kedro.io.core.Version``. If its ``load`` attribute is
None, the latest version will be loaded. If its ``save``
attribute is None, save version will be autogenerated.
credentials: Credentials required to get access to the underlying filesystem.
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`.
fs_args: Extra arguments to pass into underlying filesystem class constructor
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as
to pass to the filesystem's `open` method through nested keys
`open_args_load` and `open_args_save`.
Here you can find all available arguments for `open`:
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open
All defaults are preserved, except `mode`, which is set to `r` when loading
and to `w` when saving.
Raises:
DataSetError: Will be raised if at least less than one appropriate
read or write methods are identified.
"""

self._file_format = file_format.lower()

_fs_args = deepcopy(fs_args) or {}
_fs_open_args_load = _fs_args.pop("open_args_load", {})
_fs_open_args_save = _fs_args.pop("open_args_save", {})
_credentials = deepcopy(credentials) or {}

protocol, path = get_protocol_and_path(filepath)
if protocol == "file":
_fs_args.setdefault("auto_mkdir", True)

self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args)

super().__init__(
filepath=PurePosixPath(path),
version=version,
exists_function=self._fs.exists,
glob_function=self._fs.glob,
)

self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)

_fs_open_args_save.setdefault("mode", "wb")
self._fs_open_args_load = _fs_open_args_load
self._fs_open_args_save = _fs_open_args_save
self._write_mode = write_mode
self._assert_write_mode()

def _assert_write_mode(self) -> None:
"""Check that the write mode is supported."""
if self._write_mode not in ("overwrite", "ignore"):
raise DataSetError(
f"Write mode `{self._write_mode}` is not supported. "
"Allowed values are: overwrite, ignore."
)

def _ensure_file_system_target(self) -> None:
"""Check if format is supported by file system target"""

if self._file_format not in ACCEPTED_READ_FILE_FORMATS:
raise DataSetError(
f"Unable to retrieve 'polars.read_{self._file_format}' method, please"
" ensure that your "
"'file_format' parameter has been defined correctly as per the Polars"
" API"
" https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)

def _load(self) -> pl.DataFrame: # pylint: disable= inconsistent-return-statements

self._ensure_file_system_target()

load_path = get_filepath_str(self._get_load_path(), self._protocol)
load_method = getattr(pl, f"read_{self._file_format}", None)
if load_method:
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return load_method(fs_file, **self._load_args)

def _save(self, data: pl.DataFrame) -> None:
if (
self._write_mode == "overwrite"
and self._file_format not in ACCEPTED_WRITE_FILE_FORMATS
):
if self._file_format in ACCEPTED_READ_FILE_FORMATS:
raise DataSetError(
f"This file format is read-only: '{self._file_format}' "
f"If you want only to read, change write_mode to 'ignore'"
)
astrojuanlu marked this conversation as resolved.
Show resolved Hide resolved
raise DataSetError(
f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' "
"method, please "
"ensure that your 'file_format' parameter has been defined correctly as"
" per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)
if self._write_mode == "ignore":
raise DataSetError(f"Write mode '{self._write_mode}' is read-only.")
self._ensure_file_system_target()

save_path = get_filepath_str(self._get_save_path(), self._protocol)
save_method = getattr(data, f"write_{self._file_format}", None)

if save_method:
buf = BytesIO()
save_method(file=buf, **self._save_args)
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
fs_file.write(buf.getvalue())
self._invalidate_cache()

def _exists(self) -> bool:
try:
load_path = get_filepath_str(self._get_load_path(), self._protocol)
except DataSetError:
return False

return self._fs.exists(load_path)

def _describe(self) -> Dict[str, Any]:
return {
"file_format": self._file_format,
"filepath": self._filepath,
"write_mode": self._write_mode,
"protocol": self._protocol,
"load_args": self._load_args,
"save_args": self._save_args,
"version": self._version,
}

def _release(self) -> None:
super()._release()
self._invalidate_cache()

def _invalidate_cache(self) -> None:
"""Invalidate underlying filesystem caches."""
filepath = get_filepath_str(self._filepath, self._protocol)
self._fs.invalidate_cache(filepath)
10 changes: 8 additions & 2 deletions kedro-datasets/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SPARK = "pyspark>=2.2, <4.0"
HDFS = "hdfs>=2.5.8, <3.0"
S3FS = "s3fs>=0.3.0, <0.5"
POLARS = "polars~=0.17.0"
POLARS = "polars~=0.18.0"
DELTA = "delta-spark~=1.2.1"


Expand Down Expand Up @@ -51,7 +51,13 @@ def _collect_requirements(requires):
"plotly.PlotlyDataSet": [PANDAS, "plotly>=4.8.0, <6.0"],
"plotly.JSONDataSet": ["plotly>=4.8.0, <6.0"],
}
polars_require = {"polars.CSVDataSet": [POLARS]}
polars_require = {
"polars.CSVDataSet": [POLARS],
"polars.GenericDataSet":
[
POLARS, "pyarrow>=4.0", "xlsx2csv>=0.8.0", "deltalake >= 0.6.2"
],
}
redis_require = {"redis.PickleDataSet": ["redis~=4.1"]}
snowflake_require = {
"snowflake.SnowparkTableDataSet": [
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pandas-gbq>=0.12.0, <0.18.0
pandas>=1.3, <2 # 1.3 for read_xml/to_xml, <2 for compatibility with Spark < 3.4
Pillow~=9.0
plotly>=4.8.0, <6.0
polars~=0.15.13
polars[xlsx2csv, deltalake]~=0.18.0
pre-commit>=2.9.2, <3.0 # The hook `mypy` requires pre-commit version 2.9.2.
psutil==5.8.0
pyarrow~=8.0
Expand Down
Loading