From c968927f9716d9309fc251e21acb8f314d050805 Mon Sep 17 00:00:00 2001 From: "alex.hang.zhao@gmail.com" Date: Wed, 15 Mar 2023 16:46:19 +0800 Subject: [PATCH 1/2] Update plot_util.py Adjust the plot_results function so that both the size and position of the subplot can be adjusted according to the number of legends. --- RLA/easy_plot/plot_util.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/RLA/easy_plot/plot_util.py b/RLA/easy_plot/plot_util.py index fbfd475..e844756 100644 --- a/RLA/easy_plot/plot_util.py +++ b/RLA/easy_plot/plot_util.py @@ -411,6 +411,11 @@ def plot_results( nrows = largest_divisor ncols = int(math.ceil(N / nrows)) figsize = figsize or (7 * ncols, 6 * nrows) + + legend_height = 0.12 + subfigsize = (1 / ncols, figsize[1] / (figsize[1] + len(allresults) * legend_height) / nrows) + figsize = (figsize[0], figsize[1] + len(allresults) * legend_height) + # if legend_outside: # figsize = list(figsize) # figsize[0] += 4 @@ -614,8 +619,8 @@ def allequal(qs): loc=2 if legend_outside else None, bbox_to_anchor=(1,1) if legend_outside else None, fontsize=15 if pretty else None) - if legend_outside: - f.subplots_adjust(bottom=0.12, left=0.12) + # if legend_outside: + # f.subplots_adjust(bottom=0.12, left=0.12) else: lgd = f.legend( legend_lines, @@ -626,9 +631,9 @@ def allequal(qs): ncol=legend_ncol, fancybox=True, borderaxespad=0.0) - if legend_outside: - leg_rows = np.ceil(len(legend_lines) / legend_ncol) - f.subplots_adjust(bottom=(0.2 + 0.05 * (leg_rows - 1)) / nrows) + # if legend_outside: + # leg_rows = np.ceil(len(legend_lines) / legend_ncol) + # f.subplots_adjust(bottom=(0.2 + 0.05 * (leg_rows - 1)) / nrows) else: lgd = None @@ -665,6 +670,21 @@ def allequal(qs): plt.ylim(ylim) if xlim is not None: plt.xlim(xlim) + + if legend_outside: + if len(sk2r.keys()) == 1: + plt.gcf().subplots_adjust(bottom=0.12, left=0.12) + else: + for isplit in range(ncols * nrows): + idx_row = isplit // ncols + idx_col = isplit % ncols + col_scale = 0.8 + row_scale = 0.8 + ax = axarr[idx_row][idx_col] + ax.set_position([idx_col * subfigsize[0] + (1 - col_scale) / (ncols * 2), + 1 - (idx_row + 1) * subfigsize[1] + (1 - row_scale) / (nrows * 2), + col_scale * subfigsize[0], row_scale * subfigsize[1]]) + # plt.gcf().subplots_adjust(bottom=0.12, left=0.12) # if pretty: # from matplotlib.ticker import ScalarFormatter # xfmt = ScalarFormatter(useMathText=True) From b1b27bb1d552175bc32b26ca3e3360d7dd11cd43 Mon Sep 17 00:00:00 2001 From: "alex.hang.zhao@gmail.com" Date: Tue, 21 Mar 2023 10:09:54 +0800 Subject: [PATCH 2/2] fix the bug that legend covers figure --- RLA/easy_plot/plot_util.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/RLA/easy_plot/plot_util.py b/RLA/easy_plot/plot_util.py index e844756..abfc8e1 100644 --- a/RLA/easy_plot/plot_util.py +++ b/RLA/easy_plot/plot_util.py @@ -412,9 +412,10 @@ def plot_results( ncols = int(math.ceil(N / nrows)) figsize = figsize or (7 * ncols, 6 * nrows) - legend_height = 0.12 - subfigsize = (1 / ncols, figsize[1] / (figsize[1] + len(allresults) * legend_height) / nrows) - figsize = (figsize[0], figsize[1] + len(allresults) * legend_height) + legend_height = 0.11 + figpercentage = (1, figsize[1] / (figsize[1] + len(allresults) * legend_height)) + subfigsize = (figpercentage[0] / ncols, figpercentage[1] / nrows) + figsize = (figsize[0], figsize[1] + len(allresults) * legend_height) # if legend_outside: # figsize = list(figsize) @@ -678,11 +679,11 @@ def allequal(qs): for isplit in range(ncols * nrows): idx_row = isplit // ncols idx_col = isplit % ncols - col_scale = 0.8 - row_scale = 0.8 + col_scale = 0.78 + row_scale = 0.78 ax = axarr[idx_row][idx_col] ax.set_position([idx_col * subfigsize[0] + (1 - col_scale) / (ncols * 2), - 1 - (idx_row + 1) * subfigsize[1] + (1 - row_scale) / (nrows * 2), + 1 - (idx_row + 1) * subfigsize[1] + figpercentage[1] * (1 - row_scale) / (nrows * 2), col_scale * subfigsize[0], row_scale * subfigsize[1]]) # plt.gcf().subplots_adjust(bottom=0.12, left=0.12) # if pretty: