Skip to content

Commit

Permalink
Add Count stat (#3086)
Browse files Browse the repository at this point in the history
* Add Count stat

* Rename _stats/histogram -> _stats/counting

* Call-out the different handling of numeric variables in Count/Hist

* Simplify Count tests for backcompat
  • Loading branch information
mwaskom authored Oct 16, 2022
1 parent a02b6bf commit 95dc377
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 7 deletions.
121 changes: 121 additions & 0 deletions doc/_docstrings/objects.Count.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "89113d6b-70b9-4ebe-9910-10a80eab246e",
"metadata": {
"tags": [
"hide"
]
},
"outputs": [],
"source": [
"import seaborn.objects as so\n",
"from seaborn import load_dataset\n",
"tips = load_dataset(\"tips\")"
]
},
{
"cell_type": "raw",
"id": "daf6ff78-df24-4541-ba72-73fb9eddb50d",
"metadata": {},
"source": [
"The transform counts distinct observations of the orientation variable defines a new variable on the opposite axis:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "390f2fd3-0596-40e3-b262-163b3a90d055",
"metadata": {},
"outputs": [],
"source": [
"so.Plot(tips, x=\"day\").add(so.Bar(), so.Count())"
]
},
{
"cell_type": "raw",
"id": "813fb4a5-db68-4b51-b236-5b5628ebba47",
"metadata": {},
"source": [
"When additional mapping variables are defined, they are also used to define groups:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76a4ae70-e914-4f54-b979-ce1b79374fc3",
"metadata": {},
"outputs": [],
"source": [
"so.Plot(tips, x=\"day\", color=\"sex\").add(so.Bar(), so.Count(), so.Dodge())"
]
},
{
"cell_type": "raw",
"id": "2973dee1-5aee-4768-846d-22d220faf170",
"metadata": {},
"source": [
"Unlike :class:`Hist`, numeric data are not binned before counting:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f94c5f0-680e-4d8a-a1c9-70876980dd1c",
"metadata": {},
"outputs": [],
"source": [
"so.Plot(tips, x=\"size\").add(so.Bar(), so.Count())"
]
},
{
"cell_type": "raw",
"id": "11acd5e6-f477-4eb1-b1d7-72f4582bca45",
"metadata": {},
"source": [
"When the `y` variable is defined, the counts are assigned to the `x` variable:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "924e0e35-210f-4f65-83b4-4aebe41ad264",
"metadata": {},
"outputs": [],
"source": [
"so.Plot(tips, y=\"size\").add(so.Bar(), so.Count())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0229fa39-b6dc-48da-9a25-31e25ed34ebc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "py310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Stat objects

Agg
Est
Count
Hist
Perc
PolyFit
Expand Down
2 changes: 2 additions & 0 deletions doc/whatsnew/v0.12.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ v0.12.1 (Unreleased)

- |Feature| Added the :class:`objects.Perc` stat (:pr:`3063`).

- |Feature| Added the :class:`objects.Count` stat (:pr:`3086`).

- |Feature| The :class:`objects.Band` and :class:`objects.Range` marks will now cover the full extent of the data if `min` / `max` variables are not explicitly assigned or added in a transform (:pr:`3056`).

- |Enhancement| |Defaults| The :class:`objects.Jitter` move now applies a small amount of jitter by default (:pr:`3066`).
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __call__(
res = (
groupby
.agg(data, {var: self.func})
.dropna()
.dropna(subset=[var])
.reset_index(drop=True)
)
return res
Expand Down Expand Up @@ -86,7 +86,7 @@ def __call__(
res = (
groupby
.apply(data, self._process, var, engine)
.dropna(subset=["x", "y"])
.dropna(subset=[var])
.reset_index(drop=True)
)

Expand Down
40 changes: 38 additions & 2 deletions seaborn/_stats/histogram.py → seaborn/_stats/counting.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
from __future__ import annotations
from dataclasses import dataclass
from warnings import warn
from typing import ClassVar

import numpy as np
import pandas as pd
from pandas import DataFrame

from seaborn._core.groupby import GroupBy
from seaborn._core.scales import Scale
from seaborn._stats.base import Stat

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from numpy.typing import ArrayLike


@dataclass
class Count(Stat):
"""
Count distinct observations within groups.
See Also
--------
Hist : A more fully-featured transform including binning and/or normalization.
Examples
--------
.. include:: ../docstrings/objects.Count.rst
"""
group_by_orient: ClassVar[bool] = True

def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

var = {"x": "y", "y": "x"}.get(orient)
data[var] = data[orient]
res = (
groupby
.agg(data, {var: len})
.dropna(subset=["x", "y"])
.reset_index(drop=True)
)
return res


@dataclass
class Hist(Stat):
"""
Expand Down Expand Up @@ -167,10 +201,12 @@ def _normalize(self, data):

return data.assign(**{self.stat: hist})

def __call__(self, data, groupby, orient, scales):
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

scale_type = scales[orient].__class__.__name__.lower()
grouping_vars = [v for v in data if v in groupby.order]
grouping_vars = [str(v) for v in data if v in groupby.order]
if not grouping_vars or self.common_bins is True:
bin_kws = self._define_bin_params(data, orient, scale_type)
data = groupby.apply(data, self._eval, orient, bin_kws)
Expand Down
2 changes: 1 addition & 1 deletion seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# We have moved univariate histogram computation over to the new Hist class,
# but still use the older Histogram for bivariate computation.
from ._statistics import ECDF, Histogram, KDE
from ._stats.histogram import Hist
from ._stats.counting import Hist

from .axisgrid import (
FacetGrid,
Expand Down
2 changes: 1 addition & 1 deletion seaborn/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from seaborn._stats.base import Stat # noqa: F401
from seaborn._stats.aggregation import Agg, Est # noqa: F401
from seaborn._stats.histogram import Hist # noqa: F401
from seaborn._stats.counting import Count, Hist # noqa: F401
from seaborn._stats.order import Perc # noqa: F401
from seaborn._stats.regression import PolyFit # noqa: F401

Expand Down
40 changes: 39 additions & 1 deletion tests/_stats/test_histogram.py → tests/_stats/test_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,45 @@
from numpy.testing import assert_array_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.histogram import Hist
from seaborn._stats.counting import Hist, Count


class TestCount:

@pytest.fixture
def df(self, rng):

n = 30
return pd.DataFrame(dict(
x=rng.uniform(0, 7, n).round(),
y=rng.normal(size=n),
color=rng.choice(["a", "b", "c"], n),
group=rng.choice(["x", "y"], n),
))

def get_groupby(self, df, orient):

other = {"x": "y", "y": "x"}[orient]
cols = [c for c in df if c != other]
return GroupBy(cols)

def test_single_grouper(self, df):

ori = "x"
df = df[["x"]]
gb = self.get_groupby(df, ori)
res = Count()(df, gb, ori, {})
expected = df.groupby("x").size()
assert_array_equal(res.sort_values("x")["y"], expected)

def test_multiple_groupers(self, df):

ori = "x"
df = df[["x", "group"]].sort_values("group")
gb = self.get_groupby(df, ori)
res = Count()(df, gb, ori, {})
expected = df.groupby(["x", "group"]).size()
assert_array_equal(res.sort_values(["x", "group"])["y"], expected)


class TestHist:
Expand Down

0 comments on commit 95dc377

Please sign in to comment.