Skip to content

Commit

Permalink
Add countplot function to replace removed barplot functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Mar 13, 2015
1 parent da20f4e commit 66aee90
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 1 deletion.
48 changes: 47 additions & 1 deletion seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,42 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
""").format(**_categorical_docs)


def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
orient=None, color=None, palette=None, saturation=.75,
ax=None, **kwargs):

estimator = len
ci = None
n_boot = 0
units = None
errcolor = None

if orient is None and y is None:
orient = "v"

if x is None and y is not None:
x = y
elif y is None and x is not None:
y = x
elif x is not None and y is not None:
raise TypeError("Cannot pass values for both `x` and `y`")
else:
raise TypeError("Must pass valus for either `x` or `y`")

plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units,
orient, color, palette, saturation,
errcolor)

plotter.value_label = "count"

if ax is None:
ax = plt.gca()

plotter.plot(ax, kwargs)
return ax


def factorplot(x=None, y=None, hue=None, data=None, row=None, col=None,
col_wrap=None, estimator=np.mean, ci=95, n_boot=1000,
units=None, order=None, hue_order=None, row_order=None,
Expand Down Expand Up @@ -2222,10 +2258,20 @@ def factorplot(x=None, y=None, hue=None, data=None, row=None, col=None,
err = "Plot kind '{}' is not recognized".format(kind)
raise ValueError(err)

# Alias the input variables to determine categorical order and palette
# correctly in the case of a count plot
if kind == "count":
if x is None and y is not None:
x_, y_ = y, y
elif y is None and x is not None:
x_, y_ = x, x
else:
x_, y_ = x, y

# Determine the order for the whole dataset, which will be used in all
# facets to ensure representation of all data in the final plot
p = _CategoricalPlotter()
p.establish_variables(x, y, hue, data, orient, order, hue_order)
p.establish_variables(x_, y_, hue, data, orient, order, hue_order)
order = p.group_names
hue_order = p.hue_names

Expand Down
45 changes: 45 additions & 0 deletions seaborn/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,41 @@ def test_simple_pointplots(self):
plt.close("all")


class TestCountPlot(CategoricalFixture):

def test_plot_elements(self):

ax = cat.countplot("g", data=self.df)
nt.assert_equal(ax.patches, self.g.unique().size)
for p in ax.patches:
nt.assert_equal(p.get_y(), 0)
nt.assert_equal(p.get_height(),
self.g.size() / self.g.unique().size)

ax = cat.countplot(y="g", data=self.df)
nt.assert_equal(ax.patches, self.g.unique().size)
for p in ax.patches:
nt.assert_equal(p.get_x(), 0)
nt.assert_equal(p.get_width(),
self.g.size() / self.g.unique().size)

ax = cat.countplot("g", hue="h", data=self.df)
nt.assert_equal(ax.patches,
self.g.unique().size * self.h.unique().size)

ax = cat.countplot(y="g", hue="h", data=self.df)
nt.assert_equal(ax.patches,
self.g.unique().size * self.h.unique().size)

def test_input_error(self):

with nt.assert_raises(TypeError):
cat.countplot()

with nt.assert_raises(TypeError):
cat.countplot(x="g", y="h", data=self.df)


class TestFactorPlot(CategoricalFixture):

def test_facet_organization(self):
Expand Down Expand Up @@ -2022,6 +2057,16 @@ def test_plot_elements(self):
nt.assert_equal(len(g.ax.patches), want_elements)
nt.assert_equal(len(g.ax.lines), want_elements)

g = cat.factorplot("g", data=self.df, kind="count")
want_elements = self.g.unique().size
nt.assert_equal(len(g.ax.patches), want_elements)
nt.assert_equal(len(g.ax.lines), 0)

g = cat.factorplot("g", hue="h", data=self.df, kind="count")
want_elements = self.g.unique().size * self.h.unique().size
nt.assert_equal(len(g.ax.patches), want_elements)
nt.assert_equal(len(g.ax.lines), 0)

g = cat.factorplot("g", "y", data=self.df, kind="box")
want_artists = self.g.unique().size
nt.assert_equal(len(g.ax.artists), want_artists)
Expand Down

5 comments on commit 66aee90

@JWarmenhoven
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I understand correctly that this functionality replaced the option to get a count of a category by using factorplot and not providing an 'y' variable?
(as mentioned in http://stanford.edu/~mwaskom/software/seaborn/tutorial/categorical_linear_models.html)

sns.factorplot("class", data=titanic, palette="PuBuGn_d");

The above line generates an error.

@mwaskom
Copy link
Owner Author

@mwaskom mwaskom commented on 66aee90 Apr 6, 2015 via email

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwaskom
Copy link
Owner Author

@mwaskom mwaskom commented on 66aee90 Apr 6, 2015

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way you can see the release notes for the development version here: http://stanford.edu/~mwaskom/software/seaborn-dev/whatsnew.html

@JWarmenhoven
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, thanks! Amazing library by the way.
I am thinking about looking into some kind of Pareto version of Factorplot.

@JWarmenhoven
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick and dirty concept hack:

df = pd.DataFrame({'Complaint Category': ['B','A','C','A','B','D','B','B','A','D','E'],
                   'Handling Time': [2, 8, 3, 9, 4, 6, 7, 8, 7, 5, 1]})
order=df['Complaint Category'].value_counts().index

g = sns.factorplot(x='Complaint Category', data=df, kind='count', ci=None, order=order)
leftax = g._left_axes
rightax = leftax[0].twinx()
rightax.set_ylabel('Cumulative %')
rightax.plot((df['Complaint Category'].value_counts().cumsum().values/
              df['Complaint Category'].count())*100, '-k.')
rightax.grid(b=None)
rightax.set_ylim(bottom=0)

index

order=df.groupby('Complaint Category')['Handling Time'].sum().order(ascending=False).index

g=sns.factorplot(x='Complaint Category', y='Handling Time', data=df, estimator=sum,
                 kind='bar', ci=None, order=order)
leftax = g._left_axes
rightax = leftax[0].twinx()
rightax.set_ylabel('Cumulative %')
rightax.plot((df.groupby('Complaint Category')['Handling Time'].sum().cumsum().values/
              df['Handling Time'].sum())*100, '-k.')
rightax.grid(b=None)
rightax.set_ylim(bottom=0)

index

Please sign in to comment.