Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RLA/auto_ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import traceback
from RLA.const import *
from RLA.easy_log import logger

import pysftp


Expand All @@ -17,6 +16,7 @@ def ftp_factory(name, server, username, password, port, ignore=None):
else:
raise NotImplementedError


class FTPHandler(object):

def __init__(self, ftp_server, username, password, port, ignore=None):
Expand Down Expand Up @@ -139,6 +139,7 @@ def close(self):
self.ftp.quit()
self.ftp.close()


class SFTPHandler(FTPHandler):

def __init__(self, sftp_server, username, password, port, ignore=None):
Expand Down
2 changes: 1 addition & 1 deletion RLA/easy_log/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
OTHER_RESULTS = 'results'
ARCHIVED_TABLE = 'arc'
default_log_types = [LOG, CODE, CHECKPOINT, ARCHIVE_TESTER, OTHER_RESULTS]

HYPARAM = 'parameter'

class LoadTesterMode:
FORK_TO_NEW = 'fork'
Expand Down
8 changes: 5 additions & 3 deletions RLA/easy_log/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ def log_files_gen(self):
self.serialize_object_and_save()
self.__copy_source_code(self.run_file, code_dir)
self._feed_hyper_params_to_tb()
params = self.hyper_param
for param_dir in [self.code_dir, self.log_dir]:
with open(osp.join(param_dir, HYPARAM + '.json'), 'w') as f:
json.dump(params, f, sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
print("gen:", osp.join(param_dir, 'parameter.json'))
self.print_log_dir()

def update_log_files_location(self, root:str):
Expand Down Expand Up @@ -782,9 +787,6 @@ def print_args(self):
# formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date)
params = exp_manager.hyper_param
# params['formatted_log_name'] = formatted_log_name
json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'),
sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
print("gen:", osp.join(self.code_dir, 'parameter.json'))


def print_large_memory_variable(self):
Expand Down
13 changes: 7 additions & 6 deletions RLA/easy_plot/plot_func_v2.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Created by xionghuichen at 2022/8/10
# Email: chenxh@lamda.nju.edu.cn
import glob
import json
import os.path as osp
import os
import dill
import copy
import numpy as np
from typing import Dict, List, Tuple, Type, Union, Optional, Callable
import matplotlib.pyplot as plt

from RLA import logger
from RLA.const import DEFAULT_X_NAME
from RLA.query_tool import experiment_data_query, extract_valid_index

from RLA.easy_plot import plot_util
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS, HYPARAM



Expand All @@ -25,7 +24,6 @@ def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
else:
return task_split_key


def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
use_buf=False, verbose=True,
x_bound: Optional[int]=None,
Expand Down Expand Up @@ -97,7 +95,11 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
if verbose:
print("find log", v.dirname)
counter += 1
result.hyper_param = tester_dict[k].exp_manager.hyper_param
if os.path.exists(osp.join(v.dirname, HYPARAM + '.json')):
with open(osp.join(v.dirname, HYPARAM + '.json')) as f:
result.hyper_param = json.load(f)
else:
result.hyper_param = tester_dict[k].exp_manager.hyper_param
results.append(result)
reg_group[reg].append(result)
print("find log number", counter)
Expand Down Expand Up @@ -126,7 +128,6 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs)
print("--- complete process ---")
if save_name is not None:
import os
file_name = osp.join(data_root, OTHER_RESULTS, 'easy_plot', save_name)
os.makedirs(os.path.dirname(file_name), exist_ok=True)
if lgd is not None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"env_id": "Test-v1",
"info": "default exp info",
"input_size": 16,
"learning_rate": 0.001,
"loaded_date": true,
"loaded_task_name": "",
"seed": 88
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"input_size": 16,
"learning_rate": 0.0001
}
2 changes: 1 addition & 1 deletion test/test_plot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 16,
"id": "a55937ed",
"metadata": {
"scrolled": false
Expand Down
14 changes: 13 additions & 1 deletion test/test_proj/proj/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_log_tf(self):

exp_manager.new_saver(var_prefix='', max_to_keep=1)
# synthetic target function.
for i in range(0, 1000):
for i in range(0, 100):
exp_manager.time_step_holder.set_time(i)
x_input = np.random.normal(0, 3, [64, kwargs["input_size"]])
y = target_func(x_input)
Expand Down Expand Up @@ -143,10 +143,22 @@ def test_sent_to_master(self):
yaml = self._load_rla_config()
try:
from test.test_proj.proj import private_config
# try to import libs
except ImportError as e:
print("[WARN] for this test, you should config your username, password, and the remote root firstly.")
return
# raise RuntimeError
try:
if private_config.protocol == 'ftp':
import ftplib
elif private_config.protocol == 'sftp':
import pysftp
else:
raise NotImplementedError
except ImportError as e:
print(e)
print(f"[WARN] the select protocol {private_config.protocol} cannot be loaded. skip the unittest.")
return
yaml['DL_FRAMEWORK'] = 'torch'
yaml['SEND_LOG_FILE'] = True
yaml['REMOTE_SETTING']['ftp_server'] = '127.0.0.1'
Expand Down