Skip to content
Merged

Dev #17

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0d25abf
Merge pull request #19 from xionghuichen/dev
xionghuichen Jun 22, 2022
cf30b18
Update README.md
xionghuichen Jun 22, 2022
33e6aee
Dev (#20)
xionghuichen Jul 13, 2022
74c7712
fix: minor changes for version compatibility
xionghuichen Jul 14, 2022
e36a1bc
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 14, 2022
3bdbf3e
Dev (#21)
xionghuichen Jul 14, 2022
96ba639
fix: a bug of sorting in torch-version checkpoint loading
xionghuichen Jul 14, 2022
680c6be
Dev (#22)
xionghuichen Jul 14, 2022
44c2cff
refactor: robust multi-key plot implementation
xionghuichen Jul 21, 2022
ce25856
feat: supoort pretty plotter
xionghuichen Jul 23, 2022
efc4815
refactor(log plotter): record scores of the log plotter
xionghuichen Jul 23, 2022
3fca57c
fix(exp_loader): add parameter ckp_index
xionghuichen Jul 23, 2022
017efc2
refactor(rla_script): add start_server to start_pretty_plotter.py
xionghuichen Jul 23, 2022
7c0b7dc
update readme
xionghuichen Jul 23, 2022
4239bd6
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 23, 2022
29e4932
Dev (#23)
xionghuichen Jul 23, 2022
5a0c180
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 27, 2022
d3ff59d
rm unsolved merge
xionghuichen Jul 27, 2022
f67a10a
Dev (#24)
xionghuichen Jul 27, 2022
65d2859
feat: tf-v2 compatible
xionghuichen Jul 27, 2022
845f3ab
refactor: add timestep recorder. refactor on exp_loader
xionghuichen Aug 10, 2022
9f799c3
test: add test data
xionghuichen Aug 10, 2022
a277416
feat(plot): track the hyper-parameter from the exp_manager instead of…
xionghuichen Aug 10, 2022
e9b29a7
Dev (#25)
xionghuichen Aug 10, 2022
00d26d2
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 10, 2022
f71bdcd
test(plot): add user cases and documents
xionghuichen Aug 11, 2022
a5f88dd
test(plot): add user cases
xionghuichen Aug 11, 2022
23349cd
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Aug 11, 2022
26ee5d3
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 11, 2022
bf87b78
Dev (#26)
xionghuichen Aug 11, 2022
4efee65
Update README.md (#28)
xionghuichen Aug 11, 2022
6140b9f
simplify codes
xionghuichen Sep 12, 2022
e9fb6fd
refactor: more robust freq print implementation
xionghuichen Sep 14, 2022
bf3f65b
update readme
xionghuichen Sep 25, 2022
8fdc403
update readme
xionghuichen Oct 4, 2022
9bd402e
Dev (#29)
xionghuichen Oct 4, 2022
423034e
fix(exp_loader): fix a bug of loaded experiments information print
xionghuichen Oct 4, 2022
b52053e
fix param name
xionghuichen Oct 14, 2022
48dcbc5
feat(complex-data-recorder): add image recorder
xionghuichen Nov 3, 2022
622870d
feat(exp_manager): add hyper-parameter log in `code` dir
xionghuichen Nov 3, 2022
fdacbda
docs(readme): add gitignore document
xionghuichen Nov 11, 2022
771a99a
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Nov 11, 2022
708670e
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Nov 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ We build an example project for integrating RLA, which can be seen in ./example/

### Step1: Configuration.
1. We define the property of the database in `rla_config.yaml`. You can construct your YAML file based on the template in ./example/simplest_code/rla_config.yaml.
2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this.
2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this:
```python
from RLA import exp_manager
import os
kwargs = {'env_id': 'Hopper-v2', 'lr': 1e-3}
exp_manager.set_hyper_param(**kwargs) # kwargs are the hyper-parameters for your experiment
exp_manager.add_record_param(["env_id"]) # add parts of hyper-parameters to name the index of data items for better readability.
Expand All @@ -121,11 +122,17 @@ We build an example project for integrating RLA, which can be seen in ./example/
rla_data_root = get_package_path() # the place to store the data items.

rla_config = os.path.join(get_package_path(), 'rla_config.yaml')
exp_manager.configure(task_table_name=task_name, rla_config=rla_config, data_root=rla_data_root)

ignore_file_path=os.path.join(get_package_path(), '.gitignore')
exp_manager.configure(task_table_name=task_name, ignore_file_path=ignore_file_path,
rla_config=rla_config, data_root=rla_data_root)
exp_manager.log_files_gen() # initialize the data items.
exp_manager.print_args()
```
3. We add the generated data items to .gitignore to avoid pushing them into our git repo.
where ``ignore_file_path`` is a gitignore-style file, which is used to ignored files when backing up your project into ``code`` folder.
It is an optional parameter, and you can use your `.gitignore` file of your git repository directly.

4. We add the generated data items to .gitignore to avoid pushing them into our git repo.
```gitignore
**/archive_tester/**
**/checkpoint/**
Expand Down
3 changes: 2 additions & 1 deletion RLA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from RLA.easy_log import logger
from RLA.easy_log.time_step import time_step_holder
from RLA.easy_plot.plot_func_v2 import plot_func
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder
from RLA.easy_log.complex_data_recorder import MatplotlibRecorder, ImgRecorder
from RLA.easy_log.exp_loader import ExperimentLoader
29 changes: 21 additions & 8 deletions RLA/easy_log/complex_data_recorder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import os.path as osp

import numpy as np
import seaborn as sns
sns.set_style('darkgrid', {'legend.frameon': True})

Expand All @@ -10,18 +10,23 @@
from typing import Callable
# video recorder

def format_name(name, add_timestamp, cover):
save_path = osp.join(exp_manager.results_dir, name)
save_path_split = save_path.split('/')
if add_timestamp:
save_path = '/'.join(save_path_split[:-1]) + '/' + str(time_step_holder.get_time()) + "-" + str(save_path_split[-1])
if not osp.exists(save_path) or cover:
save_dir = '/'.join(save_path.split('/')[:-1])
os.makedirs(save_dir, exist_ok=True)
return save_path


# figure recorder
class MatplotlibRecorder:
@classmethod
def save(cls, name=None, fig=None, cover=False, add_timestamp=True, **kwargs):
save_path = osp.join(exp_manager.results_dir, name)
save_path_split = save_path.split('/')
if add_timestamp:
save_path = '/'.join(save_path_split[:-1]) + '/' + str(time_step_holder.get_time()) + "-" + str(save_path_split[-1])
save_path = format_name(name, add_timestamp, cover)
if not osp.exists(save_path) or cover:
save_dir = '/'.join(save_path.split('/')[:-1])
os.makedirs(save_dir, exist_ok=True)
if fig is not None:
fig.savefig(save_path, **kwargs)
else:
Expand Down Expand Up @@ -71,4 +76,12 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
cls.save(name, cover=cover, add_timestamp=add_timestamp, bbox_extra_artists=tuple([lgd]),
bbox_inches='tight', *args, **kwargs)
else:
cls.save(name, cover=cover, add_timestamp=add_timestamp, *args, **kwargs)
cls.save(name, cover=cover, add_timestamp=add_timestamp, *args, **kwargs)

class ImgRecorder:
@classmethod
def save(cls, name=None, img=None, cover=False, add_timestamp=True, **kwargs):
import cv2
save_path = format_name(name, add_timestamp, cover)
if not osp.exists(save_path) or cover:
cv2.imwrite(save_path, img.astype(np.uint8))
6 changes: 3 additions & 3 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ExperimentLoader(object):
- resume an experiment:
0. config loaded_task_name and loaded_date to the task and timestamp of the target experiment to load respectively.
1. init your exp_manager;
2. call exp_loader.fork_tester_log_files to copy all of the log data of the target experiment to the current experiment.
2. call exp_loader.fork_log_files to copy all of the log data of the target experiment to the current experiment.
3. call exp_loader.load_from_record_date to resume the neural networks and intermediate variables.
4. start your process.
- resume an experiment with other settings.
Expand Down Expand Up @@ -63,8 +63,8 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t
for v in hp_to_overwrite:
target_hp[v] = exp_manager.hyper_param[v]
args = argparse.Namespace(**target_hp)
args.load_date = self.load_date
args.load_task_name = self.task_name
args.loaded_date = self.load_date
args.loaded_task_name = self.task_name
if sync_timestep:
load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME)
exp_manager.time_step_holder.set_time(load_iter)
Expand Down
40 changes: 34 additions & 6 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import time
import os

import json
import datetime
import os.path as osp
import pprint
Expand Down Expand Up @@ -282,6 +283,12 @@ def load_tester(cls, record_date, task_table_name, log_root):
assert isinstance(load_tester, Tester)
logger.info("update log files' root")
load_tester.update_log_files_location(root=log_root)
logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(
str(load_tester.record_date.strftime("%Y/%m/%d")) + '/' + load_tester.record_date_to_str(
load_tester.record_date), load_tester.ipaddr, load_tester.info))



return load_tester

def add_record_param(self, keys):
Expand Down Expand Up @@ -427,15 +434,26 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che
if log_type == 'dir':
search_list = dirs
elif log_type =='files':
search_list =files
search_list = files
else:
raise NotImplementedError
for search_item in search_list:
if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))):
split_dir = search_item.split(' ')

# self.__ipaddr = split_dir[1]
info = " ".join(split_dir[2:])
logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
# if version_num is None:
# split_dir = search_item.split(' ')
# info = " ".join(split_dir[2:])
# logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
#
# elif version_num == LOG_NAME_FORMAT_VERSION.V1:
# split_dir = search_item.split('_')
# info = " ".join(split_dir[2:])
# logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
#
# else:
# raise RuntimeError("unknown version name", version_num)

file_found = search_item
break
return directory, file_found
Expand Down Expand Up @@ -501,12 +519,16 @@ def __copy_source_code(self, run_file, code_dir):
def record_date_to_str(self, record_date):
return str(record_date.strftime("%H-%M-%S-%f"))

def get_version_num(self):
version_num = getattr(self, 'log_name_format_version', None)
return version_num

def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None):
if record_date is None:
record_date = self.record_date
directory = str(record_date.strftime("%Y/%m/%d"))
directory = osp.join(prefix, directory)
version_num = getattr(self, 'log_name_format_version', None)
version_num = self.get_version_num()

if version_num is None:
name_format = '{dir}/{timestep} {ip} {info}{ext}'
Expand Down Expand Up @@ -743,6 +765,13 @@ def print_args(self):
for key, value in sort_list:
# logger.info("key: %s, value: %s" % (key, value))
logger.backup("key: %s, value: %s" % (key, value))
# formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date)
params = exp_manager.hyper_param
# params['formatted_log_name'] = formatted_log_name
json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'),
sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
print("gen:", osp.join(self.code_dir, 'parameter.json'))


def print_large_memory_variable(self):
import sys
Expand All @@ -766,7 +795,6 @@ def sizeof_fmt(num, suffix='B'):
summary = self.dict_to_table_text_summary(large_mermory_dict, 'large_memory')
self.add_summary_to_logger(summary, 'large_memory')


def dict_to_table_text_summary(self, input_dict, name):
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
Expand Down
7 changes: 7 additions & 0 deletions example/simplest_code/project/ignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

**/archive_tester/**
**/checkpoint/**
**/code/**
**/results/**
**/log/**
**/arc/**
9 changes: 9 additions & 0 deletions example/simplest_code/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@ def get_param():

task_name = 'demo_task'
rla_data_root = '../'
<<<<<<< HEAD
<<<<<<< HEAD
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root,
ignore_file_path='./ignore')
=======
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root)
>>>>>>> 9bd402e505cc920aa4329f451ac34fb3b12f6347
=======
exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root)
>>>>>>> 29ab768949f26c307e4bdb07fd9d0dc15047a69d
exp_manager.log_files_gen()
exp_manager.print_args()

Expand Down