diff --git a/README.md b/README.md index 4c73f48..ee394c6 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,9 @@ PS: 2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS. # TODO -- [ ] video visualization. +- [ ] support sftp-based sync. +- [ ] support custom data structure saving and loading. +- [ ] support video visualization. - [ ] add comments and documents to the functions. - [ ] add an auto integration script. - [ ] download / upload experiment logs through timestamp. diff --git a/RLA/easy_log/exp_loader.py b/RLA/easy_log/exp_loader.py index ac4806b..a8ee0a6 100644 --- a/RLA/easy_log/exp_loader.py +++ b/RLA/easy_log/exp_loader.py @@ -32,8 +32,7 @@ class ExperimentLoader(object): def __init__(self): self.task_name = exp_manager.hyper_param.get('loaded_task_name', None) self.load_date = exp_manager.hyper_param.get('loaded_date', None) - self.root = getattr(exp_manager, 'root', None) - self.data_root = None + self.data_root = getattr(exp_manager, 'root', None) pass def config(self, task_name, record_date, root): @@ -49,12 +48,12 @@ def is_valid_config(self): logger.warn("meet invalid loader config when use it") logger.warn("load_date", self.load_date) logger.warn("task_name", self.task_name) - logger.warn("root", self.root) + logger.warn("root", self.data_root) return False def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None): if self.is_valid_config: - load_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) target_hp = copy.deepcopy(exp_manager.hyper_param) target_hp.update(load_tester.hyper_param) if hp_to_overwrite is not None: @@ -75,7 +74,7 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: :return: """ if self.is_valid_config: - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # load checkpoint load_res = {} if var_prefix is not None: @@ -100,7 +99,7 @@ def fork_log_files(self): if self.is_valid_config: global exp_manager assert isinstance(exp_manager, Tester) - loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.root) + loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root) # copy log file exp_manager.log_file_copy(loaded_tester) # copy attribute @@ -109,4 +108,4 @@ def fork_log_files(self): exp_manager.private_config = loaded_tester.private_config -exp_loader = experimental_loader = ExperimentLoader() \ No newline at end of file +exp_loader = experimental_loader = ExperimentLoader() diff --git a/RLA/easy_log/logger.py b/RLA/easy_log/logger.py index 6ac17c2..7f6ed4d 100644 --- a/RLA/easy_log/logger.py +++ b/RLA/easy_log/logger.py @@ -688,7 +688,7 @@ def configure(dir=None, format_strs=None, comm=None, framework='tensorflow'): if format_strs is None: format_strs = os.getenv('OPENAI_LOG_FORMAT', 'stdout,log,csv').split(',') format_strs = filter(None, format_strs) - output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] + output_formats = [make_output_format(f, dir, log_suffix, framework) for f in format_strs] warn_output_formats = make_output_format('warn', dir, log_suffix, framework) backup_output_formats = make_output_format('backup', dir, log_suffix, framework) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 9fce214..e7ba846 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -187,7 +187,7 @@ def _init_logger(self): self.writer = None # logger configure logger.info("store file %s" % self.pkl_file) - logger.configure(self.log_dir, self.private_config["LOG_USED"]) + logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"]) for fmt in logger.Logger.CURRENT.output_formats: if isinstance(fmt, logger.TensorBoardOutputFormat): self.writer = fmt.writer