diff --git a/src/safeds/data/tabular/containers/_time_series.py b/src/safeds/data/tabular/containers/_time_series.py index fee39fd94..a755885e1 100644 --- a/src/safeds/data/tabular/containers/_time_series.py +++ b/src/safeds/data/tabular/containers/_time_series.py @@ -6,6 +6,7 @@ import matplotlib.pyplot as plt import pandas as pd +import seaborn as sns from safeds.data.image.containers import Image from safeds.data.tabular.containers import Column, Row, Table, TaggedTable @@ -36,7 +37,7 @@ def _from_tagged_table( Parameters ---------- - table : TaggedTable + tagged_table: TaggedTable The tagged table. time_name: str Name of the time column. @@ -906,3 +907,150 @@ def plot_lagplot(self, lag: int) -> Image: plt.close() # Prevents the figure from being displayed directly buffer.seek(0) return Image.from_bytes(buffer.read()) + + def plot_lineplot(self, x_column_name: str | None = None, y_column_name: str | None = None) -> Image: + """ + + Plot the time series target or the given column(s) as line plot. + + The function will take the time column as the default value for y_column_name and the target column as the + default value for x_column_name. + + Parameters + ---------- + x_column_name: + The column name of the column to be plotted on the x-Axis, default is the time column. + y_column_name: + The column name of the column to be plotted on the y-Axis, default is the target column. + + Returns + ------- + plot: + The plot as an image. + + Raises + ------ + NonNumericColumnError + If the time series given columns contain non-numerical values. + + UnknownColumnNameError + If one of the given names does not exist in the table + + Examples + -------- + >>> from safeds.data.tabular.containers import TimeSeries + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) + >>> image = table.plot_lineplot() + + """ + self._data.index.name = "index" + if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric(): + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") + + if y_column_name is None: + y_column_name = self.target.name + + elif y_column_name not in self._data.columns: + raise UnknownColumnNameError([y_column_name]) + + if x_column_name is None: + x_column_name = self.time.name + + if not self.get_column(y_column_name).type.is_numeric(): + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") + + fig = plt.figure() + ax = sns.lineplot( + data=self._data, + x=x_column_name, + y=y_column_name, + ) + ax.set(xlabel=x_column_name, ylabel=y_column_name) + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), + rotation=45, + horizontalalignment="right", + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + plt.tight_layout() + + buffer = io.BytesIO() + fig.savefig(buffer, format="png") + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + self._data = self._data.reset_index() + return Image.from_bytes(buffer.read()) + + def plot_scatterplot( + self, + x_column_name: str | None = None, + y_column_name: str | None = None, + ) -> Image: + """ + Plot the time series target or the given column(s) as scatter plot. + + The function will take the time column as the default value for x_column_name and the target column as the + default value for y_column_name. + + Parameters + ---------- + x_column_name: + The column name of the column to be plotted on the x-Axis. + y_column_name: + The column name of the column to be plotted on the y-Axis. + + Returns + ------- + plot: + The plot as an image. + + Raises + ------ + NonNumericColumnError + If the time series given columns contain non-numerical values. + + UnknownColumnNameError + If one of the given names does not exist in the table + + Examples + -------- + >>> from safeds.data.tabular.containers import TimeSeries + >>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], ) + >>> image = table.plot_scatterplot() + + """ + self._data.index.name = "index" + if x_column_name is not None and not self.get_column(x_column_name).type.is_numeric(): + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") + + if y_column_name is None: + y_column_name = self.target.name + elif y_column_name not in self._data.columns: + raise UnknownColumnNameError([y_column_name]) + if x_column_name is None: + x_column_name = self.time.name + + if not self.get_column(y_column_name).type.is_numeric(): + raise NonNumericColumnError("The time series plotted column contains non-numerical columns.") + + fig = plt.figure() + ax = sns.scatterplot( + data=self._data, + x=x_column_name, + y=y_column_name, + ) + ax.set(xlabel=x_column_name, ylabel=y_column_name) + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), + rotation=45, + horizontalalignment="right", + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + plt.tight_layout() + + buffer = io.BytesIO() + fig.savefig(buffer, format="png") + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + self._data = self._data.reset_index() + return Image.from_bytes(buffer.read()) diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature.png new file mode 100644 index 000000000..98dbeaa93 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_x.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_x.png new file mode 100644 index 000000000..4d4de31bf Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_x.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_y.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_y.png new file mode 100644 index 000000000..b0d4d1918 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_plot_feature_y.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table.png new file mode 100644 index 000000000..f40ea7854 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table_both.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table_both.png new file mode 100644 index 000000000..8c10979a0 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_lineplot/test_should_return_table_both.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature.png new file mode 100644 index 000000000..d52aa8a85 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_both_set.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_both_set.png new file mode 100644 index 000000000..9d6035eaf Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_both_set.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_x.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_x.png new file mode 100644 index 000000000..f44732c27 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_x.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_y_optional.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_y_optional.png new file mode 100644 index 000000000..d52aa8a85 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_plot_feature_only_y_optional.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_return_table.png b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_return_table.png new file mode 100644 index 000000000..f92f23bea Binary files /dev/null and b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/__snapshots__/test_plot_scatterplot/test_should_return_table.png differ diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lineplot.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lineplot.py new file mode 100644 index 000000000..a4816e4ec --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_lineplot.py @@ -0,0 +1,265 @@ +import pytest +from safeds.data.tabular.containers import TimeSeries +from safeds.exceptions import NonNumericColumnError, UnknownColumnNameError +from syrupy import SnapshotAssertion + + +def test_should_return_table(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_lineplot() + assert plot == snapshot_png + + +def test_should_raise_if_column_contains_non_numerical_values_x() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_lineplot(x_column_name="feature_1") + + +def test_should_return_table_both(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_lineplot(x_column_name="feature_1", y_column_name="target") + assert plot == snapshot_png + + +def test_should_plot_feature_y(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_lineplot(y_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_plot_feature_x(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_lineplot(x_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_plot_feature(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_lineplot(x_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_raise_if_column_contains_non_numerical_values() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_lineplot(x_column_name="target") + + +@pytest.mark.parametrize( + ("time_series", "name", "error", "error_msg"), + [ + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_1", + NonNumericColumnError, + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns.", + ), + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_3", + UnknownColumnNameError, + r"Could not find column\(s\) 'feature_3'.", + ), + ], + ids=["feature_not_numerical", "feature_does_not_exist"], +) +def test_should_raise_error_optional_parameter( + time_series: TimeSeries, + name: str, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + time_series.plot_lineplot(x_column_name=name) + + +@pytest.mark.parametrize( + ("time_series", "name", "error", "error_msg"), + [ + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_1", + NonNumericColumnError, + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns.", + ), + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_3", + UnknownColumnNameError, + r"Could not find column\(s\) 'feature_3'.", + ), + ], + ids=["feature_not_numerical", "feature_does_not_exist"], +) +def test_should_raise_error_optional_parameter_y( + time_series: TimeSeries, + name: str, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + time_series.plot_lineplot(y_column_name=name) + + +def test_should_raise_if_column_does_not_exist_x() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + UnknownColumnNameError, + match=r"Could not find column\(s\) '2'.", + ): + table.plot_lineplot(x_column_name="target", y_column_name="2") + + +def test_should_raise_if_column_does_not_exist_y() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + UnknownColumnNameError, + match=r"Could not find column\(s\) '2'.", + ): + table.plot_lineplot(x_column_name="2", y_column_name="target") diff --git a/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_scatterplot.py b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_scatterplot.py new file mode 100644 index 000000000..822b3d755 --- /dev/null +++ b/tests/safeds/data/tabular/containers/_table/_tagged_table/_time_series/test_plot_scatterplot.py @@ -0,0 +1,265 @@ +import pytest +from safeds.data.tabular.containers import TimeSeries +from safeds.exceptions import NonNumericColumnError, UnknownColumnNameError +from syrupy import SnapshotAssertion + + +def test_should_return_table(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_scatterplot() + assert plot == snapshot_png + + +def test_should_plot_feature(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_scatterplot(y_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_plot_feature_only_x(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_scatterplot(x_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_plot_feature_only_y_optional(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [10, 9, 8, 7, 6, 5, 4, 3, 2, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_scatterplot(y_column_name="feature_1") + assert plot == snapshot_png + + +def test_should_plot_feature_both_set(snapshot_png: SnapshotAssertion) -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 1, 2, 1, 2, 1, 2, 1, 1], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + plot = table.plot_scatterplot(x_column_name="feature_1", y_column_name="target") + assert plot == snapshot_png + + +def test_should_raise_if_column_contains_non_numerical_values() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_scatterplot(y_column_name="feature_1") + + +def test_should_raise_if_column_contains_non_numerical_values_x() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + NonNumericColumnError, + match=( + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns." + ), + ): + table.plot_scatterplot(x_column_name="feature_1") + + +@pytest.mark.parametrize( + ("time_series", "name", "error", "error_msg"), + [ + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_1", + NonNumericColumnError, + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns.", + ), + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_3", + UnknownColumnNameError, + r"Could not find column\(s\) 'feature_3'.", + ), + ], + ids=["feature_not_numerical", "feature_does_not_exist"], +) +def test_should_raise_error_optional_parameter( + time_series: TimeSeries, + name: str, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + time_series.plot_scatterplot(x_column_name=name) + + +@pytest.mark.parametrize( + ("time_series", "name", "error", "error_msg"), + [ + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_1", + NonNumericColumnError, + r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThe time series plotted" + r" column" + r" contains" + r" non-numerical columns.", + ), + ( + TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + "target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], + }, + target_name="target", + time_name="time", + feature_names=None, + ), + "feature_3", + UnknownColumnNameError, + r"Could not find column\(s\) 'feature_3'.", + ), + ], + ids=["feature_not_numerical", "feature_does_not_exist"], +) +def test_should_raise_error_optional_parameter_y( + time_series: TimeSeries, + name: str, + error: type[Exception], + error_msg: str, +) -> None: + with pytest.raises( + error, + match=error_msg, + ): + time_series.plot_scatterplot(y_column_name=name) + + +def test_should_raise_if_column_does_not_exist_y() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + UnknownColumnNameError, + match=r"Could not find column\(s\) '2'.", + ): + table.plot_scatterplot(x_column_name="target", y_column_name="2") + + +def test_should_raise_if_column_does_not_exist_x() -> None: + table = TimeSeries( + { + "time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + "target": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + }, + target_name="target", + time_name="time", + feature_names=None, + ) + with pytest.raises( + UnknownColumnNameError, + match=r"Could not find column\(s\) '2'.", + ): + table.plot_scatterplot(x_column_name="2")