Skip to content

Commit 41301c1

Browse files
committed
BUG: support dtypes in column_dtypes for to_records()
1 parent 95f8dca commit 41301c1

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

Diff for: pandas/core/frame.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1716,7 +1716,8 @@ def to_records(self, index=True, convert_datetime64=None,
17161716
# string naming a type.
17171717
if dtype_mapping is None:
17181718
formats.append(v.dtype)
1719-
elif isinstance(dtype_mapping, (type, compat.string_types)):
1719+
elif isinstance(dtype_mapping, (type, np.dtype,
1720+
compat.string_types)):
17201721
formats.append(dtype_mapping)
17211722
else:
17221723
element = "row" if i < index_len else "column"

Diff for: pandas/tests/frame/test_convert_to.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from pandas.compat import long
1212

13-
from pandas import DataFrame, MultiIndex, Series, Timestamp, compat, date_range
13+
from pandas import (DataFrame, MultiIndex, Series, Timestamp, compat,
14+
date_range, CategoricalDtype)
1415
from pandas.tests.frame.common import TestData
1516
import pandas.util.testing as tm
1617

@@ -220,6 +221,12 @@ def test_to_records_with_categorical(self):
220221
dtype=[("index", "<i8"), ("A", "<U"),
221222
("B", "<U"), ("C", "<U")])),
222223
224+
# Pass in a dtype instance.
225+
(dict(column_dtypes=np.dtype('unicode')),
226+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
227+
dtype=[("index", "<i8"), ("A", "<U"),
228+
("B", "<U"), ("C", "<U")])),
229+
223230
# Pass in a dictionary (name-only).
224231
(dict(column_dtypes={"A": np.int8, "B": np.float32, "C": "<U2"}),
225232
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
@@ -249,6 +256,12 @@ def test_to_records_with_categorical(self):
249256
dtype=[("index", "<i8"), ("A", "i1"),
250257
("B", "<f4"), ("C", "O")])),
251258
259+
# Names / indices not in dtype mapping default to array dtype.
260+
(dict(column_dtypes={"A": np.dtype('int8'), "B": np.dtype('float32')}),
261+
np.rec.array([("0", "1", "0.2", "a"), ("1", "2", "1.5", "bc")],
262+
dtype=[("index", "<i8"), ("A", "i1"),
263+
("B", "<f4"), ("C", "O")]))])
264+
252265
# Mixture of everything.
253266
(dict(column_dtypes={"A": np.int8, "B": np.float32},
254267
index_dtypes="<U2"),
@@ -258,17 +271,26 @@ def test_to_records_with_categorical(self):
258271

259272
# Invalid dype values.
260273
(dict(index=False, column_dtypes=list()),
261-
"Invalid dtype \\[\\] specified for column A"),
274+
(ValueError, "Invalid dtype \\[\\] specified for column A")),
262275

263276
(dict(index=False, column_dtypes={"A": "int32", "B": 5}),
264-
"Invalid dtype 5 specified for column B"),
277+
(ValueError, "Invalid dtype 5 specified for column B")),
278+
279+
# Numpy can't handle EA types, so check error is raised
280+
(dict(index=False, column_dtypes={"A": "int32",
281+
"B": CategoricalDtype(['a', 'b'])}),
282+
(ValueError, 'Invalid dtype category specified for column B')),
283+
284+
# Check that bad types raise
285+
(dict(index=False, column_dtypes={"A": "int32", "B": "foo"}),
286+
(TypeError, 'data type "foo" not understood')),
265287
])
266288
def test_to_records_dtype(self, kwargs, expected):
267289
# see gh-18146
268290
df = DataFrame({"A": [1, 2], "B": [0.2, 1.5], "C": ["a", "bc"]})
269291

270-
if isinstance(expected, str):
271-
with pytest.raises(ValueError, match=expected):
292+
if not isinstance(expected, np.recarray):
293+
with pytest.raises(expected[0], match=expected[1]):
272294
df.to_records(**kwargs)
273295
else:
274296
result = df.to_records(**kwargs)

0 commit comments

Comments
 (0)