Skip to content

Commit

Permalink
Allow Plot.label to control title(s) (#2934)
Browse files Browse the repository at this point in the history
* Don't show facet variable names in facet titles

* Don't document Plot.label as accepting None as a value

* Allow Plot.label to control titles, including when faceting

* Don't include separator in labeled facet title

* Clean up title typing

* Fix legend test

* Fix legend contents typing after rebase

* Add theme update to Plot.clone and remove outdated todo
  • Loading branch information
mwaskom authored Aug 4, 2022
1 parent 762db89 commit a259ac5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 29 deletions.
62 changes: 37 additions & 25 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import contextmanager
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any, Optional, cast
from typing import Any, cast

from cycler import cycler
import pandas as pd
Expand Down Expand Up @@ -151,7 +151,7 @@ class Plot:

_scales: dict[str, Scale]
_limits: dict[str, tuple[Any, Any]]
_labels: dict[str, str | Callable[[str], str] | None]
_labels: dict[str, str | Callable[[str], str]]
_theme: dict[str, Any]

_facet_spec: FacetSpec
Expand All @@ -176,6 +176,7 @@ def __init__(
raise TypeError(err)

self._data = PlotData(data, variables)

self._layers = []

self._scales = {}
Expand Down Expand Up @@ -249,8 +250,9 @@ def _clone(self) -> Plot:
new._layers.extend(self._layers)

new._scales.update(self._scales)
new._labels.update(self._labels)
new._limits.update(self._limits)
new._labels.update(self._labels)
new._theme.update(self._theme)

new._facet_spec.update(self._facet_spec)
new._pair_spec.update(self._pair_spec)
Expand Down Expand Up @@ -599,23 +601,27 @@ def limit(self, **limits: tuple[Any, Any]) -> Plot:
new._limits.update(limits)
return new

def label(self, **labels: str | Callable[[str], str] | None) -> Plot:
def label(self, *, title=None, **variables: str | Callable[[str], str]) -> Plot:
"""
Control the labels used for variables in the plot.
For coordinate variables, this sets the axis label.
For semantic variables, it sets the legend title.
Add or modify labels for axes, legends, and subplots.
Keywords correspond to variables defined in the plot.
Additional keywords correspond to variables defined in the plot.
Values can be one of the following types:
- string (used literally)
- string (used literally; pass "" to clear the default label)
- function (called on the default label)
- None (disables the label for this variable)
For coordinate variables, the value sets the axis label.
For semantic variables, the value sets the legend title.
For faceting variables, `title=` modifies the subplot-specific label,
while `col=` and/or `row=` add a label for the faceting variable.
When using a single subplot, `title=` sets its title.
"""
new = self._clone()
new._labels.update(labels)
if title is not None:
new._labels["title"] = title
new._labels.update(variables)
return new

def configure(
Expand Down Expand Up @@ -773,7 +779,7 @@ def __init__(self, pyplot: bool, theme: dict[str, Any]):
self._pyplot = pyplot
self._theme = theme
self._legend_contents: list[tuple[
tuple[str | None, str | int], list[Artist], list[str],
tuple[str, str | int], list[Artist], list[str],
]] = []
self._scales: dict[str, Scale] = {}

Expand Down Expand Up @@ -852,16 +858,17 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:

return common_data, layers

def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str | None:
def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str:

label: str | None
label: str
if var in p._labels:
manual_label = p._labels[var]
if callable(manual_label) and auto_label is not None:
label = manual_label(auto_label)
else:
# mypy needs a lot of help here, I'm not sure why
label = cast(Optional[str], manual_label)
label = cast(str, manual_label)
elif auto_label is None:
label = ""
else:
label = auto_label
return label
Expand Down Expand Up @@ -935,10 +942,13 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
# Let's have what we currently call "margin titles" but properly using the
# ax.set_title interface (see my gist)
title_parts = []
for dim in ["row", "col"]:
for dim in ["col", "row"]:
if sub[dim] is not None:
name = common.names.get(dim) # TODO None = val looks bad
title_parts.append(f"{name} = {sub[dim]}")
val = self._resolve_label(p, "title", f"{sub[dim]}")
if dim in p._labels:
key = self._resolve_label(p, dim, common.names.get(dim))
val = f"{key} {val}"
title_parts.append(val)

has_col = sub["col"] is not None
has_row = sub["row"] is not None
Expand All @@ -953,6 +963,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
title = " | ".join(title_parts)
title_text = ax.set_title(title)
title_text.set_visible(show_title)
elif not (has_col or has_row):
title = self._resolve_label(p, "title", None)
title_text = ax.set_title(title)

def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:

Expand Down Expand Up @@ -1445,7 +1458,7 @@ def _update_legend_contents(

# First pass: Identify the values that will be shown for each variable
schema: list[tuple[
tuple[str | None, str | int], list[str], tuple[list, list[str]]
tuple[str, str | int], list[str], tuple[list, list[str]]
]] = []
schema = []
for var in legend_vars:
Expand All @@ -1458,8 +1471,7 @@ def _update_legend_contents(
part_vars.append(var)
break
else:
auto_title = data.names[var]
title = self._resolve_label(p, var, auto_title)
title = self._resolve_label(p, var, data.names[var])
entry = (title, data.ids[var]), [var], (values, labels)
schema.append(entry)

Expand All @@ -1479,7 +1491,7 @@ def _make_legend(self, p: Plot) -> None:
# Input list has an entry for each distinct variable in each layer
# Output dict has an entry for each distinct variable
merged_contents: dict[
tuple[str | None, str | int], tuple[list[Artist], list[str]],
tuple[str, str | int], tuple[list[Artist], list[str]],
] = {}
for key, artists, labels in self._legend_contents:
# Key is (name, id); we need the id to resolve variable uniqueness,
Expand All @@ -1503,7 +1515,7 @@ def _make_legend(self, p: Plot) -> None:
self._figure,
handles,
labels,
title="" if name is None else name,
title=name,
loc="center left",
bbox_to_anchor=(.98, .55),
)
Expand Down
36 changes: 32 additions & 4 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,34 @@ def test_labels_legend(self, long_df):
p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=func).plot()
assert p._figure.legends[0].get_title().get_text() == label

def test_labels_facets(self):

data = {"a": ["b", "c"], "x": ["y", "z"]}
p = Plot(data).facet("a", "x").label(col=str.capitalize, row="$x$").plot()
axs = np.reshape(p._figure.axes, (2, 2))
for (i, j), ax in np.ndenumerate(axs):
expected = f"A {data['a'][j]} | $x$ {data['x'][i]}"
assert ax.get_title() == expected

def test_title_single(self):

label = "A"
p = Plot().label(title=label).plot()
assert p._figure.axes[0].get_title() == label

def test_title_facet_function(self):

titles = ["a", "b"]
p = Plot().facet(titles).label(title=str.capitalize).plot()
for i, ax in enumerate(p._figure.axes):
assert ax.get_title() == titles[i].upper()

cols, rows = ["a", "b"], ["x", "y"]
p = Plot().facet(cols, rows).label(title=str.capitalize).plot()
for i, ax in enumerate(p._figure.axes):
expected = " | ".join([cols[i % 2].upper(), rows[i // 2].upper()])
assert ax.get_title() == expected


class TestFacetInterface:

Expand Down Expand Up @@ -1152,7 +1180,7 @@ def check_facet_results_1d(self, p, df, dim, key, order=None):
for subplot, level in zip(p._subplots, order):
assert subplot[dim] == level
assert subplot[other_dim] is None
assert subplot["ax"].get_title() == f"{key} = {level}"
assert subplot["ax"].get_title() == f"{level}"
assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)})

def test_1d(self, long_df, dim):
Expand Down Expand Up @@ -1188,7 +1216,7 @@ def check_facet_results_2d(self, p, df, variables, order=None):
assert subplot["row"] == row_level
assert subplot["col"] == col_level
assert subplot["axes"].get_title() == (
f"{variables['row']} = {row_level} | {variables['col']} = {col_level}"
f"{col_level} | {row_level}"
)
assert_gridspec_shape(
subplot["axes"], len(levels["row"]), len(levels["col"])
Expand Down Expand Up @@ -1375,7 +1403,7 @@ def test_with_facets(self, long_df):
ax = subplot["ax"]
assert ax.get_xlabel() == x
assert ax.get_ylabel() == y_i
assert ax.get_title() == f"{col} = {col_i}"
assert ax.get_title() == f"{col_i}"
assert_gridspec_shape(ax, len(y), len(facet_levels))

@pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")])
Expand Down Expand Up @@ -1763,7 +1791,7 @@ def test_single_layer_common_unnamed_variable(self, xy):

labels = list(np.unique(s)) # assumes sorted order

assert e[0] == (None, id(s))
assert e[0] == ("", id(s))
assert e[-1] == labels

artists = e[1]
Expand Down

0 comments on commit a259ac5

Please sign in to comment.