diff --git a/src/safeds/data/tabular/plotting/_table_plotter.py b/src/safeds/data/tabular/plotting/_table_plotter.py index 62f9d1113..e7bf679cb 100644 --- a/src/safeds/data/tabular/plotting/_table_plotter.py +++ b/src/safeds/data/tabular/plotting/_table_plotter.py @@ -4,11 +4,13 @@ from typing import TYPE_CHECKING from safeds._utils import _figure_to_image -from safeds._validation import _check_columns_exist +from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound from safeds._validation._check_columns_are_numeric import _check_columns_are_numeric from safeds.exceptions import ColumnTypeError, NonNumericColumnError if TYPE_CHECKING: + from typing import Literal + from safeds.data.image.containers import Image from safeds.data.tabular.containers import Table @@ -454,6 +456,86 @@ def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Ima return _figure_to_image(fig) + def histogram_2d( + self, + x_name: str, + y_name: str, + *, + x_max_bin_count: int = 10, + y_max_bin_count: int = 10, + theme: Literal["dark", "light"] = "light", + ) -> Image: + """ + Create a 2D histogram for two columns in the table. + + Parameters + ---------- + x_name: + The name of the column to be plotted on the x-axis. + y_name: + The name of the column to be plotted on the y-axis. + x_max_bin_count: + The maximum number of bins to use in the histogram for the x-axis. Default is 10. + y_max_bin_count: + The maximum number of bins to use in the histogram for the y-axis. Default is 10. + theme: + The color theme of the plot. Default is "light". + + Returns + ------- + plot: + The plot as an image. + + Raises + ------ + ColumnNotFoundError + If a column does not exist. + OutOfBoundsError: + If x_max_bin_count or y_max_bin_count is less than 1. + TypeError + If a column is not numeric. + + Examples + -------- + >>> from safeds.data.tabular.containers import Table + >>> table = Table( + ... { + ... "a": [1, 2, 3, 4, 5], + ... "b": [2, 3, 4, 5, 6], + ... } + ... ) + >>> image = table.plot.histogram_2d("a", "b") + """ + _check_bounds("x_max_bin_count", x_max_bin_count, lower_bound=_ClosedBound(1)) + _check_bounds("y_max_bin_count", y_max_bin_count, lower_bound=_ClosedBound(1)) + _plot_validation(self._table, x_name, [y_name]) + + import matplotlib.pyplot as plt + + if theme == "dark": + context = "dark_background" + else: + context = "default" + + with plt.style.context(context): + fig, ax = plt.subplots() + + ax.hist2d( + x=self._table.get_column(x_name)._series, + y=self._table.get_column(y_name)._series, + bins=(x_max_bin_count, y_max_bin_count), + ) + ax.set_xlabel(x_name) + ax.set_ylabel(y_name) + ax.tick_params( + axis="x", + labelrotation=45, + ) + + fig.tight_layout() + + return _figure_to_image(fig) + def _plot_validation(table: Table, x_name: str, y_names: list[str]) -> None: y_names.append(x_name) diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[functional].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[functional].png new file mode 100644 index 000000000..e9f0cdf59 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[functional].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[multiple].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[multiple].png new file mode 100644 index 000000000..9fa0740b4 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[multiple].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[overlapping].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[overlapping].png new file mode 100644 index 000000000..9fa0740b4 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot[overlapping].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot_dark_theme[dark_theme].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot_dark_theme[dark_theme].png new file mode 100644 index 000000000..ffb684f3d Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histogram_2d/test_should_match_snapshot_dark_theme[dark_theme].png differ diff --git a/tests/safeds/data/tabular/plotting/test_plot_histogram_2d.py b/tests/safeds/data/tabular/plotting/test_plot_histogram_2d.py new file mode 100644 index 000000000..88fef9b5e --- /dev/null +++ b/tests/safeds/data/tabular/plotting/test_plot_histogram_2d.py @@ -0,0 +1,101 @@ +import pytest +from safeds.data.tabular.containers import Table +from safeds.exceptions import ColumnNotFoundError, ColumnTypeError, OutOfBoundsError +from syrupy import SnapshotAssertion + +from tests.helpers import os_mac, skip_if_os + + +@pytest.mark.parametrize( + ("table", "col1", "col2"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B"), + ( + Table( + { + "A": [1, 0.99, 0.99, 2], + "B": [1, 0.99, 1.01, 2], + }, + ), + "A", + "B", + ), + ( + Table( + {"A": [1, 0.99, 0.99, 2], "B": [1, 0.99, 1.01, 2], "C": [2, 2.99, 2.01, 3]}, + ), + "A", + "B", + ), + ], + ids=[ + "functional", + "overlapping", + "multiple", + ], +) +def test_should_match_snapshot( + table: Table, + col1: str, + col2: str, + snapshot_png_image: SnapshotAssertion, +) -> None: + skip_if_os([os_mac]) + histogram_2d = table.plot.histogram_2d(col1, col2) + assert histogram_2d == snapshot_png_image + + +@pytest.mark.parametrize( + ("table", "col1", "col2"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "C", "A"), + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "B", "C"), + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "C", "D"), + (Table(), "C", "D"), + ], + ids=[ + "First argument doesn't exist", + "Second argument doesn't exist", + "Both arguments do not exist", + "empty", + ], +) +def test_should_raise_if_column_does_not_exist(table: Table, col1: str, col2: str) -> None: + with pytest.raises(ColumnNotFoundError): + table.plot.histogram_2d(col1, col2) + + +@pytest.mark.parametrize( + ("table", "col1", "col2"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B"), + ], + ids=["dark_theme"], +) +def test_should_match_snapshot_dark_theme( + table: Table, + col1: str, + col2: str, + snapshot_png_image: SnapshotAssertion, +) -> None: + skip_if_os([os_mac]) + histogram_2d = table.plot.histogram_2d(col1, col2, theme="dark") + assert histogram_2d == snapshot_png_image + + +def test_should_raise_if_value_not_in_range_x() -> None: + table = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}) + with pytest.raises(OutOfBoundsError): + table.plot.histogram_2d("col1", "col2", x_max_bin_count=0) + + +def test_should_raise_if_value_not_in_range_y() -> None: + table = Table({"col1": [1, 2, 1], "col2": [1, 2, 4]}) + with pytest.raises(OutOfBoundsError): + table.plot.histogram_2d("col1", "col2", y_max_bin_count=0) + + +def test_should_raise_if_column_is_not_numeric() -> None: + table = Table({"col1": ["a"], "col2": ["b"]}) + with pytest.raises(ColumnTypeError): + table.plot.histogram_2d("col1", "col2")