Skip to content

Commit 2adacb4

Browse files
Fix pandas-dev#58421: Index[timestamp[pyarrow]].union with itself return object type
1 parent 56847c5 commit 2adacb4

File tree

2 files changed

+90
-5
lines changed

2 files changed

+90
-5
lines changed

pandas/core/reshape/concat.py

+43-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import numpy as np
1818

19+
import pandas
20+
1921
from pandas._libs import lib
2022
from pandas.util._decorators import set_module
2123
from pandas.util._exceptions import find_stack_level
@@ -47,6 +49,10 @@
4749
)
4850
from pandas.core.internals import concatenate_managers
4951

52+
from pandas.core.dtypes.common import is_extension_array_dtype
53+
54+
from pandas.core.construction import array
55+
5056
if TYPE_CHECKING:
5157
from collections.abc import (
5258
Callable,
@@ -819,7 +825,20 @@ def _get_sample_object(
819825

820826

821827
def _concat_indexes(indexes) -> Index:
822-
return indexes[0].append(indexes[1:])
828+
# try to preserve extension types such as timestamp[pyarrow]
829+
values = []
830+
for idx in indexes:
831+
values.extend(idx._values if hasattr(idx, "_values") else idx)
832+
833+
# use the first index as a sample to infer the desired dtype
834+
sample = indexes[0]
835+
try:
836+
# this helps preserve extension types like timestamp[pyarrow]
837+
arr = array(values, dtype=sample.dtype)
838+
except Exception:
839+
arr = array(values) # fallback
840+
841+
return Index(arr)
823842

824843

825844
def validate_unique_levels(levels: list[Index]) -> None:
@@ -876,14 +895,33 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
876895

877896
concat_index = _concat_indexes(indexes)
878897

879-
# these go at the end
880898
if isinstance(concat_index, MultiIndex):
881899
levels.extend(concat_index.levels)
882900
codes_list.extend(concat_index.codes)
883901
else:
884-
codes, categories = factorize_from_iterable(concat_index)
885-
levels.append(categories)
886-
codes_list.append(codes)
902+
# handle the case where the resulting index is a flat Index
903+
# but contains tuples (i.e., a collapsed MultiIndex)
904+
if isinstance(concat_index[0], tuple):
905+
# retrieve the original dtypes
906+
original_dtypes = [lvl.dtype for lvl in indexes[0].levels]
907+
908+
unzipped = list(zip(*concat_index))
909+
for i, level_values in enumerate(unzipped):
910+
# reconstruct each level using original dtype
911+
arr = array(level_values, dtype=original_dtypes[i])
912+
level_codes, _ = factorize_from_iterable(arr)
913+
levels.append(ensure_index(arr))
914+
codes_list.append(level_codes)
915+
else:
916+
# simple indexes factorize directly
917+
codes, categories = factorize_from_iterable(concat_index)
918+
values = getattr(concat_index, "_values", concat_index)
919+
if is_extension_array_dtype(values):
920+
levels.append(values)
921+
else:
922+
levels.append(categories)
923+
codes_list.append(codes)
924+
887925

888926
if len(names) == len(levels):
889927
names = list(names)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pandas as pd
2+
import pytest
3+
4+
5+
schema = {
6+
"id": "int64[pyarrow]",
7+
"time": "timestamp[s][pyarrow]",
8+
"value": "float[pyarrow]",
9+
}
10+
11+
@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
12+
def test_concat_preserves_pyarrow_timestamp(dtype):
13+
dfA = (
14+
pd.DataFrame(
15+
[
16+
(0, "2021-01-01 00:00:00", 5.3),
17+
(1, "2021-01-01 00:01:00", 5.4),
18+
(2, "2021-01-01 00:01:00", 5.4),
19+
(3, "2021-01-01 00:02:00", 5.5),
20+
],
21+
columns=schema,
22+
)
23+
.astype(schema)
24+
.set_index(["id", "time"])
25+
)
26+
27+
dfB = (
28+
pd.DataFrame(
29+
[
30+
(1, "2022-01-01 08:00:00", 6.3),
31+
(2, "2022-01-01 08:01:00", 6.4),
32+
(3, "2022-01-01 08:02:00", 6.5),
33+
],
34+
columns=schema,
35+
)
36+
.astype(schema)
37+
.set_index(["id", "time"])
38+
)
39+
40+
df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])
41+
42+
# chech whether df.index is multiIndex
43+
assert isinstance(df.index, pd.MultiIndex), f"Expected MultiIndex, but received {type(df.index)}"
44+
45+
# Verifying special dtype timestamp[s][pyarrow] stays intact after concat
46+
assert df.index.levels[2].dtype == dtype, f"Expected {dtype}, but received {df.index.levels[2].dtype}"
47+

0 commit comments

Comments
 (0)