diff --git a/README.md b/README.md index 8498ec9..5d8a04e 100644 --- a/README.md +++ b/README.md @@ -240,7 +240,7 @@ for i in range(1000): # your trianing code. exp_manager.sync_log_file() ``` -then the data items we be sent to the `remote_data_root` of the main node. Since `SEND_LOG_FILE` is set to False in the main node, the `exp_manager.sync_log_file()` will be skipped in the main node. +then the data items we be sent to the `remote_data_root` of the main node. Since `is_master_node` is set to True in the main node, the `exp_manager.sync_log_file()` will be skipped in the main node. PS: 1. You might meet "socket.error: [Errno 111] Connection refused" problem in this process. The solution can be found [here](https://stackoverflow.com/a/70784201/6055868). diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index 6183223..043e22f 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -72,7 +72,7 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t else: return argparse.Namespace(**exp_manager.hyper_param) - def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None): + def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None, verbose=True): """ :param var_prefix: the prefix of namescope (for tf) to load. Set to '' to load all of the parameters. @@ -81,8 +81,9 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: """ if self.is_valid_config: loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) - print("attrs of the loaded tester") - pprint(loaded_tester.__dict__) + if verbose: + print("attrs of the loaded tester") + pprint(loaded_tester.__dict__) # load checkpoint load_res = {} if var_prefix is not None: diff --git a/RLA/easy_log/log_tools.py b/RLA/easy_log/log_tools.py index 5417f23..1f0df8e 100644 --- a/RLA/easy_log/log_tools.py +++ b/RLA/easy_log/log_tools.py @@ -106,10 +106,12 @@ def __init__(self, proj_root, task_table_name, regex, filter, *args, **kwargs): self.small_timestep_regs = [] super(DeleteLogTool, self).__init__(*args, **kwargs) - def _delete_related_log(self, regex, show=False): + def _delete_related_log(self, regex, show=False, delete_log_types=None): log_found = 0 for log_type in self.log_types: print(f"--- search {log_type} ---") + if delete_log_types is not None and log_type not in delete_log_types: + continue root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, regex) empty = True for root_dir in glob.glob(root_dir_regex): @@ -144,15 +146,15 @@ def _delete_related_log(self, regex, show=False): if empty: print("empty regex {}".format(root_dir_regex)) return log_found - def delete_related_log(self, skip_ask=False): - self._delete_related_log(show=True, regex=self.regex) + def delete_related_log(self, skip_ask=False, delete_log_types=None): + self._delete_related_log(show=True, regex=self.regex, delete_log_types=delete_log_types) if skip_ask: s = 'y' else: s = input("delete these files? (y/n)") if s == 'y': print("do delete ...") - return self._delete_related_log(show=False, regex=self.regex) + return self._delete_related_log(show=False, regex=self.regex, delete_log_types=delete_log_types) else: return 0 @@ -209,7 +211,11 @@ def _archive_log(self, show=False): if os.path.isdir(root_dir): if not show: # os.makedirs(archiving_target, exist_ok=True) - shutil.copytree(root_dir, archiving_target) + try: + shutil.copytree(root_dir, archiving_target) + except FileExistsError as e: + print(e) + print("copy dir {}, to {}".format(root_dir, archiving_target)) else: if not show: diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 6a04be0..46233bb 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -15,6 +15,7 @@ import os.path as osp import pprint +import numpy as np import tensorboardX from RLA.easy_log.time_step import time_step_holder @@ -222,8 +223,8 @@ def update_log_files_location(self, root:str): task_table_name = getattr(self, 'task_name', None) print("[WARN] you are using an old-version RLA. " "Some attributes' name have been changed (task_name->task_table_name).") - else: - raise RuntimeError("invalid ExpManager: task_table_name cannot be found", ) + if task_table_name is None: + raise RuntimeError("invalid ExpManager: task_table_name cannot be found", ) code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False) log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False) self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl') @@ -430,9 +431,15 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che raise NotImplementedError for search_item in search_list: if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))): - split_dir = search_item.split(' ') - # self.__ipaddr = split_dir[1] - info = " ".join(split_dir[2:]) + try: + split_dir = search_item.split('_') + assert len(split_dir) >= 2 + info = " ".join(split_dir[2:]) + except AssertionError as e: + split_dir = search_item.split(' ') + # self.__ipaddr = split_dir[1] + info = "_".join(split_dir[2:]) + print("[WARN] We find an old-version experiment data.") logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info)) file_found = search_item break @@ -615,11 +622,17 @@ def load_checkpoint(self, ckp_index=None): return int(max_iter), None elif self.dl_framework == FRAMEWORK.torch: import torch - all_ckps = sorted(os.listdir(self.checkpoint_dir)) + all_ckps = os.listdir(self.checkpoint_dir) + ites = [] + for ckps in all_ckps: + ites.append(int(ckps.split('checkpoint-')[1].split('.pt')[0])) + idx = np.argsort(ites) + all_ckps = np.array(all_ckps)[idx] print("all checkpoints:") pprint.pprint(all_ckps) if ckp_index is None: ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0] + print("loaded checkpoints:", "checkpoint-{}.pt".format(ckp_index)) return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index)) def auto_parse_info(self):