diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 6411b4302e95..4bb749f9d293 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -154,7 +154,7 @@ def _arg_dict_to_list(arg_dict): extra_launch_args = dict() torch_version = version.parse(torch.__version__) - assert torch_version.major == 1 + assert torch_version.major >= 1 if torch_version.minor < 9: cmd = [