diff --git a/HelloDeepSpeed/README.md b/HelloDeepSpeed/README.md new file mode 100644 index 000000000..9bc7adac0 --- /dev/null +++ b/HelloDeepSpeed/README.md @@ -0,0 +1,300 @@ +# Training a Masked Language Model with PyTorch and DeepSpeed + +In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate DeepSpeed, and show some of the advantages of doing so. + +Table of contents +================= + + + * [(1) Training a Transformer Encoder (BERT / Roberta) model for MLM](#1-training-a-transformer-encoder-bert--roberta-model-for-mlm) + * [1.0 Some Good Practices](#10-some-good-practices) + * [1.1 The Masked Language Modeling Task](#11-the-masked-language-modeling-task) + * [1.2 Creating a Transformer model](#12-creating-a-transformer-model) + * [1.3 Training the Model](#13-training-the-model) + * [(2) Integrating DeepSpeed For More Efficient Training](#2-integrating-deepspeed-for-more-efficient-training) + * [2.0 Core DeepSpeed Code Changes](#20-core-deepspeed-code-changes) + * [2.1 Launching Training](#21-launching-training) + * [2.2 Mixed Precision Training (fp16)](#22-mixed-precision-training-fp16) + * [2.3 Zero Redundancy Optimizer (ZeRO)](#23-zero-redundancy-optimizer-zero) + * [References](#references) + + +## 1. Training a Transformer Encoder (BERT / Roberta) model for MLM + +### 1.0 Some Good Practices + +### Version Control and Reproducibility + +One of the most important parts of training ML models is for the experiments to be reproducible (either by someone else, or by you 3 months later). Some steps that help with this are: + +* Use some form of version control (eg: `git`). Additionally, make sure to save the `gitdiff` and `githash`, so that anyone trying to replicate the experiment can easily do so + +* Save all the hyperparameters associated with the experiment (be it taken from a config or parsed from the command line) + +* Seed your random generators + +* Specify all the packages and their versions. This can be a `requirements.txt` file, a conda `env.yaml` file or a `pyproject.toml` file. If you want complete reproducibility, you can also include a `Dockerfile` to specify the environment to run the experiment in. + +In this example, the checkpoint directory has the following format: + +```bash +{exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ +``` +This ensures that if you revisit this experiment, it is easy to understand what the experiment was about, when was it run, with what hyperparameters and what was the status of the code when the run was executed. + +### Writing Unit Tests + +Unit tests can help catch bugs early, and also set up a fast feedback loop. Since some experiments can take days or even weeks to run, both of these things are quite invaluable. + +In this tutorial, two primary parts are data creating and model training. Hence, we test both these parts (see [here](./tests/test_train_bert.py)): + +1. Testing Data creation: Data creation for the MLM task involves randomly masking words to generate model inputs. To test for correctness, we test if the fraction of masked tokens matches what we expect from our `DataLoader`. For more details, please take a look at + +```python +def test_masking_stats(tol: float = 1e-3): + """Test to check that the masking probabilities + match what we expect them to be. + """ + ... +``` + +2. Model training and checkpointing: Since pretraining experiments are usually quite expensive (take days / weeks to complete), chances are that you might run into some hardware failure before the training completes. Thus, it is crucial that the checkpointing logic is correct to allow a model to resume training. One way to do this is to train a small model for a few iterations and see if the model can resume training and if the checkpoints are loaded correctly. See `test_model_checkpointing` for an example test. + +--- + +💡 **_Tip:_** While saving checkpoints, make sure to also save the optimizer states along with the model parameters ! + +--- + +### 1.1 The Masked Language Modeling Task + +The main idea behind the MLM task is to get the model to fill in the blanks based on contextual clues present **both before and after** the blank. Consider, for example, the following sentence: + +> In the beautiful season of ____ the ____ shed their leaves. + +Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to fill in the blanks. + +In order to do that, we do the following: + +1. Tokenize a sentence into word(pieces) +2. Randomly select some words to mask, and replace them with a special \ token + * Of the masked tokens, it is common to replace a fraction of them with a random token, and leave a fraction of them unchanged. +3. Collect the actual words that were masked, and use them as targets for the model to predict against: + * From the model's perspective, this is a simple `CrossEntropy` loss over the vocabulary of the model. + +In this tutorial, we use the [wikitext-2-v1](https://huggingface.co/datasets/wikitext) dataset from [HuggingFace datasets](https://github.com/huggingface/datasets). To see how this is done in code, take a look at `masking_function` in [train_bert.py](./train_bert.py). + + +### 1.2 Creating a Transformer model + +A Transformer model repeatedly applies a (Multi-Headed) Self-Attention block and a FeedForward layer to generate contextual representations for each token. Thus, the key hyperparameters for a Transformer model usually are + +1. The number of Self-Attention + FeedForward blocks (depth) +2. The size of the hidden representation +3. The number of Self Attention Heads +4. The size of the intermediate representation between the FeedForward block + +Check out the `create_model` function in [train_bert.py](./train_bert.py) to see how this is done in code. In this example, we create a Roberta model[3](#3) + +--- +📌 **Note:** You can check out [[1](#1), [2](#2)] as a starting point for better understanding Transformers. Additionally, there are a number of blogs that do a nice deep dive into the workings of these models (eg: [this](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [this](https://jalammar.github.io/illustrated-bert/) and [this](https://jalammar.github.io/illustrated-transformer/)). + +--- + +### 1.3 Training the Model + +In order to train the model, you can run the following command + +```bash +python train_bert.py --checkpoint_dir ./experiments +``` +This will create a model with the default parameters (as specified by the arguments to the `train` function), and train it on the wikitext dataset. Other parameters can be configured from the command line as: + +```bash +python train_bert.py --checkpoint_dir ./experiments \ + --mask_prob ${mask_prob} \ + --random_replace_prob ${random_replace_prob} \ + --unmask_replace_prob ${unmask_replace_prob} \ + --max_seq_length ${max_seq_length} \ + --tokenizer ${tokenizer} \ + --num_layers ${num_layers} \ + --num_heads ${num_heads} \ + --ff_dim ${ff_dim} \ + --h_dim ${h_dim} \ + --dropout ${dropout} \ + --batch_size ${batch_size} \ + --num_iterations ${num_iterations} \ + --checkpoint_every ${checkpoint_every} \ + --log_every ${log_every} \ + --local_rank ${local_rank} + +``` + +The parameters are explained in more details in the docstring of `train`. + +--- +💡 **_Tip:_** If you have a GPU available, you can considerably speedup your training by running it on the GPU. Simply set the `local_rank` to the GPU you want to run it on. Eg: for a single GPU machine, this would look like +```bash +--local_rank 0 +``` + +--- + +## 2. Integrating DeepSpeed For More Efficient Training + +In this next section we'll add DeepSpeed to the model presented in Section 1 and turn on several features. + +## 2.0 Core DeepSpeed Code Changes + +Please see the [Writing DeepSpeed Models](https://www.deepspeed.ai/getting-started/#writing-deepspeed-models) instructions written on modifying an existing model to use DeepSpeed. Also we will heavily rely on the [DeepSpeed API documentation](https://deepspeed.readthedocs.io/en/latest/) and [config JSON documentation](https://www.deepspeed.ai/docs/config-json/) going forward. + +Please install DeepSpeed via `pip install deepspeed` if you haven't already done so, after installing you can check if your current version and other information via `ds_report`. For this tutorial we assume a DeepSpeed version of >= 0.5.4 and a torch version >= 1.6. Please upgrade via `pip install --upgrade deepspeed` if you are running an older version of DeepSpeed. + +### Add deepspeed.initialize + config + +Our first task is to identify where to add `deepspeed.initialize()` to the existing code in order to use the DeepSpeed training engine. Please see the [deepspeed.initialize API documentation](https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization) for more details. This needs to be done after the model has been created and before the training loop has started. Most of our edits will be inside the `train` function inside [train_bert.py](./train_bert.py). + +After the model is created and before the optimizer is created we want to add the following lines: + +```python +ds_config = { + "train_micro_batch_size_per_gpu": batch_size, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, +} +model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) +``` + +This will create the DeepSpeed training engine based on the previously instantiated model and the new `ds_config` dictionary. We can now also remove the previous lines of code that created an Adam optimizer, this will now be done via the DeepSpeed engine. It should be noted, you can optionally created your own optimizer and pass it into `deepspeed.initialize` however DeepSpeed is able to make further performance optimizations by instantiating its own optimizers. + +### Update the training-loop + +Next we want to update our training-loop to use the new model engine with the following changes: + +* `model.to(device)` can be removed + * DeepSpeed will be careful on when to move the model to GPU to reduce GPU memory usage (e.g., converts to half on CPU then moves to GPU) +* `optimizer.zero_grad()` can be removed + * DeepSpeed will do this for you at the right time. +* Replace `loss.backward()` with `model.backward(loss)` + * There are several cases where the engine will properly scale the loss when using certain features (e.g., fp16, gradient-accumulation). +* Replace `optimizer.step()` with `model.step()` + * The optimizer step is handled by the engine now and is responsible for dispatching to the right optimizer depending on certain features. + +### Update checkpoint save and load + +Immediately after our new `deepspeed.initialize` you will see a checkpoint load and in the training-loop you will see a few checkpoint save calls. DeepSpeed handles the complexities of checkpoint saving for you so we can simplify these codepaths in the following way. Please refer to the [model checkpoint API documentation](https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html) for more details. + +__Checkpoint saving__: DeepSpeed will construct and save the state_dict for you, we can replace the *two* checkpoint saving snippets (i.e., `state_dict` construction and `torch.save`) and replace them with the snippet below. The `client_state` being passed in here is an example of state outside the view of DeepSpeed that will be saved with the checkpoint. + +```python +model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) +``` + +__Checkpoint loading__: The checkpoint loading is happening right before the training-loop starts. It invokes the `load_model_checkpoint` function which consists of around 30 lines of code. We can replace the `load_model_checkpoint(load_checkpoint_dir, model, optimizer)` call with the following snippet: + +```python +_, client_state = model.load_checkpoint(load_dir=load_checkpoint_dir) +checkpoint_step = client_state['checkpoint_step'] +``` + +## 2.1 Launching Training + +We are now ready to launch our training! As a convenience, DeepSpeed provides its own launcher that is seamlessly compatible with clusters that provide a `/job/hostfile` containing all available machines in your job. You can now try running your model on your available GPU(s) with the command below. By default this will attempt to run distributed data-parallel (DDP) training across all available GPUs on the current machine + any external machines listed in your `/job/hostfile`. Please read [more details about the DeepSpeed launcher](https://www.deepspeed.ai/getting-started/#launching-deepspeed-training) and its assumptions on our website. + +```bash +deepspeed train_bert.py --checkpoint_dir . +``` + +--- +📌 **Note:** If using the deepspeed launcher you should not pass the `--local_rank` explicitly. This will be done by the launcher in the same way as if you launched with `torch.distributed.launch` from PyTorch. + +--- + +## 2.2 Mixed Precision Training (fp16) + +Now that we are setup to use the DeepSpeed engine with our model we can start trying out a few different features of DeepSpeed. One feature is mixed precision training that utilizes half precision (floating-point 16 or fp16) data types. If you want to learn more about how mixed precision training works please refer to the Mixed Precision Training paper [[3]](https://arxiv.org/pdf/1710.03740v3.pdf) from Baidu and NVIDIA on the topic. + +To enable this mode in DeepSpeed we need to update our `ds_config` before the engine is created. Please see [fp16 training options](https://www.deepspeed.ai/docs/config-json/#fp16-training-options) in the config documentation for more information. In our case let's simple enable it by adding the following to our `ds_config` dictionary: + +```python + "fp16": { + "enabled": True + } +``` + +The memory reduction by switching from fp32 to fp16 results in the *model parameters* using half the amount of GPU memory, however the overall GPU memory reduction is not as simple. Since fp16 has half the available bits as fp32 it is not able to represent the same expressiveness as fp32, which can result in numeric instabilities during training. We are able to get around these instabilities in most cases by keeping some states in fp16 and others remain in fp32 (see Section 3 in [[3]](https://arxiv.org/pdf/1710.03740v3.pdf) if you'd like to learn more). + +The primary reason to utilize fp16 training is due to *Tensor Cores*. If you are training with NVIDIA V100 or A100 GPUs they include Tensor Cores which in some cases can accelerate computation by as much as 8x if certain conditions are met. One of the most important conditions is that your model parameters are stored as fp16. For more details on other conditions and tips to better utilize these cores please see this guide from NVIDIA on [Tips for Optimizing GPU Performance Using Tensor Cores](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). + +--- +📌 **Note:** At the start of training you will probably see several log messages about loss scaling and overflows, this is normal. In order for fp16 training to be numerically stable we utilize a common technique called "loss scaling" (similar to Section 3.2 in [[3]](https://arxiv.org/pdf/1710.03740v3.pdf)). This attempts to find a scaling value to mitigate gradient over/under-flows during training. + +--- + +## 2.3 Zero Redundancy Optimizer (ZeRO) + +ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. There are 3 different stages of ZeRO, Stage 1: optimizer state partitioning, Stage 2: optimizer state + gradient partitioning, and Stage 3: optimizer state + gradient + weight partitioning. We will focus on two features of ZeRO here, ZeRO Stage 1 and ZeRO-Offload. For further information, please refer to our [tutorial deep diving ZeRO](https://www.deepspeed.ai/tutorials/zero/) and our [tutorial deep diving ZeRO Offload](https://www.deepspeed.ai/tutorials/zero-offload/) on our website. To deep dive into ZeRO, please see our three papers [[4](https://arxiv.org/pdf/1910.02054.pdf), [5](https://www.usenix.org/system/files/atc21-ren-jie.pdf), [6](https://arxiv.org/abs/2104.07857)] that explore different optimizations in this space. + +* ZeRO Stage 1: The optimizer states (e.g., for the Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. +* ZeRO-Offload: Supports efficiently offloading optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be trained on a single GPU. + +To enable ZeRO Stage 1 in DeepSpeed we need to again update our `ds_config` before the engine is created. Please see [ZeRO optimizations](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training) in the DeepSpeed config documentation for more information. In our case let's simply enable stage 1 it by adding the following to our `ds_config` dictionary: + +```python + "zero_optimization": { + "stage": 1 + } +``` + +We can re-run our training now with ZeRO stage 1 enabled and will see a per-GPU memory reduction as we scale up the total number of GPUs. Typically you can now use this extra GPU memory to either scale up your model size or scale up your per-GPU training batch size. However, if we only have 1 GPU available we probably want to enable ZeRO-Offload to allow us to train larger model sizes. Please update your `ds_config` to include the following: + +```python + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "cpu" + } + } +``` + +This config will now allow us to train a much larger model than we were previously able to do. For example on a single P40 GPU with 24GB of memory we are unable to train a 2 billion parameter model (i.e., `--num_layers 24 --h_dim 4096`), however with ZeRO-Offload we now can! + +```bash +deepspeed train_bert.py --checkpoint_dir . --num_layers 24 --h_dim 4096 +``` + +--- +📌 **Note:** Earlier on when we setup `deepspeed.initialize` we chose not to explicitly pass an optimizer and instead let the DeepSpeed engine instantiate one for us. This is especially useful now that we are using ZeRO-Offload. DeepSpeed includes a highly optimized version of Adam that executes purely on CPU. This means that DeepSpeed will detect if you are using ZeRO-Offload w. Adam and switch to optimized CPUAdam variant. + +--- + +## References +> [1] +[Vaswani et al. Attention is all you need. +In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)](https://arxiv.org/pdf/1706.03762.pdf) +> +> [2] +[J. Devlin et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) +> +> [3] +[P. Micikevicius et al. Mixed Precision Training (ICLR'18)](https://arxiv.org/pdf/1710.03740v3.pdf) +> +> [4]> +[S. Rajbhandari, J. Rasley, O. Ruwase, Y. He. ZeRO: memory optimizations toward training trillion parameter models. (SC‘20)](https://arxiv.org/pdf/1910.02054.pdf) +> +> [5] +[J. Ren, S. Rajbhandari, R. Aminabadi, O. Ruwase, S. Yang, M. Zhang, D. Li, Y. He. ZeRO-Offload: Democratizing Billion-Scale Model Training. (ATC'21)](https://www.usenix.org/system/files/atc21-ren-jie.pdf) +> +> [6] +[S. Rajbhandari, O. Ruwase, J. Rasley, S. Smith, Y. He. ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning (SC'21)](https://arxiv.org/abs/2104.07857) diff --git a/HelloDeepSpeed/requirements.txt b/HelloDeepSpeed/requirements.txt new file mode 100644 index 000000000..3471b7061 --- /dev/null +++ b/HelloDeepSpeed/requirements.txt @@ -0,0 +1,8 @@ +datasets==1.13.3 +transformers==4.5.1 +fire==0.4.0 +pytz==2021.1 +loguru==0.5.3 +sh==1.14.2 +pytest==6.2.5 +tqdm==4.62.3 \ No newline at end of file diff --git a/HelloDeepSpeed/tests/__init__.py b/HelloDeepSpeed/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/HelloDeepSpeed/tests/test_train_bert.py b/HelloDeepSpeed/tests/test_train_bert.py new file mode 100644 index 000000000..307fb735f --- /dev/null +++ b/HelloDeepSpeed/tests/test_train_bert.py @@ -0,0 +1,108 @@ +import tempfile + +import numpy as np +import pytest +import torch +import tqdm +from transformers import AutoTokenizer + +from train_bert import create_data_iterator, create_model, load_model_checkpoint, train + + +@pytest.fixture(scope="function") +def checkpoint_dir() -> str: + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +def test_masking_stats(tol: float = 1e-3): + """Test to check that the masking probabilities + match what we expect them to be. + """ + kwargs = { + "mask_prob": 0.15, + "random_replace_prob": 0.1, + "unmask_replace_prob": 0.1, + "batch_size": 8, + } + tokenizer = AutoTokenizer.from_pretrained("roberta-base") + dataloader = create_data_iterator(**kwargs) + num_samples = 10000 + total_tokens = 0 + masked_tokens = 0 + random_replace_tokens = 0 + unmasked_replace_tokens = 0 + for ix, batch in tqdm.tqdm(enumerate(dataloader, start=1), total=num_samples): + # Since we don't mask the BOS / EOS tokens, we subtract them from the total tokens + total_tokens += batch["attention_mask"].sum().item() - ( + 2 * batch["attention_mask"].size(0) + ) + masked_tokens += (batch["tgt_tokens"] != tokenizer.pad_token_id).sum().item() + random_or_unmasked = ( + batch["tgt_tokens"] != tokenizer.pad_token_id + ).logical_and(batch["src_tokens"] != tokenizer.mask_token_id) + unmasked = random_or_unmasked.logical_and( + batch["src_tokens"] == batch["tgt_tokens"] + ) + unmasked_replace_tokens += unmasked.sum().item() + random_replace_tokens += random_or_unmasked.sum().item() - unmasked.sum().item() + if ix == num_samples: + break + estimated_mask_prob = masked_tokens / total_tokens + estimated_random_tokens = random_replace_tokens / total_tokens + estimated_unmasked_tokens = unmasked_replace_tokens / total_tokens + assert np.isclose(estimated_mask_prob, kwargs["mask_prob"], atol=tol) + assert np.isclose( + estimated_random_tokens, + kwargs["random_replace_prob"] * kwargs["mask_prob"], + atol=tol, + ) + assert np.isclose( + estimated_unmasked_tokens, + kwargs["unmask_replace_prob"] * kwargs["mask_prob"], + atol=tol, + ) + + +def test_model_checkpointing(checkpoint_dir: str): + """Training a small model, and ensuring + that both checkpointing and resuming from + a checkpoint work. + """ + # First train a tiny model for 5 iterations + train_params = { + "checkpoint_dir": checkpoint_dir, + "checkpoint_every": 2, + "num_layers": 2, + "num_heads": 4, + "ff_dim": 64, + "h_dim": 64, + "num_iterations": 5, + } + exp_dir = train(**train_params) + # now check that we have 3 checkpoints + assert len(list(exp_dir.glob("*.pt"))) == 3 + model = create_model( + num_layers=train_params["num_layers"], + num_heads=train_params["num_heads"], + ff_dim=train_params["ff_dim"], + h_dim=train_params["h_dim"], + dropout=0.1, + ) + optimizer = torch.optim.Adam(model.parameters()) + step, model, optimizer = load_model_checkpoint(exp_dir, model, optimizer) + assert step == 5 + model_state_dict = model.state_dict() + # the saved checkpoint would be for iteration 5 + correct_state_dict = torch.load(exp_dir / "checkpoint.iter_5.pt") + correct_model_state_dict = correct_state_dict["model"] + assert set(model_state_dict.keys()) == set(correct_model_state_dict.keys()) + assert all( + torch.allclose(model_state_dict[key], correct_model_state_dict[key]) + for key in model_state_dict.keys() + ) + # Finally, try training with the checkpoint + train_params.pop("checkpoint_dir") + train_params["load_checkpoint_dir"] = str(exp_dir) + train_params["num_iterations"] = 10 + train(**train_params) diff --git a/HelloDeepSpeed/train_bert.py b/HelloDeepSpeed/train_bert.py new file mode 100644 index 000000000..14d61f00c --- /dev/null +++ b/HelloDeepSpeed/train_bert.py @@ -0,0 +1,791 @@ +import datetime +import json +import pathlib +import re +import string +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union + +import datasets +import fire +import loguru +import numpy as np +import pytz +import sh +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.models.roberta import RobertaConfig, RobertaModel +from transformers.models.roberta.modeling_roberta import ( + RobertaLMHead, + RobertaPreTrainedModel, +) + +logger = loguru.logger + +###################################################################### +############### Dataset Creation Related Functions ################### +###################################################################### + + +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +def collate_function( + batch: List[Tuple[List[int], List[int]]], pad_token_id: int +) -> Dict[str, torch.Tensor]: + """Collect a list of masked token indices, and labels, and + batch them, padding to max length in the batch. + """ + max_length = max(len(token_ids) for token_ids, _ in batch) + padded_token_ids = [ + token_ids + [pad_token_id for _ in range(0, max_length - len(token_ids))] + for token_ids, _ in batch + ] + padded_labels = [ + labels + [pad_token_id for _ in range(0, max_length - len(labels))] + for _, labels in batch + ] + src_tokens = torch.LongTensor(padded_token_ids) + tgt_tokens = torch.LongTensor(padded_labels) + attention_mask = src_tokens.ne(pad_token_id).type_as(src_tokens) + return { + "src_tokens": src_tokens, + "tgt_tokens": tgt_tokens, + "attention_mask": attention_mask, + } + + +def masking_function( + text: str, + tokenizer: TokenizerType, + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + max_length: int, +) -> Tuple[List[int], List[int]]: + """Given a text string, randomly mask wordpieces for Bert MLM + training. + + Args: + text (str): + The input text + tokenizer (TokenizerType): + The tokenizer for tokenization + mask_prob (float): + What fraction of tokens to mask + random_replace_prob (float): + Of the masked tokens, how many should be replaced with + random tokens (improves performance) + unmask_replace_prob (float): + Of the masked tokens, how many should be replaced with + the original token (improves performance) + max_length (int): + The maximum sequence length to consider. Note that for + Bert style models, this is a function of the number of + positional embeddings you learn + + Returns: + Tuple[List[int], List[int]]: + The masked token ids (based on the tokenizer passed), + and the output labels (padded with `tokenizer.pad_token_id`) + """ + # Note: By default, encode does add the BOS and EOS token + # Disabling that behaviour to make this more clear + tokenized_ids = ( + [tokenizer.bos_token_id] + + tokenizer.encode( + text, add_special_tokens=False, truncation=True, max_length=max_length - 2 + ) + + [tokenizer.eos_token_id] + ) + seq_len = len(tokenized_ids) + tokenized_ids = np.array(tokenized_ids) + subword_mask = np.full(len(tokenized_ids), False) + + # Masking the BOS and EOS token leads to slightly worse performance + low = 1 + high = len(subword_mask) - 1 + mask_choices = np.arange(low, high) + num_subwords_to_mask = max(int((mask_prob * (high - low)) + np.random.rand()), 1) + subword_mask[ + np.random.choice(mask_choices, num_subwords_to_mask, replace=False) + ] = True + + # Create the labels first + labels = np.full(seq_len, tokenizer.pad_token_id) + labels[subword_mask] = tokenized_ids[subword_mask] + + tokenized_ids[subword_mask] = tokenizer.mask_token_id + + # Now of the masked tokens, choose how many to replace with random and how many to unmask + rand_or_unmask_prob = random_replace_prob + unmask_replace_prob + if rand_or_unmask_prob > 0: + rand_or_unmask = subword_mask & ( + np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob + ) + if random_replace_prob == 0: + unmask = rand_or_unmask + rand_mask = None + elif unmask_replace_prob == 0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = unmask_replace_prob / rand_or_unmask_prob + decision = np.random.rand(len(tokenized_ids)) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + if unmask is not None: + tokenized_ids[unmask] = labels[unmask] + if rand_mask is not None: + weights = np.ones(tokenizer.vocab_size) + weights[tokenizer.all_special_ids] = 0 + probs = weights / weights.sum() + num_rand = rand_mask.sum() + tokenized_ids[rand_mask] = np.random.choice( + tokenizer.vocab_size, num_rand, p=probs + ) + return tokenized_ids.tolist(), labels.tolist() + + +class WikiTextMLMDataset(Dataset): + """A [Map style dataset](https://pytorch.org/docs/stable/data.html) + for iterating over the wikitext dataset. Note that this assumes + the dataset can fit in memory. For larger datasets + you'd want to shard them and use an iterable dataset (eg: see + [Infinibatch](https://github.com/microsoft/infinibatch)) + + Args: + Dataset (datasets.arrow_dataset.Dataset): + The wikitext dataset + masking_function (Callable[[str], Tuple[List[int], List[int]]]) + The masking function. To generate one training instance, + the masking function is applied to the `text` of a dataset + record + + """ + + def __init__( + self, + dataset: datasets.arrow_dataset.Dataset, + masking_function: Callable[[str], Tuple[List[int], List[int]]], + ) -> None: + self.dataset = dataset + self.masking_function = masking_function + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: + tokens, labels = self.masking_function(self.dataset[idx]["text"]) + return (tokens, labels) + + +T = TypeVar("T") + + +class InfiniteIterator(object): + def __init__(self, iterable: Iterable[T]) -> None: + self._iterable = iterable + self._iterator = iter(self._iterable) + + def __iter__(self): + return self + + def __next__(self) -> T: + next_item = None + try: + next_item = next(self._iterator) + except StopIteration: + self._iterator = iter(self._iterable) + next_item = next(self._iterator) + return next_item + + +def create_data_iterator( + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + batch_size: int, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", +) -> InfiniteIterator: + """Create the dataloader. + + Args: + mask_prob (float): + Fraction of tokens to mask + random_replace_prob (float): + Fraction of masked tokens to replace with random token + unmask_replace_prob (float): + Fraction of masked tokens to replace with the actual token + batch_size (int): + The batch size of the generated tensors + max_seq_length (int, optional): + The maximum sequence length for the MLM task. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + + Returns: + InfiniteIterator: + The torch DataLoader, wrapped in an InfiniteIterator class, to + be able to continuously generate samples + + """ + wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train") + wikitext_dataset = wikitext_dataset.filter(lambda record: record["text"] != "").map( + lambda record: {"text": record["text"].rstrip("\n")} + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + masking_function_partial = partial( + masking_function, + tokenizer=tokenizer, + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + max_length=max_seq_length, + ) + dataset = WikiTextMLMDataset(wikitext_dataset, masking_function_partial) + collate_fn_partial = partial(collate_function, pad_token_id=tokenizer.pad_token_id) + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_partial + ) + + return InfiniteIterator(dataloader) + + +###################################################################### +############### Model Creation Related Functions ##################### +###################################################################### + + +class RobertaLMHeadWithMaskedPredict(RobertaLMHead): + def __init__( + self, config: RobertaConfig, embedding_weight: Optional[torch.Tensor] = None + ) -> None: + super(RobertaLMHeadWithMaskedPredict, self).__init__(config) + if embedding_weight is not None: + self.decoder.weight = embedding_weight + + def forward( # pylint: disable=arguments-differ + self, + features: torch.Tensor, + masked_token_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """The current `transformers` library does not provide support + for masked_token_indices. This function provides the support, by + running the final forward pass only for the masked indices. This saves + memory + + Args: + features (torch.Tensor): + The features to select from. Shape (batch, seq_len, h_dim) + masked_token_indices (torch.Tensor, optional): + The indices of masked tokens for index select. Defaults to None. + Shape: (num_masked_tokens,) + + Returns: + torch.Tensor: + The index selected features. Shape (num_masked_tokens, h_dim) + + """ + if masked_token_indices is not None: + features = torch.index_select( + features.view(-1, features.shape[-1]), 0, masked_token_indices + ) + return super().forward(features) + + +class RobertaMLMModel(RobertaPreTrainedModel): + def __init__(self, config: RobertaConfig, encoder: RobertaModel) -> None: + super().__init__(config) + self.encoder = encoder + self.lm_head = RobertaLMHeadWithMaskedPredict( + config, self.encoder.embeddings.word_embeddings.weight + ) + self.lm_head.apply(self._init_weights) + + def forward( + self, + src_tokens: torch.Tensor, + attention_mask: torch.Tensor, + tgt_tokens: torch.Tensor, + ) -> torch.Tensor: + """The forward pass for the MLM task + + Args: + src_tokens (torch.Tensor): + The masked token indices. Shape: (batch, seq_len) + attention_mask (torch.Tensor): + The attention mask, since the batches are padded + to the largest sequence. Shape: (batch, seq_len) + tgt_tokens (torch.Tensor): + The output tokens (padded with `config.pad_token_id`) + + Returns: + torch.Tensor: + The MLM loss + """ + # shape: (batch, seq_len, h_dim) + sequence_output, *_ = self.encoder( + input_ids=src_tokens, attention_mask=attention_mask, return_dict=False + ) + + pad_token_id = self.config.pad_token_id + # (labels have also been padded with pad_token_id) + # filter out all masked labels + # shape: (num_masked_tokens,) + masked_token_indexes = torch.nonzero( + (tgt_tokens != pad_token_id).view(-1) + ).view(-1) + # shape: (num_masked_tokens, vocab_size) + prediction_scores = self.lm_head(sequence_output, masked_token_indexes) + # shape: (num_masked_tokens,) + target = torch.index_select(tgt_tokens.view(-1), 0, masked_token_indexes) + + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), target + ) + return masked_lm_loss + + +def create_model( + num_layers: int, num_heads: int, ff_dim: int, h_dim: int, dropout: float +) -> RobertaMLMModel: + """Create a Bert model with the specified `num_heads`, `ff_dim`, + `h_dim` and `dropout` + + Args: + num_layers (int): + The number of layers + num_heads (int): + The number of attention heads + ff_dim (int): + The intermediate hidden size of + the feed forward block of the + transformer + h_dim (int): + The hidden dim of the intermediate + representations of the transformer + dropout (float): + The value of dropout to be used. + Note that we apply the same dropout + to both the attention layers and the + FF layers + + Returns: + RobertaMLMModel: + A Roberta model for MLM task + + """ + roberta_config_dict = { + "attention_probs_dropout_prob": dropout, + "bos_token_id": 0, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": dropout, + "hidden_size": h_dim, + "initializer_range": 0.02, + "intermediate_size": ff_dim, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 514, + "model_type": "roberta", + "num_attention_heads": num_heads, + "num_hidden_layers": num_layers, + "pad_token_id": 1, + "type_vocab_size": 1, + "vocab_size": 50265, + } + roberta_config = RobertaConfig.from_dict(roberta_config_dict) + roberta_encoder = RobertaModel(roberta_config) + roberta_model = RobertaMLMModel(roberta_config, roberta_encoder) + return roberta_model + + +###################################################################### +########### Experiment Management Related Functions ################## +###################################################################### + + +def get_unique_identifier(length: int = 8) -> str: + """Create a unique identifier by choosing `length` + random characters from list of ascii characters and numbers + """ + alphabet = string.ascii_lowercase + string.digits + uuid = "".join(alphabet[ix] for ix in np.random.choice(len(alphabet), length)) + return uuid + + +def create_experiment_dir( + checkpoint_dir: pathlib.Path, all_arguments: Dict[str, Any] +) -> pathlib.Path: + """Create an experiment directory and save all arguments in it. + Additionally, also store the githash and gitdiff. Finally create + a directory for `Tensorboard` logs. The structure would look something + like + checkpoint_dir + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ + + Args: + checkpoint_dir (pathlib.Path): + The base checkpoint directory + all_arguments (Dict[str, Any]): + The arguments to save + + Returns: + pathlib.Path: The experiment directory + """ + # experiment name follows the following convention + # {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + current_time = datetime.datetime.now(pytz.timezone("US/Pacific")) + expname = "bert_pretrain.{0}.{1}.{2}.{3}.{4}.{5}.{6}".format( + current_time.year, + current_time.month, + current_time.day, + current_time.hour, + current_time.minute, + current_time.second, + get_unique_identifier(), + ) + exp_dir = checkpoint_dir / expname + exp_dir.mkdir(exist_ok=False) + hparams_file = exp_dir / "hparams.json" + with hparams_file.open("w") as handle: + json.dump(obj=all_arguments, fp=handle, indent=2) + # Save the git hash + try: + gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) + with (exp_dir / "githash.log").open("w") as handle: + handle.write(gitlog.stdout.decode("utf-8")) + except sh.ErrorReturnCode_128: + logger.info( + "Seems like the code is not running from" + " within a git repo, so hash will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # And the git diff + try: + gitdiff = sh.git.diff(_fg=False, _tty_out=False) + with (exp_dir / "gitdiff.log").open("w") as handle: + handle.write(gitdiff.stdout.decode("utf-8")) + except sh.ErrorReturnCode_129: + logger.info( + "Seems like the code is not running from" + " within a git repo, so diff will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # Finally create the Tensorboard Dir + tb_dir = exp_dir / "tb_dir" + tb_dir.mkdir() + return exp_dir + + +###################################################################### +################ Checkpoint Related Functions ######################## +###################################################################### + + +def load_model_checkpoint( + load_checkpoint_dir: pathlib.Path, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, +) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + """Loads the optimizer state dict and model state dict from the load_checkpoint_dir + into the passed model and optimizer. Searches for the most recent checkpoint to + load from + + Args: + load_checkpoint_dir (pathlib.Path): + The base checkpoint directory to load from + model (torch.nn.Module): + The model to load the checkpoint weights into + optimizer (torch.optim.Optimizer): + The optimizer to load the checkpoint weigths into + + Returns: + Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + The checkpoint step, model with state_dict loaded and + optimizer with state_dict loaded + + """ + logger.info(f"Loading model and optimizer checkpoint from {load_checkpoint_dir}") + checkpoint_files = list( + filter( + lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) is not None, + load_checkpoint_dir.glob("*.pt"), + ) + ) + assert len(checkpoint_files) > 0, "No checkpoints found in directory" + checkpoint_files = sorted( + checkpoint_files, + key=lambda path: int( + re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") + ), + ) + latest_checkpoint_path = checkpoint_files[-1] + checkpoint_step = int( + re.search(r"iter_(?P\d+)\.pt", latest_checkpoint_path.name).group( + "iter_no" + ) + ) + + state_dict = torch.load(latest_checkpoint_path) + model.load_state_dict(state_dict["model"], strict=True) + optimizer.load_state_dict(state_dict["optimizer"]) + logger.info( + f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + ) + return checkpoint_step, model, optimizer + + +###################################################################### +######################## Driver Functions ############################ +###################################################################### + + +def train( + checkpoint_dir: str = None, + load_checkpoint_dir: str = None, + # Dataset Parameters + mask_prob: float = 0.15, + random_replace_prob: float = 0.1, + unmask_replace_prob: float = 0.1, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", + # Model Parameters + num_layers: int = 6, + num_heads: int = 8, + ff_dim: int = 512, + h_dim: int = 256, + dropout: float = 0.1, + # Training Parameters + batch_size: int = 8, + num_iterations: int = 10000, + checkpoint_every: int = 1000, + log_every: int = 10, + local_rank: int = -1, +) -> pathlib.Path: + """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) + (transformer encoder only) model for MLM Task + + Args: + checkpoint_dir (str): + The base experiment directory to save experiments to + mask_prob (float, optional): + The fraction of tokens to mask. Defaults to 0.15. + random_replace_prob (float, optional): + The fraction of masked tokens to replace with random token. + Defaults to 0.1. + unmask_replace_prob (float, optional): + The fraction of masked tokens to leave unchanged. + Defaults to 0.1. + max_seq_length (int, optional): + The maximum sequence length of the examples. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + num_layers (int, optional): + The number of layers in the Bert model. Defaults to 6. + num_heads (int, optional): + Number of attention heads to use. Defaults to 8. + ff_dim (int, optional): + Size of the intermediate dimension in the FF layer. + Defaults to 512. + h_dim (int, optional): + Size of intermediate representations. + Defaults to 256. + dropout (float, optional): + Amout of Dropout to use. Defaults to 0.1. + batch_size (int, optional): + The minibatch size. Defaults to 8. + num_iterations (int, optional): + Total number of iterations to run the model for. + Defaults to 10000. + checkpoint_every (int, optional): + Save checkpoint after these many steps. + + ..note :: + + You want this to be frequent enough that you can + resume training in case it crashes, but not so much + that you fill up your entire storage ! + + Defaults to 1000. + log_every (int, optional): + Print logs after these many steps. Defaults to 10. + local_rank (int, optional): + Which GPU to run on (-1 for CPU). Defaults to -1. + + Returns: + pathlib.Path: The final experiment directory + + """ + device = ( + torch.device("cuda", local_rank) + if (local_rank > -1) and torch.cuda.is_available() + else torch.device("cpu") + ) + ################################ + ###### Create Exp. Dir ######### + ################################ + if checkpoint_dir is None and load_checkpoint_dir is None: + logger.error("Need to specify one of checkpoint_dir" " or load_checkpoint_dir") + return + if checkpoint_dir is not None and load_checkpoint_dir is not None: + logger.error("Cannot specify both checkpoint_dir" " and load_checkpoint_dir") + return + if checkpoint_dir: + logger.info("Creating Experiment Directory") + checkpoint_dir = pathlib.Path(checkpoint_dir) + checkpoint_dir.mkdir(exist_ok=True) + all_arguments = { + # Dataset Params + "mask_prob": mask_prob, + "random_replace_prob": random_replace_prob, + "unmask_replace_prob": unmask_replace_prob, + "max_seq_length": max_seq_length, + "tokenizer": tokenizer, + # Model Params + "num_layers": num_layers, + "num_heads": num_heads, + "ff_dim": ff_dim, + "h_dim": h_dim, + "dropout": dropout, + # Training Params + "batch_size": batch_size, + "num_iterations": num_iterations, + "checkpoint_every": checkpoint_every, + } + exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) + logger.info(f"Experiment Directory created at {exp_dir}") + else: + logger.info("Loading from Experiment Directory") + load_checkpoint_dir = pathlib.Path(load_checkpoint_dir) + assert load_checkpoint_dir.exists() + with (load_checkpoint_dir / "hparams.json").open("r") as handle: + hparams = json.load(handle) + # Set the hparams + # Dataset Params + mask_prob = hparams.get("mask_prob", mask_prob) + tokenizer = hparams.get("tokenizer", tokenizer) + random_replace_prob = hparams.get("random_replace_prob", random_replace_prob) + unmask_replace_prob = hparams.get("unmask_replace_prob", unmask_replace_prob) + max_seq_length = hparams.get("max_seq_length", max_seq_length) + # Model Params + ff_dim = hparams.get("ff_dim", ff_dim) + h_dim = hparams.get("h_dim", h_dim) + dropout = hparams.get("dropout", dropout) + num_layers = hparams.get("num_layers", num_layers) + num_heads = hparams.get("num_heads", num_heads) + # Training Params + batch_size = hparams.get("batch_size", batch_size) + _num_iterations = hparams.get("num_iterations", num_iterations) + num_iterations = max(num_iterations, _num_iterations) + checkpoint_every = hparams.get("checkpoint_every", checkpoint_every) + exp_dir = load_checkpoint_dir + # Tensorboard writer + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) + ################################ + ###### Create Datasets ######### + ################################ + logger.info("Creating Datasets") + data_iterator = create_data_iterator( + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + batch_size=batch_size, + ) + logger.info("Dataset Creation Done") + ################################ + ###### Create Model ############ + ################################ + logger.info("Creating Model") + model = create_model( + num_layers=num_layers, + num_heads=num_heads, + ff_dim=ff_dim, + h_dim=h_dim, + dropout=dropout, + ) + model = model.to(device) + logger.info("Model Creation Done") + ################################ + ###### Create Optimizer ####### + ################################ + logger.info("Creating Optimizer") + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + logger.info("Optimizer Creation Done") + ################################ + #### Load Model checkpoint ##### + ################################ + start_step = 1 + if load_checkpoint_dir is not None: + checkpoint_step, model, optimizer = load_model_checkpoint( + load_checkpoint_dir, model, optimizer + ) + start_step = checkpoint_step + 1 + + ################################ + ####### The Training Loop ###### + ################################ + logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") + model.train() + losses = [] + for step, batch in enumerate(data_iterator, start=start_step): + if step >= num_iterations: + break + optimizer.zero_grad() + # Move the tensors to device + for key, value in batch.items(): + batch[key] = value.to(device) + # Forward pass + loss = model(**batch) + # Backward pass + loss.backward() + # Optimizer Step + optimizer.step() + losses.append(loss.item()) + if step % log_every == 0: + logger.info("Loss: {0:.4f}".format(np.mean(losses))) + summary_writer.add_scalar(f"Train/loss", np.mean(losses), step) + if step % checkpoint_every == 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) + logger.info( + "Saved model to {0}".format((exp_dir / f"checkpoint.iter_{step}.pt")) + ) + # Save the last checkpoint if not saved yet + if step % checkpoint_every != 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) + logger.info( + "Saved model to {0}".format((exp_dir / f"checkpoint.iter_{step}.pt")) + ) + + return exp_dir + + +if __name__ == "__main__": + fire.Fire(train) diff --git a/HelloDeepSpeed/train_bert_ds.py b/HelloDeepSpeed/train_bert_ds.py new file mode 100644 index 000000000..421d03daf --- /dev/null +++ b/HelloDeepSpeed/train_bert_ds.py @@ -0,0 +1,805 @@ +""" +Modified version of train_bert.py that adds DeepSpeed +""" + +import datetime +import json +import pathlib +import re +import string +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union + +import datasets +import fire +import loguru +import numpy as np +import pytz +import sh +import torch +import torch.nn as nn +import deepspeed +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.models.roberta import RobertaConfig, RobertaModel +from transformers.models.roberta.modeling_roberta import ( + RobertaLMHead, + RobertaPreTrainedModel, +) + +logger = loguru.logger + +###################################################################### +############### Dataset Creation Related Functions ################### +###################################################################### + + +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +def collate_function( + batch: List[Tuple[List[int], List[int]]], pad_token_id: int +) -> Dict[str, torch.Tensor]: + """Collect a list of masked token indices, and labels, and + batch them, padding to max length in the batch. + """ + max_length = max(len(token_ids) for token_ids, _ in batch) + padded_token_ids = [ + token_ids + [pad_token_id for _ in range(0, max_length - len(token_ids))] + for token_ids, _ in batch + ] + padded_labels = [ + labels + [pad_token_id for _ in range(0, max_length - len(labels))] + for _, labels in batch + ] + src_tokens = torch.LongTensor(padded_token_ids) + tgt_tokens = torch.LongTensor(padded_labels) + attention_mask = src_tokens.ne(pad_token_id).type_as(src_tokens) + return { + "src_tokens": src_tokens, + "tgt_tokens": tgt_tokens, + "attention_mask": attention_mask, + } + + +def masking_function( + text: str, + tokenizer: TokenizerType, + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + max_length: int, +) -> Tuple[List[int], List[int]]: + """Given a text string, randomly mask wordpieces for Bert MLM + training. + + Args: + text (str): + The input text + tokenizer (TokenizerType): + The tokenizer for tokenization + mask_prob (float): + What fraction of tokens to mask + random_replace_prob (float): + Of the masked tokens, how many should be replaced with + random tokens (improves performance) + unmask_replace_prob (float): + Of the masked tokens, how many should be replaced with + the original token (improves performance) + max_length (int): + The maximum sequence length to consider. Note that for + Bert style models, this is a function of the number of + positional embeddings you learn + + Returns: + Tuple[List[int], List[int]]: + The masked token ids (based on the tokenizer passed), + and the output labels (padded with `tokenizer.pad_token_id`) + """ + # Note: By default, encode does add the BOS and EOS token + # Disabling that behaviour to make this more clear + tokenized_ids = ( + [tokenizer.bos_token_id] + + tokenizer.encode( + text, add_special_tokens=False, truncation=True, max_length=max_length - 2 + ) + + [tokenizer.eos_token_id] + ) + seq_len = len(tokenized_ids) + tokenized_ids = np.array(tokenized_ids) + subword_mask = np.full(len(tokenized_ids), False) + + # Masking the BOS and EOS token leads to slightly worse performance + low = 1 + high = len(subword_mask) - 1 + mask_choices = np.arange(low, high) + num_subwords_to_mask = max(int((mask_prob * (high - low)) + np.random.rand()), 1) + subword_mask[ + np.random.choice(mask_choices, num_subwords_to_mask, replace=False) + ] = True + + # Create the labels first + labels = np.full(seq_len, tokenizer.pad_token_id) + labels[subword_mask] = tokenized_ids[subword_mask] + + tokenized_ids[subword_mask] = tokenizer.mask_token_id + + # Now of the masked tokens, choose how many to replace with random and how many to unmask + rand_or_unmask_prob = random_replace_prob + unmask_replace_prob + if rand_or_unmask_prob > 0: + rand_or_unmask = subword_mask & ( + np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob + ) + if random_replace_prob == 0: + unmask = rand_or_unmask + rand_mask = None + elif unmask_replace_prob == 0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = unmask_replace_prob / rand_or_unmask_prob + decision = np.random.rand(len(tokenized_ids)) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + if unmask is not None: + tokenized_ids[unmask] = labels[unmask] + if rand_mask is not None: + weights = np.ones(tokenizer.vocab_size) + weights[tokenizer.all_special_ids] = 0 + probs = weights / weights.sum() + num_rand = rand_mask.sum() + tokenized_ids[rand_mask] = np.random.choice( + tokenizer.vocab_size, num_rand, p=probs + ) + return tokenized_ids.tolist(), labels.tolist() + + +class WikiTextMLMDataset(Dataset): + """A [Map style dataset](https://pytorch.org/docs/stable/data.html) + for iterating over the wikitext dataset. Note that this assumes + the dataset can fit in memory. For larger datasets + you'd want to shard them and use an iterable dataset (eg: see + [Infinibatch](https://github.com/microsoft/infinibatch)) + + Args: + Dataset (datasets.arrow_dataset.Dataset): + The wikitext dataset + masking_function (Callable[[str], Tuple[List[int], List[int]]]) + The masking function. To generate one training instance, + the masking function is applied to the `text` of a dataset + record + + """ + + def __init__( + self, + dataset: datasets.arrow_dataset.Dataset, + masking_function: Callable[[str], Tuple[List[int], List[int]]], + ) -> None: + self.dataset = dataset + self.masking_function = masking_function + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: + tokens, labels = self.masking_function(self.dataset[idx]["text"]) + return (tokens, labels) + + +T = TypeVar("T") + + +class InfiniteIterator(object): + def __init__(self, iterable: Iterable[T]) -> None: + self._iterable = iterable + self._iterator = iter(self._iterable) + + def __iter__(self): + return self + + def __next__(self) -> T: + next_item = None + try: + next_item = next(self._iterator) + except StopIteration: + self._iterator = iter(self._iterable) + next_item = next(self._iterator) + return next_item + + +def create_data_iterator( + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + batch_size: int, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", +) -> InfiniteIterator: + """Create the dataloader. + + Args: + mask_prob (float): + Fraction of tokens to mask + random_replace_prob (float): + Fraction of masked tokens to replace with random token + unmask_replace_prob (float): + Fraction of masked tokens to replace with the actual token + batch_size (int): + The batch size of the generated tensors + max_seq_length (int, optional): + The maximum sequence length for the MLM task. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + + Returns: + InfiniteIterator: + The torch DataLoader, wrapped in an InfiniteIterator class, to + be able to continuously generate samples + + """ + wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train") + wikitext_dataset = wikitext_dataset.filter(lambda record: record["text"] != "").map( + lambda record: {"text": record["text"].rstrip("\n")} + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + masking_function_partial = partial( + masking_function, + tokenizer=tokenizer, + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + max_length=max_seq_length, + ) + dataset = WikiTextMLMDataset(wikitext_dataset, masking_function_partial) + collate_fn_partial = partial(collate_function, pad_token_id=tokenizer.pad_token_id) + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_partial + ) + + return InfiniteIterator(dataloader) + + +###################################################################### +############### Model Creation Related Functions ##################### +###################################################################### + + +class RobertaLMHeadWithMaskedPredict(RobertaLMHead): + def __init__( + self, config: RobertaConfig, embedding_weight: Optional[torch.Tensor] = None + ) -> None: + super(RobertaLMHeadWithMaskedPredict, self).__init__(config) + if embedding_weight is not None: + self.decoder.weight = embedding_weight + + def forward( # pylint: disable=arguments-differ + self, + features: torch.Tensor, + masked_token_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """The current `transformers` library does not provide support + for masked_token_indices. This function provides the support, by + running the final forward pass only for the masked indices. This saves + memory + + Args: + features (torch.Tensor): + The features to select from. Shape (batch, seq_len, h_dim) + masked_token_indices (torch.Tensor, optional): + The indices of masked tokens for index select. Defaults to None. + Shape: (num_masked_tokens,) + + Returns: + torch.Tensor: + The index selected features. Shape (num_masked_tokens, h_dim) + + """ + if masked_token_indices is not None: + features = torch.index_select( + features.view(-1, features.shape[-1]), 0, masked_token_indices + ) + return super().forward(features) + + +class RobertaMLMModel(RobertaPreTrainedModel): + def __init__(self, config: RobertaConfig, encoder: RobertaModel) -> None: + super().__init__(config) + self.encoder = encoder + self.lm_head = RobertaLMHeadWithMaskedPredict( + config, self.encoder.embeddings.word_embeddings.weight + ) + self.lm_head.apply(self._init_weights) + + def forward( + self, + src_tokens: torch.Tensor, + attention_mask: torch.Tensor, + tgt_tokens: torch.Tensor, + ) -> torch.Tensor: + """The forward pass for the MLM task + + Args: + src_tokens (torch.Tensor): + The masked token indices. Shape: (batch, seq_len) + attention_mask (torch.Tensor): + The attention mask, since the batches are padded + to the largest sequence. Shape: (batch, seq_len) + tgt_tokens (torch.Tensor): + The output tokens (padded with `config.pad_token_id`) + + Returns: + torch.Tensor: + The MLM loss + """ + # shape: (batch, seq_len, h_dim) + sequence_output, *_ = self.encoder( + input_ids=src_tokens, attention_mask=attention_mask, return_dict=False + ) + + pad_token_id = self.config.pad_token_id + # (labels have also been padded with pad_token_id) + # filter out all masked labels + # shape: (num_masked_tokens,) + masked_token_indexes = torch.nonzero( + (tgt_tokens != pad_token_id).view(-1) + ).view(-1) + # shape: (num_masked_tokens, vocab_size) + prediction_scores = self.lm_head(sequence_output, masked_token_indexes) + # shape: (num_masked_tokens,) + target = torch.index_select(tgt_tokens.view(-1), 0, masked_token_indexes) + + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), target + ) + return masked_lm_loss + + +def create_model( + num_layers: int, num_heads: int, ff_dim: int, h_dim: int, dropout: float +) -> RobertaMLMModel: + """Create a Bert model with the specified `num_heads`, `ff_dim`, + `h_dim` and `dropout` + + Args: + num_layers (int): + The number of layers + num_heads (int): + The number of attention heads + ff_dim (int): + The intermediate hidden size of + the feed forward block of the + transformer + h_dim (int): + The hidden dim of the intermediate + representations of the transformer + dropout (float): + The value of dropout to be used. + Note that we apply the same dropout + to both the attention layers and the + FF layers + + Returns: + RobertaMLMModel: + A Roberta model for MLM task + + """ + roberta_config_dict = { + "attention_probs_dropout_prob": dropout, + "bos_token_id": 0, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": dropout, + "hidden_size": h_dim, + "initializer_range": 0.02, + "intermediate_size": ff_dim, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 514, + "model_type": "roberta", + "num_attention_heads": num_heads, + "num_hidden_layers": num_layers, + "pad_token_id": 1, + "type_vocab_size": 1, + "vocab_size": 50265, + } + roberta_config = RobertaConfig.from_dict(roberta_config_dict) + roberta_encoder = RobertaModel(roberta_config) + roberta_model = RobertaMLMModel(roberta_config, roberta_encoder) + return roberta_model + + +###################################################################### +########### Experiment Management Related Functions ################## +###################################################################### + + +def get_unique_identifier(length: int = 8) -> str: + """Create a unique identifier by choosing `length` + random characters from list of ascii characters and numbers + """ + alphabet = string.ascii_lowercase + string.digits + uuid = "".join(alphabet[ix] for ix in np.random.choice(len(alphabet), length)) + return uuid + + +def create_experiment_dir( + checkpoint_dir: pathlib.Path, all_arguments: Dict[str, Any] +) -> pathlib.Path: + """Create an experiment directory and save all arguments in it. + Additionally, also store the githash and gitdiff. Finally create + a directory for `Tensorboard` logs. The structure would look something + like + checkpoint_dir + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ + + Args: + checkpoint_dir (pathlib.Path): + The base checkpoint directory + all_arguments (Dict[str, Any]): + The arguments to save + + Returns: + pathlib.Path: The experiment directory + """ + # experiment name follows the following convention + # {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + current_time = datetime.datetime.now(pytz.timezone("US/Pacific")) + expname = "bert_pretrain.{0}.{1}.{2}.{3}.{4}.{5}.{6}".format( + current_time.year, + current_time.month, + current_time.day, + current_time.hour, + current_time.minute, + current_time.second, + get_unique_identifier(), + ) + exp_dir = checkpoint_dir / expname + exp_dir.mkdir(exist_ok=False) + hparams_file = exp_dir / "hparams.json" + with hparams_file.open("w") as handle: + json.dump(obj=all_arguments, fp=handle, indent=2) + # Save the git hash + try: + gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) + with (exp_dir / "githash.log").open("w") as handle: + handle.write(gitlog.stdout.decode("utf-8")) + except sh.ErrorReturnCode_128: + logger.info( + "Seems like the code is not running from" + " within a git repo, so hash will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # And the git diff + try: + gitdiff = sh.git.diff(_fg=False, _tty_out=False) + with (exp_dir / "gitdiff.log").open("w") as handle: + handle.write(gitdiff.stdout.decode("utf-8")) + except sh.ErrorReturnCode_129: + logger.info( + "Seems like the code is not running from" + " within a git repo, so diff will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # Finally create the Tensorboard Dir + tb_dir = exp_dir / "tb_dir" + tb_dir.mkdir() + return exp_dir + + +###################################################################### +################ Checkpoint Related Functions ######################## +###################################################################### + + +def load_model_checkpoint( + load_checkpoint_dir: pathlib.Path, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, +) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + """Loads the optimizer state dict and model state dict from the load_checkpoint_dir + into the passed model and optimizer. Searches for the most recent checkpoint to + load from + + Args: + load_checkpoint_dir (pathlib.Path): + The base checkpoint directory to load from + model (torch.nn.Module): + The model to load the checkpoint weights into + optimizer (torch.optim.Optimizer): + The optimizer to load the checkpoint weigths into + + Returns: + Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + The checkpoint step, model with state_dict loaded and + optimizer with state_dict loaded + + """ + logger.info(f"Loading model and optimizer checkpoint from {load_checkpoint_dir}") + checkpoint_files = list( + filter( + lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) is not None, + load_checkpoint_dir.glob("*.pt"), + ) + ) + assert len(checkpoint_files) > 0, "No checkpoints found in directory" + checkpoint_files = sorted( + checkpoint_files, + key=lambda path: int( + re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") + ), + ) + latest_checkpoint_path = checkpoint_files[-1] + checkpoint_step = int( + re.search(r"iter_(?P\d+)\.pt", latest_checkpoint_path.name).group( + "iter_no" + ) + ) + + state_dict = torch.load(latest_checkpoint_path) + model.load_state_dict(state_dict["model"], strict=True) + optimizer.load_state_dict(state_dict["optimizer"]) + logger.info( + f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + ) + return checkpoint_step, model, optimizer + + +###################################################################### +######################## Driver Functions ############################ +###################################################################### + + +def train( + checkpoint_dir: str = None, + load_checkpoint_dir: str = None, + # Dataset Parameters + mask_prob: float = 0.15, + random_replace_prob: float = 0.1, + unmask_replace_prob: float = 0.1, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", + # Model Parameters + num_layers: int = 6, + num_heads: int = 8, + ff_dim: int = 512, + h_dim: int = 256, + dropout: float = 0.1, + # Training Parameters + batch_size: int = 8, + num_iterations: int = 10000, + checkpoint_every: int = 1000, + log_every: int = 10, + local_rank: int = -1, +) -> pathlib.Path: + """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) + (transformer encoder only) model for MLM Task + + Args: + checkpoint_dir (str): + The base experiment directory to save experiments to + mask_prob (float, optional): + The fraction of tokens to mask. Defaults to 0.15. + random_replace_prob (float, optional): + The fraction of masked tokens to replace with random token. + Defaults to 0.1. + unmask_replace_prob (float, optional): + The fraction of masked tokens to leave unchanged. + Defaults to 0.1. + max_seq_length (int, optional): + The maximum sequence length of the examples. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + num_layers (int, optional): + The number of layers in the Bert model. Defaults to 6. + num_heads (int, optional): + Number of attention heads to use. Defaults to 8. + ff_dim (int, optional): + Size of the intermediate dimension in the FF layer. + Defaults to 512. + h_dim (int, optional): + Size of intermediate representations. + Defaults to 256. + dropout (float, optional): + Amout of Dropout to use. Defaults to 0.1. + batch_size (int, optional): + The minibatch size. Defaults to 8. + num_iterations (int, optional): + Total number of iterations to run the model for. + Defaults to 10000. + checkpoint_every (int, optional): + Save checkpoint after these many steps. + + ..note :: + + You want this to be frequent enough that you can + resume training in case it crashes, but not so much + that you fill up your entire storage ! + + Defaults to 1000. + log_every (int, optional): + Print logs after these many steps. Defaults to 10. + local_rank (int, optional): + Which GPU to run on (-1 for CPU). Defaults to -1. + + Returns: + pathlib.Path: The final experiment directory + + """ + device = ( + torch.device("cuda", local_rank) + if (local_rank > -1) and torch.cuda.is_available() + else torch.device("cpu") + ) + ################################ + ###### Create Exp. Dir ######### + ################################ + if checkpoint_dir is None and load_checkpoint_dir is None: + logger.error("Need to specify one of checkpoint_dir" " or load_checkpoint_dir") + return + if checkpoint_dir is not None and load_checkpoint_dir is not None: + logger.error("Cannot specify both checkpoint_dir" " and load_checkpoint_dir") + return + if checkpoint_dir: + logger.info("Creating Experiment Directory") + checkpoint_dir = pathlib.Path(checkpoint_dir) + checkpoint_dir.mkdir(exist_ok=True) + all_arguments = { + # Dataset Params + "mask_prob": mask_prob, + "random_replace_prob": random_replace_prob, + "unmask_replace_prob": unmask_replace_prob, + "max_seq_length": max_seq_length, + "tokenizer": tokenizer, + # Model Params + "num_layers": num_layers, + "num_heads": num_heads, + "ff_dim": ff_dim, + "h_dim": h_dim, + "dropout": dropout, + # Training Params + "batch_size": batch_size, + "num_iterations": num_iterations, + "checkpoint_every": checkpoint_every, + } + exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) + logger.info(f"Experiment Directory created at {exp_dir}") + else: + logger.info("Loading from Experiment Directory") + load_checkpoint_dir = pathlib.Path(load_checkpoint_dir) + assert load_checkpoint_dir.exists() + with (load_checkpoint_dir / "hparams.json").open("r") as handle: + hparams = json.load(handle) + # Set the hparams + # Dataset Params + mask_prob = hparams.get("mask_prob", mask_prob) + tokenizer = hparams.get("tokenizer", tokenizer) + random_replace_prob = hparams.get("random_replace_prob", random_replace_prob) + unmask_replace_prob = hparams.get("unmask_replace_prob", unmask_replace_prob) + max_seq_length = hparams.get("max_seq_length", max_seq_length) + # Model Params + ff_dim = hparams.get("ff_dim", ff_dim) + h_dim = hparams.get("h_dim", h_dim) + dropout = hparams.get("dropout", dropout) + num_layers = hparams.get("num_layers", num_layers) + num_heads = hparams.get("num_heads", num_heads) + # Training Params + batch_size = hparams.get("batch_size", batch_size) + _num_iterations = hparams.get("num_iterations", num_iterations) + num_iterations = max(num_iterations, _num_iterations) + checkpoint_every = hparams.get("checkpoint_every", checkpoint_every) + exp_dir = load_checkpoint_dir + # Tensorboard writer + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) + ################################ + ###### Create Datasets ######### + ################################ + logger.info("Creating Datasets") + data_iterator = create_data_iterator( + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + batch_size=batch_size, + ) + logger.info("Dataset Creation Done") + ################################ + ###### Create Model ############ + ################################ + logger.info("Creating Model") + model = create_model( + num_layers=num_layers, + num_heads=num_heads, + ff_dim=ff_dim, + h_dim=h_dim, + dropout=dropout, + ) + logger.info("Model Creation Done") + ################################ + ###### DeepSpeed engine ######## + ################################ + logger.info("Creating DeepSpeed engine") + ds_config = { + "train_micro_batch_size_per_gpu": batch_size, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "cpu" + } + } + } + model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) + logger.info("DeepSpeed engine created") + ################################ + #### Load Model checkpoint ##### + ################################ + start_step = 1 + if load_checkpoint_dir is not None: + _, client_state = model.load_checkpoint(load_dir=load_checkpoint_dir) + checkpoint_step = client_state['checkpoint_step'] + start_step = checkpoint_step + 1 + + ################################ + ####### The Training Loop ###### + ################################ + logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") + model.train() + losses = [] + for step, batch in enumerate(data_iterator, start=start_step): + if step >= num_iterations: + break + # Move the tensors to device + for key, value in batch.items(): + batch[key] = value.to(device) + # Forward pass + loss = model(**batch) + # Backward pass + model.backward(loss) + # Optimizer Step + model.step() + losses.append(loss.item()) + if step % log_every == 0: + logger.info("Loss: {0:.4f}".format(np.mean(losses))) + summary_writer.add_scalar(f"Train/loss", np.mean(losses), step) + if step % checkpoint_every == 0: + model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) + logger.info( + "Saved model to {0}".format(exp_dir) + ) + # Save the last checkpoint if not saved yet + if step % checkpoint_every != 0: + model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) + logger.info( + "Saved model to {0}".format(exp_dir) + ) + + return exp_dir + + +if __name__ == "__main__": + fire.Fire(train) diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md new file mode 100644 index 000000000..a80e3510c --- /dev/null +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md @@ -0,0 +1 @@ +This is an example of how to use DeepSpeed's curriculum learning (CL) feature which provides faster and more stable language model pre-training. Currently it is only integrated for GPT pre-training. Note that there are two curriculum learning examples in two different repos for Megatron-LM GPT-2 pre-training. Both of them have some unique features and limitations. See details in our [tutorial](https://www.deepspeed.ai/tutorials/curriculum-learning/). For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). \ No newline at end of file diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh index 959af6813..338b93f42 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh @@ -11,7 +11,7 @@ SEED=$8 SAVE_INTERVAL=$9 NUM_ITER=${10} NUM_TOKEN=${11} -LR_DECAY_ITER=${12} +LR_DECAY_TOKEN=${12} LR_WARMUP_ITER=${13} CONFIG_TEMPLATE=${14} CURRICULUM_STEP=${15} @@ -74,7 +74,7 @@ else config_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}.json" fi -JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_ITER}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}" +JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_TOKEN}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}" LOG_NAME="${JOB_NAME}_${host}_${current_time}" #Actication Checkpointing and Contigious Memory @@ -102,7 +102,7 @@ gpt_options=" \ --batch-size $BATCHSIZE \ --train-iters $NUM_ITER \ --train-tokens $NUM_TOKEN \ - --lr-decay-iters $LR_DECAY_ITER \ + --lr-decay-tokens $LR_DECAY_TOKEN \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh index ff7e7e58b..aac11ab03 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh @@ -8,10 +8,9 @@ # MP_SIZE=1 # SEED=1234 # SAVE_INTERVAL=5000 - # NUM_ITER=600000 # NUM_TOKEN=157286400000 -# LR_DECAY_ITER=300000 +# LR_DECAY_TOKEN=157286400000 # LR_WARMUP_ITER=3000 # CONFIG_TEMPLATE=false # CURRICULUM_STEP=0 @@ -26,15 +25,13 @@ SEQ_LEN=1024 MP_SIZE=1 SEED=1234 SAVE_INTERVAL=1000 - NUM_ITER=75000 NUM_TOKEN=157286400000 +LR_DECAY_TOKEN=157286400000 LR_WARMUP_ITER=3000 CONFIG_TEMPLATE=true -CURRICULUM_STEP=15000 +CURRICULUM_STEP=45000 CURRICULUM_MIN=64 - -LR_DECAY_ITER=$((37500 + ${CURRICULUM_STEP} / 2)) TAG="${CONFIG}_s${CURRICULUM_MIN}to${SEQ_LEN}_step${CURRICULUM_STEP}" -bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_ITER $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN +bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_TOKEN $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py index bb1cb6779..f95020af5 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py @@ -296,6 +296,8 @@ def _add_learning_rate_args(parser): group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') + group.add_argument('--lr-decay-tokens', type=int, default=None, + help='Learning rate decay tokens.') group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py index 123a805e1..fd8eb6cb1 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py @@ -268,7 +268,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): else: try: iteration = state_dict['iteration'] - args.tokens = state_dict['tokens'] + if 'tokens' in state_dict: + args.tokens = state_dict['tokens'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py index afc5a8d77..19be32b8c 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py @@ -17,7 +17,7 @@ import math -from megatron import print_rank_0 +from megatron import print_rank_0, get_args class AnnealingLR(object): @@ -28,7 +28,7 @@ def __init__(self, optimizer, start_lr, decay_style, last_iter, min_lr=0.0, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): - + args = get_args() # Class values. self.optimizer = optimizer self.start_lr = start_lr @@ -37,6 +37,9 @@ def __init__(self, optimizer, start_lr, self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 + self.lr_decay_tokens = args.lr_decay_tokens + self.num_tokens = 0 + self.warmup_tokens = 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler @@ -44,7 +47,7 @@ def __init__(self, optimizer, start_lr, assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate - self.step(self.num_iters) + self.step(self.num_iters, self.num_tokens) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) @@ -53,16 +56,26 @@ def get_lr(self): https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" # Warmup. - if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: - return float(self.start_lr) * self.num_iters / self.warmup_iter - - # For any iterations larger than `self.end_iter`, use `self.min_lr`. - if self.num_iters > self.end_iter: - return self.min_lr - # If we are done with the warmup period, use the decay style. - current_iter = self.num_iters - self.warmup_iter - decay_iter = self.end_iter - self.warmup_iter - decay_ratio = float(current_iter) / float(decay_iter) + if self.warmup_iter > 0: + if self.num_iters == self.warmup_iter and self.lr_decay_tokens is not None: + self.warmup_tokens = self.num_tokens + if self.num_iters <= self.warmup_iter: + return float(self.start_lr) * self.num_iters / self.warmup_iter + + if self.lr_decay_tokens is None: + # For any iterations larger than `self.end_iter`, use `self.min_lr`. + if self.num_iters > self.end_iter: + return self.min_lr + # If we are done with the warmup period, use the decay style. + current_iter = self.num_iters - self.warmup_iter + decay_iter = self.end_iter - self.warmup_iter + decay_ratio = float(current_iter) / float(decay_iter) + else: + if self.num_tokens > self.lr_decay_tokens: + return self.min_lr + current_tokens = self.num_tokens - self.warmup_tokens + decay_tokens = self.lr_decay_tokens - self.warmup_tokens + decay_ratio = float(current_tokens) / float(decay_tokens) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 @@ -78,11 +91,15 @@ def get_lr(self): lr = self.start_lr return max(lr, self.min_lr) - def step(self, step_num=None): + def step(self, step_num=None, token_num=None): """Set lr for all parameters groups.""" + args = get_args() if step_num is None: step_num = self.num_iters + 1 + if token_num is None: + token_num = args.tokens self.num_iters = step_num + self.num_tokens = token_num new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr @@ -92,6 +109,8 @@ def state_dict(self): 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, + 'warmup_tokens': self.warmup_tokens, + 'num_tokens': self.num_tokens, 'decay_style': self.decay_style, 'end_iter': self.end_iter, 'min_lr': self.min_lr @@ -128,4 +147,8 @@ def load_state_dict(self, sd): 'decay style') self.num_iters = sd['num_iters'] - self.step(self.num_iters) + if 'warmup_tokens' in sd: + self.warmup_tokens = sd['warmup_tokens'] + if 'num_tokens' in sd: + self.num_tokens = sd['num_tokens'] + self.step(self.num_iters, self.num_tokens) diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py index be4da2202..0671f393d 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py @@ -55,8 +55,8 @@ def __init__(self, num_tokentypes=0, parallel_output=True): def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, forward_method_parallel_output=None, curriculum_seqlen=None): + args = get_args() if curriculum_seqlen is not None: - args = get_args() args.curriculum_seqlen = curriculum_seqlen if curriculum_seqlen < input_ids.size()[1]: # seqlen-based curriculum learning @@ -67,6 +67,10 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None, # attention_mask has size [1, 1, seqlen, seqlen] attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous() + else: + if args.curriculum_learning: + # If got a None input, need to reset curriculum_seqlen on user side + args.curriculum_seqlen = args.seq_length # Language model. lm_output = self.language_model(input_ids, diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py index 8fc8791ee..5cb17d1a0 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py @@ -327,7 +327,7 @@ def train_step(forward_step_func, data_iterator, #see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() - loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning) + loss, loss_reduced = forward_step_func(data_iterator, model) timers('forward').stop() #see_memory_usage(f'before backward {model.global_steps}', force=True) diff --git a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py index 026702d95..86aac0c9a 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py +++ b/Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py @@ -86,7 +86,7 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids -def forward_step(data_iterator, model, curriculum_learning=False): +def forward_step(data_iterator, model): """Forward step.""" args = get_args() timers = get_timers() @@ -98,7 +98,7 @@ def forward_step(data_iterator, model, curriculum_learning=False): timers('batch generator').stop() # Forward model. losses = model(tokens, position_ids, attention_mask, labels=labels) - if curriculum_learning and args.curriculum_seqlen < args.seq_length: + if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() diff --git a/MoQ/huggingface-transformers/tests/fixtures/tests_samples/GermEval/dev.txt b/MoQ/huggingface-transformers/tests/fixtures/tests_samples/GermEval/dev.txt index de0015823..1aba64f7a 100644 --- a/MoQ/huggingface-transformers/tests/fixtures/tests_samples/GermEval/dev.txt +++ b/MoQ/huggingface-transformers/tests/fixtures/tests_samples/GermEval/dev.txt @@ -10,7 +10,7 @@ homo I-OTH " O in O enger O -Auseinandersetzung O +Ause inandersetzung O mit O diesem O Bild O diff --git a/README.md b/README.md index f39f55128..1da997098 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ This repo contains example models that use [DeepSpeed](https://github.com/micros # Note on Megatron examples +NOTE: We are in the process of deprecating the 3 Megatron-LM snapshots in this repo. Our current and future features with Megatron-LM will use the [Megatron-DeepSpeed fork](https://github.com/microsoft/Megatron-DeepSpeed). Currently the Megatron-DeepSpeed fork supports 3D parallelism + ZeRO Stage 1 and Curriculum Learning. Please see this new fork for further updates in the process. + Megatron-LM : This is a fairly old snapshot of Megatron-LM , and we have been using it show case the earlier features of DeepSpeed. This does not contain ZeRO-3 or 3D parallelism. Megatron-LM-v1.1.5-3D_parallelism: This is a relatively new Megatron (Oct 2020), but before Megatron started supporting 3D parallelism. We ported this version to showcase how to use 3D parallelism inside DeepSpeed with Megatron. diff --git a/autotuning/.gitignore b/autotuning/.gitignore new file mode 100644 index 000000000..82319e4a0 --- /dev/null +++ b/autotuning/.gitignore @@ -0,0 +1,4 @@ +autotuning_results* +autotuning_exps* +output* +mnli diff --git a/autotuning/README.md b/autotuning/README.md new file mode 100644 index 000000000..d028a945e --- /dev/null +++ b/autotuning/README.md @@ -0,0 +1,3 @@ +# Autotuning Examples + +This showcases the [autotuning](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning) feature in DeepSpeed (DS). diff --git a/autotuning/hf/README.md b/autotuning/hf/README.md new file mode 100644 index 000000000..567deda04 --- /dev/null +++ b/autotuning/hf/README.md @@ -0,0 +1,62 @@ +# Autotuning Hugging Face Examples + +This showcases the [autotuning](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/autotuning) feature in DeepSpeed (DS) with Hugging Face (HF) models. + +## List of Models + +- [DistilBERT](distilbert) +- [BERT-base](bert-base) +- [BERT-large](bert-large) +- [GPT2](gpt2) +- [GPT2-medium](gpt2-medium) +- [GPT2-large](gpt2-large) +- [GPT2-xl](gpt2-xl) +- [DeBERTa](deberta) + +Each model folder has a `test_tune.sh` script: + +- `./test_tune.sh tune` tunes the model training and then runs it using the selected tuned DeepSpeed configuration. +- `./test_tune.sh 0` runs the model using HF without DeepSpeed. +- `./test_tune.sh z0` runs the model using HF + DS with ZeRO optimization disabled. +- `./test_tune.sh z1` runs the model using HF + DS with ZeRO optimization stage 1. +- `./test_tune.sh z2` runs the model using HF + DS with ZeRO optimization stage 2. +- `./test_tune.sh z3` runs the model using HF + DS with ZeRO optimization stage 3. + + +## Testing Environment + +The training runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | num_params | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | throughput improvement over baseline | autotuning time (mins) | number of experiments | +| :----------: | :--------: | :---------------------------: | :----------------------------------: | :----------------------------: | :----------------------------------: | :--------------------: | :-------------------: | +| DistilBERT | 66M | 5161.902 (gas = 1, mbs = 256) | 5305.067 (z = 0, gas = 1 mbs = 256) | 5305.067 (z0_gas1_tmbspg256) | 1.03x | 11 | 11 | +| BERT-base | 0.11B | 2502.236 (gas = 1,mbs = 128) | 2523.684 (z = 0, gas = 1, mbs = 128) | 2736.561 (z0_gas1_tmbspg235) | 1.09x | 35 | 34 | +| BERT-large | 0.34B | 742.692 (gas = 1,mbs = 64) | 766.929 (z = 1, gas = 1, mbs = 64) | 808.168 (z1_gas1_tmbspg93) | 1.09x | 36 | 22 | +| GPT2 | 0.12B | 284.142 (gas = 1,mbs = 8) | 397.827 (z = 1, gas = 1, mbs = 8) | 431.586 (z1_gas1_tmbspg14) | 1.52x | 25 | 17 | +| GPT2-medium | 0.35B | 71.61 (gas = 1, mbs = 2) | 142.211 (z = 1, gas = 1, mbs = 4) | 163.3 (z1_gas1_tmbspg6) | 2.28 | 15 | 25 | +| GPT2-large | 0.77B | 27.874 (gas = 1, mbs = 1) | 56.797 (z = 1, gas = 1, mbs = 2) | 69.061 (z = 1, mbs = 3) | 2.48x | 27 | 13 | +| GPT2-xl | 1.5B | Not runnable | 27.462 (gas = 1, mbs = 1) | 27.497 (z1_gas1_tmbspg1) | inf | 21 | 9 | +| DeBERTa | 1.5B | Not runnable | 140.587 (z = 1, gas = 1 mbs = 8) | 162.395 (z1_gas1_tmbspg11) | inf | 40 | 12 | diff --git a/autotuning/hf/bert-base/README.md b/autotuning/hf/bert-base/README.md new file mode 100644 index 000000000..02450fdd3 --- /dev/null +++ b/autotuning/hf/bert-base/README.md @@ -0,0 +1,58 @@ +# [bert-base-cased](https://huggingface.co/bert-base-cased) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 110M parameters. + +## Environment + +The training use fp32 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is set to `4096`. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS handtuned | HF + DS autotuning | +| ---------- | ----------------------------- | ------------------------------------ | ---------------------------- | +| BERT-base | 2502.236 (gas = 1, mbs = 128) | 2523.684 (z = 0, gas = 1, mbs = 128) | 2736.561 (z0_gas1_tmbspg235) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 35 mins +- Number of experiments: 34 +- Throughput Improvement over baseline: 1.09x + + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :---------------- | +| z0 | 9 | 2930.18 | z0_gas1_tmbspg235 | +| z1 | 7 | 2930.17 | z1_gas1_tmbspg235 | +| z2 | 8 | 2744.16 | z2_gas1_tmbspg235 | +| z3 | 10 | 2479.47 | z3_gas1_tmbspg238 | +| global | 34 | 2930.18 | z0_gas1_tmbspg235 | + +Tuning completed in 0:34:41.842250. Total number of experiments: 34. diff --git a/autotuning/hf/bert-base/ds_config_tune.json b/autotuning/hf/bert-base/ds_config_tune.json new file mode 100644 index 000000000..23a48ddf9 --- /dev/null +++ b/autotuning/hf/bert-base/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "max_train_batch_size": 4096, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/autotuning/hf/bert-base/test_tune.sh b/autotuning/hf/bert-base/test_tune.sh new file mode 100755 index 000000000..532efc902 --- /dev/null +++ b/autotuning/hf/bert-base/test_tune.sh @@ -0,0 +1,114 @@ +TASK_NAME=mnli +MODEL_NAME=bert-base-cased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z0 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z1 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z2 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z3 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/bert-large/README.md b/autotuning/hf/bert-large/README.md new file mode 100644 index 000000000..157dba0c1 --- /dev/null +++ b/autotuning/hf/bert-large/README.md @@ -0,0 +1,55 @@ +# [bert-large-uncased](https://huggingface.co/bert-large-uncased) + +This model has the following configuration: + +- 24-layer +- 1024 hidden dimension +- 16 attention heads +- 336M parameters + +The training use fp32 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS handtuned | HF + DS autotuning | +| ---------- | --------------------------- | --------------------------------- | -------------------------- | +| BERT-large | 742.692 (gas = 1, mbs = 64) | 766.929 (z = 1, gas =1, mbs = 64) | 808.168 (z1_gas1_tmbspg93) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 36 mins +- Number of experiments: 22 +- Throughput Improvement over baseline: 1.09x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 6 | 835.244 | z0_gas1_tmbspg93 | +| z1 | 6 | 842.243 | z1_gas1_tmbspg93 | +| z2 | 9 | 764.524 | z2_gas1_tmbspg94 | +| z3 | 1 | 0 | z3_gas1_tmbspg94 | +| global | 22 | 842.243 | z1_gas1_tmbspg93 | + +Tuning completed in 0:36:16.261417. Total number of experiments: 23. diff --git a/autotuning/hf/bert-large/ds_config_tune.json b/autotuning/hf/bert-large/ds_config_tune.json new file mode 100644 index 000000000..e79f9c450 --- /dev/null +++ b/autotuning/hf/bert-large/ds_config_tune.json @@ -0,0 +1,11 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/autotuning/hf/bert-large/test_tune.sh b/autotuning/hf/bert-large/test_tune.sh new file mode 100755 index 000000000..e63f917b8 --- /dev/null +++ b/autotuning/hf/bert-large/test_tune.sh @@ -0,0 +1,114 @@ +TASK_NAME=mnli +MODEL_NAME=bert-large-uncased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z0 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z1 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z2 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_z3 \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --save_steps 0 \ + --overwrite_output_dir \ + --max_steps $MAX_STEPS +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_b${PER_DEVICE_TRAIN_BATCH_SIZE}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/deberta/README.md b/autotuning/hf/deberta/README.md new file mode 100644 index 000000000..9144376cd --- /dev/null +++ b/autotuning/hf/deberta/README.md @@ -0,0 +1,72 @@ +# [deberta-v2-xxlarge-mnli](https://huggingface.co/microsoft/deberta-v2-xxlarge) + +This model has the following configuration: + +- 48-layer +- 1536 hidden dimension +- 1.5B parameters. + +Refer to [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://github.com/microsoft/DeBERTa). +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg), reduce_bucket_size (rbs), allgather_bucket_size (abs). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | --------------------------------- | ------------------------------ | +| DeBERTa | Not runnable | 140.587 (z = 1, gas = 1 mbs = 8), | 162.395 (z1_gas1_tmbspg11) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. +### Fast-mode +- Autotuning time: 40 mins +- Number of experiments: 12 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 1 | 0 | z0_gas1_tmbspg1 | +| z1 | 6 | 177.843 | z1_gas1_tmbspg11 | +| z2 | 4 | 154.002 | z2_gas1_tmbspg14 | +| z3 | 1 | 0 | z3_gas1_tmbspg14 | +| global | 12 | 177.843 | z1_gas1_tmbspg11 | + +Tuning completed in 0:39:25.253998. Total number of experiments: 12. + +### Full-mode ("fast" set to false) +- Autotuning time: 1 hr 2 mins +- Number of experiments: 24 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :---------------- | --------------: | --------------: | :------------------------------------- | +| z0 | 1 | 0 | z0_gas1_tmbspg1 | +| z1 | 6 | 177.843 | z1_gas1_tmbspg11 | +| z1_rbs_abs_tmbspg | 12 | 193.577 | z1_rbs5.0e+07_abs1.0e+09_gas1_tmbspg11 | +| z2 | 4 | 154.002 | z2_gas1_tmbspg14 | +| z3 | 1 | 0 | z3_gas1_tmbspg14 | +| global | 24 | 193.577 | z1_rbs5.0e+07_abs1.0e+09_gas1_tmbspg11 | + +Tuning completed in 1:02:32.759424. Total number of experiments: 24. diff --git a/autotuning/hf/deberta/ds_config_fp16_tune.json b/autotuning/hf/deberta/ds_config_fp16_tune.json new file mode 100644 index 000000000..b405929bb --- /dev/null +++ b/autotuning/hf/deberta/ds_config_fp16_tune.json @@ -0,0 +1,16 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": true, + "initial_scale_power": 12 + }, + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} \ No newline at end of file diff --git a/autotuning/hf/deberta/test_tune.sh b/autotuning/hf/deberta/test_tune.sh new file mode 100755 index 000000000..d4de499ee --- /dev/null +++ b/autotuning/hf/deberta/test_tune.sh @@ -0,0 +1,127 @@ +MODEL_NAME=microsoft/deberta-v2-xxlarge +TASK_NAME=mnli +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --fp16 \ + --max_seq_length 256 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 3e-6 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/distilbert/README.md b/autotuning/hf/distilbert/README.md new file mode 100644 index 000000000..dce99207c --- /dev/null +++ b/autotuning/hf/distilbert/README.md @@ -0,0 +1,69 @@ +# [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 66M parameters. + +## Environment + +The training uses 1 node with 16 Nvidia V100 GPUs, fp32, max_train_batch_size = 4096. The autotuning uses the same hardware resource as the training. `"max_train_batch_size"` is set to `4096`. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | ----------------------------- | ------------------------------------ | ------------------------------ | +| DistilBERT | 5161.902 (gas = 1, mbs = 256) | 5305.067 (z = 0, gas = 1 mbs = 256), | 5305.067 (z0_gas1_tmbspg256) | + +3700.296 + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 11 mins +- Number of experiments: 11 +- Throughput Improvement: 1.03x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :---------------- | +| z0 | 5 | 5759.96 | z0_gas1_tmbspg256 | +| z1 | 2 | 5667.06 | z1_gas1_tmbspg256 | +| z2 | 2 | 5366.97 | z2_gas1_tmbspg256 | +| z3 | 2 | 4892.49 | z3_gas1_tmbspg256 | +| global | 11 | 5759.96 | z0_gas1_tmbspg256 | + +Tuning completed in 0:10:45.085016. Total number of experiments: 11. + + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :----------------- | +| z0 | 7 | 5759.98 | z0_gas22_tmbspg179 | +| z1 | 2 | 5543.49 | z1_gas1_tmbspg269 | +| z2 | 2 | 5044.88 | z2_gas15_tmbspg269 | +| z3 | 2 | 4627.63 | z3_gas1_tmbspg269 | +| global | 13 | 5759.98 | z0_gas22_tmbspg179 | + +Tuning completed in 0:25:44.502148. Total number of experiments: 13. diff --git a/autotuning/hf/distilbert/ds_config_tune.json b/autotuning/hf/distilbert/ds_config_tune.json new file mode 100644 index 000000000..23a48ddf9 --- /dev/null +++ b/autotuning/hf/distilbert/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "max_train_batch_size": 4096, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/autotuning/hf/distilbert/test_tune.sh b/autotuning/hf/distilbert/test_tune.sh new file mode 100755 index 000000000..08b92d56e --- /dev/null +++ b/autotuning/hf/distilbert/test_tune.sh @@ -0,0 +1,119 @@ +TASK_NAME=mnli +MODEL_NAME=distilbert-base-uncased +HF_PATH=~/projects +PER_DEVICE_TRAIN_BATCH_SIZE=64 +MAX_TRAIN_BATCH_SIZE=4096 +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./${TASK_NAME}/output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z0.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z1.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z2.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ../dsconfigs/ds_config_z3.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py --deepspeed ./ds_config_tune.json \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path $MODEL_NAME \ + --task_name $TASK_NAME \ + --do_train \ + --max_seq_length 128 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/dsconfigs/ds_config_fp16_tune.json b/autotuning/hf/dsconfigs/ds_config_fp16_tune.json new file mode 100644 index 000000000..7ae31168b --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_fp16_tune.json @@ -0,0 +1,15 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "fp16": { + "enabled": true + }, + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_fp16_z0.json b/autotuning/hf/dsconfigs/ds_config_fp16_z0.json new file mode 100644 index 000000000..ff375bb3e --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_fp16_z0.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": true + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_fp16_z1.json b/autotuning/hf/dsconfigs/ds_config_fp16_z1.json new file mode 100644 index 000000000..209706d24 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_fp16_z1.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 1 + }, + "fp16": { + "enabled": true + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_fp16_z2.json b/autotuning/hf/dsconfigs/ds_config_fp16_z2.json new file mode 100644 index 000000000..d3782ab14 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_fp16_z2.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 2 + }, + "fp16": { + "enabled": true + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_fp16_z3.json b/autotuning/hf/dsconfigs/ds_config_fp16_z3.json new file mode 100644 index 000000000..d0affd293 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_fp16_z3.json @@ -0,0 +1,9 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 3 + }, + "fp16": { + "enabled": true + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_tune.json b/autotuning/hf/dsconfigs/ds_config_tune.json new file mode 100644 index 000000000..413e19630 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_tune.json @@ -0,0 +1,12 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "autotuning": { + "enabled": true, + "overwrite": false, + "fast": true, + "arg_mappings": { + "train_micro_batch_size_per_gpu": "--per_device_train_batch_size", + "gradient_accumulation_steps ": "--gradient_accumulation_steps" + } + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_z0.json b/autotuning/hf/dsconfigs/ds_config_z0.json new file mode 100644 index 000000000..6247e56c4 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_z0.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 0 + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_z1.json b/autotuning/hf/dsconfigs/ds_config_z1.json new file mode 100644 index 000000000..fd39970a4 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_z1.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 1 + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_z2.json b/autotuning/hf/dsconfigs/ds_config_z2.json new file mode 100644 index 000000000..b898aee82 --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_z2.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 2 + } +} diff --git a/autotuning/hf/dsconfigs/ds_config_z3.json b/autotuning/hf/dsconfigs/ds_config_z3.json new file mode 100644 index 000000000..5b118864e --- /dev/null +++ b/autotuning/hf/dsconfigs/ds_config_z3.json @@ -0,0 +1,6 @@ +{ + "train_micro_batch_size_per_gpu": "auto", + "zero_optimization": { + "stage": 3 + } +} diff --git a/autotuning/hf/gpt2-large/README.md b/autotuning/hf/gpt2-large/README.md new file mode 100644 index 000000000..a736db485 --- /dev/null +++ b/autotuning/hf/gpt2-large/README.md @@ -0,0 +1,59 @@ +# [gpt2-large](https://huggingface.co/gpt2-large) + +This model has the following configuration: + +- 36-layer +- 1280 hidden dimension +- 20 attention heads +- 774M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0)datasets (1.11.0) + +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | ------------------------ | ------------------------------ | +| GPT2-large | 27.874 (mbs = 1) | 56.797 (z = 1, mbs = 2), | 69.061 (z = 1, mbs = 3) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 27 mins +- Number of experiments: 13 +- Throughput Improvement over baseline: 2.48x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z0 | 4 | 59.0229 | z0_gas1_tmbspg2 | +| z1 | 5 | 87.3017 | z1_gas1_tmbspg3 | +| z2 | 3 | 77.8338 | z2_gas1_tmbspg3 | +| z3 | 1 | 0 | z3_gas1_tmbspg3 | +| global | 13 | 87.3017 | z1_gas1_tmbspg3 | + +Tuning completed in 0:27:33.988447. Total number of experiments: 13. diff --git a/autotuning/hf/gpt2-large/test_tune.sh b/autotuning/hf/gpt2-large/test_tune.sh new file mode 100755 index 000000000..c5fa9b608 --- /dev/null +++ b/autotuning/hf/gpt2-large/test_tune.sh @@ -0,0 +1,132 @@ +MODEL_NAME=gpt2-large +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/gpt2-medium/README.md b/autotuning/hf/gpt2-medium/README.md new file mode 100644 index 000000000..e97a1f9b3 --- /dev/null +++ b/autotuning/hf/gpt2-medium/README.md @@ -0,0 +1,57 @@ +# [gpt2-medium](https://huggingface.co/gpt2-medium) + +This model has the following configuration: +- 24-layer +- 1024 hidden dimension +- 16 attention heads +- 345M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ----------- | ------------------------ | --------------------------------- | ------------------------------ | +| GPT2-medium | 71.61 (gas = 1, mbs = 2) | 142.211 (z = 1, gas = 1, mbs = 4) | 163.3 (z1_gas1_tmbspg6) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 25 mins +- Number of experiments: 15 +- Throughput Improvement over baseline: 2.28x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z0 | 6 | 167.688 | z0_gas1_tmbspg5 | +| z1 | 5 | 175.46 | z1_gas1_tmbspg6 | +| z2 | 3 | 161.619 | z2_gas1_tmbspg6 | +| z3 | 1 | 0 | z3_gas1_tmbspg6 | +| global | 15 | 175.46 | z1_gas1_tmbspg6 | + +Tuning completed in 0:25:18.653731. Total number of experiments: 15. diff --git a/autotuning/hf/gpt2-medium/test_tune.sh b/autotuning/hf/gpt2-medium/test_tune.sh new file mode 100755 index 000000000..567deb4ff --- /dev/null +++ b/autotuning/hf/gpt2-medium/test_tune.sh @@ -0,0 +1,142 @@ +MODEL_NAME=gpt2-medium +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --block_size 512 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune_test" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune_test.json \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune_test \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/gpt2-xl/README.md b/autotuning/hf/gpt2-xl/README.md new file mode 100644 index 000000000..f6d81b264 --- /dev/null +++ b/autotuning/hf/gpt2-xl/README.md @@ -0,0 +1,56 @@ +# [gpt2-xl](https://huggingface.co/gpt2-xl) + +This model has the following configuration: +- 48-layer +- 1600 hidden dimension +- 25 attention heads +- 1.5B parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | --------------------------------- | -------------------------------- | +| GPT2-xl | Not runnable | Zero1 (27.462, gas = 1, mbs = 1), | Zero1 (27.497, gas = 1, mbs = 1) | + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 21 mins +- Number of experiments: 9 +- Throughput Improvement over baseline: Inf + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :-------------- | +| z1 | 3 | 40.1749 | z1_gas1_tmbspg1 | +| z2 | 3 | 33.0472 | z2_gas1_tmbspg1 | +| z3 | 3 | 12.8604 | z3_gas1_tmbspg1 | +| global | 9 | 40.1749 | z1_gas1_tmbspg1 | + +Tuning completed in 0:20:55.156000. Total number of experiments: 9. diff --git a/autotuning/hf/gpt2-xl/test_tune.sh b/autotuning/hf/gpt2-xl/test_tune.sh new file mode 100755 index 000000000..3c144635e --- /dev/null +++ b/autotuning/hf/gpt2-xl/test_tune.sh @@ -0,0 +1,142 @@ +MODEL_NAME=gpt2-xl +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=50 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --block_size 512 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune_test" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune_test.json \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune_test \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi diff --git a/autotuning/hf/gpt2/README.md b/autotuning/hf/gpt2/README.md new file mode 100644 index 000000000..bb426910c --- /dev/null +++ b/autotuning/hf/gpt2/README.md @@ -0,0 +1,59 @@ +# [gpt2](https://huggingface.co/gpt2) + +This model has the following configuration: + +- 12-layer +- 768 hidden dimension +- 12 attention heads +- 117M parameters. + +Refer to [GPT-2/GPT and causal language modeling](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling) + +## Environment + +The training use fp16 and runs on 1 node with 16 Nvidia V100 GPUs. The autotuning uses the same hardware resource as the training. `max_train_batch_size` is not defined. +The HF packages below are used. + +HF examples require installing the `transformers` package from source: +```bash + git clone https://github.com/huggingface/transformers.git + cd transformers + pip install . +``` +The `datasets` package can be installed by `pip install datasets` + +Below are the versions used in this test. + +- transformers (4.12.0) +- datasets (1.11.0) +## Throughput Comparison + +The table below shows the throughput (samples per second) comparison. The corresponding train micro-batch size per GPU (mbs or tmbspg) and ZeRO stage used to achieve the throughput value is also shown in the parentheses. Assume the strategy users would use in the handtuning process is to start from `mbs = 1` and increase mbs by 2 each time until running out of GPU memory. + - `baseline` is the vanila HF without DeepSpeed (DS) and mbs is hand-tuned. + - `HF + DS hand-tuned` is HF with DS, and mbs is hand-tuned while other DS configuration uses default values. + - `HF + DS autotuning` is HF with DS, and the DS configuration is selected from autotuning. + +Notation: Hugging Face (HF), DeepSpeed (DS), ZeRO stage (z), gradient accumulation steps (gas), train micro-batch size per GPU (mbs or tmbspg). + +| Model name | baseline (vanila HF) | HF + DS hand-tuned | HF + DS autotuning (fast-mode) | +| ---------- | -------------------- | ------------------------ | ------------------------------ | +| GPT2 | 284.142 (mbs = 8) | 397.827 (z = 1, mbs = 8) | 431.586 (z1_gas1_tmbspg15) | + + +## Detailed `HF + DS autotuning` Result Summary + +Note that the performance metric used in autotuning is calculated using the timings captured within DeepSpeed forward, backward, and step functions. The sum of these timings is less than the actual training step latency, thus the throughput metric values used by autotuning would be higher than the end-to-end throughput in training. + +- Fast-mode Autotuning time: 25 mins +- Number of experiments: 17 +- Throughput Improvement over baseline: 1.52x + +| tuning_space | num_experiments | best_metric_val | best_exp_name | +| :----------- | --------------: | --------------: | :--------------- | +| z0 | 9 | 441.693 | z0_gas1_tmbspg11 | +| z1 | 6 | 452.004 | z1_gas1_tmbspg15 | +| z2 | 1 | 0 | z2_gas1_tmbspg15 | +| z3 | 1 | 0 | z3_gas1_tmbspg15 | +| global | 17 | 452.004 | z1_gas1_tmbspg15 | + +Tuning completed in 0:24:19.976427. Total number of experiments: 17. diff --git a/autotuning/hf/gpt2/test_tune.sh b/autotuning/hf/gpt2/test_tune.sh new file mode 100755 index 000000000..b570c455c --- /dev/null +++ b/autotuning/hf/gpt2/test_tune.sh @@ -0,0 +1,133 @@ +MODEL_NAME=gpt2 +PER_DEVICE_TRAIN_BATCH_SIZE=1 +HF_PATH=~/projects +NEPOCHS=1 +NGPUS=16 +NNODES=1 +MAX_STEPS=200 +OUTPUT_DIR=./output_b${PER_DEVICE_TRAIN_BATCH_SIZE}_g${NGPUS}_$MAX_STEPS + +TEST=$1 + + +if [ ${TEST} == "0" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z0" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z0.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z0 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z1" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z1.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z1 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z2" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z2.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z2 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "z3" ] +then + deepspeed --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_z3.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_z3 \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "tune" ] +then + deepspeed --autotuning run --num_nodes=$NNODES --num_gpus=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py --deepspeed ../dsconfigs/ds_config_fp16_tune.json\ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_tune \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" +elif [ ${TEST} == "fs" ] +then + python -m torch.distributed.launch --nproc_per_node=$NGPUS $HF_PATH/transformers/examples/pytorch/language-modeling/run_clm.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --fp16 \ + --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \ + --learning_rate 2e-5 \ + --num_train_epochs $NEPOCHS \ + --output_dir ${OUTPUT_DIR}_fs \ + --overwrite_output_dir \ + --save_steps 0 \ + --max_steps $MAX_STEPS \ + --save_strategy "no" + --sharded_ddp zero_dp_2 +fi