From 091f389a412611ae8166d328b93735d0608ca52b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 10:53:08 +0800 Subject: [PATCH 1/8] update vit example for hybrid plugin --- colossalai/shardformer/modeling/gpt2.py | 1 + colossalai/shardformer/modeling/vit.py | 21 ++-- examples/images/vit/README.md | 4 +- examples/images/vit/args.py | 158 +++++++++--------------- examples/images/vit/data.py | 22 ++-- examples/images/vit/run_benchmark.sh | 11 +- examples/images/vit/run_demo.sh | 10 +- examples/images/vit/test_ci.sh | 7 +- examples/images/vit/vit_benchmark.py | 50 ++++++-- examples/images/vit/vit_train_demo.py | 150 +++++++++++++++------- 10 files changed, 242 insertions(+), 192 deletions(-) diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 8ed367b25349..9eb58df4d723 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -884,6 +884,7 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: + logger = logging.get_logger(__name__) logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py index 9fc0b7488803..2ce52163ac32 100644 --- a/colossalai/shardformer/modeling/vit.py +++ b/colossalai/shardformer/modeling/vit.py @@ -1,9 +1,9 @@ -import logging import math from typing import Dict, List, Optional, Set, Tuple, Union import torch from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder +from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -72,18 +72,17 @@ def pp_forward( bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = (output_hidden_states - if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if output_attentions is not None: - logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') - output_attentions = None - if output_hidden_states is not None: - logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') - output_hidden_states = None + logger = logging.get_logger(__name__) + + # Preprocess passed in arguments + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md index 7c4147b76457..33c6454ad92c 100644 --- a/examples/images/vit/README.md +++ b/examples/images/vit/README.md @@ -3,7 +3,7 @@ Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time. In our example, we are using pretrained weights of ViT loaded from HuggingFace. -We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. +We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin (DDP), LowLevelZeroPlugin (Zero1/Zero2), GeminiPlugin (Gemini) and HybridParallelPlugin (any combination of tensor/pipeline/data parallel). ## Run Demo @@ -25,4 +25,4 @@ You can run benchmark for ViT model by running the following script: ```bash bash run_benchmark.sh ``` -The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. \ No newline at end of file +The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing. diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e4a873a9eb52..e60aed5ffc39 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -1,124 +1,80 @@ from colossalai import get_default_parser + def parse_demo_args(): parser = get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to pretrained model or model identifier from huggingface.co/models." - ) - parser.add_argument( - "--output_path", - type=str, - default="./output_model.bin", - help="The path of your saved model after finetuning." - ) + parser.add_argument("--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to pretrained model or model identifier from huggingface.co/models.") + parser.add_argument("--output_path", + type=str, + default="./output_model", + help="The path of your saved model after finetuning.") parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." - ) - parser.add_argument( - "--num_epoch", - type=int, - default=3, - help="Number of epochs." - ) - parser.add_argument( - "--batch_size", - type=int, - default=32, - help="Batch size (per dp group) for the training dataloader." - ) - parser.add_argument( - "--learning_rate", - type=float, - default=3e-4, - help="Initial learning rate (after the potential warmup period) to use." - ) - parser.add_argument( - "--warmup_ratio", - type=float, - default=0.3, - help="Ratio of warmup steps against total training steps." - ) - parser.add_argument( - "--weight_decay", - type=float, - default=0.1, - help="Weight decay to use." - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="A seed for reproducible training." - ) + help= + "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + ) + parser.add_argument("--num_epoch", type=int, default=3, help="Number of epochs.") + parser.add_argument("--batch_size", + type=int, + default=32, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--tp_size", + type=int, + default=2, + help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") + parser.add_argument("--pp_size", + type=int, + default=2, + help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") + parser.add_argument("--learning_rate", + type=float, + default=3e-4, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--warmup_ratio", + type=float, + default=0.3, + help="Ratio of warmup steps against total training steps.") + parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") args = parser.parse_args() return args + def parse_benchmark_args(): parser = get_default_parser() - parser.add_argument( - "--model_name_or_path", - type=str, - default="google/vit-base-patch16-224", - help="Path to a pretrained model or model identifier from huggingface.co/models." - ) + parser.add_argument("--model_name_or_path", + type=str, + default="google/vit-base-patch16-224", + help="Path to a pretrained model or model identifier from huggingface.co/models.") parser.add_argument( "--plugin", type=str, default="gemini", - help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'." - ) - parser.add_argument( - "--batch_size", - type=int, - default=8, - help="Batch size (per dp group) for the training dataloader." - ) - parser.add_argument( - "--num_labels", - type=int, - default=10, - help="Number of labels for classification." - ) - parser.add_argument( - "--learning_rate", - type=float, - default=5e-5, - help="Initial learning rate (after the potential warmup period) to use." - ) - parser.add_argument( - "--weight_decay", - type=float, - default=0.0, - help="Weight decay to use." - ) - parser.add_argument( - "--max_train_steps", - type=int, - default=20, - help="Total number of training steps to perform." - ) - parser.add_argument( - "--seed", - type=int, - default=42, - help="A seed for reproducible training." - ) - parser.add_argument( - "--mem_cap", - type=int, - default=0, - help="Limit on the usage of space for each GPU (in GB)." - ) + help= + "Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero', 'hybrid_parallel'." + ) + parser.add_argument("--batch_size", + type=int, + default=8, + help="Batch size (per dp group) for the training dataloader.") + parser.add_argument("--num_labels", type=int, default=10, help="Number of labels for classification.") + parser.add_argument("--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.") + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") args = parser.parse_args() - return args \ No newline at end of file + return args diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py index 00fde707b173..321d1c4c0f4f 100644 --- a/examples/images/vit/data.py +++ b/examples/images/vit/data.py @@ -1,32 +1,38 @@ import torch -from torch.utils.data import Dataset from datasets import load_dataset +from torch.utils.data import Dataset + class BeansDataset(Dataset): - - def __init__(self, image_processor, split='train'): + + def __init__(self, image_processor, tp_size, split='train'): super().__init__() self.image_processor = image_processor self.ds = load_dataset('beans')[split] self.label_names = self.ds.features['labels'].names + while len(self.label_names) % tp_size != 0: + # ensure that the number of labels is multiple of tp_size + self.label_names.append(f"pad_label_{len(self.label_names)}") self.num_labels = len(self.label_names) self.inputs = [] for example in self.ds: self.inputs.append(self.process_example(example)) - + def __len__(self): return len(self.inputs) def __getitem__(self, idx): return self.inputs[idx] - + def process_example(self, example): input = self.image_processor(example['image'], return_tensors='pt') input['labels'] = example['labels'] return input - + def beans_collator(batch): - return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), - 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)} + return { + 'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0), + 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64) + } diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh index 2487bf81ee2b..41eab9c5a188 100644 --- a/examples/images/vit/run_benchmark.sh +++ b/examples/images/vit/run_benchmark.sh @@ -5,23 +5,20 @@ export BS=8 export MEMCAP=0 export GPUNUM=1 -for BS in 8 32 128 +for BS in 8 32 do -for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" -do -for GPUNUM in 1 4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" do MODEL_PATH="google/vit-base-patch16-224" torchrun \ --standalone \ - --nproc_per_node ${GPUNUM} \ + --nproc_per_node 4 \ vit_benchmark.py \ --model_name_or_path ${MODEL_PATH} \ --mem_cap ${MEMCAP} \ --plugin ${PLUGIN} \ --batch_size ${BS} - -done + done done diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh index 2d140dd6e423..dd8724048b09 100644 --- a/examples/images/vit/run_demo.sh +++ b/examples/images/vit/run_demo.sh @@ -5,12 +5,16 @@ pip install -r requirements.txt MODEL="google/vit-base-patch16-224" # path for saving model -OUTPUT_PATH="./output_model.bin" +OUTPUT_PATH="./output_model" # plugin(training strategy) -# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini" +# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel" PLUGIN="gemini" +# configuration of hybrid parallel, only used when setting PLUGIN to "hybrid_parallel" +TP_SIZE=2 +PP_SIZE=2 + # number of gpus to use GPUNUM=4 @@ -38,6 +42,8 @@ torchrun \ --output_path ${OUTPUT_PATH} \ --plugin ${PLUGIN} \ --batch_size ${BS} \ + --tp_size ${TP_SIZE} \ + --pp_size ${PP_SIZE} \ --num_epoch ${EPOCH} \ --learning_rate ${LR} \ --weight_decay ${WEIGHT_DECAY} \ diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh index 8606015c0397..570147606636 100644 --- a/examples/images/vit/test_ci.sh +++ b/examples/images/vit/test_ci.sh @@ -2,18 +2,15 @@ set -xe pip install -r requirements.txt BS=8 -for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" -do -for GPUNUM in 1 4 +for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini" "hybrid_parallel" do torchrun \ --standalone \ - --nproc_per_node ${GPUNUM} \ + --nproc_per_node 4 \ vit_benchmark.py \ --model_name_or_path "google/vit-base-patch16-224" \ --plugin ${PLUGIN} \ --batch_size ${BS} done -done diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index c2293b96ad73..98864a2821f0 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -8,7 +8,7 @@ import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam @@ -24,7 +24,7 @@ def format_num(num: int, bytes=False): num /= factor -def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): +def get_data_batch(batch_size, num_labels, num_channels=3, height=224, width=224): pixel_values = torch.randn(batch_size, num_channels, height, @@ -32,7 +32,7 @@ def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): device=torch.cuda.current_device(), dtype=torch.float) labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64) - return pixel_values, labels + return dict(pixel_values=pixel_values, labels=labels) def colo_memory_cap(size_in_GB): @@ -69,9 +69,6 @@ def main(): model = ViTForImageClassification(config) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == 'torch_ddp_fp16': @@ -82,14 +79,29 @@ def main(): plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + plugin = HybridParallelPlugin(tp_size=2, + pp_size=2, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision='fp16', + initial_scale=1) logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) + # Set criterion (loss function) + criterion = lambda x: x.loss + + def _criterion(outputs, inputs): + loss = criterion(outputs) + return loss + # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, _, _ = booster.boost(model, optimizer) + model, optimizer, _criterion, _, _ = booster.boost(model, optimizer, criterion=_criterion) # Start training. logger.info(f"Start testing", ranks=[0]) @@ -100,12 +112,24 @@ def main(): start_time = time.time() for _ in range(args.max_train_steps): - - pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) optimizer.zero_grad() - outputs = model(pixel_values=pixel_values, labels=labels) - loss = outputs['loss'] - booster.backward(loss, optimizer) + batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) + + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + # run pipeline forward backward + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + _criterion, + optimizer, + return_loss=True, + return_outputs=True) + else: + outputs = model(**batch) + loss = _criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + optimizer.step() torch.cuda.synchronize() @@ -124,6 +148,8 @@ def main(): f"maximum memory usage per gpu: {max_mem}.", ranks=[0]) + torch.cuda.empty_cache() + if __name__ == "__main__": main() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 4dc0f67f40bf..52decad97f16 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -1,14 +1,20 @@ +from typing import Any, Callable, Iterator + import torch import torch.distributed as dist +import torch.nn as nn +import tqdm import transformers from args import parse_demo_args from data import BeansDataset, beans_collator -from tqdm import tqdm +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -20,51 +26,91 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): +def run_forward_backward(model: nn.Module, + optimizer: Optimizer, + criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, + booster: Booster, + forward_only: bool = False): + + optimizer.zero_grad() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # run pipeline forward backward when enabling pp in hybrid parallel plugin + output_dict = booster.execute_pipeline(data_iter, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + loss, outputs = output_dict['loss'], output_dict['outputs'] + else: + batch = next(data_iter) + batch = move_to_cuda(batch, torch.cuda.current_device()) + outputs = model(**batch) + loss = criterion(outputs, None) + if not forward_only: + booster.backward(loss, optimizer) - torch.cuda.synchronize() - model.train() + return loss, outputs - with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: - for batch in pbar: +def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], + lr_scheduler: LRScheduler, dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): - # Foward - optimizer.zero_grad() - batch = move_to_cuda(batch, torch.cuda.current_device()) - outputs = model(**batch) - loss = outputs['loss'] + torch.cuda.synchronize() - # Backward - booster.backward(loss, optimizer) - optimizer.step() - lr_scheduler.step() + num_steps = len(dataloader) + enable_pbar = coordinator.is_master() + if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: + # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar + tp_rank = dist.get_rank(booster.plugin.tp_group) + dp_rank = dist.get_rank(booster.plugin.dp_group) + enable_pbar = tp_rank == 0 and dp_rank == 0 \ + and booster.plugin.stage_manager.is_last_stage() - # Print batch loss - pbar.set_postfix({'loss': loss.item()}) + progress_bar = tqdm.tqdm(total=num_steps, desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) + model.train() + + for _ in range(num_steps): + loss, _ = run_forward_backward(model, optimizer, criterion, iter(dataloader), booster, forward_only=False) + optimizer.step() + lr_scheduler.step() + + # Print batch loss + if enable_pbar: + progress_bar.set_postfix({'loss': loss.item()}) + progress_bar.update(1) @torch.no_grad() -def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): +def evaluate_model(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], + eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): + torch.cuda.synchronize() model.eval() - accum_loss = torch.zeros(1, device=get_current_device()) - total_num = torch.zeros(1, device=get_current_device()) - accum_correct = torch.zeros(1, device=get_current_device()) + accum_loss = torch.zeros(1, device=torch.cuda.current_device()) + total_num = torch.zeros(1, device=torch.cuda.current_device()) + accum_correct = torch.zeros(1, device=torch.cuda.current_device()) for batch in eval_dataloader: batch = move_to_cuda(batch, torch.cuda.current_device()) - outputs = model(**batch) - val_loss, logits = outputs[:2] - accum_loss += (val_loss / len(eval_dataloader)) - if num_labels > 1: + loss, outputs = run_forward_backward(model, optimizer, criterion, iter([batch]), booster, forward_only=True) + + to_accum = True + if isinstance(booster.plugin, HybridParallelPlugin): + # when using hybrid parallel, loss is only collected from last stage of pipeline with tp_rank == 0 + to_accum = to_accum and (dist.get_rank(booster.plugin.tp_group) == 0) + if booster.plugin.pp_size > 1: + to_accum = to_accum and booster.plugin.stage_manager.is_last_stage() + + if to_accum: + accum_loss += (loss / len(eval_dataloader)) + logits = outputs["logits"] preds = torch.argmax(logits, dim=1) - elif num_labels == 1: - preds = logits.squeeze() - labels = batch["labels"] - total_num += batch["labels"].shape[0] - accum_correct += (torch.sum(preds == labels)) + labels = batch["labels"] + total_num += batch["labels"].shape[0] + accum_correct += (torch.sum(preds == labels)) dist.all_reduce(accum_loss) dist.all_reduce(total_num) @@ -96,12 +142,13 @@ def main(): # Prepare Dataset image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) - train_dataset = BeansDataset(image_processor, split='train') - eval_dataset = BeansDataset(image_processor, split='validation') + train_dataset = BeansDataset(image_processor, args.tp_size, split='train') + eval_dataset = BeansDataset(image_processor, args.tp_size, split='validation') + num_labels = train_dataset.num_labels # Load pretrained ViT model config = ViTConfig.from_pretrained(args.model_name_or_path) - config.num_labels = train_dataset.num_labels + config.num_labels = num_labels config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} model = ViTForImageClassification.from_pretrained(args.model_name_or_path, @@ -109,9 +156,6 @@ def main(): ignore_mismatched_sizes=True) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) - # Enable gradient checkpointing - model.gradient_checkpointing_enable() - # Set plugin booster_kwargs = {} if args.plugin == 'torch_ddp_fp16': @@ -122,6 +166,16 @@ def main(): plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) elif args.plugin == 'low_level_zero': plugin = LowLevelZeroPlugin(initial_scale=2**5) + elif args.plugin == 'hybrid_parallel': + plugin = HybridParallelPlugin(tp_size=args.tp_size, + pp_size=args.pp_size, + num_microbatches=None, + microbatch_size=1, + enable_all_optimization=True, + precision='fp16', + initial_scale=1) + else: + raise ValueError(f"Plugin with name {args.plugin} is not supported!") logger.info(f"Set plugin as {args.plugin}", ranks=[0]) # Prepare dataloader @@ -139,6 +193,13 @@ def main(): # Set optimizer optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) + # Set criterion (loss function) + criterion = lambda x: x.loss + + def _criterion(outputs, inputs): + loss = criterion(outputs) + return loss + # Set lr scheduler total_steps = len(train_dataloader) * args.num_epoch num_warmup_steps = int(args.warmup_ratio * total_steps) @@ -148,20 +209,21 @@ def main(): # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, - optimizer=optimizer, - dataloader=train_dataloader, - lr_scheduler=lr_scheduler) + model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model, + optimizer=optimizer, + criterion=_criterion, + dataloader=train_dataloader, + lr_scheduler=lr_scheduler) # Finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): - train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) - evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator) + train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) + evaluate_model(epoch, model, optimizer, _criterion, eval_dataloader, booster, coordinator) logger.info(f"Finish finetuning", ranks=[0]) # Save the finetuned model - booster.save_model(model, args.output_path) + booster.save_model(model, args.output_path, shard=True) logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) From c496c0f270049d314c4e8802b8d23dcacb1dd9b3 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 11:01:27 +0800 Subject: [PATCH 2/8] reset tp/pp size --- examples/images/vit/args.py | 4 ++-- examples/images/vit/vit_train_demo.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e60aed5ffc39..e042c463908f 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -26,11 +26,11 @@ def parse_demo_args(): help="Batch size (per dp group) for the training dataloader.") parser.add_argument("--tp_size", type=int, - default=2, + default=1, help="The size along tensor parallel dimension, only be used when enabling hybrid parallel.") parser.add_argument("--pp_size", type=int, - default=2, + default=1, help="The size along pipeline parallel dimension, only be used when enabling hybrid parallel.") parser.add_argument("--learning_rate", type=float, diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 52decad97f16..d3fa411ab184 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -140,6 +140,11 @@ def main(): else: transformers.utils.logging.set_verbosity_error() + # Reset tp_size and pp_size to 1 if not using hybrid parallel. + if args.plugin != 'hybrid_parallel': + args.tp_size = 1 + args.pp_size = 1 + # Prepare Dataset image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path) train_dataset = BeansDataset(image_processor, args.tp_size, split='train') From 4b0ca7825ec3266b5542bd392e197a2ec84be126 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 11:23:55 +0800 Subject: [PATCH 3/8] fix dataloader iteration bug --- examples/images/vit/data.py | 2 +- examples/images/vit/run_demo.sh | 6 +++--- examples/images/vit/vit_train_demo.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py index 321d1c4c0f4f..77a8ad525056 100644 --- a/examples/images/vit/data.py +++ b/examples/images/vit/data.py @@ -5,7 +5,7 @@ class BeansDataset(Dataset): - def __init__(self, image_processor, tp_size, split='train'): + def __init__(self, image_processor, tp_size=1, split='train'): super().__init__() self.image_processor = image_processor diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh index dd8724048b09..b34eef1d0c13 100644 --- a/examples/images/vit/run_demo.sh +++ b/examples/images/vit/run_demo.sh @@ -9,16 +9,16 @@ OUTPUT_PATH="./output_model" # plugin(training strategy) # can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel" -PLUGIN="gemini" +PLUGIN="hybrid_parallel" -# configuration of hybrid parallel, only used when setting PLUGIN to "hybrid_parallel" +# configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel" TP_SIZE=2 PP_SIZE=2 # number of gpus to use GPUNUM=4 -# batch size per gpu +# batch size per data parallel group BS=16 # learning rate diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index d3fa411ab184..6e97c175ec66 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -60,6 +60,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C torch.cuda.synchronize() num_steps = len(dataloader) + data_iter = iter(dataloader) enable_pbar = coordinator.is_master() if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # when using pp, only the last stage of master pipeline (dp_rank and tp_rank are both zero) shows pbar @@ -72,7 +73,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C model.train() for _ in range(num_steps): - loss, _ = run_forward_backward(model, optimizer, criterion, iter(dataloader), booster, forward_only=False) + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster, forward_only=False) optimizer.step() lr_scheduler.step() From a4e2afac835a0b4d110abbd111bdd1903ce3a13a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 11:47:12 +0800 Subject: [PATCH 4/8] update optimizer passing in evaluation/add grad_accum --- examples/images/vit/args.py | 2 ++ examples/images/vit/run_demo.sh | 3 ++- examples/images/vit/vit_benchmark.py | 4 ++++ examples/images/vit/vit_train_demo.py | 26 +++++++++++++------------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index e042c463908f..ddb9e490b3f8 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -41,6 +41,7 @@ def parse_demo_args(): default=0.3, help="Ratio of warmup steps against total training steps.") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") + parser.add_argument("--grad_accum", type=bool, default=True, help="Whether to use Gradient Accumulation.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") args = parser.parse_args() @@ -72,6 +73,7 @@ def parse_benchmark_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--grad_accum", type=bool, default=True, help="Whether to use Gradient Accumulation.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh index b34eef1d0c13..9efe1475956d 100644 --- a/examples/images/vit/run_demo.sh +++ b/examples/images/vit/run_demo.sh @@ -9,7 +9,8 @@ OUTPUT_PATH="./output_model" # plugin(training strategy) # can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"/"hybrid_parallel" -PLUGIN="hybrid_parallel" +PLUGIN="gemini" +#PLUGIN="hybrid_parallel" # configuration of parallel group sizes, only used when setting PLUGIN to "hybrid_parallel" TP_SIZE=2 diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index 98864a2821f0..146c449c472a 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -69,6 +69,10 @@ def main(): model = ViTForImageClassification(config) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + # Enable gradient checkpointing + if args.grad_accum: + model.gradient_checkpointing_enable() + # Set plugin booster_kwargs = {} if args.plugin == 'torch_ddp_fp16': diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 6e97c175ec66..bdbadc718f59 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -26,14 +26,10 @@ def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} -def run_forward_backward(model: nn.Module, - optimizer: Optimizer, - criterion: Callable[[Any, Any], torch.Tensor], - data_iter: Iterator, - booster: Booster, - forward_only: bool = False): - - optimizer.zero_grad() +def run_forward_backward(model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], + data_iter: Iterator, booster: Booster): + if optimizer is not None: + optimizer.zero_grad() if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1: # run pipeline forward backward when enabling pp in hybrid parallel plugin output_dict = booster.execute_pipeline(data_iter, @@ -48,7 +44,7 @@ def run_forward_backward(model: nn.Module, batch = move_to_cuda(batch, torch.cuda.current_device()) outputs = model(**batch) loss = criterion(outputs, None) - if not forward_only: + if optimizer is not None: booster.backward(loss, optimizer) return loss, outputs @@ -73,7 +69,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C model.train() for _ in range(num_steps): - loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster, forward_only=False) + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() lr_scheduler.step() @@ -84,7 +80,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C @torch.no_grad() -def evaluate_model(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: Callable[[Any, Any], torch.Tensor], +def evaluate_model(epoch: int, model: nn.Module, criterion: Callable[[Any, Any], torch.Tensor], eval_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator): torch.cuda.synchronize() @@ -95,7 +91,7 @@ def evaluate_model(epoch: int, model: nn.Module, optimizer: Optimizer, criterion for batch in eval_dataloader: batch = move_to_cuda(batch, torch.cuda.current_device()) - loss, outputs = run_forward_backward(model, optimizer, criterion, iter([batch]), booster, forward_only=True) + loss, outputs = run_forward_backward(model, None, criterion, iter([batch]), booster) to_accum = True if isinstance(booster.plugin, HybridParallelPlugin): @@ -162,6 +158,10 @@ def main(): ignore_mismatched_sizes=True) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) + # Enable gradient checkpointing + if args.grad_accum: + model.gradient_checkpointing_enable() + # Set plugin booster_kwargs = {} if args.plugin == 'torch_ddp_fp16': @@ -225,7 +225,7 @@ def _criterion(outputs, inputs): logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - evaluate_model(epoch, model, optimizer, _criterion, eval_dataloader, booster, coordinator) + evaluate_model(epoch, model, _criterion, eval_dataloader, booster, coordinator) logger.info(f"Finish finetuning", ranks=[0]) # Save the finetuned model From 0f982e1290dc355f79d585675b9e57a76581cc53 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 12:13:07 +0800 Subject: [PATCH 5/8] change criterion --- examples/images/vit/vit_benchmark.py | 13 +++++-------- examples/images/vit/vit_train_demo.py | 14 +++++--------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index 146c449c472a..d3c1f95de7d7 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -97,15 +97,12 @@ def main(): optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size)) # Set criterion (loss function) - criterion = lambda x: x.loss - - def _criterion(outputs, inputs): - loss = criterion(outputs) - return loss + def criterion(outputs, inputs): + return outputs.loss # Set booster booster = Booster(plugin=plugin, **booster_kwargs) - model, optimizer, _criterion, _, _ = booster.boost(model, optimizer, criterion=_criterion) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion=criterion) # Start training. logger.info(f"Start testing", ranks=[0]) @@ -124,13 +121,13 @@ def _criterion(outputs, inputs): batch = iter([batch]) outputs = booster.execute_pipeline(batch, model, - _criterion, + criterion, optimizer, return_loss=True, return_outputs=True) else: outputs = model(**batch) - loss = _criterion(outputs, None) + loss = criterion(outputs, None) # Backward booster.backward(loss, optimizer) diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index bdbadc718f59..03a6297e3984 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -19,7 +19,6 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -from colossalai.utils import get_current_device def move_to_cuda(batch, device): @@ -200,11 +199,8 @@ def main(): optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) # Set criterion (loss function) - criterion = lambda x: x.loss - - def _criterion(outputs, inputs): - loss = criterion(outputs) - return loss + def criterion(outputs, inputs): + return outputs.loss # Set lr scheduler total_steps = len(train_dataloader) * args.num_epoch @@ -217,15 +213,15 @@ def _criterion(outputs, inputs): booster = Booster(plugin=plugin, **booster_kwargs) model, optimizer, _criterion, train_dataloader, lr_scheduler = booster.boost(model=model, optimizer=optimizer, - criterion=_criterion, + criterion=criterion, dataloader=train_dataloader, lr_scheduler=lr_scheduler) # Finetuning logger.info(f"Start finetuning", ranks=[0]) for epoch in range(args.num_epoch): - train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) - evaluate_model(epoch, model, _criterion, eval_dataloader, booster, coordinator) + train_epoch(epoch, model, optimizer, criterion, lr_scheduler, train_dataloader, booster, coordinator) + evaluate_model(epoch, model, criterion, eval_dataloader, booster, coordinator) logger.info(f"Finish finetuning", ranks=[0]) # Save the finetuned model From 221aedfe119da99c7396330ae5b5803169e5f5a8 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 12:38:54 +0800 Subject: [PATCH 6/8] wrap tqdm --- examples/images/vit/vit_benchmark.py | 48 +++++++++++++-------------- examples/images/vit/vit_train_demo.py | 18 +++++----- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index d3c1f95de7d7..c4d4257c7bf2 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -106,35 +106,35 @@ def criterion(outputs, inputs): # Start training. logger.info(f"Start testing", ranks=[0]) - progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) torch.cuda.synchronize() model.train() start_time = time.time() - for _ in range(args.max_train_steps): - optimizer.zero_grad() - batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) - - if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: - # run pipeline forward backward - batch = iter([batch]) - outputs = booster.execute_pipeline(batch, - model, - criterion, - optimizer, - return_loss=True, - return_outputs=True) - else: - outputs = model(**batch) - loss = criterion(outputs, None) - # Backward - booster.backward(loss, optimizer) - - optimizer.step() - - torch.cuda.synchronize() - progress_bar.update(1) + with tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) as pbar: + for _ in range(args.max_train_steps): + optimizer.zero_grad() + batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) + + if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None: + # run pipeline forward backward + batch = iter([batch]) + outputs = booster.execute_pipeline(batch, + model, + criterion, + optimizer, + return_loss=True, + return_outputs=True) + else: + outputs = model(**batch) + loss = criterion(outputs, None) + # Backward + booster.backward(loss, optimizer) + + optimizer.step() + + torch.cuda.synchronize() + pbar.update(1) # Compute Statistics end_time = time.time() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 03a6297e3984..98a3c802234b 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -64,18 +64,18 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C enable_pbar = tp_rank == 0 and dp_rank == 0 \ and booster.plugin.stage_manager.is_last_stage() - progress_bar = tqdm.tqdm(total=num_steps, desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) model.train() - for _ in range(num_steps): - loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) - optimizer.step() - lr_scheduler.step() + with tqdm.tqdm(total=num_steps, desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: + for _ in range(num_steps): + loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) + optimizer.step() + lr_scheduler.step() - # Print batch loss - if enable_pbar: - progress_bar.set_postfix({'loss': loss.item()}) - progress_bar.update(1) + # Print batch loss + if enable_pbar: + pbar.set_postfix({'loss': loss.item()}) + pbar.update(1) @torch.no_grad() From 52144f9ad41e29f9ca94ab03b339c81171acd8ac Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 15:39:26 +0800 Subject: [PATCH 7/8] change grad_accum to grad_checkpoint --- examples/images/vit/args.py | 4 ++-- examples/images/vit/vit_benchmark.py | 6 +++--- examples/images/vit/vit_train_demo.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py index ddb9e490b3f8..e6c52c4e97fd 100644 --- a/examples/images/vit/args.py +++ b/examples/images/vit/args.py @@ -41,7 +41,7 @@ def parse_demo_args(): default=0.3, help="Ratio of warmup steps against total training steps.") parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay to use.") - parser.add_argument("--grad_accum", type=bool, default=True, help="Whether to use Gradient Accumulation.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") args = parser.parse_args() @@ -73,7 +73,7 @@ def parse_benchmark_args(): default=5e-5, help="Initial learning rate (after the potential warmup period) to use.") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") - parser.add_argument("--grad_accum", type=bool, default=True, help="Whether to use Gradient Accumulation.") + parser.add_argument("--grad_checkpoint", type=bool, default=True, help="Whether to use gradient checkpointing.") parser.add_argument("--max_train_steps", type=int, default=20, help="Total number of training steps to perform.") parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") parser.add_argument("--mem_cap", type=int, default=0, help="Limit on the usage of space for each GPU (in GB).") diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index c4d4257c7bf2..a9a30058be05 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -1,9 +1,9 @@ import time import torch -import tqdm import transformers from args import parse_benchmark_args +from tqdm import tqdm from transformers import ViTConfig, ViTForImageClassification import colossalai @@ -70,7 +70,7 @@ def main(): logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing - if args.grad_accum: + if args.grad_checkpoint: model.gradient_checkpointing_enable() # Set plugin @@ -111,7 +111,7 @@ def criterion(outputs, inputs): model.train() start_time = time.time() - with tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) as pbar: + with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar: for _ in range(args.max_train_steps): optimizer.zero_grad() batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 98a3c802234b..93d928aa379e 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -3,13 +3,13 @@ import torch import torch.distributed as dist import torch.nn as nn -import tqdm import transformers from args import parse_demo_args from data import BeansDataset, beans_collator from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from tqdm import tqdm from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor import colossalai @@ -66,7 +66,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C model.train() - with tqdm.tqdm(total=num_steps, desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: + with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: for _ in range(num_steps): loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() @@ -158,7 +158,7 @@ def main(): logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) # Enable gradient checkpointing - if args.grad_accum: + if args.grad_checkpoint: model.gradient_checkpointing_enable() # Set plugin From 5a38d1c7d229df198bcfca3f0e9ed3a99ea68c85 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 7 Sep 2023 16:47:38 +0800 Subject: [PATCH 8/8] fix pbar --- examples/images/vit/vit_benchmark.py | 3 +-- examples/images/vit/vit_train_demo.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py index a9a30058be05..d822fe23ecf0 100644 --- a/examples/images/vit/vit_benchmark.py +++ b/examples/images/vit/vit_benchmark.py @@ -112,7 +112,7 @@ def criterion(outputs, inputs): start_time = time.time() with tqdm(range(args.max_train_steps), desc="Training Step", disable=not coordinator.is_master()) as pbar: - for _ in range(args.max_train_steps): + for _ in pbar: optimizer.zero_grad() batch = get_data_batch(args.batch_size, args.num_labels, 3, 224, 224) @@ -134,7 +134,6 @@ def criterion(outputs, inputs): optimizer.step() torch.cuda.synchronize() - pbar.update(1) # Compute Statistics end_time = time.time() diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py index 93d928aa379e..206d8694b8f5 100644 --- a/examples/images/vit/vit_train_demo.py +++ b/examples/images/vit/vit_train_demo.py @@ -67,7 +67,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C model.train() with tqdm(range(num_steps), desc=f'Epoch [{epoch + 1}]', disable=not enable_pbar) as pbar: - for _ in range(num_steps): + for _ in pbar: loss, _ = run_forward_backward(model, optimizer, criterion, data_iter, booster) optimizer.step() lr_scheduler.step() @@ -75,7 +75,6 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, criterion: C # Print batch loss if enable_pbar: pbar.set_postfix({'loss': loss.item()}) - pbar.update(1) @torch.no_grad()