34
34
from numpy .lib .stride_tricks import sliding_window_view # noqa
35
35
36
36
from xarray .core import dask_array_ops , dtypes , nputils
37
- from xarray .core .utils import module_available
38
- from xarray .namedarray ._array_api import _get_data_namespace
39
37
from xarray .namedarray ._typing import _arrayfunction_or_api
40
38
from xarray .namedarray .parallelcompat import get_chunked_array_type , is_chunked_array
41
39
from xarray .namedarray .pycompat import array_type
42
- from xarray .namedarray .utils import is_duck_dask_array
40
+ from xarray .namedarray .utils import is_duck_dask_array , module_available
43
41
44
42
dask_available = module_available ("dask" )
45
43
46
44
45
+ def get_array_namespace (x ):
46
+ if hasattr (x , "__array_namespace__" ):
47
+ return x .__array_namespace__ ()
48
+ else :
49
+ return np
50
+
51
+
47
52
def _dask_or_eager_func (
48
53
name ,
49
54
eager_module = np ,
@@ -121,7 +126,7 @@ def isnull(data):
121
126
return isnat (data )
122
127
elif issubclass (scalar_type , np .inexact ):
123
128
# float types use NaN for null
124
- xp = _get_data_namespace (data )
129
+ xp = get_array_namespace (data )
125
130
return xp .isnan (data )
126
131
elif issubclass (scalar_type , (np .bool_ , np .integer , np .character , np .void )):
127
132
# these types cannot represent missing values
@@ -179,7 +184,7 @@ def cumulative_trapezoid(y, x, axis):
179
184
180
185
def astype (data , dtype , ** kwargs ):
181
186
if hasattr (data , "__array_namespace__" ):
182
- xp = _get_data_namespace (data )
187
+ xp = get_array_namespace (data )
183
188
if xp == np :
184
189
# numpy currently doesn't have a astype:
185
190
return data .astype (dtype , ** kwargs )
@@ -211,7 +216,7 @@ def as_shared_dtype(scalars_or_arrays, xp=np):
211
216
212
217
213
218
def broadcast_to (array , shape ):
214
- xp = _get_data_namespace (array )
219
+ xp = get_array_namespace (array )
215
220
return xp .broadcast_to (array , shape )
216
221
217
222
@@ -289,7 +294,7 @@ def count(data, axis=None):
289
294
290
295
291
296
def sum_where (data , axis = None , dtype = None , where = None ):
292
- xp = _get_data_namespace (data )
297
+ xp = get_array_namespace (data )
293
298
if where is not None :
294
299
a = where_method (xp .zeros_like (data ), where , data )
295
300
else :
@@ -300,7 +305,7 @@ def sum_where(data, axis=None, dtype=None, where=None):
300
305
301
306
def where (condition , x , y ):
302
307
"""Three argument where() with better dtype promotion rules."""
303
- xp = _get_data_namespace (condition )
308
+ xp = get_array_namespace (condition )
304
309
return xp .where (condition , * as_shared_dtype ([x , y ], xp = xp ))
305
310
306
311
@@ -320,19 +325,19 @@ def fillna(data, other):
320
325
def concatenate (arrays , axis = 0 ):
321
326
"""concatenate() with better dtype promotion rules."""
322
327
if hasattr (arrays [0 ], "__array_namespace__" ):
323
- xp = _get_data_namespace (arrays [0 ])
328
+ xp = get_array_namespace (arrays [0 ])
324
329
return xp .concat (as_shared_dtype (arrays , xp = xp ), axis = axis )
325
330
return _concatenate (as_shared_dtype (arrays ), axis = axis )
326
331
327
332
328
333
def stack (arrays , axis = 0 ):
329
334
"""stack() with better dtype promotion rules."""
330
- xp = _get_data_namespace (arrays [0 ])
335
+ xp = get_array_namespace (arrays [0 ])
331
336
return xp .stack (as_shared_dtype (arrays , xp = xp ), axis = axis )
332
337
333
338
334
339
def reshape (array , shape ):
335
- xp = _get_data_namespace (array )
340
+ xp = get_array_namespace (array )
336
341
return xp .reshape (array , shape )
337
342
338
343
@@ -376,7 +381,7 @@ def f(values, axis=None, skipna=None, **kwargs):
376
381
if name in ["sum" , "prod" ]:
377
382
kwargs .pop ("min_count" , None )
378
383
379
- xp = _get_data_namespace (values )
384
+ xp = get_array_namespace (values )
380
385
func = getattr (xp , name )
381
386
382
387
try :
0 commit comments