diff --git a/README.md b/README.md index 9150f6b..055d82a 100644 --- a/README.md +++ b/README.md @@ -107,9 +107,10 @@ We build an example project for integrating RLA, which can be seen in ./example/ ### Step1: Configuration. 1. We define the property of the database in `rla_config.yaml`. You can construct your YAML file based on the template in ./example/simplest_code/rla_config.yaml. -2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this. +2. We define the property of the table in exp_manager.config. Before starting your experiment, you should configure the global object RLA.easy_log.tester.exp_manager like this: ```python from RLA import exp_manager + import os kwargs = {'env_id': 'Hopper-v2', 'lr': 1e-3} exp_manager.set_hyper_param(**kwargs) # kwargs are the hyper-parameters for your experiment exp_manager.add_record_param(["env_id"]) # add parts of hyper-parameters to name the index of data items for better readability. @@ -121,11 +122,17 @@ We build an example project for integrating RLA, which can be seen in ./example/ rla_data_root = get_package_path() # the place to store the data items. rla_config = os.path.join(get_package_path(), 'rla_config.yaml') - exp_manager.configure(task_table_name=task_name, rla_config=rla_config, data_root=rla_data_root) + + ignore_file_path=os.path.join(get_package_path(), '.gitignore') + exp_manager.configure(task_table_name=task_name, ignore_file_path=ignore_file_path, + rla_config=rla_config, data_root=rla_data_root) exp_manager.log_files_gen() # initialize the data items. exp_manager.print_args() ``` -3. We add the generated data items to .gitignore to avoid pushing them into our git repo. + where ``ignore_file_path`` is a gitignore-style file, which is used to ignored files when backing up your project into ``code`` folder. + It is an optional parameter, and you can use your `.gitignore` file of your git repository directly. + +4. We add the generated data items to .gitignore to avoid pushing them into our git repo. ```gitignore **/archive_tester/** **/checkpoint/** diff --git a/RLA/__init__.py b/RLA/__init__.py index 0a4478b..f014863 100644 --- a/RLA/__init__.py +++ b/RLA/__init__.py @@ -2,4 +2,5 @@ from RLA.easy_log import logger from RLA.easy_log.time_step import time_step_holder from RLA.easy_plot.plot_func_v2 import plot_func -from RLA.easy_log.complex_data_recorder import MatplotlibRecorder \ No newline at end of file +from RLA.easy_log.complex_data_recorder import MatplotlibRecorder, ImgRecorder +from RLA.easy_log.exp_loader import ExperimentLoader diff --git a/RLA/easy_log/complex_data_recorder.py b/RLA/easy_log/complex_data_recorder.py index c4897ce..900d835 100644 --- a/RLA/easy_log/complex_data_recorder.py +++ b/RLA/easy_log/complex_data_recorder.py @@ -1,6 +1,6 @@ import os import os.path as osp - +import numpy as np import seaborn as sns sns.set_style('darkgrid', {'legend.frameon': True}) @@ -10,18 +10,23 @@ from typing import Callable # video recorder +def format_name(name, add_timestamp, cover): + save_path = osp.join(exp_manager.results_dir, name) + save_path_split = save_path.split('/') + if add_timestamp: + save_path = '/'.join(save_path_split[:-1]) + '/' + str(time_step_holder.get_time()) + "-" + str(save_path_split[-1]) + if not osp.exists(save_path) or cover: + save_dir = '/'.join(save_path.split('/')[:-1]) + os.makedirs(save_dir, exist_ok=True) + return save_path + # figure recorder class MatplotlibRecorder: @classmethod def save(cls, name=None, fig=None, cover=False, add_timestamp=True, **kwargs): - save_path = osp.join(exp_manager.results_dir, name) - save_path_split = save_path.split('/') - if add_timestamp: - save_path = '/'.join(save_path_split[:-1]) + '/' + str(time_step_holder.get_time()) + "-" + str(save_path_split[-1]) + save_path = format_name(name, add_timestamp, cover) if not osp.exists(save_path) or cover: - save_dir = '/'.join(save_path.split('/')[:-1]) - os.makedirs(save_dir, exist_ok=True) if fig is not None: fig.savefig(save_path, **kwargs) else: @@ -71,4 +76,12 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable, cls.save(name, cover=cover, add_timestamp=add_timestamp, bbox_extra_artists=tuple([lgd]), bbox_inches='tight', *args, **kwargs) else: - cls.save(name, cover=cover, add_timestamp=add_timestamp, *args, **kwargs) \ No newline at end of file + cls.save(name, cover=cover, add_timestamp=add_timestamp, *args, **kwargs) + +class ImgRecorder: + @classmethod + def save(cls, name=None, img=None, cover=False, add_timestamp=True, **kwargs): + import cv2 + save_path = format_name(name, add_timestamp, cover) + if not osp.exists(save_path) or cover: + cv2.imwrite(save_path, img.astype(np.uint8)) \ No newline at end of file diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index 6d98c0c..56f7c4a 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -19,7 +19,7 @@ class ExperimentLoader(object): - resume an experiment: 0. config loaded_task_name and loaded_date to the task and timestamp of the target experiment to load respectively. 1. init your exp_manager; - 2. call exp_loader.fork_tester_log_files to copy all of the log data of the target experiment to the current experiment. + 2. call exp_loader.fork_log_files to copy all of the log data of the target experiment to the current experiment. 3. call exp_loader.load_from_record_date to resume the neural networks and intermediate variables. 4. start your process. - resume an experiment with other settings. @@ -63,8 +63,8 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t for v in hp_to_overwrite: target_hp[v] = exp_manager.hyper_param[v] args = argparse.Namespace(**target_hp) - args.load_date = self.load_date - args.load_task_name = self.task_name + args.loaded_date = self.load_date + args.loaded_task_name = self.task_name if sync_timestep: load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME) exp_manager.time_step_holder.set_time(load_iter) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index cc2dc1c..31fa723 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -11,6 +11,7 @@ import time import os +import json import datetime import os.path as osp import pprint @@ -282,6 +283,12 @@ def load_tester(cls, record_date, task_table_name, log_root): assert isinstance(load_tester, Tester) logger.info("update log files' root") load_tester.update_log_files_location(root=log_root) + logger.info("load data: \n ts {}, \n ip {}, \n info {}".format( + str(load_tester.record_date.strftime("%Y/%m/%d")) + '/' + load_tester.record_date_to_str( + load_tester.record_date), load_tester.ipaddr, load_tester.info)) + + + return load_tester def add_record_param(self, keys): @@ -427,15 +434,26 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che if log_type == 'dir': search_list = dirs elif log_type =='files': - search_list =files + search_list = files else: raise NotImplementedError for search_item in search_list: if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))): - split_dir = search_item.split(' ') + # self.__ipaddr = split_dir[1] - info = " ".join(split_dir[2:]) - logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info)) + # if version_num is None: + # split_dir = search_item.split(' ') + # info = " ".join(split_dir[2:]) + # logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info)) + # + # elif version_num == LOG_NAME_FORMAT_VERSION.V1: + # split_dir = search_item.split('_') + # info = " ".join(split_dir[2:]) + # logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info)) + # + # else: + # raise RuntimeError("unknown version name", version_num) + file_found = search_item break return directory, file_found @@ -501,12 +519,16 @@ def __copy_source_code(self, run_file, code_dir): def record_date_to_str(self, record_date): return str(record_date.strftime("%H-%M-%S-%f")) + def get_version_num(self): + version_num = getattr(self, 'log_name_format_version', None) + return version_num + def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None): if record_date is None: record_date = self.record_date directory = str(record_date.strftime("%Y/%m/%d")) directory = osp.join(prefix, directory) - version_num = getattr(self, 'log_name_format_version', None) + version_num = self.get_version_num() if version_num is None: name_format = '{dir}/{timestep} {ip} {info}{ext}' @@ -743,6 +765,13 @@ def print_args(self): for key, value in sort_list: # logger.info("key: %s, value: %s" % (key, value)) logger.backup("key: %s, value: %s" % (key, value)) + # formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date) + params = exp_manager.hyper_param + # params['formatted_log_name'] = formatted_log_name + json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'), + sort_keys=True, indent=4, allow_nan=True, default=lambda o: '') + print("gen:", osp.join(self.code_dir, 'parameter.json')) + def print_large_memory_variable(self): import sys @@ -766,7 +795,6 @@ def sizeof_fmt(num, suffix='B'): summary = self.dict_to_table_text_summary(large_mermory_dict, 'large_memory') self.add_summary_to_logger(summary, 'large_memory') - def dict_to_table_text_summary(self, input_dict, name): import tensorflow as tf with tf.Session(graph=tf.Graph()) as sess: diff --git a/example/simplest_code/project/ignore b/example/simplest_code/project/ignore new file mode 100644 index 0000000..461affd --- /dev/null +++ b/example/simplest_code/project/ignore @@ -0,0 +1,7 @@ + +**/archive_tester/** +**/checkpoint/** +**/code/** +**/results/** +**/log/** +**/arc/** \ No newline at end of file diff --git a/example/simplest_code/project/main.py b/example/simplest_code/project/main.py index c0fad91..84ac737 100644 --- a/example/simplest_code/project/main.py +++ b/example/simplest_code/project/main.py @@ -28,7 +28,16 @@ def get_param(): task_name = 'demo_task' rla_data_root = '../' +<<<<<<< HEAD +<<<<<<< HEAD +exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root, + ignore_file_path='./ignore') +======= exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root) +>>>>>>> 9bd402e505cc920aa4329f451ac34fb3b12f6347 +======= +exp_manager.configure(task_name, rla_config='../rla_config.yaml', data_root=rla_data_root) +>>>>>>> 29ab768949f26c307e4bdb07fd9d0dc15047a69d exp_manager.log_files_gen() exp_manager.print_args()