Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion RLA/easy_log/complex_data_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
:param title: title of the plotted figure
:type title: str
:param add_timestamp: add the timestamp (recorded by the timestep holder) to the name.
:type add_timestamp: str
:type add_timestamp: bool
:param args: other parameters to plt.savefig
:type args:
:param kwargs: other parameters to plt.savefig
Expand Down
17 changes: 16 additions & 1 deletion RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,16 @@ def timestep():
ma_dict = {}


def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
def ma_record_tabular(key, val, record_len:[int], ignore_nan=False, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
"""
Moving Averaged log recorder
:param key: save to log this key
:param val: save to log this value
:param record_len: sliding window size for averaged value computation
:param ignore_nan: ignore the nan value or not
:param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard')
:param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq.
"""
if key not in ma_dict:
ma_dict[key] = deque(maxlen=record_len)
if ignore_nan:
Expand All @@ -428,6 +437,9 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Opt

:param key: (Any) save to log this key
:param val: (Any) save to log this value
:param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard')
:param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq.

"""
if key not in lst_print_dict:
lst_print_dict[key] = -np.inf
Expand Down Expand Up @@ -629,6 +641,9 @@ def logkv(self, key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None):

:param key: (Any) save to log this key
:param val: (Any) save to log this value
:param exclude: exclude to save the log to some types of logger (e.g., 'stdout', 'log', 'json', 'csv' or 'tensorboard')
:param freq: the log will be dumped only after the timestep gap (holden by the time_step_holder) of recording is larger than freq.

"""
self.name2val[key] = val
self.exclude_name[key] = exclude
Expand Down
21 changes: 14 additions & 7 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self):

@deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root')
def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str,
ignore_file_path: Optional[str] = None, run_file: Optional[str] = None,
ignore_file_path: Optional[str] = None, run_file: Union[str, List[str]] = None,
is_master_node: bool = False, code_root: Optional[str] = None):
"""
The function to configure your exp_manager, which should be run before your experiments.
Expand All @@ -131,7 +131,7 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
:type ignore_file_path: str
:param run_file: If you have extra files out of your codebase (e.g., some scripts to run the code), you can pass it to the run_file.
Then we will backup the run_file too.
:type run_file: str
:type run_file: str or list
:param is_master_node: In "distributed training & centralized logs" mode (By set SEND_LOG_FILE in rla_config.yaml to True),
you should mark the master node (is_master_node=True) to collect logs of the slave nodes (is_master_node=False).
:type is_master_node: bool
Expand Down Expand Up @@ -502,18 +502,25 @@ def get_ignore_files(self, src, names):

def __copy_source_code(self, run_file, code_dir):
import shutil
def _copy_run_file(run_file, code_dir):
if type(run_file) == list:
for file_name in run_file:
shutil.copy(file_name, code_dir)
else:
shutil.copy(run_file, code_dir)
if self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'lib':
assert os.listdir(code_dir) == []
os.removedirs(code_dir)
shutil.copytree(osp.join(self.project_root, self.private_config["BACKUP_CONFIG"]["lib_dir"]), code_dir)
assert run_file is not None, "you should define the run_file in lib backup mode."
shutil.copy(run_file, code_dir)
_copy_run_file(run_file, code_dir)
elif self.private_config["PROJECT_TYPE"]["backup_code_by"] == 'source':
for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]:
shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name),
ignore=self.get_ignore_files)
if self.private_config["BACKUP_CONFIG"].get("backup_code_dir"):
for dir_name in self.private_config["BACKUP_CONFIG"]["backup_code_dir"]:
shutil.copytree(osp.join(self.project_root, dir_name), osp.join(code_dir, dir_name),
ignore=self.get_ignore_files)
if run_file is not None:
shutil.copy(run_file, code_dir)
_copy_run_file(run_file, code_dir)
else:
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions test/test_proj/proj/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_log_torch(self):
logger.record_tabular("y_out-long", np.mean(y), freq=25)
def plot_func():
import matplotlib.pyplot as plt
# plt.switch_backend('agg')
testX = np.repeat(np.expand_dims(np.arange(-10, 10, 0.1), axis=-1), repeats=kwargs["input_size"], axis=-1)
testX = testX.astype(np.float32)
testY = target_func(testX)
Expand Down