From df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c Mon Sep 17 00:00:00 2001 From: liuyuliang Date: Wed, 13 Apr 2022 17:30:47 +0800 Subject: [PATCH 1/4] [CLI] add CLI launcher --- bin/colossal | 6 + colossalai/launcher/__init__.py | 0 colossalai/launcher/multinode_runner.py | 82 +++++++ colossalai/launcher/run.py | 286 ++++++++++++++++++++++++ setup.py | 3 + 5 files changed, 377 insertions(+) create mode 100644 bin/colossal create mode 100644 colossalai/launcher/__init__.py create mode 100644 colossalai/launcher/multinode_runner.py create mode 100644 colossalai/launcher/run.py diff --git a/bin/colossal b/bin/colossal new file mode 100644 index 000000000000..8e9fcb29b7d8 --- /dev/null +++ b/bin/colossal @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 + +from colossalai.launcher.run import main + +if __name__ == '__main__': + main() diff --git a/colossalai/launcher/__init__.py b/colossalai/launcher/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/colossalai/launcher/multinode_runner.py b/colossalai/launcher/multinode_runner.py new file mode 100644 index 000000000000..188d2c81fe14 --- /dev/null +++ b/colossalai/launcher/multinode_runner.py @@ -0,0 +1,82 @@ +import os +import sys +import shutil +import subprocess +import warnings +from shlex import quote +from abc import ABC, abstractmethod + +from colossalai.logging import get_dist_logger + +logger = get_dist_logger() + +class MultiNodeRunner(ABC): + def __init__(self, args): + self.args = args + self.user_arguments = self.args.user_args + self.user_script = args.user_script + self.exports = {} + + @abstractmethod + def backend_exists(self): + """Return whether the corresponding backend exists""" + + @abstractmethod + def get_cmd(self, environment, active_devices): + """Return the command to execute on node""" + + def add_export(self, key, var): + self.exports[key.strip()] = var.strip() + + @property + def name(self): + """Return the name of the backend""" + return self.__class__.__name__ + + +class PDSHRunner(MultiNodeRunner): + def __init__(self, args): + super().__init__(args) + + def backend_exists(self): + return shutil.which('pdsh') + + @property + def name(self): + return "pdsh" + + def parse_user_args(self): + return list( + map(lambda x: x if x.startswith("-") else f"'{x}'", + self.args.user_args)) + + def get_cmd(self, environment, active_devices, args): + environment['PDSH_RCMD_TYPE'] = 'ssh' + + active_workers = ",".join(active_devices.keys()) + logger.info("Running on the following workers: %s" % active_workers) + + pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers] + + exports = "" + for key, val in self.exports.items(): + exports += f"export {key}={quote(val)}; " + + # https://linux.die.net/man/1/pdsh + # %n will be replaced by pdsh command + colossal_launch = [ + exports, + f"cd {os.path.abspath('.')};", + sys.executable, "-m", + "torch.distributed.launch", + f"--nproc_per_node={args.num_gpus}", + f"--master_addr={args.master_addr}", + f"--master_port={args.master_port}" + ] + return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments + +class OpenMPIRunner(MultiNodeRunner): + pass + +class SLURMRunner(MultiNodeRunner): + pass \ No newline at end of file diff --git a/colossalai/launcher/run.py b/colossalai/launcher/run.py new file mode 100644 index 000000000000..a9c2f00a7871 --- /dev/null +++ b/colossalai/launcher/run.py @@ -0,0 +1,286 @@ +import argparse +from argparse import ArgumentParser, REMAINDER +import subprocess +import collections +import sys +import os +import torch +from colossalai.logging import get_dist_logger +from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner + +def build_args_parser() -> ArgumentParser: + """Helper function parsing the command line options.""" + + parser = ArgumentParser(description="colossal distributed training launcher") + + parser.add_argument("-H", + "--hostfile", + type=str, + default="", + help="Hostfile path that defines the " + "device pool available to the job (e.g., " + "worker-name:number of slots)") + + parser.add_argument("-i", + "--include", + type=str, + default="", + help="Specify computing devices to use during execution." + "String format is NODE_SPEC@NODE_SPEC" + "where NODE_SPEC=:") + + parser.add_argument("-e", + "--exclude", + type=str, + default="", + help="Specify computing devices to NOT use during execution." + "Mutually exclusive with --include. Formatting" + "is the same as --include.") + + parser.add_argument("--num_nodes", + type=int, + default=-1, + help="Total number of worker nodes to use.") + + parser.add_argument("--num_gpus", + type=int, + default=-1, + help="Number of GPUs to use on each node.") + + parser.add_argument("--master_port", + default=29500, + type=int, + help="(optional) Port used by PyTorch distributed for " + "communication during distributed training.") + + parser.add_argument("--master_addr", + default="127.0.0.1", + type=str, + help="(optional) IP address of node 0, will be " + "inferred via 'hostname -I' if not specified.") + + parser.add_argument("--launcher", + default="torch", + type=str, + help="(optional) choose launcher backend for multi-node " + "training. Options currently include PDSH, OpenMPI, SLURM.") + + parser.add_argument("--launcher_args", + default="", + type=str, + help="(optional) pass launcher specific arguments as a " + "single quoted argument.") + + parser.add_argument("user_script", + type=str, + help="User script to launch, followed by any required " + "arguments.") + + parser.add_argument('user_args', nargs=argparse.REMAINDER) + + return parser + +def fetch_hostfile(hostfile_path): + logger = get_dist_logger() + if not os.path.isfile(hostfile_path): + logger.warning("Unable to find hostfile, will proceed with training " + "with local resources only") + return None + + # e.g., worker-0:16 + with open(hostfile_path, 'r') as fd: + device_pool = collections.OrderedDict() + for line in fd.readlines(): + line = line.strip() + if line == '': + # skip empty lines + continue + try: + hostname, slot_count = line.split(":") + slot_count = int(slot_count) + except ValueError as err: + logger.error("Hostfile is not formatted correctly, unable to " + "proceed with training.") + raise err + device_pool[hostname] = slot_count + + return device_pool + +def _stable_remove_duplicates(data): + # Create a new list in the same order as original but with duplicates + # removed, should never be more than ~16 elements so simple is best + new_list = [] + for x in data: + if x not in new_list: + new_list.append(x) + return new_list + +def parse_device_filter(host_info, include_str="", exclude_str=""): + '''Parse an inclusion or exclusion string and filter a hostfile dictionary. + + Examples: + include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and + slots [0, 2] on worker-1. + exclude_str="worker-1:0" will use all available devices except + slot 0 on worker-1. + ''' + + logger = get_dist_logger() + + # Constants that define our syntax + NODE_SEP = '@' + SLOT_LIST_START = ':' + SLOT_SEP = ',' + + # Ensure include/exclude are mutually exclusive + if (include_str != "") and (exclude_str != ""): + raise ValueError('include_str and exclude_str are mutually exclusive.') + + # no-op + if (include_str == "") and (exclude_str == ""): + return host_info + + # Either build from scratch or remove items + filtered_hosts = dict() + if include_str: + parse_str = include_str + if exclude_str != "": + filtered_hosts = deepcopy(host_info) + parse_str = exclude_str + + # foreach node in the list + for node_config in parse_str.split(NODE_SEP): + # Node can either be alone or node:slot,slot,slot + if SLOT_LIST_START in node_config: + hostname, slots = node_config.split(SLOT_LIST_START) + slots = [int(x) for x in slots.split(SLOT_SEP)] + + # sanity checks + if hostname not in host_info: + raise ValueError(f"Hostname '{hostname}' not found in hostfile") + for slot in slots: + if slot not in host_info[hostname]: + raise ValueError(f"No slot '{slot}' specified on host '{hostname}'") + + # If include string, build the list from here + if include_str: + filtered_hosts[hostname] = slots + elif exclude_str: + for slot in slots: + logger.info(f'removing {slot} from {hostname}') + filtered_hosts[hostname].remove(slot) + + # User just specified the whole node + else: + hostname = node_config + # sanity check hostname + if hostname not in host_info: + raise ValueError(f"Hostname '{hostname}' not found in hostfile") + + if include_str: + filtered_hosts[hostname] = host_info[hostname] + elif exclude_str: + filtered_hosts[hostname] = [] + + # Post-processing to remove duplicates and empty nodes + del_keys = [] + for hostname in filtered_hosts: + # Remove duplicates + filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname]) + # Remove empty hosts + if len(filtered_hosts[hostname]) == 0: + del_keys.append(hostname) + for name in del_keys: + del filtered_hosts[name] + + # Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure + # we map ranks to nodes correctly by maintaining host_info ordering. + ordered_hosts = collections.OrderedDict() + for host in host_info: + if host in filtered_hosts: + ordered_hosts[host] = filtered_hosts[host] + + return ordered_hosts + +def parse_inclusion_exclusion(device_pool, inclusion, exclusion): + active_devices = collections.OrderedDict() + for hostname, slots in device_pool.items(): + active_devices[hostname] = list(range(slots)) + + return parse_device_filter(active_devices, + include_str=inclusion, + exclude_str=exclusion) + +def main(args=None): + logger = get_dist_logger() + parser = build_args_parser() + args = parser.parse_args(args) + + device_pool = fetch_hostfile(args.hostfile) + + active_devices = None + if device_pool: + active_devices = parse_inclusion_exclusion(device_pool, + args.include, + args.exclude) + if args.num_nodes > 0: + updated_active_devices = collections.OrderedDict() + for count, hostname in enumerate(active_devices.keys()): + if args.num_nodes == count: + break + updated_active_devices[hostname] = active_devices[hostname] + active_devices = updated_active_devices + + if args.num_gpus > 0: + updated_active_devices = collections.OrderedDict() + for hostname in active_devices.keys(): + updated_active_devices[hostname] = list(range(args.num_gpus)) + active_devices = updated_active_devices + + env = os.environ.copy() + + if not active_devices: + if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count(): + nproc_per_node = torch.cuda.device_count() + else: + nproc_per_node = args.num_gpus + if torch.__version__ <= "1.09": + cmd = [sys.executable, "-m", + "torch.distributed.launch", + f"--nproc_per_node={nproc_per_node}", + f"--master_addr={args.master_addr}", + f"--master_port={args.master_port}"] + [args.user_script] + args.user_args + else: + cmd = ["torchrun", + f"--nproc_per_node={nproc_per_node}", + f"--master_addr={args.master_addr}", + f"--master_port={args.master_port}"] + [args.user_script] + args.user_args + else: + if args.launcher == "torch": + runner = PDSHRunner(args) + elif args.launcher == "mpi": + runner = OpenMPIRunner(args, device_pool) + elif args.launcher == "slurm": + runner = SLURMRunner(args, device_pool) + else: + raise NotImplementedError(f"Unknown launcher {args.launcher}") + + if not runner.backend_exists(): + raise RuntimeError(f"launcher '{args.launcher}' not installed.") + + curr_path = os.path.abspath('.') + if 'PYTHONPATH' in env: + env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] + else: + env['PYTHONPATH'] = curr_path + + cmd = runner.get_cmd(env, active_devices, args) + + result = subprocess.Popen(cmd, env=env) + result.wait() + if result.returncode > 0: + sys.exit(result.returncode) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup.py b/setup.py index 1e4b0da1ccc5..a47aeb44d0a8 100644 --- a/setup.py +++ b/setup.py @@ -216,6 +216,9 @@ def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): extras_require={ 'zero': fetch_requirements('requirements/requirements-zero.txt'), }, + scripts=[ + 'bin/colossal', + ], python_requires='>=3.7', classifiers=[ 'Programming Language :: Python :: 3', From a25697ac3e32cf22058bf33e702e060359fd8656 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Tue, 19 Apr 2022 11:27:00 +0800 Subject: [PATCH 2/4] Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. --- bin/colossal | 6 - colossalai/launcher/__init__.py | 0 colossalai/launcher/multinode_runner.py | 82 ------- colossalai/launcher/run.py | 286 ------------------------ setup.py | 3 - 5 files changed, 377 deletions(-) delete mode 100644 bin/colossal delete mode 100644 colossalai/launcher/__init__.py delete mode 100644 colossalai/launcher/multinode_runner.py delete mode 100644 colossalai/launcher/run.py diff --git a/bin/colossal b/bin/colossal deleted file mode 100644 index 8e9fcb29b7d8..000000000000 --- a/bin/colossal +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python3 - -from colossalai.launcher.run import main - -if __name__ == '__main__': - main() diff --git a/colossalai/launcher/__init__.py b/colossalai/launcher/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/colossalai/launcher/multinode_runner.py b/colossalai/launcher/multinode_runner.py deleted file mode 100644 index 188d2c81fe14..000000000000 --- a/colossalai/launcher/multinode_runner.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import sys -import shutil -import subprocess -import warnings -from shlex import quote -from abc import ABC, abstractmethod - -from colossalai.logging import get_dist_logger - -logger = get_dist_logger() - -class MultiNodeRunner(ABC): - def __init__(self, args): - self.args = args - self.user_arguments = self.args.user_args - self.user_script = args.user_script - self.exports = {} - - @abstractmethod - def backend_exists(self): - """Return whether the corresponding backend exists""" - - @abstractmethod - def get_cmd(self, environment, active_devices): - """Return the command to execute on node""" - - def add_export(self, key, var): - self.exports[key.strip()] = var.strip() - - @property - def name(self): - """Return the name of the backend""" - return self.__class__.__name__ - - -class PDSHRunner(MultiNodeRunner): - def __init__(self, args): - super().__init__(args) - - def backend_exists(self): - return shutil.which('pdsh') - - @property - def name(self): - return "pdsh" - - def parse_user_args(self): - return list( - map(lambda x: x if x.startswith("-") else f"'{x}'", - self.args.user_args)) - - def get_cmd(self, environment, active_devices, args): - environment['PDSH_RCMD_TYPE'] = 'ssh' - - active_workers = ",".join(active_devices.keys()) - logger.info("Running on the following workers: %s" % active_workers) - - pdsh_cmd_args = ['pdsh', '-f', str(1024), '-w', active_workers] - - exports = "" - for key, val in self.exports.items(): - exports += f"export {key}={quote(val)}; " - - # https://linux.die.net/man/1/pdsh - # %n will be replaced by pdsh command - colossal_launch = [ - exports, - f"cd {os.path.abspath('.')};", - sys.executable, "-m", - "torch.distributed.launch", - f"--nproc_per_node={args.num_gpus}", - f"--master_addr={args.master_addr}", - f"--master_port={args.master_port}" - ] - return pdsh_cmd_args + colossal_launch + [self.user_script] + self.user_arguments - -class OpenMPIRunner(MultiNodeRunner): - pass - -class SLURMRunner(MultiNodeRunner): - pass \ No newline at end of file diff --git a/colossalai/launcher/run.py b/colossalai/launcher/run.py deleted file mode 100644 index a9c2f00a7871..000000000000 --- a/colossalai/launcher/run.py +++ /dev/null @@ -1,286 +0,0 @@ -import argparse -from argparse import ArgumentParser, REMAINDER -import subprocess -import collections -import sys -import os -import torch -from colossalai.logging import get_dist_logger -from .multinode_runner import PDSHRunner, OpenMPIRunner, SLURMRunner - -def build_args_parser() -> ArgumentParser: - """Helper function parsing the command line options.""" - - parser = ArgumentParser(description="colossal distributed training launcher") - - parser.add_argument("-H", - "--hostfile", - type=str, - default="", - help="Hostfile path that defines the " - "device pool available to the job (e.g., " - "worker-name:number of slots)") - - parser.add_argument("-i", - "--include", - type=str, - default="", - help="Specify computing devices to use during execution." - "String format is NODE_SPEC@NODE_SPEC" - "where NODE_SPEC=:") - - parser.add_argument("-e", - "--exclude", - type=str, - default="", - help="Specify computing devices to NOT use during execution." - "Mutually exclusive with --include. Formatting" - "is the same as --include.") - - parser.add_argument("--num_nodes", - type=int, - default=-1, - help="Total number of worker nodes to use.") - - parser.add_argument("--num_gpus", - type=int, - default=-1, - help="Number of GPUs to use on each node.") - - parser.add_argument("--master_port", - default=29500, - type=int, - help="(optional) Port used by PyTorch distributed for " - "communication during distributed training.") - - parser.add_argument("--master_addr", - default="127.0.0.1", - type=str, - help="(optional) IP address of node 0, will be " - "inferred via 'hostname -I' if not specified.") - - parser.add_argument("--launcher", - default="torch", - type=str, - help="(optional) choose launcher backend for multi-node " - "training. Options currently include PDSH, OpenMPI, SLURM.") - - parser.add_argument("--launcher_args", - default="", - type=str, - help="(optional) pass launcher specific arguments as a " - "single quoted argument.") - - parser.add_argument("user_script", - type=str, - help="User script to launch, followed by any required " - "arguments.") - - parser.add_argument('user_args', nargs=argparse.REMAINDER) - - return parser - -def fetch_hostfile(hostfile_path): - logger = get_dist_logger() - if not os.path.isfile(hostfile_path): - logger.warning("Unable to find hostfile, will proceed with training " - "with local resources only") - return None - - # e.g., worker-0:16 - with open(hostfile_path, 'r') as fd: - device_pool = collections.OrderedDict() - for line in fd.readlines(): - line = line.strip() - if line == '': - # skip empty lines - continue - try: - hostname, slot_count = line.split(":") - slot_count = int(slot_count) - except ValueError as err: - logger.error("Hostfile is not formatted correctly, unable to " - "proceed with training.") - raise err - device_pool[hostname] = slot_count - - return device_pool - -def _stable_remove_duplicates(data): - # Create a new list in the same order as original but with duplicates - # removed, should never be more than ~16 elements so simple is best - new_list = [] - for x in data: - if x not in new_list: - new_list.append(x) - return new_list - -def parse_device_filter(host_info, include_str="", exclude_str=""): - '''Parse an inclusion or exclusion string and filter a hostfile dictionary. - - Examples: - include_str="worker-0@worker-1:0,2" will use all slots on worker-0 and - slots [0, 2] on worker-1. - exclude_str="worker-1:0" will use all available devices except - slot 0 on worker-1. - ''' - - logger = get_dist_logger() - - # Constants that define our syntax - NODE_SEP = '@' - SLOT_LIST_START = ':' - SLOT_SEP = ',' - - # Ensure include/exclude are mutually exclusive - if (include_str != "") and (exclude_str != ""): - raise ValueError('include_str and exclude_str are mutually exclusive.') - - # no-op - if (include_str == "") and (exclude_str == ""): - return host_info - - # Either build from scratch or remove items - filtered_hosts = dict() - if include_str: - parse_str = include_str - if exclude_str != "": - filtered_hosts = deepcopy(host_info) - parse_str = exclude_str - - # foreach node in the list - for node_config in parse_str.split(NODE_SEP): - # Node can either be alone or node:slot,slot,slot - if SLOT_LIST_START in node_config: - hostname, slots = node_config.split(SLOT_LIST_START) - slots = [int(x) for x in slots.split(SLOT_SEP)] - - # sanity checks - if hostname not in host_info: - raise ValueError(f"Hostname '{hostname}' not found in hostfile") - for slot in slots: - if slot not in host_info[hostname]: - raise ValueError(f"No slot '{slot}' specified on host '{hostname}'") - - # If include string, build the list from here - if include_str: - filtered_hosts[hostname] = slots - elif exclude_str: - for slot in slots: - logger.info(f'removing {slot} from {hostname}') - filtered_hosts[hostname].remove(slot) - - # User just specified the whole node - else: - hostname = node_config - # sanity check hostname - if hostname not in host_info: - raise ValueError(f"Hostname '{hostname}' not found in hostfile") - - if include_str: - filtered_hosts[hostname] = host_info[hostname] - elif exclude_str: - filtered_hosts[hostname] = [] - - # Post-processing to remove duplicates and empty nodes - del_keys = [] - for hostname in filtered_hosts: - # Remove duplicates - filtered_hosts[hostname] = _stable_remove_duplicates(filtered_hosts[hostname]) - # Remove empty hosts - if len(filtered_hosts[hostname]) == 0: - del_keys.append(hostname) - for name in del_keys: - del filtered_hosts[name] - - # Lastly, go over filtered_hosts and convert to a OrderedDict() to ensure - # we map ranks to nodes correctly by maintaining host_info ordering. - ordered_hosts = collections.OrderedDict() - for host in host_info: - if host in filtered_hosts: - ordered_hosts[host] = filtered_hosts[host] - - return ordered_hosts - -def parse_inclusion_exclusion(device_pool, inclusion, exclusion): - active_devices = collections.OrderedDict() - for hostname, slots in device_pool.items(): - active_devices[hostname] = list(range(slots)) - - return parse_device_filter(active_devices, - include_str=inclusion, - exclude_str=exclusion) - -def main(args=None): - logger = get_dist_logger() - parser = build_args_parser() - args = parser.parse_args(args) - - device_pool = fetch_hostfile(args.hostfile) - - active_devices = None - if device_pool: - active_devices = parse_inclusion_exclusion(device_pool, - args.include, - args.exclude) - if args.num_nodes > 0: - updated_active_devices = collections.OrderedDict() - for count, hostname in enumerate(active_devices.keys()): - if args.num_nodes == count: - break - updated_active_devices[hostname] = active_devices[hostname] - active_devices = updated_active_devices - - if args.num_gpus > 0: - updated_active_devices = collections.OrderedDict() - for hostname in active_devices.keys(): - updated_active_devices[hostname] = list(range(args.num_gpus)) - active_devices = updated_active_devices - - env = os.environ.copy() - - if not active_devices: - if args.num_gpus == -1 or args.num_gpus > torch.cuda.device_count(): - nproc_per_node = torch.cuda.device_count() - else: - nproc_per_node = args.num_gpus - if torch.__version__ <= "1.09": - cmd = [sys.executable, "-m", - "torch.distributed.launch", - f"--nproc_per_node={nproc_per_node}", - f"--master_addr={args.master_addr}", - f"--master_port={args.master_port}"] + [args.user_script] + args.user_args - else: - cmd = ["torchrun", - f"--nproc_per_node={nproc_per_node}", - f"--master_addr={args.master_addr}", - f"--master_port={args.master_port}"] + [args.user_script] + args.user_args - else: - if args.launcher == "torch": - runner = PDSHRunner(args) - elif args.launcher == "mpi": - runner = OpenMPIRunner(args, device_pool) - elif args.launcher == "slurm": - runner = SLURMRunner(args, device_pool) - else: - raise NotImplementedError(f"Unknown launcher {args.launcher}") - - if not runner.backend_exists(): - raise RuntimeError(f"launcher '{args.launcher}' not installed.") - - curr_path = os.path.abspath('.') - if 'PYTHONPATH' in env: - env['PYTHONPATH'] = curr_path + ":" + env['PYTHONPATH'] - else: - env['PYTHONPATH'] = curr_path - - cmd = runner.get_cmd(env, active_devices, args) - - result = subprocess.Popen(cmd, env=env) - result.wait() - if result.returncode > 0: - sys.exit(result.returncode) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/setup.py b/setup.py index f2232175b183..ec281e2e9dfb 100644 --- a/setup.py +++ b/setup.py @@ -213,9 +213,6 @@ def cuda_ext_helper(name, sources, extra_cuda_flags, extra_cxx_flags=[]): ext_modules=ext_modules, cmdclass={'build_ext': BuildExtension} if ext_modules else {}, install_requires=fetch_requirements('requirements/requirements.txt'), - scripts=[ - 'bin/colossal', - ], python_requires='>=3.7', classifiers=[ 'Programming Language :: Python :: 3', From 64a0bc636664ed4b7ba9e9e8d2de95895ee612c1 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Tue, 19 Jul 2022 17:47:56 +0800 Subject: [PATCH 3/4] [fx] refactor tracer to trace complete graph --- colossalai/fx/proxy.py | 53 +++++++++++++--- colossalai/fx/tracer/_tracer_utils.py | 17 ++++++ .../meta_patch/patched_function/torch_ops.py | 5 ++ colossalai/fx/tracer/tracer.py | 60 ++++++++++++++++++- tests/test_fx/test_coloproxy.py | 35 +++++++++-- .../test_pipeline/test_hf_model/test_opt.py | 1 - .../test_timm_model/test_timm.py | 1 - .../test_tracer/test_hf_model/test_hf_opt.py | 1 - .../test_timm_model/test_timm_model.py | 1 - tests/test_fx/test_transform_mlp_pass.py | 9 ++- 10 files changed, 161 insertions(+), 22 deletions(-) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index e96971b36c17..0f93a890f8f0 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -2,6 +2,7 @@ import torch from torch.fx.proxy import Proxy, Attribute from typing import List, Union, Any +from colossalai.fx.tracer.meta_patch import meta_patched_function __all__ = ['ColoProxy'] @@ -45,6 +46,14 @@ def __len__(self): self._assert_has_meta_data() return len(self.meta_data) + def __int__(self): + self._assert_has_meta_data() + return int(self.meta_data) + + def __float__(self): + self._assert_has_meta_data() + return float(self.meta_data) + def __bool__(self): self._assert_has_meta_data() return self.meta_data @@ -53,9 +62,6 @@ def __getattr__(self, k): return ColoAttribute(self, k) - def __setitem__(self, indices, values): - return self.tracer.create_proxy("call_function", operator.setitem, (self, indices, values), {}) - def __contains__(self, key): if self.node.op == "placeholder": # this is used to handle like @@ -65,11 +71,26 @@ def __contains__(self, key): return super().__contains__(key) +def extract_meta(*args, **kwargs): + """ + This function is copied from _tracer_utils.py to avoid circular import issue. + """ + + def _convert(val): + if isinstance(val, ColoProxy): + return val.meta_data + elif isinstance(val, (list, tuple)): + return type(val)([_convert(ele) for ele in val]) + return val + + new_args = [_convert(val) for val in args] + new_kwargs = {k: _convert(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + class ColoAttribute(ColoProxy): def __init__(self, root, attr: str): - # this class is copied from torch.fx.Attribute - # but inherits ColoProxy self.root = root self.attr = attr self.tracer = root.tracer @@ -78,8 +99,26 @@ def __init__(self, root, attr: str): @property def node(self): if self._node is None: - self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node + proxy = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*(self.root, self.attr)) + meta_out = getattr(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + self._node = proxy.node + return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + proxy = self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_args, meta_kwargs = extract_meta(*((self.root,) + args), **kwargs) + method = getattr(meta_args[0].__class__, self.attr) + if meta_patched_function.has(method): + meta_target = meta_patched_function.get(method) + else: + meta_target = method + meta_out = meta_target(*meta_args, **meta_kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 300a822767dd..a8a85c3b85be 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -1,5 +1,7 @@ from typing import List, Union, Any from ..proxy import ColoProxy, ColoAttribute +import torch +from .meta_patch import meta_patched_function, meta_patched_module __all__ = ['is_element_in_list', 'extract_meta'] @@ -29,3 +31,18 @@ def _convert(val): new_args = [_convert(val) for val in args] new_kwargs = {k: _convert(v) for k, v in kwargs.items()} return new_args, new_kwargs + + +def compute_meta_data_for_functions_proxy(target, args, kwargs): + args_metas, kwargs_metas = extract_meta(*args, **kwargs) + + # fetch patched function + if meta_patched_function.has(target): + meta_target = meta_patched_function.get(target) + else: + meta_target = target + meta_out = meta_target(*args_metas, **kwargs_metas) + if isinstance(meta_out, torch.Tensor): + meta_out = meta_out.to(device="meta") + + return meta_out diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index e3342a6465e9..457740ecd1fa 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -24,6 +24,11 @@ def torch_arange(*args, **kwargs): return torch.empty((end - start) // step, dtype=dtype, device="meta") +@meta_patched_function.register(torch.finfo) +def torch_finfo(*args, **kwargs): + return torch.finfo(*args) + + @meta_patched_function.register(torch.where) def torch_where(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 5b7a1eced1bb..82fdecff1622 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -7,6 +7,7 @@ import enum import inspect import functools +import operator from colossalai.fx.tracer.meta_patch import meta_patched_module import torch import torch.nn as nn @@ -16,8 +17,9 @@ from torch.fx.proxy import Proxy, ParameterProxy from ..proxy import ColoProxy from typing import Optional, Dict, Any -from ._tracer_utils import is_element_in_list, extract_meta +from ._tracer_utils import is_element_in_list, extract_meta, compute_meta_data_for_functions_proxy from .meta_patch import meta_patched_function, meta_patched_module +from torch.fx.graph import magic_methods, reflectable_magic_methods __all__ = ['ColoTracer'] @@ -61,7 +63,7 @@ def __init__(self, *args, **kwargs): # Feature flag for proxying accesses to buffer values proxy_buffer_attributes: bool = True - _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor"] + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full", "full_like", "eye", "empty", "tensor", "finfo"] def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None) -> ColoProxy: """ @@ -344,11 +346,15 @@ def look_for_proxy(*args, **kwargs): for arg in args: if isinstance(arg, Proxy): return arg + if isinstance(arg, (tuple, list)): + return look_for_proxy(*arg) # find in keyword vars for k, v in kwargs.items(): if isinstance(v, Proxy): return v + if isinstance(v, (tuple, list)): + return look_for_proxy(*v) return None @functools.wraps(target) @@ -358,10 +364,58 @@ def wrapper(*args, **kwargs): if proxy is not None: # if the arg is a proxy, then need to record this function called on this proxy # e.g. torch.ones(size) where size is an input proxy - return proxy.tracer.create_proxy("call_function", target, args, kwargs) + colo_proxy = proxy.tracer.create_proxy("call_function", target, args, kwargs) + if not isinstance(colo_proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + colo_proxy = ColoProxy(fx_proxy.node) + colo_proxy.meta_data = meta_out + return colo_proxy else: # this is called directly when the inputs do not contain proxy # e.g. torch.ones(4) where the input is static return target(*args, **kwargs) return wrapper, target + + +for method in magic_methods: + + def _scope(method): + + def impl(*args, **kwargs): + + tracer = args[0].tracer + target = getattr(operator, method) + proxy = tracer.create_proxy('call_function', target, args, kwargs) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method + as_magic = f'__{method.strip("_")}__' + setattr(ColoProxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f'__r{orig_method_name.strip("_")}__' + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) + if not isinstance(proxy, ColoProxy): + meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) + proxy = ColoProxy(proxy.node) + proxy.meta_data = meta_out + return proxy + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(ColoProxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 82be9329d4c8..2bb6cf86466c 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,17 +1,40 @@ import torch +import torch.nn as nn from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from torch.fx import GraphModule import pytest -@pytest.mark.skip('skip due to tracer') +class Conv1D(nn.Module): + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.zeros(nf)) + + def forward(self, x): + size_out = x.shape[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + def test_coloproxy(): - # create a dummy node only for testing purpose - model = torch.nn.Linear(10, 10) - gm = torch.fx.symbolic_trace(model) + + tracer = ColoTracer() + model = Conv1D(3, 3) + input_sample = {'x': torch.rand(3, 3).to('meta')} + + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() node = list(gm.graph.nodes)[0] - # create proxy - proxy = ColoProxy(node=node) + proxy = ColoProxy(node=node, tracer=tracer) proxy.meta_data = torch.empty(4, 2, device='meta') assert len(proxy) == 4 diff --git a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py index bd1b2aa2cfc1..a55ea54feb59 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/test_opt.py +++ b/tests/test_fx/test_pipeline/test_hf_model/test_opt.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('skip due to tracer') def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index 81ff4536d973..7c3764f341b7 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -24,7 +24,6 @@ def test_timm_models_without_control_flow(): split_model_and_compare_output(model, data) -@pytest.mark.skip('skip due to tracer') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py index 3206dc75ba33..5ac051887eb6 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py @@ -7,7 +7,6 @@ SEQ_LENGHT = 16 -@pytest.mark.skip('skip due to tracer') def test_opt(): MODEL_LIST = [ transformers.OPTModel, diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 38f5a3829d5a..2ee498b9edfd 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -54,7 +54,6 @@ def test_timm_models_without_control_flow(): trace_and_compare(model_cls, tracer, data) -@pytest.mark.skip('skip due to tracer') def test_timm_models_with_control_flow(): torch.backends.cudnn.deterministic = True diff --git a/tests/test_fx/test_transform_mlp_pass.py b/tests/test_fx/test_transform_mlp_pass.py index 202c8ce0e138..f0c9adbf5936 100644 --- a/tests/test_fx/test_transform_mlp_pass.py +++ b/tests/test_fx/test_transform_mlp_pass.py @@ -4,8 +4,10 @@ import colossalai from colossalai.fx import ColoTracer from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass + CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) + class MLP(torch.nn.Module): def __init__(self, dim: int): @@ -24,6 +26,7 @@ def forward(self, x): x = torch.nn.functional.relu(self.linear4(x)) return x + def test_out_acc(): model = MLP(16).cuda() model.eval() @@ -36,6 +39,7 @@ def test_out_acc(): new_output = splitted_gm(input_tensor) assert output.equal(new_output) + def test_linear_acc(): input_tensor = torch.rand(2, 16).cuda() model = MLP(16).cuda() @@ -45,15 +49,16 @@ def test_linear_acc(): splitted_gm = transform_mlp_pass(gm) col_shard = True for node in splitted_gm.graph.nodes: - if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), torch.nn.Linear): + if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), + torch.nn.Linear): target_module = node.graph.owning_module.get_submodule(node.target) dim = 0 if col_shard else -1 assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs") col_shard = not col_shard + if __name__ == "__main__": torch.manual_seed(1) torch.cuda.manual_seed(1) - # colossalai.launch_from_torch(config=CONFIG) test_out_acc() test_linear_acc() From 6042054f784aee6e1f4ea7c077e1ae9e02ab2b60 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 Date: Wed, 20 Jul 2022 11:10:16 +0800 Subject: [PATCH 4/4] add comments and solve conflicts. --- colossalai/fx/proxy.py | 2 + colossalai/fx/tracer/_tracer_utils.py | 2 + .../meta_patch/patched_function/torch_ops.py | 2 +- colossalai/fx/tracer/tracer.py | 2 + tests/test_fx/test_transform_mlp_pass.py | 64 ------------------- 5 files changed, 7 insertions(+), 65 deletions(-) delete mode 100644 tests/test_fx/test_transform_mlp_pass.py diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 0f93a890f8f0..464c3598f1f0 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -116,6 +116,8 @@ def __call__(self, *args, **kwargs): method = getattr(meta_args[0].__class__, self.attr) if meta_patched_function.has(method): meta_target = meta_patched_function.get(method) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) else: meta_target = method meta_out = meta_target(*meta_args, **meta_kwargs) diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index a8a85c3b85be..0ec49a90a133 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -39,6 +39,8 @@ def compute_meta_data_for_functions_proxy(target, args, kwargs): # fetch patched function if meta_patched_function.has(target): meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + meta_target = meta_patched_function.get(target.__name__) else: meta_target = target meta_out = meta_target(*args_metas, **kwargs_metas) diff --git a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py index 457740ecd1fa..4c5c7c2e3b76 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/torch_ops.py @@ -25,7 +25,7 @@ def torch_arange(*args, **kwargs): @meta_patched_function.register(torch.finfo) -def torch_finfo(*args, **kwargs): +def torch_finfo(*args): return torch.finfo(*args) diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 82fdecff1622..1415e2f9d838 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -378,6 +378,8 @@ def wrapper(*args, **kwargs): return wrapper, target +# Patched magic methods for ColoProxy, then tracer could record the magic_method like __sub__, +# and add meta_data attribute to the created proxy. for method in magic_methods: def _scope(method): diff --git a/tests/test_fx/test_transform_mlp_pass.py b/tests/test_fx/test_transform_mlp_pass.py deleted file mode 100644 index f0c9adbf5936..000000000000 --- a/tests/test_fx/test_transform_mlp_pass.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -import pytest -import colossalai -from colossalai.fx import ColoTracer -from colossalai.fx.passes.shard_1d_pass import transform_mlp_pass - -CONFIG = dict(parallel=dict(tensor=dict(size=2, mode='1d'))) - - -class MLP(torch.nn.Module): - - def __init__(self, dim: int): - super().__init__() - self.linear1 = torch.nn.Linear(dim, dim) - self.linear2 = torch.nn.Linear(dim, dim) - self.linear3 = torch.nn.Linear(dim, dim) - self.linear4 = torch.nn.Linear(dim, dim) - self.dropout = torch.nn.Dropout() - self.relu = torch.nn.ReLU() - - def forward(self, x): - x = self.relu(self.linear1(x)) - x = self.dropout(self.relu(self.linear2(x))) - x = self.linear3(x) - x = torch.nn.functional.relu(self.linear4(x)) - return x - - -def test_out_acc(): - model = MLP(16).cuda() - model.eval() - input_tensor = torch.rand(2, 16).cuda() - output = model(input_tensor) - tracer = ColoTracer() - graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - splitted_gm = transform_mlp_pass(gm) - new_output = splitted_gm(input_tensor) - assert output.equal(new_output) - - -def test_linear_acc(): - input_tensor = torch.rand(2, 16).cuda() - model = MLP(16).cuda() - tracer = ColoTracer() - graph = tracer.trace(model, meta_args={'x': torch.randn((2, 16), device="meta")}) - gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) - splitted_gm = transform_mlp_pass(gm) - col_shard = True - for node in splitted_gm.graph.nodes: - if node.op == "call_module" and isinstance(node.graph.owning_module.get_submodule(node.target), - torch.nn.Linear): - target_module = node.graph.owning_module.get_submodule(node.target) - dim = 0 if col_shard else -1 - assert target_module.weight.fx_attr == (dim, "SHARD", "TP", "col_needs_many_outputs") - col_shard = not col_shard - - -if __name__ == "__main__": - torch.manual_seed(1) - torch.cuda.manual_seed(1) - test_out_acc() - test_linear_acc()