Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- [Policy](#policy)
- [Model Sharder](#model-sharder)
- [User-facing API](#user-facing-api)
- [Shardformer Convergence](#shardformer-convergence)


## 🔗 Introduction
Expand Down Expand Up @@ -324,3 +325,15 @@ class ShardFormer:
"""
...
```

### Shardformer Convergence

To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/shardformer_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.

| accuracy | f1 | loss | GPU number | model shard |
| :-----: | :----: | :----: | :----: | :----: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |

Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
146 changes: 146 additions & 0 deletions colossalai/shardformer/examples/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import datasets
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer

from colossalai.booster.plugin.dp_plugin_base import DPPluginBase


class GLUEDataBuilder:

task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}

glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}

loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]

def __init__(
self,
model_name_or_path: str,
plugin: DPPluginBase = None,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.plugin = plugin

self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
self.setup()

def setup(self):
self.dataset = datasets.load_dataset("glue", self.task_name)

for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)

self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]

def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

def train_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)
return self.plugin.prepare_dataloader(self.dataset["train"],
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True)

def val_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

def test_dataloader(self):
if self.plugin == None:
return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size)
if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
for x in self.eval_splits
]

def convert_to_features(self, example_batch):

# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]

# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
max_length=self.max_seq_length,
padding='max_length',
truncation=True)

# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]

return features

def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False):

return DataLoader(dataset,
batch_size=batch_size,
sampler=None,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory)
154 changes: 154 additions & 0 deletions colossalai/shardformer/examples/shardformer_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import argparse
import math
from typing import Any, List, Union

import evaluate
import torch
import torch.distributed as dist
from data import GLUEDataBuilder
from torch import nn
from torch.optim import Adam, AdamW, Optimizer
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup

import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer import ShardConfig, ShardFormer


def to_device(x: Any, device: torch.device) -> Any:

def _to(t: Any):
if isinstance(t, torch.Tensor):
return t.to(device)
return t

return tree_map(_to, x)


def train(args):
colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator()

# prepare for data and dataset
data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain,
task_name=args.task,
train_batch_size=args.batch_size,
eval_batch_size=args.batch_size)
train_dataloader = data_builder.train_dataloader()
test_dataloader = data_builder.test_dataloader()

if args.model == "bert":
cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels)
model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg)

model.to(torch.cuda.current_device())

# if multiple GPUs, shard the model
if dist.get_world_size() > 1:
shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm)
shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.shard_model(model)

optim = Adam(model.parameters(), lr=args.lr)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
lr_scheduler = get_linear_schedule_with_warmup(
optim,
num_warmup_steps=math.ceil(max_steps * args.warmup_fraction),
num_training_steps=max_steps,
)
fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size,
coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
if coordinator.is_master():
print(results)
if args.target_f1 is not None and 'f1' in results:
assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'


def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size,
coordinator):
step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs),
desc=f'steps',
disable=not coordinator.is_master())
total_loss = 0
for epoch in range(max_epochs):
model.train()
for batch_id, batch in enumerate(train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = model(**batch)
loss = outputs.loss
loss = loss / accumulation_steps
loss.backward()
total_loss += loss.item()
if (batch_id + 1) % accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
step_bar.set_postfix({
'epoch': epoch,
'loss': total_loss / batch_size,
'lr': scheduler.get_last_lr()[0]
})
total_loss = 0
step_bar.update()


# evaluate
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()

def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
for batch in dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)

if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()

labels = batch["labels"]
metric.add_batch(predictions=preds, references=labels)

results = metric.compute()
if coordinator.is_master():
results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size)
return results

if isinstance(test_dataloader, DataLoader):
return evaluate_subset(test_dataloader)
else:
assert len(test_dataloader) == len(eval_splits)
final_results = {}
for split, sub_loader in zip(eval_splits, test_dataloader):
results = evaluate_subset(sub_loader)
final_results.update({f'{k}_{split}': v for k, v in results.items()})
return final_results


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
parser.add_argument('--model', type=str, default="bert")
parser.add_argument('--pretrain', type=str, default="bert-base-uncased")
parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lr', type=float, default=2.4e-5)
parser.add_argument('--fused_layernorm', type=bool, default=False)
parser.add_argument('--accumulation_steps', type=int, default=8)
parser.add_argument('--warmup_fraction', type=float, default=0.03)
parser.add_argument('--target_f1', type=float, default=None)
args = parser.parse_args()
train(args)
9 changes: 9 additions & 0 deletions colossalai/shardformer/examples/shardformer_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
torchrun --standalone --nproc_per_node=4 shardformer_benchmark.py \
--model "bert" \
--pretrain "bert-base-uncased" \
--max_epochs 1 \
--batch_size 2 \
--lr 2.4e-5 \
--fused_layernorm False \
--accumulation_steps 8 \
--warmup_fraction 0.03
2 changes: 1 addition & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ShardConfig:
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_process_group: int = None
tensor_parallel_process_group: ProcessGroup = None
enable_fused_normalization: bool = False
enable_all_optimization: bool = False

Expand Down