diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 277843b66568..f74fa0d1922c 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -303,7 +303,7 @@ def __init__(self, ) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}' if enable_sequence_parallelism: - assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism' + assert tp_size > 1, 'Tensor parallelism must be enabled when using sequence parallelism' self.tp_size = tp_size self.pp_size = pp_size @@ -414,7 +414,7 @@ def configure( use_pipeline=self.enable_pipeline_parallelism, param_info=param_info) else: - assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1." + assert self.dp_size > 1, "Data parallel size should be greater than 1 when using Zero." assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO." optimizer = HybridParallelZeroOptimizer(optimizer, model, diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index b209ffde85a4..8ef06c669511 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -1,10 +1,13 @@ import argparse from typing import List, Union +import copy +from contextlib import nullcontext import evaluate import torch import torch.distributed as dist import torch.nn as nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from data import GLUEDataBuilder from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -18,15 +21,17 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin, HybridParallelPlugin from colossalai.cluster import DistCoordinator from colossalai.nn.optimizer import HybridAdam +from torch.optim import Adam from colossalai.utils import get_current_device +from colossalai.lazy import LazyInitContext # ============================== # Prepare Hyperparameters # ============================== -NUM_EPOCHS = 3 +NUM_EPOCHS = 1 BATCH_SIZE = 32 LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 @@ -77,24 +82,62 @@ def evaluate_subset(dataloader: DataLoader): return final_results -def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, +def train_epoch(epoch: int, org_model: nn.Module, sharded_model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - model.train() + sharded_model.train() with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar: for batch in pbar: # Forward pass batch = move_to_cuda(batch) - outputs = model(**batch) - loss = outputs[0] - - # Backward and optimize - booster.backward(loss, optimizer) - optimizer.step() - optimizer.zero_grad() + + # different forward and backward logic when pipeline parallel is enabled + if booster.plugin.pp_size > 1: + + output_transform_fn = lambda x: x + criterion = lambda x: x.loss + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + data=batch + for k, v in data.items(): + if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to('cuda').repeat(*new_shape) + + data_iter = iter([data]) + outputs = booster.execute_pipeline(data_iter, + sharded_model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + loss = outputs['loss'] + + # Backward and optimize + org_model.train() + data = {k: v.cuda() for k, v in data.items()} + org_output = org_model(**data) + + org_loss = criterion(org_output) + org_loss.backward() + + else: + outputs = org_model(**batch) + loss = outputs[0] + + # Backward and optimize + booster.backward(loss, optimizer) + + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() # Print log info - pbar.set_postfix({'loss': loss.item()}) + #pbar.set_postfix({'loss': loss.item()}) def main(): @@ -107,7 +150,7 @@ def main(): '--plugin', type=str, default='torch_ddp', - choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'], + choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'], help="plugin to use") parser.add_argument( "--model_type", @@ -116,6 +159,7 @@ def main(): help="bert or albert", ) parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached") + parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context") args = parser.parse_args() if args.model_type == 'bert': @@ -145,6 +189,17 @@ def main(): plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + + # modify the param accordingly for finetuning test cases + plugin = HybridParallelPlugin(tp_size=1, + pp_size=2, + num_microbatches=2, + enable_all_optimization=False, + #enable_sequence_parallelism=True, + zero_stage=1, + precision='fp16', + initial_scale=1) booster = Booster(plugin=plugin, **booster_kwargs) @@ -166,12 +221,12 @@ def main(): cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels) if model_name == "bert-base-uncased": - model = BertForSequenceClassification.from_pretrained(model_name, config=cfg) + model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() elif model_name == "albert-xxlarge-v2": model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg) else: raise RuntimeError - + # optimizer no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ @@ -184,8 +239,18 @@ def main(): "weight_decay": 0.0, }, ] - - optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) + + # lazy_init + use_lazy_init = args.use_lazy_init + ctx = LazyInitContext() if use_lazy_init else nullcontext() + with ctx: + org_model = model + sharded_model = copy.deepcopy(org_model) + if use_lazy_init: + ctx.materialize(org_model) + + optimizer = Adam(model.parameters(), lr=lr, eps=1e-8) + sharded_optimizer = Adam(sharded_model.parameters(), lr=lr, eps=1e-8) # lr scheduler total_steps = len(train_dataloader) * NUM_EPOCHS @@ -195,17 +260,17 @@ def main(): num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, ) - + # ============================== # Boost with ColossalAI # ============================== - model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler) + sharded_model, optimizer, _, _, lr_scheduler = booster.boost(sharded_model, sharded_optimizer, lr_scheduler=lr_scheduler) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) + train_epoch(epoch, model, sharded_model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator) diff --git a/examples/language/bert/requirements.txt b/examples/language/bert/requirements.txt index 377422c260ad..eb4f6a9f4391 100644 --- a/examples/language/bert/requirements.txt +++ b/examples/language/bert/requirements.txt @@ -7,3 +7,4 @@ transformers scipy scikit-learn ptflops +xformers