diff --git a/kedro-datasets/kedro_datasets/polars/generic_dataset.py b/kedro-datasets/kedro_datasets/polars/generic_dataset.py index 2c166a653..2eb4f627c 100644 --- a/kedro-datasets/kedro_datasets/polars/generic_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/generic_dataset.py @@ -172,7 +172,7 @@ def _assert_write_mode(self) -> None: """Check that the write mode is supported.""" if self._write_mode not in ("overwrite", "ignore"): raise DataSetError( - f"Write mode {self._write_mode} not supported. " + f"Write mode `{self._write_mode}` is not supported. " "Allowed values are: overwrite, ignore." ) @@ -181,11 +181,14 @@ def _ensure_file_system_target(self) -> None: if self._file_format not in ACCEPTED_READ_FILE_FORMATS: raise DataSetError( - f"Cannot create a dataset of file_format '{self._file_format}' as it " - f"does not support a filepath target/source." + f"Unable to retrieve 'polars.read_{self._file_format}' method, please" + " ensure that your " + "'file_format' parameter has been defined correctly as per the Polars" + " API" + " https://pola-rs.github.io/polars/py-polars/html/reference/io.html" ) - def _load(self) -> pl.DataFrame: + def _load(self) -> pl.DataFrame: # pylint: disable= inconsistent-return-statements self._ensure_file_system_target() @@ -194,11 +197,6 @@ def _load(self) -> pl.DataFrame: if load_method: with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: return load_method(fs_file, **self._load_args) - raise DataSetError( - f"Unable to retrieve 'polars.read_{self._file_format}' method, please ensure that your " - "'file_format' parameter has been defined correctly as per the Polars API " - "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" - ) def _save(self, data: pl.DataFrame) -> None: if ( @@ -230,13 +228,6 @@ def _save(self, data: pl.DataFrame) -> None: with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: fs_file.write(buf.getvalue()) self._invalidate_cache() - else: - raise DataSetError( - f"Unable to retrieve 'polars.DataFrame.write_{self._file_format}' method, please " - "ensure that your 'file_format' parameter has been defined correctly as " - "per the Polars API " - "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" - ) def _exists(self) -> bool: try: diff --git a/kedro-datasets/tests/polars/test_generic_dataset.py b/kedro-datasets/tests/polars/test_generic_dataset.py index 8c7a223f7..62009443a 100644 --- a/kedro-datasets/tests/polars/test_generic_dataset.py +++ b/kedro-datasets/tests/polars/test_generic_dataset.py @@ -79,6 +79,18 @@ def filepath_excel(tmp_path): return tmp_path / "test.xlsx" +@pytest.fixture +def parquet_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_parquet): + dummy_dataframe.write_parquet(filepath_parquet) + + return GenericDataSet( + filepath=filepath_parquet.as_posix(), + file_format="parquet", + write_mode="ignore", + load_args={"low_memory": True}, + ) + + @pytest.fixture def excel_data_set_ignore(dummy_dataframe: pl.DataFrame, filepath_excel): pd_df = dummy_dataframe.to_pandas() @@ -97,13 +109,23 @@ def excel_data_set_overwrite(dummy_dataframe: pl.DataFrame, filepath_excel): pd_df.to_excel(filepath_excel, index=False) return GenericDataSet( - filepath=filepath_excel.as_posix(), - file_format="excel", - write_mode="overwrite", + filepath=filepath_excel.as_posix(), file_format="excel", write_mode="overwrite" ) class TestGenericExcelDataSet: + def test_assert_write_mode(self): + pattern = ( + "Write mode `test` is not supported. " + "Allowed values are: overwrite, ignore." + ) + with pytest.raises(DataSetError, match=pattern): + GenericDataSet( + filepath="test.xlsx", + file_format="excel", + write_mode="test", + ) + def test_load(self, excel_data_set_ignore): df = excel_data_set_ignore.load() assert df.shape == (2, 3) @@ -161,6 +183,10 @@ def test_catalog_release(self, mocker): class TestGenericParquetDataSetVersioned: + def test_load_args(self, parquet_data_set_ignore): + df = parquet_data_set_ignore.load() + assert df.shape == (2, 3) + def test_save_and_load(self, versioned_parquet_data_set, dummy_dataframe): """Test saving and reloading the data set.""" versioned_parquet_data_set.save(dummy_dataframe) @@ -517,18 +543,18 @@ def test_bad_file_format_argument(self): ds = GenericDataSet(filepath="test.kedro", file_format="kedro") pattern = ( - "Cannot create a dataset of file_format 'kedro' as" - " it does not support a filepath target/source." + "Unable to retrieve 'polars.DataFrame.write_kedro' method, please " + "ensure that your 'file_format' parameter has been defined correctly as " + "per the Polars API " + "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" ) - with pytest.raises(DataSetError, match=pattern): - _ = ds.load() + ds.save(pd.DataFrame([1])) pattern2 = ( - "Unable to retrieve 'polars.DataFrame.write_kedro' method, please " - "ensure that your 'file_format' parameter has been defined correctly as " - "per the Polars API " + "Unable to retrieve 'polars.read_kedro' method, please ensure that your " + "'file_format' parameter has been defined correctly as per the Polars API " "https://pola-rs.github.io/polars/py-polars/html/reference/io.html" ) with pytest.raises(DataSetError, match=pattern2): - ds.save(pd.DataFrame([1])) + ds.load()