@@ -95,6 +95,7 @@ def __init__(
95
95
self .win_freq = None
96
96
self .axis = obj ._get_axis_number (axis ) if axis is not None else None
97
97
self .validate ()
98
+ self ._apply_func_cache = dict ()
98
99
99
100
@property
100
101
def _constructor (self ):
@@ -493,7 +494,7 @@ def _apply(
493
494
minimum_periods = _check_min_periods (
494
495
self .min_periods or 1 , self .min_periods , len (values ) + offset
495
496
)
496
- func = partial ( # type: ignore
497
+ func_partial = partial ( # type: ignore
497
498
func , begin = start , end = end , minimum_periods = minimum_periods
498
499
)
499
500
@@ -511,7 +512,7 @@ def _apply(
511
512
cfunc , check_minp , index_as_array , ** kwargs
512
513
)
513
514
514
- func = partial ( # type: ignore
515
+ func_partial = partial ( # type: ignore
515
516
func ,
516
517
window = window ,
517
518
min_periods = self .min_periods ,
@@ -521,12 +522,12 @@ def _apply(
521
522
if additional_nans is not None :
522
523
523
524
def calc (x ):
524
- return func (np .concatenate ((x , additional_nans )))
525
+ return func_partial (np .concatenate ((x , additional_nans )))
525
526
526
527
else :
527
528
528
529
def calc (x ):
529
- return func (x )
530
+ return func_partial (x )
530
531
531
532
with np .errstate (all = "ignore" ):
532
533
if values .ndim > 1 :
@@ -535,6 +536,9 @@ def calc(x):
535
536
result = calc (values )
536
537
result = np .asarray (result )
537
538
539
+ if use_numba :
540
+ self ._apply_func_cache [name ] = func
541
+
538
542
if center :
539
543
result = self ._center_window (result , window )
540
544
@@ -1147,8 +1151,34 @@ def f(arg, window, min_periods, closed):
1147
1151
1148
1152
# Numba doesn't support kwargs in nopython mode
1149
1153
# https://github.com/numba/numba/issues/2916
1150
- numba_func = numba .njit (func )
1151
- rolling_apply = partial (methods .rolling_apply , numba_func = numba_func , args = args )
1154
+ if func not in self ._apply_func_cache :
1155
+
1156
+ def make_rolling_apply (func ):
1157
+
1158
+ numba_func = numba .njit (func )
1159
+
1160
+ @numba .njit
1161
+ def roll_apply (
1162
+ values : np .ndarray ,
1163
+ begin : np .ndarray ,
1164
+ end : np .ndarray ,
1165
+ minimum_periods : int ,
1166
+ ):
1167
+ result = np .empty (len (begin ))
1168
+ for i , (start , stop ) in enumerate (zip (begin , end )):
1169
+ window = values [start :stop ]
1170
+ count_nan = np .sum (np .isnan (window ))
1171
+ if len (window ) - count_nan >= minimum_periods :
1172
+ result [i ] = numba_func (window , * args )
1173
+ else :
1174
+ result [i ] = np .nan
1175
+ return result
1176
+
1177
+ return roll_apply
1178
+
1179
+ rolling_apply = make_rolling_apply (func )
1180
+ else :
1181
+ rolling_apply = self ._apply_func_cache [func ]
1152
1182
1153
1183
return self ._apply (
1154
1184
rolling_apply ,
0 commit comments