@@ -286,7 +286,7 @@ def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=(8.0, 6.0)
286286 axes_num : number
287287 expected number of axes. Unnecessary axes should be set to invisible.
288288 layout : tuple
289- expected layout
289+ expected layout, (expected number of rows , columns)
290290 figsize : tuple
291291 expected figsize. default is matplotlib default
292292 """
@@ -299,17 +299,22 @@ def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=(8.0, 6.0)
299299 self .assertTrue (len (ax .get_children ()) > 0 )
300300
301301 if layout is not None :
302- if isinstance (axes , list ):
303- self .assertEqual ((len (axes ), ), layout )
304- elif isinstance (axes , np .ndarray ):
305- self .assertEqual (axes .shape , layout )
306- else :
307- # in case of AxesSubplot
308- self .assertEqual ((1 , ), layout )
302+ result = self ._get_axes_layout (plotting ._flatten (axes ))
303+ self .assertEqual (result , layout )
309304
310305 self .assert_numpy_array_equal (np .round (visible_axes [0 ].figure .get_size_inches ()),
311306 np .array (figsize ))
312307
308+ def _get_axes_layout (self , axes ):
309+ x_set = set ()
310+ y_set = set ()
311+ for ax in axes :
312+ # check axes coordinates to estimate layout
313+ points = ax .get_position ().get_points ()
314+ x_set .add (points [0 ][0 ])
315+ y_set .add (points [0 ][1 ])
316+ return (len (y_set ), len (x_set ))
317+
313318 def _flatten_visible (self , axes ):
314319 """
315320 Flatten axes, and filter only visible
@@ -401,14 +406,14 @@ def test_plot(self):
401406
402407 # GH 6951
403408 ax = _check_plot_works (self .ts .plot , subplots = True )
404- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ))
409+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ))
405410
406411 @slow
407412 def test_plot_figsize_and_title (self ):
408413 # figsize and title
409414 ax = self .series .plot (title = 'Test' , figsize = (16 , 8 ))
410415 self ._check_text_labels (ax .title , 'Test' )
411- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ), figsize = (16 , 8 ))
416+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ), figsize = (16 , 8 ))
412417
413418 def test_ts_area_lim (self ):
414419 ax = self .ts .plot (kind = 'area' , stacked = False )
@@ -556,10 +561,10 @@ def test_hist_layout_with_by(self):
556561 df = self .hist_df
557562
558563 axes = _check_plot_works (df .height .hist , by = df .gender , layout = (2 , 1 ))
559- self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , ), figsize = (10 , 5 ))
564+ self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , 1 ), figsize = (10 , 5 ))
560565
561566 axes = _check_plot_works (df .height .hist , by = df .category , layout = (4 , 1 ))
562- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ), figsize = (10 , 5 ))
567+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ), figsize = (10 , 5 ))
563568
564569 axes = _check_plot_works (df .height .hist , by = df .classroom , layout = (2 , 2 ))
565570 self ._check_axes_shape (axes , axes_num = 3 , layout = (2 , 2 ), figsize = (10 , 5 ))
@@ -757,9 +762,9 @@ def test_plot(self):
757762 df = self .tdf
758763 _check_plot_works (df .plot , grid = False )
759764 axes = _check_plot_works (df .plot , subplots = True )
760- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
765+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
761766 _check_plot_works (df .plot , subplots = True , use_index = False )
762- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
767+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
763768
764769 df = DataFrame ({'x' : [1 , 2 ], 'y' : [3 , 4 ]})
765770 with tm .assertRaises (TypeError ):
@@ -774,7 +779,7 @@ def test_plot(self):
774779 _check_plot_works (df .plot , ylim = (- 100 , 100 ), xlim = (- 100 , 100 ))
775780
776781 axes = _check_plot_works (df .plot , subplots = True , title = 'blah' )
777- self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , ))
782+ self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , 1 ))
778783
779784 _check_plot_works (df .plot , title = 'blah' )
780785
@@ -804,7 +809,7 @@ def test_plot(self):
804809 # Test with single column
805810 df = DataFrame ({'x' : np .random .rand (10 )})
806811 axes = _check_plot_works (df .plot , kind = 'bar' , subplots = True )
807- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
812+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
808813
809814 def test_nonnumeric_exclude (self ):
810815 df = DataFrame ({'A' : ["x" , "y" , "z" ], 'B' : [1 , 2 , 3 ]})
@@ -846,7 +851,7 @@ def test_plot_xy(self):
846851 # figsize and title
847852 ax = df .plot (x = 1 , y = 2 , title = 'Test' , figsize = (16 , 8 ))
848853 self ._check_text_labels (ax .title , 'Test' )
849- self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , ), figsize = (16. , 8. ))
854+ self ._check_axes_shape (ax , axes_num = 1 , layout = (1 , 1 ), figsize = (16. , 8. ))
850855
851856 # columns.inferred_type == 'mixed'
852857 # TODO add MultiIndex test
@@ -913,7 +918,7 @@ def test_subplots(self):
913918
914919 for kind in ['bar' , 'barh' , 'line' , 'area' ]:
915920 axes = df .plot (kind = kind , subplots = True , sharex = True , legend = True )
916- self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , ))
921+ self ._check_axes_shape (axes , axes_num = 3 , layout = (3 , 1 ))
917922
918923 for ax , column in zip (axes , df .columns ):
919924 self ._check_legend_labels (ax , labels = [com .pprint_thing (column )])
@@ -1081,7 +1086,7 @@ def test_bar_linewidth(self):
10811086
10821087 # subplots
10831088 axes = df .plot (kind = 'bar' , linewidth = 2 , subplots = True )
1084- self ._check_axes_shape (axes , axes_num = 5 , layout = (5 , ))
1089+ self ._check_axes_shape (axes , axes_num = 5 , layout = (5 , 1 ))
10851090 for ax in axes :
10861091 for r in ax .patches :
10871092 self .assertEqual (r .get_linewidth (), 2 )
@@ -1179,7 +1184,7 @@ def test_plot_scatter(self):
11791184
11801185 # GH 6951
11811186 axes = df .plot (x = 'x' , y = 'y' , kind = 'scatter' , subplots = True )
1182- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
1187+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
11831188
11841189 @slow
11851190 def test_plot_bar (self ):
@@ -1486,7 +1491,7 @@ def test_kde(self):
14861491 self ._check_legend_labels (ax , labels = expected )
14871492
14881493 axes = _check_plot_works (df .plot , kind = 'kde' , subplots = True )
1489- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ))
1494+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ))
14901495
14911496 axes = df .plot (kind = 'kde' , logy = True , subplots = True )
14921497 self ._check_ax_scales (axes , yaxis = 'log' )
@@ -1949,8 +1954,7 @@ def test_hexbin_basic(self):
19491954 # hexbin should have 2 axes in the figure, 1 for plotting and another is colorbar
19501955 self .assertEqual (len (axes [0 ].figure .axes ), 2 )
19511956 # return value is single axes
1952- self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , ))
1953-
1957+ self ._check_axes_shape (axes , axes_num = 1 , layout = (1 , 1 ))
19541958
19551959 @slow
19561960 def test_hexbin_with_c (self ):
@@ -2193,31 +2197,31 @@ class TestDataFrameGroupByPlots(TestPlotBase):
21932197 def test_boxplot (self ):
21942198 grouped = self .hist_df .groupby (by = 'gender' )
21952199 box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2196- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2200+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
21972201
21982202 box = _check_plot_works (grouped .boxplot , subplots = False ,
21992203 return_type = 'dict' )
2200- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2204+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
22012205
22022206 tuples = lzip (string .ascii_letters [:10 ], range (10 ))
22032207 df = DataFrame (np .random .rand (10 , 3 ),
22042208 index = MultiIndex .from_tuples (tuples ))
22052209
22062210 grouped = df .groupby (level = 1 )
22072211 box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2208- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 )
2212+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 , layout = ( 4 , 3 ) )
22092213
22102214 box = _check_plot_works (grouped .boxplot , subplots = False ,
22112215 return_type = 'dict' )
2212- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 )
2216+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 10 , layout = ( 4 , 3 ) )
22132217
22142218 grouped = df .unstack (level = 1 ).groupby (level = 0 , axis = 1 )
22152219 box = _check_plot_works (grouped .boxplot , return_type = 'dict' )
2216- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2220+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
22172221
22182222 box = _check_plot_works (grouped .boxplot , subplots = False ,
22192223 return_type = 'dict' )
2220- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2224+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
22212225
22222226 def test_series_plot_color_kwargs (self ):
22232227 # GH1890
@@ -2327,35 +2331,35 @@ def test_grouped_box_layout(self):
23272331
23282332 box = _check_plot_works (df .groupby ('gender' ).boxplot , column = 'height' ,
23292333 return_type = 'dict' )
2330- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 )
2334+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 2 , layout = ( 1 , 2 ) )
23312335
23322336 box = _check_plot_works (df .groupby ('category' ).boxplot , column = 'height' ,
23332337 return_type = 'dict' )
2334- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 )
2338+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 , layout = ( 2 , 2 ) )
23352339
23362340 # GH 6769
23372341 box = _check_plot_works (df .groupby ('classroom' ).boxplot ,
23382342 column = 'height' , return_type = 'dict' )
2339- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2343+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
23402344
23412345 box = df .boxplot (column = ['height' , 'weight' , 'category' ], by = 'gender' )
2342- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2346+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
23432347
23442348 box = df .groupby ('classroom' ).boxplot (
23452349 column = ['height' , 'weight' , 'category' ], return_type = 'dict' )
2346- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2350+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 2 , 2 ) )
23472351
23482352 box = _check_plot_works (df .groupby ('category' ).boxplot , column = 'height' ,
23492353 layout = (3 , 2 ), return_type = 'dict' )
2350- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 )
2354+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 4 , layout = ( 3 , 2 ) )
23512355
23522356 box = df .boxplot (column = ['height' , 'weight' , 'category' ], by = 'gender' , layout = (4 , 1 ))
2353- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2357+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 4 , 1 ) )
23542358
23552359 box = df .groupby ('classroom' ).boxplot (
23562360 column = ['height' , 'weight' , 'category' ], layout = (1 , 4 ),
23572361 return_type = 'dict' )
2358- self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 )
2362+ self ._check_axes_shape (self .plt .gcf ().axes , axes_num = 3 , layout = ( 1 , 4 ) )
23592363
23602364 @slow
23612365 def test_grouped_hist_layout (self ):
@@ -2367,10 +2371,10 @@ def test_grouped_hist_layout(self):
23672371 layout = (1 , 3 ))
23682372
23692373 axes = _check_plot_works (df .hist , column = 'height' , by = df .gender , layout = (2 , 1 ))
2370- self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , ), figsize = (10 , 5 ))
2374+ self ._check_axes_shape (axes , axes_num = 2 , layout = (2 , 1 ), figsize = (10 , 5 ))
23712375
23722376 axes = _check_plot_works (df .hist , column = 'height' , by = df .category , layout = (4 , 1 ))
2373- self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , ), figsize = (10 , 5 ))
2377+ self ._check_axes_shape (axes , axes_num = 4 , layout = (4 , 1 ), figsize = (10 , 5 ))
23742378
23752379 axes = _check_plot_works (df .hist , column = 'height' , by = df .category ,
23762380 layout = (4 , 2 ), figsize = (12 , 8 ))
0 commit comments