diff --git a/RLA/easy_log/complex_data_recorder.py b/RLA/easy_log/complex_data_recorder.py index 900d835..1bf53ab 100644 --- a/RLA/easy_log/complex_data_recorder.py +++ b/RLA/easy_log/complex_data_recorder.py @@ -54,7 +54,7 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable, :param title: title of the plotted figure :type title: str :param add_timestamp: add the timestamp (recorded by the timestep holder) to the name. - :type add_timestamp: str + :type add_timestamp: bool :param args: other parameters to plt.savefig :type args: :param kwargs: other parameters to plt.savefig diff --git a/RLA/easy_log/logger.py b/RLA/easy_log/logger.py index d43c0a2..4160823 100644 --- a/RLA/easy_log/logger.py +++ b/RLA/easy_log/logger.py @@ -406,7 +406,16 @@ def timestep(): ma_dict = {} -def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None): +def ma_record_tabular(key, val, record_len:[int], ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None): + """ + Moving Averaged log recorder + :param key: save to log this key + :param val: save to log this value + :param record_len: sliding window size for averaged value computation + :param ignore_nan: ignore the nan value or not + :param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard') + :param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq. + """ if key not in ma_dict: ma_dict[key] = deque(maxlen=record_len) if ignore_nan: @@ -428,6 +437,9 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Opt :param key: (Any) save to log this key :param val: (Any) save to log this value + :param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard') + :param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq. + """ if key not in lst_print_dict: lst_print_dict[key] = -np.inf @@ -629,6 +641,9 @@ def logkv(self, key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None): :param key: (Any) save to log this key :param val: (Any) save to log this value + :param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard') + :param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq. + """ self.name2val[key] = val self.exclude_name[key] = exclude diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 66f7807..ac36bc4 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -113,7 +113,7 @@ def __init__(self): @deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root') def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str, - ignore_file_path: Optional[str] = None, run_file: Optional[str] = None, + ignore_file_path: Optional[str] = None, run_file: Union[str, List[str]] = None, is_master_node: bool = False, code_root: Optional[str] = None): """ The function to configure your exp_manager, which should be run before your experiments. @@ -131,7 +131,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo :type ignore_file_path: str :param run_file: If you have extra files out of your codebase (e.g., some scripts to run the code), you can pass it to the run_file. Then we will backup the run_file too. - :type run_file: str + :type run_file: str or list :param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True), you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False). :type is_master_node: bool @@ -502,18 +502,25 @@ def get_ignore_files(self, src, names): def __copy_source_code(self, run_file, code_dir): import shutil + def _copy_run_file(run_file, code_dir): + if type(run_file) == list: + for file_name in run_file: + shutil.copy(file_name, code_dir) + else: + shutil.copy(run_file, code_dir) if self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'lib': assert os.listdir(code_dir) == [] os.removedirs(code_dir) shutil.copytree(osp.join(self.project_root, self.private_config["BACKUP_CONFIG"]["lib_dir"]), code_dir) assert run_file is not None, "you should define the run_file in lib backup mode." - shutil.copy(run_file, code_dir) + _copy_run_file(run_file, code_dir) elif self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'source': - for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]: - shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name), - ignore=self.get_ignore_files) + if self.private_config["BACKUP_CONFIG"].get("backup_code_dir"): + for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]: + shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name), + ignore=self.get_ignore_files) if run_file is not None: - shutil.copy(run_file, code_dir) + _copy_run_file(run_file, code_dir) else: raise NotImplementedError diff --git a/test/test_proj/proj/test_manager.py b/test/test_proj/proj/test_manager.py index a156f53..1414baf 100644 --- a/test/test_proj/proj/test_manager.py +++ b/test/test_proj/proj/test_manager.py @@ -115,6 +115,7 @@ def test_log_torch(self): logger.record_tabular("y_out-long", np.mean(y), freq=25) def plot_func(): import matplotlib.pyplot as plt + # plt.switch_backend('agg') testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1) testX = testX.astype(np.float32) testY = target_func(testX)