Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions colossalai/cli/launcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@
"This will be converted to --arg1=1 --arg2=2 during execution",
)
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str)
@click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)")
@click.argument("user_script", type=str, required=False, default=None)
@click.argument("user_args", nargs=-1)
def run(
host: str,
Expand All @@ -77,8 +78,9 @@ def run(
master_port: int,
extra_launch_args: str,
ssh_port: int,
m: str,
user_script: str,
user_args: str,
user_args: tuple,
) -> None:
"""
To launch multiple processes on a single node or multiple nodes via command line.
Expand All @@ -102,9 +104,24 @@ def run(
# run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
"""
if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help")
exit()
if m is not None:
if m.endswith(".py"):
click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help")
exit()
if user_script is not None:
user_args = (user_script,) + user_args
user_script = m
m = True
else:
if user_script is None:
click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help")
exit()
if not user_script.endswith(".py"):
click.echo(
f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help"
)
exit()
m = False

args_dict = locals()
args = Config(args_dict)
Expand Down
9 changes: 8 additions & 1 deletion colossalai/cli/launcher/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def get_launch_command(
user_args: List[str],
node_rank: int,
num_nodes: int,
run_as_module: bool,
extra_launch_args: str = None,
) -> str:
"""
Expand Down Expand Up @@ -155,6 +156,8 @@ def _arg_dict_to_list(arg_dict):

torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1
if torch_version.major < 2 and run_as_module:
raise ValueError("Torch version < 2.0 does not support running as module")

if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9
Expand Down Expand Up @@ -198,7 +201,10 @@ def _arg_dict_to_list(arg_dict):
]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args)

cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args
cmd += _arg_dict_to_list(extra_launch_args)
if run_as_module:
cmd.append("-m")
cmd += [user_script] + user_args
cmd = " ".join(cmd)
return cmd

Expand Down Expand Up @@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None:
user_args=args.user_args,
node_rank=node_id,
num_nodes=len(active_device_pool),
run_as_module=args.m,
extra_launch_args=args.extra_launch_args,
)
runner.send(hostinfo=hostinfo, cmd=cmd)
Expand Down