Skip to content

Commit 70c4ee7

Browse files
fix NamedArray.imag and NamedArray.real typing info (#8369)
Co-authored-by: Illviljan <[email protected]>
1 parent ccc8f99 commit 70c4ee7

File tree

4 files changed

+71
-10
lines changed

4 files changed

+71
-10
lines changed

xarray/core/variable.py

+22
Original file line numberDiff line numberDiff line change
@@ -2365,6 +2365,28 @@ def notnull(self, keep_attrs: bool | None = None):
23652365
keep_attrs=keep_attrs,
23662366
)
23672367

2368+
@property
2369+
def imag(self) -> Variable:
2370+
"""
2371+
The imaginary part of the variable.
2372+
2373+
See Also
2374+
--------
2375+
numpy.ndarray.imag
2376+
"""
2377+
return self._new(data=self.data.imag)
2378+
2379+
@property
2380+
def real(self) -> Variable:
2381+
"""
2382+
The real part of the variable.
2383+
2384+
See Also
2385+
--------
2386+
numpy.ndarray.real
2387+
"""
2388+
return self._new(data=self.data.real)
2389+
23682390
def __array_wrap__(self, obj, context=None):
23692391
return Variable(self.dims, obj)
23702392

xarray/namedarray/_typing.py

+8
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,14 @@ def __array_function__(
138138
) -> Any:
139139
...
140140

141+
@property
142+
def imag(self) -> _arrayfunction[_ShapeType_co, Any]:
143+
...
144+
145+
@property
146+
def real(self) -> _arrayfunction[_ShapeType_co, Any]:
147+
...
148+
141149

142150
# Corresponds to np.typing.NDArray:
143151
_ArrayFunction = _arrayfunction[Any, np.dtype[_ScalarType_co]]

xarray/namedarray/core.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@
2222
from xarray.core import dtypes, formatting, formatting_html
2323
from xarray.namedarray._aggregations import NamedArrayAggregations
2424
from xarray.namedarray._typing import (
25+
_arrayapi,
2526
_arrayfunction_or_api,
2627
_chunkedarray,
28+
_dtype,
2729
_DType_co,
2830
_ScalarType_co,
2931
_ShapeType_co,
32+
_SupportsImag,
33+
_SupportsReal,
3034
)
3135
from xarray.namedarray.utils import _default, is_duck_dask_array, to_0d_object_array
3236

@@ -513,26 +517,39 @@ def data(self, data: duckarray[Any, _DType_co]) -> None:
513517
self._data = data
514518

515519
@property
516-
def imag(self) -> Self:
520+
def imag(
521+
self: NamedArray[_ShapeType, np.dtype[_SupportsImag[_ScalarType]]], # type: ignore[type-var]
522+
) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
517523
"""
518524
The imaginary part of the array.
519525
520526
See Also
521527
--------
522528
numpy.ndarray.imag
523529
"""
524-
return self._replace(data=self.data.imag) # type: ignore
530+
if isinstance(self._data, _arrayapi):
531+
from xarray.namedarray._array_api import imag
532+
533+
return imag(self)
534+
535+
return self._new(data=self._data.imag)
525536

526537
@property
527-
def real(self) -> Self:
538+
def real(
539+
self: NamedArray[_ShapeType, np.dtype[_SupportsReal[_ScalarType]]], # type: ignore[type-var]
540+
) -> NamedArray[_ShapeType, _dtype[_ScalarType]]:
528541
"""
529542
The real part of the array.
530543
531544
See Also
532545
--------
533546
numpy.ndarray.real
534547
"""
535-
return self._replace(data=self.data.real) # type: ignore
548+
if isinstance(self._data, _arrayapi):
549+
from xarray.namedarray._array_api import real
550+
551+
return real(self)
552+
return self._new(data=self._data.real)
536553

537554
def __dask_tokenize__(self) -> Hashable:
538555
# Use v.data, instead of v._data, in order to cope with the wrappers

xarray/tests/test_namedarray.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,25 @@ def test_data(random_inputs: np.ndarray[Any, Any]) -> None:
168168

169169

170170
def test_real_and_imag() -> None:
171-
named_array: NamedArray[Any, Any]
172-
named_array = NamedArray(["x"], np.arange(3) - 1j * np.arange(3))
173-
expected_real = np.arange(3)
174-
assert np.array_equal(named_array.real.data, expected_real)
171+
expected_real: np.ndarray[Any, np.dtype[np.float64]]
172+
expected_real = np.arange(3, dtype=np.float64)
173+
174+
expected_imag: np.ndarray[Any, np.dtype[np.float64]]
175+
expected_imag = -np.arange(3, dtype=np.float64)
176+
177+
arr: np.ndarray[Any, np.dtype[np.complex128]]
178+
arr = expected_real + 1j * expected_imag
179+
180+
named_array: NamedArray[Any, np.dtype[np.complex128]]
181+
named_array = NamedArray(["x"], arr)
182+
183+
actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data
184+
assert np.array_equal(actual_real, expected_real)
185+
assert actual_real.dtype == expected_real.dtype
175186

176-
expected_imag = -np.arange(3)
177-
assert np.array_equal(named_array.imag.data, expected_imag)
187+
actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data
188+
assert np.array_equal(actual_imag, expected_imag)
189+
assert actual_imag.dtype == expected_imag.dtype
178190

179191

180192
# Additional tests as per your original class-based code
@@ -347,7 +359,9 @@ def _new(
347359

348360
def test_replace_namedarray() -> None:
349361
dtype_float = np.dtype(np.float32)
362+
np_val: np.ndarray[Any, np.dtype[np.float32]]
350363
np_val = np.array([1.5, 3.2], dtype=dtype_float)
364+
np_val2: np.ndarray[Any, np.dtype[np.float32]]
351365
np_val2 = 2 * np_val
352366

353367
narr_float: NamedArray[Any, np.dtype[np.float32]]

0 commit comments

Comments
 (0)