@@ -308,21 +308,27 @@ def sum(self):
308
308
except Exception :
309
309
return self .aggregate (lambda x : np .sum (x , axis = self .axis ))
310
310
311
+ def ohlc (self ):
312
+ """
313
+ Compute sum of values, excluding missing values
314
+
315
+ For multiple groupings, the result index will be a MultiIndex
316
+ """
317
+ return self ._cython_agg_general ('ohlc' )
318
+
311
319
def _cython_agg_general (self , how ):
312
320
output = {}
313
321
for name , obj in self ._iterate_slices ():
314
322
if not issubclass (obj .dtype .type , (np .number , np .bool_ )):
315
323
continue
316
324
317
- obj = com ._ensure_float64 (obj )
318
- result , counts = self .grouper .aggregate (obj , how )
319
- mask = counts > 0
320
- output [name ] = result [mask ]
325
+ result , names = self .grouper .aggregate (obj , how )
326
+ output [name ] = result
321
327
322
328
if len (output ) == 0 :
323
329
raise GroupByError ('No numeric types to aggregate' )
324
330
325
- return self ._wrap_aggregated_output (output )
331
+ return self ._wrap_aggregated_output (output , names )
326
332
327
333
def _python_agg_general (self , func , * args , ** kwargs ):
328
334
func = _intercept_function (func )
@@ -588,7 +594,13 @@ def get_group_levels(self):
588
594
'std' : np .sqrt
589
595
}
590
596
597
+ _name_functions = {
598
+ 'ohlc' : lambda * args : ['open' , 'low' , 'high' , 'close' ]
599
+ }
600
+
591
601
def aggregate (self , values , how ):
602
+ values = com ._ensure_float64 (values )
603
+
592
604
comp_ids , _ , ngroups = self .group_info
593
605
agg_func = self ._cython_functions [how ]
594
606
if values .ndim == 1 :
@@ -608,10 +620,18 @@ def aggregate(self, values, how):
608
620
agg_func (result , counts , values , comp_ids )
609
621
result = trans_func (result )
610
622
623
+ result = lib .row_bool_subset (result , counts > 0 )
624
+
611
625
if squeeze :
612
626
result = result .squeeze ()
613
627
614
- return result , counts
628
+ if how in self ._name_functions :
629
+ # TODO
630
+ names = self ._name_functions [how ]()
631
+ else :
632
+ names = None
633
+
634
+ return result , names
615
635
616
636
def agg_series (self , obj , func ):
617
637
try :
@@ -862,16 +882,18 @@ def agg_series(self, obj, func):
862
882
}
863
883
864
884
def aggregate (self , values , how ):
885
+ values = com ._ensure_float64 (values )
886
+
865
887
agg_func = self ._cython_functions [how ]
866
888
arity = self ._cython_arity .get (how , 1 )
867
889
868
890
if values .ndim == 1 :
869
891
squeeze = True
870
892
values = values [:, None ]
871
- out_shape = (self .ngroups , 1 )
893
+ out_shape = (self .ngroups , arity )
872
894
else :
873
895
squeeze = False
874
- out_shape = (self .ngroups , values .shape [1 ])
896
+ out_shape = (self .ngroups , values .shape [1 ] * arity )
875
897
876
898
trans_func = self ._cython_transforms .get (how , lambda x : x )
877
899
@@ -882,10 +904,18 @@ def aggregate(self, values, how):
882
904
agg_func (result , counts , values , self .bins )
883
905
result = trans_func (result )
884
906
907
+ result = lib .row_bool_subset (result , counts > 0 )
908
+
885
909
if squeeze :
886
910
result = result .squeeze ()
887
911
888
- return result , counts
912
+ if how in self ._name_functions :
913
+ # TODO
914
+ names = self ._name_functions [how ]()
915
+ else :
916
+ names = None
917
+
918
+ return result , names
889
919
890
920
class Grouping (object ):
891
921
"""
@@ -1185,11 +1215,15 @@ def _aggregate_multiple_funcs(self, arg):
1185
1215
1186
1216
return DataFrame (results )
1187
1217
1188
- def _wrap_aggregated_output (self , output ):
1218
+ def _wrap_aggregated_output (self , output , names = None ):
1189
1219
# sort of a kludge
1190
1220
output = output [self .name ]
1191
1221
index = self .grouper .result_index
1192
- return Series (output , index = index , name = self .name )
1222
+
1223
+ if names is not None :
1224
+ return DataFrame (output , index = index , columns = names )
1225
+ else :
1226
+ return Series (output , index = index , name = self .name )
1193
1227
1194
1228
def _wrap_applied_output (self , keys , values , not_indexed_same = False ):
1195
1229
if len (keys ) == 0 :
@@ -1320,11 +1354,7 @@ def _cython_agg_general(self, how):
1320
1354
continue
1321
1355
1322
1356
values = com ._ensure_float64 (values )
1323
- result , counts = self .grouper .aggregate (values , how )
1324
-
1325
- mask = counts > 0
1326
- if len (mask ) > 0 :
1327
- result = result [mask ]
1357
+ result , names = self .grouper .aggregate (values , how )
1328
1358
newb = make_block (result .T , block .items , block .ref_items )
1329
1359
new_blocks .append (newb )
1330
1360
@@ -1522,7 +1552,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
1522
1552
1523
1553
return DataFrame (result , columns = result_columns )
1524
1554
1525
- def _wrap_aggregated_output (self , output ):
1555
+ def _wrap_aggregated_output (self , output , names = None ):
1526
1556
agg_axis = 0 if self .axis == 1 else 1
1527
1557
agg_labels = self ._obj_with_exclusions ._get_axis (agg_axis )
1528
1558
@@ -1930,12 +1960,6 @@ def numpy_groupby(data, labels, axis=0):
1930
1960
# Helper functions
1931
1961
1932
1962
def translate_grouping (how ):
1933
- if set (how ) == set ('ohlc' ):
1934
- return {'open' : lambda arr : arr [0 ],
1935
- 'low' : lambda arr : arr .min (),
1936
- 'high' : lambda arr : arr .max (),
1937
- 'close' : lambda arr : arr [- 1 ]}
1938
-
1939
1963
if how in 'last' :
1940
1964
def picker (arr ):
1941
1965
return arr [- 1 ] if arr is not None and len (arr ) else np .nan
0 commit comments