Skip to content

Commit 869be8d

Browse files
topper-123jreback
authored andcommitted
Added test for _get_dtype_type. (#16899)
1 parent 142b5b6 commit 869be8d

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

Diff for: pandas/tests/dtypes/test_common.py

+37
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,40 @@ def test__get_dtype(input_param, result):
568568
def test__get_dtype_fails(input_param):
569569
# python objects
570570
pytest.raises(TypeError, com._get_dtype, input_param)
571+
572+
573+
@pytest.mark.parametrize('input_param,result', [
574+
(int, np.dtype(int).type),
575+
('int32', np.int32),
576+
(float, np.dtype(float).type),
577+
('float64', np.float64),
578+
(np.dtype('float64'), np.float64),
579+
(str, np.dtype(str).type),
580+
(pd.Series([1, 2], dtype=np.dtype('int16')), np.int16),
581+
(pd.Series(['a', 'b']), np.object_),
582+
(pd.Index([1, 2], dtype='int64'), np.int64),
583+
(pd.Index(['a', 'b']), np.object_),
584+
('category', com.CategoricalDtypeType),
585+
(pd.Categorical(['a', 'b']).dtype, com.CategoricalDtypeType),
586+
(pd.Categorical(['a', 'b']), com.CategoricalDtypeType),
587+
(pd.CategoricalIndex(['a', 'b']).dtype, com.CategoricalDtypeType),
588+
(pd.CategoricalIndex(['a', 'b']), com.CategoricalDtypeType),
589+
(pd.DatetimeIndex([1, 2]), np.datetime64),
590+
(pd.DatetimeIndex([1, 2]).dtype, np.datetime64),
591+
('<M8[ns]', np.datetime64),
592+
(pd.DatetimeIndex([1, 2], tz='Europe/London'), com.DatetimeTZDtypeType),
593+
(pd.DatetimeIndex([1, 2], tz='Europe/London').dtype,
594+
com.DatetimeTZDtypeType),
595+
('datetime64[ns, Europe/London]', com.DatetimeTZDtypeType),
596+
(pd.SparseSeries([1, 2], dtype='int32'), np.int32),
597+
(pd.SparseSeries([1, 2], dtype='int32').dtype, np.int32),
598+
(PeriodDtype(freq='D'), com.PeriodDtypeType),
599+
('period[D]', com.PeriodDtypeType),
600+
(IntervalDtype(), com.IntervalDtypeType),
601+
(None, type(None)),
602+
(1, type(None)),
603+
(1.2, type(None)),
604+
(pd.DataFrame([1, 2]), type(None)), # composite dtype
605+
])
606+
def test__get_dtype_type(input_param, result):
607+
assert com._get_dtype_type(input_param) == result

0 commit comments

Comments
 (0)