Skip to content

Commit

Permalink
100 coverage generic dataset
Browse files Browse the repository at this point in the history
Signed-off-by: wmoreiraa <walber3@gmail.com>
  • Loading branch information
wmoreiraa authored and astrojuanlu committed Jul 13, 2023
1 parent c5895a2 commit 69555e1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 27 deletions.
23 changes: 7 additions & 16 deletions kedro-datasets/kedro_datasets/polars/generic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _assert_write_mode(self) -> None:
"""Check that the write mode is supported."""
if self._write_mode not in ("overwrite", "ignore"):
raise DataSetError(
f"Write mode {self._write_mode} not supported. "
f"Write mode `{self._write_mode}` is not supported. "
"Allowed values are: overwrite, ignore."
)

Expand All @@ -181,11 +181,14 @@ def _ensure_file_system_target(self) -> None:

if self._file_format not in ACCEPTED_READ_FILE_FORMATS:
raise DataSetError(
f"Cannot create a dataset of file_format '{self._file_format}' as it "
f"does not support a filepath target/source."
f"Unable to retrieve 'polars.read_{self._file_format}' method, please"
" ensure that your "
"'file_format' parameter has been defined correctly as per the Polars"
" API"
" https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)

def _load(self) -> pl.DataFrame:
def _load(self) -> pl.DataFrame: # pylint: disable= inconsistent-return-statements

self._ensure_file_system_target()

Expand All @@ -194,11 +197,6 @@ def _load(self) -> pl.DataFrame:
if load_method:
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file:
return load_method(fs_file, **self._load_args)
raise DataSetError(
f"Unable to retrieve 'polars.read_{self._file_format}' method, please ensure that your "
"'file_format' parameter has been defined correctly as per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)

def _save(self, data: pl.DataFrame) -> None:
if (
Expand Down Expand Up @@ -230,13 +228,6 @@ def _save(self, data: pl.DataFrame) -> None:
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file:
fs_file.write(buf.getvalue())
self._invalidate_cache()
else:
raise DataSetError(
f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' method, please "
"ensure that your 'file_format' parameter has been defined correctly as "
"per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)

def _exists(self) -> bool:
try:
Expand Down
48 changes: 37 additions & 11 deletions kedro-datasets/tests/polars/test_generic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def filepath_excel(tmp_path):
return tmp_path / "test.xlsx"


@pytest.fixture
def parquet_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet):
dummy_dataframe.write_parquet(filepath_parquet)

return GenericDataSet(
filepath=filepath_parquet.as_posix(),
file_format="parquet",
write_mode="ignore",
load_args={"low_memory": True},
)


@pytest.fixture
def excel_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_excel):
pd_df = dummy_dataframe.to_pandas()
Expand All @@ -97,13 +109,23 @@ def excel_data_set_overwrite(dummy_dataframe: pl.DataFrame, filepath_excel):
pd_df.to_excel(filepath_excel, index=False)

return GenericDataSet(
filepath=filepath_excel.as_posix(),
file_format="excel",
write_mode="overwrite",
filepath=filepath_excel.as_posix(), file_format="excel", write_mode="overwrite"
)


class TestGenericExcelDataSet:
def test_assert_write_mode(self):
pattern = (
"Write mode `test` is not supported. "
"Allowed values are: overwrite, ignore."
)
with pytest.raises(DataSetError, match=pattern):
GenericDataSet(
filepath="test.xlsx",
file_format="excel",
write_mode="test",
)

def test_load(self, excel_data_set_ignore):
df = excel_data_set_ignore.load()
assert df.shape == (2, 3)
Expand Down Expand Up @@ -161,6 +183,10 @@ def test_catalog_release(self, mocker):


class TestGenericParquetDataSetVersioned:
def test_load_args(self, parquet_data_set_ignore):
df = parquet_data_set_ignore.load()
assert df.shape == (2, 3)

def test_save_and_load(self, versioned_parquet_data_set, dummy_dataframe):
"""Test saving and reloading the data set."""
versioned_parquet_data_set.save(dummy_dataframe)
Expand Down Expand Up @@ -517,18 +543,18 @@ def test_bad_file_format_argument(self):
ds = GenericDataSet(filepath="test.kedro", file_format="kedro")

pattern = (
"Cannot create a dataset of file_format 'kedro' as"
" it does not support a filepath target/source."
"Unable to retrieve 'polars.DataFrame.write_kedro' method, please "
"ensure that your 'file_format' parameter has been defined correctly as "
"per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)

with pytest.raises(DataSetError, match=pattern):
_ = ds.load()
ds.save(pd.DataFrame([1]))

pattern2 = (
"Unable to retrieve 'polars.DataFrame.write_kedro' method, please "
"ensure that your 'file_format' parameter has been defined correctly as "
"per the Polars API "
"Unable to retrieve 'polars.read_kedro' method, please ensure that your "
"'file_format' parameter has been defined correctly as per the Polars API "
"https://pola-rs.github.io/polars/py-polars/html/reference/io.html"
)
with pytest.raises(DataSetError, match=pattern2):
ds.save(pd.DataFrame([1]))
ds.load()

0 comments on commit 69555e1

Please sign in to comment.