diff --git a/RLA/auto_ftp.py b/RLA/auto_ftp.py index 4a177b4..d125163 100644 --- a/RLA/auto_ftp.py +++ b/RLA/auto_ftp.py @@ -5,7 +5,6 @@ import traceback from RLA.const import * from RLA.easy_log import logger - import pysftp @@ -17,6 +16,7 @@ def ftp_factory(name, server, username, password, port, ignore=None): else: raise NotImplementedError + class FTPHandler(object): def __init__(self, ftp_server, username, password, port, ignore=None): @@ -139,6 +139,7 @@ def close(self): self.ftp.quit() self.ftp.close() + class SFTPHandler(FTPHandler): def __init__(self, sftp_server, username, password, port, ignore=None): diff --git a/RLA/easy_log/const.py b/RLA/easy_log/const.py index 4eaa5d9..2c432a1 100644 --- a/RLA/easy_log/const.py +++ b/RLA/easy_log/const.py @@ -5,7 +5,7 @@ OTHER_RESULTS = 'results' ARCHIVED_TABLE = 'arc' default_log_types = [LOG, CODE, CHECKPOINT, ARCHIVE_TESTER, OTHER_RESULTS] - +HYPARAM = 'parameter' class LoadTesterMode: FORK_TO_NEW = 'fork' diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index ba7b2ed..af2b2bf 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -215,6 +215,11 @@ def log_files_gen(self): self.serialize_object_and_save() self.__copy_source_code(self.run_file, code_dir) self._feed_hyper_params_to_tb() + params = self.hyper_param + for param_dir in [self.code_dir, self.log_dir]: + with open(osp.join(param_dir, HYPARAM + '.json'), 'w') as f: + json.dump(params, f, sort_keys=True, indent=4, allow_nan=True, default=lambda o: '') + print("gen:", osp.join(param_dir, 'parameter.json')) self.print_log_dir() def update_log_files_location(self, root:str): @@ -782,9 +787,6 @@ def print_args(self): # formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date) params = exp_manager.hyper_param # params['formatted_log_name'] = formatted_log_name - json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'), - sort_keys=True, indent=4, allow_nan=True, default=lambda o: '') - print("gen:", osp.join(self.code_dir, 'parameter.json')) def print_large_memory_variable(self): diff --git a/RLA/easy_plot/plot_func_v2.py b/RLA/easy_plot/plot_func_v2.py index d1ed652..f309d1e 100644 --- a/RLA/easy_plot/plot_func_v2.py +++ b/RLA/easy_plot/plot_func_v2.py @@ -1,6 +1,7 @@ # Created by xionghuichen at 2022/8/10 # Email: chenxh@lamda.nju.edu.cn import glob +import json import os.path as osp import os import dill @@ -8,13 +9,11 @@ import numpy as np from typing import Dict, List, Tuple, Type, Union, Optional, Callable import matplotlib.pyplot as plt - from RLA import logger from RLA.const import DEFAULT_X_NAME from RLA.query_tool import experiment_data_query, extract_valid_index - from RLA.easy_plot import plot_util -from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS +from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS, HYPARAM @@ -25,7 +24,6 @@ def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True): else: return task_split_key - def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list, use_buf=False, verbose=True, x_bound: Optional[int]=None, @@ -97,7 +95,11 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me if verbose: print("find log", v.dirname) counter += 1 - result.hyper_param = tester_dict[k].exp_manager.hyper_param + if os.path.exists(osp.join(v.dirname, HYPARAM + '.json')): + with open(osp.join(v.dirname, HYPARAM + '.json')) as f: + result.hyper_param = json.load(f) + else: + result.hyper_param = tester_dict[k].exp_manager.hyper_param results.append(result) reg_group[reg].append(result) print("find log number", counter) @@ -126,7 +128,6 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs) print("--- complete process ---") if save_name is not None: - import os file_name = osp.join(data_root, OTHER_RESULTS, 'easy_plot', save_name) os.makedirs(os.path.dirname(file_name), exist_ok=True) if lgd is not None: diff --git a/test/test_data_root/log/demo_task/2022/11/24/18-46-30-917022_172.16.0.147_&env_id=Test-v1&learning_rate=0.001&seed=88/parameter.json b/test/test_data_root/log/demo_task/2022/11/24/18-46-30-917022_172.16.0.147_&env_id=Test-v1&learning_rate=0.001&seed=88/parameter.json new file mode 100644 index 0000000..f184d95 --- /dev/null +++ b/test/test_data_root/log/demo_task/2022/11/24/18-46-30-917022_172.16.0.147_&env_id=Test-v1&learning_rate=0.001&seed=88/parameter.json @@ -0,0 +1,9 @@ +{ + "env_id": "Test-v1", + "info": "default exp info", + "input_size": 16, + "learning_rate": 0.001, + "loaded_date": true, + "loaded_task_name": "", + "seed": 88 +} \ No newline at end of file diff --git a/test/test_data_root/log/demo_task/2023/02/02/13-27-13-132216_11.0.91.89_&input_size=16/parameter.json b/test/test_data_root/log/demo_task/2023/02/02/13-27-13-132216_11.0.91.89_&input_size=16/parameter.json new file mode 100644 index 0000000..8757b29 --- /dev/null +++ b/test/test_data_root/log/demo_task/2023/02/02/13-27-13-132216_11.0.91.89_&input_size=16/parameter.json @@ -0,0 +1,4 @@ +{ + "input_size": 16, + "learning_rate": 0.0001 +} \ No newline at end of file diff --git a/test/test_plot.ipynb b/test/test_plot.ipynb index f3c5280..320d924 100644 --- a/test/test_plot.ipynb +++ b/test/test_plot.ipynb @@ -731,7 +731,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 16, "id": "a55937ed", "metadata": { "scrolled": false diff --git a/test/test_proj/proj/test_manager.py b/test/test_proj/proj/test_manager.py index b6c21f9..8d4ea30 100644 --- a/test/test_proj/proj/test_manager.py +++ b/test/test_proj/proj/test_manager.py @@ -55,7 +55,7 @@ def test_log_tf(self): exp_manager.new_saver(var_prefix='', max_to_keep=1) # synthetic target function. - for i in range(0, 1000): + for i in range(0, 100): exp_manager.time_step_holder.set_time(i) x_input = np.random.normal(0, 3, [64, kwargs["input_size"]]) y = target_func(x_input) @@ -143,10 +143,22 @@ def test_sent_to_master(self): yaml = self._load_rla_config() try: from test.test_proj.proj import private_config + # try to import libs except ImportError as e: print("[WARN] for this test, you should config your username, password, and the remote root firstly.") return # raise RuntimeError + try: + if private_config.protocol == 'ftp': + import ftplib + elif private_config.protocol == 'sftp': + import pysftp + else: + raise NotImplementedError + except ImportError as e: + print(e) + print(f"[WARN] the select protocol {private_config.protocol} cannot be loaded. skip the unittest.") + return yaml['DL_FRAMEWORK'] = 'torch' yaml['SEND_LOG_FILE'] = True yaml['REMOTE_SETTING']['ftp_server'] = '127.0.0.1'