18
18
from numpy import zeros_like # noqa
19
19
from numpy import around , broadcast_to # noqa
20
20
from numpy import concatenate as _concatenate
21
- from numpy import einsum , gradient , isclose , isin , isnan , isnat # noqa
22
- from numpy import stack as _stack
23
- from numpy import take , tensordot , transpose , unravel_index # noqa
24
- from numpy import where as _where
21
+ from numpy import ( # noqa
22
+ einsum ,
23
+ gradient ,
24
+ isclose ,
25
+ isin ,
26
+ isnat ,
27
+ take ,
28
+ tensordot ,
29
+ transpose ,
30
+ unravel_index ,
31
+ )
25
32
from numpy .lib .stride_tricks import sliding_window_view # noqa
26
33
27
34
from . import dask_array_ops , dtypes , nputils
36
43
dask_array = None # type: ignore
37
44
38
45
46
+ def get_array_namespace (x ):
47
+ if hasattr (x , "__array_namespace__" ):
48
+ return x .__array_namespace__ ()
49
+ else :
50
+ return np
51
+
52
+
39
53
def _dask_or_eager_func (
40
54
name ,
41
55
eager_module = np ,
@@ -108,7 +122,8 @@ def isnull(data):
108
122
return isnat (data )
109
123
elif issubclass (scalar_type , np .inexact ):
110
124
# float types use NaN for null
111
- return isnan (data )
125
+ xp = get_array_namespace (data )
126
+ return xp .isnan (data )
112
127
elif issubclass (scalar_type , (np .bool_ , np .integer , np .character , np .void )):
113
128
# these types cannot represent missing values
114
129
return zeros_like (data , dtype = bool )
@@ -164,28 +179,31 @@ def cumulative_trapezoid(y, x, axis):
164
179
165
180
166
181
def astype (data , dtype , ** kwargs ):
182
+ if hasattr (data , "__array_namespace__" ):
183
+ xp = get_array_namespace (data )
184
+ return xp .astype (data , dtype , ** kwargs )
167
185
return data .astype (dtype , ** kwargs )
168
186
169
187
170
188
def asarray (data , xp = np ):
171
189
return data if is_duck_array (data ) else xp .asarray (data )
172
190
173
191
174
- def as_shared_dtype (scalars_or_arrays ):
192
+ def as_shared_dtype (scalars_or_arrays , xp = np ):
175
193
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
176
194
177
195
if any (isinstance (x , cupy_array_type ) for x in scalars_or_arrays ):
178
196
import cupy as cp
179
197
180
198
arrays = [asarray (x , xp = cp ) for x in scalars_or_arrays ]
181
199
else :
182
- arrays = [asarray (x ) for x in scalars_or_arrays ]
200
+ arrays = [asarray (x , xp = xp ) for x in scalars_or_arrays ]
183
201
# Pass arrays directly instead of dtypes to result_type so scalars
184
202
# get handled properly.
185
203
# Note that result_type() safely gets the dtype from dask arrays without
186
204
# evaluating them.
187
205
out_type = dtypes .result_type (* arrays )
188
- return [x . astype (out_type , copy = False ) for x in arrays ]
206
+ return [astype (x , out_type , copy = False ) for x in arrays ]
189
207
190
208
191
209
def lazy_array_equiv (arr1 , arr2 ):
@@ -259,9 +277,20 @@ def count(data, axis=None):
259
277
return np .sum (np .logical_not (isnull (data )), axis = axis )
260
278
261
279
280
+ def sum_where (data , axis = None , dtype = None , where = None ):
281
+ xp = get_array_namespace (data )
282
+ if where is not None :
283
+ a = where_method (xp .zeros_like (data ), where , data )
284
+ else :
285
+ a = data
286
+ result = xp .sum (a , axis = axis , dtype = dtype )
287
+ return result
288
+
289
+
262
290
def where (condition , x , y ):
263
291
"""Three argument where() with better dtype promotion rules."""
264
- return _where (condition , * as_shared_dtype ([x , y ]))
292
+ xp = get_array_namespace (condition )
293
+ return xp .where (condition , * as_shared_dtype ([x , y ], xp = xp ))
265
294
266
295
267
296
def where_method (data , cond , other = dtypes .NA ):
@@ -284,7 +313,13 @@ def concatenate(arrays, axis=0):
284
313
285
314
def stack (arrays , axis = 0 ):
286
315
"""stack() with better dtype promotion rules."""
287
- return _stack (as_shared_dtype (arrays ), axis = axis )
316
+ xp = get_array_namespace (arrays [0 ])
317
+ return xp .stack (as_shared_dtype (arrays , xp = xp ), axis = axis )
318
+
319
+
320
+ def reshape (array , shape ):
321
+ xp = get_array_namespace (array )
322
+ return xp .reshape (array , shape )
288
323
289
324
290
325
@contextlib .contextmanager
@@ -323,11 +358,8 @@ def f(values, axis=None, skipna=None, **kwargs):
323
358
if name in ["sum" , "prod" ]:
324
359
kwargs .pop ("min_count" , None )
325
360
326
- if hasattr (values , "__array_namespace__" ):
327
- xp = values .__array_namespace__ ()
328
- func = getattr (xp , name )
329
- else :
330
- func = getattr (np , name )
361
+ xp = get_array_namespace (values )
362
+ func = getattr (xp , name )
331
363
332
364
try :
333
365
with warnings .catch_warnings ():
0 commit comments