Skip to content

Commit

Permalink
BUG: support dtypes in column_dtypes for to_records()
Browse files Browse the repository at this point in the history
  • Loading branch information
qwhelan committed Jan 26, 2019
1 parent 95f8dca commit 41301c1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
3 changes: 2 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,7 +1716,8 @@ def to_records(self, index=True, convert_datetime64=None,
# string naming a type.
if dtype_mapping is None:
formats.append(v.dtype)
elif isinstance(dtype_mapping, (type, compat.string_types)):
elif isinstance(dtype_mapping, (type, np.dtype,
compat.string_types)):
formats.append(dtype_mapping)
else:
element = "row" if i < index_len else "column"
Expand Down
32 changes: 27 additions & 5 deletions pandas/tests/frame/test_convert_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from pandas.compat import long

from pandas import DataFrame, MultiIndex, Series, Timestamp, compat, date_range
from pandas import (DataFrame, MultiIndex, Series, Timestamp, compat,
date_range, CategoricalDtype)
from pandas.tests.frame.common import TestData
import pandas.util.testing as tm

Expand Down Expand Up @@ -220,6 +221,12 @@ def test_to_records_with_categorical(self):
dtype=[("index", "<i8"), ("A", "<U"),
("B", "<U"), ("C", "<U")])),
# Pass in a dtype instance.
(dict(column_dtypes=np.dtype('unicode')),
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
dtype=[("index", "<i8"), ("A", "<U"),
("B", "<U"), ("C", "<U")])),
# Pass in a dictionary (name-only).
(dict(column_dtypes={"A": np.int8, "B": np.float32, "C": "<U2"}),
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
Expand Down Expand Up @@ -249,6 +256,12 @@ def test_to_records_with_categorical(self):
dtype=[("index", "<i8"), ("A", "i1"),
("B", "<f4"), ("C", "O")])),
# Names / indices not in dtype mapping default to array dtype.
(dict(column_dtypes={"A": np.dtype('int8'), "B": np.dtype('float32')}),
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
dtype=[("index", "<i8"), ("A", "i1"),
("B", "<f4"), ("C", "O")]))])

# Mixture of everything.
(dict(column_dtypes={"A": np.int8, "B": np.float32},
index_dtypes="<U2"),
Expand All @@ -258,17 +271,26 @@ def test_to_records_with_categorical(self):

# Invalid dype values.
(dict(index=False, column_dtypes=list()),
"Invalid dtype \\[\\] specified for column A"),
(ValueError, "Invalid dtype \\[\\] specified for column A")),

(dict(index=False, column_dtypes={"A": "int32", "B": 5}),
"Invalid dtype 5 specified for column B"),
(ValueError, "Invalid dtype 5 specified for column B")),

# Numpy can't handle EA types, so check error is raised
(dict(index=False, column_dtypes={"A": "int32",
"B": CategoricalDtype(['a', 'b'])}),
(ValueError, 'Invalid dtype category specified for column B')),

# Check that bad types raise
(dict(index=False, column_dtypes={"A": "int32", "B": "foo"}),
(TypeError, 'data type "foo" not understood')),
])
def test_to_records_dtype(self, kwargs, expected):
# see gh-18146
df = DataFrame({"A": [1, 2], "B": [0.2, 1.5], "C": ["a", "bc"]})

if isinstance(expected, str):
with pytest.raises(ValueError, match=expected):
if not isinstance(expected, np.recarray):
with pytest.raises(expected[0], match=expected[1]):
df.to_records(**kwargs)
else:
result = df.to_records(**kwargs)
Expand Down

0 comments on commit 41301c1

Please sign in to comment.