diff --git a/RLA/easy_plot/plot_util.py b/RLA/easy_plot/plot_util.py index fbfd475..e844756 100644 --- a/RLA/easy_plot/plot_util.py +++ b/RLA/easy_plot/plot_util.py @@ -411,6 +411,11 @@ def plot_results( nrows = largest_divisor ncols = int(math.ceil(N / nrows)) figsize = figsize or (7 * ncols, 6 * nrows) + + legend_height = 0.12 + subfigsize = (1 / ncols, figsize[1] / (figsize[1] + len(allresults) * legend_height) / nrows) + figsize = (figsize[0], figsize[1] + len(allresults) * legend_height) + # if legend_outside: # figsize = list(figsize) # figsize[0] += 4 @@ -614,8 +619,8 @@ def allequal(qs): loc=2 if legend_outside else None, bbox_to_anchor=(1,1) if legend_outside else None, fontsize=15 if pretty else None) - if legend_outside: - f.subplots_adjust(bottom=0.12, left=0.12) + # if legend_outside: + # f.subplots_adjust(bottom=0.12, left=0.12) else: lgd = f.legend( legend_lines, @@ -626,9 +631,9 @@ def allequal(qs): ncol=legend_ncol, fancybox=True, borderaxespad=0.0) - if legend_outside: - leg_rows = np.ceil(len(legend_lines) / legend_ncol) - f.subplots_adjust(bottom=(0.2 + 0.05 * (leg_rows - 1)) / nrows) + # if legend_outside: + # leg_rows = np.ceil(len(legend_lines) / legend_ncol) + # f.subplots_adjust(bottom=(0.2 + 0.05 * (leg_rows - 1)) / nrows) else: lgd = None @@ -665,6 +670,21 @@ def allequal(qs): plt.ylim(ylim) if xlim is not None: plt.xlim(xlim) + + if legend_outside: + if len(sk2r.keys()) == 1: + plt.gcf().subplots_adjust(bottom=0.12, left=0.12) + else: + for isplit in range(ncols * nrows): + idx_row = isplit // ncols + idx_col = isplit % ncols + col_scale = 0.8 + row_scale = 0.8 + ax = axarr[idx_row][idx_col] + ax.set_position([idx_col * subfigsize[0] + (1 - col_scale) / (ncols * 2), + 1 - (idx_row + 1) * subfigsize[1] + (1 - row_scale) / (nrows * 2), + col_scale * subfigsize[0], row_scale * subfigsize[1]]) + # plt.gcf().subplots_adjust(bottom=0.12, left=0.12) # if pretty: # from matplotlib.ticker import ScalarFormatter # xfmt = ScalarFormatter(useMathText=True)