Skip to content
Merged

Dev #13

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0d25abf
Merge pull request #19 from xionghuichen/dev
xionghuichen Jun 22, 2022
cf30b18
Update README.md
xionghuichen Jun 22, 2022
33e6aee
Dev (#20)
xionghuichen Jul 13, 2022
74c7712
fix: minor changes for version compatibility
xionghuichen Jul 14, 2022
e36a1bc
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 14, 2022
3bdbf3e
Dev (#21)
xionghuichen Jul 14, 2022
96ba639
fix: a bug of sorting in torch-version checkpoint loading
xionghuichen Jul 14, 2022
680c6be
Dev (#22)
xionghuichen Jul 14, 2022
44c2cff
refactor: robust multi-key plot implementation
xionghuichen Jul 21, 2022
ce25856
feat: supoort pretty plotter
xionghuichen Jul 23, 2022
efc4815
refactor(log plotter): record scores of the log plotter
xionghuichen Jul 23, 2022
3fca57c
fix(exp_loader): add parameter ckp_index
xionghuichen Jul 23, 2022
017efc2
refactor(rla_script): add start_server to start_pretty_plotter.py
xionghuichen Jul 23, 2022
7c0b7dc
update readme
xionghuichen Jul 23, 2022
4239bd6
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 23, 2022
29e4932
Dev (#23)
xionghuichen Jul 23, 2022
5a0c180
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 27, 2022
d3ff59d
rm unsolved merge
xionghuichen Jul 27, 2022
f67a10a
Dev (#24)
xionghuichen Jul 27, 2022
65d2859
feat: tf-v2 compatible
xionghuichen Jul 27, 2022
845f3ab
refactor: add timestep recorder. refactor on exp_loader
xionghuichen Aug 10, 2022
9f799c3
test: add test data
xionghuichen Aug 10, 2022
a277416
feat(plot): track the hyper-parameter from the exp_manager instead of…
xionghuichen Aug 10, 2022
e9b29a7
Dev (#25)
xionghuichen Aug 10, 2022
00d26d2
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 10, 2022
f71bdcd
test(plot): add user cases and documents
xionghuichen Aug 11, 2022
a5f88dd
test(plot): add user cases
xionghuichen Aug 11, 2022
23349cd
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Aug 11, 2022
26ee5d3
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
load_res = {}
if var_prefix is not None:
loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
_, load_res = loaded_tester.load_checkpoint(ckp_index)
else:
loaded_tester.new_saver(max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
_, load_res = loaded_tester.load_checkpoint(ckp_index)
hist_variables = {}
if variable_list is not None:
for v in variable_list:
Expand Down
63 changes: 54 additions & 9 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,11 +543,31 @@ def update_fph(self, cum_epochs):
# self.last_record_fph_time = cur_time
logger.dump_tabular()

def time_record(self, name):
def time_record(self, name:str):
"""
[deprecated] see RLA.easy_log.time_used_recorder
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
assert name not in self._rc_start_time
self._rc_start_time[name] = time.time()

def time_record_end(self, name):
def time_record_end(self, name:str):
"""
[deprecated] see RLA.easy_log.time_used_recorder
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
end_time = time.time()
start_time = self._rc_start_time[name]
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
Expand All @@ -566,23 +586,46 @@ def new_saver(self, max_to_keep, var_prefix=None):
import tensorflow as tf
if var_prefix is None:
var_prefix = ''
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
logger.info("save variable :")
for v in var_list:
logger.info(v)
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
try:
var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
logger.info("save variable :")
for v in var_list:
logger.info(v)
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir,
save_relative_paths=True)

except AttributeError as e:
self.max_to_keep = max_to_keep
# tf.compat.v1.disable_eager_execution()
# tf = tf.compat.v1
# var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix)
elif self.dl_framework == FRAMEWORK.torch:
self.max_to_keep = max_to_keep
else:
raise NotImplementedError

def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Optional[dict]=None):
def save_checkpoint(self, model_dict: Optional[dict] = None, related_variable: Optional[dict] = None):
if self.dl_framework == FRAMEWORK.tensorflow:
import tensorflow as tf
iter = self.time_step_holder.get_time()
cpt_name = osp.join(self.checkpoint_dir, 'checkpoint')
logger.info("save checkpoint to ", cpt_name, iter)
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
try:
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
except AttributeError as e:
if model_dict is None:
logger.warn("call save_checkpoints without passing a model_dict")
return
if self.checkpoint_keep_list is None:
self.checkpoint_keep_list = []
iter = self.time_step_holder.get_time()
# tf.compat.v1.disable_eager_execution()
# tf = tf.compat.v1
# self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)

tf.train.Checkpoint(**model_dict).save(tester.checkpoint_dir + "checkpoint-{}".format(iter))
self.checkpoint_keep_list.append(iter)
self.checkpoint_keep_list = self.checkpoint_keep_list[-1 * self.max_to_keep:]
elif self.dl_framework == FRAMEWORK.torch:
import torch
if self.checkpoint_keep_list is None:
Expand All @@ -602,6 +645,7 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
for k, v in related_variable.items():
self.add_custom_data(k, v, type(v), mode='replace')
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')
self.serialize_object_and_save()

def load_checkpoint(self, ckp_index=None):
if self.dl_framework == FRAMEWORK.tensorflow:
Expand All @@ -613,6 +657,7 @@ def load_checkpoint(self, ckp_index=None):
ckpt_path = tf.train.latest_checkpoint(cpt_name)
else:
ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index)
logger.info("load ckpt_path {}".format(ckpt_path))
self.saver.restore(tf.get_default_session(), ckpt_path)
max_iter = ckpt_path.split('-')[-1]
return int(max_iter), None
Expand Down
44 changes: 27 additions & 17 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,31 @@

from RLA import logger
from RLA.const import DEFAULT_X_NAME
from RLA.query_tool import experiment_data_query
from RLA.query_tool import experiment_data_query, extract_valid_index

from RLA.easy_plot import plot_util
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS



def default_key_to_legend(parse_list, y_name):
task_split_key = '.'.join(parse_list)
def default_key_to_legend(parse_dict, split_keys, y_name):
task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
return task_split_key + ' eval:' + y_name


def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
use_buf=False, verbose=True,
xlim: Optional[tuple] = None,
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[str] = None,
scale_dict: Optional[dict] = None, replace_legend_keys: Optional[list] = None,
scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None,
key_to_legend_fn: Optional[Callable] = default_key_to_legend,
save_name: Optional[str] = None, *args, **kwargs):
"""
A high-level matplotlib plotter.
The function is to load your experiments and plot curves.
You can group several experiments into a single figure through this function.
It is completed by loading experiments satisfying [data_root, task_table_name, regs] pattern,
grouping by "split_keys" or by the "regs" terms (see replace_legend_keys), and plotting the customized "metrics".
grouping by "split_keys" or by the "regs" terms (see regs2legends), and plotting the customized "metrics".

The function support several configure to customize the figure, including xlim, xlabel, ylabel, key_to_legend_fn, etc.
The function also supports several configure to post-process your log data, including resample, smooth_step, scale_dict, key_to_legend_fn, etc.
Expand All @@ -61,7 +62,13 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
:param scale_dict: a function dict, to map the value of the metrics through customize functions.
e.g.,set metrics = ['return'], scale_dict = {'return': lambda x: np.log(x)}, then we will plot a log-scale return.
:type scale_dict: Optional[dict]
:param args: set the label of the y axes.
:param regs2legends: use regex-to-legend mode to plot the figure. Each iterm in regs will be gouped into a curve.
In this reg2legend_map mode, you should define the lgend name for each curve. See test/test_plot/test_reg_map_mode for details.
:type regs2legends: Optional[list] = None
:param key_to_legend_fn: we give a default function to stringify the k-v pairs. you can customize your own function in key_to_legend_fn.
See default_key_to_legend for the detault way and test/test_plot/test_customize_legend_name_mode for details.
:type key_to_legend_fn: Optional[Callable] = default_key_to_legend
:param args/kwargs: send other parameters to plot_util.plot_results

:return:
:rtype:
Expand Down Expand Up @@ -98,17 +105,17 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
if ylabel is None:
ylabel = metrics

if replace_legend_keys is not None:
assert len(replace_legend_keys) == len(regs) and len(metrics) == 1, \
if regs2legends is not None:
assert len(regs2legends) == len(regs) and len(metrics) == 1, \
"In manual legend-key mode, the number of keys should be one-to-one matched with regs"
# if len(replace_legend_keys) == len(regs):
# if len(regs2legends) == len(regs):
group_fn = lambda r: split_by_reg(taskpath=r, reg_group=reg_group, y_names=y_names)
else:
group_fn = lambda r: picture_split(taskpath=r, split_keys=split_keys, y_names=y_names,
key_to_legend_fn=key_to_legend_fn)
_, _, lgd, texts, g2lf, score_results = \
plot_util.plot_results(results, xy_fn= lambda r, y_names: csv_to_xy(r, DEFAULT_X_NAME, y_names, final_scale_dict),
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, replace_legend_keys=replace_legend_keys, *args, **kwargs)
group_fn=group_fn, average_group=True, ylabel=ylabel, xlabel=xlabel, regs2legends=regs2legends, *args, **kwargs)
print("--- complete process ---")
if save_name is not None:
import os
Expand All @@ -127,25 +134,28 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
def split_by_reg(taskpath, reg_group, y_names):
task_split_key = "None"
for i , reg_k in enumerate(reg_group.keys()):
if taskpath.dirname in reg_group[reg_k]:
assert task_split_key == "None", "one experiment should belong to only one reg_group"
task_split_key = str(i)
for result in reg_group[reg_k]:
if taskpath.dirname == result.dirname:
assert task_split_key == "None", "one experiment should belong to only one reg_group"
task_split_key = str(i)
assert len(y_names) == 1
return task_split_key, y_names


def split_by_task(taskpath, split_keys, y_names, key_to_legend_fn):
pair_delimiter = '&'
kv_delimiter = '='
parse_list = []
parse_dict = {}
for split_key in split_keys:
if split_key in taskpath.hyper_param:
parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key]))
parse_dict[split_key] = str(taskpath.hyper_param[split_key])
# parse_list.append(split_key + '=' + str(taskpath.hyper_param[split_key]))
else:
parse_list.append(split_key + '=NF')
parse_dict[split_key] = 'NF'
# parse_list.append(split_key + '=NF')
param_keys = []
for y_name in y_names:
param_keys.append(key_to_legend_fn(parse_list, y_name))
param_keys.append(key_to_legend_fn(parse_dict, split_keys, y_name))
return param_keys, y_names


Expand Down
4 changes: 3 additions & 1 deletion RLA/easy_plot/plot_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def plot_results(
ylabel=None,
title=None,
replace_legend_keys=None,
replace_legend_sort=None,
regs2legends=None,
pretty=False,
bound_line=None,
colors=None,
Expand Down Expand Up @@ -505,6 +505,8 @@ def allequal(qs):
legend_lines = legend_lines[sorted_index]
if replace_legend_keys is not None:
legend_keys = np.array(replace_legend_keys)
if regs2legends is not None:
legend_keys = np.array(regs2legends)
# if replace_legend_sort is not None:
# sorted_index = replace_legend_sort
# else:
Expand Down
65 changes: 62 additions & 3 deletions test/test_plot.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,85 @@
# Created by xionghuichen at 2022/8/10
# Email: chenxh@lamda.nju.edu.cn
from test._base import BaseTest
import numpy as np
from RLA.easy_log.log_tools import DeleteLogTool, Filter
from RLA.easy_log.log_tools import ArchiveLogTool, ViewLogTool
from RLA.easy_log.tester import exp_manager

from RLA import plot_func
import os

class ScriptTest(BaseTest):
def test_plot(self):
from RLA import plot_func

def get_basic_info(self):
data_root = 'test_data_root'
task = 'demo_task'
return data_root, task

def test_plot_basic(self):
data_root, task = self.get_basic_info()

regs = [
'2022/03/01/21-[12]*'
]
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'])
# customize the figure
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'], ylim=(0, 0.1))
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio', )


def test_pretty_plot(self):
data_root, task = self.get_basic_info()

regs = [
'2022/03/01/21-[12]*'
]
# save image
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio',
shaded_range=False, show_number=False, pretty=True)
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'], ylim=(0, 0.1), xlabel='epochs', ylabel='reward ratio',
shaded_range=False, pretty=True, save_name='saved_image.png')

def test_reg_map_mode(self):
# reg-map mode.
data_root, task = self.get_basic_info()
regs = [
'2022/03/01/21-[12]*learning_rate=0.01*',
'2022/03/01/21-[12]*learning_rate=0.00*',
]
_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'], regs2legends=['lr=0.01', 'lr<=0.001'],
shaded_range=False, pretty=True)

def test_customize_legend_name_mode(self):
data_root, task = self.get_basic_info()
regs = [
'2022/03/01/21-[12]*'
]

def my_key_to_legend(parse_dict, split_keys, y_name):

task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
task_split_key = task_split_key.replace('learning_rate', 'α')
return task_split_key

_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'],
key_to_legend_fn=my_key_to_legend,
shaded_range=False, pretty=True, show_number=False)

def test_post_process(self):
data_root, task = self.get_basic_info()
regs = [
'2022/03/01/21-[12]*'
]

_ = plot_func(data_root=data_root, task_table_name=task, regs=regs, split_keys=['learning_rate'],
metrics=['perf/mse'],
scale_dict={'perf/mse': lambda x: np.log(x)},
ylabel='RMSE',
shaded_range=False, pretty=True, show_number=False)