Skip to content

Commit

Permalink
Add stripplot
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jan 14, 2015
1 parent 6e8d4f7 commit 1233017
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 41 deletions.
84 changes: 84 additions & 0 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def restyle_boxplot(self, artist_dict, color):
"""Take a drawn matplotlib boxplot and make it look nice."""
for box in artist_dict["boxes"]:
box.set_color(color)
box.set_zorder(.9)
box.set_edgecolor(self.gray)
box.set_linewidth(self.linewidth)
for whisk in artist_dict["whiskers"]:
Expand Down Expand Up @@ -661,6 +662,75 @@ def plot(self, ax):
self.annotate_axes(ax)


class _StripPlotter(_BoxPlotter):
"""1-d scatterplot with categorical organization."""
def __init__(self, x, y, hue, data, order, hue_order,
jitter, split, orient, color, palette):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient, order, hue_order)
self.establish_colors(color, palette, 1)

# Set object attributes
self.split = split
self.width = .8

if jitter == 1: # Use a good default for `jitter = True`
jlim = 0.1
else:
jlim = float(jitter)
if self.hue_names is not None and split:
jlim /= len(self.hue_names)
self.jitterer = stats.uniform(-jlim, jlim * 2).rvs

def draw_stripplot(self, ax, kws):
"""Draw the points onto `ax`."""
for i, group_data in enumerate(self.plot_data):
if self.plot_hues is None:

# Determine the positions of the points
strip_data = remove_na(group_data)
jitter = self.jitterer(len(strip_data))
kws["color"] = self.colors[i]

# Draw the plot
if self.orient == "v":
ax.scatter(i + jitter, strip_data, **kws)
else:
ax.scatter(strip_data, i + jitter, **kws)

else:
offsets = self.hue_offsets
for j, hue_level in enumerate(self.hue_names):
hue_mask = self.plot_hues[i] == hue_level
if not hue_mask.any():
continue

# Determine the positions of the points
strip_data = remove_na(group_data[hue_mask])
pos = i + offsets[j] if self.split else i
jitter = self.jitterer(len(strip_data))
kws["color"] = self.colors[j]

# Only label one set of plots
if i:
kws.pop("label", None)
else:
kws["label"] = hue_level

# Draw the plot
if self.orient == "v":
ax.scatter(pos + jitter, strip_data, **kws)
else:
ax.scatter(strip_data, pos + jitter, **kws)

def plot(self, ax, kws):
"""Make the plot."""
self.draw_stripplot(ax, kws)
self.annotate_axes(ax)

# TODO The horizontal representation should go from top to bottom


class _SwarmPlotter(_BoxPlotter):

def __init__(self):
Expand Down Expand Up @@ -814,6 +884,20 @@ def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
return ax


def stripplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
jitter=False, split=True, orient=None, color=None, palette=None,
ax=None, **kwargs):

plotter = _StripPlotter(x, y, hue, data, order, hue_order,
jitter, split, orient, color, palette)
if ax is None:
ax = plt.gca()

plotter.plot(ax, kwargs)

return ax


def boxplot_old(vals, groupby=None, names=None, join_rm=False, order=None,
color=None, alpha=None, fliersize=3, linewidth=1.5, widths=.8,
saturation=.7, label=None, ax=None, **kwargs):
Expand Down
174 changes: 133 additions & 41 deletions seaborn/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ class TestBoxPlotter(object):
default_kws = dict(x=None, y=None, hue=None, data=None,
order=None, hue_order=None,
orient=None, color=None, palette=None,
saturation=.75, alpha=None,
width=.8, fliersize=5, linewidth=None)
saturation=.75, width=.8,
fliersize=5, linewidth=None)

def test_wide_df_data(self):

Expand Down Expand Up @@ -498,6 +498,137 @@ def test_axes_annotation(self):
plt.close("all")


class TestStripPlotter(object):
"""Test boxplot (also base class for things like violinplots)."""
rs = np.random.RandomState(30)
n_total = 60
y = pd.Series(rs.randn(n_total), name="y_data")
g = pd.Series(np.repeat(list("abc"), n_total / 3), name="small")
h = pd.Series(np.tile(list("mn"), n_total / 2), name="medium")
df = pd.DataFrame(dict(y=y, g=g, h=h))

def test_stripplot_vertical(self):

pal = palettes.color_palette()

ax = dist.stripplot("g", "y", data=self.df)
for i, (_, vals) in enumerate(self.y.groupby(self.g)):

x, y = ax.collections[i].get_offsets().T

npt.assert_array_equal(x, np.ones(len(x)) * i)
npt.assert_array_equal(y, vals)

npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i])

plt.close("all")

@skipif(not pandas_has_categoricals)
def test_stripplot_horiztonal(self):

df = self.df.copy()
df.g = df.g.astype("category")

ax = dist.stripplot("y", "g", data=df)
for i, (_, vals) in enumerate(self.y.groupby(self.g)):

x, y = ax.collections[i].get_offsets().T

npt.assert_array_equal(x, vals)
npt.assert_array_equal(y, np.ones(len(x)) * i)

plt.close("all")

def test_stripplot_jitter(self):

pal = palettes.color_palette()

ax = dist.stripplot("g", "y", data=self.df, jitter=True)
for i, (_, vals) in enumerate(self.y.groupby(self.g)):

x, y = ax.collections[i].get_offsets().T

npt.assert_array_less(np.ones(len(x)) * i - .1, x)
npt.assert_array_less(x, np.ones(len(x)) * i + .1)
npt.assert_array_equal(y, vals)

npt.assert_equal(ax.collections[i].get_facecolors()[0, :3], pal[i])

plt.close("all")

def test_split_nested_stripplot_vertical(self):

pal = palettes.color_palette()

ax = dist.stripplot("g", "y", "h", data=self.df)
for i, (_, group_vals) in enumerate(self.y.groupby(self.g)):
for j, (_, vals) in enumerate(group_vals.groupby(self.h)):

x, y = ax.collections[i * 2 + j].get_offsets().T

npt.assert_array_equal(x, np.ones(len(x)) * i + [-.2, .2][j])
npt.assert_array_equal(y, vals)

fc = ax.collections[i * 2 + j].get_facecolors()[0, :3]
npt.assert_equal(fc, pal[j])

plt.close("all")

@skipif(not pandas_has_categoricals)
def test_split_nested_stripplot_horizontal(self):

df = self.df.copy()
df.g = df.g.astype("category")

ax = dist.stripplot("y", "g", "h", data=df)
plt.savefig("/Users/mwaskom/Desktop/nose.png")
for i, (_, group_vals) in enumerate(self.y.groupby(self.g)):
for j, (_, vals) in enumerate(group_vals.groupby(self.h)):

x, y = ax.collections[i * 2 + j].get_offsets().T

npt.assert_array_equal(x, vals)
npt.assert_array_equal(y, np.ones(len(x)) * i + [-.2, .2][j])

plt.close("all")

def test_unsplit_nested_stripplot_vertical(self):

pal = palettes.color_palette()

# Test a simple vertical strip plot
ax = dist.stripplot("g", "y", "h", data=self.df, split=False)
for i, (_, group_vals) in enumerate(self.y.groupby(self.g)):
for j, (_, vals) in enumerate(group_vals.groupby(self.h)):

x, y = ax.collections[i * 2 + j].get_offsets().T

npt.assert_array_equal(x, np.ones(len(x)) * i)
npt.assert_array_equal(y, vals)

fc = ax.collections[i * 2 + j].get_facecolors()[0, :3]
npt.assert_equal(fc, pal[j])

plt.close("all")

@skipif(not pandas_has_categoricals)
def test_unsplit_nested_stripplot_horizontal(self):

df = self.df.copy()
df.g = df.g.astype("category")

ax = dist.stripplot("y", "g", "h", data=df, split=False)
for i, (_, group_vals) in enumerate(self.y.groupby(self.g)):
for j, (_, vals) in enumerate(group_vals.groupby(self.h)):

x, y = ax.collections[i * 2 + j].get_offsets().T

npt.assert_array_equal(x, vals)
npt.assert_array_equal(y, np.ones(len(x)) * i)

plt.close("all")


class TestBoxReshaping(object):
"""Tests for function that preps boxplot/violinplot data."""
n_total = 60
Expand Down Expand Up @@ -749,45 +880,6 @@ def test_bivariate_kde_series(self):
plt.close("all")


class TestViolinPlot(object):

df = pd.DataFrame(dict(x=np.random.randn(60),
y=list("abcdef") * 10,
z=list("ab") * 29 + ["a", "c"]))

def test_single_violin(self):

ax = dist.violinplot(self.df.x)
nt.assert_equal(len(ax.collections), 1)
nt.assert_equal(len(ax.lines), 5)
plt.close("all")

def test_multi_violins(self):

ax = dist.violinplot(self.df.x, self.df.y)
nt.assert_equal(len(ax.collections), 6)
nt.assert_equal(len(ax.lines), 30)
plt.close("all")

def test_multi_violins_single_obs(self):

ax = dist.violinplot(self.df.x, self.df.z)
nt.assert_equal(len(ax.collections), 2)
nt.assert_equal(len(ax.lines), 11)
plt.close("all")

data = [np.random.randn(30), [0, 0, 0]]
ax = dist.violinplot(data)
nt.assert_equal(len(ax.collections), 1)
nt.assert_equal(len(ax.lines), 6)
plt.close("all")

@classmethod
def teardown_class(cls):
"""Ensure that all figures are closed on exit."""
plt.close("all")


class TestJointPlot(object):

rs = np.random.RandomState(sum(map(ord, "jointplot")))
Expand Down

0 comments on commit 1233017

Please sign in to comment.