Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is an example of how to use DeepSpeed's curriculum learning (CL) feature which provides faster and more stable language model pre-training. Currently it is only integrated for GPT pre-training. Note that there are two curriculum learning examples in two different repos for Megatron-LM GPT-2 pre-training. Both of them have some unique features and limitations. See details in our [tutorial](https://www.deepspeed.ai/tutorials/curriculum-learning/). For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084).
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ SEED=$8
SAVE_INTERVAL=$9
NUM_ITER=${10}
NUM_TOKEN=${11}
LR_DECAY_ITER=${12}
LR_DECAY_TOKEN=${12}
LR_WARMUP_ITER=${13}
CONFIG_TEMPLATE=${14}
CURRICULUM_STEP=${15}
Expand Down Expand Up @@ -74,7 +74,7 @@ else
config_json="$script_dir/ds_zero_stage_${stage}_config_${CONFIG}.json"
fi

JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_ITER}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}"
JOB_NAME="gpt2_${MODEL_SIZE}M_bsz${TOTAL_BATCHSIZE}_seq${SEQ_LEN}_lr${LR}_warmup${LR_WARMUP_ITER}_decay${LR_DECAY_TOKEN}_seed${SEED}_${TAG}_stage${stage}_n${NUM_WORKERS}_g${NUM_GPUS_PER_WORKER}_mp${MP_SIZE}"
LOG_NAME="${JOB_NAME}_${host}_${current_time}"

#Actication Checkpointing and Contigious Memory
Expand Down Expand Up @@ -102,7 +102,7 @@ gpt_options=" \
--batch-size $BATCHSIZE \
--train-iters $NUM_ITER \
--train-tokens $NUM_TOKEN \
--lr-decay-iters $LR_DECAY_ITER \
--lr-decay-tokens $LR_DECAY_TOKEN \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
Expand Down
11 changes: 4 additions & 7 deletions Megatron-LM-v1.1.5-ZeRO3/curriculum_learning/ds_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
# MP_SIZE=1
# SEED=1234
# SAVE_INTERVAL=5000

# NUM_ITER=600000
# NUM_TOKEN=157286400000
# LR_DECAY_ITER=300000
# LR_DECAY_TOKEN=157286400000
# LR_WARMUP_ITER=3000
# CONFIG_TEMPLATE=false
# CURRICULUM_STEP=0
Expand All @@ -26,15 +25,13 @@ SEQ_LEN=1024
MP_SIZE=1
SEED=1234
SAVE_INTERVAL=1000

NUM_ITER=75000
NUM_TOKEN=157286400000
LR_DECAY_TOKEN=157286400000
LR_WARMUP_ITER=3000
CONFIG_TEMPLATE=true
CURRICULUM_STEP=15000
CURRICULUM_STEP=45000
CURRICULUM_MIN=64

LR_DECAY_ITER=$((37500 + ${CURRICULUM_STEP} / 2))
TAG="${CONFIG}_s${CURRICULUM_MIN}to${SEQ_LEN}_step${CURRICULUM_STEP}"

bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_ITER $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN
bash ds_pretrain_gpt2.sh $CONFIG $TAG $MODEL_SIZE $LR $BSZ $SEQ_LEN $MP_SIZE $SEED $SAVE_INTERVAL $NUM_ITER $NUM_TOKEN $LR_DECAY_TOKEN $LR_WARMUP_ITER $CONFIG_TEMPLATE $CURRICULUM_STEP $CURRICULUM_MIN
2 changes: 2 additions & 0 deletions Megatron-LM-v1.1.5-ZeRO3/megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-tokens', type=int, default=None,
help='Learning rate decay tokens.')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
Expand Down
3 changes: 2 additions & 1 deletion Megatron-LM-v1.1.5-ZeRO3/megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
else:
try:
iteration = state_dict['iteration']
args.tokens = state_dict['tokens']
if 'tokens' in state_dict:
args.tokens = state_dict['tokens']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
Expand Down
53 changes: 38 additions & 15 deletions Megatron-LM-v1.1.5-ZeRO3/megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math

from megatron import print_rank_0
from megatron import print_rank_0, get_args


class AnnealingLR(object):
Expand All @@ -28,7 +28,7 @@ def __init__(self, optimizer, start_lr,
decay_style, last_iter, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):

args = get_args()
# Class values.
self.optimizer = optimizer
self.start_lr = start_lr
Expand All @@ -37,14 +37,17 @@ def __init__(self, optimizer, start_lr,
self.num_iters = last_iter
self.end_iter = total_iters
assert self.end_iter > 0
self.lr_decay_tokens = args.lr_decay_tokens
self.num_tokens = 0
self.warmup_tokens = 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters)
self.step(self.num_iters, self.num_tokens)

print_rank_0('> learning rate decay style: {}'.format(self.decay_style))

Expand All @@ -53,16 +56,26 @@ def get_lr(self):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""

# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter

# For any iterations larger than `self.end_iter`, use `self.min_lr`.
if self.num_iters > self.end_iter:
return self.min_lr
# If we are done with the warmup period, use the decay style.
current_iter = self.num_iters - self.warmup_iter
decay_iter = self.end_iter - self.warmup_iter
decay_ratio = float(current_iter) / float(decay_iter)
if self.warmup_iter > 0:
if self.num_iters == self.warmup_iter and self.lr_decay_tokens is not None:
self.warmup_tokens = self.num_tokens
if self.num_iters <= self.warmup_iter:
return float(self.start_lr) * self.num_iters / self.warmup_iter

if self.lr_decay_tokens is None:
# For any iterations larger than `self.end_iter`, use `self.min_lr`.
if self.num_iters > self.end_iter:
return self.min_lr
# If we are done with the warmup period, use the decay style.
current_iter = self.num_iters - self.warmup_iter
decay_iter = self.end_iter - self.warmup_iter
decay_ratio = float(current_iter) / float(decay_iter)
else:
if self.num_tokens > self.lr_decay_tokens:
return self.min_lr
current_tokens = self.num_tokens - self.warmup_tokens
decay_tokens = self.lr_decay_tokens - self.warmup_tokens
decay_ratio = float(current_tokens) / float(decay_tokens)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0

Expand All @@ -78,11 +91,15 @@ def get_lr(self):
lr = self.start_lr
return max(lr, self.min_lr)

def step(self, step_num=None):
def step(self, step_num=None, token_num=None):
"""Set lr for all parameters groups."""
args = get_args()
if step_num is None:
step_num = self.num_iters + 1
if token_num is None:
token_num = args.tokens
self.num_iters = step_num
self.num_tokens = token_num
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
Expand All @@ -92,6 +109,8 @@ def state_dict(self):
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
'warmup_tokens': self.warmup_tokens,
'num_tokens': self.num_tokens,
'decay_style': self.decay_style,
'end_iter': self.end_iter,
'min_lr': self.min_lr
Expand Down Expand Up @@ -128,4 +147,8 @@ def load_state_dict(self, sd):
'decay style')

self.num_iters = sd['num_iters']
self.step(self.num_iters)
if 'warmup_tokens' in sd:
self.warmup_tokens = sd['warmup_tokens']
if 'num_tokens' in sd:
self.num_tokens = sd['num_tokens']
self.step(self.num_iters, self.num_tokens)