Skip to content

Commit

Permalink
Fix hue_order in PairGrid
Browse files Browse the repository at this point in the history
This fixes #472.

This also changes the default `hue_order` to use the same `category_order`
rules as elsewhere in seaborn (cf #361).
  • Loading branch information
mwaskom committed May 9, 2015
1 parent 377f056 commit b0a975e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 10 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.6.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,5 @@ Bug fixes
- Fixed a bug in :class:`FacetGrid` and :class:`PairGrid` that lead to incorrect legend labels when levels of the ``hue`` variable appeared in ``hue_order`` but not in the data.

- Fixed a bug in :meth:`FacetGrid.set_xticklabels` or :meth:`FacetGrid.set_yticklabels` when ``col_wrap`` is being used.

- Fixed a bug in :class:`PairGrid` where the ``hue_order`` parameter was ignored.
48 changes: 38 additions & 10 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,14 +915,11 @@ def __init__(self, data, hue=None, hue_order=None, palette=None,
# Sort out the hue variable
self._hue_var = hue
if hue is None:
self.hue_names = None
self.hue_names = ["_nolegend_"]
self.hue_vals = pd.Series(["_nolegend_"] * len(data),
index=data.index)
else:
if hue_order is None:
hue_names = np.unique(np.sort(data[hue]))
else:
hue_names = hue_order
hue_names = utils.categorical_order(data[hue], hue_order)
if dropna:
# Filter NA from the list of unique hue names
hue_names = list(filter(pd.notnull, hue_names))
Expand Down Expand Up @@ -954,7 +951,14 @@ def map(self, func, **kwargs):
for i, y_var in enumerate(self.y_vars):
for j, x_var in enumerate(self.x_vars):
hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)

Expand Down Expand Up @@ -1008,11 +1012,22 @@ def map_diag(self, func, **kwargs):
# Special-case plt.hist with stacked bars
if func is plt.hist:
plt.sca(ax)
vals = [v.values for g, v in hue_grouped]
vals = []
for label in self.hue_names:
# Attempt to get data for this level, allowing for empty
try:
vals.append(hue_grouped.get_group(label))
except KeyError:
vals.append(np.array([]))
func(vals, color=self.palette, histtype="barstacked",
**kwargs)
else:
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])
plt.sca(ax)
func(data_k, label=label_k,
color=self.palette[k], **kwargs)
Expand All @@ -1034,7 +1049,13 @@ def map_lower(self, func, **kwargs):
kw_color = kwargs.pop("color", None)
for i, j in zip(*np.tril_indices_from(self.axes, -1)):
hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):
for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)
Expand Down Expand Up @@ -1071,7 +1092,14 @@ def map_upper(self, func, **kwargs):
for i, j in zip(*np.triu_indices_from(self.axes, 1)):

hue_grouped = self.data.groupby(self.hue_vals)
for k, (label_k, data_k) in enumerate(hue_grouped):

for k, label_k in enumerate(self.hue_names):

# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = np.array([])

ax = self.axes[i, j]
plt.sca(ax)
Expand Down
52 changes: 52 additions & 0 deletions seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,58 @@ def test_hue_kws(self):
for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
nt.assert_equal(line.get_marker(), marker)

g = ag.PairGrid(self.df, hue="a", hue_kws=kws,
hue_order=list("dcab"))
g.map(plt.plot)

for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
nt.assert_equal(line.get_marker(), marker)

plt.close("all")

def test_hue_order(self):

order = list("dcab")
g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map(plt.plot)

for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_diag(plt.plot)

for line, level in zip(g.axes[0, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_lower(plt.plot)

for line, level in zip(g.axes[1, 0].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

plt.close("all")

g = ag.PairGrid(self.df, hue="a", hue_order=order)
g.map_upper(plt.plot)

for line, level in zip(g.axes[0, 1].lines, order):
x, y = line.get_xydata().T
npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

plt.close("all")

def test_nondefault_index(self):

df = self.df.copy().set_index("b")
Expand Down

0 comments on commit b0a975e

Please sign in to comment.