From ee5a21a726dc8f18b50af2d3ee0321b0a9b34778 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Aug 2023 15:48:35 +0800 Subject: [PATCH 1/6] [example] update gpt example --- examples/language/gpt/gemini/run_gemini.sh | 6 - examples/language/gpt/gemini/test_ci.sh | 22 +--- .../language/gpt/gemini/train_gpt_demo.py | 106 +----------------- 3 files changed, 10 insertions(+), 124 deletions(-) diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh index ad4e9419c1bd..57ce6ab64c5b 100644 --- a/examples/language/gpt/gemini/run_gemini.sh +++ b/examples/language/gpt/gemini/run_gemini.sh @@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} # The following options only valid when DISTPLAN="colossalai" export GPUNUM=${GPUNUM:-1} -export TPDEGREE=${TPDEGREE:-1} -export PLACEMENT=${PLACEMENT:-"cpu"} -export USE_SHARD_INIT=${USE_SHARD_INIT:-False} export BATCH_SIZE=${BATCH_SIZE:-16} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export TRAIN_STEP=${TRAIN_STEP:-10} @@ -21,11 +18,8 @@ fi mkdir -p gemini_logs torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ ---tp_degree=${TPDEGREE} \ --model_type=${MODEL_TYPE} \ --batch_size=${BATCH_SIZE} \ ---placement=${PLACEMENT} \ -${USE_SHARD_INIT} \ --distplan=${DISTPLAN} \ --train_step=${TRAIN_STEP} \ 2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh index 0ddfd3a6211c..6fb08b975d7a 100644 --- a/examples/language/gpt/gemini/test_ci.sh +++ b/examples/language/gpt/gemini/test_ci.sh @@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do for DISTPLAN in "CAI_Gemini"; do for BATCH_SIZE in 2; do for GPUNUM in 1 4; do - for TPDEGREE in 1 2; do - if [ ${TPDEGREE} -gt ${GPUNUM} ]; then - continue - fi - for PLACEMENT in "cpu" "auto"; do - MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ - bash ./run_gemini.sh - done - done + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \ + bash ./run_gemini.sh done done done - for DISTPLAN in "zero1" "zero2"; do + for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do for BATCH_SIZE in 2; do for GPUNUM in 1 4; do - for TPDEGREE in 1; do - if [ ${TPDEGREE} -gt ${GPUNUM} ]; then - continue - fi - MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\ - bash ./run_gemini.sh - done + MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \ + bash ./run_gemini.sh done done done diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index 9e61779a1dbf..347251ca5631 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -1,4 +1,5 @@ import os +from contextlib import nullcontext from functools import partial from time import time @@ -13,11 +14,10 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext CAI_VERSION = colossalai.__version__ @@ -30,24 +30,6 @@ def parse_args(): default='CAI_Gemini', help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", ) - parser.add_argument( - "--tp_degree", - type=int, - default=1, - help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--placement", - type=str, - default='cpu', - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--shardinit", - action='store_true', - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", - ) parser.add_argument( "--batch_size", type=int, @@ -71,20 +53,6 @@ def parse_args(): return args -# Parameter Sharding Strategies for Tensor Parallelism -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) - - class GPTLMLoss(nn.Module): def __init__(self): @@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism(): print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") -# Tensor Parallel -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - """tensor_parallelize - Sharding the Model Parameters. - - Args: - model (torch.nn.Module): a torch module to be sharded - """ - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - # NOTE() a param maybe shared by two modules - if hasattr(param, 'visited'): - continue - - # if shard init, then convert param to replica and use the dp-only ProcessGroup - param: ColoParameter = param - param.set_dist_spec(ReplicaSpec()) - param.set_process_group(pg) - - # shard it w.r.t tp pattern - if 'mlp.c_fc' in mn: - if 'weight' in pn or 'bias' in pn: - split_param_col_tp1d(param, pg) # column slice - # keep the shape of the output from c_fc - param.compute_spec.set_output_replicate(False) - else: - param.set_dist_spec(ReplicaSpec()) - elif 'mlp.c_proj' in mn: - if 'weight' in pn: - split_param_row_tp1d(param, pg) # row slice - else: - param.set_dist_spec(ReplicaSpec()) - elif 'wte' in mn or 'wpe' in mn: - split_param_col_tp1d(param, pg) # column slice - elif 'c_attn' in mn or 'c_proj' in mn: - split_param_col_tp1d(param, pg) # column slice - else: - param.set_dist_spec(ReplicaSpec()) - param.visited = True - - def main(): # version check # this example is supposed to work for versions greater than 0.2.0 @@ -213,30 +140,13 @@ def main(): # build criterion criterion = GPTLMLoss() - torch.manual_seed(123) if args.distplan.startswith("CAI"): - # all param must use the same process group. - world_size = torch.distributed.get_world_size() - shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None - default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None - - if args.shardinit and args.distplan != "CAI_Gemini": - raise RuntimeError("You can only use shardinit with CAI_Gemini") - + ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext() # build GPT model - with ColoInitContext(device=get_current_device(), - dtype=torch.half, - default_dist_spec=default_dist_spec, - default_pg=shard_pg): + with ctx: model = model_builder(args.model_type)(checkpoint=True) - tp_pg = ProcessGroup(tp_degree=args.tp_degree) - # Tensor Parallelism (TP) - # You should notice that v0.1.10 is not compatible with TP degree > 1 - if args.tp_degree > 1: - tensor_parallelize(model, tp_pg) - # assign running configurations if args.distplan == "CAI_ZeRO1": zero_stage = 1 @@ -254,13 +164,7 @@ def main(): overlap_communication=True, verbose=True) elif args.distplan == "CAI_Gemini": - plugin = GeminiPlugin(device=get_current_device(), - placement_policy=args.placement, - pin_memory=True, - strict_ddp_mode=args.tp_degree == 1, - search_range_m=128, - hidden_dim=model.config.n_embd, - gpu_margin_mem_ratio=0.) + plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) else: raise RuntimeError From b5f0fb5662458d88fbc0bd5db045d24f443f07b5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Aug 2023 15:49:00 +0800 Subject: [PATCH 2/6] [example] update dreambooth example --- examples/images/dreambooth/test_ci.sh | 3 +- .../dreambooth/train_dreambooth_colossalai.py | 53 ++++++++++--------- .../train_dreambooth_colossalai_lora.py | 30 ++++++----- 3 files changed, 44 insertions(+), 42 deletions(-) diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh index 21f45adae2a0..84345f589bb5 100644 --- a/examples/images/dreambooth/test_ci.sh +++ b/examples/images/dreambooth/test_ci.sh @@ -20,6 +20,5 @@ for plugin in "gemini"; do --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --test_run=True \ - --num_class_images=200 \ - --placement="auto" # "cuda" + --num_class_images=200 done diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 888b28de8306..f60704650b7e 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -2,9 +2,9 @@ import hashlib import math import os +import shutil from pathlib import Path from typing import Optional -import shutil import torch import torch.nn.functional as F @@ -19,6 +19,8 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -26,8 +28,6 @@ from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext from colossalai.zero.gemini import get_static_torch_model -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin disable_existing_loggers() logger = get_dist_logger() @@ -138,10 +138,10 @@ def parse_args(input_args=None): " resolution"), ) parser.add_argument( - "--placement", - type=str, - default="cpu", - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", + "--offload_optim_frac", + type=float, + default=1.0, + help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.", ) parser.add_argument( "--center_crop", @@ -461,18 +461,17 @@ def main(args): revision=args.revision, ) - if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) + revision=args.revision, + low_cpu_mem_usage=False) vae.requires_grad_(False) text_encoder.requires_grad_(False) @@ -491,30 +490,31 @@ def main(args): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) + plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': - plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam(unet.parameters(), + lr=args.learning_rate, + initial_scale=2**5, + clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # prepare dataset logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - instance_prompt=args.instance_prompt, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_prompt=args.class_prompt, - tokenizer=tokenizer, - size=args.resolution, - center_crop=args.center_crop, - test=args.test_run - ) + train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + test=args.test_run) def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] @@ -690,6 +690,7 @@ def collate_fn(examples): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + if __name__ == "__main__": args = parse_args() main(args) diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index dce65ff514b7..c98950fd795d 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -2,9 +2,9 @@ import hashlib import math import os +import shutil from pathlib import Path from typing import Optional -import shutil import torch import torch.nn.functional as F @@ -21,6 +21,8 @@ from transformers import AutoTokenizer, PretrainedConfig import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -28,8 +30,6 @@ from colossalai.utils import get_current_device from colossalai.zero import ColoInitContext, GeminiAdamOptimizer from colossalai.zero.gemini import get_static_torch_model -from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin disable_existing_loggers() logger = get_dist_logger() @@ -459,18 +459,17 @@ def main(args): revision=args.revision, ) - if args.externel_unet_path is None: logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, - subfolder="unet", - revision=args.revision, - low_cpu_mem_usage=False) + subfolder="unet", + revision=args.revision, + low_cpu_mem_usage=False) else: logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, - revision=args.revision, - low_cpu_mem_usage=False) + revision=args.revision, + low_cpu_mem_usage=False) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, @@ -490,8 +489,7 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) unet.set_attn_processor(lora_attn_procs) lora_layers = AttnProcsLayers(unet.attn_processors) @@ -513,14 +511,17 @@ def main(args): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5) + plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': - plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) + plugin = LowLevelZeroPlugin(initial_scale=2**5) booster = Booster(plugin=plugin, **booster_kwargs) # config optimizer for colossalai zero - optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) + optimizer = HybridAdam(unet.parameters(), + lr=args.learning_rate, + initial_scale=2**5, + clipping_norm=args.max_grad_norm) # load noise_scheduler noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") @@ -711,6 +712,7 @@ def collate_fn(examples): if args.push_to_hub: repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + if __name__ == "__main__": args = parse_args() main(args) From dffd138ae670ceb596cf62ed49ee3e38bc9a2058 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Aug 2023 15:56:41 +0800 Subject: [PATCH 3/6] [example] update vit --- examples/images/vit/vit_train_demo.py | 64 ++++++++++++--------------- 1 file changed, 28 insertions(+), 36 deletions(-) diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 3a739f10b5d0..4dc0f67f40bf 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -1,20 +1,19 @@ import torch import torch.distributed as dist import transformers -from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor +from args import parse_demo_args +from data import BeansDataset, beans_collator from tqdm import tqdm +from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import get_current_device from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator - -from args import parse_demo_args -from data import BeansDataset, beans_collator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -22,12 +21,12 @@ def move_to_cuda(batch, device): def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): - + torch.cuda.synchronize() model.train() with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - + for batch in pbar: # Foward @@ -47,7 +46,7 @@ def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coor @torch.no_grad() def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): - + model.eval() accum_loss = torch.zeros(1, device=get_current_device()) total_num = torch.zeros(1, device=get_current_device()) @@ -76,9 +75,7 @@ def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): print(f"Evaluation result for epoch {epoch + 1}: \ average_loss={avg_loss}, \ accuracy={accuracy}.") - - - + def main(): @@ -102,14 +99,13 @@ def main(): train_dataset = BeansDataset(image_processor, split='train') eval_dataset = BeansDataset(image_processor, split='validation') - # Load pretrained ViT model config = ViTConfig.from_pretrained(args.model_name_or_path) config.num_labels = train_dataset.num_labels config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} - model = ViTForImageClassification.from_pretrained(args.model_name_or_path, - config=config, + model = ViTForImageClassification.from_pretrained(args.model_name_or_path, + config=config, ignore_mismatched_sizes=True) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) @@ -123,26 +119,22 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - strict_ddp_mode=True, - initial_scale=2**5) + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare dataloader train_dataloader = plugin.prepare_dataloader(train_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=beans_collator) eval_dataloader = plugin.prepare_dataloader(eval_dataset, - batch_size=args.batch_size, - shuffle=True, - drop_last=True, - collate_fn=beans_collator) + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + collate_fn=beans_collator) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) @@ -156,11 +148,11 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) - + model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + dataloader=train_dataloader, + lr_scheduler=lr_scheduler) + # Finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): @@ -174,4 +166,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 009ed5dd26c4a27157d66795b9b0ed3906ac64cb Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Aug 2023 15:56:56 +0800 Subject: [PATCH 4/6] [example] update opt --- examples/language/opt/opt_train_demo.py | 55 ++++++++++--------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py index fa7feca9c9a9..80063407ecd5 100644 --- a/examples/language/opt/opt_train_demo.py +++ b/examples/language/opt/opt_train_demo.py @@ -1,25 +1,20 @@ import time -import torch import datasets +import torch import transformers -from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer -from transformers import get_linear_schedule_with_warmup -from transformers.utils.versions import require_version +from args import parse_demo_args +from data import NetflixDataset, netflix_collator from tqdm import tqdm +from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup +from transformers.utils.versions import require_version import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator - -from args import parse_demo_args -from data import NetflixDataset, netflix_collator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") @@ -30,18 +25,18 @@ def move_to_cuda(batch, device): def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): - + torch.cuda.synchronize() model.train() with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - + for batch in pbar: # Forward optimizer.zero_grad() batch = move_to_cuda(batch, torch.cuda.current_device()) - + outputs = model(use_cache=False, **batch) loss = outputs['loss'] @@ -72,7 +67,7 @@ def main(): else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() - + # Build OPT model config = AutoConfig.from_pretrained(args.model_name_or_path) model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) @@ -88,43 +83,35 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - strict_ddp_mode=True, - initial_scale=2**5) + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare tokenizer and dataloader - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) dataset = NetflixDataset(tokenizer) dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=netflix_collator) - + # Set optimizer - optimizer = HybridAdam(model.parameters(), - lr=(args.learning_rate * world_size), - weight_decay=args.weight_decay) + optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) # Set lr scheduler total_steps = len(dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) - lr_scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=len(dataloader) * args.num_epoch - ) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=len(dataloader) * args.num_epoch) # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=dataloader, + model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + dataloader=dataloader, lr_scheduler=lr_scheduler) # Start finetuning From 630d1b318bc3e8f2d667ea3765abe19ea644f5b4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 18 Aug 2023 16:02:50 +0800 Subject: [PATCH 5/6] [example] update palm --- examples/language/palm/train.py | 78 ++++----------------------------- 1 file changed, 9 insertions(+), 69 deletions(-) diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 01862a02608b..526f791403ff 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -1,4 +1,5 @@ import gzip +from contextlib import nullcontext from functools import partial from time import time @@ -14,10 +15,10 @@ import colossalai from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.lazy import LazyInitContext from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn import HybridAdam -from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec -from colossalai.zero import ColoInitContext +from colossalai.utils import get_current_device # constants @@ -40,23 +41,10 @@ def parse_args(): help="The distributed plan [colossalai, pytorch].", ) parser.add_argument( - "--tp_degree", - type=int, - default=1, - help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--placement", - type=str, - default='cpu', - help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", - ) - parser.add_argument( - "--shardinit", - type=bool, - default=False, - help= - "Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.", + "--offload_optim_frac", + type=float, + default=1.0, + help="Fraction of optimizer states to be offloaded. This is only used for gemini.", ) parser.add_argument('-p', '--plugin', @@ -107,49 +95,6 @@ def get_model_size(model: nn.Module): return total_numel -# Parameter Sharding Strategies for Tensor Parallelism -def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): - spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) - param.set_tensor_spec(*spec) - - -def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(0, param, pg) - - -def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup): - split_param_single_dim_tp1d(-1, param, pg) - - -# Tensor Parallel -def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): - """tensor_parallelize - Sharding the Model Parameters. - Args: - model (torch.nn.Module): a torch module to be sharded - """ - for mn, module in model.named_modules(): - for pn, param in module.named_parameters(recurse=False): - if hasattr(param, 'visited'): - continue - param.set_dist_spec(ReplicaSpec()) - if 'net.0' in mn: - split_param_col_tp1d(param, pg) # column slice - elif 'to_q' in mn: - split_param_col_tp1d(param, pg) # column slice - elif 'to_kv' in mn: - split_param_row_tp1d(param, pg) # row slice - elif 'to_out' in mn: - split_param_row_tp1d(param, pg) # row slice - elif '1.1' in mn: - split_param_col_tp1d(param, pg) # column slice - elif '1.2' in mn: - split_param_row_tp1d(param, pg) # row slice - else: - param.set_dist_spec(ReplicaSpec()) - param.visited = True - - args = parse_args() if args.distplan not in ["colossalai", "pytorch"]: raise TypeError(f"{args.distplan} is error") @@ -206,23 +151,18 @@ def __len__(self): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2**5) + plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"plugin: {plugin}") booster = Booster(plugin=plugin, **booster_kwargs) - default_pg = ProcessGroup(tp_degree=args.tp_degree) - default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None - ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg) + ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext() with ctx: model = PaLM(num_tokens=50304, dim=4096, depth=64) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) - pg = default_pg - tensor_parallelize(model, pg) - # optimizer optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) From 816192b7e93d9d5826c150d7f6db0b92d9794101 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 21 Aug 2023 15:10:34 +0800 Subject: [PATCH 6/6] [example] update vit and opt benchmark --- examples/images/vit/vit_benchmark.py | 52 +++++++++++++------------- examples/language/opt/opt_benchmark.py | 47 ++++++++++------------- 2 files changed, 46 insertions(+), 53 deletions(-) diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index 11d480bba65f..c2293b96ad73 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -1,19 +1,18 @@ import time import torch +import tqdm import transformers +from args import parse_benchmark_args from transformers import ViTConfig, ViTForImageClassification -import tqdm import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.utils import get_current_device from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam -from args import parse_benchmark_args def format_num(num: int, bytes=False): """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" @@ -26,8 +25,13 @@ def format_num(num: int, bytes=False): def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): - pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float) - labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64) + pixel_values = torch.randn(batch_size, + num_channels, + height, + width, + device=torch.cuda.current_device(), + dtype=torch.float) + labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) return pixel_values, labels @@ -55,11 +59,11 @@ def main(): transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() - + # Whether to set limit on memory capacity if args.mem_cap > 0: colo_memory_cap(args.mem_cap) - + # Build ViT model config = ViTConfig.from_pretrained(args.model_name_or_path) model = ViTForImageClassification(config) @@ -75,11 +79,7 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - strict_ddp_mode=True, - initial_scale=2**5) + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) @@ -90,16 +90,15 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - # Start training. logger.info(f"Start testing", ranks=[0]) progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) - + torch.cuda.synchronize() model.train() start_time = time.time() - + for _ in range(args.max_train_steps): pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) @@ -111,18 +110,19 @@ def main(): torch.cuda.synchronize() progress_bar.update(1) - - # Compute Statistics + + # Compute Statistics end_time = time.time() throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) - - logger.info(f"Testing finished, " - f"batch size per gpu: {args.batch_size}, " - f"plugin: {args.plugin}, " - f"throughput: {throughput}, " - f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + + logger.info( + f"Testing finished, " + f"batch size per gpu: {args.batch_size}, " + f"plugin: {args.plugin}, " + f"throughput: {throughput}, " + f"maximum memory usage per gpu: {max_mem}.", + ranks=[0]) if __name__ == "__main__": diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py index 2d69036b50c6..90ed10ec7cca 100755 --- a/examples/language/opt/opt_benchmark.py +++ b/examples/language/opt/opt_benchmark.py @@ -1,22 +1,18 @@ import time import torch +import tqdm import transformers +from args import parse_benchmark_args from transformers import AutoConfig, OPTForCausalLM from transformers.utils.versions import require_version -import tqdm import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.tensor import ProcessGroup, ShardSpec -from colossalai.utils import get_current_device -from colossalai.zero import ColoInitContext from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator - -from args import parse_benchmark_args +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") @@ -61,11 +57,11 @@ def main(): transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() - + # Whether to set limit of memory capacity if args.mem_cap > 0: colo_memory_cap(args.mem_cap) - + # Build OPT model config = AutoConfig.from_pretrained(args.model_name_or_path) model = OPTForCausalLM(config=config) @@ -81,11 +77,7 @@ def main(): if args.plugin.startswith('torch_ddp'): plugin = TorchDDPPlugin() elif args.plugin == 'gemini': - plugin = GeminiPlugin(device=get_current_device(), - placement_policy='cpu', - pin_memory=True, - strict_ddp_mode=True, - initial_scale=2**5) + plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) @@ -96,18 +88,18 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _, _, _ = booster.boost(model, optimizer) - + SEQ_LEN = 1024 VOCAB_SIZE = 50257 # Start training. logger.info(f"Start testing", ranks=[0]) progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) - + torch.cuda.synchronize() model.train() start_time = time.time() - + for _ in range(args.max_train_steps): input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) @@ -119,18 +111,19 @@ def main(): torch.cuda.synchronize() progress_bar.update(1) - - # Compute Statistics + + # Compute Statistics end_time = time.time() throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) - - logger.info(f"Testing finished, " - f"batch size per gpu: {args.batch_size}, " - f"plugin: {args.plugin}, " - f"throughput: {throughput}, " - f"maximum memory usage per gpu: {max_mem}.", - ranks=[0]) + + logger.info( + f"Testing finished, " + f"batch size per gpu: {args.batch_size}, " + f"plugin: {args.plugin}, " + f"throughput: {throughput}, " + f"maximum memory usage per gpu: {max_mem}.", + ranks=[0]) if __name__ == "__main__":