From 7141043d20abe8eb8c6686bad22f2542d804dd1d Mon Sep 17 00:00:00 2001 From: unknown <774005423@qq.com> Date: Tue, 11 Apr 2023 17:59:47 +0800 Subject: [PATCH] refactor: add a function to get the file-path seperator for cross OSs compatiblity --- RLA/const.py | 9 ++++++++- RLA/easy_log/log_tools.py | 15 +++++++++------ RLA/easy_log/tester.py | 7 ++++--- RLA/utils/utils.py | 26 ++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/RLA/const.py b/RLA/const.py index 20d3a89..299be24 100644 --- a/RLA/const.py +++ b/RLA/const.py @@ -9,4 +9,11 @@ class FTP_PROTOCOL_NAME: SFTP = 'sftp' class LOG_NAME_FORMAT_VERSION: - V1 = 'v1' \ No newline at end of file + V1 = 'v1' + + + +class PLATFORM_TYPE: + WIN = 'win' + LINUX = 'linux' + OTHER = 'other' diff --git a/RLA/easy_log/log_tools.py b/RLA/easy_log/log_tools.py index 0e851b3..5bacc3c 100644 --- a/RLA/easy_log/log_tools.py +++ b/RLA/easy_log/log_tools.py @@ -16,7 +16,7 @@ import json from RLA.easy_log.tester import Tester - +from RLA.utils.utils import get_dir_seperator class Filter(object): ALL = 'all' SMALL_TIMESTEP = 'small_ts' @@ -214,8 +214,9 @@ def _archive_log(self, show=False): empty = False if os.path.exists(root_dir): # remove the overlapped path. + septor = get_dir_seperator() archiving_target = osp.join(archive_root_dir, root_dir[prefix_len+1:]) - archiving_target_dir = '/'.join(archiving_target.split('/')[:-1]) + archiving_target_dir = septor.join(archiving_target.split(septor)[:-1]) os.makedirs(archiving_target_dir, exist_ok=True) if os.path.isdir(root_dir): if not show: @@ -269,21 +270,23 @@ def _migrate_log(self, show=False): if os.path.exists(root_dir): # remove the overlapped path. archiving_target = osp.join(target_root_dir, root_dir[prefix_len+1:]) - archiving_target_dir = '/'.join(archiving_target.split('/')[:-1]) + + septor = get_dir_seperator() + archiving_target_dir = septor.join(archiving_target.split(septor)[:-1]) + print("target dir", archiving_target_dir) os.makedirs(archiving_target_dir, exist_ok=True) if os.path.isdir(root_dir): + print("copy dir {}, to {}".format(root_dir, archiving_target)) if not show: # os.makedirs(archiving_target, exist_ok=True) try: shutil.copytree(root_dir, archiving_target) except FileExistsError as e: print(e) - - print("copy dir {}, to {}".format(root_dir, archiving_target)) else: + print("copy file {}, to {}".format(root_dir, archiving_target)) if not show: shutil.copy(root_dir, archiving_target) - print("copy file {}, to {}".format(root_dir, archiving_target)) else: print("no dir {}".format(root_dir)) if empty: print("empty regex {}".format(root_dir_regex)) diff --git a/RLA/easy_log/tester.py b/RLA/easy_log/tester.py index 2ef85fd..5b751a4 100644 --- a/RLA/easy_log/tester.py +++ b/RLA/easy_log/tester.py @@ -26,7 +26,7 @@ import shutil import argparse from typing import Dict, List, Tuple, Type, Union, Optional -from RLA.utils.utils import deprecated_alias, load_yaml +from RLA.utils.utils import deprecated_alias, load_yaml, get_dir_seperator from RLA.const import DEFAULT_X_NAME, FRAMEWORK import pathspec @@ -150,10 +150,11 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo logger.info("private_config: ") self.dl_framework = self.private_config["DL_FRAMEWORK"] self.is_master_node = is_master_node - + self.septor = get_dir_seperator() if code_root is None: if isinstance(rla_config, str): - self.project_root = "/".join(rla_config.split("/")[:-1]) + + self.project_root = self.septor.join(rla_config.split(self.septor)[:-1]) else: raise NotImplementedError("If you pass the rla_config dict directly, " "you should define the root of your codebase (for backup) explicitly by pass the code_root.") diff --git a/RLA/utils/utils.py b/RLA/utils/utils.py index 9138586..2416043 100644 --- a/RLA/utils/utils.py +++ b/RLA/utils/utils.py @@ -4,6 +4,32 @@ import functools import warnings +import platform +from RLA.const import * + + +def get_sys_type(): + systype = platform.system() + if systype.find('Windows') != -1: + return PLATFORM_TYPE.WIN + elif systype.find('Linux') != -1: + return PLATFORM_TYPE.LINUX + else: + return PLATFORM_TYPE.OTHER + +def get_dir_seperator(): + sys_flag = get_sys_type() + if sys_flag == PLATFORM_TYPE.WIN: + return '\\' + elif sys_flag == PLATFORM_TYPE.LINUX: + return '/' + elif sys_flag == PLATFORM_TYPE.OTHER: + print("[WARN] unrecognizable system type: ", sys_flag, "use default dir seperator") + return '/' + else: + raise NotImplementedError("[ERROR] undefined system flag", sys_flag) + + def deprecated_alias(**aliases): def deco(f): @functools.wraps(f)