From 1db86e8ec69cefd009807742c9335ce6bdebcf40 Mon Sep 17 00:00:00 2001 From: Richard Berg Date: Fri, 9 May 2025 03:02:57 -0500 Subject: [PATCH] Improve support for pandas Extension Arrays (#10301) --- xarray/core/dtypes.py | 66 +++++--- xarray/core/duck_array_ops.py | 34 ++-- xarray/core/extension_array.py | 251 +++++++++++++++++++++++----- xarray/tests/test_dataarray.py | 69 +++++++- xarray/tests/test_duck_array_ops.py | 62 +++++++ 5 files changed, 398 insertions(+), 84 deletions(-) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index c959a7f2536..83feba040f7 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -63,7 +63,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: # N.B. these casting rules should match pandas dtype_: np.typing.DTypeLike fill_value: Any - if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): + if is_extension_array_dtype(dtype): + return dtype, dtype.na_value + elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()): # for now, we always promote string dtypes to object for consistency with existing behavior # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes dtype_ = object @@ -222,19 +224,51 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool: return xp.isdtype(dtype, kind) -def preprocess_types(t): - if isinstance(t, str | bytes): - return type(t) - elif isinstance(dtype := getattr(t, "dtype", t), np.dtype) and ( - np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_) - ): +def maybe_promote_to_variable_width( + array_or_dtype: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.typing.ArrayLike | np.typing.DTypeLike: + if isinstance(array_or_dtype, str | bytes): + return type(array_or_dtype) + elif isinstance( + dtype := getattr(array_or_dtype, "dtype", array_or_dtype), np.dtype + ) and (np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_)): # drop the length from numpy's fixed-width string dtypes, it is better to # recalculate # TODO(keewis): remove once the minimum version of `numpy.result_type` does this # for us return dtype.type else: - return t + return array_or_dtype + + +def should_promote_to_object( + arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, xp +) -> bool: + """ + Test whether the given arrays_and_dtypes, when evaluated individually, match the + type promotion rules found in PROMOTE_TO_OBJECT. + """ + np_result_types = set() + for arr_or_dtype in arrays_and_dtypes: + try: + result_type = array_api_compat.result_type( + maybe_promote_to_variable_width(arr_or_dtype), xp=xp + ) + if isinstance(result_type, np.dtype): + np_result_types.add(result_type) + except TypeError: + # passing individual objects to xp.result_type means NEP-18 implementations won't have + # a chance to intercept special values (such as NA) that numpy core cannot handle + pass + + if np_result_types: + for left, right in PROMOTE_TO_OBJECT: + if any(np.issubdtype(t, left) for t in np_result_types) and any( + np.issubdtype(t, right) for t in np_result_types + ): + return True + + return False def result_type( @@ -263,19 +297,9 @@ def result_type( if xp is None: xp = get_array_namespace(arrays_and_dtypes) - types = { - array_api_compat.result_type(preprocess_types(t), xp=xp) - for t in arrays_and_dtypes - } - if any(isinstance(t, np.dtype) for t in types): - # only check if there's numpy dtypes – the array API does not - # define the types we're checking for - for left, right in PROMOTE_TO_OBJECT: - if any(np.issubdtype(t, left) for t in types) and any( - np.issubdtype(t, right) for t in types - ): - return np.dtype(object) + if should_promote_to_object(arrays_and_dtypes, xp): + return np.dtype(object) return array_api_compat.result_type( - *map(preprocess_types, arrays_and_dtypes), xp=xp + *map(maybe_promote_to_variable_width, arrays_and_dtypes), xp=xp ) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 96330a64b68..dfdd63263a3 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -27,6 +27,11 @@ from xarray.compat import dask_array_compat, dask_array_ops from xarray.compat.array_api_compat import get_array_namespace from xarray.core import dtypes, nputils +from xarray.core.extension_array import ( + PandasExtensionArray, + as_extension_array, + is_scalar, +) from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -239,7 +244,14 @@ def astype(data, dtype, *, xp=None, **kwargs): def asarray(data, xp=np, dtype=None): - converted = data if is_duck_array(data) else xp.asarray(data) + if is_duck_array(data): + converted = data + elif is_extension_array_dtype(dtype): + # data may or may not be an ExtensionArray, so we can't rely on + # np.asarray to call our NEP-18 handler; gotta hook it ourselves + converted = PandasExtensionArray(as_extension_array(data, dtype)) + else: + converted = xp.asarray(data, dtype=dtype) if dtype is None or converted.dtype == dtype: return converted @@ -252,19 +264,6 @@ def asarray(data, xp=np, dtype=None): def as_shared_dtype(scalars_or_arrays, xp=None): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - if any(is_extension_array_dtype(x) for x in scalars_or_arrays): - extension_array_types = [ - x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) - ] - if len(extension_array_types) == len(scalars_or_arrays) and all( - isinstance(x, type(extension_array_types[0])) for x in extension_array_types - ): - return scalars_or_arrays - raise ValueError( - "Cannot cast arrays to shared type, found" - f" array types {[x.dtype for x in scalars_or_arrays]}" - ) - # Avoid calling array_type("cupy") repeatidely in the any check array_type_cupy = array_type("cupy") if any(isinstance(x, array_type_cupy) for x in scalars_or_arrays): @@ -384,7 +383,12 @@ def where(condition, x, y): else: condition = astype(condition, dtype=dtype, xp=xp) - return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) + promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp) + + # pd.where won't broadcast 0-dim arrays across a series; scalar y's must be preserved + maybe_promoted_y = y if is_extension_array_dtype(x) and is_scalar(y) else promoted_y + + return xp.where(condition, promoted_x, maybe_promoted_y) def where_method(data, cond, other=dtypes.NA): diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index e8006a4c8c3..a5f29c3f45a 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,22 +1,47 @@ from __future__ import annotations +import functools from collections.abc import Callable, Sequence -from typing import Generic, cast +from typing import TYPE_CHECKING, Generic, cast import numpy as np import pandas as pd +from pandas.api.extensions import ExtensionArray, ExtensionDtype from pandas.api.types import is_extension_array_dtype +from pandas.api.types import is_scalar as pd_is_scalar +from pandas.core.dtypes.astype import astype_array_safe +from pandas.core.dtypes.cast import find_result_type +from pandas.core.dtypes.concat import concat_compat from xarray.core.types import DTypeLikeSave, T_ExtensionArray HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} -def implements(numpy_function): - """Register an __array_function__ implementation for MyArray objects.""" +if TYPE_CHECKING: + from pandas._typing import DtypeObj, Scalar + + +def is_scalar(value: object) -> bool: + """Workaround: pandas is_scalar doesn't recognize Categorical nulls for some reason.""" + return value is pd.CategoricalDtype.na_value or pd_is_scalar(value) + + +def implements(numpy_function_or_name: Callable | str) -> Callable: + """Register an __array_function__ implementation. + + Pass a function directly if it's guaranteed to exist in all supported numpy versions, or a + string to first check for its existence. + """ def decorator(func): - HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + if isinstance(numpy_function_or_name, str): + numpy_function = getattr(np, numpy_function_or_name, None) + else: + numpy_function = numpy_function_or_name + + if numpy_function: + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func return func return decorator @@ -29,6 +54,97 @@ def __extension_duck_array__issubdtype( return False # never want a function to think a pandas extension dtype is a subtype of numpy +@implements("astype") # np.astype was added in 2.1.0, but we only require >=1.24 +def __extension_duck_array__astype( + array_or_scalar: np.typing.ArrayLike, + dtype: DTypeLikeSave, + order: str = "K", + casting: str = "unsafe", + subok: bool = True, + copy: bool = True, + device: str = None, +) -> T_ExtensionArray: + if ( + not ( + is_extension_array_dtype(array_or_scalar) or is_extension_array_dtype(dtype) + ) + or casting != "unsafe" + or not subok + or order != "K" + ): + return NotImplemented + + return as_extension_array(array_or_scalar, dtype, copy=copy) + + +@implements(np.asarray) +def __extension_duck_array__asarray( + array_or_scalar: np.typing.ArrayLike, dtype: DTypeLikeSave = None +) -> T_ExtensionArray: + if not is_extension_array_dtype(dtype): + return NotImplemented + + return as_extension_array(array_or_scalar, dtype) + + +def as_extension_array( + array_or_scalar: np.typing.ArrayLike, dtype: ExtensionDtype, copy: bool = False +) -> T_ExtensionArray: + if is_scalar(array_or_scalar): + return dtype.construct_array_type()._from_sequence( + [array_or_scalar], dtype=dtype + ) + else: + return astype_array_safe(array_or_scalar, dtype, copy=copy) + + +@implements(np.result_type) +def __extension_duck_array__result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> DtypeObj: + extension_arrays_and_dtypes = [ + x for x in arrays_and_dtypes if is_extension_array_dtype(x) + ] + if not extension_arrays_and_dtypes: + return NotImplemented + + ea_dtypes: list[ExtensionDtype] = [ + getattr(x, "dtype", x) for x in extension_arrays_and_dtypes + ] + scalars: list[Scalar] = [x for x in arrays_and_dtypes if is_scalar(x)] + # other_stuff could include: + # - arrays such as pd.ABCSeries, np.ndarray, or other array-api duck arrays + # - dtypes such as pd.DtypeObj, np.dtype, or other array-api duck dtypes + other_stuff = [ + x + for x in arrays_and_dtypes + if not is_extension_array_dtype(x) and not is_scalar(x) + ] + + # We implement one special case: when possible, preserve Categoricals (avoid promoting + # to object) by merging the categories of all given Categoricals + scalars + NA. + # Ideally this could be upstreamed into pandas find_result_type / find_common_type. + if not other_stuff and all( + isinstance(x, pd.CategoricalDtype) and not x.ordered for x in ea_dtypes + ): + return union_unordered_categorical_and_scalar(ea_dtypes, scalars) + + # In all other cases, we defer to pandas find_result_type, which is the only Pandas API + # permissive enough to handle scalars + other_stuff. + # Note that unlike find_common_type or np.result_type, it operates in pairs, where + # the left side must be a DtypeObj. + return functools.reduce(find_result_type, arrays_and_dtypes, ea_dtypes[0]) + + +def union_unordered_categorical_and_scalar( + categorical_dtypes: list[pd.CategoricalDtype], scalars: list[Scalar] +) -> pd.CategoricalDtype: + scalars = [x for x in scalars if x is not pd.CategoricalDtype.na_value] + all_categories = set().union(*(x.categories for x in categorical_dtypes)) + all_categories = all_categories.union(scalars) + return pd.CategoricalDtype(categories=all_categories) + + @implements(np.broadcast_to) def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): if shape[0] == len(arr) and len(shape) == 1: @@ -45,21 +161,36 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] + return concat_compat(arrays, ea_compat_axis=True) @implements(np.where) def __extension_duck_array__where( - condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray + condition: T_ExtensionArray | np.ArrayLike, + x: T_ExtensionArray, + y: T_ExtensionArray | np.ArrayLike, ) -> T_ExtensionArray: - if ( - isinstance(x, pd.Categorical) - and isinstance(y, pd.Categorical) - and x.dtype != y.dtype - ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] - y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] - return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) + return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) + + +def _replace_duck(args, replacer: Callable[[PandasExtensionArray]]) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, PandasExtensionArray): + args_as_list[index] = replacer(value) + elif isinstance(value, tuple): # should handle more than just tuple? iterable? + args_as_list[index] = tuple(_replace_duck(value, replacer)) + elif isinstance(value, list): + args_as_list[index] = _replace_duck(value, replacer) + return args_as_list + + +def replace_duck_with_extension_array(args) -> tuple: + return tuple(_replace_duck(args, lambda duck: duck.array)) + + +def replace_duck_with_series(args) -> tuple: + return tuple(_replace_duck(args, lambda duck: pd.Series(duck.array))) class PandasExtensionArray(Generic[T_ExtensionArray]): @@ -74,36 +205,80 @@ def __init__(self, array: T_ExtensionArray): The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. ``` """ - if not isinstance(array, pd.api.extensions.ExtensionArray): + if not isinstance(array, ExtensionArray): raise TypeError(f"{array} is not an pandas ExtensionArray.") self.array = array + self._add_ops_dunders() + + def _add_ops_dunders(self): + """Delegate all operators to pd.Series""" + + def create_dunder(name: str) -> Callable: + def binary_dunder(self, other): + self, other = replace_duck_with_series((self, other)) + res = getattr(pd.Series, name)(self, other) + if isinstance(res, pd.Series): + res = PandasExtensionArray(res.array) + return res + + return binary_dunder + + # see pandas.core.arraylike.OpsMixin + binary_operators = [ + "__eq__", + "__ne__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__and__", + "__rand__", + "__or__", + "__ror__", + "__xor__", + "__rxor__", + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__mod__", + "__rmod__", + "__divmod__", + "__rdivmod__", + "__pow__", + "__rpow__", + ] + for method_name in binary_operators: + setattr(self.__class__, method_name, create_dunder(method_name)) + def __array_function__(self, func, types, args, kwargs): - def replace_duck_with_extension_array(args) -> list: - args_as_list = list(args) - for index, value in enumerate(args_as_list): - if isinstance(value, PandasExtensionArray): - args_as_list[index] = value.array - elif isinstance( - value, tuple - ): # should handle more than just tuple? iterable? - args_as_list[index] = tuple( - replace_duck_with_extension_array(value) - ) - elif isinstance(value, list): - args_as_list[index] = replace_duck_with_extension_array(value) - return args_as_list - - args = tuple(replace_duck_with_extension_array(args)) + args = replace_duck_with_extension_array(args) if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: return func(*args, **kwargs) res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) - if is_extension_array_dtype(res): + if isinstance(res, ExtensionArray): return type(self)[type(res)](res) return res def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): - return ufunc(*inputs, **kwargs) + if first_ea := next( + (x for x in inputs if isinstance(x, PandasExtensionArray)), None + ): + inputs = replace_duck_with_series(inputs) + res = first_ea.__array_ufunc__(ufunc, method, *inputs, **kwargs) + if isinstance(res, pd.Series): + arr = res.array + return type(self)[type(arr)](arr) + return res + + return getattr(ufunc, method)(*inputs, **kwargs) def __repr__(self): return f"PandasExtensionArray(array={self.array!r})" @@ -115,20 +290,12 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: item = self.array[key] if is_extension_array_dtype(item): return type(self)(item) - if np.isscalar(item): + if is_scalar(item): return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): self.array[key] = val - def __eq__(self, other): - if isinstance(other, PandasExtensionArray): - return self.array == other.array - return self.array == other - - def __ne__(self, other): - return ~(self == other) - def __len__(self): return len(self.array) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8d0d5011026..2f53ea70735 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -34,6 +34,7 @@ from xarray.core import dtypes from xarray.core.common import full_like from xarray.core.coordinates import Coordinates +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar @@ -1792,12 +1793,37 @@ def test_reindex_empty_array_dtype(self) -> None: x = xr.DataArray([], dims=("x",), coords={"x": []}).astype("float32") y = x.reindex(x=[1.0, 2.0]) - assert x.dtype == y.dtype, ( - "Dtype of reindexed DataArray should match dtype of the original DataArray" - ) - assert y.dtype == np.float32, ( - "Dtype of reindexed DataArray should remain float32" - ) + assert ( + x.dtype == y.dtype + ), "Dtype of reindexed DataArray should match dtype of the original DataArray" + assert ( + y.dtype == np.float32 + ), "Dtype of reindexed DataArray should remain float32" + + def test_reindex_extension_array(self) -> None: + index1 = np.array([1, 2, 3]) + index2 = np.array([1, 2, 4]) + srs = pd.Series(index=index1, data=1).convert_dtypes() + x = srs.to_xarray() + y = x.reindex(index=index2) # used to fail (GH #10301) + assert_array_equal(x, pd.array([1, 1, 1])) + assert_array_equal(y, pd.array([1, 1, pd.NA])) + assert x.dtype == y.dtype == pd.Int64Dtype() + assert x.index.dtype == y.index.dtype == np.dtype("int64") + + def test_reindex_categorical(self) -> None: + index1 = pd.Categorical(["a", "b", "c"]) + index2 = pd.Categorical(["a", "b", "d"]) + srs = pd.Series(index=index1, data=1).convert_dtypes() + x = srs.to_xarray() + y = x.reindex(index=index2) + assert_array_equal(x, pd.array([1, 1, 1])) + assert_array_equal(y, pd.array([1, 1, pd.NA])) + assert x.dtype == y.dtype == pd.Int64Dtype() + assert isinstance(x.index.dtype, pd.CategoricalDtype) + assert isinstance(y.index.dtype, pd.CategoricalDtype) + assert_array_equal(x.index.dtype.categories, np.array(["a", "b", "c"])) + assert_array_equal(y.index.dtype.categories, np.array(["a", "b", "d"])) def test_rename(self) -> None: da = xr.DataArray( @@ -7255,3 +7281,34 @@ def test_unstack_index_var() -> None: name="x", ) assert_identical(actual, expected) + + +def test_from_series_regression() -> None: + # all of these examples used to fail + # see GH:issue:10301 + srs = pd.Series(index=[1, 2, 3], data=pd.array([1, 1, pd.NA])) + arr = srs.to_xarray() + + # binary operator + res = arr * 5 + assert_array_equal(res, np.array([5, 5, np.nan])) + assert res.dtype == pd.Int64Dtype() + assert isinstance(res, xr.DataArray) + + # NEP-13 ufunc + res = np.add(3, arr) + assert_array_equal(np.add(2, arr), np.array([3, 3, np.nan])) + assert res.dtype == pd.Int64Dtype() + assert isinstance(res, xr.DataArray) + + # NEP-18 array_function + res = np.astype(arr.data, pd.Int32Dtype()) + assert_array_equal(res, arr) + assert res.dtype == pd.Int32Dtype() + assert isinstance(res, PandasExtensionArray) + + # xarray ufunc + res = arr.fillna(0) + assert_array_equal(res, np.array([1, 1, 0])) + assert res.dtype == pd.Int64Dtype() + assert isinstance(res, xr.DataArray) diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index dcf8349aba4..31e92033acf 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -1099,3 +1099,65 @@ def test_extension_array_repr(int1): def test_extension_array_attr(int1): int_duck_array = PandasExtensionArray(int1) assert (~int_duck_array.fillna(10)).all() + + +def test_extension_array_result_type_numeric(int1, int2): + assert pd.Int64Dtype() == np.result_type( + PandasExtensionArray(int1), PandasExtensionArray(int2) + ) + assert pd.Int64Dtype() == np.result_type( + 100, -100, PandasExtensionArray(int1), pd.NA + ) + assert pd.Int64Dtype() == np.result_type( + PandasExtensionArray(pd.array([1, 2, 3], dtype=pd.Int8Dtype())), + np.array([4]), + ) + assert pd.Float64Dtype() == np.result_type( + np.array([1.0]), + PandasExtensionArray(int1), + ) + + +def test_extension_array_result_type_categorical(categorical1, categorical2): + res = np.result_type( + PandasExtensionArray(categorical1), PandasExtensionArray(categorical2) + ) + assert isinstance(res, pd.CategoricalDtype) + assert set(res.categories) == set(categorical1.categories) | set( + categorical2.categories + ) + assert not res.ordered + + assert categorical1.dtype == np.result_type( + PandasExtensionArray(categorical1), pd.CategoricalDtype.na_value + ) + + +def test_extension_array_result_type_mixed(int1, categorical1): + assert np.dtype("object") == np.result_type( + PandasExtensionArray(int1), PandasExtensionArray(categorical1) + ) + assert np.dtype("object") == np.result_type( + np.array([1, 2, 3]), PandasExtensionArray(categorical1) + ) + assert np.dtype("object") == np.result_type( + PandasExtensionArray(int1), dt.datetime.now() + ) + + +def test_extension_array_astype(int1): + res = np.astype(PandasExtensionArray(int1), float) + assert res.dtype == np.dtype("float64") + assert_array_equal(res, np.array([np.nan, 2, 3, np.nan, np.nan], dtype="float32")) + + res = np.astype(PandasExtensionArray(int1), pd.Float64Dtype()) + assert res.dtype == pd.Float64Dtype() + assert_array_equal( + res, pd.array([pd.NA, np.float64(2), np.float64(3), pd.NA, pd.NA]) + ) + + res = np.astype( + PandasExtensionArray(pd.array([1, 2], dtype="int8")), pd.Int16Dtype() + ) + assert res.dtype == pd.Int16Dtype() + assert_array_equal(res, pd.array([1, 2], dtype=pd.Int16Dtype()))