Skip to content

Commit 41ddb8b

Browse files
committed
ENH: implement Cython OHLC function for groupby #152
1 parent 91d8453 commit 41ddb8b

File tree

5 files changed

+119
-6
lines changed

5 files changed

+119
-6
lines changed

Diff for: pandas/core/groupby.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -853,11 +853,18 @@ def agg_series(self, obj, func):
853853
'add' : lib.group_add_bin,
854854
'mean' : lib.group_mean_bin,
855855
'var' : lib.group_var_bin,
856-
'std' : lib.group_var_bin
856+
'std' : lib.group_var_bin,
857+
'ohlc' : lib.group_ohlc
858+
}
859+
860+
_cython_arity = {
861+
'ohlc' : 4, # OHLC
857862
}
858863

859864
def aggregate(self, values, how):
860865
agg_func = self._cython_functions[how]
866+
arity = self._cython_arity.get(how, 1)
867+
861868
if values.ndim == 1:
862869
squeeze = True
863870
values = values[:, None]

Diff for: pandas/src/engines.pyx

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,15 @@ PyDateTime_IMPORT
2727
cdef extern from "Python.h":
2828
int PySlice_Check(object)
2929
int PyList_Check(object)
30-
30+
int PyTuple_Check(object)
3131

3232
cdef inline is_definitely_invalid_key(object val):
33+
if PyTuple_Check(val):
34+
try:
35+
hash(val)
36+
except TypeError:
37+
return True
38+
3339
return (PySlice_Check(val) or cnp.PyArray_Check(val)
3440
or PyList_Check(val))
3541

Diff for: pandas/src/groupby.pyx

+72
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,78 @@ def group_add_bin(ndarray[float64_t, ndim=2] out,
532532
else:
533533
out[i, j] = sumx[i, j]
534534

535+
536+
@cython.boundscheck(False)
537+
@cython.wraparound(False)
538+
def group_ohlc(ndarray[float64_t, ndim=2] out,
539+
ndarray[int32_t] counts,
540+
ndarray[float64_t, ndim=2] values,
541+
ndarray[int32_t] bins):
542+
'''
543+
Only aggregates on axis=0
544+
'''
545+
cdef:
546+
Py_ssize_t i, j, N, K, ngroups, b
547+
float64_t val, count
548+
float64_t vopen, vhigh, vlow, vclose, NA
549+
bint got_first = 0
550+
551+
ngroups = len(bins) + 1
552+
N, K = (<object> values).shape
553+
554+
if out.shape[1] != 4:
555+
raise ValueError('Output array must have 4 columns')
556+
557+
NA = np.nan
558+
559+
b = 0
560+
if K > 1:
561+
raise NotImplementedError
562+
else:
563+
for i in range(N):
564+
if b < ngroups - 1 and i >= bins[b]:
565+
if not got_first:
566+
out[b, 0] = NA
567+
out[b, 1] = NA
568+
out[b, 2] = NA
569+
out[b, 3] = NA
570+
else:
571+
out[b, 0] = vopen
572+
out[b, 1] = vlow
573+
out[b, 2] = vhigh
574+
out[b, 3] = vclose
575+
b += 1
576+
got_first = 0
577+
578+
counts[b] += 1
579+
val = values[i, 0]
580+
581+
# not nan
582+
if val == val:
583+
if not got_first:
584+
got_first = 1
585+
vopen = val
586+
vlow = val
587+
vhigh = val
588+
else:
589+
if val < vlow:
590+
vlow = val
591+
if val > vhigh:
592+
vhigh = val
593+
vclose = val
594+
595+
if not got_first:
596+
out[b, 0] = NA
597+
out[b, 1] = NA
598+
out[b, 2] = NA
599+
out[b, 3] = NA
600+
else:
601+
out[b, 0] = vopen
602+
out[b, 1] = vlow
603+
out[b, 2] = vhigh
604+
out[b, 3] = vclose
605+
606+
535607
@cython.boundscheck(False)
536608
@cython.wraparound(False)
537609
def group_mean_bin(ndarray[float64_t, ndim=2] out,

Diff for: pandas/tests/test_tseries.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
import numpy as np
4-
from pandas import Index
4+
from pandas import Index, isnull
55
from pandas.util.testing import assert_almost_equal
66
import pandas.util.testing as common
77
import pandas._tseries as lib
@@ -317,7 +317,7 @@ def test_group_add_bin():
317317
# bin-based group_add
318318
bins = np.array([3, 6], dtype=np.int32)
319319
out = np.zeros((3, 1), np.float64)
320-
counts = np.empty(len(out), dtype=np.int32)
320+
counts = np.zeros(len(out), dtype=np.int32)
321321
lib.group_add_bin(out, counts, obj, bins)
322322

323323
assert_almost_equal(out, exp)
@@ -334,7 +334,7 @@ def test_group_mean_bin():
334334
# bin-based group_mean
335335
bins = np.array([3, 6], dtype=np.int32)
336336
out = np.zeros((3, 1), np.float64)
337-
counts = np.empty(len(out), dtype=np.int32)
337+
counts = np.zeros(len(out), dtype=np.int32)
338338
lib.group_mean_bin(out, counts, obj, bins)
339339

340340
assert_almost_equal(out, exp)
@@ -351,12 +351,37 @@ def test_group_var_bin():
351351
# bin-based group_var
352352
bins = np.array([3, 6], dtype=np.int32)
353353
out = np.zeros((3, 1), np.float64)
354-
counts = np.empty(len(out), dtype=np.int32)
354+
counts = np.zeros(len(out), dtype=np.int32)
355355

356356
lib.group_var_bin(out, counts, obj, bins)
357357

358358
assert_almost_equal(out, exp)
359359

360+
def test_group_ohlc():
361+
obj = np.random.randn(20)
362+
363+
bins = np.array([6, 12], dtype=np.int32)
364+
out = np.zeros((3, 4), np.float64)
365+
counts = np.zeros(len(out), dtype=np.int32)
366+
367+
lib.group_ohlc(out, counts, obj[:, None], bins)
368+
369+
def _ohlc(group):
370+
if isnull(group).all():
371+
return np.repeat(np.nan, 4)
372+
return [group[0], group.min(), group.max(), group[-1]]
373+
374+
expected = np.array([_ohlc(obj[:6]), _ohlc(obj[6:12]),
375+
_ohlc(obj[12:])])
376+
377+
assert_almost_equal(out, expected)
378+
assert_almost_equal(counts, [6, 6, 8])
379+
380+
obj[:6] = np.nan
381+
lib.group_ohlc(out, counts, obj[:, None], bins)
382+
expected[0] = np.nan
383+
assert_almost_equal(out, expected)
384+
360385
class TestTypeInference(unittest.TestCase):
361386

362387
def test_length_zero(self):

Diff for: ts_todo.txt

+3
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,6 @@ gp- get rid of Ts class, simplify timestamp creation
1515
- attach tz in DatetimeIndex.asobject
1616
- failing duplicate timestamp test
1717
- _tseries.pyd depends on datetime.pyx
18+
19+
20+
- BUG: time_rule DateRange tests

0 commit comments

Comments
 (0)