Skip to content

Commit

Permalink
ENH: allow get_dummies to accept dtype argument (#18330)
Browse files Browse the repository at this point in the history
  • Loading branch information
Scorpil authored and jreback committed Nov 22, 2017
1 parent bd145c8 commit fedc503
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 191 deletions.
13 changes: 12 additions & 1 deletion doc/source/reshaping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ values will be set to ``NaN``.
df3
df3.unstack()
.. versionadded: 0.18.0
.. versionadded:: 0.18.0

Alternatively, unstack takes an optional ``fill_value`` argument, for specifying
the value of missing data.
Expand Down Expand Up @@ -634,6 +634,17 @@ When a column contains only one level, it will be omitted in the result.
pd.get_dummies(df, drop_first=True)
By default new columns will have ``np.uint8`` dtype. To choose another dtype use ``dtype`` argument:

.. ipython:: python
df = pd.DataFrame({'A': list('abc'), 'B': [1.1, 2.2, 3.3]})
pd.get_dummies(df, dtype=bool).dtypes
.. versionadded:: 0.22.0


.. _reshaping.factorize:

Factorizing values
Expand Down
15 changes: 15 additions & 0 deletions doc/source/whatsnew/v0.22.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ New features
-
-


.. _whatsnew_0210.enhancements.get_dummies_dtype:

``get_dummies`` now supports ``dtype`` argument
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The :func:`get_dummies` now accepts a ``dtype`` argument, which specifies a dtype for the new columns. The default remains uint8. (:issue:`18330`)

.. ipython:: python

df = pd.DataFrame({'a': [1, 2], 'b': [3, 4], 'c': [5, 6]})
pd.get_dummies(df, columns=['c']).dtypes
pd.get_dummies(df, columns=['c'], dtype=bool).dtypes


.. _whatsnew_0220.enhancements.other:

Other Enhancements
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def _set_axis_name(self, name, axis=0, inplace=False):
inplace : bool
whether to modify `self` directly or return a copy
.. versionadded: 0.21.0
.. versionadded:: 0.21.0
Returns
-------
Expand Down
38 changes: 29 additions & 9 deletions pandas/core/reshape/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pandas.core.dtypes.common import (
_ensure_platform_int,
is_list_like, is_bool_dtype,
needs_i8_conversion, is_sparse)
needs_i8_conversion, is_sparse, is_object_dtype)
from pandas.core.dtypes.cast import maybe_promote
from pandas.core.dtypes.missing import notna

Expand Down Expand Up @@ -697,7 +697,7 @@ def _convert_level_number(level_num, columns):


def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
columns=None, sparse=False, drop_first=False):
columns=None, sparse=False, drop_first=False, dtype=None):
"""
Convert categorical variable into dummy/indicator variables
Expand Down Expand Up @@ -728,6 +728,11 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
.. versionadded:: 0.18.0
dtype : dtype, default np.uint8
Data type for new columns. Only a single dtype is allowed.
.. versionadded:: 0.22.0
Returns
-------
dummies : DataFrame or SparseDataFrame
Expand Down Expand Up @@ -783,6 +788,12 @@ def get_dummies(data, prefix=None, prefix_sep='_', dummy_na=False,
3 0 0
4 0 0
>>> pd.get_dummies(pd.Series(list('abc')), dtype=float)
a b c
0 1.0 0.0 0.0
1 0.0 1.0 0.0
2 0.0 0.0 1.0
See Also
--------
Series.str.get_dummies
Expand Down Expand Up @@ -835,20 +846,29 @@ def check_len(item, name):

dummy = _get_dummies_1d(data[col], prefix=pre, prefix_sep=sep,
dummy_na=dummy_na, sparse=sparse,
drop_first=drop_first)
drop_first=drop_first, dtype=dtype)
with_dummies.append(dummy)
result = concat(with_dummies, axis=1)
else:
result = _get_dummies_1d(data, prefix, prefix_sep, dummy_na,
sparse=sparse, drop_first=drop_first)
sparse=sparse,
drop_first=drop_first,
dtype=dtype)
return result


def _get_dummies_1d(data, prefix, prefix_sep='_', dummy_na=False,
sparse=False, drop_first=False):
sparse=False, drop_first=False, dtype=None):
# Series avoids inconsistent NaN handling
codes, levels = _factorize_from_iterable(Series(data))

if dtype is None:
dtype = np.uint8
dtype = np.dtype(dtype)

if is_object_dtype(dtype):
raise ValueError("dtype=object is not a valid dtype for get_dummies")

def get_empty_Frame(data, sparse):
if isinstance(data, Series):
index = data.index
Expand Down Expand Up @@ -903,18 +923,18 @@ def get_empty_Frame(data, sparse):
sp_indices = sp_indices[1:]
dummy_cols = dummy_cols[1:]
for col, ixs in zip(dummy_cols, sp_indices):
sarr = SparseArray(np.ones(len(ixs), dtype=np.uint8),
sarr = SparseArray(np.ones(len(ixs), dtype=dtype),
sparse_index=IntIndex(N, ixs), fill_value=0,
dtype=np.uint8)
dtype=dtype)
sparse_series[col] = SparseSeries(data=sarr, index=index)

out = SparseDataFrame(sparse_series, index=index, columns=dummy_cols,
default_fill_value=0,
dtype=np.uint8)
dtype=dtype)
return out

else:
dummy_mat = np.eye(number_of_cols, dtype=np.uint8).take(codes, axis=0)
dummy_mat = np.eye(number_of_cols, dtype=dtype).take(codes, axis=0)

if not dummy_na:
# reset NaN GH4446
Expand Down
Loading

0 comments on commit fedc503

Please sign in to comment.