Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ pip install -e .
We build an example project to include most of the features of RLA, which can be seen in ./example/simplest_code. Now we summarize the steps to use it.

### Step1: Configuration.
1. To configure the experiment "database", you need to create a YAML file rla_config.yaml. You can use the template provided in ./example/simplest_code/rla_config.yaml as a starting point.
1. To configure the experiment "database", you need to create a YAML file rla_config.yaml. You can use the template provided in ./example/rla_config.yaml as a starting point.
2. Before starting your experiment, you should configure the RLA.exp_manager object. Here's an example:

```python
Expand Down Expand Up @@ -256,7 +256,7 @@ from RLA import MatplotlibRecorder as mpr
def plot_func():
import matplotlib.pyplot as plt
plt.plot([1,1,1], [2,2,2])
mpr.pretty_plot_wrapper('func', plot_func, xlabel='x', ylabel='y', title='react test', )
mpr.pretty_plot_wrapper('func', plot_func, pretty_plot=True, xlabel='x', ylabel='y', title='react test')
```

This code plots a figure using Matplotlib and saves it in the "results" directory.
Expand Down
11 changes: 9 additions & 2 deletions RLA/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
DEFAULT_X_NAME = 'time-step'

DEFAULT_TIMESTAMP = 'timestamp'
class FRAMEWORK:
tensorflow = 'tensorflow'
torch = 'torch'
Expand All @@ -9,4 +9,11 @@ class FTP_PROTOCOL_NAME:
SFTP = 'sftp'

class LOG_NAME_FORMAT_VERSION:
V1 = 'v1'
V1 = 'v1'



class PLATFORM_TYPE:
WIN = 'win'
LINUX = 'linux'
OTHER = 'other'
17 changes: 10 additions & 7 deletions RLA/easy_log/complex_data_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def save(cls, name=None, fig=None, cover=False, add_timestamp=True, **kwargs):

@classmethod
def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
cover=False, legend_outside=False, xlabel='', ylabel='', title='',
cover=False, legend_outside=False, pretty_plot=False, xlabel='', ylabel='', title='',
add_timestamp=True, *args, **kwargs):
"""
Save the customized plot figure to the RLA database.
Expand All @@ -47,6 +47,8 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
:type cover: bool
:param legend_outside: let legend be outside of the figure.
:type legend_outside: bool
:param pretty_plot: use predefined configurations for plotting.
:type pretty_plot: bool
:param xlabel: name of xlabel
:type xlabel: str
:param ylabel: name of xlabel
Expand All @@ -66,12 +68,13 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
plot_func()
lgd = plt.legend(prop={'size': 15}, loc=2 if legend_outside else None,
bbox_to_anchor=(1, 1) if legend_outside else None)
plt.xlabel(xlabel, fontsize=18)
plt.ylabel(ylabel, fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.title(title, fontsize=13)
plt.grid(True)
if pretty_plot:
plt.xlabel(xlabel, fontsize=18)
plt.ylabel(ylabel, fontsize=18)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.title(title, fontsize=13)
plt.grid(True)
if lgd is not None:
cls.save(name, cover=cover, add_timestamp=add_timestamp, bbox_extra_artists=tuple([lgd]),
bbox_inches='tight', *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def is_valid_config(self):
if self.load_date is not None and self.task_name is not None and self.data_root is not None:
return True
else:
logger.warn("meet invalid loader config when use it")
logger.warn("meet invalid loader config when using it")
logger.warn("load_date", self.load_date)
logger.warn("task_name", self.task_name)
logger.warn("root", self.data_root)
Expand Down
15 changes: 9 additions & 6 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import json
from RLA.easy_log.tester import Tester

from RLA.utils.utils import get_dir_seperator
class Filter(object):
ALL = 'all'
SMALL_TIMESTEP = 'small_ts'
Expand Down Expand Up @@ -214,8 +214,9 @@ def _archive_log(self, show=False):
empty = False
if os.path.exists(root_dir):
# remove the overlapped path.
septor = get_dir_seperator()
archiving_target = osp.join(archive_root_dir, root_dir[prefix_len+1:])
archiving_target_dir = '/'.join(archiving_target.split('/')[:-1])
archiving_target_dir = septor.join(archiving_target.split(septor)[:-1])
os.makedirs(archiving_target_dir, exist_ok=True)
if os.path.isdir(root_dir):
if not show:
Expand Down Expand Up @@ -269,21 +270,23 @@ def _migrate_log(self, show=False):
if os.path.exists(root_dir):
# remove the overlapped path.
archiving_target = osp.join(target_root_dir, root_dir[prefix_len+1:])
archiving_target_dir = '/'.join(archiving_target.split('/')[:-1])

septor = get_dir_seperator()
archiving_target_dir = septor.join(archiving_target.split(septor)[:-1])
print("target dir", archiving_target_dir)
os.makedirs(archiving_target_dir, exist_ok=True)
if os.path.isdir(root_dir):
print("copy dir {}, to {}".format(root_dir, archiving_target))
if not show:
# os.makedirs(archiving_target, exist_ok=True)
try:
shutil.copytree(root_dir, archiving_target)
except FileExistsError as e:
print(e)

print("copy dir {}, to {}".format(root_dir, archiving_target))
else:
print("copy file {}, to {}".format(root_dir, archiving_target))
if not show:
shutil.copy(root_dir, archiving_target)
print("copy file {}, to {}".format(root_dir, archiving_target))
else:
print("no dir {}".format(root_dir))
if empty: print("empty regex {}".format(root_dir_regex))
Expand Down
8 changes: 7 additions & 1 deletion RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
from contextlib import contextmanager
from RLA.const import DEFAULT_X_NAME, FRAMEWORK
from RLA.const import DEFAULT_X_NAME, FRAMEWORK, DEFAULT_TIMESTAMP

DEBUG = 10
INFO = 20
Expand Down Expand Up @@ -510,6 +510,12 @@ def dumpkvs():
print(e)
for fmt in Logger.CURRENT.output_formats:
print(fmt)
try:
get_current().logkv(DEFAULT_TIMESTAMP, time.time())
except NotImplementedError as e:
print(e)
for fmt in Logger.CURRENT.output_formats:
print(fmt)
return get_current().dumpkvs()

def getkvs():
Expand Down
18 changes: 12 additions & 6 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import shutil
import argparse
from typing import Dict, List, Tuple, Type, Union, Optional
from RLA.utils.utils import deprecated_alias, load_yaml
from RLA.utils.utils import deprecated_alias, load_yaml, get_dir_seperator
from RLA.const import DEFAULT_X_NAME, FRAMEWORK
import pathspec

Expand Down Expand Up @@ -150,10 +150,11 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
logger.info("private_config: ")
self.dl_framework = self.private_config["DL_FRAMEWORK"]
self.is_master_node = is_master_node

self.septor = get_dir_seperator()
if code_root is None:
if isinstance(rla_config, str):
self.project_root = "/".join(rla_config.split("/")[:-1])

self.project_root = self.septor.join(rla_config.split(self.septor)[:-1])
else:
raise NotImplementedError("If you pass the rla_config dict directly, "
"you should define the root of your codebase (for backup) explicitly by pass the code_root.")
Expand Down Expand Up @@ -309,14 +310,19 @@ def load_tester(cls, record_date, task_table_name, log_root):
def add_record_param(self, keys):
for k in keys:
if '.' in k:
sub_k = None
try:
sub_k_list = k.split('.')
v = self.hyper_param[sub_k_list[0]]
sub_k = sub_k_list[0]
v = self.hyper_param[sub_k]
for sub_k in sub_k_list[1:]:
v = v[sub_k]
self.hyper_param_record.append(str(k) + '=' + str(v).replace('[', '{').replace(']', '}').replace('/', '_'))
except KeyError as e:
print("do not include dot ('.') in your hyperparemeter name")
print(f"the current key to parsed is: {k}. Can not find a matching key for {sub_k}."
"\n Hint: do not include dot ('.') in your hyperparemeter name."
"\n The recorded hyper parameters are")
self.print_args()
else:
self.hyper_param_record.append(str(k) + '=' + str(self.hyper_param[k]).replace('[', '{').replace(']', '}').replace('/', '_'))

Expand Down Expand Up @@ -620,7 +626,7 @@ def time_record_end(self, name:str):
end_time = time.time()
start_time = self._rc_start_time[name]
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
# logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
del self._rc_start_time[name]

# Saver manger.
Expand Down
14 changes: 2 additions & 12 deletions RLA/easy_log/time_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from RLA.easy_log import logger

import time

class TimeStepHolder(object):

Expand All @@ -8,20 +8,10 @@ def __init__(self, time, epoch, tf_log=False):
self.__outer_epoch = epoch
self.tf_log = tf_log

def config(self, time=0, epoch=0, tf_log=False):
def config(self, time=0, tf_log=False):
self.__timesteps = time
self.__outer_epoch = epoch
self.tf_log = tf_log

def set_outer_epoch(self, epoch):
self.__outer_epoch = epoch

def get_outer_epoch(self):
return self.__outer_epoch

def inc_outer_epoch(self):
self.__outer_epoch +=1

def set_time(self, time):
self.__timesteps = time
self.__update_tf_times()
Expand Down
2 changes: 1 addition & 1 deletion RLA/easy_log/time_used_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ def time_record_end(name):
end_time = time.time()
start_time = rc_start_time[name]
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
# logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
del rc_start_time[name]
2 changes: 2 additions & 0 deletions RLA/easy_plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from RLA.easy_plot.plot_saved_images import plot_saved_images
from RLA.easy_plot.plot_func_v2 import plot_func
103 changes: 58 additions & 45 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,50 @@
from RLA.query_tool import experiment_data_query, extract_valid_index
from RLA.easy_plot import plot_util
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS, HYPARAM


from RLA.easy_plot.utils import results_loader
from RLA.query_tool import LogQueryResult

def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
"""
Formats the keys into a string to be used as legend in a plot.
If a key is not in parse_dict, it's added with value 'NF'.

:param parse_dict: Dictionary with keys to be formatted into a legend.
:type parse_dict: Dict
:param split_keys: List of keys to be checked in parse_dict.
:type split_keys: List
:param y_name: Value to be appended to the legend.
:type y_name: str
:param use_y_name: If True, appends y_name to the legend.
:type use_y_name: bool, default to True
"""
for k in split_keys:
if k not in parse_dict.keys():
parse_dict[k] = 'NF'
task_split_key = '.'.join(f'{k}={parse_dict[k]}' for k in split_keys)
if use_y_name:
return task_split_key + ' eval:' + y_name
else:
return task_split_key

def meta_csv_data_loader_func(query_res, select_names, x_bound, use_buf):
assert isinstance(query_res, LogQueryResult)
dirname = query_res.dirname
result = plot_util.load_results(dirname, names=select_names, x_bound=x_bound, use_buf=use_buf)
if len(result) == 0:
return None
assert len(result) == 1
result = result[0]
return result

def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
use_buf=False, verbose=True,
use_buf=False, verbose=False, summarize_res=True,
x_bound: Optional[int]=None,
xlabel: Optional[str] = DEFAULT_X_NAME, ylabel: Optional[Union[str, list]] = None,
scale_dict: Optional[dict] = None, regs2legends: Optional[list] = None,
hp_filter_dict: Optional[dict] = None,
key_to_legend_fn: Optional[Callable] = default_key_to_legend,
split_by_metrics=True,
save_name: Optional[str] = None, *args, **kwargs):
split_by_metrics=True, save_name: Optional[str]=None, *args, **kwargs):
"""
A high-level matplotlib plotter.
The function is to load your experiments and plot curves.
Expand All @@ -43,23 +69,28 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
The function also supports several configure to post-process your log data, including resample, smooth_step, scale_dict, key_to_legend_fn, etc.
The function also supports several configure to beautify the figure, see the parameters of plot_util.plot_results.

:param data_root:
:type data_root:
:param task_table_name:
:type task_table_name:
:param regs:
:type regs:
:param split_keys:
:type split_keys:
:param metrics:
:type metrics:
:param use_buf: use the preloaded csv data instead of loading from scratch
:type use_buf: bool
:param x_bound: drop the collected with time-step larger than x_bound.
:param xlabel: set the label of the y axes.
:param data_root: Root directory for the data.
:type data_root: str
:param task_table_name: Task table name.
:type task_table_name: str
:param regs: List of regular expressions used for matching files/directories.
:type regs: list
:param split_keys: List of keys to group experiments.
:type split_keys: list
:param metrics: List of metrics to be plotted.
:type metrics: list
:param use_buf: If True, uses preloaded csv data instead of loading from scratch.
:type use_buf: bool, default to False
:param verbose: If True, prints detailed log information during the process.
:type verbose: bool, default to True
:param x_bound: Drops the data collected with time-step larger than x_bound.
:type x_bound: Optional[int]
:type xlabel: Optional[str]
:param ylabel: set the label of the y axes.
:type ylabel: Optional[str,list]
:param hp_filter_dict: a dict to filter your log.
e.g., hp_filter_dict= {'learning_rate': [0.001, 0.01, 0.1]} will select the logs where the learning rate is 0.001, 0.01, or 0.1.
:type hp_filter_dict: Optional[dict]
:param scale_dict: a function dict, to map the value of the metrics through customize functions.
e.g.,set metrics = ['return'], scale_dict = {'return': lambda x: np.log(x)}, then we will plot a log-scale return.
:type scale_dict: Optional[dict]
Expand All @@ -77,32 +108,14 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
:return:
:rtype:
"""
results = []
reg_group = {}
for reg in regs:
reg_group[reg] = []
print("searching", reg)
tester_dict = experiment_data_query(data_root, task_table_name, reg, ARCHIVE_TESTER)
log_dict = experiment_data_query(data_root, task_table_name, reg, LOG)
counter = 0
for k, v in log_dict.items():
result = plot_util.load_results(v.dirname, names=metrics + [DEFAULT_X_NAME],
x_bound=[DEFAULT_X_NAME, x_bound], use_buf=use_buf)
if len(result) == 0:
continue
assert len(result) == 1
result = result[0]
if verbose:
print("find log", v.dirname)
counter += 1
if os.path.exists(osp.join(v.dirname, HYPARAM + '.json')):
with open(osp.join(v.dirname, HYPARAM + '.json')) as f:
result.hyper_param = json.load(f)
else:
result.hyper_param = tester_dict[k].exp_manager.hyper_param
results.append(result)
reg_group[reg].append(result)
print("find log number", counter)
csv_data_loader_func = lambda dirname: meta_csv_data_loader_func(dirname, select_names=metrics + [DEFAULT_X_NAME],
x_bound=[DEFAULT_X_NAME, x_bound], use_buf=use_buf)
results, reg_group = results_loader(data_root, task_table_name, regs, hp_filter_dict, csv_data_loader_func, verbose, data_type=LOG)
if summarize_res:
for k, v in reg_group.items():
print(f"for regex {k}, we have the following logs:")
for res in v:
print("find log", res.dirname, "\n [parsed key]", key_to_legend_fn(res.hyper_param, split_keys, '', False))
final_scale_dict = {}
for m in metrics:
final_scale_dict[m] = lambda x: x
Expand Down
Loading