diff --git a/doc/changelog.qmd b/doc/changelog.qmd index 54ee61cb4..657922196 100644 --- a/doc/changelog.qmd +++ b/doc/changelog.qmd @@ -13,6 +13,16 @@ title: Changelog - Fixed [](:class:`~plotnine.geom_text`) with `adjust_text` for some cases where the text are placed outside the panels. {{< issue 899 >}} +- The default aesthetics and aesthetic parameters of geoms (and stats) are + now transformed. So you can now set date values as aesthetic parameters. + e.g. + + ```python + geom_point(y=datetime(2024, 12, 11)) + ``` + + if the `y` aesthetic is mapped to datetime column in another layer. + ## v0.14.3 (2024-11-26) [![](https://zenodo.org/badge/DOI/10.5281/zenodo.14224336.svg)](https://doi.org/10.5281/zenodo.14224336) diff --git a/plotnine/geoms/geom.py b/plotnine/geoms/geom.py index 269063f9a..8d4733dc5 100644 --- a/plotnine/geoms/geom.py +++ b/plotnine/geoms/geom.py @@ -217,7 +217,7 @@ def use_defaults( data : Data used for drawing the geom. aes_modifiers : - Aesthetics + Aesthetics to evaluate Returns ------- diff --git a/plotnine/ggplot.py b/plotnine/ggplot.py index 44b36aa42..aba115d9b 100755 --- a/plotnine/ggplot.py +++ b/plotnine/ggplot.py @@ -406,7 +406,7 @@ def _build(self): layout.setup_panel_params(self.coordinates) # fill in the defaults - layers.use_defaults_after_scale() + layers.use_defaults_after_scale(scales) # Allow stats to modify the layer data layers.finish_statistics() diff --git a/plotnine/layer.py b/plotnine/layer.py index b8ef910de..4d7e240a7 100644 --- a/plotnine/layer.py +++ b/plotnine/layer.py @@ -370,6 +370,7 @@ def use_defaults( self, data: pd.DataFrame, aes_modifiers: dict[str, Any], + scales: Scales | None = None, ) -> pd.DataFrame: """ Prepare/modify data for plotting @@ -382,7 +383,17 @@ def use_defaults( Expression to evaluate and replace aesthetics in the data. """ - return self.geom.use_defaults(data, aes_modifiers) + old_columns = data.columns + data = self.geom.use_defaults(data, aes_modifiers) + if scales is not None: + # The default aesthetics and the aesthetic parameters are + # specified in userspace. When we add them we have to + # transform them. + new_columns = data.columns.difference(old_columns) + _data = scales.transform_df(self.data[new_columns]) + for col in new_columns: + data[col] = _data[col] + return data def finish_statistics(self): """ @@ -468,9 +479,9 @@ def compute_position(self, layout: Layout): for l in self: l.compute_position(layout) - def use_defaults_after_scale(self): + def use_defaults_after_scale(self, scales: Scales): for l in self: - l.use_defaults(l.data, l.mapping._scaled) + l.data = l.use_defaults(l.data, l.mapping._scaled, scales) def transform(self, scales: Scales): for l in self: diff --git a/tests/baseline_images/test_scale_internals/transform_aes_defaults_and_params.png b/tests/baseline_images/test_scale_internals/transform_aes_defaults_and_params.png new file mode 100644 index 000000000..dd79a29b0 Binary files /dev/null and b/tests/baseline_images/test_scale_internals/transform_aes_defaults_and_params.png differ diff --git a/tests/baseline_images/test_scale_internals/transform_datetime_aes_param.png b/tests/baseline_images/test_scale_internals/transform_datetime_aes_param.png new file mode 100644 index 000000000..6cd1f0870 Binary files /dev/null and b/tests/baseline_images/test_scale_internals/transform_datetime_aes_param.png differ diff --git a/tests/test_scale_internals.py b/tests/test_scale_internals.py index 2b35cb562..93792097c 100644 --- a/tests/test_scale_internals.py +++ b/tests/test_scale_internals.py @@ -53,6 +53,7 @@ scale_x_discrete, scale_x_log10, scale_y_continuous, + scale_y_log10, ) from plotnine.scales.scales import make_scale @@ -876,3 +877,33 @@ def test_discrete_scale_for_empty_layer(): ) p.draw_test() + + +def test_transform_aes_defaults_and_params(): + data = pd.DataFrame({"x": range(5), "y": 10 ** np.arange(5)}) + # Another set of points that should be inline with the first set + # both with and without the scale_y_log10. + data2 = pd.DataFrame({"x": [3.2, 3.4, 3.6, 3.8]}) + yparam = list(10 ** data2["x"]) + + p = ( + ggplot(data, aes("x", "y")) + + geom_point() + + geom_point(data=data2, y=yparam, color="red") + + scale_y_log10() + ) + assert p == "transform_aes_defaults_and_params" + + +def test_transform_datetime_aes_param(): + data = pd.DataFrame( + {"x": range(5), "y": [datetime(2024, i, 1) for i in range(1, 6)]} + ) + yparam: list[datetime] = list(data["y"] + timedelta(days=30)) + + p = ( + ggplot(data, aes("x", "y")) + + geom_point() + + geom_point(y=yparam, color="red") + ) + assert p == "transform_datetime_aes_param"