1
1
from __future__ import annotations
2
2
3
- from ...common import _aliases
3
+ from typing import Callable
4
+
5
+ from ...common import _aliases , array_namespace
4
6
5
7
from ..._internal import get_xp
6
8
29
31
)
30
32
31
33
from typing import TYPE_CHECKING
34
+
32
35
if TYPE_CHECKING :
33
36
from typing import Optional , Union
34
37
35
- from ...common ._typing import Device , Dtype , Array , NestedSequence , SupportsBufferProtocol
38
+ from ...common ._typing import (
39
+ Device ,
40
+ Dtype ,
41
+ Array ,
42
+ NestedSequence ,
43
+ SupportsBufferProtocol ,
44
+ )
36
45
37
46
import dask .array as da
38
47
39
48
isdtype = get_xp (np )(_aliases .isdtype )
40
49
unstack = get_xp (da )(_aliases .unstack )
41
50
51
+
42
52
# da.astype doesn't respect copy=True
43
53
def astype (
44
54
x : Array ,
45
55
dtype : Dtype ,
46
56
/ ,
47
57
* ,
48
58
copy : bool = True ,
49
- device : Optional [Device ] = None
59
+ device : Optional [Device ] = None ,
50
60
) -> Array :
51
61
"""
52
62
Array API compatibility wrapper for astype().
@@ -61,8 +71,10 @@ def astype(
61
71
x = x .astype (dtype )
62
72
return x .copy () if copy else x
63
73
74
+
64
75
# Common aliases
65
76
77
+
66
78
# This arange func is modified from the common one to
67
79
# not pass stop/step as keyword arguments, which will cause
68
80
# an error with dask
@@ -189,6 +201,7 @@ def asarray(
189
201
concatenate as concat ,
190
202
)
191
203
204
+
192
205
# dask.array.clip does not work unless all three arguments are provided.
193
206
# Furthermore, the masking workaround in common._aliases.clip cannot work with
194
207
# dask (meaning uint64 promoting to float64 is going to just be unfixed for
@@ -205,8 +218,10 @@ def clip(
205
218
See the corresponding documentation in the array library and/or the array API
206
219
specification for more details.
207
220
"""
221
+
208
222
def _isscalar (a ):
209
223
return isinstance (a , (int , float , type (None )))
224
+
210
225
min_shape = () if _isscalar (min ) else min .shape
211
226
max_shape = () if _isscalar (max ) else max .shape
212
227
@@ -228,12 +243,99 @@ def _isscalar(a):
228
243
229
244
return astype (da .minimum (da .maximum (x , min ), max ), x .dtype )
230
245
231
- # exclude these from all since dask.array has no sorting functions
232
- _da_unsupported = ['sort' , 'argsort' ]
233
246
234
- _common_aliases = [alias for alias in _aliases .__all__ if alias not in _da_unsupported ]
247
+ def _ensure_single_chunk (x : Array , axis : int ) -> tuple [Array , Callable [[Array ], Array ]]:
248
+ """
249
+ Make sure that Array is not broken into multiple chunks along axis.
250
+
251
+ Returns
252
+ -------
253
+ x : Array
254
+ The input Array with a single chunk along axis.
255
+ restore : Callable[Array, Array]
256
+ function to apply to the output to rechunk it back into reasonable chunks
257
+ """
258
+ if axis < 0 :
259
+ axis += x .ndim
260
+ if x .numblocks [axis ] < 2 :
261
+ return x , lambda x : x
262
+
263
+ # Break chunks on other axes in an attempt to keep chunk size low
264
+ x = x .rechunk ({i : - 1 if i == axis else "auto" for i in range (x .ndim )})
265
+
266
+ # Rather than reconstructing the original chunks, which can be a
267
+ # very expensive affair, just break down oversized chunks without
268
+ # incurring in any transfers over the network.
269
+ # This has the downside of a risk of overchunking if the array is
270
+ # then used in operations against other arrays that match the
271
+ # original chunking pattern.
272
+ return x , lambda x : x .rechunk ()
273
+
274
+
275
+ def sort (
276
+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
277
+ ) -> Array :
278
+ """
279
+ Array API compatibility layer around the lack of sort() in Dask.
280
+
281
+ Warnings
282
+ --------
283
+ This function temporarily rechunks the array along `axis` to a single chunk.
284
+ This can be extremely inefficient and can lead to out-of-memory errors.
285
+
286
+ See the corresponding documentation in the array library and/or the array API
287
+ specification for more details.
288
+ """
289
+ x , restore = _ensure_single_chunk (x , axis )
290
+
291
+ meta_xp = array_namespace (x ._meta )
292
+ x = da .map_blocks (
293
+ meta_xp .sort ,
294
+ x ,
295
+ axis = axis ,
296
+ meta = x ._meta ,
297
+ dtype = x .dtype ,
298
+ descending = descending ,
299
+ stable = stable ,
300
+ )
301
+
302
+ return restore (x )
235
303
236
- __all__ = _common_aliases + ['__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
304
+
305
+ def argsort (
306
+ x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
307
+ ) -> Array :
308
+ """
309
+ Array API compatibility layer around the lack of argsort() in Dask.
310
+
311
+ See the corresponding documentation in the array library and/or the array API
312
+ specification for more details.
313
+
314
+ Warnings
315
+ --------
316
+ This function temporarily rechunks the array along `axis` into a single chunk.
317
+ This can be extremely inefficient and can lead to out-of-memory errors.
318
+ """
319
+ x , restore = _ensure_single_chunk (x , axis )
320
+
321
+ meta_xp = array_namespace (x ._meta )
322
+ dtype = meta_xp .argsort (x ._meta ).dtype
323
+ meta = meta_xp .astype (x ._meta , dtype )
324
+ x = da .map_blocks (
325
+ meta_xp .argsort ,
326
+ x ,
327
+ axis = axis ,
328
+ meta = meta ,
329
+ dtype = dtype ,
330
+ descending = descending ,
331
+ stable = stable ,
332
+ )
333
+
334
+ return restore (x )
335
+
336
+
337
+ __all__ = _aliases .__all__ + [
338
+ '__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
237
339
'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
238
340
'atanh' , 'bitwise_left_shift' , 'bitwise_invert' ,
239
341
'bitwise_right_shift' , 'concat' , 'pow' , 'iinfo' , 'finfo' , 'can_cast' ,
@@ -242,4 +344,4 @@ def _isscalar(a):
242
344
'complex64' , 'complex128' , 'iinfo' , 'finfo' ,
243
345
'can_cast' , 'result_type' ]
244
346
245
- _all_ignore = ["get_xp" , "da" , "np" ]
347
+ _all_ignore = ["Callable" , "array_namespace" , " get_xp" , "da" , "np" ]
0 commit comments