From 5e3da8d26250b9f06a3209a35e26e296890fc6a2 Mon Sep 17 00:00:00 2001 From: Robin <91134475+robmeth@users.noreply.github.com> Date: Fri, 7 Jul 2023 10:22:15 +0200 Subject: [PATCH] feat: discretize table (#327) Closes #143. ### Summary of Changes * Added a class `Discretizer` in `safeds.data.tabular.transformation` that wraps the [`KBinsDiscretizer` of `scikit-learn`](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html) * Made the class a subclass of `TableTransformer` * The `__init__` for now only has a parameter `number_of_bins` to control how many bins are created * If `number_of_bins` is less than 2, it raises a `ValueError` --- .../data/tabular/transformation/__init__.py | 2 + .../tabular/transformation/_discretizer.py | 205 ++++++++++++ .../transformation/test_discretizer.py | 301 ++++++++++++++++++ 3 files changed, 508 insertions(+) create mode 100644 src/safeds/data/tabular/transformation/_discretizer.py create mode 100644 tests/safeds/data/tabular/transformation/test_discretizer.py diff --git a/src/safeds/data/tabular/transformation/__init__.py b/src/safeds/data/tabular/transformation/__init__.py index edf45242b..215573798 100644 --- a/src/safeds/data/tabular/transformation/__init__.py +++ b/src/safeds/data/tabular/transformation/__init__.py @@ -1,5 +1,6 @@ """Classes for transforming tabular data.""" +from ._discretizer import Discretizer from ._imputer import Imputer from ._label_encoder import LabelEncoder from ._one_hot_encoder import OneHotEncoder @@ -14,5 +15,6 @@ "InvertibleTableTransformer", "TableTransformer", "RangeScaler", + "Discretizer", "StandardScaler", ] diff --git a/src/safeds/data/tabular/transformation/_discretizer.py b/src/safeds/data/tabular/transformation/_discretizer.py new file mode 100644 index 000000000..38a6cf0f0 --- /dev/null +++ b/src/safeds/data/tabular/transformation/_discretizer.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +from sklearn.preprocessing import KBinsDiscretizer as sk_KBinsDiscretizer + +from safeds.data.tabular.containers import Table +from safeds.data.tabular.transformation._table_transformer import TableTransformer +from safeds.exceptions import NonNumericColumnError, TransformerNotFittedError, UnknownColumnNameError + + +class Discretizer(TableTransformer): + """ + The Discretizer bins continuous data into intervals. + + Parameters + ---------- + number_of_bins: float + The number of bins to be created. + + Raises + ------ + ValueError + If the given number_of_bins is less than 2. + """ + + def __init__(self, number_of_bins: float = 5): + self._column_names: list[str] | None = None + self._wrapped_transformer: sk_KBinsDiscretizer | None = None + + if number_of_bins < 2: + raise ValueError("Parameter 'number_of_bins' must be >= 2.") + self._number_of_bins = number_of_bins + + def fit(self, table: Table, column_names: list[str] | None) -> Discretizer: + """ + Learn a transformation for a set of columns in a table. + + This transformer is not modified. + + Parameters + ---------- + table : Table + The table used to fit the transformer. + column_names : list[str] | None + The list of columns from the table used to fit the transformer. If `None`, all columns are used. + + Returns + ------- + fitted_transformer : TableTransformer + The fitted transformer. + + Raises + ------ + ValueError + If the table is empty. + NonNumericColumnError + If one of the columns, that should be fitted is non-numeric. + UnknownColumnNameError + If one of the columns, that should be fitted is not in the table. + """ + if table.number_of_rows == 0: + raise ValueError("The Discretizer cannot be fitted because the table contains 0 rows") + + if column_names is None: + column_names = table.column_names + else: + missing_columns = set(column_names) - set(table.column_names) + if len(missing_columns) > 0: + raise UnknownColumnNameError( + sorted( + missing_columns, + key={val: ix for ix, val in enumerate(column_names)}.__getitem__, + ), + ) + + for column in column_names: + if not table.get_column(column).type.is_numeric(): + raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.") + + wrapped_transformer = sk_KBinsDiscretizer(n_bins=self._number_of_bins, encode="ordinal") + wrapped_transformer.fit(table._data[column_names]) + + result = Discretizer(self._number_of_bins) + result._wrapped_transformer = wrapped_transformer + result._column_names = column_names + + return result + + def transform(self, table: Table) -> Table: + """ + Apply the learned transformation to a table. + + The table is not modified. + + Parameters + ---------- + table : Table + The table to which the learned transformation is applied. + + Returns + ------- + transformed_table : Table + The transformed table. + + Raises + ------ + TransformerNotFittedError + If the transformer has not been fitted yet. + ValueError + If the table is empty. + UnknownColumnNameError + If one of the columns, that should be transformed is not in the table. + NonNumericColumnError + If one of the columns, that should be fitted is non-numeric. + """ + # Transformer has not been fitted yet + if self._wrapped_transformer is None or self._column_names is None: + raise TransformerNotFittedError + + if table.number_of_rows == 0: + raise ValueError("The table cannot be transformed because it contains 0 rows") + + # Input table does not contain all columns used to fit the transformer + missing_columns = set(self._column_names) - set(table.column_names) + if len(missing_columns) > 0: + raise UnknownColumnNameError( + sorted( + missing_columns, + key={val: ix for ix, val in enumerate(self._column_names)}.__getitem__, + ), + ) + + for column in self._column_names: + if not table.get_column(column).type.is_numeric(): + raise NonNumericColumnError(f"{column} is of type {table.get_column(column).type}.") + + data = table._data.copy() + data.columns = table.column_names + data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names]) + return Table._from_pandas_dataframe(data) + + def is_fitted(self) -> bool: + """ + Check if the transformer is fitted. + + Returns + ------- + is_fitted : bool + Whether the transformer is fitted. + """ + return self._wrapped_transformer is not None + + def get_names_of_added_columns(self) -> list[str]: + """ + Get the names of all new columns that have been added by the Discretizer. + + Returns + ------- + added_columns : list[str] + A list of names of the added columns, ordered as they will appear in the table. + + Raises + ------ + TransformerNotFittedError + If the transformer has not been fitted yet. + """ + if not self.is_fitted(): + raise TransformerNotFittedError + return [] + + # (Must implement abstract method, cannot instantiate class otherwise.) + def get_names_of_changed_columns(self) -> list[str]: + """ + Get the names of all columns that may have been changed by the Discretizer. + + Returns + ------- + changed_columns : list[str] + The list of (potentially) changed column names, as passed to fit. + + Raises + ------ + TransformerNotFittedError + If the transformer has not been fitted yet. + """ + if self._column_names is None: + raise TransformerNotFittedError + return self._column_names + + def get_names_of_removed_columns(self) -> list[str]: + """ + Get the names of all columns that have been removed by the Discretizer. + + Returns + ------- + removed_columns : list[str] + A list of names of the removed columns, ordered as they appear in the table the Discretizer was fitted on. + + Raises + ------ + TransformerNotFittedError + If the transformer has not been fitted yet. + """ + if not self.is_fitted(): + raise TransformerNotFittedError + return [] diff --git a/tests/safeds/data/tabular/transformation/test_discretizer.py b/tests/safeds/data/tabular/transformation/test_discretizer.py new file mode 100644 index 000000000..b4a69971b --- /dev/null +++ b/tests/safeds/data/tabular/transformation/test_discretizer.py @@ -0,0 +1,301 @@ +import pytest +from safeds.data.tabular.containers import Table +from safeds.data.tabular.transformation import Discretizer +from safeds.exceptions import NonNumericColumnError, TransformerNotFittedError, UnknownColumnNameError + + +class TestInit: + def test_should_raise_value_error(self) -> None: + with pytest.raises(ValueError, match="Parameter 'number_of_bins' must be >= 2."): + _ = Discretizer(1) + + +class TestFit: + @pytest.mark.parametrize( + ("table", "columns", "error", "error_message"), + [ + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + }, + ), + ["col2"], + UnknownColumnNameError, + r"Could not find column\(s\) 'col2'", + ), + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + "col2": [0.0, 5.0, 5.0, 10.0], + "col3": [0.0, 5.0, 5.0, 10.0], + }, + ), + ["col4", "col5"], + UnknownColumnNameError, + r"Could not find column\(s\) 'col4, col5'", + ), + (Table(), ["col2"], ValueError, "The Discretizer cannot be fitted because the table contains 0 rows"), + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + "col2": ["a", "b", "c", "d"], + }, + ), + ["col2"], + NonNumericColumnError, + "Tried to do a numerical operation on one or multiple non-numerical columns: \ncol2 is of type String.", + ), + ], + ids=["UnknownColumnNameError", "multiple missing columns", "ValueError", "NonNumericColumnError"], + ) + def test_should_raise_errors( + self, + table: Table, + columns: list[str], + error: type[Exception], + error_message: str, + ) -> None: + with pytest.raises(error, match=error_message): + Discretizer().fit(table, columns) + + def test_should_not_change_original_transformer(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = Discretizer() + transformer.fit(table, None) + + assert transformer._wrapped_transformer is None + assert transformer._column_names is None + + +class TestTransform: + @pytest.mark.parametrize( + ("table_to_transform", "columns", "error", "error_message"), + [ + ( + Table( + { + "col2": ["a", "b", "c"], + }, + ), + ["col1"], + UnknownColumnNameError, + r"Could not find column\(s\) 'col1'", + ), + ( + Table( + { + "col2": ["a", "b", "c"], + }, + ), + ["col3", "col1"], + UnknownColumnNameError, + r"Could not find column\(s\) 'col3, col1'", + ), + (Table(), ["col1", "col3"], ValueError, "The table cannot be transformed because it contains 0 rows"), + ( + Table( + { + "col1": ["a", "b", "c", "d"], + }, + ), + ["col1"], + NonNumericColumnError, + "Tried to do a numerical operation on one or multiple non-numerical columns: \ncol1 is of type String.", + ), + ], + ids=["UnknownColumnNameError", "multiple missing columns", "ValueError", "NonNumericColumnError"], + ) + def test_should_raise_errors( + self, + table_to_transform: Table, + columns: list[str], + error: type[Exception], + error_message: str, + ) -> None: + table_to_fit = Table( + { + "col1": [0.0, 5.0, 10.0], + "col3": [0.0, 5.0, 10.0], + }, + ) + + transformer = Discretizer().fit(table_to_fit, columns) + + with pytest.raises(error, match=error_message): + transformer.transform(table_to_transform) + + def test_should_raise_if_not_fitted(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = Discretizer() + + with pytest.raises(TransformerNotFittedError, match=r"The transformer has not been fitted yet."): + transformer.transform(table) + + +class TestIsFitted: + def test_should_return_false_before_fitting(self) -> None: + transformer = Discretizer() + assert not transformer.is_fitted() + + def test_should_return_true_after_fitting(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + transformer = Discretizer() + fitted_transformer = transformer.fit(table, None) + assert fitted_transformer.is_fitted() + + +class TestFitAndTransform: + @pytest.mark.parametrize( + ("table", "column_names", "expected"), + [ + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + }, + ), + None, + Table( + { + "col1": [0.0, 2.0, 2.0, 3.0], + }, + ), + ), + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + "col2": [0.0, 5.0, 5.0, 10.0], + }, + ), + ["col1"], + Table( + { + "col1": [0.0, 2.0, 2.0, 3.0], + "col2": [0.0, 5.0, 5.0, 10.0], + }, + ), + ), + ], + ids=["None", "col1"], + ) + def test_should_return_transformed_table( + self, + table: Table, + column_names: list[str] | None, + expected: Table, + ) -> None: + assert Discretizer().fit_and_transform(table, column_names) == expected + + @pytest.mark.parametrize( + ("table", "number_of_bins", "expected"), + [ + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + }, + ), + 2, + Table( + { + "col1": [0, 1.0, 1.0, 1.0], + }, + ), + ), + ( + Table( + { + "col1": [0.0, 5.0, 5.0, 10.0], + }, + ), + 10, + Table( + { + "col1": [0.0, 4.0, 4.0, 7.0], + }, + ), + ), + ], + ids=["2", "10"], + ) + def test_should_return_transformed_table_with_correct_number_of_bins( + self, + table: Table, + number_of_bins: int, + expected: Table, + ) -> None: + assert Discretizer(number_of_bins).fit_and_transform(table, ["col1"]) == expected + + def test_should_not_change_original_table(self) -> None: + table = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + Discretizer().fit_and_transform(table) + + expected = Table( + { + "col1": [0.0, 5.0, 10.0], + }, + ) + + assert table == expected + + def test_get_names_of_added_columns(self) -> None: + transformer = Discretizer() + with pytest.raises(TransformerNotFittedError, match=r"The transformer has not been fitted yet."): + transformer.get_names_of_added_columns() + + table = Table( + { + "a": [0.0], + }, + ) + transformer = transformer.fit(table, None) + assert transformer.get_names_of_added_columns() == [] + + def test_get_names_of_changed_columns(self) -> None: + transformer = Discretizer() + with pytest.raises(TransformerNotFittedError, match=r"The transformer has not been fitted yet."): + transformer.get_names_of_changed_columns() + table = Table( + { + "a": [0.0], + }, + ) + transformer = transformer.fit(table, None) + assert transformer.get_names_of_changed_columns() == ["a"] + + def test_get_names_of_removed_columns(self) -> None: + transformer = Discretizer() + with pytest.raises(TransformerNotFittedError, match=r"The transformer has not been fitted yet."): + transformer.get_names_of_removed_columns() + + table = Table( + { + "a": [0.0], + }, + ) + transformer = transformer.fit(table, None) + assert transformer.get_names_of_removed_columns() == []