From 46503c35dd9342f943308ee451b62751f36bc961 Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Thu, 1 Jun 2023 14:30:51 +0800 Subject: [PATCH 1/6] Modify torch version requirement to adapt torch 2.0 --- colossalai/cli/launcher/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 6411b4302e95..4bb749f9d293 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -154,7 +154,7 @@ def _arg_dict_to_list(arg_dict): extra_launch_args = dict() torch_version = version.parse(torch.__version__) - assert torch_version.major == 1 + assert torch_version.major >= 1 if torch_version.minor < 9: cmd = [ From fcda67cdd255518dc317ce98bc79a558f061fa7f Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 16:57:03 +0800 Subject: [PATCH 2/6] modify palm example using new booster API --- examples/language/palm/README.md | 3 ++ examples/language/palm/run.sh | 8 ++-- examples/language/palm/test_ci.sh | 2 +- examples/language/palm/train.py | 70 +++++++++++++++++++++---------- 4 files changed, 56 insertions(+), 27 deletions(-) diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md index 486bf240f89c..3ff3939d63d4 100644 --- a/examples/language/palm/README.md +++ b/examples/language/palm/README.md @@ -43,6 +43,9 @@ palm = PaLM( ) ``` +## New API +We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/. + ## Test on Enwik8 ```bash diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh index 7a533509e009..2a846e81a9a7 100644 --- a/examples/language/palm/run.sh +++ b/examples/language/palm/run.sh @@ -3,9 +3,11 @@ export DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai" export TPDEGREE=1 -export GPUNUM=1 +export GPUNUM=4 export PLACEMENT='cpu' export USE_SHARD_INIT=False -export BATCH_SIZE=4 +export BATCH_SIZE=1 -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \ +--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \ +--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh index f21095578077..001004543958 100644 --- a/examples/language/palm/test_ci.sh +++ b/examples/language/palm/test_ci.sh @@ -4,6 +4,6 @@ for BATCH_SIZE in 2 do for GPUNUM in 1 4 do -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 25641 train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log done done diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index b16da1c7744a..22ecbdf87f44 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -9,6 +9,8 @@ import torch.optim as optim import tqdm from packaging import version + +from colossalai.nn import HybridAdam from palm_pytorch import PaLM from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from torch.utils.data import DataLoader, Dataset @@ -18,6 +20,8 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import MultiTimer, get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin # constants @@ -58,6 +62,12 @@ def parse_args(): help= "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", ) + parser.add_argument('-p', + '--plugin', + type=str, + default='torch_ddp', + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + help="plugin to use") parser.add_argument( "--batch_size", type=int, @@ -102,27 +112,27 @@ def get_model_size(model: nn.Module): # Gemini + ZeRO DDP -def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): - cai_version = colossalai.__version__ - if version.parse(cai_version) > version.parse("0.1.10"): - from colossalai.nn.parallel import GeminiDDP - model = GeminiDDP(model, - device=get_current_device(), - placement_policy=placement_policy, - pin_memory=True, - search_range_mb=32) - elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): - from colossalai.gemini import ChunkManager, GeminiManager - chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - chunk_manager = ChunkManager(chunk_size, - pg, - enable_distributed_storage=True, - init_device=GeminiManager.get_default_device(placement_policy)) - model = ZeroDDP(model, gemini_manager) - else: - raise NotImplemented(f"CAI version {cai_version} is not supported") - return model +# def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): +# cai_version = colossalai.__version__ +# if version.parse(cai_version) > version.parse("0.1.10"): +# from colossalai.zero.gemini.gemini_ddp import GeminiDDP +# model = GeminiDDP(model, +# device=get_current_device(), +# placement_policy=placement_policy, +# pin_memory=True, +# search_range_mb=32) +# elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): +# from colossalai.gemini import ChunkManager, GeminiManager +# chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) +# gemini_manager = GeminiManager(placement_policy, chunk_manager) +# chunk_manager = ChunkManager(chunk_size, +# pg, +# enable_distributed_storage=True, +# init_device=GeminiManager.get_default_device(placement_policy)) +# model = ZeroDDP(model, gemini_manager) +# else: +# raise NotImplemented(f"CAI version {cai_version} is not supported") +# return model # Parameter Sharding Strategies for Tensor Parallelism @@ -218,6 +228,18 @@ def __len__(self): if args.distplan == "colossalai": # instantiate GPT-like decoder model + booster_kwargs = {} + if args.plugin == 'torch_ddp_fp16': + booster_kwargs['mixed_precision'] = 'fp16' + if args.plugin.startswith('torch_ddp'): + plugin = TorchDDPPlugin() + elif args.plugin == 'gemini': + plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + elif args.plugin == 'low_level_zero': + plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + logger.info(f"plugin: {plugin}") + booster = Booster(plugin=plugin, **booster_kwargs) + default_pg = ProcessGroup(tp_degree=args.tp_degree) default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) @@ -228,12 +250,14 @@ def __len__(self): pg = default_pg tensor_parallelize(model, pg) - model = gemini_zero_dpp(model, pg, args.placement) + # model = gemini_zero_dpp(model, pg, args.placement) # optimizer #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) - optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) + optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) + model, optimizer, _, _, _ = booster.boost(model, optimizer) + else: model = PaLM(num_tokens=256, dim=512, depth=8) model = AutoregressiveWrapper(model, max_seq_len=2048) From a461ba044b9a1a61fa5db93dc12bad8453bc323d Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 16:58:00 +0800 Subject: [PATCH 3/6] roll back --- colossalai/cli/launcher/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py index 4bb749f9d293..6411b4302e95 100644 --- a/colossalai/cli/launcher/run.py +++ b/colossalai/cli/launcher/run.py @@ -154,7 +154,7 @@ def _arg_dict_to_list(arg_dict): extra_launch_args = dict() torch_version = version.parse(torch.__version__) - assert torch_version.major >= 1 + assert torch_version.major == 1 if torch_version.minor < 9: cmd = [ From f57512e103aa988b62f86af446ca484e7725a8ea Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 16:59:02 +0800 Subject: [PATCH 4/6] fix port --- examples/language/palm/test_ci.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh index 001004543958..4de6a44e5bf7 100644 --- a/examples/language/palm/test_ci.sh +++ b/examples/language/palm/test_ci.sh @@ -4,6 +4,6 @@ for BATCH_SIZE in 2 do for GPUNUM in 1 4 do -env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 25641 train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log +env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log done done From fbb89f5a21217f645273ac3826eb51465729d4d8 Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 16:59:47 +0800 Subject: [PATCH 5/6] polish --- examples/language/palm/train.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 22ecbdf87f44..0d3adce2a8e1 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -111,28 +111,6 @@ def get_model_size(model: nn.Module): return total_numel -# Gemini + ZeRO DDP -# def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"): -# cai_version = colossalai.__version__ -# if version.parse(cai_version) > version.parse("0.1.10"): -# from colossalai.zero.gemini.gemini_ddp import GeminiDDP -# model = GeminiDDP(model, -# device=get_current_device(), -# placement_policy=placement_policy, -# pin_memory=True, -# search_range_mb=32) -# elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): -# from colossalai.gemini import ChunkManager, GeminiManager -# chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) -# gemini_manager = GeminiManager(placement_policy, chunk_manager) -# chunk_manager = ChunkManager(chunk_size, -# pg, -# enable_distributed_storage=True, -# init_device=GeminiManager.get_default_device(placement_policy)) -# model = ZeroDDP(model, gemini_manager) -# else: -# raise NotImplemented(f"CAI version {cai_version} is not supported") -# return model # Parameter Sharding Strategies for Tensor Parallelism From f8bf1339bc648aeb8dfd3ac1b4831f60e3691320 Mon Sep 17 00:00:00 2001 From: Maruyama_Aya Date: Tue, 6 Jun 2023 17:01:52 +0800 Subject: [PATCH 6/6] polish --- examples/language/palm/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 0d3adce2a8e1..62062e8bd272 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -228,11 +228,9 @@ def __len__(self): pg = default_pg tensor_parallelize(model, pg) - # model = gemini_zero_dpp(model, pg, args.placement) # optimizer - #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) model, optimizer, _, _, _ = booster.boost(model, optimizer)