Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ We manage the items in the database via toolkits in rla_scripts. Currently, the
3. Send to remote [TODO]
4. Download from remote [TODO]

We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database. We give several user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_scripts.py
We can use the above tools after copying the rla_scripts to our research project and modifying the DATA_ROOT in config.py to locate the root of the RLA database. We give several user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_scripts.py.
We will also introduce our practices of using RLA to manage our projects .

## Practices of managing the experiments via RLA scripts


## Distributed training & centralized logs
Expand Down
11 changes: 5 additions & 6 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
use_buf=False, verbose=True,
x_bound: Optional[int]=None,
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[str] = None,
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[Union[str, list]] = None,
scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None,
key_to_legend_fn: Optional[Callable] = default_key_to_legend,
split_by_metrics=True,
Expand Down Expand Up @@ -61,7 +61,7 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
:param xlabel: set the label of the y axes.
:type xlabel: Optional[str]
:param ylabel: set the label of the y axes.
:type ylabel: Optional[str]
:type ylabel: Optional[str,list]
:param scale_dict: a function dict, to map the value of the metrics through customize functions.
e.g.,set metrics = ['return'], scale_dict = {'return': lambda x: np.log(x)}, then we will plot a log-scale return.
:type scale_dict: Optional[dict]
Expand Down Expand Up @@ -112,15 +112,15 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
ylabel = metrics

if regs2legends is not None:
assert len(regs2legends) == len(regs) and len(metrics) == 1, \
assert len(regs2legends) == len(regs), \
"In manual legend-key mode, the number of keys should be one-to-one matched with regs"
# if len(regs2legends) == len(regs):
group_fn = lambda r: split_by_reg(taskpath=r, reg_group=reg_group, y_names=y_names)
else:
group_fn = lambda r: picture_split(taskpath=r, split_keys=split_keys, y_names=y_names,
key_to_legend_fn=lambda parse_dict, split_keys, y_name:
key_to_legend_fn(parse_dict, split_keys, y_name, not split_by_metrics))
_, _, lgd, texts, g2lf, score_results = \
fig, _, lgd, texts, g2lf, score_results = \
plot_util.plot_results(results, xy_fn= lambda r, y_names: csv_to_xy(r, DEFAULT_X_NAME, y_names, final_scale_dict),
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, metrics=metrics,
split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs)
Expand All @@ -129,7 +129,6 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
import os
file_name = osp.join(data_root, OTHER_RESULTS, 'easy_plot', save_name)
os.makedirs(os.path.dirname(file_name), exist_ok=True)

if lgd is not None:
plt.savefig(file_name, bbox_extra_artists=tuple([lgd] + texts), bbox_inches='tight')
else:
Expand All @@ -146,7 +145,7 @@ def split_by_reg(taskpath, reg_group, y_names):
if taskpath.dirname == result.dirname:
assert task_split_key == "None", "one experiment should belong to only one reg_group"
task_split_key = str(i)
assert len(y_names) == 1
# assert len(y_names) == 1
return task_split_key, y_names

def split_by_task(taskpath, split_keys, y_names, key_to_legend_fn):
Expand Down
133 changes: 94 additions & 39 deletions RLA/easy_plot/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def load_results(root_dir_or_dirs, names, x_bound, enable_progress=True, use_buf
'brown', 'lightblue', 'lime', 'lavender', 'turquoise',
'tan', 'salmon', 'darkred', 'darkblue', 'gold']
PRETTY_COLORS = ['orangered', 'royalblue', 'forestgreen', 'orange', 'deeppink', 'deepskyblue']
PRETTY_MARKERS = ['o', 'v', 'd', 'P', 'p', ]

def default_xy_fn(r, y_name):
x = np.cumsum(r.monitor.l)
Expand Down Expand Up @@ -315,6 +316,9 @@ def plot_results(
show_number=True,
skip_legend=False,
split_by_metrics=False,
use_marker=False,
markers=None,
legend_ncol=1,
rescale_idx=None):
'''
Plot multiple Results objects
Expand All @@ -333,7 +337,7 @@ def plot_results(
(if average_group is True). The default value is the same as default value for split_fn

average_group: bool - if True, will average the curves in the same group and plot the mean. Enables resampling
(if resample = 0, will use 512 steps)
(if resample = 0, will use all steps)

shaded_std: bool - if True (default), the shaded region corresponding to standard deviation of the group of curves will be
shown (only applicable if average_group = True)
Expand All @@ -356,7 +360,14 @@ def plot_results(
smooth_step: float - when resampling (i.e. when resample > 0 or average_group is True), use this EMA decay parameter (in units of the new grid step).
See docstrings for decay_steps in symmetric_ema or one_sided_ema functions.

use_marker: bool - if True, will use marker to the legend.

markers: list - customize your marker map through a list. PRETTY_MARKER by default.


colors: list - customize your corlor map through a list. COLORS/PRETTY_COLORS by default.

legend_ncol: int - The number of columns that the legend has.
'''
score_results = {}
if vary_len_plot:
Expand All @@ -366,9 +377,11 @@ def plot_results(
colors = PRETTY_COLORS
else:
colors = COLORS
if markers is None:
markers = PRETTY_MARKERS

if pretty:
assert not split_by_metrics or len(metrics) == 1, "pretty mode cannot support the multiply metric plotting. Please use only one metric for plotting."
# if pretty:
# assert not split_by_metrics or len(metrics) == 1, "pretty mode cannot support the multiply metric plotting. Please use only one metric for plotting."
split_by_metrics = split_by_metrics and len(metrics) != 1
if split_fn is None: split_fn = lambda _ : ''
if group_fn is None: group_fn = lambda _ : ''
Expand Down Expand Up @@ -414,18 +427,20 @@ def plot_results(
# if average_group:
# resample = resample or default_samples
lgd = None

texts = []
if split_by_metrics:
keys = metrics
else:
keys = sorted(sk2r.keys())
for (isplit, sk) in enumerate(keys):
for (kid, sk) in enumerate(keys):
g2l = {}
g2lf = {}
g2c = defaultdict(int)
sresults = sk2r[sk]
gresults = defaultdict(list)
idx_row = isplit // ncols
idx_col = isplit % ncols
idx_row = kid // ncols
idx_col = kid % ncols
ax = axarr[idx_row][idx_col]
for result in sresults:
result_groups, y_names = group_fn(result)
Expand Down Expand Up @@ -459,6 +474,7 @@ def plot_results(
if not any(xys):
continue
color = colors[groups.index(group) % len(colors)]
marker = markers[groups.index(group) % len(markers)]
origxs = [xy[0] for xy in xys]
maxlen = max(map(len, origxs))
minxlen = min(map(len, origxs))
Expand Down Expand Up @@ -494,18 +510,51 @@ def allequal(qs):
ymin = np.nanmin(ys, axis=0)
ymax = np.nanmax(ys, axis=0)
ystderr = ystd / np.sqrt(len(ys))
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
if use_marker:
if resample:
marker_every = resample // 10
else:
marker_every = len(ymean) // 10
if marker_every == 0:
marker_every = 1
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color, marker=marker, markevery=marker_every,
markersize=10)
else:
l, = axarr[idx_row][idx_col].plot(usex, ymean, color=color)
g2l[group] = l
if shaded_err:
g2lf[group + '-se'] = [ax.fill_between(usex, ymean - ystderr, ymean + ystderr, color=color, alpha=.2), ymean, ystderr]
if shaded_std:
g2lf[group + '-ss'] = [ax.fill_between(usex, ymean - ystd, ymean + ystd, color=color, alpha=.2), ymean, ystd]
if shaded_range:
g2lf[group + '-sr'] = [ax.fill_between(usex, ymin, ymax, color=color, alpha=.1), ymin, ymax]

ax.set_title(sk)
if not pretty:
ax.set_title(sk)
if split_by_metrics:
ax.set_ylabel(sk)
if ylabel is None:
select_label = sk
elif type(ylabel) == str:
select_label = ylabel
elif type(ylabel) == list:
assert split_by_metrics, "list-wise ylabel can only support in split_by_metrics mode."
select_label = ylabel[kid]
else:
raise NotImplementedError
if pretty:
ax.set_ylabel(select_label, fontsize=18)
else:
ax.set_ylabel(select_label)
if pretty:
from matplotlib.ticker import ScalarFormatter
xfmt = ScalarFormatter(useMathText=True)
xfmt.set_powerlimits((-4, 4)) # Or whatever your limits are . . .
ax.yaxis.set_major_formatter(xfmt)
ax.yaxis.offsetText.set_fontsize(15)
ax.xaxis.set_major_formatter(xfmt)
ax.xaxis.offsetText.set_fontsize(15)
ax.yaxis.set_tick_params(labelsize=16)
ax.xaxis.set_tick_params(labelsize=16)

if log:
ax.set_yscale('log')

Expand All @@ -524,17 +573,12 @@ def allequal(qs):
legend_lines = legend_lines[sorted_index]
if regs2legends is not None:
legend_keys = np.array(regs2legends)
# if replace_legend_sort is not None:
# sorted_index = replace_legend_sort
# else:
# sorted_index = np.argsort(legend_keys)
# assert legend_keys.shape[0] == legend_lines.shape[0], \
# "The number of lines is not consistent with the keys"
# legend_keys = legend_keys[sorted_index]
# legend_lines = legend_lines[sorted_index]
if pretty:
for index, l in enumerate(legend_lines):
l.update(props={"color": colors[index % len(colors)]})
if use_marker:
l.update(props={"color": colors[index % len(colors)]})
else:
l.update(props={"color": colors[index % len(colors)]})
original_legend_keys = np.array(['%s' % (g) for g in g2l] if average_group else g2l.keys())
original_legend_keys = original_legend_keys[sorted_index]
if shaded_err:
Expand All @@ -554,6 +598,7 @@ def allequal(qs):
score_results[legend_keys[index]+'-range'] = [res[1][-1], res[2][-1]]

if bound_line is not None:
assert len(sk2r.keys()) == 1, "bound line can only support the single-metric mode."
for bl in bound_line:
y = np.ones(x.shape) * bl[0]
l, = ax.plot(x, y, bl[2], color=bl[1])
Expand All @@ -569,14 +614,21 @@ def allequal(qs):
loc=2 if legend_outside else None,
bbox_to_anchor=(1,1) if legend_outside else None,
fontsize=15 if pretty else None)
if legend_outside:
f.subplots_adjust(bottom=0.12, left=0.12)
else:
lgd = f.legend(
legend_lines,
legend_keys,
loc='lower center' if legend_outside else None,
bbox_to_anchor=(0.5, 0.0) if legend_outside else None,
fontsize=15 if pretty else None,
borderaxespad=0)
ncol=legend_ncol,
fancybox=True,
borderaxespad=0.0)
if legend_outside:
leg_rows = np.ceil(len(legend_lines) / legend_ncol)
f.subplots_adjust(bottom=(0.2 + 0.05 * (leg_rows - 1)) / nrows)
else:
lgd = None

Expand All @@ -586,7 +638,7 @@ def allequal(qs):
for ax in axarr[-1]:
plt.sca(ax)
if pretty:
plt.xlabel(xlabel, fontsize=20)
texts.append(plt.xlabel(xlabel, fontsize=20))
else:
plt.xlabel(xlabel)
# add ylabels, but only to left column
Expand All @@ -596,33 +648,36 @@ def allequal(qs):
# plt.ylabel(ylabel)
# plt.sca(f)
if pretty:
f.supylabel(ylabel, fontsize=20, horizontalalignment='center')
if not split_by_metrics or len(metrics) == 1:
texts.append(f.supylabel(ylabel, fontsize=20, horizontalalignment='center'))
else:
pass
else:
if not split_by_metrics or len(metrics) == 1:
f.supylabel(ylabel, horizontalalignment='center')
if title is not None:
plt.title(title)
if pretty:
texts.append(plt.title(title, fontsize=18))
else:
texts.append(plt.title(title))
plt.grid(True)
if ylim is not None:
plt.ylim(ylim)
if xlim is not None:
plt.xlim(xlim)
texts = []
if pretty:
from matplotlib.ticker import ScalarFormatter
xfmt = ScalarFormatter(useMathText=True)
xfmt.set_powerlimits((-4, 4)) # Or whatever your limits are . . .
plt.gca().yaxis.set_major_formatter(xfmt)
plt.gca().yaxis.offsetText.set_fontsize(15)
plt.gca().xaxis.set_major_formatter(xfmt)
plt.gca().xaxis.offsetText.set_fontsize(15)
# plt.xlabel(xlabel, fontsize=20)
# plt.ylabel(ylabel, fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.title(title, fontsize=18)
else:
plt.gcf().subplots_adjust(bottom=0.12, left=0.12)
# if pretty:
# from matplotlib.ticker import ScalarFormatter
# xfmt = ScalarFormatter(useMathText=True)
# xfmt.set_powerlimits((-4, 4)) # Or whatever your limits are . . .
# plt.gca().yaxis.set_major_formatter(xfmt)
# plt.gca().yaxis.offsetText.set_fontsize(15)
# plt.gca().xaxis.set_major_formatter(xfmt)
# plt.gca().xaxis.offsetText.set_fontsize(15)
# texts.extend(plt.xticks(fontsize=16)[1])
# texts.extend(plt.yticks(fontsize=16)[1])
# texts.append(plt.title(title, fontsize=18))
# else:
# plt.gcf().subplots_adjust(bottom=0.12, left=0.12)
return f, axarr, lgd, texts, g2lf, score_results

def regression_analysis(df):
Expand Down

This file was deleted.

Loading