Skip to content

Commit d34b96c

Browse files
authored
Merge pull request #29 from twosigma/feature/rolling_apply_numba
Add Numba to rolling.apply
2 parents c06d296 + e955d47 commit d34b96c

File tree

6 files changed

+128
-56
lines changed

6 files changed

+128
-56
lines changed

asv_bench/benchmarks/rolling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ class Apply:
3030
["DataFrame", "Series"],
3131
[10, 1000],
3232
["int", "float"],
33-
[sum, np.sum, lambda x: np.sum(x) + 5],
33+
# TODO: numba doesn't support builtin.sum
34+
[np.sum, lambda x: np.sum(x) + 5],
3435
[True, False],
3536
)
3637
param_names = ["contructor", "window", "dtype", "function", "raw"]

pandas/core/window/aggregators/methods.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
This implementation mimics what we currently do in cython except the
44
calculation of window bounds is independent of the aggregation routine.
55
"""
6-
76
import numba
87
import numpy as np
98

pandas/core/window/rolling.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from datetime import timedelta
66
from functools import partial
77
from textwrap import dedent
8-
from typing import Callable, List, Optional, Set, Union
8+
from typing import Callable, Dict, List, Optional, Set, Union
99
import warnings
1010

11+
import numba
1112
import numpy as np
1213

1314
import pandas._libs.window as libwindow
@@ -94,6 +95,7 @@ def __init__(
9495
self.win_freq = None
9596
self.axis = obj._get_axis_number(axis) if axis is not None else None
9697
self.validate()
98+
self._apply_func_cache = dict() # type: Dict
9799

98100
@property
99101
def _constructor(self):
@@ -431,7 +433,13 @@ def _apply(
431433
-------
432434
y : type of input
433435
"""
434-
use_numba = kwargs.pop("use_numba", None)
436+
use_numba = kwargs.pop("use_numba", False)
437+
floor = kwargs.pop("floor", None)
438+
if not use_numba:
439+
# apply stores use_numba and floor in kwargs[kwargs]
440+
extra_kwargs = kwargs.pop("kwargs", {})
441+
use_numba = extra_kwargs.get("use_numba", False)
442+
floor = extra_kwargs.get("floor", None)
435443

436444
if center is None:
437445
center = self.center
@@ -487,12 +495,16 @@ def _apply(
487495
window,
488496
_use_window(self.min_periods, window),
489497
len(values) + offset,
498+
floor,
490499
)
491500
else:
492501
minimum_periods = _check_min_periods(
493-
self.min_periods or 1, self.min_periods, len(values) + offset
502+
self.min_periods or 1,
503+
self.min_periods,
504+
len(values) + offset,
505+
floor,
494506
)
495-
func = partial( # type: ignore
507+
func_partial = partial( # type: ignore
496508
func, begin=start, end=end, minimum_periods=minimum_periods
497509
)
498510

@@ -510,7 +522,7 @@ def _apply(
510522
cfunc, check_minp, index_as_array, **kwargs
511523
)
512524

513-
func = partial( # type: ignore
525+
func_partial = partial( # type: ignore
514526
func,
515527
window=window,
516528
min_periods=self.min_periods,
@@ -520,12 +532,12 @@ def _apply(
520532
if additional_nans is not None:
521533

522534
def calc(x):
523-
return func(np.concatenate((x, additional_nans)))
535+
return func_partial(np.concatenate((x, additional_nans)))
524536

525537
else:
526538

527539
def calc(x):
528-
return func(x)
540+
return func_partial(x)
529541

530542
with np.errstate(all="ignore"):
531543
if values.ndim > 1:
@@ -534,6 +546,9 @@ def calc(x):
534546
result = calc(values)
535547
result = np.asarray(result)
536548

549+
if use_numba:
550+
self._apply_func_cache[name] = func
551+
537552
if center:
538553
result = self._center_window(result, window)
539554

@@ -1106,12 +1121,8 @@ def count(self):
11061121
)
11071122

11081123
def apply(self, func, raw=None, args=(), kwargs={}):
1109-
from pandas import Series
11101124

11111125
kwargs.pop("_level", None)
1112-
window = self._get_window()
1113-
offset = _offset(window, self.center)
1114-
index_as_array = self._get_index()
11151126

11161127
# TODO: default is for backward compat
11171128
# change to False in the future
@@ -1127,24 +1138,54 @@ def apply(self, func, raw=None, args=(), kwargs={}):
11271138
)
11281139
raw = True
11291140

1130-
def f(arg, window, min_periods, closed):
1131-
minp = _use_window(min_periods, window)
1132-
if not raw:
1133-
arg = Series(arg, index=self.obj.index)
1134-
return libwindow.roll_generic(
1135-
arg,
1136-
window,
1137-
minp,
1138-
index_as_array,
1139-
closed,
1140-
offset,
1141-
func,
1142-
raw,
1143-
args,
1144-
kwargs,
1145-
)
1146-
1147-
return self._apply(f, func, args=args, kwargs=kwargs, center=False, raw=raw)
1141+
# Numba doesn't support kwargs in nopython mode
1142+
# https://github.com/numba/numba/issues/2916
1143+
if func not in self._apply_func_cache:
1144+
1145+
def make_rolling_apply(func):
1146+
@numba.generated_jit(nopython=True)
1147+
def numba_func(window, *_args):
1148+
if getattr(np, func.__name__, False) is func:
1149+
1150+
def impl(window, *_args):
1151+
return func(window, *_args)
1152+
1153+
return impl
1154+
else:
1155+
jf = numba.njit(func)
1156+
1157+
def impl(window, *_args):
1158+
return jf(window, *_args)
1159+
1160+
return impl
1161+
1162+
@numba.njit
1163+
def roll_apply(
1164+
values: np.ndarray,
1165+
begin: np.ndarray,
1166+
end: np.ndarray,
1167+
minimum_periods: int,
1168+
):
1169+
result = np.empty(len(begin))
1170+
for i, (start, stop) in enumerate(zip(begin, end)):
1171+
window = values[start:stop]
1172+
count_nan = np.sum(np.isnan(window))
1173+
if len(window) - count_nan >= minimum_periods:
1174+
result[i] = numba_func(window, *args)
1175+
else:
1176+
result[i] = np.nan
1177+
return result
1178+
1179+
return roll_apply
1180+
1181+
rolling_apply = make_rolling_apply(func)
1182+
else:
1183+
rolling_apply = self._apply_func_cache[func]
1184+
kwargs["use_numba"] = True
1185+
kwargs["floor"] = 0
1186+
return self._apply(
1187+
rolling_apply, func, args=args, kwargs=kwargs, center=False, raw=raw
1188+
)
11481189

11491190
def sum(self, *args, **kwargs):
11501191
nv.validate_window_func("sum", args, kwargs)

pandas/tests/window/test_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def test_agg(self):
132132
expected.columns = pd.MultiIndex.from_tuples(exp_cols)
133133
tm.assert_frame_equal(result, expected, check_like=True)
134134

135+
@pytest.mark.xfail(reason="TypingError: numba doesn't support kwarg for std")
135136
def test_agg_apply(self, raw):
136137

137138
# passed lambda

pandas/tests/window/test_moments.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,10 @@ def test_rolling_quantile_param(self):
628628
with pytest.raises(TypeError):
629629
ser.rolling(3).quantile("foo")
630630

631+
@pytest.mark.xfail(
632+
reason="unsupported controlflow due to return/raise statements "
633+
"inside with block"
634+
)
631635
def test_rolling_apply(self, raw):
632636
# suppress warnings about empty slices, as we are deliberately testing
633637
# with a 0-length Series
@@ -679,6 +683,10 @@ def test_rolling_apply_out_of_bounds(self, raw):
679683
expected = pd.Series([1, 3, 6, 10], dtype=float)
680684
tm.assert_almost_equal(result, expected)
681685

686+
@pytest.mark.xfail(
687+
reason="Untyped global name 'df': "
688+
"cannot determine Numba type of <class 'pandas.core.frame.DataFrame'>"
689+
)
682690
@pytest.mark.parametrize("window", [2, "2s"])
683691
def test_rolling_apply_with_pandas_objects(self, window):
684692
# 5071
@@ -1629,6 +1637,10 @@ def _ewma(s, com, min_periods, adjust, ignore_na):
16291637
),
16301638
)
16311639

1640+
@pytest.mark.xfail(
1641+
reason="Untyped global name 'Series': cannot determine "
1642+
"Numba type of <class 'type'>"
1643+
)
16321644
@pytest.mark.slow
16331645
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])
16341646
def test_expanding_consistency(self, min_periods):
@@ -1701,6 +1713,10 @@ def test_expanding_consistency(self, min_periods):
17011713
if name in ["sum", "prod"]:
17021714
tm.assert_equal(expanding_f_result, expanding_apply_f_result)
17031715

1716+
@pytest.mark.xfail(
1717+
reason="Untyped global name 'Series': cannot determine Numba type of "
1718+
"<class 'type'>"
1719+
)
17041720
@pytest.mark.slow
17051721
@pytest.mark.parametrize(
17061722
"window,min_periods,center", list(_rolling_consistency_cases())
@@ -1977,6 +1993,7 @@ def func(A, B, com, **kwargs):
19771993
with pytest.raises(Exception, match=msg):
19781994
func(A, randn(50), 20, min_periods=5)
19791995

1996+
@pytest.mark.xfail(reason="Use of unsupported opcode (SETUP_EXCEPT) found")
19801997
def test_expanding_apply_args_kwargs(self, raw):
19811998
def mean_w_arg(x, const):
19821999
return np.mean(x) + const
@@ -2118,8 +2135,18 @@ def test_rolling_corr_diff_length(self):
21182135
lambda x: x.rolling(window=10, min_periods=5).kurt(),
21192136
lambda x: x.rolling(window=10, min_periods=5).quantile(quantile=0.5),
21202137
lambda x: x.rolling(window=10, min_periods=5).median(),
2121-
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
2122-
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
2138+
pytest.param(
2139+
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
2140+
marks=pytest.mark.xfail(
2141+
reason="https://github.com/numba/numba/issues/4587"
2142+
),
2143+
),
2144+
pytest.param(
2145+
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
2146+
marks=pytest.mark.xfail(
2147+
reason="https://github.com/numba/numba/issues/4587"
2148+
),
2149+
),
21232150
lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(),
21242151
],
21252152
)
@@ -2164,17 +2191,9 @@ def test_rolling_functions_window_non_shrinkage_binary(self):
21642191
df_result = f(df)
21652192
tm.assert_frame_equal(df_result, df_expected)
21662193

2167-
def test_moment_functions_zero_length(self):
2168-
# GH 8056
2169-
s = Series()
2170-
s_expected = s
2171-
df1 = DataFrame()
2172-
df1_expected = df1
2173-
df2 = DataFrame(columns=["a"])
2174-
df2["a"] = df2["a"].astype("float64")
2175-
df2_expected = df2
2176-
2177-
functions = [
2194+
@pytest.mark.parametrize(
2195+
"f",
2196+
[
21782197
lambda x: x.expanding().count(),
21792198
lambda x: x.expanding(min_periods=5).cov(x, pairwise=False),
21802199
lambda x: x.expanding(min_periods=5).corr(x, pairwise=False),
@@ -2206,21 +2225,31 @@ def test_moment_functions_zero_length(self):
22062225
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False),
22072226
lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True),
22082227
lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(),
2209-
]
2210-
for f in functions:
2211-
try:
2212-
s_result = f(s)
2213-
tm.assert_series_equal(s_result, s_expected)
2228+
],
2229+
)
2230+
def test_moment_functions_zero_length(self, f):
2231+
# GH 8056
2232+
s = Series()
2233+
s_expected = s
2234+
df1 = DataFrame()
2235+
df1_expected = df1
2236+
df2 = DataFrame(columns=["a"])
2237+
df2["a"] = df2["a"].astype("float64")
2238+
df2_expected = df2
22142239

2215-
df1_result = f(df1)
2216-
tm.assert_frame_equal(df1_result, df1_expected)
2240+
try:
2241+
s_result = f(s)
2242+
tm.assert_series_equal(s_result, s_expected)
22172243

2218-
df2_result = f(df2)
2219-
tm.assert_frame_equal(df2_result, df2_expected)
2220-
except (ImportError):
2244+
df1_result = f(df1)
2245+
tm.assert_frame_equal(df1_result, df1_expected)
22212246

2222-
# scipy needed for rolling_window
2223-
continue
2247+
df2_result = f(df2)
2248+
tm.assert_frame_equal(df2_result, df2_expected)
2249+
except (ImportError):
2250+
2251+
# scipy needed for rolling_window
2252+
pass
22242253

22252254
def test_moment_functions_zero_length_pairwise(self):
22262255

pandas/tests/window/test_rolling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def test_constructor_with_timedelta_window(self, window):
7979
expected = df.rolling("3D").sum()
8080
tm.assert_frame_equal(result, expected)
8181

82+
@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/4587")
8283
@pytest.mark.parametrize("window", [timedelta(days=3), pd.Timedelta(days=3), "3D"])
8384
def test_constructor_timedelta_window_and_minperiods(self, window, raw):
8485
# GH 15305

0 commit comments

Comments
 (0)