Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KDE stat #3111

Merged
merged 6 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
export SHELL := /bin/bash

test:
pytest -n auto --cov=seaborn --cov=tests --cov-config=.coveragerc tests
pytest -n auto --cov=seaborn --cov=tests --cov-config=setup.cfg tests

lint:
flake8 seaborn
Expand Down
270 changes: 270 additions & 0 deletions doc/_docstrings/objects.KDE.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "dcc1ae12-bba4-4de9-af8d-543b3d65b42b",
"metadata": {
"tags": [
"hide"
]
},
"outputs": [],
"source": [
"import seaborn.objects as so\n",
"from seaborn import load_dataset\n",
"penguins = load_dataset(\"penguins\")"
]
},
{
"cell_type": "raw",
"id": "1042b991-1471-43bd-934c-43caae3cb2fa",
"metadata": {},
"source": [
"This stat estimates transforms observations into a smooth function representing the estimated density:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2406e2aa-7f0f-4a51-af59-4cef827d28d8",
"metadata": {},
"outputs": [],
"source": [
"p = so.Plot(penguins, x=\"flipper_length_mm\")\n",
"p.add(so.Area(), so.KDE())"
]
},
{
"cell_type": "raw",
"id": "44515f21-683b-420f-967b-4c7568c907d7",
"metadata": {},
"source": [
"Adjust the smoothing bandwidth to see more or fewer details:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4e6ba5b-4dd2-4210-8cf0-de057dc71e2a",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Area(), so.KDE(bw_adjust=0.25))"
]
},
{
"cell_type": "raw",
"id": "fd665fe1-a5e4-4742-adc9-e40615d57d08",
"metadata": {},
"source": [
"The curve will extend beyond observed values in the dataset:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4cda1cb8-f663-4f94-aa24-6f1727a41031",
"metadata": {},
"outputs": [],
"source": [
"p2 = p.add(so.Bars(alpha=.3), so.Hist(\"density\"))\n",
"p2.add(so.Line(), so.KDE())"
]
},
{
"cell_type": "raw",
"id": "75235825-d522-4562-aacc-9b7413eabf5d",
"metadata": {},
"source": [
"Control the range of the density curve relative to the observations using `cut`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a7a9275e-9889-437d-bdc5-18653d2c92ef",
"metadata": {},
"outputs": [],
"source": [
"p2.add(so.Line(), so.KDE(cut=0))"
]
},
{
"cell_type": "raw",
"id": "6a885eeb-81ba-47c6-8402-1bef40544fd1",
"metadata": {},
"source": [
"When observations are assigned to the `y` variable, the density will be shown for `x`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38b3a0fb-54ff-493a-bd64-f83a12365723",
"metadata": {},
"outputs": [],
"source": [
"so.Plot(penguins, y=\"flipper_length_mm\").add(so.Area(), so.KDE())"
]
},
{
"cell_type": "raw",
"id": "59996340-168e-479f-a0c6-c7e1fcab0fb0",
"metadata": {},
"source": [
"Use `gridsize` to increase or decrease the resolution of the grid where the density is evaluated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23715820-7df9-40ba-9e74-f11564704dd0",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Dots(), so.KDE(gridsize=100))"
]
},
{
"cell_type": "raw",
"id": "4c9b6492-98c8-45ab-9f53-681cde2f767a",
"metadata": {},
"source": [
"Or pass `None` to evaluate the density at the original datapoints:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e1b6810-5c28-43aa-aa61-652521299b51",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Dots(), so.KDE(gridsize=None))"
]
},
{
"cell_type": "raw",
"id": "0970a56b-0cba-4c40-bb1b-b8e71739df5c",
"metadata": {},
"source": [
"Other variables will define groups for the estimation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f0ce0b6-5742-4bc0-9ac3-abedde923684",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Area(), so.KDE(), color=\"species\")"
]
},
{
"cell_type": "raw",
"id": "22204fcd-4b25-46e5-a170-02b1419c23d5",
"metadata": {},
"source": [
"By default, the density is normalized across all groups (i.e., the joint density is shown); pass `common_norm=False` to show conditional densities:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ad56958-dc45-4632-94d1-23039ad3ec58",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Area(), so.KDE(common_norm=False), color=\"species\")"
]
},
{
"cell_type": "raw",
"id": "b1627197-85d1-4476-b4ae-3e93044ee988",
"metadata": {},
"source": [
"Or pass a list of variables to condition on:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "58f63734-5afd-4d90-bbfb-fc39c8d1981f",
"metadata": {},
"outputs": [],
"source": [
"(\n",
" p.facet(\"sex\")\n",
" .add(so.Area(), so.KDE(common_norm=[\"col\"]), color=\"species\")\n",
")"
]
},
{
"cell_type": "raw",
"id": "2b7e018e-1374-4939-909c-e95f5ffd086e",
"metadata": {},
"source": [
"This stat can be combined with other transforms, such as :class:`Stack` (when `common_grid=True`):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96e5b2d0-c7e2-47df-91f1-7f9ec0bb08a9",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Area(), so.KDE(), so.Stack(), color=\"sex\")"
]
},
{
"cell_type": "raw",
"id": "8500ff86-0b1f-4831-954b-08b6df690387",
"metadata": {},
"source": [
"Set `cumulative=True` to integrate the density:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26bb736e-7cfd-421e-b80d-42fa450e88c0",
"metadata": {},
"outputs": [],
"source": [
"p.add(so.Line(), so.KDE(cumulative=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8bfd9d2-ad60-4971-aa7f-71a285f44a20",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"language": "python",
"name": "py310"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Stat objects
Est
Count
Hist
KDE
Perc
PolyFit

Expand Down
5 changes: 5 additions & 0 deletions doc/whatsnew/v0.12.2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@

v0.12.2 (Unreleased)
--------------------

- Added the :class:`objects.KDE` stat (:pr:`3111`).
15 changes: 14 additions & 1 deletion seaborn/_stats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from typing import ClassVar, Any
import warnings

from typing import TYPE_CHECKING
if TYPE_CHECKING:
Expand All @@ -29,7 +30,7 @@ class Stat:
# value on the orient axis, but we would not in the latter case.
group_by_orient: ClassVar[bool] = False

def _check_param_one_of(self, param: Any, options: Iterable[Any]) -> None:
def _check_param_one_of(self, param: str, options: Iterable[Any]) -> None:
"""Raise when parameter value is not one of a specified set."""
value = getattr(self, param)
if value not in options:
Expand All @@ -41,6 +42,18 @@ def _check_param_one_of(self, param: Any, options: Iterable[Any]) -> None:
])
raise ValueError(err)

def _check_grouping_vars(
self, param: str, data_vars: list[str], stacklevel: int = 2,
) -> None:
"""Warn if vars are named in parameter without being present in the data."""
param_vars = getattr(self, param)
undefined = set(param_vars) - set(data_vars)
if undefined:
param = f"{self.__class__.__name__}.{param}"
names = ", ".join(f"{x!r}" for x in undefined)
msg = f"Undefined variable(s) passed for {param}: {names}."
warnings.warn(msg, stacklevel=stacklevel)

def __call__(
self,
data: DataFrame,
Expand Down
24 changes: 6 additions & 18 deletions seaborn/_stats/counting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from warnings import warn
from typing import ClassVar

import numpy as np
Expand Down Expand Up @@ -86,7 +85,6 @@ class Hist(Stat):

Notes
-----

The choice of bins for computing and plotting a histogram can exert
substantial influence on the insights that one is able to draw from the
visualization. If the bins are too large, they may erase important features.
Expand All @@ -100,7 +98,6 @@ class Hist(Stat):
by setting the total number of bins to use, the width of each bin, or the
specific locations where the bins should break.


Examples
--------
.. include:: ../docstrings/objects.Hist.rst
Expand Down Expand Up @@ -215,12 +212,8 @@ def __call__(
bin_groupby = GroupBy(grouping_vars)
else:
bin_groupby = GroupBy(self.common_bins)
undefined = set(self.common_bins) - set(grouping_vars)
if undefined:
param = f"{self.__class__.__name__}.common_bins"
names = ", ".join(f"{x!r}" for x in undefined)
msg = f"Undefined variables(s) passed to `{param}`: {names}."
warn(msg)
self._check_grouping_vars("common_bins", grouping_vars)

data = bin_groupby.apply(
data, self._get_bins_and_eval, orient, groupby, scale_type,
)
Expand All @@ -229,16 +222,11 @@ def __call__(
data = self._normalize(data)
else:
if self.common_norm is False:
norm_grouper = grouping_vars
norm_groupby = GroupBy(grouping_vars)
else:
norm_grouper = self.common_norm
undefined = set(self.common_norm) - set(grouping_vars)
if undefined:
param = f"{self.__class__.__name__}.common_norm"
names = ", ".join(f"{x!r}" for x in undefined)
msg = f"Undefined variables(s) passed to `{param}`: {names}."
warn(msg)
data = GroupBy(norm_grouper).apply(data, self._normalize)
norm_groupby = GroupBy(self.common_norm)
self._check_grouping_vars("common_norm", grouping_vars)
data = norm_groupby.apply(data, self._normalize)

other = {"x": "y", "y": "x"}[orient]
return data.assign(**{other: data[self.stat]})
Loading