diff --git a/doc/source/whatsnew/v0.20.3.txt b/doc/source/whatsnew/v0.20.3.txt index 3ffc3c12852c0d..b2d382a3202a50 100644 --- a/doc/source/whatsnew/v0.20.3.txt +++ b/doc/source/whatsnew/v0.20.3.txt @@ -45,7 +45,8 @@ Conversion ^^^^^^^^^^ - Bug in pickle compat prior to the v0.20.x series, when ``UTC`` is a timezone in a Series/DataFrame/Index (:issue:`16608`) -- Bug in Series construction when passing a Series with ``dtype='category'`` (:issue:`16524`). +- Bug in ``Series`` construction when passing a ``Series`` with ``dtype='category'`` (:issue:`16524`). +- Bug in ``DataFrame.astype()`` when passing a ``Series`` as the ``dtype`` kwarg. (:issue:`16717`). Indexing ^^^^^^^^ diff --git a/pandas/core/generic.py b/pandas/core/generic.py index be90244119dc2d..87ce570e992b9e 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -3379,12 +3379,12 @@ def astype(self, dtype, copy=True, errors='raise', **kwargs): ------- casted : type of caller """ - if isinstance(dtype, collections.Mapping): + if is_dict_like(dtype): if self.ndim == 1: # i.e. Series - if len(dtype) > 1 or list(dtype.keys())[0] != self.name: + if len(dtype) > 1 or self.name not in dtype: raise KeyError('Only the Series name can be used for ' 'the key in Series dtype mappings.') - new_type = list(dtype.values())[0] + new_type = dtype[self.name] return self.astype(new_type, copy, errors, **kwargs) elif self.ndim > 2: raise NotImplementedError( diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index b99a6fabfa42b7..335b76ff2aadeb 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -442,8 +442,9 @@ def test_astype_str(self): expected = DataFrame(['1.12345678901']) assert_frame_equal(result, expected) - def test_astype_dict(self): - # GH7271 + @pytest.mark.parametrize("dtype_class", [dict, Series]) + def test_astype_dict_like(self, dtype_class): + # GH7271 & GH16717 a = Series(date_range('2010-01-04', periods=5)) b = Series(range(5)) c = Series([0.0, 0.2, 0.4, 0.6, 0.8]) @@ -452,7 +453,8 @@ def test_astype_dict(self): original = df.copy(deep=True) # change type of a subset of columns - result = df.astype({'b': 'str', 'd': 'float32'}) + dt1 = dtype_class({'b': 'str', 'd': 'float32'}) + result = df.astype(dt1) expected = DataFrame({ 'a': a, 'b': Series(['0', '1', '2', '3', '4']), @@ -461,7 +463,8 @@ def test_astype_dict(self): assert_frame_equal(result, expected) assert_frame_equal(df, original) - result = df.astype({'b': np.float32, 'c': 'float32', 'd': np.float64}) + dt2 = dtype_class({'b': np.float32, 'c': 'float32', 'd': np.float64}) + result = df.astype(dt2) expected = DataFrame({ 'a': a, 'b': Series([0.0, 1.0, 2.0, 3.0, 4.0], dtype='float32'), @@ -471,19 +474,31 @@ def test_astype_dict(self): assert_frame_equal(df, original) # change all columns - assert_frame_equal(df.astype({'a': str, 'b': str, 'c': str, 'd': str}), + dt3 = dtype_class({'a': str, 'b': str, 'c': str, 'd': str}) + assert_frame_equal(df.astype(dt3), df.astype(str)) assert_frame_equal(df, original) # error should be raised when using something other than column labels # in the keys of the dtype dict - pytest.raises(KeyError, df.astype, {'b': str, 2: str}) - pytest.raises(KeyError, df.astype, {'e': str}) + dt4 = dtype_class({'b': str, 2: str}) + dt5 = dtype_class({'e': str}) + pytest.raises(KeyError, df.astype, dt4) + pytest.raises(KeyError, df.astype, dt5) assert_frame_equal(df, original) # if the dtypes provided are the same as the original dtypes, the # resulting DataFrame should be the same as the original DataFrame - equiv = df.astype({col: df[col].dtype for col in df.columns}) + dt6 = dtype_class({col: df[col].dtype for col in df.columns}) + equiv = df.astype(dt6) + assert_frame_equal(df, equiv) + assert_frame_equal(df, original) + + # GH 16717 + # if dtypes provided is empty, the resulting DataFrame + # should be the same as the original DataFrame + dt7 = dtype_class({}) + result = df.astype(dt7) assert_frame_equal(df, equiv) assert_frame_equal(df, original) diff --git a/pandas/tests/series/test_dtypes.py b/pandas/tests/series/test_dtypes.py index 9ab02a8c2aad73..2ec579842e33f2 100644 --- a/pandas/tests/series/test_dtypes.py +++ b/pandas/tests/series/test_dtypes.py @@ -152,24 +152,35 @@ def test_astype_unicode(self): reload(sys) # noqa sys.setdefaultencoding(former_encoding) - def test_astype_dict(self): + @pytest.mark.parametrize("dtype_class", [dict, Series]) + def test_astype_dict_like(self, dtype_class): # see gh-7271 s = Series(range(0, 10, 2), name='abc') - result = s.astype({'abc': str}) + dt1 = dtype_class({'abc': str}) + result = s.astype(dt1) expected = Series(['0', '2', '4', '6', '8'], name='abc') tm.assert_series_equal(result, expected) - result = s.astype({'abc': 'float64'}) + dt2 = dtype_class({'abc': 'float64'}) + result = s.astype(dt2) expected = Series([0.0, 2.0, 4.0, 6.0, 8.0], dtype='float64', name='abc') tm.assert_series_equal(result, expected) + dt3 = dtype_class({'abc': str, 'def': str}) with pytest.raises(KeyError): - s.astype({'abc': str, 'def': str}) + s.astype(dt3) + dt4 = dtype_class({0: str}) with pytest.raises(KeyError): - s.astype({0: str}) + s.astype(dt4) + + # GH16717 + # if dtypes provided is empty, it should error + dt5 = dtype_class({}) + with pytest.raises(KeyError): + s.astype(dt5) def test_astype_generic_timestamp_deprecated(self): # see gh-15524