Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
0f8710b
[shardformer] fix opt test hanging
flybird11111 Aug 27, 2023
31cd2f6
fix
flybird11111 Aug 27, 2023
2faec20
test
flybird11111 Aug 28, 2023
0ed1eda
test
flybird11111 Aug 29, 2023
1a5860a
Merge branch 'feature/shardformer' into feature/shardformer
flybird11111 Aug 29, 2023
8d7f059
test
flybird11111 Aug 29, 2023
73b0909
fix test
flybird11111 Aug 29, 2023
0da3a99
fix test
flybird11111 Aug 29, 2023
00383e0
remove print
flybird11111 Aug 30, 2023
2a23ac1
add fix
flybird11111 Aug 30, 2023
13f390c
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
c44787c
Merge branch 'feature/shardformer' into bert-finetune
flybird11111 Sep 1, 2023
bed0c84
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
5be7648
Merge branch 'bert-finetune' of https://github.com/flybird11111/Colos…
flybird11111 Sep 1, 2023
a73382b
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
1b03dbd
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
dced3e1
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
db74e72
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
275035a
[shardformer] fix epoch change
flybird11111 Sep 1, 2023
c81e09e
[shardformer] broadcast add pp group
flybird11111 Sep 1, 2023
56745dc
[shardformer] fix opt test hanging
flybird11111 Aug 27, 2023
b0867c5
fix
flybird11111 Aug 27, 2023
b8a29c8
test
flybird11111 Aug 28, 2023
2560d7f
test
flybird11111 Aug 29, 2023
f867ac8
[shardformer] zero1+pp and the corresponding tests (#4517)
CjhHa1 Aug 28, 2023
fc78baf
[shardformer/fix overlap bug] fix overlap bug, add overlap as an opti…
FoolPlayer Aug 28, 2023
4e9038a
[shardformer] fix emerged bugs after updating transformers (#4526)
Aug 29, 2023
f223ee8
test
flybird11111 Aug 29, 2023
ced65eb
fix test
flybird11111 Aug 29, 2023
412bd45
fix test
flybird11111 Aug 29, 2023
ef748f1
remove print
flybird11111 Aug 30, 2023
e9f9b3f
add fix
flybird11111 Aug 30, 2023
9c8ded6
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
20e0505
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
853ef03
[shardformer] Add overlap support for gpt2 (#4535)
FoolPlayer Aug 29, 2023
2a293d7
[shardformer] support pp+tp+zero1 tests (#4531)
flybird11111 Aug 30, 2023
9879034
[shardformer] fix submodule replacement bug when enabling pp (#4544)
Aug 31, 2023
c647bdb
[shardformer] support sharded optimizer checkpointIO of HybridParalle…
Aug 31, 2023
bf50c53
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
7c069b5
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
fd166e5
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
b980da4
[shardformer] add bert finetune example
flybird11111 Sep 1, 2023
f2ee523
[shardformer] fix epoch change
flybird11111 Sep 1, 2023
0b04238
[shardformer] broadcast add pp group
flybird11111 Sep 1, 2023
980f8df
Merge branch 'bert-finetune' of https://github.com/flybird11111/Colos…
flybird11111 Sep 3, 2023
940f740
Merge branch 'feature/shardformer' into bert-finetune
flybird11111 Sep 3, 2023
ff5fd27
rebase feature/shardformer
flybird11111 Sep 3, 2023
03302ee
update pipeline
flybird11111 Sep 3, 2023
eba9b81
[shardformer] fix
flybird11111 Sep 4, 2023
95329bc
[shardformer] fix
flybird11111 Sep 4, 2023
9ac6dc6
[shardformer] bert finetune fix
flybird11111 Sep 4, 2023
c9cfa8c
[shardformer] add all_reduce operation to loss
flybird11111 Sep 4, 2023
a679acd
[shardformer] make compatible with pytree.
flybird11111 Sep 4, 2023
b01febe
[shardformer] disable tp
flybird11111 Sep 4, 2023
5447d85
[shardformer] add 3d plugin to ci test
flybird11111 Sep 4, 2023
d52cf38
[shardformer] update num_microbatches to None
flybird11111 Sep 4, 2023
42fca4d
[shardformer] update microbatchsize
flybird11111 Sep 4, 2023
fb4ceab
[shardformer] update assert
flybird11111 Sep 4, 2023
0c596c4
update scheduler
flybird11111 Sep 4, 2023
2040e15
update scheduler
flybird11111 Sep 4, 2023
1d033bd
update scheduler
flybird11111 Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self,
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
Expand Down
3 changes: 2 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self._use_microbatch_size = num_microbatches is None

def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Expand All @@ -51,7 +52,7 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
if self.num_microbatches is not None:
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
Expand Down
163 changes: 130 additions & 33 deletions examples/language/bert/finetune.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
from typing import List, Union
from contextlib import nullcontext
from typing import Callable, List, Union

import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
Expand All @@ -18,8 +20,9 @@

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

Expand All @@ -32,38 +35,93 @@
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1

output_transform_fn = lambda x: x
criterion = lambda x: x.loss


def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}


@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
task_name: str,
eval_splits: List[str],
booster: Booster,
coordinator: DistCoordinator,
):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()

def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

labels = batch["labels"]

metric.add_batch(predictions=preds, references=labels)
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)

if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]

logits = outputs["outputs"]["logits"]

accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)

metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())

dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)

accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)

else:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

metric.add_batch(predictions=preds, references=labels)

results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size

return results

if isinstance(test_dataloader, DataLoader):
Expand All @@ -77,25 +135,43 @@ def evaluate_subset(dataloader: DataLoader):
return final_results


def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):

model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
Comment thread
flybird11111 marked this conversation as resolved.
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
Comment thread
ver217 marked this conversation as resolved.
outputs = booster.execute_pipeline(batch,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({'loss': loss.item()})

# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()

# Print log info
pbar.set_postfix({'loss': loss.item()})


def main():
# ==============================
Expand All @@ -107,7 +183,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
Expand All @@ -116,6 +192,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()

if args.model_type == 'bert':
Expand Down Expand Up @@ -145,6 +222,17 @@ def main():
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':

# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)

booster = Booster(plugin=plugin, **booster_kwargs)

Expand All @@ -165,8 +253,9 @@ def main():
# bert pretrained model

cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)

if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
Expand Down Expand Up @@ -196,19 +285,27 @@ def main():
num_training_steps=total_steps,
)

def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss

# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)

# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)

results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)

if coordinator.is_master():
print(results)
Expand Down
2 changes: 1 addition & 1 deletion examples/language/bert/test_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ set -xe

pip install -r requirements.txt

for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done