Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement violin plots #900

Merged
merged 16 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions docs/tutorials/data_visualization.ipynb

Large diffs are not rendered by default.

68 changes: 67 additions & 1 deletion src/safeds/data/tabular/plotting/_column_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def box_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a box plot")

import matplotlib.pyplot as plt

def _set_boxplot_colors(box: dict, theme: str) -> None:
Expand Down Expand Up @@ -127,6 +126,73 @@ def _set_boxplot_colors(box: dict, theme: str) -> None:

return _figure_to_image(fig)

def violin_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a violin plot for the values in the column. This is only possible for numeric columns.

Muellersen marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
theme:
The color theme of the plot. Default is "light".

Returns
-------
plot:
The violin plot as an image.

Raises
------
TypeError
If the column is not numeric.

Examples
--------
>>> from safeds.data.tabular.containers import Column
>>> column = Column("test", [1, 2, 3])
>>> violinplot = column.plot.violin_plot()
"""
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a violin plot")
from math import nan

import matplotlib.pyplot as plt

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "gray",
"grid.linewidth": 0.5,
},
)
else:
plt.rcParams.update(
{
"grid.linewidth": 0.5,
},
)

fig, ax = plt.subplots()
data = self._column._series.drop_nulls()
if len(data) == 0:
data = [nan, nan]
ax.violinplot(
data,
)

ax.set(title=self._column.name)

ax.yaxis.grid(visible=True)
fig.tight_layout()

return _figure_to_image(fig)

def histogram(self, *, max_bin_count: int = 10, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a histogram for the values in the column.
Expand Down
106 changes: 94 additions & 12 deletions src/safeds/data/tabular/plotting/_table_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,102 @@ def box_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
fig.delaxes(axs[number_of_rows - 1, i])

fig.tight_layout()
return _figure_to_image(fig)

def violin_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Create a violin plot for every numerical column.

Muellersen marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
theme:
The color theme of the plot. Default is "light".

Returns
-------
plot:
The violin plot(s) as an image.

Raises
------
NonNumericColumnError
If the table contains only non-numerical columns.

Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2], "b": [3, 42]})
>>> image = table.plot.violin_plots()
"""
numerical_table = self._table.remove_non_numeric_columns()
if numerical_table.column_count == 0:
raise NonNumericColumnError("This table contains only non-numerical columns.")
from math import ceil

import matplotlib.pyplot as plt

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
"grid.color": "gray",
"grid.linewidth": 0.5,
},
)
else:
plt.rcParams.update(
{
"grid.linewidth": 0.5,
},
)

columns = numerical_table.to_columns()
columns = [column._series.drop_nulls() for column in columns]
max_width = 3
number_of_columns = len(columns) if len(columns) <= max_width else max_width
number_of_rows = ceil(len(columns) / number_of_columns)

fig, axs = plt.subplots(nrows=number_of_rows, ncols=number_of_columns)
line = 0
for i, column in enumerate(columns):
data = column.to_list()

if i % number_of_columns == 0 and i != 0:
line += 1

if number_of_columns == 1:
axs.violinplot(
data,
)
axs.set_title(numerical_table.column_names[i])
break

style = "dark_background" if theme == "dark" else "default"
with plt.style.context(style):
if theme == "dark":
plt.rcParams.update(
{
"text.color": "white",
"axes.labelcolor": "white",
"axes.edgecolor": "white",
"xtick.color": "white",
"ytick.color": "white",
},
if number_of_rows == 1:
axs[i].violinplot(
data,
)
axs[i].set_title(numerical_table.column_names[i])

else:
axs[line, i % number_of_columns].violinplot(
data,
)
return _figure_to_image(fig)
axs[line, i % number_of_columns].set_title(numerical_table.column_names[i])

# removes unused ax indices, so there wont be empty plots
last_filled_ax_index = len(columns) % number_of_columns
for i in range(last_filled_ax_index, number_of_columns):
if number_of_rows != 1 and last_filled_ax_index != 0:
fig.delaxes(axs[number_of_rows - 1, i])

fig.tight_layout()
return _figure_to_image(fig)

def correlation_heatmap(self, *, theme: Literal["dark", "light"] = "light") -> Image:
"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import ColumnTypeError
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
"column",
[
Column("a", []),
Column("a", [0]),
Column("a", [0, 1]),
],
ids=[
"empty",
"one row",
"multiple rows",
],
)
def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
violin_plot = column.plot.violin_plot()
assert violin_plot == snapshot_png_image


@pytest.mark.parametrize(
"column",
[
Column("a", []),
Column("a", [0]),
Column("a", [0, 1]),
],
ids=[
"empty",
"one row",
"multiple rows",
],
)
def test_should_match_dark_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None:
violin_plot = column.plot.violin_plot(theme="dark")
assert violin_plot == snapshot_png_image


def test_should_raise_if_column_contains_non_numerical_values() -> None:
column = Column("a", ["A", "B", "C"])
with pytest.raises(ColumnTypeError):
column.plot.violin_plot()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 49 additions & 0 deletions tests/safeds/data/tabular/plotting/test_plot_violin_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.exceptions import NonNumericColumnError
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
"table",
[
Table({"A": [1, 2, 3]}),
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
],
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
)
def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
violinplots = table.plot.violin_plots()
assert violinplots == snapshot_png_image


@pytest.mark.parametrize(
"table",
[
Table({"A": [1, 2, 3]}),
Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}),
Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}),
],
ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"],
)
def test_should_match_dark_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
violinplots = table.plot.violin_plots(theme="dark")
assert violinplots == snapshot_png_image


def test_should_raise_if_column_contains_non_numerical_values() -> None:
table = Table.from_dict({"A": ["1", "2", "3.5"], "B": ["0.2", "4", "77"]})
with pytest.raises(
NonNumericColumnError,
match=(
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis table contains only"
r" non-numerical columns."
),
):
table.plot.violin_plots()


def test_should_fail_on_empty_table() -> None:
with pytest.raises(NonNumericColumnError):
Table().plot.violin_plots()