Skip to content

Commit 609d7b4

Browse files
jbrockmendelim-vinicius
authored and
im-vinicius
committed
REF: split out dtype-finding in concat_compat (pandas-dev#53260)
* REF: split out dtype-finding in concat_compat * mypy fixup * fix annotation * remove unused ignore
1 parent ffd210c commit 609d7b4

File tree

1 file changed

+44
-82
lines changed

1 file changed

+44
-82
lines changed

pandas/core/dtypes/concat.py

+44-82
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,17 @@
1818
common_dtype_categorical_compat,
1919
find_common_type,
2020
)
21-
from pandas.core.dtypes.dtypes import (
22-
CategoricalDtype,
23-
DatetimeTZDtype,
24-
ExtensionDtype,
25-
)
21+
from pandas.core.dtypes.dtypes import CategoricalDtype
2622
from pandas.core.dtypes.generic import (
2723
ABCCategoricalIndex,
28-
ABCExtensionArray,
2924
ABCSeries,
3025
)
3126

3227
if TYPE_CHECKING:
3328
from pandas._typing import (
3429
ArrayLike,
3530
AxisInt,
31+
DtypeObj,
3632
)
3733

3834
from pandas.core.arrays import (
@@ -100,45 +96,54 @@ def concat_compat(
10096
# Creating an empty array directly is tempting, but the winnings would be
10197
# marginal given that it would still require shape & dtype calculation and
10298
# np.concatenate which has them both implemented is compiled.
99+
orig = to_concat
103100
non_empties = [x for x in to_concat if _is_nonempty(x, axis)]
104101
if non_empties and axis == 0 and not ea_compat_axis:
105102
# ea_compat_axis see GH#39574
106103
to_concat = non_empties
107104

108-
dtypes = {obj.dtype for obj in to_concat}
109-
kinds = {obj.dtype.kind for obj in to_concat}
110-
contains_datetime = any(
111-
isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in "mM"
112-
for dtype in dtypes
113-
) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)
105+
any_ea, kinds, target_dtype = _get_result_dtype(to_concat, non_empties)
106+
107+
if len(to_concat) < len(orig):
108+
_, _, alt_dtype = _get_result_dtype(orig, non_empties)
109+
110+
if target_dtype is not None:
111+
to_concat = [astype_array(arr, target_dtype, copy=False) for arr in to_concat]
112+
113+
if not isinstance(to_concat[0], np.ndarray):
114+
# i.e. isinstance(to_concat[0], ExtensionArray)
115+
to_concat_eas = cast("Sequence[ExtensionArray]", to_concat)
116+
cls = type(to_concat[0])
117+
return cls._concat_same_type(to_concat_eas)
118+
else:
119+
to_concat_arrs = cast("Sequence[np.ndarray]", to_concat)
120+
result = np.concatenate(to_concat_arrs, axis=axis)
121+
122+
if not any_ea and "b" in kinds and result.dtype.kind in "iuf":
123+
# GH#39817 cast to object instead of casting bools to numeric
124+
result = result.astype(object, copy=False)
125+
return result
114126

115-
all_empty = not len(non_empties)
116-
single_dtype = len(dtypes) == 1
117-
any_ea = any(isinstance(x, ExtensionDtype) for x in dtypes)
118127

119-
if contains_datetime:
120-
return _concat_datetime(to_concat, axis=axis)
128+
def _get_result_dtype(
129+
to_concat: Sequence[ArrayLike], non_empties: Sequence[ArrayLike]
130+
) -> tuple[bool, set[str], DtypeObj | None]:
131+
target_dtype = None
121132

133+
dtypes = {obj.dtype for obj in to_concat}
134+
kinds = {obj.dtype.kind for obj in to_concat}
135+
136+
any_ea = any(not isinstance(x, np.ndarray) for x in to_concat)
122137
if any_ea:
138+
# i.e. any ExtensionArrays
139+
123140
# we ignore axis here, as internally concatting with EAs is always
124141
# for axis=0
125-
if not single_dtype:
142+
if len(dtypes) != 1:
126143
target_dtype = find_common_type([x.dtype for x in to_concat])
127144
target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)
128-
to_concat = [
129-
astype_array(arr, target_dtype, copy=False) for arr in to_concat
130-
]
131-
132-
if isinstance(to_concat[0], ABCExtensionArray):
133-
# TODO: what about EA-backed Index?
134-
to_concat_eas = cast("Sequence[ExtensionArray]", to_concat)
135-
cls = type(to_concat[0])
136-
return cls._concat_same_type(to_concat_eas)
137-
else:
138-
to_concat_arrs = cast("Sequence[np.ndarray]", to_concat)
139-
return np.concatenate(to_concat_arrs)
140145

141-
elif all_empty:
146+
elif not len(non_empties):
142147
# we have all empties, but may need to coerce the result dtype to
143148
# object if we have non-numeric type operands (numpy would otherwise
144149
# cast this to float)
@@ -148,17 +153,16 @@ def concat_compat(
148153
pass
149154
else:
150155
# coerce to object
151-
to_concat = [x.astype("object") for x in to_concat]
156+
target_dtype = np.dtype(object)
152157
kinds = {"o"}
158+
else:
159+
# Argument 1 to "list" has incompatible type "Set[Union[ExtensionDtype,
160+
# Any]]"; expected "Iterable[Union[dtype[Any], None, Type[Any],
161+
# _SupportsDType[dtype[Any]], str, Tuple[Any, Union[SupportsIndex,
162+
# Sequence[SupportsIndex]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
163+
target_dtype = np.find_common_type(list(dtypes), []) # type: ignore[arg-type]
153164

154-
# error: Argument 1 to "concatenate" has incompatible type
155-
# "Sequence[Union[ExtensionArray, ndarray[Any, Any]]]"; expected
156-
# "Union[_SupportsArray[dtype[Any]], _NestedSequence[_SupportsArray[dtype[Any]]]]"
157-
result: np.ndarray = np.concatenate(to_concat, axis=axis) # type: ignore[arg-type]
158-
if "b" in kinds and result.dtype.kind in "iuf":
159-
# GH#39817 cast to object instead of casting bools to numeric
160-
result = result.astype(object, copy=False)
161-
return result
165+
return any_ea, kinds, target_dtype
162166

163167

164168
def union_categoricals(
@@ -320,45 +324,3 @@ def _maybe_unwrap(x):
320324

321325
dtype = CategoricalDtype(categories=categories, ordered=ordered)
322326
return Categorical._simple_new(new_codes, dtype=dtype)
323-
324-
325-
def _concatenate_2d(to_concat: Sequence[np.ndarray], axis: AxisInt) -> np.ndarray:
326-
# coerce to 2d if needed & concatenate
327-
if axis == 1:
328-
to_concat = [np.atleast_2d(x) for x in to_concat]
329-
return np.concatenate(to_concat, axis=axis)
330-
331-
332-
def _concat_datetime(to_concat: Sequence[ArrayLike], axis: AxisInt = 0) -> ArrayLike:
333-
"""
334-
provide concatenation of an datetimelike array of arrays each of which is a
335-
single M8[ns], datetime64[ns, tz] or m8[ns] dtype
336-
337-
Parameters
338-
----------
339-
to_concat : sequence of arrays
340-
axis : axis to provide concatenation
341-
342-
Returns
343-
-------
344-
a single array, preserving the combined dtypes
345-
"""
346-
from pandas.core.construction import ensure_wrapped_if_datetimelike
347-
348-
to_concat = [ensure_wrapped_if_datetimelike(x) for x in to_concat]
349-
350-
single_dtype = lib.dtypes_all_equal([x.dtype for x in to_concat])
351-
352-
# multiple types, need to coerce to object
353-
if not single_dtype:
354-
# ensure_wrapped_if_datetimelike ensures that astype(object) wraps
355-
# in Timestamp/Timedelta
356-
return _concatenate_2d([x.astype(object) for x in to_concat], axis=axis)
357-
358-
# error: Unexpected keyword argument "axis" for "_concat_same_type" of
359-
# "ExtensionArray"
360-
to_concat_eas = cast("list[ExtensionArray]", to_concat)
361-
result = type(to_concat_eas[0])._concat_same_type( # type: ignore[call-arg]
362-
to_concat_eas, axis=axis
363-
)
364-
return result

0 commit comments

Comments
 (0)