Skip to content

Commit 0f8ff5c

Browse files
slevangtien-vokeewisdcherian
authored
Allow wrapping np.ndarray subclasses (#9760)
* Allow wrapping astropy.units.Quantity * allow all np.ndarray subclasses * whats new * test np.matrix * fix comment --------- Co-authored-by: tvo <[email protected]> Co-authored-by: Justus Magin <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 5a9ff0b commit 0f8ff5c

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`).
33+
By `Sam Levang <https://github.com/slevang>`_ and `Tien Vo <https://github.com/tien-vo>`_.
3234
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with
3335
arrays with more than two dimensions.
3436
(:issue:`5629`). By `Deepak Cherian <https://github.com/dcherian>`_.

xarray/core/variable.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -321,14 +321,18 @@ def convert_non_numpy_type(data):
321321
else:
322322
data = np.asarray(data)
323323

324-
# immediately return array-like types except `numpy.ndarray` subclasses and `numpy` scalars
325-
if not isinstance(data, np.ndarray | np.generic) and (
324+
if isinstance(data, np.matrix):
325+
data = np.asarray(data)
326+
327+
# immediately return array-like types except `numpy.ndarray` and `numpy` scalars
328+
# compare types with `is` instead of `isinstance` to allow `numpy.ndarray` subclasses
329+
is_numpy = type(data) is np.ndarray or isinstance(data, np.generic)
330+
if not is_numpy and (
326331
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
327332
):
328333
return cast("T_DuckArray", data)
329334

330-
# validate whether the data is valid data types. Also, explicitly cast `numpy`
331-
# subclasses and `numpy` scalars to `numpy.ndarray`
335+
# anything left will be converted to `numpy.ndarray`, including `numpy` scalars
332336
data = np.asarray(data)
333337

334338
if data.dtype.kind in "OMm":

xarray/tests/test_variable.py

+20
Original file line numberDiff line numberDiff line change
@@ -2746,6 +2746,26 @@ def test_ones_like(self) -> None:
27462746
assert_identical(ones_like(orig), full_like(orig, 1))
27472747
assert_identical(ones_like(orig, dtype=int), full_like(orig, 1, dtype=int))
27482748

2749+
def test_numpy_ndarray_subclass(self):
2750+
class SubclassedArray(np.ndarray):
2751+
def __new__(cls, array, foo):
2752+
obj = np.asarray(array).view(cls)
2753+
obj.foo = foo
2754+
return obj
2755+
2756+
data = SubclassedArray([1, 2, 3], foo="bar")
2757+
actual = as_compatible_data(data)
2758+
assert isinstance(actual, SubclassedArray)
2759+
assert actual.foo == "bar"
2760+
assert_array_equal(data, actual)
2761+
2762+
def test_numpy_matrix(self):
2763+
with pytest.warns(PendingDeprecationWarning):
2764+
data = np.matrix([[1, 2], [3, 4]])
2765+
actual = as_compatible_data(data)
2766+
assert isinstance(actual, np.ndarray)
2767+
assert_array_equal(data, actual)
2768+
27492769
def test_unsupported_type(self):
27502770
# Non indexable type
27512771
class CustomArray(NDArrayMixin):

0 commit comments

Comments
 (0)