@@ -308,21 +308,27 @@ def sum(self):
308308 except Exception :
309309 return self .aggregate (lambda x : np .sum (x , axis = self .axis ))
310310
311+ def ohlc (self ):
312+ """
313+ Compute sum of values, excluding missing values
314+
315+ For multiple groupings, the result index will be a MultiIndex
316+ """
317+ return self ._cython_agg_general ('ohlc' )
318+
311319 def _cython_agg_general (self , how ):
312320 output = {}
313321 for name , obj in self ._iterate_slices ():
314322 if not issubclass (obj .dtype .type , (np .number , np .bool_ )):
315323 continue
316324
317- obj = com ._ensure_float64 (obj )
318- result , counts = self .grouper .aggregate (obj , how )
319- mask = counts > 0
320- output [name ] = result [mask ]
325+ result , names = self .grouper .aggregate (obj , how )
326+ output [name ] = result
321327
322328 if len (output ) == 0 :
323329 raise GroupByError ('No numeric types to aggregate' )
324330
325- return self ._wrap_aggregated_output (output )
331+ return self ._wrap_aggregated_output (output , names )
326332
327333 def _python_agg_general (self , func , * args , ** kwargs ):
328334 func = _intercept_function (func )
@@ -588,7 +594,13 @@ def get_group_levels(self):
588594 'std' : np .sqrt
589595 }
590596
597+ _name_functions = {
598+ 'ohlc' : lambda * args : ['open' , 'low' , 'high' , 'close' ]
599+ }
600+
591601 def aggregate (self , values , how ):
602+ values = com ._ensure_float64 (values )
603+
592604 comp_ids , _ , ngroups = self .group_info
593605 agg_func = self ._cython_functions [how ]
594606 if values .ndim == 1 :
@@ -608,10 +620,18 @@ def aggregate(self, values, how):
608620 agg_func (result , counts , values , comp_ids )
609621 result = trans_func (result )
610622
623+ result = lib .row_bool_subset (result , counts > 0 )
624+
611625 if squeeze :
612626 result = result .squeeze ()
613627
614- return result , counts
628+ if how in self ._name_functions :
629+ # TODO
630+ names = self ._name_functions [how ]()
631+ else :
632+ names = None
633+
634+ return result , names
615635
616636 def agg_series (self , obj , func ):
617637 try :
@@ -862,16 +882,18 @@ def agg_series(self, obj, func):
862882 }
863883
864884 def aggregate (self , values , how ):
885+ values = com ._ensure_float64 (values )
886+
865887 agg_func = self ._cython_functions [how ]
866888 arity = self ._cython_arity .get (how , 1 )
867889
868890 if values .ndim == 1 :
869891 squeeze = True
870892 values = values [:, None ]
871- out_shape = (self .ngroups , 1 )
893+ out_shape = (self .ngroups , arity )
872894 else :
873895 squeeze = False
874- out_shape = (self .ngroups , values .shape [1 ])
896+ out_shape = (self .ngroups , values .shape [1 ] * arity )
875897
876898 trans_func = self ._cython_transforms .get (how , lambda x : x )
877899
@@ -882,10 +904,18 @@ def aggregate(self, values, how):
882904 agg_func (result , counts , values , self .bins )
883905 result = trans_func (result )
884906
907+ result = lib .row_bool_subset (result , counts > 0 )
908+
885909 if squeeze :
886910 result = result .squeeze ()
887911
888- return result , counts
912+ if how in self ._name_functions :
913+ # TODO
914+ names = self ._name_functions [how ]()
915+ else :
916+ names = None
917+
918+ return result , names
889919
890920class Grouping (object ):
891921 """
@@ -1185,11 +1215,15 @@ def _aggregate_multiple_funcs(self, arg):
11851215
11861216 return DataFrame (results )
11871217
1188- def _wrap_aggregated_output (self , output ):
1218+ def _wrap_aggregated_output (self , output , names = None ):
11891219 # sort of a kludge
11901220 output = output [self .name ]
11911221 index = self .grouper .result_index
1192- return Series (output , index = index , name = self .name )
1222+
1223+ if names is not None :
1224+ return DataFrame (output , index = index , columns = names )
1225+ else :
1226+ return Series (output , index = index , name = self .name )
11931227
11941228 def _wrap_applied_output (self , keys , values , not_indexed_same = False ):
11951229 if len (keys ) == 0 :
@@ -1320,11 +1354,7 @@ def _cython_agg_general(self, how):
13201354 continue
13211355
13221356 values = com ._ensure_float64 (values )
1323- result , counts = self .grouper .aggregate (values , how )
1324-
1325- mask = counts > 0
1326- if len (mask ) > 0 :
1327- result = result [mask ]
1357+ result , names = self .grouper .aggregate (values , how )
13281358 newb = make_block (result .T , block .items , block .ref_items )
13291359 new_blocks .append (newb )
13301360
@@ -1522,7 +1552,7 @@ def _aggregate_item_by_item(self, func, *args, **kwargs):
15221552
15231553 return DataFrame (result , columns = result_columns )
15241554
1525- def _wrap_aggregated_output (self , output ):
1555+ def _wrap_aggregated_output (self , output , names = None ):
15261556 agg_axis = 0 if self .axis == 1 else 1
15271557 agg_labels = self ._obj_with_exclusions ._get_axis (agg_axis )
15281558
@@ -1930,12 +1960,6 @@ def numpy_groupby(data, labels, axis=0):
19301960# Helper functions
19311961
19321962def translate_grouping (how ):
1933- if set (how ) == set ('ohlc' ):
1934- return {'open' : lambda arr : arr [0 ],
1935- 'low' : lambda arr : arr .min (),
1936- 'high' : lambda arr : arr .max (),
1937- 'close' : lambda arr : arr [- 1 ]}
1938-
19391963 if how in 'last' :
19401964 def picker (arr ):
19411965 return arr [- 1 ] if arr is not None and len (arr ) else np .nan
0 commit comments