From bb2720b72164723c2a1a1fb36bb428fd00f2a8f5 Mon Sep 17 00:00:00 2001 From: Xiong-Hui Chen Date: Wed, 13 Jul 2022 12:28:31 +0800 Subject: [PATCH 1/2] fix: fix bugs of torch-version ckp loader --- README.md | 3 +- RLA/const.py | 3 ++ RLA/easy_log/exp_loader.py | 18 +++++--- RLA/easy_log/log_tools.py | 4 ++ RLA/easy_log/tester.py | 65 ++++++++++++++++++++------- RLA/rla_argparser.py | 19 ++++++++ example/simplest_code/project/main.py | 1 + 7 files changed, 91 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index e6e3154..8498ec9 100644 --- a/README.md +++ b/README.md @@ -248,10 +248,11 @@ PS: 2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS. # TODO -- [ ] support sftp-based sync. - [ ] support custom data structure saving and loading. - [ ] support video visualization. - [ ] add comments and documents to the functions. - [ ] add an auto integration script. - [ ] download / upload experiment logs through timestamp. - [ ] add a document to the plot function. +- [ ] allow sync LOG only or ALL TYPE LOGS. +- [ ] support aim and smarter logger. diff --git a/RLA/const.py b/RLA/const.py index a502cf6..20d3a89 100644 --- a/RLA/const.py +++ b/RLA/const.py @@ -7,3 +7,6 @@ class FRAMEWORK: class FTP_PROTOCOL_NAME: FTP = 'ftp' SFTP = 'sftp' + +class LOG_NAME_FORMAT_VERSION: + V1 = 'v1' \ No newline at end of file diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index a8ee0a6..5f96369 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -4,6 +4,7 @@ import argparse from typing import Optional, OrderedDict, Union, Dict, Any from RLA.const import DEFAULT_X_NAME +from pprint import pprint class ExperimentLoader(object): """ @@ -32,7 +33,9 @@ class ExperimentLoader(object): def __init__(self): self.task_name = exp_manager.hyper_param.get('loaded_task_name', None) self.load_date = exp_manager.hyper_param.get('loaded_date', None) - self.data_root = getattr(exp_manager, 'root', None) + self.data_root = getattr(exp_manager, 'data_root', None) + if self.data_root is None: + self.data_root = getattr(exp_manager, 'root', None) pass def config(self, task_name, record_date, root): @@ -53,15 +56,17 @@ def is_valid_config(self): def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): if self.is_valid_config: - load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) target_hp = copy.deepcopy(exp_manager.hyper_param) - target_hp.update(load_tester.hyper_param) + target_hp.update(loaded_tester.hyper_param) if hp_to_overwrite is not None: for v in hp_to_overwrite: target_hp[v] = exp_manager.hyper_param[v] args = argparse.Namespace(**target_hp) args.load_date = self.load_date args.load_task_name = self.task_name + load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME) + exp_manager.time_step_holder.set_time(load_iter) return args else: return argparse.Namespace(**exp_manager.hyper_param) @@ -75,18 +80,21 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: """ if self.is_valid_config: loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) + print("attrs of the loaded tester") + pprint(loaded_tester.__dict__) # load checkpoint load_res = {} if var_prefix is not None: loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1) _, load_res = loaded_tester.load_checkpoint() - exp_manager.print_log_dir() + else: + loaded_tester.new_saver(max_to_keep=1) + _, load_res = loaded_tester.load_checkpoint() hist_variables = {} if variable_list is not None: for v in variable_list: hist_variables[v] = loaded_tester.get_custom_data(v) load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME) - exp_manager.time_step_holder.set_time(load_iter) return load_iter, load_res, hist_variables else: return 0, {}, {} diff --git a/RLA/easy_log/log_tools.py b/RLA/easy_log/log_tools.py index 76bcdd4..5417f23 100644 --- a/RLA/easy_log/log_tools.py +++ b/RLA/easy_log/log_tools.py @@ -169,6 +169,10 @@ def delete_small_timestep_log(self, skip_ask=False): for res in self.small_timestep_regs: print("[delete small-timestep log] reg: ", res[1]) self._delete_related_log(show=True, regex=res[0] + '*') + print("summarize:") + for count, res in enumerate(self.small_timestep_regs): + print(f"[delete small-timestep log] {count} reg: {res[1]}") + if skip_ask: s = 'y' else: diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 6ec2720..6a04be0 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -17,10 +17,10 @@ import tensorboardX -from RLA.easy_log.const import * from RLA.easy_log.time_step import time_step_holder from RLA.easy_log import logger from RLA.easy_log.const import * +from RLA.const import * import yaml import shutil import argparse @@ -107,6 +107,8 @@ def __init__(self): self.code_dir = None self.saver = None self.dl_framework = None + self.checkpoint_keep_list = None + self.log_name_format_version = LOG_NAME_FORMAT_VERSION.V1 @deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root') def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str, @@ -205,13 +207,28 @@ def log_files_gen(self): self._feed_hyper_params_to_tb() self.print_log_dir() - def update_log_files_location(self, root): + def update_log_files_location(self, root:str): + """ + This function is designed for the requirement of using copied/moved experiment logs to other databases for downstream task. + The location of the experiment logs might have changed compared with their original location. + The function automatically update the attributes related to the data_root to the current location. + :param root: current data_root + :type root: str + """ self.data_root = root - code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, self.task_table_name), '', is_file=False) - log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, self.task_table_name), '', is_file=False) - self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, self.task_table_name), '.pkl') - self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, self.task_table_name), is_file=False) - self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, self.task_table_name), is_file=False) + + task_table_name = getattr(self, 'task_table_name', None) + if task_table_name is None: + task_table_name = getattr(self, 'task_name', None) + print("[WARN] you are using an old-version RLA. " + "Some attributes' name have been changed (task_name->task_table_name).") + else: + raise RuntimeError("invalid ExpManager: task_table_name cannot be found", ) + code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False) + log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False) + self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl') + self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, task_table_name), is_file=False) + self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, task_table_name), is_file=False) self.log_dir = log_dir self.code_dir = code_dir self.print_log_dir() @@ -487,15 +504,23 @@ def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None record_date = self.record_date directory = str(record_date.strftime("%Y/%m/%d")) directory = osp.join(prefix, directory) + version_num = getattr(self, 'log_name_format_version', None) + + if version_num is None: + name_format = '{dir}/{timestep} {ip} {info}{ext}' + elif version_num == LOG_NAME_FORMAT_VERSION.V1: + name_format = '{dir}/{timestep}_{ip}_{info}{ext}' + else: + raise RuntimeError("unknown version name", version_num) + if is_file: os.makedirs(directory, exist_ok=True) - file_name = '{dir}/{timestep}_{ip}_{info}{ext}'.format(dir=directory, - timestep=self.record_date_to_str(record_date), + file_name = name_format.format(dir=directory, timestep=self.record_date_to_str(record_date), ip=str(self.ipaddr), info=self.info, ext=ext) else: - directory = '{dir}/{timestep}_{ip}_{info}{ext}/'.format(dir=directory, + directory = (name_format + '/').format(dir=directory, timestep=self.record_date_to_str(record_date), ip=str(self.ipaddr), info=self.info, @@ -545,7 +570,6 @@ def new_saver(self, max_to_keep, var_prefix=None): self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True) elif self.dl_framework == FRAMEWORK.torch: self.max_to_keep = max_to_keep - self.checkpoint_keep_list = [] else: raise NotImplementedError @@ -558,6 +582,8 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) elif self.dl_framework == FRAMEWORK.torch: import torch + if self.checkpoint_keep_list is None: + self.checkpoint_keep_list = [] iter = self.time_step_holder.get_time() torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter)) self.checkpoint_keep_list.append(iter) @@ -574,20 +600,27 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt self.add_custom_data(k, v, type(v), mode='replace') self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace') - def load_checkpoint(self): + def load_checkpoint(self, ckp_index=None): if self.dl_framework == FRAMEWORK.tensorflow: # TODO: load with variable scope. import tensorflow as tf cpt_name = osp.join(self.checkpoint_dir) logger.info("load checkpoint {}".format(cpt_name)) - ckpt_path = tf.train.latest_checkpoint(cpt_name) + if ckp_index is None: + ckpt_path = tf.train.latest_checkpoint(cpt_name) + else: + ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index) self.saver.restore(tf.get_default_session(), ckpt_path) max_iter = ckpt_path.split('-')[-1] - self.time_step_holder.set_time(max_iter) return int(max_iter), None elif self.dl_framework == FRAMEWORK.torch: import torch - return self.checkpoint_keep_list[-1], torch.load(tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1])) + all_ckps = sorted(os.listdir(self.checkpoint_dir)) + print("all checkpoints:") + pprint.pprint(all_ckps) + if ckp_index is None: + ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0] + return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index)) def auto_parse_info(self): return '&'.join(self.hyper_param_record) @@ -648,7 +681,7 @@ def serialize_object_and_save(self): saver = self.saver self.saver = None with open(self.pkl_file, 'wb') as f: - dill.dump(self, f) + dill.dump(self, f, recurse=True) self.writer = writer self.saver = saver diff --git a/RLA/rla_argparser.py b/RLA/rla_argparser.py index a0b3dc1..43970eb 100644 --- a/RLA/rla_argparser.py +++ b/RLA/rla_argparser.py @@ -1,6 +1,25 @@ import argparse +def boolean_flag(parser: argparse.ArgumentParser, name, default=False, help=None): + """Add a boolean flag to argparse parser. + + Parameters + ---------- + parser: argparse.Parser + parser to add the flag to + name: str + -- will enable the flag, while --no- will disable it + default: bool or None + default value of the flag + help: str + help string for the flag + """ + dest = name.replace('-', '_') + parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help) + parser.add_argument("--no-" + name, action="store_false", dest=dest) + + def arg_parser_postprocess(parser: argparse.ArgumentParser): parser.add_argument('--loaded_task_name', default='', type=str) parser.add_argument('--info', default='default exp info', type=str) diff --git a/example/simplest_code/project/main.py b/example/simplest_code/project/main.py index 8ff68e8..f9934dd 100644 --- a/example/simplest_code/project/main.py +++ b/example/simplest_code/project/main.py @@ -15,6 +15,7 @@ def get_param(): parser.add_argument('--env_id', help='environment ID', default='Test-v1') parser.add_argument('--learning_rate', help='a hyperparameter', default=1e-3, type=float) parser.add_argument('--input_size', help='a hyperparameter', default=16, type=int) + # NOTE: add some recommended hyper-parameters for RLA. parser = arg_parser_postprocess(parser) args = parser.parse_args() kwargs = vars(args) From 62dc62df51d28c5a9e9ddd75abf38a29d72f0c9b Mon Sep 17 00:00:00 2001 From: Xiong-Hui Chen Date: Wed, 13 Jul 2022 17:42:06 +0800 Subject: [PATCH 2/2] refactor: add sync_timestep for hp loader --- RLA/easy_log/exp_loader.py | 7 ++++--- setup.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index 5f96369..6183223 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -54,7 +54,7 @@ def is_valid_config(self): logger.warn("root", self.data_root) return False - def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): + def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_timestep=False): if self.is_valid_config: loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) target_hp = copy.deepcopy(exp_manager.hyper_param) @@ -65,8 +65,9 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): args = argparse.Namespace(**target_hp) args.load_date = self.load_date args.load_task_name = self.task_name - load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME) - exp_manager.time_step_holder.set_time(load_iter) + if sync_timestep: + load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME) + exp_manager.time_step_holder.set_time(load_iter) return args else: return argparse.Namespace(**exp_manager.hyper_param) diff --git a/setup.py b/setup.py index f6d9e8b..ebafbfe 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='RLA', - version="0.5.3", + version="0.6.0-pre", description=( 'RL assistant' ),