From e8d4d1cb0f88b6e0b4bf662fcbb182b71c2f4b3e Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Wed, 21 Jul 2021 20:26:21 +0200 Subject: [PATCH 01/13] Integrate EleutherAI's version of rotary embeddings + make some small optimisation --- megatron/model/enums.py | 3 ++ megatron/model/positional_embeddings.py | 51 +++++++++++++++++++++++++ megatron/model/transformer.py | 22 ++++++++++- 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 megatron/model/positional_embeddings.py diff --git a/megatron/model/enums.py b/megatron/model/enums.py index b6992fefa..2b61aa98a 100644 --- a/megatron/model/enums.py +++ b/megatron/model/enums.py @@ -26,3 +26,6 @@ class AttnType(enum.Enum): class AttnMaskType(enum.Enum): padding = 1 causal = 2 + +class PositionalEmbedding(enum.Enum): + rotary = 1 diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py new file mode 100644 index 000000000..8ca31dbf2 --- /dev/null +++ b/megatron/model/positional_embeddings.py @@ -0,0 +1,51 @@ +# Extracted from: https://github.com/EleutherAI/gpt-neox +import torch + + +class RotaryEmbedding(torch.nn.Module): + + def __init__(self, dim, base=10000, precision=torch.half): + super().__init__() + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + self.max_seq_len_cached = None + self.cos_cached = None + self.sin_cached = None + self.precision = precision + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq) + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + # [sx, 1 (b), 1 (np), hn] + self.cos_cached = emb.cos()[:, None, None, :] + self.sin_cached = emb.sin()[:, None, None, :] + if self.precision == torch.bfloat16: + self.cos_cached = self.cos_cached.bfloat16() + self.sin_cached = self.sin_cached.bfloat16() + return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + + +# rotary pos emb helpers: + +def rotate_half(x): + x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions + + +@torch.jit.script +def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 + cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) \ No newline at end of file diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 47433d725..ef62300f1 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -17,11 +17,12 @@ import math import torch import torch.nn.functional as F +from torch import nn from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.model.enums import AttnMaskType, LayerType, AttnType, PositionalEmbedding from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -30,6 +31,8 @@ import deepspeed # flags required to enable jit fusion kernels +from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb + torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) @@ -119,6 +122,7 @@ def __init__(self, init_method, args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 + self.position_embeddings = args.position_embeddings self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 @@ -192,6 +196,9 @@ def __init__(self, init_method, get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint + if self.position_embeddings == PositionalEmbedding.rotary: + self.rotary_emb = RotaryEmbedding(args.hidden_size, precision=args.params_dtype) + def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): # hidden_states: [sq, b, h] @@ -274,6 +281,19 @@ def forward(self, hidden_states, attention_mask, layer_past=None, dtype=query_layer.dtype, device=torch.cuda.current_device()) + # Rotary embeddings + if self.position_embeddings == PositionalEmbedding.rotary: + query_rot, key_rot = query_layer, key_layer + apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb + + seq_len = key_layer.shape[0] + offset = 0 + if layer_past is not None and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_rotary_fn(query_rot, key_rot, cos, sin, offset=offset) + # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, From 7844641fb3fa32a7e115d7f13278ef51562481ba Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Wed, 21 Jul 2021 20:43:00 +0200 Subject: [PATCH 02/13] Add argument parser for position embeddings --- megatron/arguments.py | 6 ++++++ megatron/model/transformer.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 0c2811117..73dff2f72 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -21,6 +21,9 @@ import torch import deepspeed +from megatron.model.enums import PositionalEmbedding + + def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse all arguments.""" @@ -301,6 +304,9 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') + group.add_argument('--positional-embeddings', type=PositionalEmbedding, choices=list(PositionalEmbedding), + help='Define positional embeddings strategy.' + ) return parser diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index ef62300f1..73011eedd 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -30,9 +30,9 @@ import deepspeed -# flags required to enable jit fusion kernels from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb +# flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_executor(False) torch._C._jit_override_can_fuse_on_cpu(True) @@ -198,6 +198,8 @@ def __init__(self, init_method, if self.position_embeddings == PositionalEmbedding.rotary: self.rotary_emb = RotaryEmbedding(args.hidden_size, precision=args.params_dtype) + else: + raise ValueError("Temporary in order to make sure that argparser works perfectly.") def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): From 836d0440e676f78d8013191d15e587e071c0faae Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Wed, 21 Jul 2021 21:18:17 +0200 Subject: [PATCH 03/13] Making max-absolute-position-embeddings optional --- examples/create_embeddings.sh | 2 +- examples/evaluate_ict_zeroshot_nq.sh | 2 +- examples/evaluate_zeroshot_gpt.sh | 2 +- examples/finetune_mnli_distributed.sh | 2 +- examples/finetune_race_distributed.sh | 2 +- examples/generate_text.sh | 2 +- examples/merge_mp_bert.sh | 2 +- examples/pretrain_bert.sh | 2 +- examples/pretrain_bert_distributed.sh | 2 +- examples/pretrain_bert_distributed_with_mp.sh | 2 +- examples/pretrain_gpt.sh | 2 +- examples/pretrain_gpt3_175B.sh | 2 +- examples/pretrain_gpt_distributed.sh | 2 +- examples/pretrain_gpt_distributed_with_mp.sh | 2 +- examples/pretrain_gpt_tiny.sh | 2 +- examples/pretrain_ict.sh | 2 +- examples/pretrain_t5.sh | 2 +- examples/pretrain_t5_distributed.sh | 2 +- examples/pretrain_t5_distributed_with_mp.sh | 2 +- megatron/arguments.py | 24 ++++---- megatron/checkpointing.py | 2 +- megatron/model/enums.py | 3 +- megatron/model/gpt_model.py | 2 - megatron/model/language_model.py | 60 ++++++++++--------- megatron/model/transformer.py | 8 +-- megatron/mpu/random.py | 4 +- run.sh | 2 +- tools/merge_mp_partitions.py | 2 +- 28 files changed, 78 insertions(+), 67 deletions(-) diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh index 59a5839f7..985a832d2 100644 --- a/examples/create_embeddings.sh +++ b/examples/create_embeddings.sh @@ -20,7 +20,7 @@ python tools/create_doc_index.py \ --checkpoint-activations \ --seq-length 512 \ --retriever-seq-length 256 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh index e1ce45a93..a29c52ffc 100644 --- a/examples/evaluate_ict_zeroshot_nq.sh +++ b/examples/evaluate_ict_zeroshot_nq.sh @@ -22,7 +22,7 @@ python tasks/main.py \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_zeroshot_gpt.sh b/examples/evaluate_zeroshot_gpt.sh index 96fd28f3a..e3a13bf68 100755 --- a/examples/evaluate_zeroshot_gpt.sh +++ b/examples/evaluate_zeroshot_gpt.sh @@ -31,7 +31,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --batch-size 8 \ --checkpoint-activations \ --seq-length 1024 \ - --max-position-embeddings 1024 \ + --max-absolute-position-embeddings 1024 \ --log-interval 10 \ --fp16 \ --no-load-optim \ diff --git a/examples/finetune_mnli_distributed.sh b/examples/finetune_mnli_distributed.sh index 213eb1fa1..80a66fd34 100755 --- a/examples/finetune_mnli_distributed.sh +++ b/examples/finetune_mnli_distributed.sh @@ -34,7 +34,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-decay-style linear \ --lr-warmup-fraction 0.065 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --save-interval 500000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ diff --git a/examples/finetune_race_distributed.sh b/examples/finetune_race_distributed.sh index 5ac642ee3..c8153016c 100755 --- a/examples/finetune_race_distributed.sh +++ b/examples/finetune_race_distributed.sh @@ -34,7 +34,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-decay-style linear \ --lr-warmup-fraction 0.06 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --save-interval 100000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ diff --git a/examples/generate_text.sh b/examples/generate_text.sh index eefe8dfbe..e630943fb 100755 --- a/examples/generate_text.sh +++ b/examples/generate_text.sh @@ -10,7 +10,7 @@ python tools/generate_samples_gpt2.py \ --hidden-size 1024 \ --load $CHECKPOINT_PATH \ --num-attention-heads 16 \ - --max-position-embeddings 1024 \ + --max-absolute-position-embeddings 1024 \ --tokenizer-type GPT2BPETokenizer \ --fp16 \ --batch-size 2 \ diff --git a/examples/merge_mp_bert.sh b/examples/merge_mp_bert.sh index 138343328..ffe40ad24 100755 --- a/examples/merge_mp_bert.sh +++ b/examples/merge_mp_bert.sh @@ -14,5 +14,5 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ --hidden-size 1024 \ --num-attention-heads 16 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --load $CHECKPOINT_PATH diff --git a/examples/pretrain_bert.sh b/examples/pretrain_bert.sh index 9c744ee45..059094d96 100755 --- a/examples/pretrain_bert.sh +++ b/examples/pretrain_bert.sh @@ -12,7 +12,7 @@ python pretrain_bert.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 2000000 \ --lr-decay-iters 990000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed.sh b/examples/pretrain_bert_distributed.sh index a833c5a94..eb88f10a4 100755 --- a/examples/pretrain_bert_distributed.sh +++ b/examples/pretrain_bert_distributed.sh @@ -21,7 +21,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 4 \ --global-batch-size 32 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed_with_mp.sh b/examples/pretrain_bert_distributed_with_mp.sh index 4c50dcc25..477644b9f 100755 --- a/examples/pretrain_bert_distributed_with_mp.sh +++ b/examples/pretrain_bert_distributed_with_mp.sh @@ -23,7 +23,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --num-attention-heads 16 \ --micro-batch-size 2 \ --global-batch-size 16 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt.sh b/examples/pretrain_gpt.sh index cad6bcc13..e2e0989d5 100755 --- a/examples/pretrain_gpt.sh +++ b/examples/pretrain_gpt.sh @@ -16,7 +16,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 1024 \ - --max-position-embeddings 1024 \ + --max-absolute-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh index ad0d244d7..7492be1aa 100755 --- a/examples/pretrain_gpt3_175B.sh +++ b/examples/pretrain_gpt3_175B.sh @@ -22,7 +22,7 @@ options=" \ --hidden-size 12288 \ --num-attention-heads 96 \ --seq-length 2048 \ - --max-position-embeddings 2048 \ + --max-absolute-position-embeddings 2048 \ --micro-batch-size 1 \ --global-batch-size 1536 \ --rampup-batch-size 16 16 5859375 \ diff --git a/examples/pretrain_gpt_distributed.sh b/examples/pretrain_gpt_distributed.sh index 1b4518604..2e6b65dcb 100755 --- a/examples/pretrain_gpt_distributed.sh +++ b/examples/pretrain_gpt_distributed.sh @@ -23,7 +23,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 8 \ --global-batch-size 64 \ --seq-length 1024 \ - --max-position-embeddings 1024 \ + --max-absolute-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_distributed_with_mp.sh b/examples/pretrain_gpt_distributed_with_mp.sh index c67db4c45..3f2ca69c9 100755 --- a/examples/pretrain_gpt_distributed_with_mp.sh +++ b/examples/pretrain_gpt_distributed_with_mp.sh @@ -25,7 +25,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 4 \ --global-batch-size 16 \ --seq-length 1024 \ - --max-position-embeddings 1024 \ + --max-absolute-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_tiny.sh b/examples/pretrain_gpt_tiny.sh index c7d953f10..38576bb1e 100644 --- a/examples/pretrain_gpt_tiny.sh +++ b/examples/pretrain_gpt_tiny.sh @@ -16,7 +16,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 256 \ - --max-position-embeddings 256 \ + --max-absolute-position-embeddings 256 \ --train-iters 10000 \ --lr-decay-iters 5000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_ict.sh b/examples/pretrain_ict.sh index 8cba0f08b..1b5fef349 100755 --- a/examples/pretrain_ict.sh +++ b/examples/pretrain_ict.sh @@ -18,7 +18,7 @@ python pretrain_ict.py \ --tensor-model-parallel-size 1 \ --micro-batch-size 32 \ --seq-length 256 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 100000 \ --vocab-file bert-vocab.txt \ --tokenizer-type BertWordPieceLowerCase \ diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh index 71fea8489..1d937af8f 100644 --- a/examples/pretrain_t5.sh +++ b/examples/pretrain_t5.sh @@ -16,7 +16,7 @@ python pretrain_t5.py \ --decoder-seq-length 128 \ --micro-batch-size 16 \ --global-batch-size 2048 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh index 778b4ad2a..4a5c359ce 100644 --- a/examples/pretrain_t5_distributed.sh +++ b/examples/pretrain_t5_distributed.sh @@ -25,7 +25,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --decoder-seq-length 128 \ --micro-batch-size 16 \ --global-batch-size 2048 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh index 9be70393d..8ea6e016a 100644 --- a/examples/pretrain_t5_distributed_with_mp.sh +++ b/examples/pretrain_t5_distributed_with_mp.sh @@ -26,7 +26,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 16 \ --global-batch-size 2048 \ --seq-length 512 \ - --max-position-embeddings 512 \ + --max-absolute-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/megatron/arguments.py b/megatron/arguments.py index 73dff2f72..d1b55529f 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -21,7 +21,7 @@ import torch import deepspeed -from megatron.model.enums import PositionalEmbedding +from megatron.model.enums import PositionEmbeddingType def parse_args(extra_args_provider=None, defaults={}, @@ -202,8 +202,7 @@ def parse_args(extra_args_provider=None, defaults={}, 'and lr-warmup-samples' # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', - 'max_position_embeddings'] + required_args = ['num_layers', 'hidden_size', 'num_attention_heads', 'position_embedding_type'] for req_arg in required_args: _check_arg_is_not_none(args, req_arg) @@ -222,10 +221,15 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.encoder_seq_length is not None args.seq_length = args.encoder_seq_length - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length + if args.position_embedding_type == PositionEmbeddingType.absolute: + assert args.max_absolute_position_embeddings is not None + if args.seq_length is not None: + assert args.max_absolute_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_absolute_position_embeddings >= args.decoder_seq_length + else: + assert args.max_absolute_position_embeddings is None + if args.lr is not None: assert args.min_lr <= args.lr if args.save is not None: @@ -282,7 +286,7 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') - group.add_argument('--max-position-embeddings', type=int, default=None, + group.add_argument('--max-absolute-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, @@ -304,8 +308,8 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') - group.add_argument('--positional-embeddings', type=PositionalEmbedding, choices=list(PositionalEmbedding), - help='Define positional embeddings strategy.' + group.add_argument('--position-embedding-type', type=PositionEmbeddingType, choices=list(PositionEmbeddingType), + help='Define position embedding type.' ) return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 3cc6a8e2e..f960440ef 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -60,7 +60,7 @@ def _compare(arg_name, old_arg_name=None): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('max_position_embeddings') + _compare('max_absolute_position_embeddings') if args.vocab_file: _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') diff --git a/megatron/model/enums.py b/megatron/model/enums.py index 2b61aa98a..84d9a4ff0 100644 --- a/megatron/model/enums.py +++ b/megatron/model/enums.py @@ -27,5 +27,6 @@ class AttnMaskType(enum.Enum): padding = 1 causal = 2 -class PositionalEmbedding(enum.Enum): +class PositionEmbeddingType(enum.Enum): rotary = 1 + absolute = 2 diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 293526465..5a78f5007 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -182,7 +182,6 @@ def _to_float16(inputs): EmbeddingPipe, args.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, @@ -224,7 +223,6 @@ def _logits_helper(embedding, lm_output): EmbeddingPipe, args.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, init_method=init_method, num_tokentypes=num_tokentypes, diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index dd5f6972b..1aa873d1b 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import LayerType, AttnMaskType +from megatron.model.enums import LayerType, AttnMaskType, PositionEmbeddingType from megatron.model.transformer import ParallelTransformer from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal @@ -107,8 +107,6 @@ class Embedding(MegatronModule): Arguments: hidden_size: hidden size vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value @@ -118,7 +116,6 @@ class Embedding(MegatronModule): def __init__(self, hidden_size, vocab_size, - max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0): @@ -137,11 +134,15 @@ def __init__(self, self._word_embeddings_key = 'word_embeddings' # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) - self._position_embeddings_key = 'position_embeddings' - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) + self.position_embedding_type = args.position_embedding_type + if self.position_embedding_type == PositionEmbeddingType.absolute: + max_absolute_position_embeddings = args.max_absolute_position_embeddings + assert max_absolute_position_embeddings is not None + self.position_embeddings = torch.nn.Embedding( + max_absolute_position_embeddings, self.hidden_size) + self._position_embeddings_key = 'position_embeddings' + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) # Token type embedding. # Add this as an optional field that can be added through @@ -179,8 +180,14 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings + embeddings = words_embeddings + + if self.position_embedding_type == PositionEmbeddingType.absolute: + assert self.position_embeddings is not None + embeddings = embeddings + self.position_embeddings(position_ids) + else: + assert self.position_embeddings is not None + if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) @@ -199,9 +206,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(destination, prefix, keep_vars) - state_dict_[self._position_embeddings_key] \ - = self.position_embeddings.state_dict( - destination, prefix, keep_vars) + if self.position_embedding_type == PositionEmbeddingType.absolute: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict( + destination, prefix, keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] \ = self.tokentype_embeddings.state_dict( @@ -225,16 +233,17 @@ def load_state_dict(self, state_dict, strict=True): self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. - if self._position_embeddings_key in state_dict: - state_dict_ = state_dict[self._position_embeddings_key] - else: - # for backward compatibility. - state_dict_ = {} - for key in state_dict.keys(): - if 'position_embeddings' in key: - state_dict_[key.split('position_embeddings.')[1]] \ - = state_dict[key] - self.position_embeddings.load_state_dict(state_dict_, strict=strict) + if self.position_embedding_type == PositionEmbeddingType.absolute: + if self._position_embeddings_key in state_dict: + state_dict_ = state_dict[self._position_embeddings_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'position_embeddings' in key: + state_dict_[key.split('position_embeddings.')[1]] \ + = state_dict[key] + self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: @@ -295,8 +304,6 @@ class TransformerLanguageModel(MegatronModule): Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding @@ -329,7 +336,6 @@ def __init__(self, if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 73011eedd..13c3ae142 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -22,7 +22,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import AttnMaskType, LayerType, AttnType, PositionalEmbedding +from megatron.model.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -122,7 +122,7 @@ def __init__(self, init_method, args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 - self.position_embeddings = args.position_embeddings + self.position_embedding_type = args.position_embedding_type self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 @@ -196,7 +196,7 @@ def __init__(self, init_method, get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker checkpoint = deepspeed.checkpointing.checkpoint - if self.position_embeddings == PositionalEmbedding.rotary: + if self.position_embedding_type == PositionEmbeddingType.rotary: self.rotary_emb = RotaryEmbedding(args.hidden_size, precision=args.params_dtype) else: raise ValueError("Temporary in order to make sure that argparser works perfectly.") @@ -284,7 +284,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, device=torch.cuda.current_device()) # Rotary embeddings - if self.position_embeddings == PositionalEmbedding.rotary: + if self.position_embedding_type == PositionEmbeddingType.rotary: query_rot, key_rot = query_layer, key_layer apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index a1c1a4c71..7bb58310e 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -45,7 +45,9 @@ def init_checkpointed_activations_memory_buffer(): """Initializ the memory buffer for the checkpointed activations.""" args = get_args() - per_layer = args.micro_batch_size * args.max_position_embeddings * \ + # TODO: Remove + assert args.max_absolute_position_embeddings is not None + per_layer = args.micro_batch_size * args.max_absolute_position_embeddings * \ args.hidden_size // args.tensor_model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' diff --git a/run.sh b/run.sh index b8fe8fc36..8dfe986fa 100755 --- a/run.sh +++ b/run.sh @@ -76,7 +76,7 @@ options=" \ --num-attention-heads 32 \ --seq-length $SEQ \ --loss-scale 12 \ - --max-position-embeddings $SEQ \ + --max-absolute-position-embeddings $SEQ \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ --train-iters 1000 \ diff --git a/tools/merge_mp_partitions.py b/tools/merge_mp_partitions.py index 4dc2d99f8..0d712df18 100644 --- a/tools/merge_mp_partitions.py +++ b/tools/merge_mp_partitions.py @@ -225,7 +225,7 @@ def main(): print(' number of attention heads ....... {}'.format( args.num_attention_heads)) print(' maximum position embeddings ..... {}'.format( - args.max_position_embeddings)) + args.max_absolute_position_embeddings)) # Full model. print('> building the full model ...') From c9523eadedf000f4de7001802962c53723473401 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Wed, 21 Jul 2021 22:18:59 +0200 Subject: [PATCH 04/13] Move enum outside model --- megatron/arguments.py | 4 ++-- megatron/data/Makefile | 2 +- megatron/{model => }/enums.py | 0 megatron/model/bert_model.py | 2 +- megatron/model/biencoder_model.py | 2 +- megatron/model/classification.py | 2 +- megatron/model/fused_softmax.py | 2 +- megatron/model/gpt_model.py | 2 +- megatron/model/language_model.py | 6 ++++-- megatron/model/multiple_choice.py | 2 +- megatron/model/realm_model.py | 2 +- megatron/model/t5_model.py | 2 +- megatron/model/transformer.py | 4 +--- 13 files changed, 16 insertions(+), 16 deletions(-) rename megatron/{model => }/enums.py (100%) diff --git a/megatron/arguments.py b/megatron/arguments.py index d1b55529f..1d75bfb5b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -21,7 +21,7 @@ import torch import deepspeed -from megatron.model.enums import PositionEmbeddingType +from megatron.enums import PositionEmbeddingType def parse_args(extra_args_provider=None, defaults={}, @@ -308,7 +308,7 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') - group.add_argument('--position-embedding-type', type=PositionEmbeddingType, choices=list(PositionEmbeddingType), + group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), help='Define position embedding type.' ) diff --git a/megatron/data/Makefile b/megatron/data/Makefile index 8f9db7686..707390cdb 100644 --- a/megatron/data/Makefile +++ b/megatron/data/Makefile @@ -1,5 +1,5 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color -CPPFLAGS += $(shell python3 -m pybind11 --includes) +CPPFLAGS += $(shell python -m pybind11 --includes) LIBNAME = helpers LIBEXT = $(shell python3-config --extension-suffix) diff --git a/megatron/model/enums.py b/megatron/enums.py similarity index 100% rename from megatron/model/enums.py rename to megatron/enums.py diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index 3ff5039d5..4cb650b36 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -19,7 +19,7 @@ from megatron import get_args from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model from megatron.model import LayerNorm diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 51ac0a060..0f0a6698f 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -8,7 +8,7 @@ from megatron.checkpointing import get_checkpoint_name from megatron import mpu, get_tokenizer from megatron.model.bert_model import bert_position_ids -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal diff --git a/megatron/model/classification.py b/megatron/model/classification.py index d975072f7..94dc5fe7d 100644 --- a/megatron/model/classification.py +++ b/megatron/model/classification.py @@ -19,7 +19,7 @@ from megatron import get_args, print_rank_last from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 097b29ef4..aa1384207 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -14,7 +14,7 @@ # limitations under the License. import torch -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 5a78f5007..8a1d189e0 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -20,9 +20,9 @@ from megatron import get_args from megatron import mpu +from megatron.enums import AttnMaskType from .module import MegatronModule, fp32_to_float16 -from .enums import AttnMaskType from .language_model import parallel_lm_logits from .language_model import get_language_model from .utils import init_method_normal diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 1aa873d1b..03741d027 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import LayerType, AttnMaskType, PositionEmbeddingType +from megatron.enums import LayerType, AttnMaskType, PositionEmbeddingType from megatron.model.transformer import ParallelTransformer from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal, scaled_init_method_normal @@ -143,6 +143,8 @@ def __init__(self, self._position_embeddings_key = 'position_embeddings' # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) + else: + self.position_embeddings = None # Token type embedding. # Add this as an optional field that can be added through @@ -186,7 +188,7 @@ def forward(self, input_ids, position_ids, tokentype_ids=None): assert self.position_embeddings is not None embeddings = embeddings + self.position_embeddings(position_ids) else: - assert self.position_embeddings is not None + assert self.position_embeddings is None if tokentype_ids is not None: assert self.tokentype_embeddings is not None diff --git a/megatron/model/multiple_choice.py b/megatron/model/multiple_choice.py index c43bd969c..a8445e0cb 100644 --- a/megatron/model/multiple_choice.py +++ b/megatron/model/multiple_choice.py @@ -19,7 +19,7 @@ from megatron import get_args, print_rank_last from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.language_model import get_language_model from megatron.model.utils import get_linear_layer diff --git a/megatron/model/realm_model.py b/megatron/model/realm_model.py index 5730a85e3..c57f51592 100644 --- a/megatron/model/realm_model.py +++ b/megatron/model/realm_model.py @@ -6,7 +6,7 @@ from megatron.model import BertModel from .module import MegatronModule from megatron import mpu -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.language_model import get_language_model diff --git a/megatron/model/t5_model.py b/megatron/model/t5_model.py index beb4f0ee5..7b59b9e18 100644 --- a/megatron/model/t5_model.py +++ b/megatron/model/t5_model.py @@ -21,7 +21,7 @@ get_args, mpu ) -from megatron.model.enums import AttnMaskType +from megatron.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.transformer import LayerNorm from megatron.model.utils import ( diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 13c3ae142..749038286 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -22,7 +22,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule -from megatron.model.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType +from megatron.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl @@ -198,8 +198,6 @@ def __init__(self, init_method, if self.position_embedding_type == PositionEmbeddingType.rotary: self.rotary_emb = RotaryEmbedding(args.hidden_size, precision=args.params_dtype) - else: - raise ValueError("Temporary in order to make sure that argparser works perfectly.") def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): From 215a38a41e251fb301da8ca7c1d73ec0b4f628b1 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 00:56:15 +0200 Subject: [PATCH 05/13] Handle max_seq_len_cached better --- megatron/model/positional_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 8ca31dbf2..6ee9ef8b9 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -16,7 +16,7 @@ def __init__(self, dim, base=10000, precision=torch.half): def forward(self, x, seq_dim=1, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] - if seq_len > self.max_seq_len_cached: + if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq) freqs = torch.einsum('i,j->ij', t, self.inv_freq) From 0bd2138a19cd32d3d5ae147450adc293654bdcbe Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 01:02:37 +0200 Subject: [PATCH 06/13] Fix dtype issue in rotary embeddings --- megatron/model/positional_embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 6ee9ef8b9..b5cb0d035 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -18,7 +18,7 @@ def forward(self, x, seq_dim=1, seq_len=None): seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq) + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) From a69c75d318ca6851a8222d1a2f429e0955ff5153 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 01:08:50 +0200 Subject: [PATCH 07/13] Fix tensor size --- megatron/model/positional_embeddings.py | 7 ++++--- megatron/model/transformer.py | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index b5cb0d035..b8f806087 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -24,9 +24,9 @@ def forward(self, x, seq_dim=1, seq_len=None): emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float() - # [sx, 1 (b), 1 (np), hn] - self.cos_cached = emb.cos()[:, None, None, :] - self.sin_cached = emb.sin()[:, None, None, :] + # [sx, 1 (b * np), hn] + self.cos_cached = emb.cos()[:, None, :] + self.sin_cached = emb.sin()[:, None, :] if self.precision == torch.bfloat16: self.cos_cached = self.cos_cached.bfloat16() self.sin_cached = self.sin_cached.bfloat16() @@ -43,6 +43,7 @@ def rotate_half(x): @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] + print(cos.shape, sin.shape, q.shape, k.shape) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 749038286..d9c4d0ce3 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -283,7 +283,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None, # Rotary embeddings if self.position_embedding_type == PositionEmbeddingType.rotary: - query_rot, key_rot = query_layer, key_layer apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb seq_len = key_layer.shape[0] @@ -292,7 +291,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, offset = layer_past[0].shape[0] seq_len += offset cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) - query_layer, key_layer = apply_rotary_fn(query_rot, key_rot, cos, sin, offset=offset) + query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( From fbae8b97330b3c5d7b4320ceadac59f923909b7f Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 01:15:29 +0200 Subject: [PATCH 08/13] Replace hidden_dim by hidden_size_per_attention_head --- megatron/model/positional_embeddings.py | 1 - megatron/model/transformer.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index b8f806087..3494f9e4e 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -43,7 +43,6 @@ def rotate_half(x): @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] - print(cos.shape, sin.shape, q.shape, k.shape) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index d9c4d0ce3..c8055015c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -197,7 +197,7 @@ def __init__(self, init_method, checkpoint = deepspeed.checkpointing.checkpoint if self.position_embedding_type == PositionEmbeddingType.rotary: - self.rotary_emb = RotaryEmbedding(args.hidden_size, precision=args.params_dtype) + self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype) def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): From 6bb1a333e8ee6ea86338b4e923077de87a44b1db Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 01:27:10 +0200 Subject: [PATCH 09/13] Change all examples to new format and improve help in argparser --- examples/create_embeddings.sh | 1 + examples/evaluate_ict_zeroshot_nq.sh | 1 + examples/evaluate_zeroshot_gpt.sh | 1 + examples/finetune_mnli_distributed.sh | 1 + examples/finetune_race_distributed.sh | 1 + examples/generate_text.sh | 1 + examples/merge_mp_bert.sh | 1 + examples/pretrain_bert.sh | 1 + examples/pretrain_bert_distributed.sh | 1 + examples/pretrain_bert_distributed_with_mp.sh | 1 + examples/pretrain_gpt.sh | 1 + examples/pretrain_gpt3_175B.sh | 1 + examples/pretrain_gpt_distributed.sh | 1 + examples/pretrain_gpt_distributed_with_mp.sh | 1 + examples/pretrain_gpt_tiny.sh | 1 + examples/pretrain_ict.sh | 1 + examples/pretrain_t5.sh | 1 + examples/pretrain_t5_distributed.sh | 1 + examples/pretrain_t5_distributed_with_mp.sh | 1 + megatron/arguments.py | 2 +- run.sh | 1 + 21 files changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh index 985a832d2..ddebac92b 100644 --- a/examples/create_embeddings.sh +++ b/examples/create_embeddings.sh @@ -21,6 +21,7 @@ python tools/create_doc_index.py \ --seq-length 512 \ --retriever-seq-length 256 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh index a29c52ffc..c4d901053 100644 --- a/examples/evaluate_ict_zeroshot_nq.sh +++ b/examples/evaluate_ict_zeroshot_nq.sh @@ -23,6 +23,7 @@ python tasks/main.py \ --checkpoint-activations \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_zeroshot_gpt.sh b/examples/evaluate_zeroshot_gpt.sh index e3a13bf68..533f968ff 100755 --- a/examples/evaluate_zeroshot_gpt.sh +++ b/examples/evaluate_zeroshot_gpt.sh @@ -32,6 +32,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --checkpoint-activations \ --seq-length 1024 \ --max-absolute-position-embeddings 1024 \ + --position-embedding-type absolute \ --log-interval 10 \ --fp16 \ --no-load-optim \ diff --git a/examples/finetune_mnli_distributed.sh b/examples/finetune_mnli_distributed.sh index 80a66fd34..0e1670e4c 100755 --- a/examples/finetune_mnli_distributed.sh +++ b/examples/finetune_mnli_distributed.sh @@ -35,6 +35,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-warmup-fraction 0.065 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --save-interval 500000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ diff --git a/examples/finetune_race_distributed.sh b/examples/finetune_race_distributed.sh index c8153016c..f60f80a38 100755 --- a/examples/finetune_race_distributed.sh +++ b/examples/finetune_race_distributed.sh @@ -35,6 +35,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-warmup-fraction 0.06 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --save-interval 100000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ diff --git a/examples/generate_text.sh b/examples/generate_text.sh index e630943fb..a841c3871 100755 --- a/examples/generate_text.sh +++ b/examples/generate_text.sh @@ -11,6 +11,7 @@ python tools/generate_samples_gpt2.py \ --load $CHECKPOINT_PATH \ --num-attention-heads 16 \ --max-absolute-position-embeddings 1024 \ + --position-embedding-type absolute \ --tokenizer-type GPT2BPETokenizer \ --fp16 \ --batch-size 2 \ diff --git a/examples/merge_mp_bert.sh b/examples/merge_mp_bert.sh index ffe40ad24..9469dcf58 100755 --- a/examples/merge_mp_bert.sh +++ b/examples/merge_mp_bert.sh @@ -15,4 +15,5 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ --num-attention-heads 16 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --load $CHECKPOINT_PATH diff --git a/examples/pretrain_bert.sh b/examples/pretrain_bert.sh index 059094d96..8c3f222ac 100755 --- a/examples/pretrain_bert.sh +++ b/examples/pretrain_bert.sh @@ -13,6 +13,7 @@ python pretrain_bert.py \ --global-batch-size 8 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 2000000 \ --lr-decay-iters 990000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed.sh b/examples/pretrain_bert_distributed.sh index eb88f10a4..26fecece8 100755 --- a/examples/pretrain_bert_distributed.sh +++ b/examples/pretrain_bert_distributed.sh @@ -22,6 +22,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --global-batch-size 32 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed_with_mp.sh b/examples/pretrain_bert_distributed_with_mp.sh index 477644b9f..d270a9712 100755 --- a/examples/pretrain_bert_distributed_with_mp.sh +++ b/examples/pretrain_bert_distributed_with_mp.sh @@ -24,6 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 2 \ --global-batch-size 16 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt.sh b/examples/pretrain_gpt.sh index e2e0989d5..62c4139bb 100755 --- a/examples/pretrain_gpt.sh +++ b/examples/pretrain_gpt.sh @@ -17,6 +17,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --global-batch-size 8 \ --seq-length 1024 \ --max-absolute-position-embeddings 1024 \ + --position-embedding-type absolute \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh index 7492be1aa..ce7f97c0c 100755 --- a/examples/pretrain_gpt3_175B.sh +++ b/examples/pretrain_gpt3_175B.sh @@ -23,6 +23,7 @@ options=" \ --num-attention-heads 96 \ --seq-length 2048 \ --max-absolute-position-embeddings 2048 \ + --position-embedding-type absolute \ --micro-batch-size 1 \ --global-batch-size 1536 \ --rampup-batch-size 16 16 5859375 \ diff --git a/examples/pretrain_gpt_distributed.sh b/examples/pretrain_gpt_distributed.sh index 2e6b65dcb..4c1a3044b 100755 --- a/examples/pretrain_gpt_distributed.sh +++ b/examples/pretrain_gpt_distributed.sh @@ -24,6 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --global-batch-size 64 \ --seq-length 1024 \ --max-absolute-position-embeddings 1024 \ + --position-embedding-type absolute \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_distributed_with_mp.sh b/examples/pretrain_gpt_distributed_with_mp.sh index 3f2ca69c9..d68f78487 100755 --- a/examples/pretrain_gpt_distributed_with_mp.sh +++ b/examples/pretrain_gpt_distributed_with_mp.sh @@ -26,6 +26,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --global-batch-size 16 \ --seq-length 1024 \ --max-absolute-position-embeddings 1024 \ + --position-embedding-type absolute \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_tiny.sh b/examples/pretrain_gpt_tiny.sh index 38576bb1e..950717c67 100644 --- a/examples/pretrain_gpt_tiny.sh +++ b/examples/pretrain_gpt_tiny.sh @@ -17,6 +17,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --global-batch-size 8 \ --seq-length 256 \ --max-absolute-position-embeddings 256 \ + --position-embedding-type absolute \ --train-iters 10000 \ --lr-decay-iters 5000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_ict.sh b/examples/pretrain_ict.sh index 1b5fef349..4112926f7 100755 --- a/examples/pretrain_ict.sh +++ b/examples/pretrain_ict.sh @@ -19,6 +19,7 @@ python pretrain_ict.py \ --micro-batch-size 32 \ --seq-length 256 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 100000 \ --vocab-file bert-vocab.txt \ --tokenizer-type BertWordPieceLowerCase \ diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh index 1d937af8f..71aabed8a 100644 --- a/examples/pretrain_t5.sh +++ b/examples/pretrain_t5.sh @@ -17,6 +17,7 @@ python pretrain_t5.py \ --micro-batch-size 16 \ --global-batch-size 2048 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh index 4a5c359ce..4a0491ef8 100644 --- a/examples/pretrain_t5_distributed.sh +++ b/examples/pretrain_t5_distributed.sh @@ -26,6 +26,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 16 \ --global-batch-size 2048 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh index 8ea6e016a..3fbd838ee 100644 --- a/examples/pretrain_t5_distributed_with_mp.sh +++ b/examples/pretrain_t5_distributed_with_mp.sh @@ -27,6 +27,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --global-batch-size 2048 \ --seq-length 512 \ --max-absolute-position-embeddings 512 \ + --position-embedding-type absolute \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/megatron/arguments.py b/megatron/arguments.py index 1d75bfb5b..364906954 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -309,7 +309,7 @@ def _add_network_size_args(parser): help='Disable BERT binary head.', dest='bert_binary_head') group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), - help='Define position embedding type.' + help='Define position embedding type ("absolute" | "rotary").' ) return parser diff --git a/run.sh b/run.sh index 8dfe986fa..9b7094f9b 100755 --- a/run.sh +++ b/run.sh @@ -77,6 +77,7 @@ options=" \ --seq-length $SEQ \ --loss-scale 12 \ --max-absolute-position-embeddings $SEQ \ + --position-embedding-type absolute \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ --train-iters 1000 \ From 1481556412e6ec1ae9efe8e62460931023e61f75 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Fri, 23 Jul 2021 01:44:23 +0200 Subject: [PATCH 10/13] Revert back changes, add comparison with position embedding type when checkpointing and replace args.max_position_embeddings with an upper bound on the sequence sizes --- megatron/checkpointing.py | 1 + megatron/data/Makefile | 2 +- megatron/mpu/random.py | 10 ++++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index f960440ef..ab4f4a51c 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -61,6 +61,7 @@ def _compare(arg_name, old_arg_name=None): _compare('hidden_size') _compare('num_attention_heads') _compare('max_absolute_position_embeddings') + _compare('position_embedding_type') if args.vocab_file: _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') diff --git a/megatron/data/Makefile b/megatron/data/Makefile index 707390cdb..8f9db7686 100644 --- a/megatron/data/Makefile +++ b/megatron/data/Makefile @@ -1,5 +1,5 @@ CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color -CPPFLAGS += $(shell python -m pybind11 --includes) +CPPFLAGS += $(shell python3 -m pybind11 --includes) LIBNAME = helpers LIBEXT = $(shell python3-config --extension-suffix) diff --git a/megatron/mpu/random.py b/megatron/mpu/random.py index 7bb58310e..180ed2c51 100644 --- a/megatron/mpu/random.py +++ b/megatron/mpu/random.py @@ -42,12 +42,14 @@ def init_checkpointed_activations_memory_buffer(): - """Initializ the memory buffer for the checkpointed activations.""" + """Initialize the memory buffer for the checkpointed activations.""" args = get_args() - # TODO: Remove - assert args.max_absolute_position_embeddings is not None - per_layer = args.micro_batch_size * args.max_absolute_position_embeddings * \ + upper_bound_sequence_length = max( + args.seq_length if args.seq_length is not None else 0, + args.decoder_seq_length if args.decoder_seq_length is not None else 0 + ) + per_layer = args.micro_batch_size * upper_bound_sequence_length * \ args.hidden_size // args.tensor_model_parallel_size assert args.num_layers % args.checkpoint_num_layers == 0, \ 'number of layers is not divisible by checkpoint-num-layers' From 0528e39c52ef53d15ce26a8f29b5e6a836c11e09 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Sat, 24 Jul 2021 00:07:10 +0200 Subject: [PATCH 11/13] Revert back changes: - Rename max-absolute-embeddings back to max-absolute-embeddings - Make absolute position embeddings the default --- examples/create_embeddings.sh | 3 +- examples/evaluate_ict_zeroshot_nq.sh | 3 +- examples/evaluate_zeroshot_gpt.sh | 5 +- examples/finetune_mnli_distributed.sh | 5 +- examples/finetune_race_distributed.sh | 5 +- examples/generate_text.sh | 3 +- examples/merge_mp_bert.sh | 5 +- examples/pretrain_bert.sh | 3 +- examples/pretrain_bert_distributed.sh | 3 +- examples/pretrain_bert_distributed_with_mp.sh | 3 +- examples/pretrain_gpt.sh | 3 +- examples/pretrain_gpt3_175B.sh | 5 +- examples/pretrain_gpt_distributed.sh | 3 +- examples/pretrain_gpt_distributed_with_mp.sh | 3 +- examples/pretrain_gpt_tiny.sh | 3 +- examples/pretrain_ict.sh | 5 +- examples/pretrain_t5.sh | 3 +- examples/pretrain_t5_distributed.sh | 3 +- examples/pretrain_t5_distributed_with_mp.sh | 3 +- megatron/arguments.py | 15 +- megatron/checkpointing.py | 2 +- megatron/model/language_model.py | 6 +- run.sh | 2 +- run.sh~ | 150 ++++++++++++++++++ tools/merge_mp_partitions.py | 2 +- 25 files changed, 189 insertions(+), 57 deletions(-) create mode 100644 run.sh~ diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh index ddebac92b..59a5839f7 100644 --- a/examples/create_embeddings.sh +++ b/examples/create_embeddings.sh @@ -20,8 +20,7 @@ python tools/create_doc_index.py \ --checkpoint-activations \ --seq-length 512 \ --retriever-seq-length 256 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh index c4d901053..e1ce45a93 100644 --- a/examples/evaluate_ict_zeroshot_nq.sh +++ b/examples/evaluate_ict_zeroshot_nq.sh @@ -22,8 +22,7 @@ python tasks/main.py \ --micro-batch-size 128 \ --checkpoint-activations \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --load ${CHECKPOINT_PATH} \ --evidence-data-path ${EVIDENCE_DATA_DIR} \ --embedding-path ${EMBEDDING_PATH} \ diff --git a/examples/evaluate_zeroshot_gpt.sh b/examples/evaluate_zeroshot_gpt.sh index 533f968ff..9b98afd52 100755 --- a/examples/evaluate_zeroshot_gpt.sh +++ b/examples/evaluate_zeroshot_gpt.sh @@ -31,9 +31,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --batch-size 8 \ --checkpoint-activations \ --seq-length 1024 \ - --max-absolute-position-embeddings 1024 \ - --position-embedding-type absolute \ - --log-interval 10 \ + --max-position-embeddings 1024 \ + --log-interval 10 \ --fp16 \ --no-load-optim \ --no-load-rng diff --git a/examples/finetune_mnli_distributed.sh b/examples/finetune_mnli_distributed.sh index 0e1670e4c..2d451bbbb 100755 --- a/examples/finetune_mnli_distributed.sh +++ b/examples/finetune_mnli_distributed.sh @@ -34,9 +34,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-decay-style linear \ --lr-warmup-fraction 0.065 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ - --save-interval 500000 \ + --max-position-embeddings 512 \ + --save-interval 500000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ --eval-interval 100 \ diff --git a/examples/finetune_race_distributed.sh b/examples/finetune_race_distributed.sh index f60f80a38..283d3ca0d 100755 --- a/examples/finetune_race_distributed.sh +++ b/examples/finetune_race_distributed.sh @@ -34,9 +34,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-decay-style linear \ --lr-warmup-fraction 0.06 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ - --save-interval 100000 \ + --max-position-embeddings 512 \ + --save-interval 100000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ --eval-interval 100 \ diff --git a/examples/generate_text.sh b/examples/generate_text.sh index a841c3871..eefe8dfbe 100755 --- a/examples/generate_text.sh +++ b/examples/generate_text.sh @@ -10,8 +10,7 @@ python tools/generate_samples_gpt2.py \ --hidden-size 1024 \ --load $CHECKPOINT_PATH \ --num-attention-heads 16 \ - --max-absolute-position-embeddings 1024 \ - --position-embedding-type absolute \ + --max-position-embeddings 1024 \ --tokenizer-type GPT2BPETokenizer \ --fp16 \ --batch-size 2 \ diff --git a/examples/merge_mp_bert.sh b/examples/merge_mp_bert.sh index 9469dcf58..838148c82 100755 --- a/examples/merge_mp_bert.sh +++ b/examples/merge_mp_bert.sh @@ -14,6 +14,5 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ --hidden-size 1024 \ --num-attention-heads 16 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ - --load $CHECKPOINT_PATH + --max-position-embeddings 512 \ + --load $CHECKPOINT_PATH diff --git a/examples/pretrain_bert.sh b/examples/pretrain_bert.sh index 8c3f222ac..9c744ee45 100755 --- a/examples/pretrain_bert.sh +++ b/examples/pretrain_bert.sh @@ -12,8 +12,7 @@ python pretrain_bert.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 2000000 \ --lr-decay-iters 990000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed.sh b/examples/pretrain_bert_distributed.sh index 26fecece8..a833c5a94 100755 --- a/examples/pretrain_bert_distributed.sh +++ b/examples/pretrain_bert_distributed.sh @@ -21,8 +21,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 4 \ --global-batch-size 32 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_bert_distributed_with_mp.sh b/examples/pretrain_bert_distributed_with_mp.sh index d270a9712..4c50dcc25 100755 --- a/examples/pretrain_bert_distributed_with_mp.sh +++ b/examples/pretrain_bert_distributed_with_mp.sh @@ -23,8 +23,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --num-attention-heads 16 \ --micro-batch-size 2 \ --global-batch-size 16 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 1000000 \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt.sh b/examples/pretrain_gpt.sh index 62c4139bb..cad6bcc13 100755 --- a/examples/pretrain_gpt.sh +++ b/examples/pretrain_gpt.sh @@ -16,8 +16,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 1024 \ - --max-absolute-position-embeddings 1024 \ - --position-embedding-type absolute \ + --max-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh index ce7f97c0c..471490e2a 100755 --- a/examples/pretrain_gpt3_175B.sh +++ b/examples/pretrain_gpt3_175B.sh @@ -22,9 +22,8 @@ options=" \ --hidden-size 12288 \ --num-attention-heads 96 \ --seq-length 2048 \ - --max-absolute-position-embeddings 2048 \ - --position-embedding-type absolute \ - --micro-batch-size 1 \ + --max-position-embeddings 2048 \ + --micro-batch-size 1 \ --global-batch-size 1536 \ --rampup-batch-size 16 16 5859375 \ --train-samples 146484375 \ diff --git a/examples/pretrain_gpt_distributed.sh b/examples/pretrain_gpt_distributed.sh index 4c1a3044b..1b4518604 100755 --- a/examples/pretrain_gpt_distributed.sh +++ b/examples/pretrain_gpt_distributed.sh @@ -23,8 +23,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 8 \ --global-batch-size 64 \ --seq-length 1024 \ - --max-absolute-position-embeddings 1024 \ - --position-embedding-type absolute \ + --max-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_distributed_with_mp.sh b/examples/pretrain_gpt_distributed_with_mp.sh index d68f78487..c67db4c45 100755 --- a/examples/pretrain_gpt_distributed_with_mp.sh +++ b/examples/pretrain_gpt_distributed_with_mp.sh @@ -25,8 +25,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 4 \ --global-batch-size 16 \ --seq-length 1024 \ - --max-absolute-position-embeddings 1024 \ - --position-embedding-type absolute \ + --max-position-embeddings 1024 \ --train-iters 500000 \ --lr-decay-iters 320000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_gpt_tiny.sh b/examples/pretrain_gpt_tiny.sh index 950717c67..c7d953f10 100644 --- a/examples/pretrain_gpt_tiny.sh +++ b/examples/pretrain_gpt_tiny.sh @@ -16,8 +16,7 @@ deepspeed --num_gpus 1 pretrain_gpt.py \ --micro-batch-size 4 \ --global-batch-size 8 \ --seq-length 256 \ - --max-absolute-position-embeddings 256 \ - --position-embedding-type absolute \ + --max-position-embeddings 256 \ --train-iters 10000 \ --lr-decay-iters 5000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_ict.sh b/examples/pretrain_ict.sh index 4112926f7..3cd302a51 100755 --- a/examples/pretrain_ict.sh +++ b/examples/pretrain_ict.sh @@ -18,9 +18,8 @@ python pretrain_ict.py \ --tensor-model-parallel-size 1 \ --micro-batch-size 32 \ --seq-length 256 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ - --train-iters 100000 \ + --max-position-embeddings 512 \ + --train-iters 100000 \ --vocab-file bert-vocab.txt \ --tokenizer-type BertWordPieceLowerCase \ --DDP-impl torch \ diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh index 71aabed8a..71fea8489 100644 --- a/examples/pretrain_t5.sh +++ b/examples/pretrain_t5.sh @@ -16,8 +16,7 @@ python pretrain_t5.py \ --decoder-seq-length 128 \ --micro-batch-size 16 \ --global-batch-size 2048 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh index 4a0491ef8..778b4ad2a 100644 --- a/examples/pretrain_t5_distributed.sh +++ b/examples/pretrain_t5_distributed.sh @@ -25,8 +25,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --decoder-seq-length 128 \ --micro-batch-size 16 \ --global-batch-size 2048 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh index 3fbd838ee..9be70393d 100644 --- a/examples/pretrain_t5_distributed_with_mp.sh +++ b/examples/pretrain_t5_distributed_with_mp.sh @@ -26,8 +26,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --micro-batch-size 16 \ --global-batch-size 2048 \ --seq-length 512 \ - --max-absolute-position-embeddings 512 \ - --position-embedding-type absolute \ + --max-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ --save $CHECKPOINT_PATH \ diff --git a/megatron/arguments.py b/megatron/arguments.py index 364906954..5a2e9274c 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -202,7 +202,7 @@ def parse_args(extra_args_provider=None, defaults={}, 'and lr-warmup-samples' # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', 'position_embedding_type'] + required_args = ['num_layers', 'hidden_size', 'num_attention_heads'] for req_arg in required_args: _check_arg_is_not_none(args, req_arg) @@ -222,13 +222,13 @@ def parse_args(extra_args_provider=None, defaults={}, args.seq_length = args.encoder_seq_length if args.position_embedding_type == PositionEmbeddingType.absolute: - assert args.max_absolute_position_embeddings is not None + assert args.max_position_embeddings is not None if args.seq_length is not None: - assert args.max_absolute_position_embeddings >= args.seq_length + assert args.max_position_embeddings >= args.seq_length if args.decoder_seq_length is not None: - assert args.max_absolute_position_embeddings >= args.decoder_seq_length + assert args.max_position_embeddings >= args.decoder_seq_length else: - assert args.max_absolute_position_embeddings is None + assert args.max_position_embeddings is None if args.lr is not None: assert args.min_lr <= args.lr @@ -286,7 +286,7 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') - group.add_argument('--max-absolute-position-embeddings', type=int, default=None, + group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, @@ -309,7 +309,8 @@ def _add_network_size_args(parser): help='Disable BERT binary head.', dest='bert_binary_head') group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), - help='Define position embedding type ("absolute" | "rotary").' + default=PositionEmbeddingType.absolute, + help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index ab4f4a51c..829fb1101 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -60,7 +60,7 @@ def _compare(arg_name, old_arg_name=None): _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') - _compare('max_absolute_position_embeddings') + _compare('max_position_embeddings') _compare('position_embedding_type') if args.vocab_file: _compare('make_vocab_size_divisible_by') diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 03741d027..1022164ec 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -136,10 +136,10 @@ def __init__(self, # Position embedding (serial). self.position_embedding_type = args.position_embedding_type if self.position_embedding_type == PositionEmbeddingType.absolute: - max_absolute_position_embeddings = args.max_absolute_position_embeddings - assert max_absolute_position_embeddings is not None + max_position_embeddings = args.max_position_embeddings + assert max_position_embeddings is not None self.position_embeddings = torch.nn.Embedding( - max_absolute_position_embeddings, self.hidden_size) + max_position_embeddings, self.hidden_size) self._position_embeddings_key = 'position_embeddings' # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) diff --git a/run.sh b/run.sh index 9b7094f9b..b35201703 100755 --- a/run.sh +++ b/run.sh @@ -76,7 +76,7 @@ options=" \ --num-attention-heads 32 \ --seq-length $SEQ \ --loss-scale 12 \ - --max-absolute-position-embeddings $SEQ \ + --max-position-embeddings $SEQ \ --position-embedding-type absolute \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ diff --git a/run.sh~ b/run.sh~ new file mode 100644 index 000000000..b35201703 --- /dev/null +++ b/run.sh~ @@ -0,0 +1,150 @@ +#!/bin/bash + + +DIR=`pwd` +DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` +mkdir -p $DIR/logs + + +#DATASET_1="" +#DATASET_2="" +#DATASET_3="" +#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" + +BASE_DATA_PATH=/data/Megatron-LM/data +DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron +VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json +MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt + + +script_path=$(realpath $0) +script_dir=$(dirname $script_path) +CONFIG_JSON="$script_dir/ds_config.json" + +USE_DEEPSPEED=1 +ZERO_STAGE=0 + + +# Debug +#TP=4 +#PP=4 +#LAYERS=8 +#HIDDEN=512 +#SEQ=1024 +#GLOBAL_BATCH=128 +#WORKER_STR="-i worker-0" + + +# 52B +TP=4 +PP=16 +HIDDEN=8192 +LAYERS=64 +SEQ=1024 +GLOBAL_BATCH=1024 +WORKER_STR="" + +MICRO_BATCH=4 + +while [[ $# -gt 0 ]] +do +key="$1" +case $key in + --no-deepspeed) + USE_DEEPSPEED=0; + shift + ;; + -z|--zero-stage) + ZERO_STAGE=$2; + shift + ;; + *) + echo "Unknown argument(s)" + usage + exit 1 + shift + ;; +esac +done + + +options=" \ + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $LAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads 32 \ + --seq-length $SEQ \ + --loss-scale 12 \ + --max-position-embeddings $SEQ \ + --position-embedding-type absolute \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --train-iters 1000 \ + --lr 6.0e-5 \ + --min-lr 6.0e-6 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 40 \ + --eval-interval 1000 \ + --data-path ${DATASET} \ + --vocab-file ${VOCAB_PATH} \ + --merge-file ${MERGE_PATH} \ + --save-interval 1000 \ + --split 98,2,0 \ + --clip-grad 1.0 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.006 \ + --fp16 \ + --checkpoint-activations + " + + +if [[ ${USE_DEEPSPEED} -eq 1 ]]; then + echo "Using DeepSpeed" + options="${options} \ + --deepspeed \ + --deepspeed_config=${CONFIG_JSON} \ + --zero-stage=${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " +fi + + +cat < $CONFIG_JSON +{ + "train_batch_size" : $GLOBAL_BATCH, + "train_micro_batch_size_per_gpu": $MICRO_BATCH, + "steps_per_print": 1, + + "zero_optimization": { + "stage": $ZERO_STAGE + }, + + "gradient_clipping": 1.0, + "prescale_gradients": true, + + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + + "wall_clock_breakdown" : true +} +EOT + +#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" +#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" +run_cmd="deepspeed $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" + + +echo ${run_cmd} +eval ${run_cmd} + +set +x diff --git a/tools/merge_mp_partitions.py b/tools/merge_mp_partitions.py index 0d712df18..4dc2d99f8 100644 --- a/tools/merge_mp_partitions.py +++ b/tools/merge_mp_partitions.py @@ -225,7 +225,7 @@ def main(): print(' number of attention heads ....... {}'.format( args.num_attention_heads)) print(' maximum position embeddings ..... {}'.format( - args.max_absolute_position_embeddings)) + args.max_position_embeddings)) # Full model. print('> building the full model ...') From 605f58562b023d4077fd0e70b39a422a54ce6940 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Sat, 24 Jul 2021 00:15:54 +0200 Subject: [PATCH 12/13] Reformat --- examples/evaluate_zeroshot_gpt.sh | 2 +- examples/finetune_mnli_distributed.sh | 2 +- examples/finetune_race_distributed.sh | 2 +- examples/merge_mp_bert.sh | 2 +- examples/pretrain_gpt3_175B.sh | 2 +- examples/pretrain_ict.sh | 2 +- megatron/arguments.py | 3 ++- 7 files changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/evaluate_zeroshot_gpt.sh b/examples/evaluate_zeroshot_gpt.sh index 9b98afd52..96fd28f3a 100755 --- a/examples/evaluate_zeroshot_gpt.sh +++ b/examples/evaluate_zeroshot_gpt.sh @@ -32,7 +32,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --checkpoint-activations \ --seq-length 1024 \ --max-position-embeddings 1024 \ - --log-interval 10 \ + --log-interval 10 \ --fp16 \ --no-load-optim \ --no-load-rng diff --git a/examples/finetune_mnli_distributed.sh b/examples/finetune_mnli_distributed.sh index 2d451bbbb..213eb1fa1 100755 --- a/examples/finetune_mnli_distributed.sh +++ b/examples/finetune_mnli_distributed.sh @@ -35,7 +35,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-warmup-fraction 0.065 \ --seq-length 512 \ --max-position-embeddings 512 \ - --save-interval 500000 \ + --save-interval 500000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ --eval-interval 100 \ diff --git a/examples/finetune_race_distributed.sh b/examples/finetune_race_distributed.sh index 283d3ca0d..5ac642ee3 100755 --- a/examples/finetune_race_distributed.sh +++ b/examples/finetune_race_distributed.sh @@ -35,7 +35,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ --lr-warmup-fraction 0.06 \ --seq-length 512 \ --max-position-embeddings 512 \ - --save-interval 100000 \ + --save-interval 100000 \ --save $CHECKPOINT_PATH \ --log-interval 10 \ --eval-interval 100 \ diff --git a/examples/merge_mp_bert.sh b/examples/merge_mp_bert.sh index 838148c82..138343328 100755 --- a/examples/merge_mp_bert.sh +++ b/examples/merge_mp_bert.sh @@ -15,4 +15,4 @@ WORLD_SIZE=$TENSOR_MODEL_PARALLEL_SIZE python tools/merge_mp_partitions.py \ --num-attention-heads 16 \ --seq-length 512 \ --max-position-embeddings 512 \ - --load $CHECKPOINT_PATH + --load $CHECKPOINT_PATH diff --git a/examples/pretrain_gpt3_175B.sh b/examples/pretrain_gpt3_175B.sh index 471490e2a..88761b555 100755 --- a/examples/pretrain_gpt3_175B.sh +++ b/examples/pretrain_gpt3_175B.sh @@ -23,7 +23,7 @@ options=" \ --num-attention-heads 96 \ --seq-length 2048 \ --max-position-embeddings 2048 \ - --micro-batch-size 1 \ + --micro-batch-size 1 \ --global-batch-size 1536 \ --rampup-batch-size 16 16 5859375 \ --train-samples 146484375 \ diff --git a/examples/pretrain_ict.sh b/examples/pretrain_ict.sh index 3cd302a51..8cba0f08b 100755 --- a/examples/pretrain_ict.sh +++ b/examples/pretrain_ict.sh @@ -19,7 +19,7 @@ python pretrain_ict.py \ --micro-batch-size 32 \ --seq-length 256 \ --max-position-embeddings 512 \ - --train-iters 100000 \ + --train-iters 100000 \ --vocab-file bert-vocab.txt \ --tokenizer-type BertWordPieceLowerCase \ --DDP-impl torch \ diff --git a/megatron/arguments.py b/megatron/arguments.py index 5a2e9274c..761641b05 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -308,7 +308,8 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') - group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), + group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], + choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute, help='Define position embedding type ("absolute" | "rotary"). "absolute" by default.' ) From 99be67e162d188828a5292f8918e5de4ba747fb7 Mon Sep 17 00:00:00 2001 From: Thomas <ö95242+thomasw21@users.noreply.github.com> Date: Sat, 24 Jul 2021 00:30:26 +0200 Subject: [PATCH 13/13] Rm run.sh~ and modify back run.sh --- run.sh | 1 - run.sh~ | 150 -------------------------------------------------------- 2 files changed, 151 deletions(-) delete mode 100644 run.sh~ diff --git a/run.sh b/run.sh index b35201703..b8fe8fc36 100755 --- a/run.sh +++ b/run.sh @@ -77,7 +77,6 @@ options=" \ --seq-length $SEQ \ --loss-scale 12 \ --max-position-embeddings $SEQ \ - --position-embedding-type absolute \ --micro-batch-size $MICRO_BATCH \ --global-batch-size $GLOBAL_BATCH \ --train-iters 1000 \ diff --git a/run.sh~ b/run.sh~ deleted file mode 100644 index b35201703..000000000 --- a/run.sh~ +++ /dev/null @@ -1,150 +0,0 @@ -#!/bin/bash - - -DIR=`pwd` -DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'` -mkdir -p $DIR/logs - - -#DATASET_1="" -#DATASET_2="" -#DATASET_3="" -#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}" - -BASE_DATA_PATH=/data/Megatron-LM/data -DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron -VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json -MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt - - -script_path=$(realpath $0) -script_dir=$(dirname $script_path) -CONFIG_JSON="$script_dir/ds_config.json" - -USE_DEEPSPEED=1 -ZERO_STAGE=0 - - -# Debug -#TP=4 -#PP=4 -#LAYERS=8 -#HIDDEN=512 -#SEQ=1024 -#GLOBAL_BATCH=128 -#WORKER_STR="-i worker-0" - - -# 52B -TP=4 -PP=16 -HIDDEN=8192 -LAYERS=64 -SEQ=1024 -GLOBAL_BATCH=1024 -WORKER_STR="" - -MICRO_BATCH=4 - -while [[ $# -gt 0 ]] -do -key="$1" -case $key in - --no-deepspeed) - USE_DEEPSPEED=0; - shift - ;; - -z|--zero-stage) - ZERO_STAGE=$2; - shift - ;; - *) - echo "Unknown argument(s)" - usage - exit 1 - shift - ;; -esac -done - - -options=" \ - --tensor-model-parallel-size $TP \ - --pipeline-model-parallel-size $PP \ - --num-layers $LAYERS \ - --hidden-size $HIDDEN \ - --num-attention-heads 32 \ - --seq-length $SEQ \ - --loss-scale 12 \ - --max-position-embeddings $SEQ \ - --position-embedding-type absolute \ - --micro-batch-size $MICRO_BATCH \ - --global-batch-size $GLOBAL_BATCH \ - --train-iters 1000 \ - --lr 6.0e-5 \ - --min-lr 6.0e-6 \ - --lr-decay-style cosine \ - --log-interval 1 \ - --eval-iters 40 \ - --eval-interval 1000 \ - --data-path ${DATASET} \ - --vocab-file ${VOCAB_PATH} \ - --merge-file ${MERGE_PATH} \ - --save-interval 1000 \ - --split 98,2,0 \ - --clip-grad 1.0 \ - --weight-decay 0.1 \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --init-method-std 0.006 \ - --fp16 \ - --checkpoint-activations - " - - -if [[ ${USE_DEEPSPEED} -eq 1 ]]; then - echo "Using DeepSpeed" - options="${options} \ - --deepspeed \ - --deepspeed_config=${CONFIG_JSON} \ - --zero-stage=${ZERO_STAGE} \ - --deepspeed-activation-checkpointing \ - " -fi - - -cat < $CONFIG_JSON -{ - "train_batch_size" : $GLOBAL_BATCH, - "train_micro_batch_size_per_gpu": $MICRO_BATCH, - "steps_per_print": 1, - - "zero_optimization": { - "stage": $ZERO_STAGE - }, - - "gradient_clipping": 1.0, - "prescale_gradients": true, - - "fp16": { - "enabled": true, - "loss_scale": 0, - "loss_scale_window": 500, - "hysteresis": 2, - "min_loss_scale": 1, - "initial_scale_power": 12 - }, - - "wall_clock_breakdown" : true -} -EOT - -#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}" -#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}" -run_cmd="deepspeed $WORKER_STR ${DIR}/pretrain_gpt.py $@ ${options}" - - -echo ${run_cmd} -eval ${run_cmd} - -set +x