diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md new file mode 100644 index 000000000..a80e3510c --- /dev/null +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md @@ -0,0 +1 @@ +This is an example of how to use DeepSpeed's curriculum learning (CL) feature which provides faster and more stable language model pre-training. Currently it is only integrated for GPT pre-training. Note that there are two curriculum learning examples in two different repos for Megatron-LM GPT-2 pre-training. Both of them have some unique features and limitations. See details in our [tutorial](https://www.deepspeed.ai/tutorials/curriculum-learning/). For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). \ No newline at end of file diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh index 959af6813..338b93f42 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_pretrain_gpt2.sh @@ -11,7 +11,7 @@ SEED=$8 SAVE_INTERVAL=$9 NUM_ITER=${10} NUM_TOKEN=${11} -LR_DECAY_ITER=${12} +LR_DECAY_TOKEN=${12} LR_WARMUP_ITER=${13} CONFIG_TEMPLATE=${14} CURRICULUM_STEP=${15} @@ -74,7 +74,7 @@ else config_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}.json" fi -JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_ITER}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}" +JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_TOKEN}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}" LOG_NAME="${JOB_NAME}_${host}_${current_time}" #Actication Checkpointing and Contigious Memory @@ -102,7 +102,7 @@ gpt_options=" \ --batch-size $BATCHSIZE \ --train-iters $NUM_ITER \ --train-tokens $NUM_TOKEN \ - --lr-decay-iters $LR_DECAY_ITER \ + --lr-decay-tokens $LR_DECAY_TOKEN \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ --data-path $DATA_PATH \ diff --git a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh index ff7e7e58b..aac11ab03 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh +++ b/Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh @@ -8,10 +8,9 @@ # MP_SIZE=1 # SEED=1234 # SAVE_INTERVAL=5000 - # NUM_ITER=600000 # NUM_TOKEN=157286400000 -# LR_DECAY_ITER=300000 +# LR_DECAY_TOKEN=157286400000 # LR_WARMUP_ITER=3000 # CONFIG_TEMPLATE=false # CURRICULUM_STEP=0 @@ -26,15 +25,13 @@ SEQ_LEN=1024 MP_SIZE=1 SEED=1234 SAVE_INTERVAL=1000 - NUM_ITER=75000 NUM_TOKEN=157286400000 +LR_DECAY_TOKEN=157286400000 LR_WARMUP_ITER=3000 CONFIG_TEMPLATE=true -CURRICULUM_STEP=15000 +CURRICULUM_STEP=45000 CURRICULUM_MIN=64 - -LR_DECAY_ITER=$((37500 + ${CURRICULUM_STEP} / 2)) TAG="${CONFIG}_s${CURRICULUM_MIN}to${SEQ_LEN}_step${CURRICULUM_STEP}" -bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_ITER $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN +bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_TOKEN $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py index bb1cb6779..f95020af5 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py @@ -296,6 +296,8 @@ def _add_learning_rate_args(parser): group.add_argument('--lr-decay-iters', type=int, default=None, help='number of iterations to decay learning rate over,' ' If None defaults to `--train-iters`') + group.add_argument('--lr-decay-tokens', type=int, default=None, + help='Learning rate decay tokens.') group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py index 123a805e1..fd8eb6cb1 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py @@ -268,7 +268,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): else: try: iteration = state_dict['iteration'] - args.tokens = state_dict['tokens'] + if 'tokens' in state_dict: + args.tokens = state_dict['tokens'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] diff --git a/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py b/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py index afc5a8d77..19be32b8c 100644 --- a/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py +++ b/Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py @@ -17,7 +17,7 @@ import math -from megatron import print_rank_0 +from megatron import print_rank_0, get_args class AnnealingLR(object): @@ -28,7 +28,7 @@ def __init__(self, optimizer, start_lr, decay_style, last_iter, min_lr=0.0, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): - + args = get_args() # Class values. self.optimizer = optimizer self.start_lr = start_lr @@ -37,6 +37,9 @@ def __init__(self, optimizer, start_lr, self.num_iters = last_iter self.end_iter = total_iters assert self.end_iter > 0 + self.lr_decay_tokens = args.lr_decay_tokens + self.num_tokens = 0 + self.warmup_tokens = 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler @@ -44,7 +47,7 @@ def __init__(self, optimizer, start_lr, assert not self.use_checkpoint_lr_scheduler, 'both override and '\ 'use-checkpoint are set.' # Set the learning rate - self.step(self.num_iters) + self.step(self.num_iters, self.num_tokens) print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) @@ -53,16 +56,26 @@ def get_lr(self): https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" # Warmup. - if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter: - return float(self.start_lr) * self.num_iters / self.warmup_iter - - # For any iterations larger than `self.end_iter`, use `self.min_lr`. - if self.num_iters > self.end_iter: - return self.min_lr - # If we are done with the warmup period, use the decay style. - current_iter = self.num_iters - self.warmup_iter - decay_iter = self.end_iter - self.warmup_iter - decay_ratio = float(current_iter) / float(decay_iter) + if self.warmup_iter > 0: + if self.num_iters == self.warmup_iter and self.lr_decay_tokens is not None: + self.warmup_tokens = self.num_tokens + if self.num_iters <= self.warmup_iter: + return float(self.start_lr) * self.num_iters / self.warmup_iter + + if self.lr_decay_tokens is None: + # For any iterations larger than `self.end_iter`, use `self.min_lr`. + if self.num_iters > self.end_iter: + return self.min_lr + # If we are done with the warmup period, use the decay style. + current_iter = self.num_iters - self.warmup_iter + decay_iter = self.end_iter - self.warmup_iter + decay_ratio = float(current_iter) / float(decay_iter) + else: + if self.num_tokens > self.lr_decay_tokens: + return self.min_lr + current_tokens = self.num_tokens - self.warmup_tokens + decay_tokens = self.lr_decay_tokens - self.warmup_tokens + decay_ratio = float(current_tokens) / float(decay_tokens) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 @@ -78,11 +91,15 @@ def get_lr(self): lr = self.start_lr return max(lr, self.min_lr) - def step(self, step_num=None): + def step(self, step_num=None, token_num=None): """Set lr for all parameters groups.""" + args = get_args() if step_num is None: step_num = self.num_iters + 1 + if token_num is None: + token_num = args.tokens self.num_iters = step_num + self.num_tokens = token_num new_lr = self.get_lr() for group in self.optimizer.param_groups: group['lr'] = new_lr @@ -92,6 +109,8 @@ def state_dict(self): 'start_lr': self.start_lr, 'warmup_iter': self.warmup_iter, 'num_iters': self.num_iters, + 'warmup_tokens': self.warmup_tokens, + 'num_tokens': self.num_tokens, 'decay_style': self.decay_style, 'end_iter': self.end_iter, 'min_lr': self.min_lr @@ -128,4 +147,8 @@ def load_state_dict(self, sd): 'decay style') self.num_iters = sd['num_iters'] - self.step(self.num_iters) + if 'warmup_tokens' in sd: + self.warmup_tokens = sd['warmup_tokens'] + if 'num_tokens' in sd: + self.num_tokens = sd['num_tokens'] + self.step(self.num_iters, self.num_tokens)