|
16 | 16 |
|
17 | 17 | import numpy as np
|
18 | 18 |
|
| 19 | +import pandas |
| 20 | + |
19 | 21 | from pandas._libs import lib
|
20 | 22 | from pandas.util._decorators import set_module
|
21 | 23 | from pandas.util._exceptions import find_stack_level
|
|
47 | 49 | )
|
48 | 50 | from pandas.core.internals import concatenate_managers
|
49 | 51 |
|
| 52 | +from pandas.core.dtypes.common import is_extension_array_dtype |
| 53 | + |
| 54 | +from pandas.core.construction import array |
| 55 | + |
50 | 56 | if TYPE_CHECKING:
|
51 | 57 | from collections.abc import (
|
52 | 58 | Callable,
|
@@ -819,7 +825,20 @@ def _get_sample_object(
|
819 | 825 |
|
820 | 826 |
|
821 | 827 | def _concat_indexes(indexes) -> Index:
|
822 |
| - return indexes[0].append(indexes[1:]) |
| 828 | + # try to preserve extension types such as timestamp[pyarrow] |
| 829 | + values = [] |
| 830 | + for idx in indexes: |
| 831 | + values.extend(idx._values if hasattr(idx, "_values") else idx) |
| 832 | + |
| 833 | + # use the first index as a sample to infer the desired dtype |
| 834 | + sample = indexes[0] |
| 835 | + try: |
| 836 | + # this helps preserve extension types like timestamp[pyarrow] |
| 837 | + arr = array(values, dtype=sample.dtype) |
| 838 | + except Exception: |
| 839 | + arr = array(values) # fallback |
| 840 | + |
| 841 | + return Index(arr) |
823 | 842 |
|
824 | 843 |
|
825 | 844 | def validate_unique_levels(levels: list[Index]) -> None:
|
@@ -876,14 +895,33 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
|
876 | 895 |
|
877 | 896 | concat_index = _concat_indexes(indexes)
|
878 | 897 |
|
879 |
| - # these go at the end |
880 | 898 | if isinstance(concat_index, MultiIndex):
|
881 | 899 | levels.extend(concat_index.levels)
|
882 | 900 | codes_list.extend(concat_index.codes)
|
883 | 901 | else:
|
884 |
| - codes, categories = factorize_from_iterable(concat_index) |
885 |
| - levels.append(categories) |
886 |
| - codes_list.append(codes) |
| 902 | + # handle the case where the resulting index is a flat Index |
| 903 | + # but contains tuples (i.e., a collapsed MultiIndex) |
| 904 | + if isinstance(concat_index[0], tuple): |
| 905 | + # retrieve the original dtypes |
| 906 | + original_dtypes = [lvl.dtype for lvl in indexes[0].levels] |
| 907 | + |
| 908 | + unzipped = list(zip(*concat_index)) |
| 909 | + for i, level_values in enumerate(unzipped): |
| 910 | + # reconstruct each level using original dtype |
| 911 | + arr = array(level_values, dtype=original_dtypes[i]) |
| 912 | + level_codes, _ = factorize_from_iterable(arr) |
| 913 | + levels.append(ensure_index(arr)) |
| 914 | + codes_list.append(level_codes) |
| 915 | + else: |
| 916 | + # simple indexes factorize directly |
| 917 | + codes, categories = factorize_from_iterable(concat_index) |
| 918 | + values = getattr(concat_index, "_values", concat_index) |
| 919 | + if is_extension_array_dtype(values): |
| 920 | + levels.append(values) |
| 921 | + else: |
| 922 | + levels.append(categories) |
| 923 | + codes_list.append(codes) |
| 924 | + |
887 | 925 |
|
888 | 926 | if len(names) == len(levels):
|
889 | 927 | names = list(names)
|
|
0 commit comments