Skip to content

Commit aa9644c

Browse files
author
Matt Roeschke
committed
Cache apply function
1 parent 9b9ea7a commit aa9644c

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

pandas/core/window/rolling.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
self.win_freq = None
9696
self.axis = obj._get_axis_number(axis) if axis is not None else None
9797
self.validate()
98+
self._apply_func_cache = dict()
9899

99100
@property
100101
def _constructor(self):
@@ -493,7 +494,7 @@ def _apply(
493494
minimum_periods = _check_min_periods(
494495
self.min_periods or 1, self.min_periods, len(values) + offset
495496
)
496-
func = partial( # type: ignore
497+
func_partial = partial( # type: ignore
497498
func, begin=start, end=end, minimum_periods=minimum_periods
498499
)
499500

@@ -511,7 +512,7 @@ def _apply(
511512
cfunc, check_minp, index_as_array, **kwargs
512513
)
513514

514-
func = partial( # type: ignore
515+
func_partial = partial( # type: ignore
515516
func,
516517
window=window,
517518
min_periods=self.min_periods,
@@ -521,12 +522,12 @@ def _apply(
521522
if additional_nans is not None:
522523

523524
def calc(x):
524-
return func(np.concatenate((x, additional_nans)))
525+
return func_partial(np.concatenate((x, additional_nans)))
525526

526527
else:
527528

528529
def calc(x):
529-
return func(x)
530+
return func_partial(x)
530531

531532
with np.errstate(all="ignore"):
532533
if values.ndim > 1:
@@ -535,6 +536,9 @@ def calc(x):
535536
result = calc(values)
536537
result = np.asarray(result)
537538

539+
if use_numba:
540+
self._apply_func_cache[name] = func
541+
538542
if center:
539543
result = self._center_window(result, window)
540544

@@ -1147,8 +1151,34 @@ def f(arg, window, min_periods, closed):
11471151

11481152
# Numba doesn't support kwargs in nopython mode
11491153
# https://github.com/numba/numba/issues/2916
1150-
numba_func = numba.njit(func)
1151-
rolling_apply = partial(methods.rolling_apply, numba_func=numba_func, args=args)
1154+
if func not in self._apply_func_cache:
1155+
1156+
def make_rolling_apply(func):
1157+
1158+
numba_func = numba.njit(func)
1159+
1160+
@numba.njit
1161+
def roll_apply(
1162+
values: np.ndarray,
1163+
begin: np.ndarray,
1164+
end: np.ndarray,
1165+
minimum_periods: int,
1166+
):
1167+
result = np.empty(len(begin))
1168+
for i, (start, stop) in enumerate(zip(begin, end)):
1169+
window = values[start:stop]
1170+
count_nan = np.sum(np.isnan(window))
1171+
if len(window) - count_nan >= minimum_periods:
1172+
result[i] = numba_func(window, *args)
1173+
else:
1174+
result[i] = np.nan
1175+
return result
1176+
1177+
return roll_apply
1178+
1179+
rolling_apply = make_rolling_apply(func)
1180+
else:
1181+
rolling_apply = self._apply_func_cache[func]
11521182

11531183
return self._apply(
11541184
rolling_apply,

0 commit comments

Comments
 (0)