Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,11 @@ PS:
2. An alternative way is building your own NFS for your physical machines and locate data_root to the NFS.

# TODO
- [ ] support sftp-based sync.
- [ ] support custom data structure saving and loading.
- [ ] support video visualization.
- [ ] add comments and documents to the functions.
- [ ] add an auto integration script.
- [ ] download / upload experiment logs through timestamp.
- [ ] add a document to the plot function.
- [ ] allow sync LOG only or ALL TYPE LOGS.
- [ ] support aim and smarter logger.
3 changes: 3 additions & 0 deletions RLA/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ class FRAMEWORK:
class FTP_PROTOCOL_NAME:
FTP = 'ftp'
SFTP = 'sftp'

class LOG_NAME_FORMAT_VERSION:
V1 = 'v1'
21 changes: 15 additions & 6 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
from typing import Optional, OrderedDict, Union, Dict, Any
from RLA.const import DEFAULT_X_NAME
from pprint import pprint

class ExperimentLoader(object):
"""
Expand Down Expand Up @@ -32,7 +33,9 @@ class ExperimentLoader(object):
def __init__(self):
self.task_name = exp_manager.hyper_param.get('loaded_task_name', None)
self.load_date = exp_manager.hyper_param.get('loaded_date', None)
self.data_root = getattr(exp_manager, 'root', None)
self.data_root = getattr(exp_manager, 'data_root', None)
if self.data_root is None:
self.data_root = getattr(exp_manager, 'root', None)
pass

def config(self, task_name, record_date, root):
Expand All @@ -51,17 +54,20 @@ def is_valid_config(self):
logger.warn("root", self.data_root)
return False

def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None):
def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_timestep=False):
if self.is_valid_config:
load_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
target_hp = copy.deepcopy(exp_manager.hyper_param)
target_hp.update(load_tester.hyper_param)
target_hp.update(loaded_tester.hyper_param)
if hp_to_overwrite is not None:
for v in hp_to_overwrite:
target_hp[v] = exp_manager.hyper_param[v]
args = argparse.Namespace(**target_hp)
args.load_date = self.load_date
args.load_task_name = self.task_name
if sync_timestep:
load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME)
exp_manager.time_step_holder.set_time(load_iter)
return args
else:
return argparse.Namespace(**exp_manager.hyper_param)
Expand All @@ -75,18 +81,21 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
# load checkpoint
load_res = {}
if var_prefix is not None:
loaded_tester.new_saver(var_prefix=var_prefix, max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
exp_manager.print_log_dir()
else:
loaded_tester.new_saver(max_to_keep=1)
_, load_res = loaded_tester.load_checkpoint()
hist_variables = {}
if variable_list is not None:
for v in variable_list:
hist_variables[v] = loaded_tester.get_custom_data(v)
load_iter = loaded_tester.get_custom_data(DEFAULT_X_NAME)
exp_manager.time_step_holder.set_time(load_iter)
return load_iter, load_res, hist_variables
else:
return 0, {}, {}
Expand Down
4 changes: 4 additions & 0 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def delete_small_timestep_log(self, skip_ask=False):
for res in self.small_timestep_regs:
print("[delete small-timestep log] reg: ", res[1])
self._delete_related_log(show=True, regex=res[0] + '*')
print("summarize:")
for count, res in enumerate(self.small_timestep_regs):
print(f"[delete small-timestep log] {count} reg: {res[1]}")

if skip_ask:
s = 'y'
else:
Expand Down
65 changes: 49 additions & 16 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import tensorboardX

from RLA.easy_log.const import *
from RLA.easy_log.time_step import time_step_holder
from RLA.easy_log import logger
from RLA.easy_log.const import *
from RLA.const import *
import yaml
import shutil
import argparse
Expand Down Expand Up @@ -107,6 +107,8 @@ def __init__(self):
self.code_dir = None
self.saver = None
self.dl_framework = None
self.checkpoint_keep_list = None
self.log_name_format_version = LOG_NAME_FORMAT_VERSION.V1

@deprecated_alias(task_name='task_table_name', private_config_path='rla_config', log_root='data_root')
def configure(self, task_table_name: str, rla_config: Union[str, dict], data_root: str,
Expand Down Expand Up @@ -205,13 +207,28 @@ def log_files_gen(self):
self._feed_hyper_params_to_tb()
self.print_log_dir()

def update_log_files_location(self, root):
def update_log_files_location(self, root:str):
"""
This function is designed for the requirement of using copied/moved experiment logs to other databases for downstream task.
The location of the experiment logs might have changed compared with their original location.
The function automatically update the attributes related to the data_root to the current location.
:param root: current data_root
:type root: str
"""
self.data_root = root
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, self.task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, self.task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, self.task_table_name), '.pkl')
self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, self.task_table_name), is_file=False)
self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, self.task_table_name), is_file=False)

task_table_name = getattr(self, 'task_table_name', None)
if task_table_name is None:
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
else:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl')
self.checkpoint_dir, _ = self.__create_file_directory(osp.join(self.data_root, CHECKPOINT, task_table_name), is_file=False)
self.results_dir, _ = self.__create_file_directory(osp.join(self.data_root, OTHER_RESULTS, task_table_name), is_file=False)
self.log_dir = log_dir
self.code_dir = code_dir
self.print_log_dir()
Expand Down Expand Up @@ -487,15 +504,23 @@ def __create_file_directory(self, prefix, ext='', is_file=True, record_date=None
record_date = self.record_date
directory = str(record_date.strftime("%Y/%m/%d"))
directory = osp.join(prefix, directory)
version_num = getattr(self, 'log_name_format_version', None)

if version_num is None:
name_format = '{dir}/{timestep} {ip} {info}{ext}'
elif version_num == LOG_NAME_FORMAT_VERSION.V1:
name_format = '{dir}/{timestep}_{ip}_{info}{ext}'
else:
raise RuntimeError("unknown version name", version_num)

if is_file:
os.makedirs(directory, exist_ok=True)
file_name = '{dir}/{timestep}_{ip}_{info}{ext}'.format(dir=directory,
timestep=self.record_date_to_str(record_date),
file_name = name_format.format(dir=directory, timestep=self.record_date_to_str(record_date),
ip=str(self.ipaddr),
info=self.info,
ext=ext)
else:
directory = '{dir}/{timestep}_{ip}_{info}{ext}/'.format(dir=directory,
directory = (name_format + '/').format(dir=directory,
timestep=self.record_date_to_str(record_date),
ip=str(self.ipaddr),
info=self.info,
Expand Down Expand Up @@ -545,7 +570,6 @@ def new_saver(self, max_to_keep, var_prefix=None):
self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True)
elif self.dl_framework == FRAMEWORK.torch:
self.max_to_keep = max_to_keep
self.checkpoint_keep_list = []
else:
raise NotImplementedError

Expand All @@ -558,6 +582,8 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
self.saver.save(tf.get_default_session(), cpt_name, global_step=iter)
elif self.dl_framework == FRAMEWORK.torch:
import torch
if self.checkpoint_keep_list is None:
self.checkpoint_keep_list = []
iter = self.time_step_holder.get_time()
torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter))
self.checkpoint_keep_list.append(iter)
Expand All @@ -574,20 +600,27 @@ def save_checkpoint(self, model_dict: Optional[dict]=None, related_variable: Opt
self.add_custom_data(k, v, type(v), mode='replace')
self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')

def load_checkpoint(self):
def load_checkpoint(self, ckp_index=None):
if self.dl_framework == FRAMEWORK.tensorflow:
# TODO: load with variable scope.
import tensorflow as tf
cpt_name = osp.join(self.checkpoint_dir)
logger.info("load checkpoint {}".format(cpt_name))
ckpt_path = tf.train.latest_checkpoint(cpt_name)
if ckp_index is None:
ckpt_path = tf.train.latest_checkpoint(cpt_name)
else:
ckpt_path = tf.train.latest_checkpoint(cpt_name, ckp_index)
self.saver.restore(tf.get_default_session(), ckpt_path)
max_iter = ckpt_path.split('-')[-1]
self.time_step_holder.set_time(max_iter)
return int(max_iter), None
elif self.dl_framework == FRAMEWORK.torch:
import torch
return self.checkpoint_keep_list[-1], torch.load(tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1]))
all_ckps = sorted(os.listdir(self.checkpoint_dir))
print("all checkpoints:")
pprint.pprint(all_ckps)
if ckp_index is None:
ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0]
return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index))

def auto_parse_info(self):
return '&'.join(self.hyper_param_record)
Expand Down Expand Up @@ -648,7 +681,7 @@ def serialize_object_and_save(self):
saver = self.saver
self.saver = None
with open(self.pkl_file, 'wb') as f:
dill.dump(self, f)
dill.dump(self, f, recurse=True)
self.writer = writer
self.saver = saver

Expand Down
19 changes: 19 additions & 0 deletions RLA/rla_argparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import argparse


def boolean_flag(parser: argparse.ArgumentParser, name, default=False, help=None):
"""Add a boolean flag to argparse parser.

Parameters
----------
parser: argparse.Parser
parser to add the flag to
name: str
--<name> will enable the flag, while --no-<name> will disable it
default: bool or None
default value of the flag
help: str
help string for the flag
"""
dest = name.replace('-', '_')
parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help)
parser.add_argument("--no-" + name, action="store_false", dest=dest)


def arg_parser_postprocess(parser: argparse.ArgumentParser):
parser.add_argument('--loaded_task_name', default='', type=str)
parser.add_argument('--info', default='default exp info', type=str)
Expand Down
1 change: 1 addition & 0 deletions example/simplest_code/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def get_param():
parser.add_argument('--env_id', help='environment ID', default='Test-v1')
parser.add_argument('--learning_rate', help='a hyperparameter', default=1e-3, type=float)
parser.add_argument('--input_size', help='a hyperparameter', default=16, type=int)
# NOTE: add some recommended hyper-parameters for RLA.
parser = arg_parser_postprocess(parser)
args = parser.parse_args()
kwargs = vars(args)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='RLA',
version="0.5.3",
version="0.6.0-pre",
description=(
'RL assistant'
),
Expand Down