11
11
12
12
import numpy as np
13
13
import pandas as pd
14
+ from numpy import all as array_all # noqa
15
+ from numpy import any as array_any # noqa
16
+ from numpy import zeros_like # noqa
17
+ from numpy import around , broadcast_to # noqa
18
+ from numpy import concatenate as _concatenate
19
+ from numpy import einsum , isclose , isin , isnan , isnat , pad # noqa
20
+ from numpy import stack as _stack
21
+ from numpy import take , tensordot , transpose , unravel_index # noqa
22
+ from numpy import where as _where
14
23
15
24
from . import dask_array_compat , dask_array_ops , dtypes , npcompat , nputils
16
25
from .nputils import nanfirst , nanlast
@@ -34,31 +43,15 @@ def _dask_or_eager_func(
34
43
name ,
35
44
eager_module = np ,
36
45
dask_module = dask_array ,
37
- list_of_args = False ,
38
- array_args = slice (1 ),
39
- requires_dask = None ,
40
46
):
41
47
"""Create a function that dispatches to dask for dask array inputs."""
42
- if dask_module is not None :
43
-
44
- def f (* args , ** kwargs ):
45
- if list_of_args :
46
- dispatch_args = args [0 ]
47
- else :
48
- dispatch_args = args [array_args ]
49
- if any (is_duck_dask_array (a ) for a in dispatch_args ):
50
- try :
51
- wrapped = getattr (dask_module , name )
52
- except AttributeError as e :
53
- raise AttributeError (f"{ e } : requires dask >={ requires_dask } " )
54
- else :
55
- wrapped = getattr (eager_module , name )
56
- return wrapped (* args , ** kwargs )
57
48
58
- else :
59
-
60
- def f (* args , ** kwargs ):
61
- return getattr (eager_module , name )(* args , ** kwargs )
49
+ def f (* args , ** kwargs ):
50
+ if any (is_duck_dask_array (a ) for a in args ):
51
+ wrapped = getattr (dask_module , name )
52
+ else :
53
+ wrapped = getattr (eager_module , name )
54
+ return wrapped (* args , ** kwargs )
62
55
63
56
return f
64
57
@@ -72,16 +65,40 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
72
65
raise NotImplementedError (msg % func_name )
73
66
74
67
75
- around = _dask_or_eager_func ("around" )
76
- isclose = _dask_or_eager_func ("isclose" )
77
-
68
+ # Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
69
+ pandas_isnull = _dask_or_eager_func ("isnull" , eager_module = pd , dask_module = dask_array )
78
70
79
- isnat = np .isnat
80
- isnan = _dask_or_eager_func ("isnan" )
81
- zeros_like = _dask_or_eager_func ("zeros_like" )
82
-
83
-
84
- pandas_isnull = _dask_or_eager_func ("isnull" , eager_module = pd )
71
+ # np.around has failing doctests, overwrite it so they pass:
72
+ # https://github.com/numpy/numpy/issues/19759
73
+ around .__doc__ = str .replace (
74
+ around .__doc__ or "" ,
75
+ "array([0., 2.])" ,
76
+ "array([0., 2.])" ,
77
+ )
78
+ around .__doc__ = str .replace (
79
+ around .__doc__ or "" ,
80
+ "array([0., 2.])" ,
81
+ "array([0., 2.])" ,
82
+ )
83
+ around .__doc__ = str .replace (
84
+ around .__doc__ or "" ,
85
+ "array([0.4, 1.6])" ,
86
+ "array([0.4, 1.6])" ,
87
+ )
88
+ around .__doc__ = str .replace (
89
+ around .__doc__ or "" ,
90
+ "array([0., 2., 2., 4., 4.])" ,
91
+ "array([0., 2., 2., 4., 4.])" ,
92
+ )
93
+ around .__doc__ = str .replace (
94
+ around .__doc__ or "" ,
95
+ (
96
+ ' .. [2] "How Futile are Mindless Assessments of\n '
97
+ ' Roundoff in Floating-Point Computation?", William Kahan,\n '
98
+ " https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n "
99
+ ),
100
+ "" ,
101
+ )
85
102
86
103
87
104
def isnull (data ):
@@ -114,21 +131,10 @@ def notnull(data):
114
131
return ~ isnull (data )
115
132
116
133
117
- transpose = _dask_or_eager_func ("transpose" )
118
- _where = _dask_or_eager_func ("where" , array_args = slice (3 ))
119
- isin = _dask_or_eager_func ("isin" , array_args = slice (2 ))
120
- take = _dask_or_eager_func ("take" )
121
- broadcast_to = _dask_or_eager_func ("broadcast_to" )
122
- pad = _dask_or_eager_func ("pad" , dask_module = dask_array_compat )
123
-
124
- _concatenate = _dask_or_eager_func ("concatenate" , list_of_args = True )
125
- _stack = _dask_or_eager_func ("stack" , list_of_args = True )
126
-
127
- array_all = _dask_or_eager_func ("all" )
128
- array_any = _dask_or_eager_func ("any" )
129
-
130
- tensordot = _dask_or_eager_func ("tensordot" , array_args = slice (2 ))
131
- einsum = _dask_or_eager_func ("einsum" , array_args = slice (1 , None ))
134
+ # TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
135
+ masked_invalid = _dask_or_eager_func (
136
+ "masked_invalid" , eager_module = np .ma , dask_module = getattr (dask_array , "ma" , None )
137
+ )
132
138
133
139
134
140
def gradient (x , coord , axis , edge_order ):
@@ -166,11 +172,6 @@ def cumulative_trapezoid(y, x, axis):
166
172
return cumsum (integrand , axis = axis , skipna = False )
167
173
168
174
169
- masked_invalid = _dask_or_eager_func (
170
- "masked_invalid" , eager_module = np .ma , dask_module = getattr (dask_array , "ma" , None )
171
- )
172
-
173
-
174
175
def astype (data , dtype , ** kwargs ):
175
176
if (
176
177
isinstance (data , sparse_array_type )
@@ -317,9 +318,7 @@ def _ignore_warnings_if(condition):
317
318
yield
318
319
319
320
320
- def _create_nan_agg_method (
321
- name , dask_module = dask_array , coerce_strings = False , invariant_0d = False
322
- ):
321
+ def _create_nan_agg_method (name , coerce_strings = False , invariant_0d = False ):
323
322
from . import nanops
324
323
325
324
def f (values , axis = None , skipna = None , ** kwargs ):
@@ -344,7 +343,8 @@ def f(values, axis=None, skipna=None, **kwargs):
344
343
else :
345
344
if name in ["sum" , "prod" ]:
346
345
kwargs .pop ("min_count" , None )
347
- func = _dask_or_eager_func (name , dask_module = dask_module )
346
+
347
+ func = getattr (np , name )
348
348
349
349
try :
350
350
with warnings .catch_warnings ():
@@ -378,9 +378,7 @@ def f(values, axis=None, skipna=None, **kwargs):
378
378
std .numeric_only = True
379
379
var = _create_nan_agg_method ("var" )
380
380
var .numeric_only = True
381
- median = _create_nan_agg_method (
382
- "median" , dask_module = dask_array_compat , invariant_0d = True
383
- )
381
+ median = _create_nan_agg_method ("median" , invariant_0d = True )
384
382
median .numeric_only = True
385
383
prod = _create_nan_agg_method ("prod" , invariant_0d = True )
386
384
prod .numeric_only = True
@@ -389,7 +387,6 @@ def f(values, axis=None, skipna=None, **kwargs):
389
387
cumprod_1d .numeric_only = True
390
388
cumsum_1d = _create_nan_agg_method ("cumsum" , invariant_0d = True )
391
389
cumsum_1d .numeric_only = True
392
- unravel_index = _dask_or_eager_func ("unravel_index" )
393
390
394
391
395
392
_mean = _create_nan_agg_method ("mean" , invariant_0d = True )
0 commit comments