@@ -224,82 +224,106 @@ def test_pivot_with_tz(self):
224224 tm .assert_frame_equal (pv , expected )
225225
226226 def test_margins (self ):
227- def _check_output (res , col , index = ['A' , 'B' ], columns = ['C' ]):
228- cmarg = res ['All' ][:- 1 ]
229- exp = self .data .groupby (index )[col ].mean ()
230- tm .assert_series_equal (cmarg , exp , check_names = False )
231- self .assertEqual (cmarg .name , 'All' )
232-
233- res = res .sortlevel ()
234- rmarg = res .xs (('All' , '' ))[:- 1 ]
235- exp = self .data .groupby (columns )[col ].mean ()
236- tm .assert_series_equal (rmarg , exp , check_names = False )
237- self .assertEqual (rmarg .name , ('All' , '' ))
238-
239- gmarg = res ['All' ]['All' , '' ]
240- exp = self .data [col ].mean ()
241- self .assertEqual (gmarg , exp )
227+ def _check_output (result , values_col , index = ['A' , 'B' ],
228+ columns = ['C' ],
229+ margins_col = 'All' ):
230+ col_margins = result .ix [:- 1 , margins_col ]
231+ expected_col_margins = self .data .groupby (index )[values_col ].mean ()
232+ tm .assert_series_equal (col_margins , expected_col_margins ,
233+ check_names = False )
234+ self .assertEqual (col_margins .name , margins_col )
235+
236+ result = result .sortlevel ()
237+ index_margins = result .ix [(margins_col , '' )].iloc [:- 1 ]
238+ expected_ix_margins = self .data .groupby (columns )[values_col ].mean ()
239+ tm .assert_series_equal (index_margins , expected_ix_margins ,
240+ check_names = False )
241+ self .assertEqual (index_margins .name , (margins_col , '' ))
242+
243+ grand_total_margins = result .loc [(margins_col , '' ), margins_col ]
244+ expected_total_margins = self .data [values_col ].mean ()
245+ self .assertEqual (grand_total_margins , expected_total_margins )
242246
243247 # column specified
244- table = self .data .pivot_table ('D' , index = ['A' , 'B' ], columns = 'C' ,
245- margins = True , aggfunc = np .mean )
246- _check_output (table , 'D' )
248+ result = self .data .pivot_table (values = 'D' , index = ['A' , 'B' ],
249+ columns = 'C' ,
250+ margins = True , aggfunc = np .mean )
251+ _check_output (result , 'D' )
252+
253+ # Set a different margins_name (not 'All')
254+ result = self .data .pivot_table (values = 'D' , index = ['A' , 'B' ],
255+ columns = 'C' ,
256+ margins = True , aggfunc = np .mean ,
257+ margins_name = 'Totals' )
258+ _check_output (result , 'D' , margins_col = 'Totals' )
247259
248260 # no column specified
249261 table = self .data .pivot_table (index = ['A' , 'B' ], columns = 'C' ,
250262 margins = True , aggfunc = np .mean )
251- for valcol in table .columns .levels [0 ]:
252- _check_output (table [valcol ], valcol )
263+ for value_col in table .columns .levels [0 ]:
264+ _check_output (table [value_col ], value_col )
253265
254266 # no col
255267
256268 # to help with a buglet
257269 self .data .columns = [k * 2 for k in self .data .columns ]
258270 table = self .data .pivot_table (index = ['AA' , 'BB' ], margins = True ,
259271 aggfunc = np .mean )
260- for valcol in table .columns :
261- gmarg = table [valcol ]['All' , '' ]
262- self .assertEqual (gmarg , self .data [valcol ].mean ())
263-
264- # this is OK
265- table = self .data .pivot_table (index = ['AA' , 'BB' ], margins = True ,
266- aggfunc = 'mean' )
272+ for value_col in table .columns :
273+ totals = table .loc [('All' , '' ), value_col ]
274+ self .assertEqual (totals , self .data [value_col ].mean ())
267275
268276 # no rows
269277 rtable = self .data .pivot_table (columns = ['AA' , 'BB' ], margins = True ,
270278 aggfunc = np .mean )
271279 tm .assertIsInstance (rtable , Series )
280+
281+ table = self .data .pivot_table (index = ['AA' , 'BB' ], margins = True ,
282+ aggfunc = 'mean' )
272283 for item in ['DD' , 'EE' , 'FF' ]:
273- gmarg = table [ item ][ 'All' , '' ]
274- self .assertEqual (gmarg , self .data [item ].mean ())
284+ totals = table . loc [( 'All' , '' ), item ]
285+ self .assertEqual (totals , self .data [item ].mean ())
275286
276287 # issue number #8349: pivot_table with margins and dictionary aggfunc
288+ data = [
289+ {'JOB' : 'Worker' , 'NAME' : 'Bob' , 'YEAR' : 2013 ,
290+ 'MONTH' : 12 , 'DAYS' : 3 , 'SALARY' : 17 },
291+ {'JOB' : 'Employ' , 'NAME' :
292+ 'Mary' , 'YEAR' : 2013 , 'MONTH' : 12 , 'DAYS' : 5 , 'SALARY' : 23 },
293+ {'JOB' : 'Worker' , 'NAME' : 'Bob' , 'YEAR' : 2014 ,
294+ 'MONTH' : 1 , 'DAYS' : 10 , 'SALARY' : 100 },
295+ {'JOB' : 'Worker' , 'NAME' : 'Bob' , 'YEAR' : 2014 ,
296+ 'MONTH' : 1 , 'DAYS' : 11 , 'SALARY' : 110 },
297+ {'JOB' : 'Employ' , 'NAME' : 'Mary' , 'YEAR' : 2014 ,
298+ 'MONTH' : 1 , 'DAYS' : 15 , 'SALARY' : 200 },
299+ {'JOB' : 'Worker' , 'NAME' : 'Bob' , 'YEAR' : 2014 ,
300+ 'MONTH' : 2 , 'DAYS' : 8 , 'SALARY' : 80 },
301+ {'JOB' : 'Employ' , 'NAME' : 'Mary' , 'YEAR' : 2014 ,
302+ 'MONTH' : 2 , 'DAYS' : 5 , 'SALARY' : 190 },
303+ ]
277304
278- df = DataFrame ([ {'JOB' :'Worker' ,'NAME' :'Bob' ,'YEAR' :2013 ,'MONTH' :12 ,'DAYS' : 3 ,'SALARY' : 17 },
279- {'JOB' :'Employ' ,'NAME' :'Mary' ,'YEAR' :2013 ,'MONTH' :12 ,'DAYS' : 5 ,'SALARY' : 23 },
280- {'JOB' :'Worker' ,'NAME' :'Bob' ,'YEAR' :2014 ,'MONTH' : 1 ,'DAYS' :10 ,'SALARY' :100 },
281- {'JOB' :'Worker' ,'NAME' :'Bob' ,'YEAR' :2014 ,'MONTH' : 1 ,'DAYS' :11 ,'SALARY' :110 },
282- {'JOB' :'Employ' ,'NAME' :'Mary' ,'YEAR' :2014 ,'MONTH' : 1 ,'DAYS' :15 ,'SALARY' :200 },
283- {'JOB' :'Worker' ,'NAME' :'Bob' ,'YEAR' :2014 ,'MONTH' : 2 ,'DAYS' : 8 ,'SALARY' : 80 },
284- {'JOB' :'Employ' ,'NAME' :'Mary' ,'YEAR' :2014 ,'MONTH' : 2 ,'DAYS' : 5 ,'SALARY' :190 } ])
285-
286- df = df .set_index (['JOB' ,'NAME' ,'YEAR' ,'MONTH' ],drop = False ,append = False )
287-
288- rs = df .pivot_table ( index = ['JOB' ,'NAME' ],
289- columns = ['YEAR' ,'MONTH' ],
290- values = ['DAYS' ,'SALARY' ],
291- aggfunc = {'DAYS' :'mean' ,'SALARY' :'sum' },
292- margins = True )
305+ df = DataFrame (data )
293306
294- ex = df .pivot_table (index = ['JOB' ,'NAME' ],columns = ['YEAR' ,'MONTH' ],values = ['DAYS' ],aggfunc = 'mean' ,margins = True )
307+ df = df .set_index (['JOB' , 'NAME' , 'YEAR' , 'MONTH' ], drop = False ,
308+ append = False )
295309
296- tm .assert_frame_equal (rs ['DAYS' ], ex ['DAYS' ])
310+ result = df .pivot_table (index = ['JOB' , 'NAME' ],
311+ columns = ['YEAR' , 'MONTH' ],
312+ values = ['DAYS' , 'SALARY' ],
313+ aggfunc = {'DAYS' : 'mean' , 'SALARY' : 'sum' },
314+ margins = True )
297315
298- ex = df .pivot_table (index = ['JOB' ,'NAME' ],columns = ['YEAR' ,'MONTH' ],values = ['SALARY' ],aggfunc = 'sum' ,margins = True )
316+ expected = df .pivot_table (index = ['JOB' , 'NAME' ],
317+ columns = ['YEAR' , 'MONTH' ], values = ['DAYS' ],
318+ aggfunc = 'mean' , margins = True )
299319
300- tm .assert_frame_equal (rs [ 'SALARY ' ], ex [ 'SALARY ' ])
320+ tm .assert_frame_equal (result [ 'DAYS ' ], expected [ 'DAYS ' ])
301321
322+ expected = df .pivot_table (index = ['JOB' , 'NAME' ],
323+ columns = ['YEAR' , 'MONTH' ], values = ['SALARY' ],
324+ aggfunc = 'sum' , margins = True )
302325
326+ tm .assert_frame_equal (result ['SALARY' ], expected ['SALARY' ])
303327
304328 def test_pivot_integer_columns (self ):
305329 # caused by upstream bug in unstack
@@ -402,6 +426,25 @@ def test_margins_no_values_two_row_two_cols(self):
402426 result = self .data [['A' , 'B' , 'C' , 'D' ]].pivot_table (index = ['A' , 'B' ], columns = ['C' , 'D' ], aggfunc = len , margins = True )
403427 self .assertEqual (result .All .tolist (), [3.0 , 1.0 , 4.0 , 3.0 , 11.0 ])
404428
429+ def test_pivot_table_with_margins_set_margin_name (self ):
430+ # GH 3335
431+ for margin_name in ['foo' , 'one' , 666 , None , ['a' , 'b' ]]:
432+ with self .assertRaises (ValueError ):
433+ # multi-index index
434+ pivot_table (self .data , values = 'D' , index = ['A' , 'B' ],
435+ columns = ['C' ], margins = True ,
436+ margins_name = margin_name )
437+ with self .assertRaises (ValueError ):
438+ # multi-index column
439+ pivot_table (self .data , values = 'D' , index = ['C' ],
440+ columns = ['A' , 'B' ], margins = True ,
441+ margins_name = margin_name )
442+ with self .assertRaises (ValueError ):
443+ # non-multi-index index/column
444+ pivot_table (self .data , values = 'D' , index = ['A' ],
445+ columns = ['B' ], margins = True ,
446+ margins_name = margin_name )
447+
405448 def test_pivot_timegrouper (self ):
406449 df = DataFrame ({
407450 'Branch' : 'A A A A A A A B' .split (),
0 commit comments