@@ -568,3 +568,40 @@ def test__get_dtype(input_param, result):
568
568
def test__get_dtype_fails (input_param ):
569
569
# python objects
570
570
pytest .raises (TypeError , com ._get_dtype , input_param )
571
+
572
+
573
+ @pytest .mark .parametrize ('input_param,result' , [
574
+ (int , np .dtype (int ).type ),
575
+ ('int32' , np .int32 ),
576
+ (float , np .dtype (float ).type ),
577
+ ('float64' , np .float64 ),
578
+ (np .dtype ('float64' ), np .float64 ),
579
+ (str , np .dtype (str ).type ),
580
+ (pd .Series ([1 , 2 ], dtype = np .dtype ('int16' )), np .int16 ),
581
+ (pd .Series (['a' , 'b' ]), np .object_ ),
582
+ (pd .Index ([1 , 2 ], dtype = 'int64' ), np .int64 ),
583
+ (pd .Index (['a' , 'b' ]), np .object_ ),
584
+ ('category' , com .CategoricalDtypeType ),
585
+ (pd .Categorical (['a' , 'b' ]).dtype , com .CategoricalDtypeType ),
586
+ (pd .Categorical (['a' , 'b' ]), com .CategoricalDtypeType ),
587
+ (pd .CategoricalIndex (['a' , 'b' ]).dtype , com .CategoricalDtypeType ),
588
+ (pd .CategoricalIndex (['a' , 'b' ]), com .CategoricalDtypeType ),
589
+ (pd .DatetimeIndex ([1 , 2 ]), np .datetime64 ),
590
+ (pd .DatetimeIndex ([1 , 2 ]).dtype , np .datetime64 ),
591
+ ('<M8[ns]' , np .datetime64 ),
592
+ (pd .DatetimeIndex ([1 , 2 ], tz = 'Europe/London' ), com .DatetimeTZDtypeType ),
593
+ (pd .DatetimeIndex ([1 , 2 ], tz = 'Europe/London' ).dtype ,
594
+ com .DatetimeTZDtypeType ),
595
+ ('datetime64[ns, Europe/London]' , com .DatetimeTZDtypeType ),
596
+ (pd .SparseSeries ([1 , 2 ], dtype = 'int32' ), np .int32 ),
597
+ (pd .SparseSeries ([1 , 2 ], dtype = 'int32' ).dtype , np .int32 ),
598
+ (PeriodDtype (freq = 'D' ), com .PeriodDtypeType ),
599
+ ('period[D]' , com .PeriodDtypeType ),
600
+ (IntervalDtype (), com .IntervalDtypeType ),
601
+ (None , type (None )),
602
+ (1 , type (None )),
603
+ (1.2 , type (None )),
604
+ (pd .DataFrame ([1 , 2 ]), type (None )), # composite dtype
605
+ ])
606
+ def test__get_dtype_type (input_param , result ):
607
+ assert com ._get_dtype_type (input_param ) == result
0 commit comments