-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine for rolling apply #30151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
56 commits
Select commit
Hold shift + click to select a range
3b9bff8
Add numba to import_optional_dependencies
9a302bf
Start adding keywords
0e9a600
Modify apply for numba and cython
36a77ed
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
dbb2a9b
Add numba as optional dependency
f0e9a4d
Add premil tests
1250aee
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
4e7fd1a
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
cb976cf
Add numba to requirements-dev, type and reorder signature in apply
45420bb
Move numba routines to its own file
17851cf
Adjust signature in top level function as well
20767ca
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
9619f8d
Generate requirements-dev.txt using script
66fa69c
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
b8908ea
Add skip test decorator, add numba to a few builds
135f2ad
black
34a5687
don't rejit a user's jitted function
6da8199
Add numba/cython comparison test
123f77e
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
54e74d1
Remove typing for now
04d3530
Remove sub description for doc failures?
4bbf587
Fix test function
f849bc7
test user predefined jit function, clarify docstring
0c30e48
Apply engine kwargs to function as well
c4c952e
Clairfy documentation
8645976
Clarify what engine_kwargs applies to
987c916
Start section for numba rolling apply
b775684
Lint
2e04e60
clarify note
9b20ff5
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
0c14033
Add apply function cache to save compiled numba functions
c7106dc
Add performance example
1640085
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
2846faf
Remove whitespace
5a645c0
Address lint errors and separate apply tests
6bac000
Add whatsnew note
6f1c73f
Skip apply tests for numba not installed, lint
a890337
Add typing
0a9071c
Add more typing
9d8d40b
Formatting cleanups
84c3491
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
a429206
Address Jeff's comments
5826ad9
Black
cf7571b
Add clarification
4bc9787
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
18eed60
Move function to module level
f715b55
move cache check higher up
6a765bf
Address Will's comments
af3fe50
Type Callable in generate_numba_apply_func
eb7b5e1
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
f7dfcf4
use ellipsis, cannot specify np.ndarray as well
a42a960
Remove trailing whitespace in apply docstring
d019830
Address Will's and Brock's comments
29d145f
Fix typing
248149c
Merge remote-tracking branch 'upstream/master' into numba_rolling_apply
a3da51e
Address followup comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
"xlrd": "1.1.0", | ||
"xlwt": "1.2.0", | ||
"xlsxwriter": "0.9.8", | ||
"numba": "0.46.0", | ||
} | ||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1246,9 +1246,21 @@ def count(self): | |
objects instead. | ||
If you are just applying a NumPy reduction function this will | ||
achieve much better performance. | ||
|
||
*args, **kwargs | ||
Arguments and keyword arguments to be passed into func. | ||
args : tuple, default None | ||
Positional arguments to be passed into func | ||
kwargs : dict, default None | ||
Keyword arguments to be passed into func | ||
engine : str, default 'cython' | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Execution engine for the applied function. | ||
* ``'cython'`` : Runs rolling apply through C-extensions from cython. | ||
* ``'numba'`` : Runs rolling apply through JIT compiled code from numba. | ||
Only available when ``raw`` is set to ``True``. | ||
engine_kwargs : dict, default None | ||
Arguments to specify for the execution engine. | ||
* For ``'cython'`` engine, there are no accepted ``engine_kwargs`` | ||
* For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil`` | ||
and ``parallel``. The default ``engine_kwargs`` for the ``'numba'`` engine is | ||
``{'nopython': True, 'nogil': False, 'parallel': False}`` | ||
|
||
Returns | ||
------- | ||
|
@@ -1262,16 +1274,48 @@ def count(self): | |
""" | ||
) | ||
|
||
def apply(self, func, raw=False, args=(), kwargs={}): | ||
from pandas import Series | ||
|
||
def apply( | ||
self, | ||
func, | ||
raw=False, | ||
args=None, | ||
kwargs=None, | ||
engine="cython", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type new arguments |
||
engine_kwargs=None, | ||
): | ||
if args is None: | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
args = () | ||
if kwargs is None: | ||
kwargs = {} | ||
kwargs.pop("_level", None) | ||
kwargs.pop("floor", None) | ||
window = self._get_window() | ||
offset = _offset(window, self.center) | ||
if not is_bool(raw): | ||
raise ValueError("raw parameter must be `True` or `False`") | ||
|
||
if engine == "cython": | ||
if engine_kwargs is not None: | ||
raise ValueError("cython engine does not accept engine_kwargs") | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
apply_func = self._generate_cython_apply_func( | ||
args, kwargs, raw, offset, func | ||
) | ||
elif engine == "numba": | ||
if raw is False: | ||
raise ValueError("raw must be `True` when using the numba engine") | ||
apply_func = self._generate_numba_apply_func( | ||
args, kwargs, func, engine_kwargs | ||
) | ||
else: | ||
raise ValueError("engine must be either 'numba' or 'cython'") | ||
|
||
# TODO: Why do we always pass center=False? | ||
# name=func for WindowGroupByMixin._apply | ||
jbrockmendel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._apply(apply_func, center=False, floor=0, name=func) | ||
|
||
def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from pandas import Series | ||
|
||
window_func = partial( | ||
self._get_cython_func_type("roll_generic"), | ||
args=args, | ||
|
@@ -1286,9 +1330,76 @@ def apply_func(values, begin, end, min_periods, raw=raw): | |
values = Series(values, index=self.obj.index) | ||
return window_func(values, begin, end, min_periods) | ||
|
||
# TODO: Why do we always pass center=False? | ||
# name=func for WindowGroupByMixin._apply | ||
return self._apply(apply_func, center=False, floor=0, name=func) | ||
return apply_func | ||
|
||
def _generate_numba_apply_func(self, args, kwargs, func, engine_kwargs): | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
|
||
numba = import_optional_dependency("numba") | ||
|
||
if engine_kwargs is None: | ||
engine_kwargs = {"nopython": True, "nogil": False, "parallel": False} | ||
|
||
nopython = engine_kwargs.get("nopython", True) | ||
nogil = engine_kwargs.get("nogil", False) | ||
parallel = engine_kwargs.get("parallel", False) | ||
|
||
if kwargs and nopython: | ||
raise ValueError( | ||
"numba does not support kwargs with nopython=True: " | ||
"https://github.com/numba/numba/issues/2916" | ||
) | ||
|
||
if parallel: | ||
loop_range = numba.prange | ||
else: | ||
loop_range = range | ||
|
||
def make_rolling_apply(func): | ||
""" | ||
1. jit the user's function | ||
2. Return a rolling apply function with the jitted function inline | ||
|
||
Configurations specified in engine_kwargs apply to both the user's | ||
function _AND_ the rolling apply function. | ||
""" | ||
|
||
@numba.generated_jit(nopython=nopython) | ||
def numba_func(window, *_args): | ||
if getattr(np, func.__name__, False) is func: | ||
|
||
def impl(window, *_args): | ||
return func(window, *_args) | ||
|
||
return impl | ||
else: | ||
jf = numba.jit(func, nopython=nopython) | ||
|
||
def impl(window, *_args): | ||
return jf(window, *_args) | ||
|
||
return impl | ||
|
||
@numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
def roll_apply( | ||
values: np.ndarray, | ||
begin: np.ndarray, | ||
end: np.ndarray, | ||
minimum_periods: int, | ||
): | ||
result = np.empty(len(begin)) | ||
for i in loop_range(len(result)): | ||
start = begin[i] | ||
stop = end[i] | ||
window = values[start:stop] | ||
count_nan = np.sum(np.isnan(window)) | ||
if len(window) - count_nan >= minimum_periods: | ||
result[i] = numba_func(window, *args) | ||
else: | ||
result[i] = np.nan | ||
return result | ||
|
||
return roll_apply | ||
|
||
return make_rolling_apply(func) | ||
|
||
def sum(self, *args, **kwargs): | ||
nv.validate_window_func("sum", args, kwargs) | ||
|
@@ -1934,8 +2045,23 @@ def count(self): | |
|
||
@Substitution(name="rolling") | ||
@Appender(_shared_docs["apply"]) | ||
def apply(self, func, raw=False, args=(), kwargs={}): | ||
return super().apply(func, raw=raw, args=args, kwargs=kwargs) | ||
def apply( | ||
self, | ||
func, | ||
raw=False, | ||
args=None, | ||
kwargs=None, | ||
engine="cython", | ||
engine_kwargs=None, | ||
): | ||
return super().apply( | ||
func, | ||
raw=raw, | ||
args=args, | ||
kwargs=kwargs, | ||
engine=engine, | ||
engine_kwargs=engine_kwargs, | ||
) | ||
|
||
@Substitution(name="rolling") | ||
@Appender(_shared_docs["sum"]) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -342,3 +342,31 @@ def test_multiple_agg_funcs(self, func, window_size, expected_vals): | |
) | ||
|
||
tm.assert_frame_equal(result, expected) | ||
|
||
|
||
class TestEngine: | ||
jreback marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this class would be more logically placed in |
||
def test_invalid_engine(self): | ||
with pytest.raises( | ||
ValueError, match="engine must be either 'numba' or 'cython'" | ||
): | ||
Series(range(1)).rolling(1).apply(lambda x: x, engine="foo") | ||
|
||
def test_invalid_engine_kwargs_cython(self): | ||
with pytest.raises( | ||
ValueError, match="cython engine does not accept engine_kwargs" | ||
): | ||
Series(range(1)).rolling(1).apply( | ||
lambda x: x, engine="cython", engine_kwargs={"nopython": False} | ||
) | ||
|
||
def test_invalid_raw_numba(self): | ||
with pytest.raises( | ||
ValueError, match="raw must be `True` when using the numba engine" | ||
): | ||
Series(range(1)).rolling(1).apply(lambda x: x, raw=False, engine="numba") | ||
|
||
def test_invalid_kwargs_nopython(self): | ||
with pytest.raises(ValueError, match="numba does not support kwargs with"): | ||
Series(range(1)).rolling(1).apply( | ||
lambda x: x, kwargs={"a": 1}, engine="numba", raw=True | ||
) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args,kwargs should be at the end