Skip to content

Commit 8a87810

Browse files
committed
revert get_array_namespace
1 parent 6aa6b80 commit 8a87810

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

xarray/core/duck_array_ops.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,21 @@
3434
from numpy.lib.stride_tricks import sliding_window_view # noqa
3535

3636
from xarray.core import dask_array_ops, dtypes, nputils
37-
from xarray.core.utils import module_available
38-
from xarray.namedarray._array_api import _get_data_namespace
3937
from xarray.namedarray._typing import _arrayfunction_or_api
4038
from xarray.namedarray.parallelcompat import get_chunked_array_type, is_chunked_array
4139
from xarray.namedarray.pycompat import array_type
42-
from xarray.namedarray.utils import is_duck_dask_array
40+
from xarray.namedarray.utils import is_duck_dask_array, module_available
4341

4442
dask_available = module_available("dask")
4543

4644

45+
def get_array_namespace(x):
46+
if hasattr(x, "__array_namespace__"):
47+
return x.__array_namespace__()
48+
else:
49+
return np
50+
51+
4752
def _dask_or_eager_func(
4853
name,
4954
eager_module=np,
@@ -121,7 +126,7 @@ def isnull(data):
121126
return isnat(data)
122127
elif issubclass(scalar_type, np.inexact):
123128
# float types use NaN for null
124-
xp = _get_data_namespace(data)
129+
xp = get_array_namespace(data)
125130
return xp.isnan(data)
126131
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
127132
# these types cannot represent missing values
@@ -179,7 +184,7 @@ def cumulative_trapezoid(y, x, axis):
179184

180185
def astype(data, dtype, **kwargs):
181186
if hasattr(data, "__array_namespace__"):
182-
xp = _get_data_namespace(data)
187+
xp = get_array_namespace(data)
183188
if xp == np:
184189
# numpy currently doesn't have a astype:
185190
return data.astype(dtype, **kwargs)
@@ -211,7 +216,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211216

212217

213218
def broadcast_to(array, shape):
214-
xp = _get_data_namespace(array)
219+
xp = get_array_namespace(array)
215220
return xp.broadcast_to(array, shape)
216221

217222

@@ -289,7 +294,7 @@ def count(data, axis=None):
289294

290295

291296
def sum_where(data, axis=None, dtype=None, where=None):
292-
xp = _get_data_namespace(data)
297+
xp = get_array_namespace(data)
293298
if where is not None:
294299
a = where_method(xp.zeros_like(data), where, data)
295300
else:
@@ -300,7 +305,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300305

301306
def where(condition, x, y):
302307
"""Three argument where() with better dtype promotion rules."""
303-
xp = _get_data_namespace(condition)
308+
xp = get_array_namespace(condition)
304309
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
305310

306311

@@ -320,19 +325,19 @@ def fillna(data, other):
320325
def concatenate(arrays, axis=0):
321326
"""concatenate() with better dtype promotion rules."""
322327
if hasattr(arrays[0], "__array_namespace__"):
323-
xp = _get_data_namespace(arrays[0])
328+
xp = get_array_namespace(arrays[0])
324329
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis)
325330
return _concatenate(as_shared_dtype(arrays), axis=axis)
326331

327332

328333
def stack(arrays, axis=0):
329334
"""stack() with better dtype promotion rules."""
330-
xp = _get_data_namespace(arrays[0])
335+
xp = get_array_namespace(arrays[0])
331336
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
332337

333338

334339
def reshape(array, shape):
335-
xp = _get_data_namespace(array)
340+
xp = get_array_namespace(array)
336341
return xp.reshape(array, shape)
337342

338343

@@ -376,7 +381,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376381
if name in ["sum", "prod"]:
377382
kwargs.pop("min_count", None)
378383

379-
xp = _get_data_namespace(values)
384+
xp = get_array_namespace(values)
380385
func = getattr(xp, name)
381386

382387
try:

0 commit comments

Comments
 (0)