Skip to content

Commit 9b9ea7a

Browse files
author
Matt Roeschke
committed
Merge branch 'feature/generalized_window_operations' into feature/rolling_apply_numba
2 parents 5f476d9 + f05c33a commit 9b9ea7a

File tree

7 files changed

+134
-57
lines changed

7 files changed

+134
-57
lines changed

asv_bench/asv.conf.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"xlwt": [],
5353
"odfpy": [],
5454
"pytest": [],
55+
"numba": [],
5556
// If using Windows with python 2.7 and want to build using the
5657
// mingw toolchain (rather than MSVC), uncomment the following line.
5758
// "libpython": [],

pandas/core/window/aggregators/kernels.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
from functools import partial
1+
"""
2+
Implementation of the rolling aggregations using jitclasses.
3+
4+
Some current difficulties as of numba 0.45.1:
5+
6+
1) jitclasses don't support inheritance, i.e. a base jitclass cannot be subclassed.
7+
8+
2) This implementation is not currently utilized because of
9+
inherent performance penalties.
10+
See https://github.com/numba/numba/issues/4522
11+
"""
12+
213
from typing import Optional
314

15+
import numba
416
import numpy as np
517

618
from pandas._typing import Scalar
@@ -56,13 +68,15 @@ class AggKernel:
5668
make_aggregator
5769
"""
5870

71+
def __init__(self):
72+
pass
73+
5974
def finalize(self):
6075
"""Return the final value of the aggregation."""
6176
raise NotImplementedError
6277

63-
@classmethod
6478
def make_aggregator(
65-
cls, values: np.ndarray, minimum_periods: int
79+
self, values: np.ndarray, minimum_periods: int
6680
) -> BaseAggregator:
6781
"""Return an aggregator that performs the aggregation calculation"""
6882
raise NotImplementedError
@@ -80,14 +94,30 @@ def invert(self, value) -> None:
8094
raise NotImplementedError
8195

8296

97+
agg_type = numba.deferred_type()
98+
99+
100+
base_aggregator_spec = (
101+
("values", numba.float64[:]),
102+
("min_periods", numba.uint64),
103+
("agg", agg_type),
104+
("previous_start", numba.int64),
105+
("previous_end", numba.int64),
106+
)
107+
108+
109+
@numba.jitclass(base_aggregator_spec)
83110
class SubtractableAggregator(BaseAggregator):
84111
"""
85112
Aggregator in which a current aggregated value
86113
is offset from a prior aggregated value.
87114
"""
88115

89116
def __init__(self, values: np.ndarray, min_periods: int, agg) -> None:
90-
super().__init__(values, min_periods)
117+
# Note: Numba doesn't like inheritance
118+
# super().__init__(values, min_periods)
119+
self.values = values
120+
self.min_periods = min_periods
91121
self.agg = agg
92122
self.previous_start = -1
93123
self.previous_end = -1
@@ -108,7 +138,8 @@ def query(self, start: int, stop: int) -> Optional[Scalar]:
108138
self.previous_end = stop
109139
if self.agg.count >= self.min_periods:
110140
return self.agg.finalize()
111-
return None
141+
# Numba wanted this to be None instead of None
142+
return np.nan
112143

113144

114145
class Sum(UnaryAggKernel):
@@ -140,32 +171,40 @@ def combine(self, other) -> None:
140171
self.total += other.total
141172
self.count += other.count
142173

143-
@classmethod
144-
def make_aggregator(cls, values: np.ndarray, min_periods: int) -> BaseAggregator:
145-
aggregator = SubtractableAggregator(values, min_periods, cls())
174+
def make_aggregator(self, values: np.ndarray, min_periods: int) -> BaseAggregator:
175+
aggregator = SubtractableAggregator(values, min_periods, self)
146176
return aggregator
147177

148178

179+
sum_spec = (("count", numba.uint64), ("total", numba.float64))
180+
181+
182+
@numba.jitclass(sum_spec)
149183
class Mean(Sum):
150184
def finalize(self) -> Optional[float]:
151185
if not self.count:
152186
return None
153187
return self.total / self.count
154188

155189

156-
def rolling_aggregation(
190+
agg_type.define(Mean.class_type.instance_type) # type: ignore
191+
192+
193+
aggregation_signature = (numba.float64[:], numba.int64[:], numba.int64[:], numba.int64)
194+
195+
196+
@numba.njit(aggregation_signature, nogil=True, parallel=True)
197+
def rolling_mean(
157198
values: np.ndarray,
158199
begin: np.ndarray,
159200
end: np.ndarray,
160201
minimum_periods: int,
161-
kernel_class,
202+
# kernel_class, Don't think I can define this in the signature in nopython mode
162203
) -> np.ndarray:
163204
"""Perform a generic rolling aggregation"""
164-
aggregator = kernel_class.make_aggregator(values, minimum_periods)
205+
aggregator = Mean().make_aggregator(values, minimum_periods)
206+
# aggregator = kernel_class().make_aggregator(values, minimum_periods)
165207
result = np.empty(len(begin))
166208
for i, (start, stop) in enumerate(zip(begin, end)):
167209
result[i] = aggregator.query(start, stop)
168210
return result
169-
170-
171-
rolling_mean = partial(rolling_aggregation, kernel_class=Mean)

pandas/core/window/aggregators/methods.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,39 @@
1+
"""
2+
Implementation of the rolling aggregations using njit methods.
3+
This implementation mimics what we currently do in cython except the
4+
calculation of window bounds is independent of the aggregation routine.
5+
"""
16
from typing import Callable
27

38
import numba
49
import numpy as np
510

611

12+
@numba.njit(nogil=True)
713
def rolling_mean(
814
values: np.ndarray, begin: np.ndarray, end: np.ndarray, minimum_periods: int
915
) -> np.ndarray:
16+
"""
17+
Compute a rolling mean over values.
18+
19+
Parameters
20+
----------
21+
values : ndarray[float64]
22+
values to roll over
23+
24+
begin : ndarray[int64]
25+
starting indexers
26+
27+
end : ndarray[int64]
28+
ending indexers
29+
30+
minimum_periods : ndarray[float64]
31+
minimum
32+
33+
Returns
34+
-------
35+
ndarray[float64]
36+
"""
1037
result = np.empty(len(begin))
1138
previous_start = -1
1239
previous_end = -1

pandas/core/window/indexers.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import Optional, Sequence, Tuple, Union
22

3+
import numba
34
import numpy as np
45

56
from pandas.tseries.offsets import DateOffset
67

78
BeginEnd = Tuple[np.ndarray, np.ndarray]
89

10+
baseindexer_spec = (("index", numba.optional(numba.int64[:])),)
11+
912

1013
class BaseIndexer:
1114
"""Base class for window bounds calculations"""
@@ -22,16 +25,11 @@ def __init__(
2225
index : ndarray[int64], default None
2326
pandas index to reference in the window bound calculation
2427
25-
offset: str or DateOffset, default None
26-
Offset used to calcuate the window boundary
27-
28-
keys: np.ndarray, default None
29-
Additional columns needed to calculate the window bounds
30-
3128
"""
3229
self.index = index
33-
self.offset = offset
34-
self.keys = keys
30+
# TODO: How to effectively types these in Numba to run in nopython?
31+
# self.offset = offset
32+
# self.keys = keys
3533

3634
def get_window_bounds(
3735
self,
@@ -74,6 +72,7 @@ def get_window_bounds(
7472
raise NotImplementedError
7573

7674

75+
@numba.jitclass(baseindexer_spec)
7776
class FixedWindowIndexer(BaseIndexer):
7877
"""Calculate window boundaries that have a fixed window size"""
7978

@@ -97,16 +96,16 @@ def get_window_bounds(
9796
(array([0, 0, 1, 1, 2]), array([1, 2, 3, 4, 5]))
9897
"""
9998
start_s = np.zeros(window_size, dtype=np.int64)
100-
start_e = np.arange(1, num_values - window_size + 1, dtype=np.int64)
101-
start = np.concatenate([start_s, start_e])
99+
start_e = np.arange(1, num_values - window_size + 1)
100+
start = np.concatenate((start_s, start_e))
102101

103-
end = np.arange(1, num_values + 1, dtype=np.int64)
104-
if window_size > num_values:
105-
start = start[:num_values]
106-
end = end[:num_values]
102+
end = np.arange(1, num_values + 1)
103+
start = start[:num_values]
104+
end = end[:num_values]
107105
return start, end
108106

109107

108+
@numba.jitclass(baseindexer_spec)
110109
class VariableWindowIndexer(BaseIndexer):
111110
"""
112111
Calculate window boundaries with variable closed boundaries and index dependent

pandas/tests/window/test_timeseries_window.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
import pandas.compat as compat
5+
46
from pandas import DataFrame, Index, Series, Timestamp, date_range, to_datetime
57
import pandas.util.testing as tm
68

@@ -577,7 +579,27 @@ def test_all_apply(self, raw):
577579
expected = er.apply(lambda x: 1, raw=raw)
578580
tm.assert_frame_equal(result, expected)
579581

580-
def test_all2(self):
582+
@pytest.mark.parametrize(
583+
"func",
584+
[
585+
"sum",
586+
pytest.param(
587+
"mean",
588+
marks=pytest.mark.skipif(
589+
compat.is_platform_32bit(), reason="Numba fails here for 32 bit"
590+
),
591+
),
592+
"count",
593+
"median",
594+
"std",
595+
"var",
596+
"kurt",
597+
"skew",
598+
"min",
599+
"max",
600+
],
601+
)
602+
def test_all2(self, func):
581603

582604
# more sophisticated comparison of integer vs.
583605
# time-based windowing
@@ -589,36 +611,21 @@ def test_all2(self):
589611

590612
r = dft.rolling(window="5H")
591613

592-
for f in [
593-
"sum",
594-
"mean",
595-
"count",
596-
"median",
597-
"std",
598-
"var",
599-
"kurt",
600-
"skew",
601-
"min",
602-
"max",
603-
]:
604-
605-
result = getattr(r, f)()
614+
result = getattr(r, func)()
606615

607-
# we need to roll the days separately
608-
# to compare with a time-based roll
609-
# finally groupby-apply will return a multi-index
610-
# so we need to drop the day
611-
def agg_by_day(x):
612-
x = x.between_time("09:00", "16:00")
613-
return getattr(x.rolling(5, min_periods=1), f)()
616+
# we need to roll the days separately
617+
# to compare with a time-based roll
618+
# finally groupby-apply will return a multi-index
619+
# so we need to drop the day
620+
def agg_by_day(x):
621+
x = x.between_time("09:00", "16:00")
622+
return getattr(x.rolling(5, min_periods=1), func)()
614623

615-
expected = (
616-
df.groupby(df.index.day)
617-
.apply(agg_by_day)
618-
.reset_index(level=0, drop=True)
619-
)
624+
expected = (
625+
df.groupby(df.index.day).apply(agg_by_day).reset_index(level=0, drop=True)
626+
)
620627

621-
tm.assert_frame_equal(result, expected)
628+
tm.assert_frame_equal(result, expected)
622629

623630
def test_groupby_monotonic(self):
624631

@@ -671,6 +678,9 @@ def test_non_monotonic(self):
671678
result = df2.groupby("A").rolling("4s", on="B").C.mean()
672679
tm.assert_series_equal(result, expected)
673680

681+
@pytest.mark.skipif(
682+
compat.is_platform_32bit(), reason="Numba fails here for 32 bit"
683+
)
674684
def test_rolling_cov_offset(self):
675685
# GH16058
676686

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ known_dtypes = pandas.core.dtypes
116116
known_post_core = pandas.tseries,pandas.io,pandas.plotting
117117
sections = FUTURE,STDLIB,THIRDPARTY,PRE_LIBS,PRE_CORE,DTYPES,FIRSTPARTY,POST_CORE,LOCALFOLDER
118118
known_first_party = pandas
119-
known_third_party = _pytest,announce,dateutil,docutils,flake8,git,hypothesis,jinja2,lxml,matplotlib,numpy,numpydoc,pkg_resources,pyarrow,pytest,pytz,requests,scipy,setuptools,sphinx,sqlalchemy,validate_docstrings,yaml
119+
known_third_party = _pytest,announce,dateutil,docutils,flake8,git,hypothesis,jinja2,lxml,matplotlib,numpy,numpydoc,pkg_resources,pyarrow,pytest,pytz,requests,scipy,setuptools,sphinx,sqlalchemy,validate_docstrings,yaml,numba
120120
multi_line_output = 3
121121
include_trailing_comma = True
122122
force_grid_wrap = 0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def is_platform_mac():
3939
"python-dateutil >= 2.6.1",
4040
"pytz >= 2017.2",
4141
"numpy >= {numpy_ver}".format(numpy_ver=min_numpy_ver),
42+
"numba >= 0.45.1"
4243
],
4344
"setup_requires": ["numpy >= {numpy_ver}".format(numpy_ver=min_numpy_ver)],
4445
"zip_safe": False,

0 commit comments

Comments
 (0)