diff --git a/seaborn/categorical.py b/seaborn/categorical.py index f48a50f06e..7b6f50eabf 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -2193,6 +2193,42 @@ def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, """).format(**_categorical_docs) +def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None, + orient=None, color=None, palette=None, saturation=.75, + ax=None, **kwargs): + + estimator = len + ci = None + n_boot = 0 + units = None + errcolor = None + + if orient is None and y is None: + orient = "v" + + if x is None and y is not None: + x = y + elif y is None and x is not None: + y = x + elif x is not None and y is not None: + raise TypeError("Cannot pass values for both `x` and `y`") + else: + raise TypeError("Must pass valus for either `x` or `y`") + + plotter = _BarPlotter(x, y, hue, data, order, hue_order, + estimator, ci, n_boot, units, + orient, color, palette, saturation, + errcolor) + + plotter.value_label = "count" + + if ax is None: + ax = plt.gca() + + plotter.plot(ax, kwargs) + return ax + + def factorplot(x=None, y=None, hue=None, data=None, row=None, col=None, col_wrap=None, estimator=np.mean, ci=95, n_boot=1000, units=None, order=None, hue_order=None, row_order=None, @@ -2222,10 +2258,20 @@ def factorplot(x=None, y=None, hue=None, data=None, row=None, col=None, err = "Plot kind '{}' is not recognized".format(kind) raise ValueError(err) + # Alias the input variables to determine categorical order and palette + # correctly in the case of a count plot + if kind == "count": + if x is None and y is not None: + x_, y_ = y, y + elif y is None and x is not None: + x_, y_ = x, x + else: + x_, y_ = x, y + # Determine the order for the whole dataset, which will be used in all # facets to ensure representation of all data in the final plot p = _CategoricalPlotter() - p.establish_variables(x, y, hue, data, orient, order, hue_order) + p.establish_variables(x_, y_, hue, data, orient, order, hue_order) order = p.group_names hue_order = p.hue_names diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 8cc517b65f..904b4c6665 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -1981,6 +1981,41 @@ def test_simple_pointplots(self): plt.close("all") +class TestCountPlot(CategoricalFixture): + + def test_plot_elements(self): + + ax = cat.countplot("g", data=self.df) + nt.assert_equal(ax.patches, self.g.unique().size) + for p in ax.patches: + nt.assert_equal(p.get_y(), 0) + nt.assert_equal(p.get_height(), + self.g.size() / self.g.unique().size) + + ax = cat.countplot(y="g", data=self.df) + nt.assert_equal(ax.patches, self.g.unique().size) + for p in ax.patches: + nt.assert_equal(p.get_x(), 0) + nt.assert_equal(p.get_width(), + self.g.size() / self.g.unique().size) + + ax = cat.countplot("g", hue="h", data=self.df) + nt.assert_equal(ax.patches, + self.g.unique().size * self.h.unique().size) + + ax = cat.countplot(y="g", hue="h", data=self.df) + nt.assert_equal(ax.patches, + self.g.unique().size * self.h.unique().size) + + def test_input_error(self): + + with nt.assert_raises(TypeError): + cat.countplot() + + with nt.assert_raises(TypeError): + cat.countplot(x="g", y="h", data=self.df) + + class TestFactorPlot(CategoricalFixture): def test_facet_organization(self): @@ -2022,6 +2057,16 @@ def test_plot_elements(self): nt.assert_equal(len(g.ax.patches), want_elements) nt.assert_equal(len(g.ax.lines), want_elements) + g = cat.factorplot("g", data=self.df, kind="count") + want_elements = self.g.unique().size + nt.assert_equal(len(g.ax.patches), want_elements) + nt.assert_equal(len(g.ax.lines), 0) + + g = cat.factorplot("g", hue="h", data=self.df, kind="count") + want_elements = self.g.unique().size * self.h.unique().size + nt.assert_equal(len(g.ax.patches), want_elements) + nt.assert_equal(len(g.ax.lines), 0) + g = cat.factorplot("g", "y", data=self.df, kind="box") want_artists = self.g.unique().size nt.assert_equal(len(g.ax.artists), want_artists)