Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add moving average plot #836

Merged
merged 30 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9660618
added moving average plot
Gerhardsa0 Jun 11, 2024
dd7b388
added moving average plot
Gerhardsa0 Jun 11, 2024
547325a
linter change
Gerhardsa0 Jun 11, 2024
ce250b9
style: apply automated linter fixes
megalinter-bot Jun 11, 2024
2eb7fc4
Merge branch 'main' into 521-feat-add-moving-average-plot-1
Gerhardsa0 Jun 11, 2024
9c882d9
for numerical columns the plot shows the values inbetween
Gerhardsa0 Jun 20, 2024
05fd84d
Merge remote-tracking branch 'origin/521-feat-add-moving-average-plot…
Gerhardsa0 Jun 20, 2024
d99898e
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
ff91a74
fixed snapshots
Gerhardsa0 Jun 20, 2024
2590cee
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
b5bad15
for some reason the numerical grouped snapshot fails
Gerhardsa0 Jun 20, 2024
abe17be
Merge remote-tracking branch 'origin/521-feat-add-moving-average-plot…
Gerhardsa0 Jun 20, 2024
2c76a38
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
9f4c7a9
Merge branch 'main' into 521-feat-add-moving-average-plot-1
Gerhardsa0 Jun 20, 2024
1d781cc
added GRU layer
Gerhardsa0 Jun 20, 2024
23ea410
Merge remote-tracking branch 'origin/521-feat-add-moving-average-plot…
Gerhardsa0 Jun 20, 2024
8dd5b87
fixed docs
Gerhardsa0 Jun 20, 2024
565ca55
added missing values for plot
Gerhardsa0 Jun 20, 2024
9133db9
linter changes and better error message
Gerhardsa0 Jun 20, 2024
f360168
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
3ffc640
linter changes and better error message
Gerhardsa0 Jun 20, 2024
c7002c5
Merge remote-tracking branch 'origin/521-feat-add-moving-average-plot…
Gerhardsa0 Jun 20, 2024
b88dfeb
removed grouped numerical
Gerhardsa0 Jun 20, 2024
8e6c178
style: apply automated linter fixes
megalinter-bot Jun 20, 2024
dcf27f3
changed one example
Gerhardsa0 Jun 20, 2024
b459805
Merge remote-tracking branch 'origin/521-feat-add-moving-average-plot…
Gerhardsa0 Jun 20, 2024
0d484f5
changed moving average plot, so it does not takes missing values
Gerhardsa0 Jun 20, 2024
2bd14bc
changed moving average plot, so it does not takes missing values
Gerhardsa0 Jun 20, 2024
7b31d53
Merge branch 'main' into 521-feat-add-moving-average-plot-1
Gerhardsa0 Jun 21, 2024
d557ce4
style: apply automated linter fixes
megalinter-bot Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/safeds/data/tabular/plotting/_table_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,41 @@ def scatter_plot(self, x_name: str, y_names: list[str]) -> Image:

return _figure_to_image(fig)

def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Image:
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
import matplotlib.pyplot as plt
import numpy as np
import polars as pl

_plot_validation(self._table, x_name, [y_name])
# Calculate the moving average
mean_col = pl.col(y_name).mean().alias(y_name)
grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(mean_col).collect()
data = grouped
moving_average = data.select([pl.col(y_name).rolling_mean(window_size).alias("moving_average")])
# set up the arrays for plotting
y_data_with_nan = moving_average["moving_average"].to_numpy()
nan_mask = ~np.isnan(y_data_with_nan)
y_data = y_data_with_nan[nan_mask]
x_data = data[x_name].to_numpy()[nan_mask]
fig, ax = plt.subplots()
ax.plot(x_data, y_data, label="moving average")
ax.set(
xlabel=x_name,
ylabel=y_name,
)
ax.legend()
if self._table.get_column(x_name).is_temporal:
ax.set_xticks(x_data) # Set x-ticks to the x data points
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(
ax.get_xticklabels(),
rotation=45,
horizontalalignment="right",
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
fig.tight_layout()

return _figure_to_image(fig)


def _plot_validation(table: Table, x_name: str, y_names: list[str]) -> None:
y_names.append(x_name)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 58 additions & 0 deletions tests/safeds/data/tabular/plotting/test_moving_average_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import datetime

import pytest
from safeds.data.tabular.containers import Table
from syrupy import SnapshotAssertion


@pytest.mark.parametrize(
("table", "x_name", "y_name", "window_size"),
[
(Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B", 2),
# (Table({"A": [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], "B": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]}), "A", "B", 2),
(
Table(
{
"time": [
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 12),
datetime.date(2022, 1, 12),
],
"A": [10, 5, 20, 2, 1, 1],
},
),
"time",
"A",
2,
),
(
Table(
{
"time": [
datetime.date(2022, 1, 9),
datetime.date(2022, 1, 10),
datetime.date(2022, 1, 11),
datetime.date(2022, 1, 12),
],
"A": [10, 5, 20, 2],
},
),
"time",
"A",
2,
),
],
ids=["numerical", "date grouped", "date"],
sibre28 marked this conversation as resolved.
Show resolved Hide resolved
)
def test_should_match_snapshot(
sibre28 marked this conversation as resolved.
Show resolved Hide resolved
table: Table,
x_name: str,
y_name: str,
window_size: int,
snapshot_png_image: SnapshotAssertion,
) -> None:
line_plot = table.plot.moving_average_plot(x_name, y_name, window_size)
assert line_plot == snapshot_png_image