Skip to content

Commit b15ae85

Browse files
committed
ENH: integrate cython ohlc in groupby and test, close #152
1 parent 41ddb8b commit b15ae85

File tree

4 files changed

+95
-23
lines changed

4 files changed

+95
-23
lines changed

Diff for: pandas/core/groupby.py

+47-23
Original file line numberDiff line numberDiff line change
@@ -308,21 +308,27 @@ def sum(self):
308308
except Exception:
309309
return self.aggregate(lambda x: np.sum(x, axis=self.axis))
310310

311+
def ohlc(self):
312+
"""
313+
Compute sum of values, excluding missing values
314+
315+
For multiple groupings, the result index will be a MultiIndex
316+
"""
317+
return self._cython_agg_general('ohlc')
318+
311319
def _cython_agg_general(self, how):
312320
output = {}
313321
for name, obj in self._iterate_slices():
314322
if not issubclass(obj.dtype.type, (np.number, np.bool_)):
315323
continue
316324

317-
obj = com._ensure_float64(obj)
318-
result, counts = self.grouper.aggregate(obj, how)
319-
mask = counts > 0
320-
output[name] = result[mask]
325+
result, names = self.grouper.aggregate(obj, how)
326+
output[name] = result
321327

322328
if len(output) == 0:
323329
raise GroupByError('No numeric types to aggregate')
324330

325-
return self._wrap_aggregated_output(output)
331+
return self._wrap_aggregated_output(output, names)
326332

327333
def _python_agg_general(self, func, *args, **kwargs):
328334
func = _intercept_function(func)
@@ -588,7 +594,13 @@ def get_group_levels(self):
588594
'std' : np.sqrt
589595
}
590596

597+
_name_functions = {
598+
'ohlc' : lambda *args: ['open', 'low', 'high', 'close']
599+
}
600+
591601
def aggregate(self, values, how):
602+
values = com._ensure_float64(values)
603+
592604
comp_ids, _, ngroups = self.group_info
593605
agg_func = self._cython_functions[how]
594606
if values.ndim == 1:
@@ -608,10 +620,18 @@ def aggregate(self, values, how):
608620
agg_func(result, counts, values, comp_ids)
609621
result = trans_func(result)
610622

623+
result = lib.row_bool_subset(result, counts > 0)
624+
611625
if squeeze:
612626
result = result.squeeze()
613627

614-
return result, counts
628+
if how in self._name_functions:
629+
# TODO
630+
names = self._name_functions[how]()
631+
else:
632+
names = None
633+
634+
return result, names
615635

616636
def agg_series(self, obj, func):
617637
try:
@@ -862,16 +882,18 @@ def agg_series(self, obj, func):
862882
}
863883

864884
def aggregate(self, values, how):
885+
values = com._ensure_float64(values)
886+
865887
agg_func = self._cython_functions[how]
866888
arity = self._cython_arity.get(how, 1)
867889

868890
if values.ndim == 1:
869891
squeeze = True
870892
values = values[:, None]
871-
out_shape = (self.ngroups, 1)
893+
out_shape = (self.ngroups, arity)
872894
else:
873895
squeeze = False
874-
out_shape = (self.ngroups, values.shape[1])
896+
out_shape = (self.ngroups, values.shape[1] * arity)
875897

876898
trans_func = self._cython_transforms.get(how, lambda x: x)
877899

@@ -882,10 +904,18 @@ def aggregate(self, values, how):
882904
agg_func(result, counts, values, self.bins)
883905
result = trans_func(result)
884906

907+
result = lib.row_bool_subset(result, counts > 0)
908+
885909
if squeeze:
886910
result = result.squeeze()
887911

888-
return result, counts
912+
if how in self._name_functions:
913+
# TODO
914+
names = self._name_functions[how]()
915+
else:
916+
names = None
917+
918+
return result, names
889919

890920
class Grouping(object):
891921
"""
@@ -1185,11 +1215,15 @@ def _aggregate_multiple_funcs(self, arg):
11851215

11861216
return DataFrame(results)
11871217

1188-
def _wrap_aggregated_output(self, output):
1218+
def _wrap_aggregated_output(self, output, names=None):
11891219
# sort of a kludge
11901220
output = output[self.name]
11911221
index = self.grouper.result_index
1192-
return Series(output, index=index, name=self.name)
1222+
1223+
if names is not None:
1224+
return DataFrame(output, index=index, columns=names)
1225+
else:
1226+
return Series(output, index=index, name=self.name)
11931227

11941228
def _wrap_applied_output(self, keys, values, not_indexed_same=False):
11951229
if len(keys) == 0:
@@ -1320,11 +1354,7 @@ def _cython_agg_general(self, how):
13201354
continue
13211355

13221356
values = com._ensure_float64(values)
1323-
result, counts = self.grouper.aggregate(values, how)
1324-
1325-
mask = counts > 0
1326-
if len(mask) > 0:
1327-
result = result[mask]
1357+
result, names = self.grouper.aggregate(values, how)
13281358
newb = make_block(result.T, block.items, block.ref_items)
13291359
new_blocks.append(newb)
13301360

@@ -1522,7 +1552,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
15221552

15231553
return DataFrame(result, columns=result_columns)
15241554

1525-
def _wrap_aggregated_output(self, output):
1555+
def _wrap_aggregated_output(self, output, names=None):
15261556
agg_axis = 0 if self.axis == 1 else 1
15271557
agg_labels = self._obj_with_exclusions._get_axis(agg_axis)
15281558

@@ -1930,12 +1960,6 @@ def numpy_groupby(data, labels, axis=0):
19301960
# Helper functions
19311961

19321962
def translate_grouping(how):
1933-
if set(how) == set('ohlc'):
1934-
return {'open' : lambda arr: arr[0],
1935-
'low' : lambda arr: arr.min(),
1936-
'high' : lambda arr: arr.max(),
1937-
'close' : lambda arr: arr[-1]}
1938-
19391963
if how in 'last':
19401964
def picker(arr):
19411965
return arr[-1] if arr is not None and len(arr) else np.nan

Diff for: pandas/src/groupby.pyx

+25
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,31 @@ def group_var_bin(ndarray[float64_t, ndim=2] out,
714714
out[i, j] = ((ct * sumxx[i, j] - sumx[i, j] * sumx[i, j]) /
715715
(ct * ct - ct))
716716

717+
718+
719+
@cython.boundscheck(False)
720+
@cython.wraparound(False)
721+
def row_bool_subset(ndarray[float64_t, ndim=2] values,
722+
ndarray[uint8_t, cast=True] mask):
723+
cdef:
724+
Py_ssize_t i, j, n, k, pos = 0
725+
ndarray[float64_t, ndim=2] out
726+
727+
n, k = (<object> values).shape
728+
assert(n == len(mask))
729+
730+
out = np.empty((mask.sum(), k), dtype=np.float64)
731+
732+
for i in range(n):
733+
if mask[i]:
734+
for j in range(k):
735+
out[pos, j] = values[i, j]
736+
pos += 1
737+
738+
return out
739+
740+
741+
717742
def group_count(ndarray[int32_t] values, Py_ssize_t size):
718743
cdef:
719744
Py_ssize_t i, n = len(values)

Diff for: pandas/src/sandbox.pyx

+1
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,4 @@ def backfill_int64(ndarray[int64_t] old, ndarray[int64_t] new,
476476
cur = prev
477477

478478
return indexer
479+

Diff for: pandas/tests/test_timeseries.py

+22
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,28 @@ def test_pad_require_monotonicity(self):
254254
self.assertRaises(AssertionError, rng2.get_indexer, rng,
255255
method='pad')
256256

257+
258+
def test_ohlc_5min(self):
259+
def _ohlc(group):
260+
if isnull(group).all():
261+
return np.repeat(np.nan, 4)
262+
return [group[0], group.min(), group.max(), group[-1]]
263+
264+
rng = date_range('1/1/2000 00:00:00', '1/1/2000 5:59:50',
265+
freq='10s')
266+
ts = Series(np.random.randn(len(rng)), index=rng)
267+
268+
converted = ts.convert('5min', how='ohlc')
269+
270+
self.assert_((converted.ix['1/1/2000 00:00'] == ts[0]).all())
271+
272+
exp = _ohlc(ts[1:31])
273+
self.assert_((converted.ix['1/1/2000 00:05'] == exp).all())
274+
275+
exp = _ohlc(ts['1/1/2000 5:55:01':])
276+
self.assert_((converted.ix['1/1/2000 6:00:00'] == exp).all())
277+
278+
257279
def _skip_if_no_pytz():
258280
try:
259281
import pytz

0 commit comments

Comments
 (0)