diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 5341612edd..c1456c0150 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1185,19 +1185,20 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: - grouping_vars = [v for v in PROPERTIES if v not in "xy"] - grouping_vars += ["col", "row", "group"] pair_vars = spec._pair_spec.get("structure", {}) for layer in layers: - data = layer["data"] mark = layer["mark"] stat = layer["stat"] if stat is None: continue + target_vars = getattr(stat, "target_vars", "xy") + + grouping_vars = [v for v in PROPERTIES if v not in target_vars] + grouping_vars += ["col", "row", "group"] iter_axes = itertools.product(*[ pair_vars.get(axis, [axis]) for axis in "xy" diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index d175273e78..e8d3068a99 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar, Callable +from typing import ClassVar, Callable, Iterable import pandas as pd from pandas import DataFrame @@ -21,6 +21,9 @@ class Agg(Stat): ---------- func : str or callable Name of a :class:`pandas.Series` method or a vector -> scalar function. + target_vars : list of strings + Variables to perform the aggregation on. Defaults to x or y, depending on + orientation. See Also -------- @@ -32,6 +35,7 @@ class Agg(Stat): """ func: str | Callable[[Vector], float] = "mean" + target_vars: Iterable[str] = ("x", "y") group_by_orient: ClassVar[bool] = True @@ -39,11 +43,11 @@ def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: - var = {"x": "y", "y": "x"}.get(orient) + vars = [v for v in self.target_vars if v != orient] res = ( groupby - .agg(data, {var: self.func}) - .dropna(subset=[var]) + .agg(data, {var: self.func for var in vars}) + .dropna(subset=vars) .reset_index(drop=True) ) return res