Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self,
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'

if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
assert tp_size > 1, 'Tensor parallelism must be enabled when using sequence parallelism'

self.tp_size = tp_size
self.pp_size = pp_size
Expand Down Expand Up @@ -414,7 +414,7 @@ def configure(
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.dp_size > 1, "Data parallel size should be greater than 1 when using Zero."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
Expand Down
105 changes: 85 additions & 20 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import argparse
from typing import List, Union
import copy
from contextlib import nullcontext

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand All @@ -18,15 +21,17 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin, HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from torch.optim import Adam
from colossalai.utils import get_current_device
from colossalai.lazy import LazyInitContext

# ==============================
# Prepare Hyperparameters
# ==============================
NUM_EPOCHS = 3
NUM_EPOCHS = 1
BATCH_SIZE = 32
LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
Expand Down Expand Up @@ -77,24 +82,62 @@ def evaluate_subset(dataloader: DataLoader):
return final_results


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
def train_epoch(epoch: int, org_model: nn.Module, sharded_model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
model.train()
sharded_model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]

# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()

# different forward and backward logic when pipeline parallel is enabled
if booster.plugin.pp_size > 1:

output_transform_fn = lambda x: x
criterion = lambda x: x.loss

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
data=batch
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)

data_iter = iter([data])
outputs = booster.execute_pipeline(data_iter,
sharded_model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss = outputs['loss']

# Backward and optimize
org_model.train()
data = {k: v.cuda() for k, v in data.items()}
org_output = org_model(**data)

org_loss = criterion(org_output)
org_loss.backward()

else:
outputs = org_model(**batch)
loss = outputs[0]

# Backward and optimize
booster.backward(loss, optimizer)

optimizer.step()
optimizer.zero_grad()

lr_scheduler.step()

# Print log info
pbar.set_postfix({'loss': loss.item()})
#pbar.set_postfix({'loss': loss.item()})


def main():
Expand All @@ -107,7 +150,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
Expand All @@ -116,6 +159,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()

if args.model_type == 'bert':
Expand Down Expand Up @@ -145,6 +189,17 @@ def main():
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':

# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=2,
enable_all_optimization=False,
#enable_sequence_parallelism=True,
zero_stage=1,
precision='fp16',
initial_scale=1)

booster = Booster(plugin=plugin, **booster_kwargs)

Expand All @@ -166,12 +221,12 @@ def main():

cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
raise RuntimeError

# optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
Expand All @@ -184,8 +239,18 @@ def main():
"weight_decay": 0.0,
},
]

optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)

# lazy_init
use_lazy_init = args.use_lazy_init
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
org_model = model
sharded_model = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)

optimizer = Adam(model.parameters(), lr=lr, eps=1e-8)
sharded_optimizer = Adam(sharded_model.parameters(), lr=lr, eps=1e-8)

# lr scheduler
total_steps = len(train_dataloader) * NUM_EPOCHS
Expand All @@ -195,17 +260,17 @@ def main():
num_warmup_steps=num_warmup_steps,
num_training_steps=total_steps,
)

# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
sharded_model, optimizer, _, _, lr_scheduler = booster.boost(sharded_model, sharded_optimizer, lr_scheduler=lr_scheduler)

# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, sharded_model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)

results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
Expand Down
1 change: 1 addition & 0 deletions examples/language/bert/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ transformers
scipy
scikit-learn
ptflops
xformers