Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ Result visualization:
We can view the results in tensorboard by: `tensorboard --logdir ${data_root}/log/${task_name}`.
For example, lanuch tensorboard by `tensorboard --logdir ./example/simplest_code/log/demo_task/2022/03`. We can see results:
![img.png](resource/demo-tb-res.png)
2. Easy_plot toolkit: The intermediate scalar variables are saved in a CSV file in `${data_root}/log/${task_name}/${index_name}/progress.csv`.
We develop high-level APIs to load the CSV files from multiple experiments and group the lines by custom keys. We give an example to use easy_plot toolkit in https://github.com/xionghuichen/RLAssistant/blob/main/example/plot_res.ipynb and more user cases in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_plot.py
2. Easy_plot toolkit:
**We recommend the users maintain their research projects via some jupyter notebooks. You can record your ideas, surmises, related empirical evidence, and benchmark results in a notebook together.** We develop high-level APIs to load the CSV files of the experiments (stored in `${data_root}/log/${task_name}/${index_name}/progress.csv`) and group the curves by custom keys. We give common user cases of the plotter in https://github.com/xionghuichen/RLAssistant/blob/main/test/test_plot.ipynb
The result will be something like this:
![img.png](resource/demo-easy-to-plot-res.png)
3. View data in "results" directory directly: other type of data are stored in `${data_root}/results/${task_name}/${index_name}`
Expand Down
1 change: 1 addition & 0 deletions RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def writekvs(self, kvs):
lines = self.file.readlines()
self.file = open(self.filename, 'w+t')
self.file.seek(0)

for (i, key) in enumerate(self.keys):
if i > 0:
self.file.write(',')
Expand Down
20 changes: 14 additions & 6 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@



def default_key_to_legend(parse_dict, split_keys, y_name):
def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
return task_split_key + ' eval:' + y_name
if use_y_name:
return task_split_key + ' eval:' + y_name
else:
return task_split_key


def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
Expand All @@ -29,6 +32,7 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[str] = None,
scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None,
key_to_legend_fn: Optional[Callable] = default_key_to_legend,
split_by_metrics=True,
save_name: Optional[str] = None, *args, **kwargs):
"""
A high-level matplotlib plotter.
Expand Down Expand Up @@ -68,8 +72,11 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
:param key_to_legend_fn: we give a default function to stringify the k-v pairs. you can customize your own function in key_to_legend_fn.
See default_key_to_legend for the detault way and test/test_plot/test_customize_legend_name_mode for details.
:type key_to_legend_fn: Optional[Callable] = default_key_to_legend
:param split_by_metrics: you can plot figure with multiple metrics together.
By default, we will split the curves with the metric and merge them into a group figure.
If you would like to print multiple metrics in single figure, please set the parameter to False.
:type split_by_metrics: Optional[bool]
:param args/kwargs: send other parameters to plot_util.plot_results

:return:
:rtype:
"""
Expand Down Expand Up @@ -112,10 +119,12 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
group_fn = lambda r: split_by_reg(taskpath=r, reg_group=reg_group, y_names=y_names)
else:
group_fn = lambda r: picture_split(taskpath=r, split_keys=split_keys, y_names=y_names,
key_to_legend_fn=key_to_legend_fn)
key_to_legend_fn=lambda parse_dict, split_keys, y_name:
key_to_legend_fn(parse_dict, split_keys, y_name, not split_by_metrics))
_, _, lgd, texts, g2lf, score_results = \
plot_util.plot_results(results, xy_fn= lambda r, y_names: csv_to_xy(r, DEFAULT_X_NAME, y_names, final_scale_dict),
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, regs2legends=regs2legends, *args, **kwargs)
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, metrics=metrics,
split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs)
print("--- complete process ---")
if save_name is not None:
import os
Expand All @@ -141,7 +150,6 @@ def split_by_reg(taskpath, reg_group, y_names):
assert len(y_names) == 1
return task_split_key, y_names


def split_by_task(taskpath, split_keys, y_names, key_to_legend_fn):
pair_delimiter = '&'
kv_delimiter = '='
Expand Down
186 changes: 108 additions & 78 deletions RLA/easy_plot/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,22 +289,22 @@ def default_split_fn(r):
def plot_results(
allresults, *,
xy_fn=default_xy_fn,
metrics=None,
split_fn=None, # default_split_fn,
group_fn=None, # default_split_fn,
average_group=False,
shaded_std=False,
shaded_err=True,
shaded_range=True,
figsize=None,
legend_outside=False,
legend_outside=True,
resample=0,
vary_len_plot=False,
smooth_step=1.0,
tiling='vertical',
tiling='symmetric',
xlabel=None,
ylabel=None,
title=None,
replace_legend_keys=None,
regs2legends=None,
pretty=False,
bound_line=None,
Expand All @@ -314,6 +314,7 @@ def plot_results(
xlim=None,
show_number=True,
skip_legend=False,
split_by_metrics=False,
rescale_idx=None):
'''
Plot multiple Results objects
Expand Down Expand Up @@ -364,12 +365,20 @@ def plot_results(
colors = PRETTY_COLORS
else:
colors = COLORS

if pretty:
assert not split_by_metrics or len(metrics) == 1, "pretty mode cannot support the multiply metric plotting. Please use only one metric for plotting."
split_by_metrics = split_by_metrics and len(metrics) != 1
if split_fn is None: split_fn = lambda _ : ''
if group_fn is None: group_fn = lambda _ : ''
sk2r = defaultdict(list) # splitkey2results
for result in allresults:
splitkey = split_fn(result)
sk2r[splitkey].append(result)
if split_by_metrics:
for y in metrics:
sk2r[y].extend(allresults)
else:
for result in allresults:
splitkey = split_fn(result)
sk2r[splitkey].append(result)
assert len(sk2r) > 0
assert isinstance(resample, int), "0: don't resample. <integer>: that many samples"
if tiling == 'vertical' or tiling is None:
Expand All @@ -392,7 +401,7 @@ def plot_results(
# figsize = list(figsize)
# figsize[0] += 4
# figsize = tuple(figsize)
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize)
f, axarr = plt.subplots(nrows, ncols, sharex=False, squeeze=False, figsize=figsize, dpi=90 * ncols)
groups = []
for results in allresults:
groups.extend(group_fn(results)[0])
Expand All @@ -415,6 +424,8 @@ def plot_results(
ax = axarr[idx_row][idx_col]
for result in sresults:
result_groups, y_names = group_fn(result)
if split_by_metrics:
y_names = [sk]
for group, y_name in zip(result_groups, y_names):
g2c[group] += 1
res = xy_fn(result, y_name)
Expand Down Expand Up @@ -449,7 +460,7 @@ def plot_results(
def allequal(qs):
return all((q==qs[0]).all() for q in qs[1:])
if resample:
low = max(x[0] for x in origxs)
low = max(x[0] for x in origxs)
high = min(x[-1] for x in origxs)
usex = np.linspace(low, high, resample)
ys = []
Expand Down Expand Up @@ -487,84 +498,103 @@ def allequal(qs):
if shaded_range:
g2lf[group + '-sr'] = [ax.fill_between(usex, ymin, ymax, color=color, alpha=.1), ymin, ymax]

ax.set_title(sk)
if split_by_metrics:
ax.set_ylabel(sk)
if log:
ax.set_yscale('log')


# https://matplotlib.org/users/legend_guide.html
if not pretty:
plt.tight_layout()
if any(g2l.keys()):
if show_number:
legend_keys = np.array(['%s (%i)'%(g, g2c[g]) for g in g2l] if average_group else g2l.keys())
else:
legend_keys = np.array(['%s'%(g) for g in g2l] if average_group else g2l.keys())

legend_lines = np.array(list(g2l.values()))

sorted_index = np.argsort(legend_keys)
legend_keys = legend_keys[sorted_index]
legend_lines = legend_lines[sorted_index]
if replace_legend_keys is not None:
legend_keys = np.array(replace_legend_keys)
if regs2legends is not None:
legend_keys = np.array(regs2legends)
# if replace_legend_sort is not None:
# sorted_index = replace_legend_sort
# else:
# sorted_index = np.argsort(legend_keys)
# assert legend_keys.shape[0] == legend_lines.shape[0], \
# "The number of lines is not consistent with the keys"
# legend_keys = legend_keys[sorted_index]
# legend_lines = legend_lines[sorted_index]
if pretty:
for index, l in enumerate(legend_lines):
l.update(props={"color": colors[index % len(colors)]})
original_legend_keys = np.array(['%s' % (g) for g in g2l] if average_group else g2l.keys())
original_legend_keys = original_legend_keys[sorted_index]
if shaded_err:
res = g2lf[original_legend_keys[index] + '-se']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-err : ({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-err'] = [res[1][-1], res[2][-1]]
if shaded_std:
res = g2lf[original_legend_keys[index] + '-ss']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-std :({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-std'] = [res[1][-1], res[2][-1]]
if shaded_range:
res = g2lf[original_legend_keys[index] + '-sr']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-range : ({:.3f}, {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-range'] = [res[1][-1], res[2][-1]]

if bound_line is not None:
for bl in bound_line:
y = np.ones(x.shape) * bl[0]
l, = ax.plot(x, y, bl[2], color=bl[1])
legend_lines = np.append(legend_lines, l)
legend_keys = np.append(legend_keys, bl[3])
if not skip_legend:
# https://matplotlib.org/users/legend_guide.html
# if not pretty:
# plt.tight_layout()
if any(g2l.keys()):
if show_number:
legend_keys = np.array(['%s (%i)' % (g, g2c[g]) for g in g2l] if average_group else g2l.keys())
else:
legend_keys = np.array(['%s' % (g) for g in g2l] if average_group else g2l.keys())

legend_lines = np.array(list(g2l.values()))
sorted_index = np.argsort(legend_keys)
legend_keys = legend_keys[sorted_index]
legend_lines = legend_lines[sorted_index]
if regs2legends is not None:
legend_keys = np.array(regs2legends)
# if replace_legend_sort is not None:
# sorted_index = replace_legend_sort
# else:
# sorted_index = np.argsort(legend_keys)
# assert legend_keys.shape[0] == legend_lines.shape[0], \
# "The number of lines is not consistent with the keys"
# legend_keys = legend_keys[sorted_index]
# legend_lines = legend_lines[sorted_index]
if pretty:
for index, l in enumerate(legend_lines):
l.update(props={"color": colors[index % len(colors)]})
original_legend_keys = np.array(['%s' % (g) for g in g2l] if average_group else g2l.keys())
original_legend_keys = original_legend_keys[sorted_index]
if shaded_err:
res = g2lf[original_legend_keys[index] + '-se']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-err : ({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-err'] = [res[1][-1], res[2][-1]]
if shaded_std:
res = g2lf[original_legend_keys[index] + '-ss']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-std :({:.3f} $\pm$ {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-std'] = [res[1][-1], res[2][-1]]
if shaded_range:
res = g2lf[original_legend_keys[index] + '-sr']
res[0].update(props={"color": colors[index % len(colors)]})
print("{}-range : ({:.3f}, {:.3f})".format(legend_keys[index], res[1][-1], res[2][-1]))
score_results[legend_keys[index]+'-range'] = [res[1][-1], res[2][-1]]

if bound_line is not None:
for bl in bound_line:
y = np.ones(x.shape) * bl[0]
l, = ax.plot(x, y, bl[2], color=bl[1])
legend_lines = np.append(legend_lines, l)
legend_keys = np.append(legend_keys, bl[3])

if not skip_legend:
# print(nrows)
if len(sk2r.keys()) == 1:
lgd = ax.legend(
legend_lines,
legend_keys,
loc=2 if legend_outside else None,
bbox_to_anchor=(1,1) if legend_outside else None,
fontsize=15 if pretty else None)
else:
lgd = None
ax.set_title(sk)
if log:
ax.set_yscale('log')
lgd = f.legend(
legend_lines,
legend_keys,
loc='lower center' if legend_outside else None,
bbox_to_anchor=(0.5, 0.0) if legend_outside else None,
fontsize=15 if pretty else None,
borderaxespad=0)
else:
lgd = None


# add xlabels, but only to the bottom row
if xlabel is not None:
for ax in axarr[-1]:
plt.sca(ax)
# add xlabels, but only to the bottom row
if xlabel is not None:
for ax in axarr[-1]:
plt.sca(ax)
if pretty:
plt.xlabel(xlabel, fontsize=20)
else:
plt.xlabel(xlabel)
# add ylabels, but only to left column
if ylabel is not None:
for ax in axarr[:,0]:
plt.sca(ax)
plt.ylabel(ylabel)
# add ylabels, but only to left column
if ylabel is not None:
# for ax in axarr[:,0]:
# plt.sca(axarr[0, 0])
# plt.ylabel(ylabel)
# plt.sca(f)
if pretty:
f.supylabel(ylabel, fontsize=20, horizontalalignment='center')
else:
if not split_by_metrics or len(metrics) == 1:
f.supylabel(ylabel, horizontalalignment='center')
if title is not None:
plt.title(title)
plt.grid(True)
Expand All @@ -581,8 +611,8 @@ def allequal(qs):
plt.gca().yaxis.offsetText.set_fontsize(15)
plt.gca().xaxis.set_major_formatter(xfmt)
plt.gca().xaxis.offsetText.set_fontsize(15)
plt.xlabel(xlabel, fontsize=20)
plt.ylabel(ylabel, fontsize=20)
# plt.xlabel(xlabel, fontsize=20)
# plt.ylabel(ylabel, fontsize=20)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.title(title, fontsize=18)
Expand Down
Loading