From cbd3d7c8d3d206ad4f22e5c86d0c4fb1a7e2686b Mon Sep 17 00:00:00 2001 From: bapatra Date: Mon, 18 Oct 2021 03:57:05 -0700 Subject: [PATCH 01/18] tutorial work in progress --- Intro101/README.md | 0 Intro101/requirements.txt | 6 + Intro101/train_bert.py | 428 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 434 insertions(+) create mode 100644 Intro101/README.md create mode 100644 Intro101/requirements.txt create mode 100644 Intro101/train_bert.py diff --git a/Intro101/README.md b/Intro101/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/Intro101/requirements.txt b/Intro101/requirements.txt new file mode 100644 index 000000000..66ec19547 --- /dev/null +++ b/Intro101/requirements.txt @@ -0,0 +1,6 @@ +datasets==1.13.3 +transformers==4.5.1 +fire==0.4.0 +pytz==2021.1 +loguru==0.5.3 +sh==1.14.2 \ No newline at end of file diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py new file mode 100644 index 000000000..dc22e9ac4 --- /dev/null +++ b/Intro101/train_bert.py @@ -0,0 +1,428 @@ +from typing import Iterable, Dict, Any, Callable, Tuple, List, Union, Optional +import fire +import pathlib +import uuid +import datetime +import pytz +import json +import numpy as np +from functools import partial +import loguru +import sh + +import torch +import torch.nn as nn +from torch.utils.data import Dataset, DataLoader +from torch.utils.tensorboard import SummaryWriter + +import datasets +from transformers import AutoTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.models.roberta import RobertaConfig, RobertaModel +from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaPreTrainedModel + +logger = loguru.logger + +###################################################################### +############### Dataset Creation Related Functions ################### +###################################################################### + + +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + +def collate_function(batch: List[Tuple[List[int], List[int]]], pad_token_id: int) -> Dict[str, torch.Tensor]: + max_length = max( + len(token_ids) + for token_ids, _ in batch + ) + padded_token_ids = [ + token_ids + [pad_token_id for _ in range(0, max_length - len(token_ids))] + for token_ids, _ in batch + ] + padded_labels = [ + labels + [pad_token_id for _ in range(0, max_length - len(labels))] + for _, labels in batch + ] + src_tokens = torch.LongTensor(padded_token_ids) + tgt_tokens = torch.LongTensor(padded_labels) + attention_mask = src_tokens.ne(pad_token_id).type_as(src_tokens) + return { + "src_tokens": src_tokens, + "tgt_tokens": tgt_tokens, + "attention_mask": attention_mask + } + +def masking_function(text: str, + tokenizer: TokenizerType, + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + max_length: int) -> Tuple[List[int], List[int]]: + # Note: By default, encode does add the BOS and EOS token + # Disabling that behaviour to make this more clear + tokenized_ids = [tokenizer.bos_token_id] + \ + tokenizer.encode(text, + add_special_tokens=False, + truncation=True, + max_length=max_length - 2) + \ + [tokenizer.eos_token_id] + seq_len = len(tokenized_ids) + tokenized_ids = np.array(tokenized_ids) + subword_mask = np.full(len(tokenized_ids), False) + + # Masking the BOS and EOS token leads to slightly worse performance + low = 1 + high = len(subword_mask) - 1 + mask_choices = np.arange(low, high) + num_subwords_to_mask = max(int((mask_prob * (high - low)) + np.random.rand()), 1) + subword_mask[np.random.choice(mask_choices, num_subwords_to_mask, replace=False)] = True + + # Create the labels first + labels = np.full(seq_len, tokenizer.pad_token_id) + labels[subword_mask] = tokenized_ids[subword_mask] + + tokenized_ids[subword_mask] = tokenizer.mask_token_id + + # Now of the masked tokens, choose how many to replace with random and how many to unmask + rand_or_unmask_prob = random_replace_prob + unmask_replace_prob + if rand_or_unmask_prob > 0: + rand_or_unmask = subword_mask & (np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob) + if random_replace_prob == 0: + unmask = rand_or_unmask + rand_mask = None + elif unmask_replace_prob == 0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = unmask_replace_prob / rand_or_unmask_prob + decision = np.random.rand(len(tokenized_ids)) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + if unmask is not None: + tokenized_ids[unmask] = labels[unmask] + if rand_mask is not None: + weights = np.ones(tokenizer.vocab_size) + weights[tokenizer.all_special_ids] = 0 + probs = weights / weights.sum() + num_rand = rand_mask.sum() + tokenized_ids[rand_mask] = np.random.choice( + tokenizer.vocab_size, + num_rand, + p=probs + ) + return tokenized_ids.tolist(), labels.tolist() + +class WikiTextMLMDataset(Dataset): + def __init__(self, + dataset: datasets.arrow_dataset.Dataset, + masking_function: Callable[[str], Tuple[List[int], List[int]]]) -> None: + self.dataset = dataset + self.masking_function = masking_function + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: + tokens, labels = self.masking_function(self.dataset[idx]["text"]) + return (tokens, labels) + +def create_data_iterator(mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + batch_size: int, + max_seq_length: int = 512, + tokenizer: str = "roberta-base") -> DataLoader: + wikitext_dataset = datasets.load_dataset( + "wikitext", + "wikitext-2-v1", + split="train" + ) + wikitext_dataset = wikitext_dataset.filter( + lambda record: record["text"] != "" + ).map( + lambda record: {"text": record["text"].rstrip("\n")} + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + masking_function_partial = partial( + masking_function, + tokenizer=tokenizer, + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + max_length=max_seq_length + ) + dataset = WikiTextMLMDataset(wikitext_dataset, masking_function_partial) + collate_fn_partial = partial( + collate_function, + pad_token_id=tokenizer.pad_token_id + ) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn_partial + ) + return dataloader + + +###################################################################### +############### Model Creation Related Functions ##################### +###################################################################### + +class RobertaLMHeadWithMaskedPredict(RobertaLMHead): + def __init__(self, + config, + embedding_weight: Optional[torch.Tensor] = None) -> None: + super(RobertaLMHeadWithMaskedPredict, self).__init__(config) + if embedding_weight is not None: + self.decoder.weight = embedding_weight + + def forward( # pylint: disable=arguments-differ + self, + features: torch.Tensor, + masked_token_indices: Optional[ + torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """The current ``Transformers'' library does not provide support + for masked_token_indices. This function provides the support + + Args: + masked_token_indices (torch.Tensor, optional): + The indices of masked tokens for index select. Defaults to None. + + Returns: + torch.Tensor: The output logits + + """ + if masked_token_indices is not None: + features = torch.index_select( + features.view(-1, features.shape[-1]), 0, masked_token_indices + ) + return super().forward(features) + +class RobertaMLMModel(RobertaPreTrainedModel): + def __init__(self, + config: RobertaConfig, + encoder: RobertaModel) -> None: + super().__init__(config) + self.encoder = encoder + self.lm_head = RobertaLMHeadWithMaskedPredict( + config, self.encoder.embeddings.word_embeddings.weight + ) + self.lm_head.apply(self._init_weights) + + def forward(self, + src_tokens, + attention_mask, + tgt_tokens) -> torch.Tensor: + sequence_output, *_ = self.encoder(input_ids=src_tokens, + attention_mask=attention_mask, return_dict=False) + + pad_token_id = self.config.pad_token_id + # (labels have also been padded with pad_token_id) + # filter out all masked labels + masked_token_indexes = torch.nonzero( + (tgt_tokens != pad_token_id).view(-1)).view(-1) + + prediction_scores = self.lm_head(sequence_output, + masked_token_indexes) + + target = torch.index_select(tgt_tokens.view(-1), 0, + masked_token_indexes) + + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), target) + return masked_lm_loss + +def create_model(num_layers: int, + num_heads: int, + ff_dim: int, + h_dim: int, + dropout: float) -> RobertaModel: + roberta_config_dict = { + "attention_probs_dropout_prob": dropout, + "bos_token_id": 0, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": dropout, + "hidden_size": h_dim, + "initializer_range": 0.02, + "intermediate_size": ff_dim, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 514, + "model_type": "roberta", + "num_attention_heads": num_heads, + "num_hidden_layers": num_layers, + "pad_token_id": 1, + "type_vocab_size": 1, + "vocab_size": 50265 + } + roberta_config = RobertaConfig.from_dict(roberta_config_dict) + roberta_encoder = RobertaModel(roberta_config) + roberta_model = RobertaMLMModel(roberta_config, roberta_encoder) + return roberta_model + + +###################################################################### +########### Experiment Management Related Functions ################## +###################################################################### + +def create_experiment_dir(checkpoint_dir: pathlib.Path, + all_arguments: Dict[str, Any]): + # experiment name follows the following convention + # {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + current_time = datetime.datetime.now(pytz.timezone("US/Pacific")) + expname = "bert_pretrain.{0}.{1}.{2}.{3}.{4}.{5}.{6}".format( + current_time.year, + current_time.month, + current_time.day, + current_time.hour, + current_time.minute, + current_time.second, + str(uuid.uuid4()) + ) + exp_dir = (checkpoint_dir / expname) + exp_dir.mkdir(exist_ok=False) + hparams_file = exp_dir / "hparams.json" + with hparams_file.open("w") as handle: + json.dump(obj=all_arguments, fp=handle, indent=2) + # Save the git hash + gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) + with (exp_dir / "githash.log").open("w") as handle: + handle.write(gitlog.stdout.decode("utf-8")) + # And the git diff + gitdiff = sh.git.diff(_fg=False, _tty_out=False) + with (exp_dir / "gitdiff.log").open("w") as handle: + handle.write(gitdiff.stdout.decode("utf-8")) + # Finally create the Tensorboard Dir + tb_dir = exp_dir / "tb_dir" + tb_dir.mkdir() + return exp_dir + +###################################################################### +####################### Driver Functions ############################# +###################################################################### + +def train( + checkpoint_dir: str, + # Dataset Parameters + mask_prob: float = 0.15, + random_replace_prob: float = 0.1, + unmask_replace_prob: float = 0.1, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", + # Model Parameters + num_layers: int = 6, + num_heads: int = 8, + ff_dim: int = 512, + h_dim: int = 256, + dropout: float = 0.1, + # Training Parameters + batch_size: int = 8, + num_iterations: int = 10000, + checkpoint_every: int = 1000, + log_every: int = 10, + device: int = -1 +) -> None: + device = torch.device("cuda", device) \ + if (device > 0) and torch.cuda.is_available() \ + else torch.device("cpu") + ################################ + ###### Create Datasets ######### + ################################ + logger.info("Creating Datasets") + data_iterator = create_data_iterator( + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + batch_size=batch_size + ) + logger.info("Dataset Creation Done") + ################################ + ###### Create Model ############ + ################################ + logger.info("Creating Model") + model = create_model( + num_layers=num_layers, + num_heads=num_heads, + ff_dim=ff_dim, + h_dim=h_dim, + dropout=dropout + ) + model = model.to(device) + logger.info("Model Creation Done") + ################################ + ###### Create Exp. Dir ######### + ################################ + logger.info("Creating Experiment Directory") + checkpoint_dir = pathlib.Path(checkpoint_dir) + checkpoint_dir.mkdir(exist_ok=True) + all_arguments = { + # Dataset Params + "mask_prob": mask_prob, + "random_replace_prob": random_replace_prob, + "unmask_replace_prob": unmask_replace_prob, + "max_seq_length": max_seq_length, + "tokenizer": tokenizer, + # Model Params + "num_layers": num_layers, + "num_heads": num_heads, + "ff_dim": ff_dim, + "h_dim": h_dim, + "dropout": dropout, + # Training Params + "batch_size": batch_size, + "num_iterations": num_iterations, + "checkpoint_every": checkpoint_every, + } + exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) + logger.info(f"Experiment Directory created at {exp_dir}") + ################################ + ###### Create Optimizer ####### + ################################ + logger.info("Creating Optimizer") + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) + logger.info("Optimizer Creation Done") + ################################ + ####### The Training Loop ###### + ################################ + losses = [] + for step, batch in enumerate(data_iterator, start=1): + optimizer.zero_grad() + # Move the tensors to device + for key, value in batch.items(): + batch[key] = value.to(device) + # Forward pass + loss = model(**batch) + # Backward pass + loss.backward() + # Optimizer Step + optimizer.step() + losses.append(loss.item()) + if step % log_every == 0: + logger.info("Loss: {0:.4f}".format(np.mean(losses))) + summary_writer.add_scalar(f"Train/loss", np.mean(losses), step) + if step % checkpoint_every == 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict() + } + torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) + if step == num_iterations: + break + + + +if __name__ == "__main__": + fire.Fire({ + "train": train, + "data": create_data_iterator, + "model": create_model + }) \ No newline at end of file From 349e944300a8071a7b222ace73bea0a6d5b811e3 Mon Sep 17 00:00:00 2001 From: bapatra Date: Mon, 18 Oct 2021 21:16:24 -0700 Subject: [PATCH 02/18] adding some error catches --- Intro101/train_bert.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index dc22e9ac4..bd85fb076 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -30,7 +30,13 @@ TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] -def collate_function(batch: List[Tuple[List[int], List[int]]], pad_token_id: int) -> Dict[str, torch.Tensor]: +def collate_function( + batch: List[Tuple[List[int], List[int]]], + pad_token_id: int +) -> Dict[str, torch.Tensor]: + """Collect a list of masked token indices, and labels, and + batch them, padding to max length in the batch. + """ max_length = max( len(token_ids) for token_ids, _ in batch @@ -289,13 +295,28 @@ def create_experiment_dir(checkpoint_dir: pathlib.Path, with hparams_file.open("w") as handle: json.dump(obj=all_arguments, fp=handle, indent=2) # Save the git hash - gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) - with (exp_dir / "githash.log").open("w") as handle: - handle.write(gitlog.stdout.decode("utf-8")) + try: + + gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) + with (exp_dir / "githash.log").open("w") as handle: + handle.write(gitlog.stdout.decode("utf-8")) + except sh.ErrorReturnCode_128: + logger.info("Seems like the code is not running from" + " within a git repo, so hash will" + " not be stored. However, it" + " is strongly advised to use" + " version control.") # And the git diff - gitdiff = sh.git.diff(_fg=False, _tty_out=False) - with (exp_dir / "gitdiff.log").open("w") as handle: - handle.write(gitdiff.stdout.decode("utf-8")) + try: + gitdiff = sh.git.diff(_fg=False, _tty_out=False) + with (exp_dir / "gitdiff.log").open("w") as handle: + handle.write(gitdiff.stdout.decode("utf-8")) + except sh.ErrorReturnCode_129: + logger.info("Seems like the code is not running from" + " within a git repo, so diff will" + " not be stored. However, it" + " is strongly advised to use" + " version control.") # Finally create the Tensorboard Dir tb_dir = exp_dir / "tb_dir" tb_dir.mkdir() @@ -327,7 +348,7 @@ def train( device: int = -1 ) -> None: device = torch.device("cuda", device) \ - if (device > 0) and torch.cuda.is_available() \ + if (device > -1) and torch.cuda.is_available() \ else torch.device("cpu") ################################ ###### Create Datasets ######### @@ -393,6 +414,7 @@ def train( ################################ ####### The Training Loop ###### ################################ + model.train() losses = [] for step, batch in enumerate(data_iterator, start=1): optimizer.zero_grad() From 150187649990dc7fd30d3c126aea72add6300d57 Mon Sep 17 00:00:00 2001 From: bapatra Date: Mon, 18 Oct 2021 22:24:03 -0700 Subject: [PATCH 03/18] docstrings + black + isort --- Intro101/train_bert.py | 451 +++++++++++++++++++++++++++++------------ 1 file changed, 320 insertions(+), 131 deletions(-) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index bd85fb076..8eeec9dd1 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -1,25 +1,26 @@ -from typing import Iterable, Dict, Any, Callable, Tuple, List, Union, Optional -import fire -import pathlib -import uuid import datetime -import pytz import json -import numpy as np +import pathlib +import uuid from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import datasets +import fire import loguru +import numpy as np +import pytz import sh - import torch import torch.nn as nn -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter - -import datasets -from transformers import AutoTokenizer -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.models.roberta import RobertaConfig, RobertaModel -from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaPreTrainedModel +from transformers.models.roberta.modeling_roberta import ( + RobertaLMHead, + RobertaPreTrainedModel, +) logger = loguru.logger @@ -30,17 +31,14 @@ TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + def collate_function( - batch: List[Tuple[List[int], List[int]]], - pad_token_id: int + batch: List[Tuple[List[int], List[int]]], pad_token_id: int ) -> Dict[str, torch.Tensor]: """Collect a list of masked token indices, and labels, and batch them, padding to max length in the batch. """ - max_length = max( - len(token_ids) - for token_ids, _ in batch - ) + max_length = max(len(token_ids) for token_ids, _ in batch) padded_token_ids = [ token_ids + [pad_token_id for _ in range(0, max_length - len(token_ids))] for token_ids, _ in batch @@ -55,23 +53,53 @@ def collate_function( return { "src_tokens": src_tokens, "tgt_tokens": tgt_tokens, - "attention_mask": attention_mask + "attention_mask": attention_mask, } -def masking_function(text: str, - tokenizer: TokenizerType, - mask_prob: float, - random_replace_prob: float, - unmask_replace_prob: float, - max_length: int) -> Tuple[List[int], List[int]]: + +def masking_function( + text: str, + tokenizer: TokenizerType, + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + max_length: int, +) -> Tuple[List[int], List[int]]: + """Given a text string, randomly mask wordpieces for Bert MLM + training. + + Args: + text (str): + The input text + tokenizer (TokenizerType): + The tokenizer for tokenization + mask_prob (float): + What fraction of tokens to mask + random_replace_prob (float): + Of the masked tokens, how many should be replaced with + random tokens (improves performance) + unmask_replace_prob (float): + Of the masked tokens, how many should be replaced with + the original token (improves performance) + max_length (int): + The maximum sequence length to consider. Note that for + Bert style models, this is a function of the number of + positional embeddings you learn + + Returns: + Tuple[List[int], List[int]]: + The masked token ids (based on the tokenizer passed), + and the output labels (padded with `tokenizer.pad_token_id`) + """ # Note: By default, encode does add the BOS and EOS token # Disabling that behaviour to make this more clear - tokenized_ids = [tokenizer.bos_token_id] + \ - tokenizer.encode(text, - add_special_tokens=False, - truncation=True, - max_length=max_length - 2) + \ - [tokenizer.eos_token_id] + tokenized_ids = ( + [tokenizer.bos_token_id] + + tokenizer.encode( + text, add_special_tokens=False, truncation=True, max_length=max_length - 2 + ) + + [tokenizer.eos_token_id] + ) seq_len = len(tokenized_ids) tokenized_ids = np.array(tokenized_ids) subword_mask = np.full(len(tokenized_ids), False) @@ -81,7 +109,9 @@ def masking_function(text: str, high = len(subword_mask) - 1 mask_choices = np.arange(low, high) num_subwords_to_mask = max(int((mask_prob * (high - low)) + np.random.rand()), 1) - subword_mask[np.random.choice(mask_choices, num_subwords_to_mask, replace=False)] = True + subword_mask[ + np.random.choice(mask_choices, num_subwords_to_mask, replace=False) + ] = True # Create the labels first labels = np.full(seq_len, tokenizer.pad_token_id) @@ -92,7 +122,9 @@ def masking_function(text: str, # Now of the masked tokens, choose how many to replace with random and how many to unmask rand_or_unmask_prob = random_replace_prob + unmask_replace_prob if rand_or_unmask_prob > 0: - rand_or_unmask = subword_mask & (np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob) + rand_or_unmask = subword_mask & ( + np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob + ) if random_replace_prob == 0: unmask = rand_or_unmask rand_mask = None @@ -112,40 +144,76 @@ def masking_function(text: str, probs = weights / weights.sum() num_rand = rand_mask.sum() tokenized_ids[rand_mask] = np.random.choice( - tokenizer.vocab_size, - num_rand, - p=probs + tokenizer.vocab_size, num_rand, p=probs ) return tokenized_ids.tolist(), labels.tolist() + class WikiTextMLMDataset(Dataset): - def __init__(self, - dataset: datasets.arrow_dataset.Dataset, - masking_function: Callable[[str], Tuple[List[int], List[int]]]) -> None: + """A [Map style dataset](https://pytorch.org/docs/stable/data.html) + for iterating over the wikitext dataset. Note that this assumes + the dataset can fit in memory. For larger datasets + you'd want to shard them and use an iterable dataset (eg: see + [Infinibatch](https://github.com/microsoft/infinibatch)) + + Args: + Dataset (datasets.arrow_dataset.Dataset): + The wikitext dataset + masking_function (Callable[[str], Tuple[List[int], List[int]]]) + The masking function. To generate one training instance, + the masking function is applied to the `text` of a dataset + record + + """ + + def __init__( + self, + dataset: datasets.arrow_dataset.Dataset, + masking_function: Callable[[str], Tuple[List[int], List[int]]], + ) -> None: self.dataset = dataset self.masking_function = masking_function - + def __len__(self) -> int: return len(self.dataset) - + def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: tokens, labels = self.masking_function(self.dataset[idx]["text"]) return (tokens, labels) -def create_data_iterator(mask_prob: float, - random_replace_prob: float, - unmask_replace_prob: float, - batch_size: int, - max_seq_length: int = 512, - tokenizer: str = "roberta-base") -> DataLoader: - wikitext_dataset = datasets.load_dataset( - "wikitext", - "wikitext-2-v1", - split="train" - ) - wikitext_dataset = wikitext_dataset.filter( - lambda record: record["text"] != "" - ).map( + +def create_data_iterator( + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + batch_size: int, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", +) -> DataLoader: + """Create the dataloader. + + Args: + mask_prob (float): + Fraction of tokens to mask + random_replace_prob (float): + Fraction of masked tokens to replace with random token + unmask_replace_prob (float): + Fraction of masked tokens to replace with the actual token + batch_size (int): + The batch size of the generated tensors + max_seq_length (int, optional): + The maximum sequence length for the MLM task. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + + Returns: + DataLoader: + The torch DataLoader. Note that the dataloader is iterable, + but is not an iterator. + + """ + wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train") + wikitext_dataset = wikitext_dataset.filter(lambda record: record["text"] != "").map( lambda record: {"text": record["text"].rstrip("\n")} ) tokenizer = AutoTokenizer.from_pretrained(tokenizer) @@ -155,19 +223,14 @@ def create_data_iterator(mask_prob: float, mask_prob=mask_prob, random_replace_prob=random_replace_prob, unmask_replace_prob=unmask_replace_prob, - max_length=max_seq_length + max_length=max_seq_length, ) dataset = WikiTextMLMDataset(wikitext_dataset, masking_function_partial) - collate_fn_partial = partial( - collate_function, - pad_token_id=tokenizer.pad_token_id - ) + collate_fn_partial = partial(collate_function, pad_token_id=tokenizer.pad_token_id) dataloader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - collate_fn=collate_fn_partial + dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_partial ) + return dataloader @@ -175,29 +238,36 @@ def create_data_iterator(mask_prob: float, ############### Model Creation Related Functions ##################### ###################################################################### + class RobertaLMHeadWithMaskedPredict(RobertaLMHead): - def __init__(self, - config, - embedding_weight: Optional[torch.Tensor] = None) -> None: + def __init__( + self, config: RobertaConfig, embedding_weight: Optional[torch.Tensor] = None + ) -> None: super(RobertaLMHeadWithMaskedPredict, self).__init__(config) if embedding_weight is not None: self.decoder.weight = embedding_weight def forward( # pylint: disable=arguments-differ - self, - features: torch.Tensor, - masked_token_indices: Optional[ - torch.Tensor] = None, - **kwargs) -> torch.Tensor: - """The current ``Transformers'' library does not provide support - for masked_token_indices. This function provides the support + self, + features: torch.Tensor, + masked_token_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """The current `transformers` library does not provide support + for masked_token_indices. This function provides the support, by + running the final forward pass only for the masked indices. This saves + memory Args: + features (torch.Tensor): + The features to select from. Shape (batch, seq_len, h_dim) masked_token_indices (torch.Tensor, optional): - The indices of masked tokens for index select. Defaults to None. + The indices of masked tokens for index select. Defaults to None. + Shape: (num_masked_tokens,) Returns: - torch.Tensor: The output logits + torch.Tensor: + The index selected features. Shape (num_masked_tokens, h_dim) """ if masked_token_indices is not None: @@ -206,47 +276,91 @@ def forward( # pylint: disable=arguments-differ ) return super().forward(features) + class RobertaMLMModel(RobertaPreTrainedModel): - def __init__(self, - config: RobertaConfig, - encoder: RobertaModel) -> None: + def __init__(self, config: RobertaConfig, encoder: RobertaModel) -> None: super().__init__(config) self.encoder = encoder self.lm_head = RobertaLMHeadWithMaskedPredict( config, self.encoder.embeddings.word_embeddings.weight ) self.lm_head.apply(self._init_weights) - - def forward(self, - src_tokens, - attention_mask, - tgt_tokens) -> torch.Tensor: - sequence_output, *_ = self.encoder(input_ids=src_tokens, - attention_mask=attention_mask, return_dict=False) - + + def forward( + self, + src_tokens: torch.Tensor, + attention_mask: torch.Tensor, + tgt_tokens: torch.Tensor, + ) -> torch.Tensor: + """The forward pass for the MLM task + + Args: + src_tokens (torch.Tensor): + The masked token indices. Shape: (batch, seq_len) + attention_mask (torch.Tensor): + The attention mask, since the batches are padded + to the largest sequence. Shape: (batch, seq_len) + tgt_tokens (torch.Tensor): + The output tokens (padded with `config.pad_token_id`) + + Returns: + torch.Tensor: + The MLM loss + """ + # shape: (batch, seq_len, h_dim) + sequence_output, *_ = self.encoder( + input_ids=src_tokens, attention_mask=attention_mask, return_dict=False + ) + pad_token_id = self.config.pad_token_id # (labels have also been padded with pad_token_id) # filter out all masked labels + # shape: (num_masked_tokens,) masked_token_indexes = torch.nonzero( - (tgt_tokens != pad_token_id).view(-1)).view(-1) - - prediction_scores = self.lm_head(sequence_output, - masked_token_indexes) - - target = torch.index_select(tgt_tokens.view(-1), 0, - masked_token_indexes) + (tgt_tokens != pad_token_id).view(-1) + ).view(-1) + # shape: (num_masked_tokens, vocab_size) + prediction_scores = self.lm_head(sequence_output, masked_token_indexes) + # shape: (num_masked_tokens,) + target = torch.index_select(tgt_tokens.view(-1), 0, masked_token_indexes) loss_fct = nn.CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), target) + prediction_scores.view(-1, self.config.vocab_size), target + ) return masked_lm_loss -def create_model(num_layers: int, - num_heads: int, - ff_dim: int, - h_dim: int, - dropout: float) -> RobertaModel: + +def create_model( + num_layers: int, num_heads: int, ff_dim: int, h_dim: int, dropout: float +) -> RobertaMLMModel: + """Create a Bert model with the specified `num_heads`, `ff_dim`, + `h_dim` and `dropout` + + Args: + num_layers (int): + The number of layers + num_heads (int): + The number of attention heads + ff_dim (int): + The intermediate hidden size of + the feed forward block of the + transformer + h_dim (int): + The hidden dim of the intermediate + representations of the transformer + dropout (float): + The value of dropout to be used. + Note that we apply the same dropout + to both the attention layers and the + FF layers + + Returns: + RobertaMLMModel: + A Roberta model for MLM task + + """ roberta_config_dict = { "attention_probs_dropout_prob": dropout, "bos_token_id": 0, @@ -263,7 +377,7 @@ def create_model(num_layers: int, "num_hidden_layers": num_layers, "pad_token_id": 1, "type_vocab_size": 1, - "vocab_size": 50265 + "vocab_size": 50265, } roberta_config = RobertaConfig.from_dict(roberta_config_dict) roberta_encoder = RobertaModel(roberta_config) @@ -275,8 +389,30 @@ def create_model(num_layers: int, ########### Experiment Management Related Functions ################## ###################################################################### -def create_experiment_dir(checkpoint_dir: pathlib.Path, - all_arguments: Dict[str, Any]): + +def create_experiment_dir( + checkpoint_dir: pathlib.Path, all_arguments: Dict[str, Any] +) -> pathlib.Path: + """Create an experiment directory and save all arguments in it. + Additionally, also store the githash and gitdiff. Finally create + a directory for `Tensorboard` logs. The structure would look something + like + checkpoint_dir + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ + + Args: + checkpoint_dir (pathlib.Path): + The base checkpoint directory + all_arguments (Dict[str, Any]): + The arguments to save + + Returns: + pathlib.Path: The experiment directory + """ # experiment name follows the following convention # {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} current_time = datetime.datetime.now(pytz.timezone("US/Pacific")) @@ -287,45 +423,50 @@ def create_experiment_dir(checkpoint_dir: pathlib.Path, current_time.hour, current_time.minute, current_time.second, - str(uuid.uuid4()) + str(uuid.uuid4()), ) - exp_dir = (checkpoint_dir / expname) + exp_dir = checkpoint_dir / expname exp_dir.mkdir(exist_ok=False) hparams_file = exp_dir / "hparams.json" with hparams_file.open("w") as handle: json.dump(obj=all_arguments, fp=handle, indent=2) # Save the git hash try: - gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) with (exp_dir / "githash.log").open("w") as handle: handle.write(gitlog.stdout.decode("utf-8")) except sh.ErrorReturnCode_128: - logger.info("Seems like the code is not running from" - " within a git repo, so hash will" - " not be stored. However, it" - " is strongly advised to use" - " version control.") + logger.info( + "Seems like the code is not running from" + " within a git repo, so hash will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) # And the git diff try: gitdiff = sh.git.diff(_fg=False, _tty_out=False) with (exp_dir / "gitdiff.log").open("w") as handle: handle.write(gitdiff.stdout.decode("utf-8")) except sh.ErrorReturnCode_129: - logger.info("Seems like the code is not running from" - " within a git repo, so diff will" - " not be stored. However, it" - " is strongly advised to use" - " version control.") + logger.info( + "Seems like the code is not running from" + " within a git repo, so diff will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) # Finally create the Tensorboard Dir tb_dir = exp_dir / "tb_dir" tb_dir.mkdir() return exp_dir + ###################################################################### ####################### Driver Functions ############################# ###################################################################### + def train( checkpoint_dir: str, # Dataset Parameters @@ -345,11 +486,64 @@ def train( num_iterations: int = 10000, checkpoint_every: int = 1000, log_every: int = 10, - device: int = -1 -) -> None: - device = torch.device("cuda", device) \ - if (device > -1) and torch.cuda.is_available() \ - else torch.device("cpu") + device: int = -1, +) -> None: + """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) + (transformer encoder only) model for MLM Task + + Args: + checkpoint_dir (str): + The base experiment directory to save experiments to + mask_prob (float, optional): + The fraction of tokens to mask. Defaults to 0.15. + random_replace_prob (float, optional): + The fraction of masked tokens to replace with random token. + Defaults to 0.1. + unmask_replace_prob (float, optional): + The fraction of masked tokens to leave unchanged. + Defaults to 0.1. + max_seq_length (int, optional): + The maximum sequence length of the examples. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + num_layers (int, optional): + The number of layers in the Bert model. Defaults to 6. + num_heads (int, optional): + Number of attention heads to use. Defaults to 8. + ff_dim (int, optional): + Size of the intermediate dimension in the FF layer. + Defaults to 512. + h_dim (int, optional): + Size of intermediate representations. + Defaults to 256. + dropout (float, optional): + Amout of Dropout to use. Defaults to 0.1. + batch_size (int, optional): + The minibatch size. Defaults to 8. + num_iterations (int, optional): + Total number of iterations to run the model for. + Defaults to 10000. + checkpoint_every (int, optional): + Save checkpoint after these many steps. + + ..note :: + + You want this to be frequent enough that you can + resume training in case it crashes, but not so much + that you fill up your entire storage ! + + Defaults to 1000. + log_every (int, optional): + Print logs after these many steps. Defaults to 10. + device (int, optional): + Which GPU to run on (-1 for CPU). Defaults to -1. + + """ + device = ( + torch.device("cuda", device) + if (device > -1) and torch.cuda.is_available() + else torch.device("cpu") + ) ################################ ###### Create Datasets ######### ################################ @@ -360,7 +554,7 @@ def train( unmask_replace_prob=unmask_replace_prob, tokenizer=tokenizer, max_seq_length=max_seq_length, - batch_size=batch_size + batch_size=batch_size, ) logger.info("Dataset Creation Done") ################################ @@ -372,7 +566,7 @@ def train( num_heads=num_heads, ff_dim=ff_dim, h_dim=h_dim, - dropout=dropout + dropout=dropout, ) model = model.to(device) logger.info("Model Creation Done") @@ -434,17 +628,12 @@ def train( if step % checkpoint_every == 0: state_dict = { "model": model.state_dict(), - "optimizer": optimizer.state_dict() + "optimizer": optimizer.state_dict(), } torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) if step == num_iterations: break - if __name__ == "__main__": - fire.Fire({ - "train": train, - "data": create_data_iterator, - "model": create_model - }) \ No newline at end of file + fire.Fire({"train": train, "data": create_data_iterator, "model": create_model}) From b0fc5850a4db1db76e70f9eb4fb86f6199d69033 Mon Sep 17 00:00:00 2001 From: bapatra Date: Mon, 18 Oct 2021 23:22:29 -0700 Subject: [PATCH 04/18] checkpointing --- Intro101/train_bert.py | 149 ++++++++++++++++++++++++++++++++--------- 1 file changed, 116 insertions(+), 33 deletions(-) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index 8eeec9dd1..2407a4521 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -1,7 +1,8 @@ import datetime import json import pathlib -import uuid +import re +import string from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -390,6 +391,15 @@ def create_model( ###################################################################### +def get_unique_identifier(length: int = 8) -> str: + """Create a unique identifier by choosing `length` + random characters from list of ascii characters and numbers + """ + alphabet = string.ascii_lowercase + string.digits + uuid = "".join(alphabet[ix] for ix in np.random.choice(len(alphabet), length)) + return uuid + + def create_experiment_dir( checkpoint_dir: pathlib.Path, all_arguments: Dict[str, Any] ) -> pathlib.Path: @@ -423,7 +433,7 @@ def create_experiment_dir( current_time.hour, current_time.minute, current_time.second, - str(uuid.uuid4()), + get_unique_identifier(), ) exp_dir = checkpoint_dir / expname exp_dir.mkdir(exist_ok=False) @@ -468,7 +478,8 @@ def create_experiment_dir( def train( - checkpoint_dir: str, + checkpoint_dir: str = None, + load_checkpoint_dir: str = None, # Dataset Parameters mask_prob: float = 0.15, random_replace_prob: float = 0.1, @@ -544,6 +555,67 @@ def train( if (device > -1) and torch.cuda.is_available() else torch.device("cpu") ) + ################################ + ###### Create Exp. Dir ######### + ################################ + if checkpoint_dir is None and load_checkpoint_dir is None: + logger.error("Need to specify one of checkpoint_dir" " or load_checkpoint_dir") + return + if checkpoint_dir is not None and load_checkpoint_dir is not None: + logger.error("Cannot specify both checkpoint_dir" " and load_checkpoint_dir") + return + if checkpoint_dir: + logger.info("Creating Experiment Directory") + checkpoint_dir = pathlib.Path(checkpoint_dir) + checkpoint_dir.mkdir(exist_ok=True) + all_arguments = { + # Dataset Params + "mask_prob": mask_prob, + "random_replace_prob": random_replace_prob, + "unmask_replace_prob": unmask_replace_prob, + "max_seq_length": max_seq_length, + "tokenizer": tokenizer, + # Model Params + "num_layers": num_layers, + "num_heads": num_heads, + "ff_dim": ff_dim, + "h_dim": h_dim, + "dropout": dropout, + # Training Params + "batch_size": batch_size, + "num_iterations": num_iterations, + "checkpoint_every": checkpoint_every, + } + exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) + logger.info(f"Experiment Directory created at {exp_dir}") + else: + logger.info("Loading from Experiment Directory") + load_checkpoint_dir = pathlib.Path(load_checkpoint_dir) + assert load_checkpoint_dir.exists() + with (load_checkpoint_dir / "hparams.json").open("r") as handle: + hparams = json.load(handle) + # Set the hparams + # Dataset Params + mask_prob = hparams.get("mask_prob", mask_prob) + tokenizer = hparams.get("tokenizer", tokenizer) + random_replace_prob = hparams.get("random_replace_prob", random_replace_prob) + unmask_replace_prob = hparams.get("unmask_replace_prob", unmask_replace_prob) + max_seq_length = hparams.get("max_seq_length", max_seq_length) + # Model Params + ff_dim = hparams.get("ff_dim", ff_dim) + h_dim = hparams.get("h_dim", h_dim) + dropout = hparams.get("dropout", dropout) + num_layers = hparams.get("num_layers", num_layers) + num_heads = hparams.get("num_heads", num_heads) + # Training Params + batch_size = hparams.get("batch_size", batch_size) + num_iterations = hparams.get("num_iterations", num_iterations) + checkpoint_every = hparams.get("checkpoint_every", checkpoint_every) + exp_dir = load_checkpoint_dir + ################################ ###### Create Datasets ######### ################################ @@ -571,46 +643,54 @@ def train( model = model.to(device) logger.info("Model Creation Done") ################################ - ###### Create Exp. Dir ######### - ################################ - logger.info("Creating Experiment Directory") - checkpoint_dir = pathlib.Path(checkpoint_dir) - checkpoint_dir.mkdir(exist_ok=True) - all_arguments = { - # Dataset Params - "mask_prob": mask_prob, - "random_replace_prob": random_replace_prob, - "unmask_replace_prob": unmask_replace_prob, - "max_seq_length": max_seq_length, - "tokenizer": tokenizer, - # Model Params - "num_layers": num_layers, - "num_heads": num_heads, - "ff_dim": ff_dim, - "h_dim": h_dim, - "dropout": dropout, - # Training Params - "batch_size": batch_size, - "num_iterations": num_iterations, - "checkpoint_every": checkpoint_every, - } - exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) - tb_dir = exp_dir / "tb_dir" - assert tb_dir.exists() - summary_writer = SummaryWriter(log_dir=tb_dir) - logger.info(f"Experiment Directory created at {exp_dir}") - ################################ ###### Create Optimizer ####### ################################ logger.info("Creating Optimizer") optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) logger.info("Optimizer Creation Done") ################################ + #### Load Model checkpoint ##### + ################################ + start_step = 1 + if load_checkpoint_dir is not None: + logger.info( + f"Loading model and optimizer checkpoint from {load_checkpoint_dir}" + ) + checkpoint_files = list( + filter( + lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) + is not None, + load_checkpoint_dir.glob("*.pt"), + ) + ) + assert len(checkpoint_files) > 0, "No checkpoints found in directory" + checkpoint_files = sorted( + checkpoint_files, + key=lambda path: int( + re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") + ), + ) + latest_checkpoint_path = checkpoint_files[-1] + start_step = ( + int( + re.search( + r"iter_(?P\d+)\.pt", latest_checkpoint_path.name + ).group("iter_no") + ) + + 1 + ) + state_dict = torch.load(latest_checkpoint_path) + model.load_state_dict(state_dict["model"], strict=True) + optimizer.load_state_dict(state_dict["optimizer"]) + logger.info( + f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + ) + ################################ ####### The Training Loop ###### ################################ model.train() losses = [] - for step, batch in enumerate(data_iterator, start=1): + for step, batch in enumerate(data_iterator, start=start_step): optimizer.zero_grad() # Move the tensors to device for key, value in batch.items(): @@ -631,6 +711,9 @@ def train( "optimizer": optimizer.state_dict(), } torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) + logger.info( + "Saved model to {0}".format((exp_dir / f"checkpoint.iter_{step}.pt")) + ) if step == num_iterations: break From 2c6ff0a3999540e6b544043610fc2a3d2ab4d3e1 Mon Sep 17 00:00:00 2001 From: bapatra Date: Tue, 19 Oct 2021 01:18:08 -0700 Subject: [PATCH 05/18] added test case for train_bert --- Intro101/requirements.txt | 4 +- Intro101/tests/__init__.py | 0 Intro101/tests/test_train_bert.py | 108 ++++++++++++++++++++ Intro101/train_bert.py | 160 +++++++++++++++++++++--------- 4 files changed, 225 insertions(+), 47 deletions(-) create mode 100644 Intro101/tests/__init__.py create mode 100644 Intro101/tests/test_train_bert.py diff --git a/Intro101/requirements.txt b/Intro101/requirements.txt index 66ec19547..3471b7061 100644 --- a/Intro101/requirements.txt +++ b/Intro101/requirements.txt @@ -3,4 +3,6 @@ transformers==4.5.1 fire==0.4.0 pytz==2021.1 loguru==0.5.3 -sh==1.14.2 \ No newline at end of file +sh==1.14.2 +pytest==6.2.5 +tqdm==4.62.3 \ No newline at end of file diff --git a/Intro101/tests/__init__.py b/Intro101/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/Intro101/tests/test_train_bert.py b/Intro101/tests/test_train_bert.py new file mode 100644 index 000000000..307fb735f --- /dev/null +++ b/Intro101/tests/test_train_bert.py @@ -0,0 +1,108 @@ +import tempfile + +import numpy as np +import pytest +import torch +import tqdm +from transformers import AutoTokenizer + +from train_bert import create_data_iterator, create_model, load_model_checkpoint, train + + +@pytest.fixture(scope="function") +def checkpoint_dir() -> str: + with tempfile.TemporaryDirectory() as tmpdirname: + yield tmpdirname + + +def test_masking_stats(tol: float = 1e-3): + """Test to check that the masking probabilities + match what we expect them to be. + """ + kwargs = { + "mask_prob": 0.15, + "random_replace_prob": 0.1, + "unmask_replace_prob": 0.1, + "batch_size": 8, + } + tokenizer = AutoTokenizer.from_pretrained("roberta-base") + dataloader = create_data_iterator(**kwargs) + num_samples = 10000 + total_tokens = 0 + masked_tokens = 0 + random_replace_tokens = 0 + unmasked_replace_tokens = 0 + for ix, batch in tqdm.tqdm(enumerate(dataloader, start=1), total=num_samples): + # Since we don't mask the BOS / EOS tokens, we subtract them from the total tokens + total_tokens += batch["attention_mask"].sum().item() - ( + 2 * batch["attention_mask"].size(0) + ) + masked_tokens += (batch["tgt_tokens"] != tokenizer.pad_token_id).sum().item() + random_or_unmasked = ( + batch["tgt_tokens"] != tokenizer.pad_token_id + ).logical_and(batch["src_tokens"] != tokenizer.mask_token_id) + unmasked = random_or_unmasked.logical_and( + batch["src_tokens"] == batch["tgt_tokens"] + ) + unmasked_replace_tokens += unmasked.sum().item() + random_replace_tokens += random_or_unmasked.sum().item() - unmasked.sum().item() + if ix == num_samples: + break + estimated_mask_prob = masked_tokens / total_tokens + estimated_random_tokens = random_replace_tokens / total_tokens + estimated_unmasked_tokens = unmasked_replace_tokens / total_tokens + assert np.isclose(estimated_mask_prob, kwargs["mask_prob"], atol=tol) + assert np.isclose( + estimated_random_tokens, + kwargs["random_replace_prob"] * kwargs["mask_prob"], + atol=tol, + ) + assert np.isclose( + estimated_unmasked_tokens, + kwargs["unmask_replace_prob"] * kwargs["mask_prob"], + atol=tol, + ) + + +def test_model_checkpointing(checkpoint_dir: str): + """Training a small model, and ensuring + that both checkpointing and resuming from + a checkpoint work. + """ + # First train a tiny model for 5 iterations + train_params = { + "checkpoint_dir": checkpoint_dir, + "checkpoint_every": 2, + "num_layers": 2, + "num_heads": 4, + "ff_dim": 64, + "h_dim": 64, + "num_iterations": 5, + } + exp_dir = train(**train_params) + # now check that we have 3 checkpoints + assert len(list(exp_dir.glob("*.pt"))) == 3 + model = create_model( + num_layers=train_params["num_layers"], + num_heads=train_params["num_heads"], + ff_dim=train_params["ff_dim"], + h_dim=train_params["h_dim"], + dropout=0.1, + ) + optimizer = torch.optim.Adam(model.parameters()) + step, model, optimizer = load_model_checkpoint(exp_dir, model, optimizer) + assert step == 5 + model_state_dict = model.state_dict() + # the saved checkpoint would be for iteration 5 + correct_state_dict = torch.load(exp_dir / "checkpoint.iter_5.pt") + correct_model_state_dict = correct_state_dict["model"] + assert set(model_state_dict.keys()) == set(correct_model_state_dict.keys()) + assert all( + torch.allclose(model_state_dict[key], correct_model_state_dict[key]) + for key in model_state_dict.keys() + ) + # Finally, try training with the checkpoint + train_params.pop("checkpoint_dir") + train_params["load_checkpoint_dir"] = str(exp_dir) + train_params["num_iterations"] = 10 + train(**train_params) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index 2407a4521..349a6ad4d 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -4,7 +4,7 @@ import re import string from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union import datasets import fire @@ -183,6 +183,27 @@ def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: return (tokens, labels) +T = TypeVar("T") + + +class InfiniteIterator(object): + def __init__(self, iterable: Iterable[T]) -> None: + self._iterable = iterable + self._iterator = iter(self._iterable) + + def __iter__(self): + return self + + def __next__(self) -> T: + next_item = None + try: + next_item = next(self._iterator) + except StopIteration: + self._iterator = iter(self._iterable) + next_item = next(self._iterator) + return next_item + + def create_data_iterator( mask_prob: float, random_replace_prob: float, @@ -190,7 +211,7 @@ def create_data_iterator( batch_size: int, max_seq_length: int = 512, tokenizer: str = "roberta-base", -) -> DataLoader: +) -> InfiniteIterator: """Create the dataloader. Args: @@ -208,9 +229,9 @@ def create_data_iterator( The tokenizer to use. Defaults to "roberta-base". Returns: - DataLoader: - The torch DataLoader. Note that the dataloader is iterable, - but is not an iterator. + InfiniteIterator: + The torch DataLoader, wrapped in an InfiniteIterator class, to + be able to continuously generate samples """ wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train") @@ -232,7 +253,7 @@ def create_data_iterator( dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_partial ) - return dataloader + return InfiniteIterator(dataloader) ###################################################################### @@ -473,7 +494,65 @@ def create_experiment_dir( ###################################################################### -####################### Driver Functions ############################# +################ Checkpoint Related Functions ######################## +###################################################################### + + +def load_model_checkpoint( + load_checkpoint_dir: pathlib.Path, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, +) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + """Loads the optimizer state dict and model state dict from the load_checkpoint_dir + into the passed model and optimizer. Searches for the most recent checkpoint to + load from + + Args: + load_checkpoint_dir (pathlib.Path): + The base checkpoint directory to load from + model (torch.nn.Module): + The model to load the checkpoint weights into + optimizer (torch.optim.Optimizer): + The optimizer to load the checkpoint weigths into + + Returns: + Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + The checkpoint step, model with state_dict loaded and + optimizer with state_dict loaded + + """ + logger.info(f"Loading model and optimizer checkpoint from {load_checkpoint_dir}") + checkpoint_files = list( + filter( + lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) is not None, + load_checkpoint_dir.glob("*.pt"), + ) + ) + assert len(checkpoint_files) > 0, "No checkpoints found in directory" + checkpoint_files = sorted( + checkpoint_files, + key=lambda path: int( + re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") + ), + ) + latest_checkpoint_path = checkpoint_files[-1] + checkpoint_step = int( + re.search(r"iter_(?P\d+)\.pt", latest_checkpoint_path.name).group( + "iter_no" + ) + ) + + state_dict = torch.load(latest_checkpoint_path) + model.load_state_dict(state_dict["model"], strict=True) + optimizer.load_state_dict(state_dict["optimizer"]) + logger.info( + f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + ) + return checkpoint_step, model, optimizer + + +###################################################################### +######################## Driver Functions ############################ ###################################################################### @@ -498,7 +577,7 @@ def train( checkpoint_every: int = 1000, log_every: int = 10, device: int = -1, -) -> None: +) -> pathlib.Path: """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) (transformer encoder only) model for MLM Task @@ -549,6 +628,9 @@ def train( device (int, optional): Which GPU to run on (-1 for CPU). Defaults to -1. + Returns: + pathlib.Path: The final experiment directory + """ device = ( torch.device("cuda", device) @@ -587,9 +669,6 @@ def train( "checkpoint_every": checkpoint_every, } exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) - tb_dir = exp_dir / "tb_dir" - assert tb_dir.exists() - summary_writer = SummaryWriter(log_dir=tb_dir) logger.info(f"Experiment Directory created at {exp_dir}") else: logger.info("Loading from Experiment Directory") @@ -612,10 +691,14 @@ def train( num_heads = hparams.get("num_heads", num_heads) # Training Params batch_size = hparams.get("batch_size", batch_size) - num_iterations = hparams.get("num_iterations", num_iterations) + _num_iterations = hparams.get("num_iterations", num_iterations) + num_iterations = max(num_iterations, _num_iterations) checkpoint_every = hparams.get("checkpoint_every", checkpoint_every) exp_dir = load_checkpoint_dir - + # Tensorboard writer + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) ################################ ###### Create Datasets ######### ################################ @@ -653,44 +736,19 @@ def train( ################################ start_step = 1 if load_checkpoint_dir is not None: - logger.info( - f"Loading model and optimizer checkpoint from {load_checkpoint_dir}" - ) - checkpoint_files = list( - filter( - lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) - is not None, - load_checkpoint_dir.glob("*.pt"), - ) - ) - assert len(checkpoint_files) > 0, "No checkpoints found in directory" - checkpoint_files = sorted( - checkpoint_files, - key=lambda path: int( - re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") - ), - ) - latest_checkpoint_path = checkpoint_files[-1] - start_step = ( - int( - re.search( - r"iter_(?P\d+)\.pt", latest_checkpoint_path.name - ).group("iter_no") - ) - + 1 - ) - state_dict = torch.load(latest_checkpoint_path) - model.load_state_dict(state_dict["model"], strict=True) - optimizer.load_state_dict(state_dict["optimizer"]) - logger.info( - f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + checkpoint_step, model, optimizer = load_model_checkpoint( + load_checkpoint_dir, model, optimizer ) + start_step = checkpoint_step + 1 + ################################ ####### The Training Loop ###### ################################ model.train() losses = [] for step, batch in enumerate(data_iterator, start=start_step): + if step >= num_iterations: + break optimizer.zero_grad() # Move the tensors to device for key, value in batch.items(): @@ -714,8 +772,18 @@ def train( logger.info( "Saved model to {0}".format((exp_dir / f"checkpoint.iter_{step}.pt")) ) - if step == num_iterations: - break + # Save the last checkpoint if not saved yet + if step % checkpoint_every != 0: + state_dict = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + torch.save(obj=state_dict, f=str(exp_dir / f"checkpoint.iter_{step}.pt")) + logger.info( + "Saved model to {0}".format((exp_dir / f"checkpoint.iter_{step}.pt")) + ) + + return exp_dir if __name__ == "__main__": From 5b7fcdc71f2c254eb09ea1bedaa931d0997331cd Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 19 Oct 2021 19:55:20 +0000 Subject: [PATCH 06/18] change device to local_rank and simplify fire --- Intro101/train_bert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index 349a6ad4d..67c07c60b 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -576,7 +576,7 @@ def train( num_iterations: int = 10000, checkpoint_every: int = 1000, log_every: int = 10, - device: int = -1, + local_rank: int = -1, ) -> pathlib.Path: """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) (transformer encoder only) model for MLM Task @@ -625,7 +625,7 @@ def train( Defaults to 1000. log_every (int, optional): Print logs after these many steps. Defaults to 10. - device (int, optional): + local_rank (int, optional): Which GPU to run on (-1 for CPU). Defaults to -1. Returns: @@ -633,8 +633,8 @@ def train( """ device = ( - torch.device("cuda", device) - if (device > -1) and torch.cuda.is_available() + torch.device("cuda", local_rank) + if (local_rank > -1) and torch.cuda.is_available() else torch.device("cpu") ) ################################ @@ -787,4 +787,4 @@ def train( if __name__ == "__main__": - fire.Fire({"train": train, "data": create_data_iterator, "model": create_model}) + fire.Fire(train) From c31e28b966ec977d29409e33f92e4fcf2af701ee Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 19 Oct 2021 15:56:20 -0700 Subject: [PATCH 07/18] Update README.md --- Intro101/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/Intro101/README.md b/Intro101/README.md index e69de29bb..17a931d19 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -0,0 +1,32 @@ + +good practices + * experiment directory saving training metadata + * pytests + +data + * what is masked LM + * what does the code do (link to code) + +model + * core params for transformer model (e.g., #layers, attn) + +how to run + * launching on CPU (slow) launch on single GPU (fast) + * different train params + +------ + +deepspeed additions + * deepspeed.init, training loop, ckpt changes + +launching across multiple GPUs + +fp16 + * how to enable + * show memory reduction when enabled via nvidia-smi + * brief overview of how fp16 training works (e.g., loss scaling) + +zero + * introduce how zero reduces memory + * introduce zero offload + * update config to use z1 + offload to showcase a model that can only run with offload enabled From d5088ccc25d10dc2e2f28eaa68d2da2c49620693 Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 01:14:13 -0700 Subject: [PATCH 08/18] updated README --- Intro101/README.md | 143 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/Intro101/README.md b/Intro101/README.md index 17a931d19..28575ff95 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -1,4 +1,147 @@ +# Training an Masked Language Model with PyTorch and Deepspeed +In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate Deepspeed, and show some of the advantages of doing so. + +# 1. Training a Transformer Encoder (BERT / Roberta) model for MLM + +## 1.0 Some Good Practices + +### Version Control and Reproducibility + +One of the most important parts of training ML models is for the experiments to be reproducible (either by someone else, or by you 3 months later). Some steps that help with this are: + +* Use some form of version control (eg: `git`). Additionally, make sure to save the `gitdiff` and `githash`, so that anyone trying to replicate the experiment can easily do so + +* Save all the hyperparameters associated with the experiment (be it taken from a config or parsed from the command line) + +* Seed your random generators + +* Specify all the packages and their versions. This can be a `requirements.txt` file, a conda `env.yaml` file or a `pyproject.toml` file. If you want complete reproducibility, you can also include a `Dockerfile` to specify the environemnt to run the experiments in. + +In this example, the checkpoint directory we create is of the following format: + +```bash +{exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ +``` +So that if you revisit this experiment, it is easier to understand what the experiment was about, when was it run, with what hyperparameters and what was the status of the code when the run was executed. + +### Writing Unit Tests + +Unit tests can help catch bugs early as well as set up a fast feedback loop. Since some experiments can take days or even weeks to run, both of these things is quite invaluable. + +In this tutorial, two primary ingredients are data creating and model training. Hence, we try and test both these parts out (see [here](./tests/test_train_bert.py)): + +1. Testing Data creating: Our data creating for the MLM task involves randomly masking words to generate model inputs. So, we test if the fraction of masked tokens matches what we expect from our `DataLoader` (for more details, please take a look at) + +```python +def test_masking_stats(tol: float = 1e-3): + """Test to check that the masking probabilities + match what we expect them to be. + """ + ... +``` + +2. Model training and checkpointing: Since pretraining experiments are usually quite expensive (take days / weeks to complete), chances are that you might run into some hardware failure before the training completes. Thus, it is crucial that the checkpointing logic is correct to allow a model to resume training. One way to do so is to train a small model for few iterations and see if the model can resume training and if the checkpoints are loaded correctly. See `test_model_checkpointing` for an example test. + +--- + +💡 **_Tip:_** Make sure to also save the optimizer state_dict along with the model parameters ! + +--- + +## 1.1 The Masked Language Modeling Task + +The main idea behind the MLM task is to get the model to fill in the blanks based on contextual clues present **both before and after** the blank. Consider, for example, the following sentence: + +> In the beautiful season of ___ the ___ shed their leaves. + +Given the left context of `season` and the right context of `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both left and right context to be able to fill in blanks. + +In order to do that, we carry out the following steps + +1. Tokenize a sentence into word(pieces) +2. Randomly select some words to mask, and replace them with a special \ token + * Of the masked tokens, it is common to replace a fraction of them with a random token, and leave a fraction of them unchanged. +3. Collect the actual words that were masked, and use that as targets for the model to predict against: + * From the model's perspective, this is a simple `CrossEntropy` loss over the vocabulary of the model. + +In this tutorial, we use the [`wikitext-2-v1`](https://huggingface.co/datasets/wikitext) dataset from [HuggingFace datasets](https://github.com/huggingface/datasets). To see how this is done in code, take a look at `masking_function` in [train_bert.py](./train_bert.py). + + +## 1.2 Creating a Transformer model + +A Transformer model repeatedly applies a (Multi-Headed) Self-Attention block and a FeedForward layer to generate contextual representations for each token. Thus, the key hyperparameters for a Transformer model usually are + +1. The number of Self-Attention + FeedForward blocks (depth) +2. The size of the hidden representation +3. The number of Self Attention Heads +4. The size of the intermediate representation between in the FeedForward block + +Check out the `create_model` function in [train_bert.py](./train_bert.py) to see how this is done in code. + +--- +📌 **Note:** You can check out [[1](#1), [2](#2)] as a starting point for better understanding Transformers. Additionally, there are a number of blogs that do nice deep dive into the workings of these models (eg: [this](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [this](https://jalammar.github.io/illustrated-bert/) and [this](https://jalammar.github.io/illustrated-transformer/)). + +--- + +## 1.3 Training the Model + +In order to train the model, you can run the following + +```bash +python train_bert.py --checkpoint_dir ./experiments +``` +This will create a model with the default parameters (as specified by the arguments to the `train` function), and train it on the wikitext dataset. Other parameters can be configured from the command line as: + +```bash +python train_bert.py --checkpoint_dir ./experiments \ + --mask_prob ${mask_prob} \ + --random_replace_prob ${random_replace_prob} \ + --unmask_replace_prob ${unmask_replace_prob} \ + --max_seq_length ${max_seq_length} \ + --tokenizer ${tokenizer} \ + --num_layers ${num_layers} \ + --num_heads ${num_heads} \ + --ff_dim ${ff_dim} \ + --h_dim ${h_dim} \ + --dropout ${dropout} \ + --batch_size ${batch_size} \ + --num_iterations ${num_iterations} \ + --checkpoint_every ${checkpoint_every} \ + --log_every ${log_every} \ + --local_rank ${local_rank} + +``` + +The parameters are explained in more details in the doctstring of `train`. + +--- +💡 **_Tip:_** If you have a GPU available, you can use it to considerably speedup your training. Simply set the `local_rank` to the GPU you want to run it on. Eg: for a single GPU machine this would look like +```bash +--local_rank 0 +``` + +--- + +## 2. Integrating Deepspeed For More Efficient Training + + +## References +> [1] +[Vaswani et. al. Attention is all you need. +In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)](https://arxiv.org/pdf/1706.03762.pdf) + +> [2] +[Devlin, Jacob et. al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) + +--------- + +## Scratch pad (TODO Remove) good practices * experiment directory saving training metadata * pytests From 288ff21a3baff7dc64f36b050a9371198f6d4ad3 Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 14:30:18 -0700 Subject: [PATCH 09/18] spelling mistakes and typographic edits --- Intro101/README.md | 56 +++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/Intro101/README.md b/Intro101/README.md index 28575ff95..dc7b85d67 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -2,9 +2,9 @@ In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate Deepspeed, and show some of the advantages of doing so. -# 1. Training a Transformer Encoder (BERT / Roberta) model for MLM +## 1. Training a Transformer Encoder (BERT / Roberta) model for MLM -## 1.0 Some Good Practices +### 1.0 Some Good Practices ### Version Control and Reproducibility @@ -16,9 +16,9 @@ One of the most important parts of training ML models is for the experiments to * Seed your random generators -* Specify all the packages and their versions. This can be a `requirements.txt` file, a conda `env.yaml` file or a `pyproject.toml` file. If you want complete reproducibility, you can also include a `Dockerfile` to specify the environemnt to run the experiments in. +* Specify all the packages and their versions. This can be a `requirements.txt` file, a conda `env.yaml` file or a `pyproject.toml` file. If you want complete reproducibility, you can also include a `Dockerfile` to specify the environment to run the experiment in. -In this example, the checkpoint directory we create is of the following format: +In this example, the checkpoint directory has the following format: ```bash {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} @@ -28,15 +28,15 @@ In this example, the checkpoint directory we create is of the following format: |- gitdiff.log `- tb_dir/ ``` -So that if you revisit this experiment, it is easier to understand what the experiment was about, when was it run, with what hyperparameters and what was the status of the code when the run was executed. +This ensures that if you revisit this experiment, it is easy to understand what the experiment was about, when was it run, with what hyperparameters and what was the status of the code when the run was executed. ### Writing Unit Tests -Unit tests can help catch bugs early as well as set up a fast feedback loop. Since some experiments can take days or even weeks to run, both of these things is quite invaluable. +Unit tests can help catch bugs early, and also set up a fast feedback loop. Since some experiments can take days or even weeks to run, both of these things are quite invaluable. -In this tutorial, two primary ingredients are data creating and model training. Hence, we try and test both these parts out (see [here](./tests/test_train_bert.py)): +In this tutorial, two primary parts are data creating and model training. Hence, we test both these parts (see [here](./tests/test_train_bert.py)): -1. Testing Data creating: Our data creating for the MLM task involves randomly masking words to generate model inputs. So, we test if the fraction of masked tokens matches what we expect from our `DataLoader` (for more details, please take a look at) +1. Testing Data creation: Data creation for the MLM task involves randomly masking words to generate model inputs. To test for correctness, we test if the fraction of masked tokens matches what we expect from our `DataLoader`. For more details, please take a look at ```python def test_masking_stats(tol: float = 1e-3): @@ -46,52 +46,52 @@ def test_masking_stats(tol: float = 1e-3): ... ``` -2. Model training and checkpointing: Since pretraining experiments are usually quite expensive (take days / weeks to complete), chances are that you might run into some hardware failure before the training completes. Thus, it is crucial that the checkpointing logic is correct to allow a model to resume training. One way to do so is to train a small model for few iterations and see if the model can resume training and if the checkpoints are loaded correctly. See `test_model_checkpointing` for an example test. +2. Model training and checkpointing: Since pretraining experiments are usually quite expensive (take days / weeks to complete), chances are that you might run into some hardware failure before the training completes. Thus, it is crucial that the checkpointing logic is correct to allow a model to resume training. One way to do this is to train a small model for a few iterations and see if the model can resume training and if the checkpoints are loaded correctly. See `test_model_checkpointing` for an example test. --- -💡 **_Tip:_** Make sure to also save the optimizer state_dict along with the model parameters ! +💡 **_Tip:_** While saving checkpoints, make sure to also save the optimizer states along with the model parameters ! --- -## 1.1 The Masked Language Modeling Task +### 1.1 The Masked Language Modeling Task The main idea behind the MLM task is to get the model to fill in the blanks based on contextual clues present **both before and after** the blank. Consider, for example, the following sentence: -> In the beautiful season of ___ the ___ shed their leaves. +> In the beautiful season of ____ the ____ shed their leaves. -Given the left context of `season` and the right context of `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both left and right context to be able to fill in blanks. +Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to be to fill in blanks. -In order to do that, we carry out the following steps +In order to do that, we do the following: 1. Tokenize a sentence into word(pieces) 2. Randomly select some words to mask, and replace them with a special \ token * Of the masked tokens, it is common to replace a fraction of them with a random token, and leave a fraction of them unchanged. -3. Collect the actual words that were masked, and use that as targets for the model to predict against: +3. Collect the actual words that were masked, and use them as targets for the model to predict against: * From the model's perspective, this is a simple `CrossEntropy` loss over the vocabulary of the model. -In this tutorial, we use the [`wikitext-2-v1`](https://huggingface.co/datasets/wikitext) dataset from [HuggingFace datasets](https://github.com/huggingface/datasets). To see how this is done in code, take a look at `masking_function` in [train_bert.py](./train_bert.py). +In this tutorial, we use the [wikitext-2-v1](https://huggingface.co/datasets/wikitext) dataset from [HuggingFace datasets](https://github.com/huggingface/datasets). To see how this is done in code, take a look at `masking_function` in [train_bert.py](./train_bert.py). -## 1.2 Creating a Transformer model +### 1.2 Creating a Transformer model A Transformer model repeatedly applies a (Multi-Headed) Self-Attention block and a FeedForward layer to generate contextual representations for each token. Thus, the key hyperparameters for a Transformer model usually are 1. The number of Self-Attention + FeedForward blocks (depth) 2. The size of the hidden representation 3. The number of Self Attention Heads -4. The size of the intermediate representation between in the FeedForward block +4. The size of the intermediate representation between the FeedForward block -Check out the `create_model` function in [train_bert.py](./train_bert.py) to see how this is done in code. +Check out the `create_model` function in [train_bert.py](./train_bert.py) to see how this is done in code. In this example, we create a Roberta model[3](#3) --- 📌 **Note:** You can check out [[1](#1), [2](#2)] as a starting point for better understanding Transformers. Additionally, there are a number of blogs that do nice deep dive into the workings of these models (eg: [this](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [this](https://jalammar.github.io/illustrated-bert/) and [this](https://jalammar.github.io/illustrated-transformer/)). --- -## 1.3 Training the Model +### 1.3 Training the Model -In order to train the model, you can run the following +In order to train the model, you can run the following command ```bash python train_bert.py --checkpoint_dir ./experiments @@ -118,10 +118,10 @@ python train_bert.py --checkpoint_dir ./experiments \ ``` -The parameters are explained in more details in the doctstring of `train`. +The parameters are explained in more details in the docstring of `train`. --- -💡 **_Tip:_** If you have a GPU available, you can use it to considerably speedup your training. Simply set the `local_rank` to the GPU you want to run it on. Eg: for a single GPU machine this would look like +💡 **_Tip:_** If you have a GPU available, you can considerably speedup your training by running it on the GPU. Simply set the `local_rank` to the GPU you want to run it on. Eg: for a single GPU machine, this would look like ```bash --local_rank 0 ``` @@ -133,14 +133,18 @@ The parameters are explained in more details in the doctstring of `train`. ## References > [1] -[Vaswani et. al. Attention is all you need. +[Vaswani, Asish et al. Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)](https://arxiv.org/pdf/1706.03762.pdf) - +> > [2] -[Devlin, Jacob et. al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) +[Devlin, Jacob et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) +> +> [3] +[Liu, Yinhan et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach. ArXiv abs/1907.11692 (2019)](https://arxiv.org/pdf/1907.11692.pdf) --------- + ## Scratch pad (TODO Remove) good practices * experiment directory saving training metadata From 3b703779ee501ead405ec04c8fa63f79ae416ee1 Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 14:31:40 -0700 Subject: [PATCH 10/18] spelling mistakes and typographic edits --- Intro101/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Intro101/README.md b/Intro101/README.md index dc7b85d67..d986d6f70 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -60,7 +60,7 @@ The main idea behind the MLM task is to get the model to fill in the blanks base > In the beautiful season of ____ the ____ shed their leaves. -Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to be to fill in blanks. +Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to fill in blanks. In order to do that, we do the following: From 2cabe80ddb8fd1f9d9438d728aef41cf70f37349 Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 14:32:10 -0700 Subject: [PATCH 11/18] spelling mistakes and typographic edits --- Intro101/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Intro101/README.md b/Intro101/README.md index d986d6f70..96313b797 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -60,7 +60,7 @@ The main idea behind the MLM task is to get the model to fill in the blanks base > In the beautiful season of ____ the ____ shed their leaves. -Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to fill in blanks. +Given the left context `season` and the right context `shed their leaves`, one can guess that the blanks are `Autumn` and `trees` respectively. This is exactly what we want the model to do: utilize both the left and right context to fill in the blanks. In order to do that, we do the following: From c7faed8e03e06a4a02bbd066a4a338a2fda8d7e3 Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 14:33:04 -0700 Subject: [PATCH 12/18] spelling mistakes and typographic edits --- Intro101/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Intro101/README.md b/Intro101/README.md index 96313b797..c889e410d 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -85,7 +85,7 @@ A Transformer model repeatedly applies a (Multi-Headed) Self-Attention block and Check out the `create_model` function in [train_bert.py](./train_bert.py) to see how this is done in code. In this example, we create a Roberta model[3](#3) --- -📌 **Note:** You can check out [[1](#1), [2](#2)] as a starting point for better understanding Transformers. Additionally, there are a number of blogs that do nice deep dive into the workings of these models (eg: [this](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [this](https://jalammar.github.io/illustrated-bert/) and [this](https://jalammar.github.io/illustrated-transformer/)). +📌 **Note:** You can check out [[1](#1), [2](#2)] as a starting point for better understanding Transformers. Additionally, there are a number of blogs that do a nice deep dive into the workings of these models (eg: [this](https://nlp.seas.harvard.edu/2018/04/03/attention.html), [this](https://jalammar.github.io/illustrated-bert/) and [this](https://jalammar.github.io/illustrated-transformer/)). --- From 34c0bab33388bb4b7d57776d82c2701be677458e Mon Sep 17 00:00:00 2001 From: bapatra Date: Wed, 20 Oct 2021 14:38:01 -0700 Subject: [PATCH 13/18] spelling mistakes and typographic edits --- Intro101/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Intro101/README.md b/Intro101/README.md index c889e410d..5ed8b2aa5 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -134,13 +134,13 @@ The parameters are explained in more details in the docstring of `train`. ## References > [1] [Vaswani, Asish et al. Attention is all you need. -In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)](https://arxiv.org/pdf/1706.03762.pdf) +_In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)_](https://arxiv.org/pdf/1706.03762.pdf) > > [2] -[Devlin, Jacob et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) +[Devlin, Jacob et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. _In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)_](https://aclanthology.org/N19-1423.pdf) > > [3] -[Liu, Yinhan et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach. ArXiv abs/1907.11692 (2019)](https://arxiv.org/pdf/1907.11692.pdf) +[Liu, Yinhan et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach. _ArXiv abs/1907.11692 (2019)_](https://arxiv.org/pdf/1907.11692.pdf) --------- From 20b31f61e3e31e7e53f6b9ad6bce306b8dce87cd Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 20 Oct 2021 22:02:38 -0700 Subject: [PATCH 14/18] Update README.md --- Intro101/README.md | 174 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 138 insertions(+), 36 deletions(-) diff --git a/Intro101/README.md b/Intro101/README.md index 5ed8b2aa5..35dea05c5 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -128,52 +128,154 @@ The parameters are explained in more details in the docstring of `train`. --- -## 2. Integrating Deepspeed For More Efficient Training +## 2. Integrating DeepSpeed For More Efficient Training +In this next section we'll add DeepSpeed to the model presented in Section 1 and turn on several features. -## References -> [1] -[Vaswani, Asish et al. Attention is all you need. -_In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)_](https://arxiv.org/pdf/1706.03762.pdf) -> -> [2] -[Devlin, Jacob et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. _In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)_](https://aclanthology.org/N19-1423.pdf) -> -> [3] -[Liu, Yinhan et al. RoBERTa: A Robustly Optimized BERT Pretraining Approach. _ArXiv abs/1907.11692 (2019)_](https://arxiv.org/pdf/1907.11692.pdf) +## 2.0 Core DeepSpeed Code Changes + +Please see the [Writing DeepSpeed Models](https://www.deepspeed.ai/getting-started/#writing-deepspeed-models) instructions written on modifying an existing model to use DeepSpeed. Also we will heavily rely on the [DeepSpeed API documentation](https://deepspeed.readthedocs.io/en/latest/) and [config JSON documentation](https://www.deepspeed.ai/docs/config-json/) going forward. ---------- +Please install DeepSpeed via `pip install deepspeed` if you haven't already done so, after installing you can check if your current version and other information via `ds_report`. For this tutorial we assume a DeepSpeed version of >= 0.5.4 and a torch version >= 1.6. Please upgrade via `pip install --upgrade deepspeed` if you are running an older version of DeepSpeed. +### Add deepspeed.initialize + config -## Scratch pad (TODO Remove) -good practices - * experiment directory saving training metadata - * pytests +Our first task is to identify where to add `deepspeed.initialize()` to the existing code in order to use the DeepSpeed training engine. Please see the [deepspeed.initialize API documentation](https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization) for more details. This needs to be done after the model has been created and before the training loop has started. Most of our edits will be inside the `train` function inside [train_bert.py](./train_bert.py). + +After the model is created and before the optimizer is created we want to add the following lines: + +```python +ds_config = { + "train_micro_batch_size_per_gpu": batch_size, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, +} +model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) +``` -data - * what is masked LM - * what does the code do (link to code) +This will create the DeepSpeed training engine based on the previously instantiated model and the new `ds_config` dictionary. We can now also remove the previous lines of code that created an Adam optimizer, this will now be done via the DeepSpeed engine. It should be noted, you can optionally created your own optimizer and pass it into `deepspeed.initialize` however DeepSpeed is able to make further performance optimizations by instantiating its own optimizers. -model - * core params for transformer model (e.g., #layers, attn) +### Update the training-loop -how to run - * launching on CPU (slow) launch on single GPU (fast) - * different train params +Next we want to update our training-loop to use the new model engine with the following changes: ------- +* `optimizer.zero_grad()` can be removed + * The DeepSpeed engine will do this for you at the right time. +* Replace `loss.backward()` with `model.backward(loss)` + * There are several cases where the engine will properly scale the loss when using certain features (e.g., fp16, gradient-accumulation). +* Replace `optimizer.step()` with `model.step()` + * The optimizer step is handled by the engine now and is responsible for dispatching to the right optimizer depending on certain features. -deepspeed additions - * deepspeed.init, training loop, ckpt changes +### Update checkpoint save and load + +Immediately after our new `deepspeed.initialize` you will see a checkpoint load and in the training-loop you will see a few checkpoint save calls. DeepSpeed handles the complexities of checkpoint saving for you so we can simplify these codepaths in the following way. Please refer to the [model checkpoint API documentation](https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html) for more details. + +__Checkpoint saving__: DeepSpeed will construct and save the state_dict for you, we can replace the *two* checkpoint saving snippets (i.e., `state_dict` construction and `torch.save`) and replace them with the snippet below. The `client_state` being passed in here is an example of state outside the view of DeepSpeed that will be saved with the checkpoint. + +```python +model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) +``` + +__Checkpoint loading__: The checkpoint loading is happening right before the training-loop starts. It invokes the `load_model_checkpoint` function which consists of around 30 lines of code. We can replace the `load_model_checkpoint(load_checkpoint_dir, model, optimizer)` call with the following snippet: + +```python +_, client_state = model.load_checkpoint(load_dir=load_checkpoint_dir) +checkpoint_step = client_state['checkpoint_step'] +``` + +## 2.1 Launching training + +We are now ready to launch our training! As a convenience, DeepSpeed provides its own launcher that is seamlessly compatible with internal clusters at MSFT (e.g., ITP). You can now try running your model on your available GPU(s) with the command below. By default this will attempt to run data-parallel training across all available GPUs on the current machine + any external machines listed in your `/job/hostfile`. Please read [more details about the DeepSpeed launcher](https://www.deepspeed.ai/getting-started/#launching-deepspeed-training) on our website. + +```bash +deepspeed train_bert.py --checkpoint_dir . +``` + +--- +📌 **Note:** If using the deepspeed launcher you should not pass the `--local_rank` explicitly. This will be done by the launcher in the same way as if you launched with `torch.distributed.launch` from PyTorch. + +--- + +## 2.2 Mixed Precision Training (fp16) + +Now that we are setup to use the DeepSpeed engine with our model we can start trying out a few different features of DeepSpeed. One feature is mixed precision training that utilizes half precision (floating-point 16 or fp16) data types. If you want to learn more about how mixed precision training works please refer to the Mixed Precision Training paper [[3]](https://arxiv.org/pdf/1710.03740v3.pdf) from Baidu and NVIDIA on the topic. + +To enable this mode in DeepSpeed we need to update our `ds_config` before the engine is created. Please see [fp16 training options](https://www.deepspeed.ai/docs/config-json/#fp16-training-options) in the config documentation for more information. In our case let's simple enable it by adding the following to our `ds_config` dictionary: + +```python + "fp16": { + "enabled": True + } +``` + +The memory reduction by switching from fp32 to fp16 results in the *model parameters* using half the amount of GPU memory, however the overall GPU memory reduction is not as simple. Since fp16 has half the available bits as fp32 it is not able to represent the same expressiveness as fp32, which can result in numeric instabilities during training. We are able to get around these instabilities in most cases by keeping some states in fp16 and others remain in fp32 (see Section 3 in [[3]](https://arxiv.org/pdf/1710.03740v3.pdf) if you'd like to learn more). + +The primary reason to utilize fp16 training is due to *Tensor Cores*. If you are training with NVIDIA V100 or A100 GPUs they include Tensor Cores which in some cases can accelerate computation by as much as 8x if certain conditions are met. One of the most important conditions is that your model parameters are stored as fp16. For more details on other conditions and tips to better utilize these cores please see this guide from NVIDIA on [Tips for Optimizing GPU Performance Using Tensor Cores](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/). + +--- +📌 **Note:** At the start of training you will probably see several log messages about loss scaling and overflows, this is normal. In order for fp16 training to be numerically stable we utilize a common technique called "loss scaling" (similar to Section 3.2 in [[3]](https://arxiv.org/pdf/1710.03740v3.pdf)). This attempts to find a scaling value to mitigate gradient over/under-flows during training. + +--- + +## 2.3 Zero Redundancy Optimizer (ZeRO) + +ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our three papers [[4](https://arxiv.org/pdf/1910.02054.pdf), [5](https://www.usenix.org/system/files/atc21-ren-jie.pdf), [6](https://arxiv.org/abs/2104.07857)] that explore different optimizations in this space. We will focus on two features of ZeRO here, ZeRO Stage 1 and ZeRO-Offload. For further information, please refer to our [tutorial deep diving ZeRO](https://www.deepspeed.ai/tutorials/zero/) and our [tutorial deep diving ZeRO Offload](https://www.deepspeed.ai/tutorials/zero-offload/) on our website. + +* ZeRO Stage 1: The optimizer states (e.g., for the Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. +* ZeRO-Offload: Supports efficiently offloading optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be trained on a single GPU. + +To enable ZeRO Stage 1 in DeepSpeed we need to again update our `ds_config` before the engine is created. Please see [ZeRO optimizations](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training) in the DeepSpeed config documentation for more information. In our case let's simply enable stage 1 it by adding the following to our `ds_config` dictionary: + +```python + "zero_optimization": { + "stage": 1 + } +``` + +We can re-run our training now with ZeRO stage 1 enabled and should see a per-GPU memory reduction as we scale up the total number of GPUs. Typically you can now use this extra GPU memory to either scale up your model size or scale up your per-GPU training batch size. However, if we only have 1 GPU available we probably want to enable ZeRO-Offload to allow us to train larger model sizes. Please update your `ds_config` to include the following: + +```python + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "cpu" + } + } +``` + +This config will now allow us to train a much larger model than we were previously able to do. For example on a single P40 GPU with 24GB of memory we are unable to train a 2 billion parameter model (i.e., `--num_layers 24 --h_dim 4096`), however with ZeRO-Offload we now can! + +```bash +deepspeed train_bert.py --checkpoint_dir . --num_layers 24 --h_dim 4096 +``` + +--- +📌 **Note:** Earlier on when we setup `deepspeed.initialize` we chose not to explicitly pass an optimizer and instead let the DeepSpeed engine instantiate one for us. This is especially useful now that we are using ZeRO-Offload. DeepSpeed includes a highly optimized version of Adam that executes purely on CPU. This means that DeepSpeed will detect if you are using ZeRO-Offload w. Adam and switch to optimized CPUAdam variant. + +--- + +## References +> [1] +[Vaswani et al. Attention is all you need. +In Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS'17)](https://arxiv.org/pdf/1706.03762.pdf) + +> [2] +[J. Devlin et al. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT'19)](https://aclanthology.org/N19-1423.pdf) + +> [3] +[P. Micikevicius et al. Mixed Precision Training (ICLR'18)](https://arxiv.org/pdf/1710.03740v3.pdf) -launching across multiple GPUs +> [4]> +[S. Rajbhandari, J. Rasley, O. Ruwase, Y. He. ZeRO: memory optimizations toward training trillion parameter models. (SC‘20)](https://arxiv.org/pdf/1910.02054.pdf) -fp16 - * how to enable - * show memory reduction when enabled via nvidia-smi - * brief overview of how fp16 training works (e.g., loss scaling) +> [5] +[J. Ren, S. Rajbhandari, R. Aminabadi, O. Ruwase, S. Yang, M. Zhang, D. Li, Y. He. ZeRO-Offload: Democratizing Billion-Scale Model Training. (ATC'21)](https://www.usenix.org/system/files/atc21-ren-jie.pdf) -zero - * introduce how zero reduces memory - * introduce zero offload - * update config to use z1 + offload to showcase a model that can only run with offload enabled +> [6] +[S. Rajbhandari, O. Ruwase, J. Rasley, S. Smith, Y. He. ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning (SC'21)](https://arxiv.org/abs/2104.07857) From 5edfa023a45faeb11e4b857f1b17ec26333ec1ce Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 20 Oct 2021 22:11:52 -0700 Subject: [PATCH 15/18] Update README.md --- Intro101/README.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/Intro101/README.md b/Intro101/README.md index 35dea05c5..a7db26930 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -2,6 +2,23 @@ In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate Deepspeed, and show some of the advantages of doing so. +Table of contents +================= + + + * [(1) Training a Transformer Encoder (BERT / Roberta) model for MLM](#1-training-a-transformer-encoder-bert--roberta-model-for-mlm) + * [1.0 Some Good Practices](#10-some-good-practices) + * [1.1 The Masked Language Modeling Task](#11-the-masked-language-modeling-task) + * [1.2 Creating a Transformer model](#12-creating-a-transformer-model) + * [1.3 Training the Model](#13-training-the-model) + * [(2) Integrating DeepSpeed For More Efficient Training](#2-integrating-deepspeed-for-more-efficient-training) + * [2.0 Core DeepSpeed Code Changes](#20-core-deepspeed-code-changes) + * [2.1 Launching Training](#21-launching-training) + * [2.2 Mixed Precision Training (fp16)](#22-mixed-precision-training-fp16) + * [2.3 Zero Redundancy Optimizer (ZeRO)](#23-zero-redundancy-optimizer-zero) + * [References](#references) + + ## 1. Training a Transformer Encoder (BERT / Roberta) model for MLM ### 1.0 Some Good Practices @@ -189,7 +206,7 @@ _, client_state = model.load_checkpoint(load_dir=load_checkpoint_dir) checkpoint_step = client_state['checkpoint_step'] ``` -## 2.1 Launching training +## 2.1 Launching Training We are now ready to launch our training! As a convenience, DeepSpeed provides its own launcher that is seamlessly compatible with internal clusters at MSFT (e.g., ITP). You can now try running your model on your available GPU(s) with the command below. By default this will attempt to run data-parallel training across all available GPUs on the current machine + any external machines listed in your `/job/hostfile`. Please read [more details about the DeepSpeed launcher](https://www.deepspeed.ai/getting-started/#launching-deepspeed-training) on our website. From ee7258b1b60f40e2676ed2e9f434b0ca1239a57a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 20 Oct 2021 22:12:22 -0700 Subject: [PATCH 16/18] Update README.md --- Intro101/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Intro101/README.md b/Intro101/README.md index a7db26930..a88f0a2c4 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -1,6 +1,6 @@ -# Training an Masked Language Model with PyTorch and Deepspeed +# Training a Masked Language Model with PyTorch and DeepSpeed -In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate Deepspeed, and show some of the advantages of doing so. +In this tutorial, we will create and train a Transformer encoder on the Masked Language Modeling (MLM) task. Then we will show the changes necessary to integrate DeepSpeed, and show some of the advantages of doing so. Table of contents ================= From fffad1f6ceaf467519b319761c658627257ee69e Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 20 Oct 2021 22:14:16 -0700 Subject: [PATCH 17/18] Update train_bert.py --- Intro101/train_bert.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index 67c07c60b..c36f34152 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -725,6 +725,7 @@ def train( ) model = model.to(device) logger.info("Model Creation Done") + logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") ################################ ###### Create Optimizer ####### ################################ From 3ec9c13bf11d3eca9d65a5d7fcf960e842d284c7 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 21 Oct 2021 05:49:13 +0000 Subject: [PATCH 18/18] add ds example code + a few minor edits --- Intro101/README.md | 4 +- Intro101/train_bert.py | 2 +- Intro101/train_bert_ds.py | 805 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 809 insertions(+), 2 deletions(-) create mode 100644 Intro101/train_bert_ds.py diff --git a/Intro101/README.md b/Intro101/README.md index a88f0a2c4..a046c0dad 100644 --- a/Intro101/README.md +++ b/Intro101/README.md @@ -182,8 +182,10 @@ This will create the DeepSpeed training engine based on the previously instantia Next we want to update our training-loop to use the new model engine with the following changes: +* `model.to(device)` can be removed + * DeepSpeed will be careful on when to move the model to GPU to reduce GPU memory usage (e.g., converts to half on CPU then moves to GPU) * `optimizer.zero_grad()` can be removed - * The DeepSpeed engine will do this for you at the right time. + * DeepSpeed will do this for you at the right time. * Replace `loss.backward()` with `model.backward(loss)` * There are several cases where the engine will properly scale the loss when using certain features (e.g., fp16, gradient-accumulation). * Replace `optimizer.step()` with `model.step()` diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index c36f34152..14d61f00c 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -725,7 +725,6 @@ def train( ) model = model.to(device) logger.info("Model Creation Done") - logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") ################################ ###### Create Optimizer ####### ################################ @@ -745,6 +744,7 @@ def train( ################################ ####### The Training Loop ###### ################################ + logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") model.train() losses = [] for step, batch in enumerate(data_iterator, start=start_step): diff --git a/Intro101/train_bert_ds.py b/Intro101/train_bert_ds.py new file mode 100644 index 000000000..421d03daf --- /dev/null +++ b/Intro101/train_bert_ds.py @@ -0,0 +1,805 @@ +""" +Modified version of train_bert.py that adds DeepSpeed +""" + +import datetime +import json +import pathlib +import re +import string +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union + +import datasets +import fire +import loguru +import numpy as np +import pytz +import sh +import torch +import torch.nn as nn +import deepspeed +from torch.utils.data import DataLoader, Dataset +from torch.utils.tensorboard import SummaryWriter +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.models.roberta import RobertaConfig, RobertaModel +from transformers.models.roberta.modeling_roberta import ( + RobertaLMHead, + RobertaPreTrainedModel, +) + +logger = loguru.logger + +###################################################################### +############### Dataset Creation Related Functions ################### +###################################################################### + + +TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +def collate_function( + batch: List[Tuple[List[int], List[int]]], pad_token_id: int +) -> Dict[str, torch.Tensor]: + """Collect a list of masked token indices, and labels, and + batch them, padding to max length in the batch. + """ + max_length = max(len(token_ids) for token_ids, _ in batch) + padded_token_ids = [ + token_ids + [pad_token_id for _ in range(0, max_length - len(token_ids))] + for token_ids, _ in batch + ] + padded_labels = [ + labels + [pad_token_id for _ in range(0, max_length - len(labels))] + for _, labels in batch + ] + src_tokens = torch.LongTensor(padded_token_ids) + tgt_tokens = torch.LongTensor(padded_labels) + attention_mask = src_tokens.ne(pad_token_id).type_as(src_tokens) + return { + "src_tokens": src_tokens, + "tgt_tokens": tgt_tokens, + "attention_mask": attention_mask, + } + + +def masking_function( + text: str, + tokenizer: TokenizerType, + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + max_length: int, +) -> Tuple[List[int], List[int]]: + """Given a text string, randomly mask wordpieces for Bert MLM + training. + + Args: + text (str): + The input text + tokenizer (TokenizerType): + The tokenizer for tokenization + mask_prob (float): + What fraction of tokens to mask + random_replace_prob (float): + Of the masked tokens, how many should be replaced with + random tokens (improves performance) + unmask_replace_prob (float): + Of the masked tokens, how many should be replaced with + the original token (improves performance) + max_length (int): + The maximum sequence length to consider. Note that for + Bert style models, this is a function of the number of + positional embeddings you learn + + Returns: + Tuple[List[int], List[int]]: + The masked token ids (based on the tokenizer passed), + and the output labels (padded with `tokenizer.pad_token_id`) + """ + # Note: By default, encode does add the BOS and EOS token + # Disabling that behaviour to make this more clear + tokenized_ids = ( + [tokenizer.bos_token_id] + + tokenizer.encode( + text, add_special_tokens=False, truncation=True, max_length=max_length - 2 + ) + + [tokenizer.eos_token_id] + ) + seq_len = len(tokenized_ids) + tokenized_ids = np.array(tokenized_ids) + subword_mask = np.full(len(tokenized_ids), False) + + # Masking the BOS and EOS token leads to slightly worse performance + low = 1 + high = len(subword_mask) - 1 + mask_choices = np.arange(low, high) + num_subwords_to_mask = max(int((mask_prob * (high - low)) + np.random.rand()), 1) + subword_mask[ + np.random.choice(mask_choices, num_subwords_to_mask, replace=False) + ] = True + + # Create the labels first + labels = np.full(seq_len, tokenizer.pad_token_id) + labels[subword_mask] = tokenized_ids[subword_mask] + + tokenized_ids[subword_mask] = tokenizer.mask_token_id + + # Now of the masked tokens, choose how many to replace with random and how many to unmask + rand_or_unmask_prob = random_replace_prob + unmask_replace_prob + if rand_or_unmask_prob > 0: + rand_or_unmask = subword_mask & ( + np.random.rand(len(tokenized_ids)) < rand_or_unmask_prob + ) + if random_replace_prob == 0: + unmask = rand_or_unmask + rand_mask = None + elif unmask_replace_prob == 0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = unmask_replace_prob / rand_or_unmask_prob + decision = np.random.rand(len(tokenized_ids)) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + if unmask is not None: + tokenized_ids[unmask] = labels[unmask] + if rand_mask is not None: + weights = np.ones(tokenizer.vocab_size) + weights[tokenizer.all_special_ids] = 0 + probs = weights / weights.sum() + num_rand = rand_mask.sum() + tokenized_ids[rand_mask] = np.random.choice( + tokenizer.vocab_size, num_rand, p=probs + ) + return tokenized_ids.tolist(), labels.tolist() + + +class WikiTextMLMDataset(Dataset): + """A [Map style dataset](https://pytorch.org/docs/stable/data.html) + for iterating over the wikitext dataset. Note that this assumes + the dataset can fit in memory. For larger datasets + you'd want to shard them and use an iterable dataset (eg: see + [Infinibatch](https://github.com/microsoft/infinibatch)) + + Args: + Dataset (datasets.arrow_dataset.Dataset): + The wikitext dataset + masking_function (Callable[[str], Tuple[List[int], List[int]]]) + The masking function. To generate one training instance, + the masking function is applied to the `text` of a dataset + record + + """ + + def __init__( + self, + dataset: datasets.arrow_dataset.Dataset, + masking_function: Callable[[str], Tuple[List[int], List[int]]], + ) -> None: + self.dataset = dataset + self.masking_function = masking_function + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Tuple[List[int], List[int]]: + tokens, labels = self.masking_function(self.dataset[idx]["text"]) + return (tokens, labels) + + +T = TypeVar("T") + + +class InfiniteIterator(object): + def __init__(self, iterable: Iterable[T]) -> None: + self._iterable = iterable + self._iterator = iter(self._iterable) + + def __iter__(self): + return self + + def __next__(self) -> T: + next_item = None + try: + next_item = next(self._iterator) + except StopIteration: + self._iterator = iter(self._iterable) + next_item = next(self._iterator) + return next_item + + +def create_data_iterator( + mask_prob: float, + random_replace_prob: float, + unmask_replace_prob: float, + batch_size: int, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", +) -> InfiniteIterator: + """Create the dataloader. + + Args: + mask_prob (float): + Fraction of tokens to mask + random_replace_prob (float): + Fraction of masked tokens to replace with random token + unmask_replace_prob (float): + Fraction of masked tokens to replace with the actual token + batch_size (int): + The batch size of the generated tensors + max_seq_length (int, optional): + The maximum sequence length for the MLM task. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + + Returns: + InfiniteIterator: + The torch DataLoader, wrapped in an InfiniteIterator class, to + be able to continuously generate samples + + """ + wikitext_dataset = datasets.load_dataset("wikitext", "wikitext-2-v1", split="train") + wikitext_dataset = wikitext_dataset.filter(lambda record: record["text"] != "").map( + lambda record: {"text": record["text"].rstrip("\n")} + ) + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + masking_function_partial = partial( + masking_function, + tokenizer=tokenizer, + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + max_length=max_seq_length, + ) + dataset = WikiTextMLMDataset(wikitext_dataset, masking_function_partial) + collate_fn_partial = partial(collate_function, pad_token_id=tokenizer.pad_token_id) + dataloader = DataLoader( + dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_partial + ) + + return InfiniteIterator(dataloader) + + +###################################################################### +############### Model Creation Related Functions ##################### +###################################################################### + + +class RobertaLMHeadWithMaskedPredict(RobertaLMHead): + def __init__( + self, config: RobertaConfig, embedding_weight: Optional[torch.Tensor] = None + ) -> None: + super(RobertaLMHeadWithMaskedPredict, self).__init__(config) + if embedding_weight is not None: + self.decoder.weight = embedding_weight + + def forward( # pylint: disable=arguments-differ + self, + features: torch.Tensor, + masked_token_indices: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """The current `transformers` library does not provide support + for masked_token_indices. This function provides the support, by + running the final forward pass only for the masked indices. This saves + memory + + Args: + features (torch.Tensor): + The features to select from. Shape (batch, seq_len, h_dim) + masked_token_indices (torch.Tensor, optional): + The indices of masked tokens for index select. Defaults to None. + Shape: (num_masked_tokens,) + + Returns: + torch.Tensor: + The index selected features. Shape (num_masked_tokens, h_dim) + + """ + if masked_token_indices is not None: + features = torch.index_select( + features.view(-1, features.shape[-1]), 0, masked_token_indices + ) + return super().forward(features) + + +class RobertaMLMModel(RobertaPreTrainedModel): + def __init__(self, config: RobertaConfig, encoder: RobertaModel) -> None: + super().__init__(config) + self.encoder = encoder + self.lm_head = RobertaLMHeadWithMaskedPredict( + config, self.encoder.embeddings.word_embeddings.weight + ) + self.lm_head.apply(self._init_weights) + + def forward( + self, + src_tokens: torch.Tensor, + attention_mask: torch.Tensor, + tgt_tokens: torch.Tensor, + ) -> torch.Tensor: + """The forward pass for the MLM task + + Args: + src_tokens (torch.Tensor): + The masked token indices. Shape: (batch, seq_len) + attention_mask (torch.Tensor): + The attention mask, since the batches are padded + to the largest sequence. Shape: (batch, seq_len) + tgt_tokens (torch.Tensor): + The output tokens (padded with `config.pad_token_id`) + + Returns: + torch.Tensor: + The MLM loss + """ + # shape: (batch, seq_len, h_dim) + sequence_output, *_ = self.encoder( + input_ids=src_tokens, attention_mask=attention_mask, return_dict=False + ) + + pad_token_id = self.config.pad_token_id + # (labels have also been padded with pad_token_id) + # filter out all masked labels + # shape: (num_masked_tokens,) + masked_token_indexes = torch.nonzero( + (tgt_tokens != pad_token_id).view(-1) + ).view(-1) + # shape: (num_masked_tokens, vocab_size) + prediction_scores = self.lm_head(sequence_output, masked_token_indexes) + # shape: (num_masked_tokens,) + target = torch.index_select(tgt_tokens.view(-1), 0, masked_token_indexes) + + loss_fct = nn.CrossEntropyLoss(ignore_index=-1) + + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), target + ) + return masked_lm_loss + + +def create_model( + num_layers: int, num_heads: int, ff_dim: int, h_dim: int, dropout: float +) -> RobertaMLMModel: + """Create a Bert model with the specified `num_heads`, `ff_dim`, + `h_dim` and `dropout` + + Args: + num_layers (int): + The number of layers + num_heads (int): + The number of attention heads + ff_dim (int): + The intermediate hidden size of + the feed forward block of the + transformer + h_dim (int): + The hidden dim of the intermediate + representations of the transformer + dropout (float): + The value of dropout to be used. + Note that we apply the same dropout + to both the attention layers and the + FF layers + + Returns: + RobertaMLMModel: + A Roberta model for MLM task + + """ + roberta_config_dict = { + "attention_probs_dropout_prob": dropout, + "bos_token_id": 0, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": dropout, + "hidden_size": h_dim, + "initializer_range": 0.02, + "intermediate_size": ff_dim, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 514, + "model_type": "roberta", + "num_attention_heads": num_heads, + "num_hidden_layers": num_layers, + "pad_token_id": 1, + "type_vocab_size": 1, + "vocab_size": 50265, + } + roberta_config = RobertaConfig.from_dict(roberta_config_dict) + roberta_encoder = RobertaModel(roberta_config) + roberta_model = RobertaMLMModel(roberta_config, roberta_encoder) + return roberta_model + + +###################################################################### +########### Experiment Management Related Functions ################## +###################################################################### + + +def get_unique_identifier(length: int = 8) -> str: + """Create a unique identifier by choosing `length` + random characters from list of ascii characters and numbers + """ + alphabet = string.ascii_lowercase + string.digits + uuid = "".join(alphabet[ix] for ix in np.random.choice(len(alphabet), length)) + return uuid + + +def create_experiment_dir( + checkpoint_dir: pathlib.Path, all_arguments: Dict[str, Any] +) -> pathlib.Path: + """Create an experiment directory and save all arguments in it. + Additionally, also store the githash and gitdiff. Finally create + a directory for `Tensorboard` logs. The structure would look something + like + checkpoint_dir + `-experiment-name + |- hparams.json + |- githash.log + |- gitdiff.log + `- tb_dir/ + + Args: + checkpoint_dir (pathlib.Path): + The base checkpoint directory + all_arguments (Dict[str, Any]): + The arguments to save + + Returns: + pathlib.Path: The experiment directory + """ + # experiment name follows the following convention + # {exp_type}.{YYYY}.{MM}.{DD}.{HH}.{MM}.{SS}.{uuid} + current_time = datetime.datetime.now(pytz.timezone("US/Pacific")) + expname = "bert_pretrain.{0}.{1}.{2}.{3}.{4}.{5}.{6}".format( + current_time.year, + current_time.month, + current_time.day, + current_time.hour, + current_time.minute, + current_time.second, + get_unique_identifier(), + ) + exp_dir = checkpoint_dir / expname + exp_dir.mkdir(exist_ok=False) + hparams_file = exp_dir / "hparams.json" + with hparams_file.open("w") as handle: + json.dump(obj=all_arguments, fp=handle, indent=2) + # Save the git hash + try: + gitlog = sh.git.log("-1", format="%H", _tty_out=False, _fg=False) + with (exp_dir / "githash.log").open("w") as handle: + handle.write(gitlog.stdout.decode("utf-8")) + except sh.ErrorReturnCode_128: + logger.info( + "Seems like the code is not running from" + " within a git repo, so hash will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # And the git diff + try: + gitdiff = sh.git.diff(_fg=False, _tty_out=False) + with (exp_dir / "gitdiff.log").open("w") as handle: + handle.write(gitdiff.stdout.decode("utf-8")) + except sh.ErrorReturnCode_129: + logger.info( + "Seems like the code is not running from" + " within a git repo, so diff will" + " not be stored. However, it" + " is strongly advised to use" + " version control." + ) + # Finally create the Tensorboard Dir + tb_dir = exp_dir / "tb_dir" + tb_dir.mkdir() + return exp_dir + + +###################################################################### +################ Checkpoint Related Functions ######################## +###################################################################### + + +def load_model_checkpoint( + load_checkpoint_dir: pathlib.Path, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, +) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + """Loads the optimizer state dict and model state dict from the load_checkpoint_dir + into the passed model and optimizer. Searches for the most recent checkpoint to + load from + + Args: + load_checkpoint_dir (pathlib.Path): + The base checkpoint directory to load from + model (torch.nn.Module): + The model to load the checkpoint weights into + optimizer (torch.optim.Optimizer): + The optimizer to load the checkpoint weigths into + + Returns: + Tuple[int, torch.nn.Module, torch.optim.Optimizer]: + The checkpoint step, model with state_dict loaded and + optimizer with state_dict loaded + + """ + logger.info(f"Loading model and optimizer checkpoint from {load_checkpoint_dir}") + checkpoint_files = list( + filter( + lambda path: re.search(r"iter_(?P\d+)\.pt", path.name) is not None, + load_checkpoint_dir.glob("*.pt"), + ) + ) + assert len(checkpoint_files) > 0, "No checkpoints found in directory" + checkpoint_files = sorted( + checkpoint_files, + key=lambda path: int( + re.search(r"iter_(?P\d+)\.pt", path.name).group("iter_no") + ), + ) + latest_checkpoint_path = checkpoint_files[-1] + checkpoint_step = int( + re.search(r"iter_(?P\d+)\.pt", latest_checkpoint_path.name).group( + "iter_no" + ) + ) + + state_dict = torch.load(latest_checkpoint_path) + model.load_state_dict(state_dict["model"], strict=True) + optimizer.load_state_dict(state_dict["optimizer"]) + logger.info( + f"Loading model and optimizer checkpoints done. Loaded from {latest_checkpoint_path}" + ) + return checkpoint_step, model, optimizer + + +###################################################################### +######################## Driver Functions ############################ +###################################################################### + + +def train( + checkpoint_dir: str = None, + load_checkpoint_dir: str = None, + # Dataset Parameters + mask_prob: float = 0.15, + random_replace_prob: float = 0.1, + unmask_replace_prob: float = 0.1, + max_seq_length: int = 512, + tokenizer: str = "roberta-base", + # Model Parameters + num_layers: int = 6, + num_heads: int = 8, + ff_dim: int = 512, + h_dim: int = 256, + dropout: float = 0.1, + # Training Parameters + batch_size: int = 8, + num_iterations: int = 10000, + checkpoint_every: int = 1000, + log_every: int = 10, + local_rank: int = -1, +) -> pathlib.Path: + """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) + (transformer encoder only) model for MLM Task + + Args: + checkpoint_dir (str): + The base experiment directory to save experiments to + mask_prob (float, optional): + The fraction of tokens to mask. Defaults to 0.15. + random_replace_prob (float, optional): + The fraction of masked tokens to replace with random token. + Defaults to 0.1. + unmask_replace_prob (float, optional): + The fraction of masked tokens to leave unchanged. + Defaults to 0.1. + max_seq_length (int, optional): + The maximum sequence length of the examples. Defaults to 512. + tokenizer (str, optional): + The tokenizer to use. Defaults to "roberta-base". + num_layers (int, optional): + The number of layers in the Bert model. Defaults to 6. + num_heads (int, optional): + Number of attention heads to use. Defaults to 8. + ff_dim (int, optional): + Size of the intermediate dimension in the FF layer. + Defaults to 512. + h_dim (int, optional): + Size of intermediate representations. + Defaults to 256. + dropout (float, optional): + Amout of Dropout to use. Defaults to 0.1. + batch_size (int, optional): + The minibatch size. Defaults to 8. + num_iterations (int, optional): + Total number of iterations to run the model for. + Defaults to 10000. + checkpoint_every (int, optional): + Save checkpoint after these many steps. + + ..note :: + + You want this to be frequent enough that you can + resume training in case it crashes, but not so much + that you fill up your entire storage ! + + Defaults to 1000. + log_every (int, optional): + Print logs after these many steps. Defaults to 10. + local_rank (int, optional): + Which GPU to run on (-1 for CPU). Defaults to -1. + + Returns: + pathlib.Path: The final experiment directory + + """ + device = ( + torch.device("cuda", local_rank) + if (local_rank > -1) and torch.cuda.is_available() + else torch.device("cpu") + ) + ################################ + ###### Create Exp. Dir ######### + ################################ + if checkpoint_dir is None and load_checkpoint_dir is None: + logger.error("Need to specify one of checkpoint_dir" " or load_checkpoint_dir") + return + if checkpoint_dir is not None and load_checkpoint_dir is not None: + logger.error("Cannot specify both checkpoint_dir" " and load_checkpoint_dir") + return + if checkpoint_dir: + logger.info("Creating Experiment Directory") + checkpoint_dir = pathlib.Path(checkpoint_dir) + checkpoint_dir.mkdir(exist_ok=True) + all_arguments = { + # Dataset Params + "mask_prob": mask_prob, + "random_replace_prob": random_replace_prob, + "unmask_replace_prob": unmask_replace_prob, + "max_seq_length": max_seq_length, + "tokenizer": tokenizer, + # Model Params + "num_layers": num_layers, + "num_heads": num_heads, + "ff_dim": ff_dim, + "h_dim": h_dim, + "dropout": dropout, + # Training Params + "batch_size": batch_size, + "num_iterations": num_iterations, + "checkpoint_every": checkpoint_every, + } + exp_dir = create_experiment_dir(checkpoint_dir, all_arguments) + logger.info(f"Experiment Directory created at {exp_dir}") + else: + logger.info("Loading from Experiment Directory") + load_checkpoint_dir = pathlib.Path(load_checkpoint_dir) + assert load_checkpoint_dir.exists() + with (load_checkpoint_dir / "hparams.json").open("r") as handle: + hparams = json.load(handle) + # Set the hparams + # Dataset Params + mask_prob = hparams.get("mask_prob", mask_prob) + tokenizer = hparams.get("tokenizer", tokenizer) + random_replace_prob = hparams.get("random_replace_prob", random_replace_prob) + unmask_replace_prob = hparams.get("unmask_replace_prob", unmask_replace_prob) + max_seq_length = hparams.get("max_seq_length", max_seq_length) + # Model Params + ff_dim = hparams.get("ff_dim", ff_dim) + h_dim = hparams.get("h_dim", h_dim) + dropout = hparams.get("dropout", dropout) + num_layers = hparams.get("num_layers", num_layers) + num_heads = hparams.get("num_heads", num_heads) + # Training Params + batch_size = hparams.get("batch_size", batch_size) + _num_iterations = hparams.get("num_iterations", num_iterations) + num_iterations = max(num_iterations, _num_iterations) + checkpoint_every = hparams.get("checkpoint_every", checkpoint_every) + exp_dir = load_checkpoint_dir + # Tensorboard writer + tb_dir = exp_dir / "tb_dir" + assert tb_dir.exists() + summary_writer = SummaryWriter(log_dir=tb_dir) + ################################ + ###### Create Datasets ######### + ################################ + logger.info("Creating Datasets") + data_iterator = create_data_iterator( + mask_prob=mask_prob, + random_replace_prob=random_replace_prob, + unmask_replace_prob=unmask_replace_prob, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + batch_size=batch_size, + ) + logger.info("Dataset Creation Done") + ################################ + ###### Create Model ############ + ################################ + logger.info("Creating Model") + model = create_model( + num_layers=num_layers, + num_heads=num_heads, + ff_dim=ff_dim, + h_dim=h_dim, + dropout=dropout, + ) + logger.info("Model Creation Done") + ################################ + ###### DeepSpeed engine ######## + ################################ + logger.info("Creating DeepSpeed engine") + ds_config = { + "train_micro_batch_size_per_gpu": batch_size, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "cpu" + } + } + } + model, _, _, _ = deepspeed.initialize(model=model, + model_parameters=model.parameters(), + config=ds_config) + logger.info("DeepSpeed engine created") + ################################ + #### Load Model checkpoint ##### + ################################ + start_step = 1 + if load_checkpoint_dir is not None: + _, client_state = model.load_checkpoint(load_dir=load_checkpoint_dir) + checkpoint_step = client_state['checkpoint_step'] + start_step = checkpoint_step + 1 + + ################################ + ####### The Training Loop ###### + ################################ + logger.info(f"Total number of model parameters: {sum([p.numel() for p in model.parameters()]):,d}") + model.train() + losses = [] + for step, batch in enumerate(data_iterator, start=start_step): + if step >= num_iterations: + break + # Move the tensors to device + for key, value in batch.items(): + batch[key] = value.to(device) + # Forward pass + loss = model(**batch) + # Backward pass + model.backward(loss) + # Optimizer Step + model.step() + losses.append(loss.item()) + if step % log_every == 0: + logger.info("Loss: {0:.4f}".format(np.mean(losses))) + summary_writer.add_scalar(f"Train/loss", np.mean(losses), step) + if step % checkpoint_every == 0: + model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) + logger.info( + "Saved model to {0}".format(exp_dir) + ) + # Save the last checkpoint if not saved yet + if step % checkpoint_every != 0: + model.save_checkpoint(save_dir=exp_dir, client_state={'checkpoint_step': step}) + logger.info( + "Saved model to {0}".format(exp_dir) + ) + + return exp_dir + + +if __name__ == "__main__": + fire.Fire(train)