diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 2ada386f76b87..941ad4612b73c 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -211,7 +211,7 @@ Strings Interval ^^^^^^^^ -- +- Bug in :func:`interval_range` where start and end numeric types were always cast to 64 bit (:issue:`57268`) - Indexing diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index 46d1ee49c22a0..7695193a15608 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -1101,9 +1101,23 @@ def interval_range( breaks: np.ndarray | TimedeltaIndex | DatetimeIndex if is_number(endpoint): + dtype: np.dtype = np.dtype("int64") if com.all_not_none(start, end, freq): + if ( + isinstance(start, (float, np.float16)) + or isinstance(end, (float, np.float16)) + or isinstance(freq, (float, np.float16)) + ): + dtype = np.dtype("float64") + elif ( + isinstance(start, (np.integer, np.floating)) + and isinstance(end, (np.integer, np.floating)) + and start.dtype == end.dtype + ): + dtype = start.dtype # 0.1 ensures we capture end breaks = np.arange(start, end + (freq * 0.1), freq) + breaks = maybe_downcast_numeric(breaks, dtype) else: # compute the period/start/end if unspecified (at most one) if periods is None: @@ -1122,7 +1136,7 @@ def interval_range( # expected "ndarray[Any, Any]" [ breaks = maybe_downcast_numeric( breaks, # type: ignore[arg-type] - np.dtype("int64"), + dtype, ) else: # delegate to the appropriate range function @@ -1131,4 +1145,9 @@ def interval_range( else: breaks = timedelta_range(start=start, end=end, periods=periods, freq=freq) - return IntervalIndex.from_breaks(breaks, name=name, closed=closed) + return IntervalIndex.from_breaks( + breaks, + name=name, + closed=closed, + dtype=IntervalDtype(subtype=breaks.dtype, closed=closed), + ) diff --git a/pandas/tests/indexes/interval/test_interval_range.py b/pandas/tests/indexes/interval/test_interval_range.py index e8de59f84bcc6..7aea481b49221 100644 --- a/pandas/tests/indexes/interval/test_interval_range.py +++ b/pandas/tests/indexes/interval/test_interval_range.py @@ -220,6 +220,20 @@ def test_float_subtype(self, start, end, freq): expected = "int64" if is_integer(start + end) else "float64" assert result == expected + @pytest.mark.parametrize( + "start, end, expected", + [ + (np.int8(1), np.int8(10), np.dtype("int8")), + (np.int8(1), np.float16(10), np.dtype("float64")), + (np.float32(1), np.float32(10), np.dtype("float32")), + (1, 10, np.dtype("int64")), + (1, 10.0, np.dtype("float64")), + ], + ) + def test_interval_dtype(self, start, end, expected): + result = interval_range(start=start, end=end).dtype.subtype + assert result == expected + def test_interval_range_fractional_period(self): # float value for periods expected = interval_range(start=0, periods=10)