From 60539b143417fb261004b4a29f205d321a8e12eb Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Thu, 14 Sep 2023 15:29:22 +0800 Subject: [PATCH 1/2] Fix the version check bug in colossalai run when generating the cmd. --- colossalai/cli/launcher/run.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index d2d02811ac9d..39b6b32c2f28 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -156,7 +156,8 @@ def _arg_dict_to_list(arg_dict): torch_version = version.parse(torch.__version__) assert torch_version.major >= 1 - if torch_version.minor < 9: + if torch_version.major == 1 and torch_version.minor < 9: + # torch distributed launch cmd with torch < 1.9 cmd = [ sys.executable, "-m", "torch.distributed.launch", f"--nproc_per_node={nproc_per_node}", f"--master_addr={master_addr}", f"--master_port={master_port}", f"--nnodes={num_nodes}", @@ -172,12 +173,14 @@ def _arg_dict_to_list(arg_dict): value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor < 10: + # torch distributed launch cmd with torch == 1.9 cmd = [ sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" ] else: + # torch distributed launch cmd with torch > 1.9 cmd = [ "torchrun", f"--nproc_per_node={nproc_per_node}", f"--nnodes={num_nodes}", f"--node_rank={node_rank}" ] From 5b098f9d7878b0f13698277c07ec0486f327615e Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Mon, 18 Sep 2023 12:54:59 +0800 Subject: [PATCH 2/2] polish code --- colossalai/cli/launcher/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 39b6b32c2f28..e594d1d1d055 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -173,7 +173,7 @@ def _arg_dict_to_list(arg_dict): value = extra_launch_args.pop(key) default_torchrun_rdzv_args[key] = value - if torch_version.major == 1 and torch_version.minor < 10: + if torch_version.major == 1 and torch_version.minor == 9: # torch distributed launch cmd with torch == 1.9 cmd = [ sys.executable, "-m", "torch.distributed.run", f"--nproc_per_node={nproc_per_node}",