From 6978b8b9f90e39b11bbd997b973ae5569b4a0d8e Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 12 Sep 2023 11:13:05 +0800 Subject: [PATCH 1/5] [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 31 +++- .../examples/convergence_benchmark.py | 154 ------------------ .../examples/convergence_benchmark.sh | 9 - colossalai/shardformer/examples/data.py | 146 ----------------- .../examples/performance_benchmark.py | 6 +- 5 files changed, 30 insertions(+), 316 deletions(-) delete mode 100644 colossalai/shardformer/examples/convergence_benchmark.py delete mode 100644 colossalai/shardformer/examples/convergence_benchmark.sh delete mode 100644 colossalai/shardformer/examples/data.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 2e48a79dc1d7..4683e35f7c9f 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -30,27 +30,48 @@ ### Quick Start -The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization, It requires that the sequence length be a multiple of 8.): +The sample API usage is given below(If you enable the use of flash attention, please install `flash_attn`. In addition, xformers's `cutlass_op` provide a supplementary optimization): ```python -from colossalai.shardformer import ShardConfig, Shard +from colossalai.shardformer import ShardConfig, ShardFormer from transformers import BertForMaskedLM +import colossalai # launch colossalai -colossalai.launch_from_torch() +colossalai.launch_from_torch(config={}) # create model config = BertConfig.from_pretrained('bert-base-uncased') model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config) # create huggingface model as normal -shard_config = ShardConfig() +shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + pipeline_stage_manager=stage_manager, + enable_tensor_parallelism=True, + enable_fused_normalization=True, + enable_flash_attention=True, + enable_jit_fused=True, + enable_sequence_parallelism=True, + enable_sequence_overlap=True) + shard_former = ShardFormer(shard_config=shard_config) -sharded_model = shard_former.optimize(model).to('cuda') +sharded_model, shared_params = shard_former.optimize(model).to('cuda') # do everything like normal ... ``` +shardformer configuration + +`tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. +`pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. +{{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} +`enable_tensor_parallelism`: using tensor parallel +`enable_fused_normalization`: using apex fused layernorm +`enable_flash_attention`: using flash attention +`enable_jit_fused`: using jit fused operators +`enable_sequence_parallelism`: using sequence parallelism, partition these non-tensor parallel regions along the sequence dimension. +`enable_sequence_overlap`: overlap the computation and communication in the sequence parallelism, it's used with `enable_sequence_parallelism`. + ### Write your own policy diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py deleted file mode 100644 index de82305b2547..000000000000 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ /dev/null @@ -1,154 +0,0 @@ -import argparse -import math -from typing import Any, List, Union - -import evaluate -import torch -import torch.distributed as dist -from data import GLUEDataBuilder -from torch import nn -from torch.optim import Adam, AdamW, Optimizer -from torch.utils._pytree import tree_map -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup - -import colossalai -from colossalai.cluster import DistCoordinator -from colossalai.nn.optimizer import HybridAdam -from colossalai.shardformer import ShardConfig, ShardFormer - - -def to_device(x: Any, device: torch.device) -> Any: - - def _to(t: Any): - if isinstance(t, torch.Tensor): - return t.to(device) - return t - - return tree_map(_to, x) - - -def train(args): - colossalai.launch_from_torch(config={}, seed=42) - coordinator = DistCoordinator() - - # prepare for data and dataset - data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, - task_name=args.task, - train_batch_size=args.batch_size, - eval_batch_size=args.batch_size) - train_dataloader = data_builder.train_dataloader() - test_dataloader = data_builder.test_dataloader() - - if args.model == "bert": - cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels) - model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg) - - model.to(torch.cuda.current_device()) - - # if multiple GPUs, shard the model - if dist.get_world_size() > 1: - shard_config = ShardConfig(enable_fused_normalization=args.fused_layernorm) - shard_former = ShardFormer(shard_config=shard_config) - model = shard_former.optimize(model) - - optim = Adam(model.parameters(), lr=args.lr) - num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps - max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) - lr_scheduler = get_linear_schedule_with_warmup( - optim, - num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), - num_training_steps=max_steps, - ) - fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, - coordinator) - results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, - coordinator) - if coordinator.is_master(): - print(results) - if args.target_f1 is not None and 'f1' in results: - assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' - - -def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, - coordinator): - step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), - desc=f'steps', - disable=not coordinator.is_master()) - total_loss = 0 - for epoch in range(max_epochs): - model.train() - for batch_id, batch in enumerate(train_dataloader): - batch = to_device(batch, torch.cuda.current_device()) - outputs = model(**batch) - loss = outputs.loss - loss = loss / accumulation_steps - loss.backward() - total_loss += loss.item() - if (batch_id + 1) % accumulation_steps == 0: - optimizer.step() - scheduler.step() - optimizer.zero_grad() - step_bar.set_postfix({ - 'epoch': epoch, - 'loss': total_loss / batch_size, - 'lr': scheduler.get_last_lr()[0] - }) - total_loss = 0 - step_bar.update() - - -# evaluate -@torch.no_grad() -def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, - task_name: str, eval_splits: List[str], coordinator: DistCoordinator): - metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) - model.eval() - - def evaluate_subset(dataloader: DataLoader): - accum_loss = torch.zeros(1, device=torch.cuda.current_device()) - for batch in dataloader: - batch = to_device(batch, torch.cuda.current_device()) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss.add_(val_loss) - - if num_labels > 1: - preds = torch.argmax(logits, axis=1) - elif num_labels == 1: - preds = logits.squeeze() - - labels = batch["labels"] - metric.add_batch(predictions=preds, references=labels) - - results = metric.compute() - if coordinator.is_master(): - results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) - return results - - if isinstance(test_dataloader, DataLoader): - return evaluate_subset(test_dataloader) - else: - assert len(test_dataloader) == len(eval_splits) - final_results = {} - for split, sub_loader in zip(eval_splits, test_dataloader): - results = evaluate_subset(sub_loader) - final_results.update({f'{k}_{split}': v for k, v in results.items()}) - return final_results - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") - parser.add_argument('--model', type=str, default="bert") - parser.add_argument('--pretrain', type=str, default="bert-base-uncased") - parser.add_argument('--max_epochs', type=int, default=1) - parser.add_argument('--batch_size', type=int, default=4) - parser.add_argument('--lr', type=float, default=2.4e-5) - parser.add_argument('--fused_layernorm', type=bool, default=False) - parser.add_argument('--accumulation_steps', type=int, default=8) - parser.add_argument('--warmup_fraction', type=float, default=0.03) - parser.add_argument('--target_f1', type=float, default=None) - args = parser.parse_args() - train(args) diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh deleted file mode 100644 index 1c281abcda6d..000000000000 --- a/colossalai/shardformer/examples/convergence_benchmark.sh +++ /dev/null @@ -1,9 +0,0 @@ -torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ - --model "bert" \ - --pretrain "bert-base-uncased" \ - --max_epochs 1 \ - --batch_size 2 \ - --lr 2.4e-5 \ - --fused_layernorm False \ - --accumulation_steps 8 \ - --warmup_fraction 0.03 diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py deleted file mode 100644 index 6296d4be4eb0..000000000000 --- a/colossalai/shardformer/examples/data.py +++ /dev/null @@ -1,146 +0,0 @@ -import datasets -from torch.utils.data import DataLoader -from transformers import AutoTokenizer, PreTrainedTokenizer - -from colossalai.booster.plugin.dp_plugin_base import DPPluginBase - - -class GLUEDataBuilder: - - task_text_field_map = { - "cola": ["sentence"], - "sst2": ["sentence"], - "mrpc": ["sentence1", "sentence2"], - "qqp": ["question1", "question2"], - "stsb": ["sentence1", "sentence2"], - "mnli": ["premise", "hypothesis"], - "qnli": ["question", "sentence"], - "rte": ["sentence1", "sentence2"], - "wnli": ["sentence1", "sentence2"], - "ax": ["premise", "hypothesis"], - } - - glue_task_num_labels = { - "cola": 2, - "sst2": 2, - "mrpc": 2, - "qqp": 2, - "stsb": 1, - "mnli": 3, - "qnli": 2, - "rte": 2, - "wnli": 2, - "ax": 3, - } - - loader_columns = [ - "datasets_idx", - "input_ids", - "token_type_ids", - "attention_mask", - "start_positions", - "end_positions", - "labels", - ] - - def __init__( - self, - model_name_or_path: str, - plugin: DPPluginBase = None, - task_name: str = "mrpc", - max_seq_length: int = 128, - train_batch_size: int = 32, - eval_batch_size: int = 32, - **kwargs, - ): - super().__init__() - self.model_name_or_path = model_name_or_path - self.task_name = task_name - self.max_seq_length = max_seq_length - self.train_batch_size = train_batch_size - self.eval_batch_size = eval_batch_size - self.plugin = plugin - - self.text_fields = self.task_text_field_map[task_name] - self.num_labels = self.glue_task_num_labels[task_name] - self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) - self.setup() - - def setup(self): - self.dataset = datasets.load_dataset("glue", self.task_name) - - for split in self.dataset.keys(): - self.dataset[split] = self.dataset[split].map( - self.convert_to_features, - batched=True, - remove_columns=["label"], - ) - self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] - self.dataset[split].set_format(type="torch", columns=self.columns) - - self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] - - def prepare_data(self): - datasets.load_dataset("glue", self.task_name) - AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) - - def train_dataloader(self): - if self.plugin == None: - return self.native_prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) - return self.plugin.prepare_dataloader(self.dataset["train"], - batch_size=self.train_batch_size, - shuffle=True, - drop_last=True) - - def val_dataloader(self): - if self.plugin == None: - return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) - if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) - elif len(self.eval_splits) > 1: - return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) - for x in self.eval_splits - ] - - def test_dataloader(self): - if self.plugin == None: - return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size) - if len(self.eval_splits) == 1: - return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) - elif len(self.eval_splits) > 1: - return [ - self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) - for x in self.eval_splits - ] - - def convert_to_features(self, example_batch): - - # Either encode single sentence or sentence pairs - if len(self.text_fields) > 1: - texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) - else: - texts_or_text_pairs = example_batch[self.text_fields[0]] - - # Tokenize the text/text pairs - features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, - max_length=self.max_seq_length, - padding='max_length', - truncation=True) - - # Rename label to labels to make it easier to pass to model forward - features["labels"] = example_batch["label"] - - return features - - def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): - - return DataLoader(dataset, - batch_size=batch_size, - sampler=None, - shuffle=shuffle, - drop_last=drop_last, - pin_memory=pin_memory) diff --git a/colossalai/shardformer/examples/performance_benchmark.py b/colossalai/shardformer/examples/performance_benchmark.py index 9c7b76bcf0a6..2f186709d946 100644 --- a/colossalai/shardformer/examples/performance_benchmark.py +++ b/colossalai/shardformer/examples/performance_benchmark.py @@ -29,7 +29,8 @@ def data_gen_for_sequence_classification(batch_size, seq_length): intermediate_size=256, num_attention_heads=4, max_position_embeddings=128, - num_labels=16) + num_labels=16, + pad_token_id=2) BATCH, N_HEADS, N_CTX, D_HEAD = 4, 8, 4096, 64 model_func = lambda: transformers.LlamaForSequenceClassification(MODEL_CONFIG) @@ -73,7 +74,8 @@ def bench_shardformer(BATCH, N_CTX, provider, model_func, dtype=torch.float32, d if provider == "shard_model": shard_config = ShardConfig(enable_fused_normalization=True, enable_tensor_parallelism=True) shard_former = ShardFormer(shard_config=shard_config) - sharded_model = shard_former.optimize(model).cuda() + sharded_model, _ = shard_former.optimize(model) + sharded_model = sharded_model.cuda() fn = lambda: train(sharded_model, data) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms From a8a096e38f96f96a73d383129672db3eeed52f55 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 12 Sep 2023 11:29:11 +0800 Subject: [PATCH 2/5] [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 55 ++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 4683e35f7c9f..0ba014851120 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -65,7 +65,7 @@ shardformer configuration `tensor_parallel_process_group`: the process group of tensor parallelism, it's necessary when using tensor parallel. `pipeline_stage_manager`: If using pipeline parallelism, it's necessary to specify a pipeline stage manager for inter-process communication in pipeline parallelism. {{ autodoc:colossalai.pipeline.stage_manager.PipelineStageManager }} -`enable_tensor_parallelism`: using tensor parallel +`enable_tensor_parallelism`: using tensor parallel, partition the model along the columns or along the rows `enable_fused_normalization`: using apex fused layernorm `enable_flash_attention`: using flash attention `enable_jit_fused`: using jit fused operators @@ -111,12 +111,13 @@ We will follow this roadmap to develop Shardformer: - [x] GPT2 - [x] OPT - [x] BLOOM - - [ ] GLM + - [x] GLM - [ ] RoBERTa - [ ] ALBERT - [ ] ERNIE - [ ] GPT Neo - [ ] GPT-J + - [ ] Qwen - [ ] CV - [x] ViT - [ ] BEiT @@ -135,12 +136,23 @@ We will follow this roadmap to develop Shardformer: - [x] GPT2 - [x] OPT - [x] BLOOM - - [ ] GLM + - [x] GLM - [ ] RoBERTa - [ ] ALBERT - [ ] ERNIE - [ ] GPT Neo - [ ] GPT-J + - [ ] Qwen + - [ ] CV + - [x] ViT + - [ ] BEiT + - [ ] SwinTransformer + - [ ] SwinTransformer V2 + - [ ] Audio + - [x] Whisper + - [ ] Multi-modal + - [x] SAM + - [x] BLIP-2 ## 💡 API Design @@ -307,41 +319,36 @@ class ShardFormer: Example: + org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') + shard_config = ShardConfig() shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - model = shard_former.optimize(model, policy=policy) - dataloader = shard_former.shard_dataset(dataset) + model, shared_params = shard_former.optimize(org_model) """ def __init__(self, shard_config: ShardConfig): """ Do two things: - 1. Create a colossalai.cluster.process_group_manager to manage process groups for dp, tp and pp + 1. Create a distribute coordinator 2. serve as a store for shard config """ self.shard_config = shard_config - self.pg_manager = None + self.coordinator = DistCoordinator() - def init_distributed(self) -> colossalai.cluster.ProcessGroupManager: - """ - Initialize the distributed process group according to the - """ - pg_manager = ... - self.pg_manager = pg_manager - return pg_manager + def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]: + r""" + This method will optimize the model based on the given policy. - def shard_model(self, model: torch.nn.Module,policy: Policy) -> torch.nn.Module: - """ - Shard model for TP and PP - """ - ... + Args: + model (`torch.nn.Model`): the origin huggingface model + shard_config (`ShardConfig`): the config for distribute information + policy (`Policy`): the custom policy for sharding - def shard_dataset(self, dataset: Dataset) -> Dataloader: + Returns: the sharded model and the shared parameters """ - Shard dataset for DP - """ - ... + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) + shared_params = sharder.shard() + return model, shared_params ``` ## ⌨️ Development Notes From e96f8c066dce140fadfada7b0048d6158dcaa49f Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 12 Sep 2023 14:25:37 +0800 Subject: [PATCH 3/5] [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 91 +++++----- .../examples/convergence_benchmark.py | 157 ++++++++++++++++++ .../examples/convergence_benchmark.sh | 9 + colossalai/shardformer/examples/data.py | 146 ++++++++++++++++ 4 files changed, 350 insertions(+), 53 deletions(-) create mode 100644 colossalai/shardformer/examples/convergence_benchmark.py create mode 100644 colossalai/shardformer/examples/convergence_benchmark.sh create mode 100644 colossalai/shardformer/examples/data.py diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index 0ba014851120..e49c94b45595 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -103,56 +103,30 @@ We will follow this roadmap to develop Shardformer: - [x] API Implementation - [x] Unit Testing - [ ] Policy Implementation - - [ ] Hugging Face - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [x] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J - - [ ] Qwen - - [ ] CV - - [x] ViT - - [ ] BEiT - - [ ] SwinTransformer - - [ ] SwinTransformer V2 - - [ ] Audio - - [x] Whisper - - [ ] Multi-modal - - [x] SAM - - [x] BLIP-2 -- [ ] Flash Attention Support - - [ ] NLP - - [x] BERT - - [x] T5 - - [x] LlaMa - - [x] GPT2 - - [x] OPT - - [x] BLOOM - - [x] GLM - - [ ] RoBERTa - - [ ] ALBERT - - [ ] ERNIE - - [ ] GPT Neo - - [ ] GPT-J - - [ ] Qwen - - [ ] CV - - [x] ViT - - [ ] BEiT - - [ ] SwinTransformer - - [ ] SwinTransformer V2 - - [ ] Audio - - [x] Whisper - - [ ] Multi-modal - - [x] SAM - - [x] BLIP-2 + +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | +| bert | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| t5 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| llama V1/V2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| gpt2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| opt | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| bloom | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| chatglm2 | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [x] | +| vit | [x] | [x] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| whisper | [x] | [x] | [x] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| sam | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| blip2 | [x] | [ ] | [ ] | [x] | [x] | [x] | [x] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | + ## 💡 API Design @@ -457,13 +431,24 @@ As shown in the figures above, when the sequence length is around 1000 or greate ### Convergence -To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. +To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results. + +the configurations are as follows: +```python +batch_size = 2 +epoch = 3 +lr = 2.4e-5 +accumulation_steps = 8 +warmup_fraction = 0.03 +``` + | accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.84589 | 0.88613 | 0.43414 | 4 | True | -| 0.83594 | 0.88064 | 0.43298 | 1 | False | +| 0.82594 | 0.87441 | 0.09913 | 4 | True | +| 0.81884 | 0.87299 | 0.10120 | 2 | True | +| 0.81855 | 0.87124 | 0.10357 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py new file mode 100644 index 000000000000..7d5830a3e890 --- /dev/null +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -0,0 +1,157 @@ +import argparse +import math +from typing import Any, List, Union + +import evaluate +import torch +import torch.distributed as dist +from data import GLUEDataBuilder +from torch import nn +from torch.optim import Adam, AdamW, Optimizer +from torch.utils._pytree import tree_map +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.nn.optimizer import HybridAdam +from colossalai.shardformer import ShardConfig, ShardFormer + + +def to_device(x: Any, device: torch.device) -> Any: + + def _to(t: Any): + if isinstance(t, torch.Tensor): + return t.to(device) + return t + + return tree_map(_to, x) + + +def train(args): + colossalai.launch_from_torch(config={}, seed=42) + coordinator = DistCoordinator() + + # prepare for data and dataset + data_builder = GLUEDataBuilder(model_name_or_path=args.pretrain, + task_name=args.task, + train_batch_size=args.batch_size, + eval_batch_size=args.batch_size) + train_dataloader = data_builder.train_dataloader() + test_dataloader = data_builder.test_dataloader() + + if args.model == "bert": + cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels) + model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg) + + model.to(torch.cuda.current_device()) + + # if multiple GPUs, shard the model + if dist.get_world_size() > 1: + tp_group = dist.new_group(backend='nccl') + shard_config = ShardConfig(tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=True, + enable_all_optimization=False) + shard_former = ShardFormer(shard_config=shard_config) + model, _ = shard_former.optimize(model) + + optim = Adam(model.parameters(), lr=args.lr) + num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps + max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch) + lr_scheduler = get_linear_schedule_with_warmup( + optim, + num_warmup_steps=math.ceil(max_steps * args.warmup_fraction), + num_training_steps=max_steps, + ) + fit(model, optim, lr_scheduler, train_dataloader, args.max_epochs, args.accumulation_steps, args.batch_size, + coordinator) + results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, + coordinator) + if coordinator.is_master(): + print(results) + if args.target_f1 is not None and 'f1' in results: + assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' + + +def fit(model: nn.Module, optimizer: Optimizer, scheduler, train_dataloader, max_epochs, accumulation_steps, batch_size, + coordinator): + step_bar = tqdm(range(len(train_dataloader) // accumulation_steps * max_epochs), + desc=f'steps', + disable=not coordinator.is_master()) + total_loss = 0 + for epoch in range(max_epochs): + model.train() + for batch_id, batch in enumerate(train_dataloader): + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = outputs.loss + loss = loss / accumulation_steps + loss.backward() + total_loss += loss.item() + if (batch_id + 1) % accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + step_bar.set_postfix({ + 'epoch': epoch, + 'loss': total_loss / batch_size, + 'lr': scheduler.get_last_lr()[0] + }) + total_loss = 0 + step_bar.update() + + +# evaluate +@torch.no_grad() +def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, + task_name: str, eval_splits: List[str], coordinator: DistCoordinator): + metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) + model.eval() + + def evaluate_subset(dataloader: DataLoader): + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + for batch in dataloader: + batch = to_device(batch, torch.cuda.current_device()) + outputs = model(**batch) + val_loss, logits = outputs[:2] + accum_loss.add_(val_loss) + + if num_labels > 1: + preds = torch.argmax(logits, axis=1) + elif num_labels == 1: + preds = logits.squeeze() + + labels = batch["labels"] + metric.add_batch(predictions=preds, references=labels) + + results = metric.compute() + if coordinator.is_master(): + results['loss'] = accum_loss.item() / (len(dataloader) * dataloader.batch_size) + return results + + if isinstance(test_dataloader, DataLoader): + return evaluate_subset(test_dataloader) + else: + assert len(test_dataloader) == len(eval_splits) + final_results = {} + for split, sub_loader in zip(eval_splits, test_dataloader): + results = evaluate_subset(sub_loader) + final_results.update({f'{k}_{split}': v for k, v in results.items()}) + return final_results + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run") + parser.add_argument('--model', type=str, default="bert") + parser.add_argument('--pretrain', type=str, default="bert-base-uncased") + parser.add_argument('--max_epochs', type=int, default=1) + parser.add_argument('--batch_size', type=int, default=4) + parser.add_argument('--lr', type=float, default=2.4e-5) + parser.add_argument('--fused_layernorm', type=bool, default=False) + parser.add_argument('--accumulation_steps', type=int, default=8) + parser.add_argument('--warmup_fraction', type=float, default=0.03) + parser.add_argument('--target_f1', type=float, default=None) + args = parser.parse_args() + train(args) diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh new file mode 100644 index 000000000000..22f13a7cf827 --- /dev/null +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -0,0 +1,9 @@ +torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ + --model "bert" \ + --pretrain "bert-base-uncased" \ + --max_epochs 3 \ + --batch_size 2 \ + --lr 2.4e-5 \ + --fused_layernorm False \ + --accumulation_steps 8 \ + --warmup_fraction 0.03 diff --git a/colossalai/shardformer/examples/data.py b/colossalai/shardformer/examples/data.py new file mode 100644 index 000000000000..6296d4be4eb0 --- /dev/null +++ b/colossalai/shardformer/examples/data.py @@ -0,0 +1,146 @@ +import datasets +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizer + +from colossalai.booster.plugin.dp_plugin_base import DPPluginBase + + +class GLUEDataBuilder: + + task_text_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], + } + + glue_task_num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, + } + + loader_columns = [ + "datasets_idx", + "input_ids", + "token_type_ids", + "attention_mask", + "start_positions", + "end_positions", + "labels", + ] + + def __init__( + self, + model_name_or_path: str, + plugin: DPPluginBase = None, + task_name: str = "mrpc", + max_seq_length: int = 128, + train_batch_size: int = 32, + eval_batch_size: int = 32, + **kwargs, + ): + super().__init__() + self.model_name_or_path = model_name_or_path + self.task_name = task_name + self.max_seq_length = max_seq_length + self.train_batch_size = train_batch_size + self.eval_batch_size = eval_batch_size + self.plugin = plugin + + self.text_fields = self.task_text_field_map[task_name] + self.num_labels = self.glue_task_num_labels[task_name] + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + self.setup() + + def setup(self): + self.dataset = datasets.load_dataset("glue", self.task_name) + + for split in self.dataset.keys(): + self.dataset[split] = self.dataset[split].map( + self.convert_to_features, + batched=True, + remove_columns=["label"], + ) + self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns] + self.dataset[split].set_format(type="torch", columns=self.columns) + + self.eval_splits = [x for x in self.dataset.keys() if "validation" in x] + + def prepare_data(self): + datasets.load_dataset("glue", self.task_name) + AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) + + def train_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + return self.plugin.prepare_dataloader(self.dataset["train"], + batch_size=self.train_batch_size, + shuffle=True, + drop_last=True) + + def val_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def test_dataloader(self): + if self.plugin == None: + return self.native_prepare_dataloader(self.dataset["test"], batch_size=self.train_batch_size) + if len(self.eval_splits) == 1: + return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) + elif len(self.eval_splits) > 1: + return [ + self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) + for x in self.eval_splits + ] + + def convert_to_features(self, example_batch): + + # Either encode single sentence or sentence pairs + if len(self.text_fields) > 1: + texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]])) + else: + texts_or_text_pairs = example_batch[self.text_fields[0]] + + # Tokenize the text/text pairs + features = self.tokenizer.batch_encode_plus(texts_or_text_pairs, + max_length=self.max_seq_length, + padding='max_length', + truncation=True) + + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] + + return features + + def native_prepare_dataloader(self, dataset, batch_size, shuffle=False, drop_last=False, pin_memory=False): + + return DataLoader(dataset, + batch_size=batch_size, + sampler=None, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory) From 191f6573959d02d108ae01520965d8d63360c7b0 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 12 Sep 2023 14:53:56 +0800 Subject: [PATCH 4/5] [shardformer] update shardformer readme --- colossalai/shardformer/README.md | 6 +++--- colossalai/shardformer/examples/convergence_benchmark.py | 2 +- colossalai/shardformer/examples/convergence_benchmark.sh | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index e49c94b45595..559f9a56f61e 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -446,9 +446,9 @@ warmup_fraction = 0.03 | accuracy | f1 | loss | GPU number | model sharded | | :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82594 | 0.87441 | 0.09913 | 4 | True | -| 0.81884 | 0.87299 | 0.10120 | 2 | True | -| 0.81855 | 0.87124 | 0.10357 | 1 | False | +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/examples/convergence_benchmark.py b/colossalai/shardformer/examples/convergence_benchmark.py index 7d5830a3e890..81be2017855c 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.py +++ b/colossalai/shardformer/examples/convergence_benchmark.py @@ -52,7 +52,7 @@ def train(args): tp_group = dist.new_group(backend='nccl') shard_config = ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=True, - enable_all_optimization=False) + enable_all_optimization=True) shard_former = ShardFormer(shard_config=shard_config) model, _ = shard_former.optimize(model) diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh index 22f13a7cf827..6751326d7b00 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,4 +1,4 @@ -torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ +torchrun --standalone --nproc_per_node=1 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ --max_epochs 3 \ From 8d499aeb5ea6b7aae551e47a3be28c4f7cad6a34 Mon Sep 17 00:00:00 2001 From: Mingyan Jiang <1829166702@qq.com> Date: Tue, 12 Sep 2023 14:54:15 +0800 Subject: [PATCH 5/5] [shardformer] update shardformer readme --- colossalai/shardformer/examples/convergence_benchmark.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/examples/convergence_benchmark.sh b/colossalai/shardformer/examples/convergence_benchmark.sh index 6751326d7b00..22f13a7cf827 100644 --- a/colossalai/shardformer/examples/convergence_benchmark.sh +++ b/colossalai/shardformer/examples/convergence_benchmark.sh @@ -1,4 +1,4 @@ -torchrun --standalone --nproc_per_node=1 convergence_benchmark.py \ +torchrun --standalone --nproc_per_node=4 convergence_benchmark.py \ --model "bert" \ --pretrain "bert-base-uncased" \ --max_epochs 3 \