Skip to content
Merged

Dev #12

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0d25abf
Merge pull request #19 from xionghuichen/dev
xionghuichen Jun 22, 2022
cf30b18
Update README.md
xionghuichen Jun 22, 2022
33e6aee
Dev (#20)
xionghuichen Jul 13, 2022
74c7712
fix: minor changes for version compatibility
xionghuichen Jul 14, 2022
e36a1bc
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 14, 2022
3bdbf3e
Dev (#21)
xionghuichen Jul 14, 2022
96ba639
fix: a bug of sorting in torch-version checkpoint loading
xionghuichen Jul 14, 2022
680c6be
Dev (#22)
xionghuichen Jul 14, 2022
44c2cff
refactor: robust multi-key plot implementation
xionghuichen Jul 21, 2022
ce25856
feat: supoort pretty plotter
xionghuichen Jul 23, 2022
efc4815
refactor(log plotter): record scores of the log plotter
xionghuichen Jul 23, 2022
3fca57c
fix(exp_loader): add parameter ckp_index
xionghuichen Jul 23, 2022
017efc2
refactor(rla_script): add start_server to start_pretty_plotter.py
xionghuichen Jul 23, 2022
7c0b7dc
update readme
xionghuichen Jul 23, 2022
4239bd6
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 23, 2022
29e4932
Dev (#23)
xionghuichen Jul 23, 2022
5a0c180
Merge branch 'main' of github.com:xionghuichen/RLAssistant into dev
xionghuichen Jul 27, 2022
d3ff59d
rm unsolved merge
xionghuichen Jul 27, 2022
65d2859
feat: tf-v2 compatible
xionghuichen Jul 27, 2022
845f3ab
refactor: add timestep recorder. refactor on exp_loader
xionghuichen Aug 10, 2022
9f799c3
test: add test data
xionghuichen Aug 10, 2022
a277416
feat(plot): track the hyper-parameter from the exp_manager instead of…
xionghuichen Aug 10, 2022
00d26d2
Merge branch 'main' of github.com:polixir/RLAssistant into dev
xionghuichen Aug 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,5 @@ PS:
- [ ] download / upload experiment logs through timestamp.
- [ ] add a document to the plot function.
- [ ] allow sync LOG only or ALL TYPE LOGS.
- [ ] support aim and smarter logger.
- [x] support aim and smarter logger.
- [ ] add unit_test to ckp loader.
3 changes: 3 additions & 0 deletions RLA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from RLA.easy_log.tester import exp_manager
from RLA.easy_log import logger
from RLA.easy_plot.plot_func_v2 import plot_func
13 changes: 11 additions & 2 deletions RLA/easy_log/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,14 @@
ARCHIVED_TABLE = 'arc'
default_log_types = [LOG, CODE, CHECKPOINT, ARCHIVE_TESTER, OTHER_RESULTS]

class LOAD_TESTER_MODE:
FORK_TO_NEW = 'fork'

class LoadTesterMode:
FORK_TO_NEW = 'fork'


# option: 'stdout', 'log', 'tensorboard', 'csv'
class LogDataType:
TB = 'tensorboard'
CSV = 'csv'
TXT = 'log'
STDOUT = 'stdout'
8 changes: 4 additions & 4 deletions RLA/easy_log/exp_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def import_hyper_parameters(self, hp_to_overwrite: Optional[list] = None, sync_t
else:
return argparse.Namespace(**exp_manager.hyper_param)

def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None, verbose=True):
def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list: Optional[list]=None, verbose=True,
ckp_index: Optional[int]=None):
"""

:param var_prefix: the prefix of namescope (for tf) to load. Set to '' to load all of the parameters.
Expand All @@ -81,9 +82,8 @@ def load_from_record_date(self, var_prefix: Optional[str] = None, variable_list:
"""
if self.is_valid_config:
loaded_tester = Tester.load_tester(self.load_date, self.task_name, self.data_root)
if verbose:
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
print("attrs of the loaded tester")
pprint(loaded_tester.__dict__)
# load checkpoint
load_res = {}
if var_prefix is not None:
Expand Down
66 changes: 57 additions & 9 deletions RLA/easy_log/log_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import yaml
import pandas as pd
import csv
import dill

import json
from RLA.easy_log.tester import Tester

class Filter(object):
ALL = 'all'
Expand All @@ -27,18 +31,22 @@ def __init__(self, optional_log_type=None):
self.log_types = default_log_types.copy()
if optional_log_type is not None:
self.log_types.extend(optional_log_type)


def is_valid_index(self, regex):
if re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', regex):
target_reg = re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', regex).group(0)
else:
target_reg = None
return target_reg

def _find_small_timestep_log(self, proj_root, task_table_name, regex, timstep_upper_bound=np.inf, timestep_lower_bound=0):
small_timestep_regs = []
root_dir_regex = osp.join(proj_root, LOG, task_table_name, regex)
for root_dir in glob.glob(root_dir_regex):
print("searching dirs", root_dir)
if os.path.exists(root_dir):
for file_list in os.walk(root_dir):
if re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', file_list[0]):
target_reg = re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', file_list[0]).group(0)
else:
target_reg = None
target_reg = self.is_valid_index(file_list[0])
if target_reg is not None:
if LOG in root_dir_regex:
try:
Expand Down Expand Up @@ -236,6 +244,7 @@ def archive_log(self, skip_ask=False):
print("do archive ...")
self._archive_log(show=False)


class ViewLogTool(BasicLogTool):
def __init__(self, proj_root, task_table_name, regex, *args, **kwargs):
self.proj_root = proj_root
Expand All @@ -248,10 +257,7 @@ def _view_log(self, regex):
for root_dir in glob.glob(root_dir_regex):
if os.path.exists(root_dir):
for file_list in os.walk(root_dir):
if re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', file_list[0]):
target_reg = re.search(r'\d{4}/\d{2}/\d{2}/\d{2}-\d{2}-\d{2}-\d{6}', file_list[0]).group(0)
else:
target_reg = None
target_reg = self.is_valid_index(file_list[0])
if target_reg is not None:
backup_file = file_list[0] + '/backup.txt'
if file_list[1] == ['tb'] or os.path.exists(backup_file): # in root of logdir
Expand All @@ -268,3 +274,45 @@ def view_log(self, skip_ask=False):
s = input("press y to view \n ")
if s == 'y':
self._view_log(regex=res[0] + '*')


class PrettyPlotterTool(BasicLogTool):
def __init__(self, proj_root, task_table_name, regex, *args, **kwargs):
self.proj_root = proj_root
self.task_table_name = task_table_name
self.regex = regex
super(PrettyPlotterTool, self).__init__(*args, **kwargs)

def json_dump(self, location):
target_index = self.is_valid_index(location)
if target_index is not None:

json_location = None
try:
exp_manager = dill.load(open(location, 'rb'))
assert isinstance(exp_manager, Tester)
formatted_log_name = exp_manager.log_name_formatter(exp_manager.get_task_table_name(),
exp_manager.record_date)
params = exp_manager.hyper_param
params['formatted_log_name'] = formatted_log_name

json_location = exp_manager.log_name_formatter(
osp.join(self.proj_root, LOG, exp_manager.get_task_table_name()), exp_manager.record_date) + '/'
json.dump(params, open(osp.join(json_location, 'parameter.json'), 'w'))
print("gen:", osp.join(json_location, 'parameter.json'))
except FileNotFoundError as e:
print("log file cannot found", json_location)
except EOFError as e:
print("log file broken", json_location)

def gen_json(self, regex):
root_dir_regex = osp.join(self.proj_root, ARCHIVE_TESTER, self.task_table_name, regex)
for root_dir in glob.glob(root_dir_regex):
if os.path.exists(root_dir):
if osp.isdir(root_dir):
for file_list in os.walk(root_dir):
for file in file_list[2]:
location = osp.join(file_list[0], file)
self.json_dump(location)
else:
self.json_dump(root_dir)
53 changes: 38 additions & 15 deletions RLA/easy_log/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,21 @@ def __init__(self, dir, framework):
path = osp.join(osp.abspath(dir), prefix)
if self.framework == FRAMEWORK.tensorflow:
import tensorflow as tf
self.tb_writer = tf.summary.FileWriter(path)
from tensorflow.python import pywrap_tensorflow

self.tbx_writer = None
from tensorflow.python import pywrap_tensorflow
from tensorflow.core.util import event_pb2
from tensorflow.python.util import compat
try:
self.tb_writer = tf.summary.FileWriter(path)
self.tf = tf
except AttributeError as e:
# tf.compat.v1.disable_eager_execution()
# from tensorflow.python.client import _pywrap_events_writer
self.tb_writer = tf.summary.create_file_writer(path) # tf.compat.v1.summary.FileWriter(path)
# tf.compat.v1.disable_eager_execution()
# tf = tf.compat.v1
self.tf = tf
self.event_pb2 = event_pb2
self.pywrap_tensorflow = pywrap_tensorflow
Expand Down Expand Up @@ -284,10 +294,16 @@ def writer(self):
def add_hyper_params_to_tb(self, hyper_param, metric_dict=None):
if self.framework == FRAMEWORK.tensorflow:
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
hyperparameters = [tf.convert_to_tensor([k, str(v)]) for k, v in hyper_param.items()]
summary = sess.run(tf.summary.text('hyperparameters', tf.stack(hyperparameters)))
self.tb_writer.add_summary(summary, self.step)
try:
with tf.Session(graph=tf.Graph()) as sess:
hyperparameters = [tf.convert_to_tensor([k, str(v)]) for k, v in hyper_param.items()]
summary = sess.run(tf.summary.text('hyperparameters', tf.stack(hyperparameters)))
self.tb_writer.add_summary(summary, self.step)
except AttributeError as e:
tf.compat.v1.enable_eager_execution()
with self.tb_writer.as_default():
hyperparameters = [tf.convert_to_tensor([k, str(v)]) for k, v in hyper_param.items()]
tf.summary.text('hyperparameters', tf.stack(hyperparameters), step=self.step)
elif self.framework == FRAMEWORK.torch:
import pprint
if metric_dict is None:
Expand All @@ -300,14 +316,20 @@ def add_hyper_params_to_tb(self, hyper_param, metric_dict=None):

def writekvs(self, kvs):
if self.framework == FRAMEWORK.tensorflow:
def summary_val(k, v):
kwargs = {'tag': k, 'simple_value': float(v)}
return self.tf.Summary.Value(**kwargs)
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step # is there any reason why you'd want to specify the step?
self.writer.add_event(event)
self.writer.flush()
try:
def summary_val(k, v):
kwargs = {'tag': k, 'simple_value': float(v)}
return self.tf.Summary.Value(**kwargs)
summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step # is there any reason why you'd want to specify the step?
self.writer.add_event(event)
self.writer.flush()
except AttributeError as e:
self.tf.compat.v1.enable_eager_execution()
with self.tb_writer.as_default():
for k, v in kvs.items():
self.tf.summary.scalar(k, v, step=self.step)
elif self.framework == FRAMEWORK.torch:
def summary_val(k, v):
kwargs = {'tag': k, 'scalar_value': float(v), 'global_step': self.step}
Expand Down Expand Up @@ -395,7 +417,7 @@ def ma_record_tabular(key, val, record_len, ignore_nan=False, exclude:Optional[U
if len(ma_dict[key]) == record_len:
record_tabular(key, np.mean(ma_dict[key]), exclude)

def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None, freq:Optional[int]=None):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
Expand All @@ -404,7 +426,8 @@ def logkv(key, val, exclude:Optional[Union[str, Tuple[str, ...]]]=None):
:param key: (Any) save to log this key
:param val: (Any) save to log this value
"""
get_current().logkv(key, val, exclude)
if freq is None or timestep() % freq == 0:
get_current().logkv(key, val, exclude)


def log_from_tf_summary(summary):
Expand Down
32 changes: 14 additions & 18 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os.path as osp
import pprint

import numpy as np
import tensorboardX

from RLA.easy_log.time_step import time_step_holder
Expand Down Expand Up @@ -162,6 +161,15 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
for k, v in self.private_config.items():
logger.info("k: {}, v: {}".format(k, v))

def get_task_table_name(self):
task_table_name = getattr(self, 'task_table_name', None)
if task_table_name is None:
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
if task_table_name is None:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
return task_table_name

def set_hyper_param(self, **argkw):
"""
Expand Down Expand Up @@ -218,13 +226,7 @@ def update_log_files_location(self, root:str):
"""
self.data_root = root

task_table_name = getattr(self, 'task_table_name', None)
if task_table_name is None:
task_table_name = getattr(self, 'task_name', None)
print("[WARN] you are using an old-version RLA. "
"Some attributes' name have been changed (task_name->task_table_name).")
if task_table_name is None:
raise RuntimeError("invalid ExpManager: task_table_name cannot be found", )
task_table_name = self.get_task_table_name()
code_dir, _ = self.__create_file_directory(osp.join(self.data_root, CODE, task_table_name), '', is_file=False)
log_dir, _ = self.__create_file_directory(osp.join(self.data_root, LOG, task_table_name), '', is_file=False)
self.pkl_dir, self.pkl_file = self.__create_file_directory(osp.join(self.data_root, ARCHIVE_TESTER, task_table_name), '.pkl')
Expand Down Expand Up @@ -431,15 +433,9 @@ def log_file_finder(cls, record_date, task_table_name='train', file_root='../che
raise NotImplementedError
for search_item in search_list:
if search_item.startswith(str(record_date.strftime("%H-%M-%S-%f"))):
try:
split_dir = search_item.split('_')
assert len(split_dir) >= 2
info = " ".join(split_dir[2:])
except AssertionError as e:
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = "_".join(split_dir[2:])
print("[WARN] We find an old-version experiment data.")
split_dir = search_item.split(' ')
# self.__ipaddr = split_dir[1]
info = " ".join(split_dir[2:])
logger.info("load data: \n ts {}, \n ip {}, \n info {}".format(split_dir[0], split_dir[1], info))
file_found = search_item
break
Expand Down Expand Up @@ -625,14 +621,14 @@ def load_checkpoint(self, ckp_index=None):
all_ckps = os.listdir(self.checkpoint_dir)
ites = []
for ckps in all_ckps:
print("ckps", ckps)
ites.append(int(ckps.split('checkpoint-')[1].split('.pt')[0]))
idx = np.argsort(ites)
all_ckps = np.array(all_ckps)[idx]
print("all checkpoints:")
pprint.pprint(all_ckps)
if ckp_index is None:
ckp_index = all_ckps[-1].split('checkpoint-')[1].split('.pt')[0]
print("loaded checkpoints:", "checkpoint-{}.pt".format(ckp_index))
return ckp_index, torch.load(self.checkpoint_dir + "checkpoint-{}.pt".format(ckp_index))

def auto_parse_info(self):
Expand Down
38 changes: 38 additions & 0 deletions RLA/easy_log/time_used_recorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Created by xionghuichen at 2022/7/29
# Email: chenxh@lamda.nju.edu.cn
from RLA.easy_log import logger
import time


rc_start_time = {}


def time_record(name):
"""
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
assert name not in rc_start_time
rc_start_time[name] = time.time()


def time_record_end(name):
"""
record the consumed time of your code snippet. call this function to start a recorder.
"name" is identifier to distinguish different recorder and record different snippets at the same time.
call time_record_end to end a recorder.
:param name: identifier of your code snippet.
:type name: str
:return:
:rtype:
"""
end_time = time.time()
start_time = rc_start_time[name]
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
del rc_start_time[name]
Loading