Skip to content
Merged

Dev #11

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ for i in range(1000):
# your trianing code.
exp_manager.sync_log_file()
```
then the data items we be sent to the `remote_data_root` of the main node. Since `SEND_LOG_FILE` is set to False in the main node, the `exp_manager.sync_log_file()` will be skipped in the main node.
then the data items we be sent to the `remote_data_root` of the main node. Since `is_master_node` is set to True in the main node, the `exp_manager.sync_log_file()` will be skipped in the main node.

PS:
1. You might meet "socket.error: [Errno 111] Connection refused" problem in this process. The solution can be found [here](https://stackoverflow.com/a/70784201/6055868).
Expand Down
7 changes: 4 additions & 3 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t
else:
return argparse.Namespace(**exp_manager.hyper_param)

def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None):
def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None, verbose=True):
"""

:param var_prefix: the prefix of namescope (for tf) to load. Set to '' to load all of the parameters.
Expand All @@ -81,8 +81,9 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
if verbose:
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand Down
16 changes: 11 additions & 5 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def __init__(self, proj_root, task_table_name, regex, filter, *args, **kwargs):
self.small_timestep_regs = []
super(DeleteLogTool, self).__init__(*args, **kwargs)

def _delete_related_log(self, regex, show=False):
def _delete_related_log(self, regex, show=False, delete_log_types=None):
log_found = 0
for log_type in self.log_types:
print(f"--- search {log_type} ---")
if delete_log_types is not None and log_type not in delete_log_types:
continue
root_dir_regex = osp.join(self.proj_root, log_type, self.task_table_name, regex)
empty = True
for root_dir in glob.glob(root_dir_regex):
Expand Down Expand Up @@ -144,15 +146,15 @@ def _delete_related_log(self, regex, show=False):
if empty: print("empty regex {}".format(root_dir_regex))
return log_found

def delete_related_log(self, skip_ask=False):
self._delete_related_log(show=True, regex=self.regex)
def delete_related_log(self, skip_ask=False, delete_log_types=None):
self._delete_related_log(show=True, regex=self.regex, delete_log_types=delete_log_types)
if skip_ask:
s = 'y'
else:
s = input("delete these files? (y/n)")
if s == 'y':
print("do delete ...")
return self._delete_related_log(show=False, regex=self.regex)
return self._delete_related_log(show=False, regex=self.regex, delete_log_types=delete_log_types)
else:
return 0

Expand Down Expand Up @@ -209,7 +211,11 @@ def _archive_log(self, show=False):
if os.path.isdir(root_dir):
if not show:
# os.makedirs(archiving_target, exist_ok=True)
shutil.copytree(root_dir, archiving_target)
try:
shutil.copytree(root_dir, archiving_target)
except FileExistsError as e:
print(e)

print("copy dir {}, to {}".format(root_dir, archiving_target))
else:
if not show:
Expand Down
25 changes: 19 additions & 6 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os.path as osp
import pprint

import numpy as np
import tensorboardX

from RLA.easy_log.time_step import time_step_holder
Expand Down Expand Up @@ -222,8 +223,8 @@ def update_log_files_location(self, root:str):
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
else:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
if task_table_name is None:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl')
Expand Down Expand Up @@ -430,9 +431,15 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che
raise NotImplementedError
for search_item in search_list:
if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))):
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = " ".join(split_dir[2:])
try:
split_dir = search_item.split('_')
assert len(split_dir) >= 2
info = " ".join(split_dir[2:])
except AssertionError as e:
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = "_".join(split_dir[2:])
print("[WARN] We find an old-version experiment data.")
logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
file_found = search_item
break
Expand Down Expand Up @@ -615,11 +622,17 @@ def load_checkpoint(self, ckp_index=None):
return int(max_iter), None
elif self.dl_framework == FRAMEWORK.torch:
import torch
all_ckps = sorted(os.listdir(self.checkpoint_dir))
all_ckps = os.listdir(self.checkpoint_dir)
ites = []
for ckps in all_ckps:
ites.append(int(ckps.split('checkpoint-')[1].split('.pt')[0]))
idx = np.argsort(ites)
all_ckps = np.array(all_ckps)[idx]
print("all checkpoints:")
pprint.pprint(all_ckps)
if ckp_index is None:
ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0]
print("loaded checkpoints:", "checkpoint-{}.pt".format(ckp_index))
return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index))

def auto_parse_info(self):
Expand Down