diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index d2d02811ac9d..e594d1d1d055 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -156,7 +156,8 @@ def _arg_dict_to_list(arg_dict): torch_version = version.parse(torch.__version__) assert torch_version.major >= 1 - if torch_version.minor < 9: + if torch_version.major == 1 and torch_version.minor < 9: + # torch distributed launch cmd with torch < 1.9 cmd = [ sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", @@ -172,12 +173,14 @@ def _arg_dict_to_list(arg_dict): value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor == 9: + # torch distributed launch cmd with torch == 1.9 cmd = [ sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" ] else: + # torch distributed launch cmd with torch > 1.9 cmd = [ "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" ]