@@ -774,7 +774,12 @@ class MPLPlot(object):
774774 data :
775775
776776 """
777- _kind = 'base'
777+
778+ @property
779+ def _kind (self ):
780+ """Specify kind str. Must be overridden in child class"""
781+ raise NotImplementedError
782+
778783 _layout_type = 'vertical'
779784 _default_rot = 0
780785 orientation = None
@@ -938,7 +943,10 @@ def generate(self):
938943 self ._make_plot ()
939944 self ._add_table ()
940945 self ._make_legend ()
941- self ._post_plot_logic ()
946+
947+ for ax in self .axes :
948+ self ._post_plot_logic_common (ax , self .data )
949+ self ._post_plot_logic (ax , self .data )
942950 self ._adorn_subplots ()
943951
944952 def _args_adjust (self ):
@@ -1055,12 +1063,34 @@ def _add_table(self):
10551063 ax = self ._get_ax (0 )
10561064 table (ax , data )
10571065
1058- def _post_plot_logic (self ):
1066+ def _post_plot_logic_common (self , ax , data ):
1067+ """Common post process for each axes"""
1068+ labels = [com .pprint_thing (key ) for key in data .index ]
1069+ labels = dict (zip (range (len (data .index )), labels ))
1070+
1071+ if self .orientation == 'vertical' or self .orientation is None :
1072+ if self ._need_to_set_index :
1073+ xticklabels = [labels .get (x , '' ) for x in ax .get_xticks ()]
1074+ ax .set_xticklabels (xticklabels )
1075+ self ._apply_axis_properties (ax .xaxis , rot = self .rot ,
1076+ fontsize = self .fontsize )
1077+ self ._apply_axis_properties (ax .yaxis , fontsize = self .fontsize )
1078+ elif self .orientation == 'horizontal' :
1079+ if self ._need_to_set_index :
1080+ yticklabels = [labels .get (y , '' ) for y in ax .get_yticks ()]
1081+ ax .set_yticklabels (yticklabels )
1082+ self ._apply_axis_properties (ax .yaxis , rot = self .rot ,
1083+ fontsize = self .fontsize )
1084+ self ._apply_axis_properties (ax .xaxis , fontsize = self .fontsize )
1085+ else : # pragma no cover
1086+ raise ValueError
1087+
1088+ def _post_plot_logic (self , ax , data ):
1089+ """Post process for each axes. Overridden in child classes"""
10591090 pass
10601091
10611092 def _adorn_subplots (self ):
1062- to_adorn = self .axes
1063-
1093+ """Common post process unrelated to data"""
10641094 if len (self .axes ) > 0 :
10651095 all_axes = self ._get_axes ()
10661096 nrows , ncols = self ._get_axes_layout ()
@@ -1069,7 +1099,7 @@ def _adorn_subplots(self):
10691099 ncols = ncols , sharex = self .sharex ,
10701100 sharey = self .sharey )
10711101
1072- for ax in to_adorn :
1102+ for ax in self . axes :
10731103 if self .yticks is not None :
10741104 ax .set_yticks (self .yticks )
10751105
@@ -1090,25 +1120,6 @@ def _adorn_subplots(self):
10901120 else :
10911121 self .axes [0 ].set_title (self .title )
10921122
1093- labels = [com .pprint_thing (key ) for key in self .data .index ]
1094- labels = dict (zip (range (len (self .data .index )), labels ))
1095-
1096- for ax in self .axes :
1097- if self .orientation == 'vertical' or self .orientation is None :
1098- if self ._need_to_set_index :
1099- xticklabels = [labels .get (x , '' ) for x in ax .get_xticks ()]
1100- ax .set_xticklabels (xticklabels )
1101- self ._apply_axis_properties (ax .xaxis , rot = self .rot ,
1102- fontsize = self .fontsize )
1103- self ._apply_axis_properties (ax .yaxis , fontsize = self .fontsize )
1104- elif self .orientation == 'horizontal' :
1105- if self ._need_to_set_index :
1106- yticklabels = [labels .get (y , '' ) for y in ax .get_yticks ()]
1107- ax .set_yticklabels (yticklabels )
1108- self ._apply_axis_properties (ax .yaxis , rot = self .rot ,
1109- fontsize = self .fontsize )
1110- self ._apply_axis_properties (ax .xaxis , fontsize = self .fontsize )
1111-
11121123 def _apply_axis_properties (self , axis , rot = None , fontsize = None ):
11131124 labels = axis .get_majorticklabels () + axis .get_minorticklabels ()
11141125 for label in labels :
@@ -1419,34 +1430,48 @@ def _get_axes_layout(self):
14191430 y_set .add (points [0 ][1 ])
14201431 return (len (y_set ), len (x_set ))
14211432
1422- class ScatterPlot (MPLPlot ):
1423- _kind = 'scatter'
1433+
1434+ class PlanePlot (MPLPlot ):
1435+ """
1436+ Abstract class for plotting on plane, currently scatter and hexbin.
1437+ """
1438+
14241439 _layout_type = 'single'
14251440
1426- def __init__ (self , data , x , y , c = None , ** kwargs ):
1441+ def __init__ (self , data , x , y , ** kwargs ):
14271442 MPLPlot .__init__ (self , data , ** kwargs )
14281443 if x is None or y is None :
1429- raise ValueError ( 'scatter requires and x and y column' )
1444+ raise ValueError (self . _kind + ' requires and x and y column' )
14301445 if com .is_integer (x ) and not self .data .columns .holds_integer ():
14311446 x = self .data .columns [x ]
14321447 if com .is_integer (y ) and not self .data .columns .holds_integer ():
14331448 y = self .data .columns [y ]
1434- if com .is_integer (c ) and not self .data .columns .holds_integer ():
1435- c = self .data .columns [c ]
14361449 self .x = x
14371450 self .y = y
1438- self .c = c
14391451
14401452 @property
14411453 def nseries (self ):
14421454 return 1
14431455
1456+ def _post_plot_logic (self , ax , data ):
1457+ x , y = self .x , self .y
1458+ ax .set_ylabel (com .pprint_thing (y ))
1459+ ax .set_xlabel (com .pprint_thing (x ))
1460+
1461+
1462+ class ScatterPlot (PlanePlot ):
1463+ _kind = 'scatter'
1464+
1465+ def __init__ (self , data , x , y , c = None , ** kwargs ):
1466+ super (ScatterPlot , self ).__init__ (data , x , y , ** kwargs )
1467+ if com .is_integer (c ) and not self .data .columns .holds_integer ():
1468+ c = self .data .columns [c ]
1469+ self .c = c
1470+
14441471 def _make_plot (self ):
14451472 import matplotlib as mpl
14461473 mpl_ge_1_3_1 = str (mpl .__version__ ) >= LooseVersion ('1.3.1' )
14471474
1448- import matplotlib .pyplot as plt
1449-
14501475 x , y , c , data = self .x , self .y , self .c , self .data
14511476 ax = self .axes [0 ]
14521477
@@ -1457,7 +1482,7 @@ def _make_plot(self):
14571482
14581483 # pandas uses colormap, matplotlib uses cmap.
14591484 cmap = self .colormap or 'Greys'
1460- cmap = plt .cm .get_cmap (cmap )
1485+ cmap = self . plt .cm .get_cmap (cmap )
14611486
14621487 if c is None :
14631488 c_values = self .plt .rcParams ['patch.facecolor' ]
@@ -1491,46 +1516,22 @@ def _make_plot(self):
14911516 err_kwds ['ecolor' ] = scatter .get_facecolor ()[0 ]
14921517 ax .errorbar (data [x ].values , data [y ].values , linestyle = 'none' , ** err_kwds )
14931518
1494- def _post_plot_logic (self ):
1495- ax = self .axes [0 ]
1496- x , y = self .x , self .y
1497- ax .set_ylabel (com .pprint_thing (y ))
1498- ax .set_xlabel (com .pprint_thing (x ))
1499-
15001519
1501- class HexBinPlot (MPLPlot ):
1520+ class HexBinPlot (PlanePlot ):
15021521 _kind = 'hexbin'
1503- _layout_type = 'single'
15041522
15051523 def __init__ (self , data , x , y , C = None , ** kwargs ):
1506- MPLPlot .__init__ (self , data , ** kwargs )
1507-
1508- if x is None or y is None :
1509- raise ValueError ('hexbin requires and x and y column' )
1510- if com .is_integer (x ) and not self .data .columns .holds_integer ():
1511- x = self .data .columns [x ]
1512- if com .is_integer (y ) and not self .data .columns .holds_integer ():
1513- y = self .data .columns [y ]
1514-
1524+ super (HexBinPlot , self ).__init__ (data , x , y , ** kwargs )
15151525 if com .is_integer (C ) and not self .data .columns .holds_integer ():
15161526 C = self .data .columns [C ]
1517-
1518- self .x = x
1519- self .y = y
15201527 self .C = C
15211528
1522- @property
1523- def nseries (self ):
1524- return 1
1525-
15261529 def _make_plot (self ):
1527- import matplotlib .pyplot as plt
1528-
15291530 x , y , data , C = self .x , self .y , self .data , self .C
15301531 ax = self .axes [0 ]
15311532 # pandas uses colormap, matplotlib uses cmap.
15321533 cmap = self .colormap or 'BuGn'
1533- cmap = plt .cm .get_cmap (cmap )
1534+ cmap = self . plt .cm .get_cmap (cmap )
15341535 cb = self .kwds .pop ('colorbar' , True )
15351536
15361537 if C is None :
@@ -1547,12 +1548,6 @@ def _make_plot(self):
15471548 def _make_legend (self ):
15481549 pass
15491550
1550- def _post_plot_logic (self ):
1551- ax = self .axes [0 ]
1552- x , y = self .x , self .y
1553- ax .set_ylabel (com .pprint_thing (y ))
1554- ax .set_xlabel (com .pprint_thing (x ))
1555-
15561551
15571552class LinePlot (MPLPlot ):
15581553 _kind = 'line'
@@ -1685,26 +1680,23 @@ def _update_stacker(cls, ax, stacking_id, values):
16851680 elif (values <= 0 ).all ():
16861681 ax ._stacker_neg_prior [stacking_id ] += values
16871682
1688- def _post_plot_logic (self ):
1689- df = self .data
1690-
1683+ def _post_plot_logic (self , ax , data ):
16911684 condition = (not self ._use_dynamic_x ()
1692- and df .index .is_all_dates
1685+ and data .index .is_all_dates
16931686 and not self .subplots
16941687 or (self .subplots and self .sharex ))
16951688
16961689 index_name = self ._get_index_name ()
16971690
1698- for ax in self .axes :
1699- if condition :
1700- # irregular TS rotated 30 deg. by default
1701- # probably a better place to check / set this.
1702- if not self ._rot_set :
1703- self .rot = 30
1704- format_date_labels (ax , rot = self .rot )
1691+ if condition :
1692+ # irregular TS rotated 30 deg. by default
1693+ # probably a better place to check / set this.
1694+ if not self ._rot_set :
1695+ self .rot = 30
1696+ format_date_labels (ax , rot = self .rot )
17051697
1706- if index_name is not None and self .use_index :
1707- ax .set_xlabel (index_name )
1698+ if index_name is not None and self .use_index :
1699+ ax .set_xlabel (index_name )
17081700
17091701
17101702class AreaPlot (LinePlot ):
@@ -1758,16 +1750,14 @@ def _add_legend_handle(self, handle, label, index=None):
17581750 handle = Rectangle ((0 , 0 ), 1 , 1 , fc = handle .get_color (), alpha = alpha )
17591751 LinePlot ._add_legend_handle (self , handle , label , index = index )
17601752
1761- def _post_plot_logic (self ):
1762- LinePlot ._post_plot_logic (self )
1753+ def _post_plot_logic (self , ax , data ):
1754+ LinePlot ._post_plot_logic (self , ax , data )
17631755
17641756 if self .ylim is None :
1765- if (self .data >= 0 ).all ().all ():
1766- for ax in self .axes :
1767- ax .set_ylim (0 , None )
1768- elif (self .data <= 0 ).all ().all ():
1769- for ax in self .axes :
1770- ax .set_ylim (None , 0 )
1757+ if (data >= 0 ).all ().all ():
1758+ ax .set_ylim (0 , None )
1759+ elif (data <= 0 ).all ().all ():
1760+ ax .set_ylim (None , 0 )
17711761
17721762
17731763class BarPlot (MPLPlot ):
@@ -1865,19 +1855,17 @@ def _make_plot(self):
18651855 start = start , label = label , log = self .log , ** kwds )
18661856 self ._add_legend_handle (rect , label , index = i )
18671857
1868- def _post_plot_logic (self ):
1869- for ax in self .axes :
1870- if self .use_index :
1871- str_index = [com .pprint_thing (key ) for key in self .data .index ]
1872- else :
1873- str_index = [com .pprint_thing (key ) for key in
1874- range (self .data .shape [0 ])]
1875- name = self ._get_index_name ()
1858+ def _post_plot_logic (self , ax , data ):
1859+ if self .use_index :
1860+ str_index = [com .pprint_thing (key ) for key in data .index ]
1861+ else :
1862+ str_index = [com .pprint_thing (key ) for key in range (data .shape [0 ])]
1863+ name = self ._get_index_name ()
18761864
1877- s_edge = self .ax_pos [0 ] - 0.25 + self .lim_offset
1878- e_edge = self .ax_pos [- 1 ] + 0.25 + self .bar_width + self .lim_offset
1865+ s_edge = self .ax_pos [0 ] - 0.25 + self .lim_offset
1866+ e_edge = self .ax_pos [- 1 ] + 0.25 + self .bar_width + self .lim_offset
18791867
1880- self ._decorate_ticks (ax , name , str_index , s_edge , e_edge )
1868+ self ._decorate_ticks (ax , name , str_index , s_edge , e_edge )
18811869
18821870 def _decorate_ticks (self , ax , name , ticklabels , start_edge , end_edge ):
18831871 ax .set_xlim ((start_edge , end_edge ))
@@ -1975,13 +1963,11 @@ def _make_plot_keywords(self, kwds, y):
19751963 kwds ['bins' ] = self .bins
19761964 return kwds
19771965
1978- def _post_plot_logic (self ):
1966+ def _post_plot_logic (self , ax , data ):
19791967 if self .orientation == 'horizontal' :
1980- for ax in self .axes :
1981- ax .set_xlabel ('Frequency' )
1968+ ax .set_xlabel ('Frequency' )
19821969 else :
1983- for ax in self .axes :
1984- ax .set_ylabel ('Frequency' )
1970+ ax .set_ylabel ('Frequency' )
19851971
19861972 @property
19871973 def orientation (self ):
@@ -2038,9 +2024,8 @@ def _make_plot_keywords(self, kwds, y):
20382024 kwds ['ind' ] = self ._get_ind (y )
20392025 return kwds
20402026
2041- def _post_plot_logic (self ):
2042- for ax in self .axes :
2043- ax .set_ylabel ('Density' )
2027+ def _post_plot_logic (self , ax , data ):
2028+ ax .set_ylabel ('Density' )
20442029
20452030
20462031class PiePlot (MPLPlot ):
@@ -2242,7 +2227,7 @@ def _set_ticklabels(self, ax, labels):
22422227 def _make_legend (self ):
22432228 pass
22442229
2245- def _post_plot_logic (self ):
2230+ def _post_plot_logic (self , ax , data ):
22462231 pass
22472232
22482233 @property
0 commit comments