diff --git a/megatron/__init__.py b/megatron/__init__.py index f670e652aac..b3a03290088 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -12,9 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + import torch import os +logger = logging.getLogger(__name__) + from .package_info import ( __description__, __contact_names__, @@ -38,21 +42,11 @@ from .initialize import initialize_megatron def print_rank_0(message): - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) + logger.info(str(message)) def is_last_rank(): return torch.distributed.get_rank() == ( torch.distributed.get_world_size() - 1) def print_rank_last(message): - """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) + logger.info(str(message)) diff --git a/megatron/arguments.py b/megatron/arguments.py index b8c230f5793..fed8440fbd1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -16,10 +16,13 @@ """Megatron arguments.""" import argparse +import logging import os import torch +logger = logging.getLogger(__name__) + def parse_args(extra_args_provider=None, defaults={}, ignore_unknown_args=False): """Parse all arguments.""" @@ -73,13 +76,12 @@ def parse_args(extra_args_provider=None, defaults={}, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size - if args.rank == 0: - print('using world size: {}, data-parallel-size: {}, ' - 'tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {} '.format( - args.world_size, args.data_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size), flush=True) + logger.info('using world size: {}, data-parallel-size: {}, ' + 'tensor-model-parallel size: {}, ' + 'pipeline-model-parallel size: {} '.format( + args.world_size, args.data_parallel_size, + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size)) # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ @@ -98,11 +100,9 @@ def parse_args(extra_args_provider=None, defaults={}, # arguments that are passed to the program. We check this by # ensuring the arg is set to None. if getattr(args, key) is not None: - if args.rank == 0: - print('WARNING: overriding default arguments for {key}:{v} \ - with {key}:{v2}'.format(key=key, v=defaults[key], - v2=getattr(args, key)), - flush=True) + logger.warning('Overriding default arguments for {key}:{v} \ + with {key}:{v2}'.format(key=key, v=defaults[key], + v2=getattr(args, key))) else: setattr(args, key, defaults[key]) @@ -111,9 +111,8 @@ def parse_args(extra_args_provider=None, defaults={}, assert args.micro_batch_size > 0 if args.global_batch_size is None: args.global_batch_size = args.micro_batch_size * args.data_parallel_size - if args.rank == 0: - print('setting global batch size to {}'.format( - args.global_batch_size), flush=True) + logger.info('setting global batch size to {}'.format( + args.global_batch_size)) assert args.global_batch_size > 0 if args.num_layers_per_virtual_pipeline_stage is not None: assert args.pipeline_model_parallel_size > 2, \ @@ -140,13 +139,10 @@ def parse_args(extra_args_provider=None, defaults={}, # be done in fp32. if not args.accumulate_allreduce_grads_in_fp32: args.accumulate_allreduce_grads_in_fp32 = True - if args.rank == 0: - print('accumulate and all-reduce gradients in fp32 for ' - 'bfloat16 data type.', flush=True) + logger.info('accumulate and all-reduce gradients in fp32 for ' + 'bfloat16 data type.') - if args.rank == 0: - print('using {} for parameters ...'.format(args.params_dtype), - flush=True) + logger.info('using {} for parameters ...'.format(args.params_dtype)) # If we do accumulation and all-reduces in fp32, we need to have # local DDP and we should set the use-contiguous-buffers-in-ddp. @@ -239,17 +235,14 @@ def parse_args(extra_args_provider=None, defaults={}, def _print_args(args): """Print arguments.""" - if args.rank == 0: - print('------------------------ arguments ------------------------', - flush=True) - str_list = [] - for arg in vars(args): - dots = '.' * (48 - len(arg)) - str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print('-------------------- end of arguments ---------------------', - flush=True) + logger.info('------------------------ arguments ------------------------') + str_list = [] + for arg in vars(args): + dots = '.' * (48 - len(arg)) + str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) + for arg in sorted(str_list, key=lambda x: x.lower()): + logger.info(arg) + logger.info('-------------------- end of arguments ---------------------') def _check_arg_is_not_none(args, arg): @@ -304,6 +297,8 @@ def _add_logging_args(parser): group.add_argument('--log-params-norm', action='store_true', help='If set, calculate and log parameters norm.') + group.add_argument('--log-scales', action='store_true', + help='Log the scales of parameters, gradients and activations.') group.add_argument('--log-num-zeros-in-grad', action='store_true', help='If set, calculate and log the number of zeros in gradient.') group.add_argument('--tensorboard-log-interval', type=int, default=1, diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py index b1b61cd87b7..dee12e1b120 100644 --- a/megatron/data/biencoder_dataset_utils.py +++ b/megatron/data/biencoder_dataset_utils.py @@ -189,6 +189,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo # Wait until rank 0 generate the index file. torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) + # It can take some time for the file to be visible on other nodes. + for i in range(120): + if indexmap_filename.is_file(): + break + if i%10==0: + print_rank_0(" Waiting for index file...") + time.sleep(1.0) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 81acb6cde64..fa8cd2eb867 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -702,6 +702,13 @@ def get_samples_mapping(indexed_dataset, # Wait until rank 0 generate the index file. torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) + # It can take some time for the file to be visible on other nodes. + for i in range(120): + if indexmap_filename.is_file(): + break + if i%10==0: + print_rank_0(" Waiting for index file...") + time.sleep(1.0) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index ca16f38efbd..815cc985e2c 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -302,6 +302,14 @@ def _build_index_mappings(name, data_prefix, documents, sizes, # Wait until rank 0 generate the index file. torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) + # It can take some time for the file to be visible on other nodes. + for i in range(120): + if doc_idx_filename.is_file() and sample_idx_filename.is_file() and shuffle_idx_filename.is_file(): + break + if i%10==0: + print_rank_0(" Waiting for index files...") + time.sleep(1.0) + # Load mappings. start_time = time.time() print_rank_0(' > loading doc-idx mapping from {}'.format( diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py index dd33fcd2886..05ed12d8cdb 100644 --- a/megatron/data/realm_dataset_utils.py +++ b/megatron/data/realm_dataset_utils.py @@ -179,6 +179,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo # Wait until rank 0 generate the index file. torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) + # It can take some time for the file to be visible on other nodes. + for i in range(120): + if indexmap_filename.is_file(): + break + if i%10==0: + print_rank_0(" Waiting for index file...") + time.sleep(1.0) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu index ce42584aa33..a892c069f53 100644 --- a/megatron/fused_kernels/layer_norm_cuda_kernel.cu +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -21,7 +21,7 @@ #include "ATen/ATen.h" #include "ATen/AccumulateType.h" #include "ATen/cuda/CUDAContext.h" -#include +#include "ATen/cuda/DeviceUtils.cuh" #include #include diff --git a/megatron/metrics.py b/megatron/metrics.py new file mode 100644 index 00000000000..8883a9dbb58 --- /dev/null +++ b/megatron/metrics.py @@ -0,0 +1,73 @@ +import logging +import math + +import torch +from megatron.global_vars import get_args + +logger = logging.getLogger(__name__) + +_iteration=0 +_metrics={} +_LOGGING_WIDTH=50 + +def next_iteration(iteration:int): + global _iteration, _metrics + _metrics={} + _iteration=iteration + + +def record_scale(name:str,x:torch.Tensor,grad=True, bias=None): + global _metrics + if get_log_scales(): + _metrics[f"{name}.scale" if grad else name]=get_scale(x if bias is None else x+bias) + if grad and x.requires_grad: + x.register_hook(lambda g: record_scale(f"{name}.grad",g,False)) + + +def get_scale(x): + return x.detach().float().pow(2).mean().pow(0.5) + + +def get_log_scales(): + args=get_args() + return args.log_scales and (_iteration+1) % args.log_interval == 0 + + +def log_metrics(): + metrics = {} + for key, value in _metrics.items(): + metrics_ = metrics + keys = key.split(".") + for prefix in keys[:-1]: + if prefix not in metrics_: + metrics_[prefix] = {} + metrics_ = metrics_[prefix] + metrics_[keys[-1]] = _format_value(value) + _log_dicts(metrics) + + +def _log_dicts(metrics, indent=0): + for key, value in metrics.items(): + key_ = key.rjust(len(key) + indent) + + # Merge keys when there is only one entry. + while isinstance(value, dict) and len(value) == 1: + for value_key, value_ in value.items(): + key_ = ".".join([key_, value_key]) + value = value_ + if isinstance(value, dict): + logger.info(key_ + ":") + _log_dicts(value, indent + 2) + else: + sep = _LOGGING_WIDTH - len(value) - len(key_) - 2 + logger.info(f"{key_.ljust(len(key_)+sep,'.')} {value}") + + +def _format_value(value, precision=5,max_leading_zeros=3): + decimals = 0 if value == 0 or not math.isfinite(value) else precision - math.floor(math.log10(abs(value))) + + if 0 <= decimals <= precision + max_leading_zeros: + value = f"{value:.{decimals}f}" + else: + value = f"{value:.{precision}e}" + return value \ No newline at end of file diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index 3ff5039d5fe..a649885760b 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -15,10 +15,12 @@ """BERT model.""" +import logging import torch from megatron import get_args from megatron import mpu +from megatron.metrics import record_scale from megatron.model.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model @@ -67,18 +69,20 @@ class BertLMHead(MegatronModule): """ def __init__(self, mpu_vocab_size, hidden_size, init_method, - layernorm_epsilon, parallel_output): + layernorm_epsilon, parallel_output, name_=""): super(BertLMHead, self).__init__() + self.name_=name_ args = get_args() self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) + self.bias.name_=f"{self.name_}.logits.linear_bias" mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) self.parallel_output = parallel_output - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) + self.dense = get_linear_layer(hidden_size, hidden_size, init_method, name_=f"{self.name_}.dense") + self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, name_=f"{self.name_}.layer_norm") self.gelu = torch.nn.functional.gelu if args.openai_gelu: self.gelu = openai_gelu @@ -86,13 +90,16 @@ def __init__(self, mpu_vocab_size, hidden_size, init_method, self.gelu = erf_gelu def forward(self, hidden_states, word_embeddings_weight): + record_scale(f"{self.name_}.hidden",hidden_states) hidden_states = self.dense(hidden_states) hidden_states = self.gelu(hidden_states) + record_scale(f"{self.name_}.gelu",hidden_states) hidden_states = self.layernorm(hidden_states) output = parallel_lm_logits(hidden_states, word_embeddings_weight, self.parallel_output, bias=self.bias) + record_scale(f"{self.name_}.logits",output) return output @@ -129,9 +136,11 @@ def __init__(self, add_binary_head=True, parallel_output=True, pre_process=True, - post_process=True): + post_process=True, + name_="bert"): super(BertModel, self).__init__() args = get_args() + self.name_=name_ self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head @@ -150,18 +159,20 @@ def __init__(self, init_method=init_method, scaled_init_method=scaled_init_method, pre_process=self.pre_process, - post_process=self.post_process) + post_process=self.post_process, + name_=self.name_) self.initialize_word_embeddings(init_method_normal) if self.post_process: self.lm_head = BertLMHead( self.word_embeddings_weight().size(0), - args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) + args.hidden_size, init_method, args.layernorm_epsilon, parallel_output, + name_=f"{self.name_}.output_layer.lm_head") self._lm_head_key = 'lm_head' self.binary_head = None if self.add_binary_head: self.binary_head = get_linear_layer(args.hidden_size, 2, - init_method) + init_method, name_=f"{self.name_}.output_layer.sop_head.binary_head") self._binary_head_key = 'binary_head' def set_input_tensor(self, input_tensor): diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 78645c23613..8218c65a5e5 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -23,6 +23,8 @@ from torch.nn import init import importlib +from megatron.metrics import record_scale + global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -61,8 +63,9 @@ def backward(ctx, grad_output): class MixedFusedLayerNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5, name_=""): super(MixedFusedLayerNorm, self).__init__() + self.name_=name_ global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = importlib.import_module( @@ -73,7 +76,9 @@ def __init__(self, normalized_shape, eps=1e-5): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.weight.name_=f"{self.name_}.layer_norm_weight" self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.bias.name_=f"{self.name_}.layer_norm_bias" self.reset_parameters() @@ -85,6 +90,8 @@ def reset_parameters(self): def forward(self, input): - return FusedLayerNormAffineFunction.apply( + output = FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape,self.eps) + record_scale(self.name_, output) + return output diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 06330d81395..3bf9a9712cf 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -21,6 +21,7 @@ from megatron import get_args from megatron import mpu from .module import MegatronModule +from megatron.metrics import record_scale from megatron.model.enums import LayerType, AttnMaskType from megatron.model.transformer import ParallelTransformer from megatron.model.utils import get_linear_layer @@ -47,7 +48,7 @@ def get_language_model(num_tokentypes, add_pooler, encoder_attn_mask_type, init_method=None, scaled_init_method=None, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, - pre_process=True, post_process=True): + pre_process=True, post_process=True, name_=""): """Build language model and return along with the key to save.""" args = get_args() @@ -68,7 +69,8 @@ def get_language_model(num_tokentypes, add_pooler, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, pre_process=pre_process, - post_process=post_process + post_process=post_process, + name_=name_ ) # key used for checkpoints. language_model_key = 'language_model' @@ -88,16 +90,20 @@ class Pooler(MegatronModule): bias is set to zero. """ - def __init__(self, hidden_size, init_method): + def __init__(self, hidden_size, init_method, name_=""): super(Pooler, self).__init__() - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) + self.name_=name_ + self.dense = get_linear_layer(hidden_size, hidden_size, init_method, name_=f"{self.name_}.dense") def forward(self, hidden_states, sequence_index=0): # hidden_states: [b, s, h] # sequence_index: index of the token to pool. + record_scale(f"{self.name_}.input",hidden_states) pooled = hidden_states[:, sequence_index, :] + record_scale(f"{self.name_}.pooled",pooled) pooled = self.dense(pooled) pooled = torch.tanh(pooled) + record_scale(f"{self.name_}.tanh",pooled) return pooled @@ -121,7 +127,8 @@ def __init__(self, max_sequence_length, embedding_dropout_prob, init_method, - num_tokentypes=0): + num_tokentypes=0, + name_=""): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -129,17 +136,22 @@ def __init__(self, self.num_tokentypes = num_tokentypes args = get_args() + self.name_=name_ # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method) self._word_embeddings_key = 'word_embeddings' + self.word_embeddings.name_=f"{self.name_}.word_embeddings" + self.word_embeddings.weight.name_=f"{self.word_embeddings.name_}.embedding_weight" # Position embedding (serial). self.position_embeddings = torch.nn.Embedding( max_sequence_length, self.hidden_size) self._position_embeddings_key = 'position_embeddings' + self.position_embeddings.name_=f"{self.name_}.position_embeddings" + self.position_embeddings.weight.name_=f"{self.position_embeddings.name_}.embedding_weight" # Initialize the position embeddings. self.init_method(self.position_embeddings.weight) @@ -151,6 +163,8 @@ def __init__(self, if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) + self.tokentype_embeddings.name_=f"{self.name_}.tokentype_embeddings" + self.tokentype_embeddings.weight.name_=f"{self.tokentype_embeddings.name_}.embedding_weight" # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) else: @@ -178,17 +192,24 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. + args=get_args() words_embeddings = self.word_embeddings(input_ids) + record_scale(self.word_embeddings.name_,words_embeddings) position_embeddings = self.position_embeddings(position_ids) + record_scale(self.position_embeddings.name_,position_embeddings) embeddings = words_embeddings + position_embeddings if tokentype_ids is not None: assert self.tokentype_embeddings is not None - embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) + tokentype_embeddings=self.tokentype_embeddings(tokentype_ids) + record_scale(self.tokentype_embeddings.name_,tokentype_embeddings) + embeddings = embeddings + tokentype_embeddings else: assert self.tokentype_embeddings is None + record_scale(f"{self.name_}.embeddings",embeddings) # Dropout. embeddings = self.embedding_dropout(embeddings) + record_scale(f"{self.name_}.dropout",embeddings) return embeddings @@ -277,9 +298,11 @@ def __init__(self, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, - post_process=True): + post_process=True, + name_=""): super(TransformerLanguageModel, self).__init__() args = get_args() + self.name_ = name_ self.pre_process = pre_process self.post_process = post_process @@ -298,7 +321,8 @@ def __init__(self, args.max_position_embeddings, args.hidden_dropout, self.init_method, - self.num_tokentypes) + self.num_tokentypes, + name_=f"{self.name_}.input_layer.embedding") self._embedding_key = 'embedding' # Transformer. @@ -307,7 +331,8 @@ def __init__(self, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, - post_process=self.post_process + post_process=self.post_process, + name_=self.name_, ) self._encoder_key = 'encoder' @@ -319,13 +344,14 @@ def __init__(self, self.init_method, output_layer_init_method, layer_type=LayerType.decoder, - self_attn_mask_type=self.decoder_attn_mask_type) + self_attn_mask_type=self.decoder_attn_mask_type, + name_=f"{self.name_}.decoder") self._decoder_key = 'decoder' if self.post_process: # Pooler. if self.add_pooler: - self.pooler = Pooler(self.hidden_size, self.init_method) + self.pooler = Pooler(self.hidden_size, self.init_method, name_=f"{self.name_}.output_layer.sop_head") self._pooler_key = 'pooler' def set_input_tensor(self, input_tensor): diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index ac9d2021892..5a2c91306f6 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -20,6 +20,7 @@ from megatron import get_args from megatron import mpu +from megatron.metrics import record_scale from .module import MegatronModule from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model import LayerNorm @@ -57,9 +58,10 @@ class ParallelMLP(MegatronModule): applied. """ - def __init__(self, init_method, output_layer_init_method): + def __init__(self, init_method, output_layer_init_method, name_=""): super(ParallelMLP, self).__init__() args = get_args() + self.name_=name_ # Project to 4h. self.dense_h_to_4h = mpu.ColumnParallelLinear( @@ -67,7 +69,8 @@ def __init__(self, init_method, output_layer_init_method): args.ffn_hidden_size, gather_output=False, init_method=init_method, - skip_bias_add=True) + skip_bias_add=True, + name_=f"{name_}.dense_0") self.bias_gelu_fusion = args.bias_gelu_fusion self.activation_func = F.gelu @@ -82,7 +85,8 @@ def __init__(self, init_method, output_layer_init_method): args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + name_=f"{name_}.dense_1") def forward(self, hidden_states): @@ -97,6 +101,7 @@ def forward(self, hidden_states): intermediate_parallel = \ self.activation_func(intermediate_parallel + bias_parallel) + record_scale(f"{self.name_}.gelu", intermediate_parallel) # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) return output, output_bias @@ -112,9 +117,11 @@ class ParallelAttention(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding): + attn_mask_type=AttnMaskType.padding, + name_=""): super(ParallelAttention, self).__init__() args = get_args() + self.name_=name_ self.fp16 = args.fp16 self.bf16 = args.bf16 @@ -143,20 +150,23 @@ def __init__(self, init_method, args.hidden_size, 3 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + name_=f"{self.name_}.query_key_value") else: assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + name_=f"{self.name_}.query") self.key_value = mpu.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, - init_method=init_method) + init_method=init_method, + name_=f"{self.name_}.key_value") coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -183,7 +193,8 @@ def __init__(self, init_method, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, - skip_bias_add=True) + skip_bias_add=True, + name_=f"{self.name_}.dense") def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False, encoder_output=None): @@ -229,6 +240,10 @@ def forward(self, hidden_states, attention_mask, layer_past=None, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) + record_scale(f"{self.name_}.query_layer", query_layer) + record_scale(f"{self.name_}.key_layer", key_layer) + record_scale(f"{self.name_}.value_layer", value_layer) + # ================================== # Adjust key and value for inference # ================================== @@ -277,6 +292,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) + record_scale(f"{self.name_}.attention_scores", attention_scores) # ================================================== # Update attention mask for inference. [b, np, sq, sk] # ================================================== @@ -301,6 +317,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None, # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + record_scale(f"{self.name_}.attention_probs", attention_probs) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -342,6 +359,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None, (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) + record_scale(f"{self.name_}.context_layer", context_layer) + # ================= # Output. [sq, b, h] # ================= @@ -388,10 +407,12 @@ class ParallelTransformerLayer(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding): + self_attn_mask_type=AttnMaskType.padding, + name_=""): args = get_args() super(ParallelTransformerLayer, self).__init__() + self.name_=name_ self.layer_number = layer_number self.layer_type = layer_type @@ -404,7 +425,9 @@ def __init__(self, init_method, output_layer_init_method, # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, - eps=args.layernorm_epsilon) + eps=args.layernorm_epsilon, + name_=f"{self.name_}.input_layer_norm", + ) # Self attention. self.self_attention = ParallelAttention( @@ -412,29 +435,35 @@ def __init__(self, init_method, output_layer_init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type) + attn_mask_type=self_attn_mask_type, + name_=f"{self.name_}.self_attention") self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, - eps=args.layernorm_epsilon) + eps=args.layernorm_epsilon, + name_=f"{self.name_}.post_attention_layer_norm", + ) if self.layer_type == LayerType.decoder: self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, - attention_type=AttnType.cross_attn) + attention_type=AttnType.cross_attn, + name_=f"{self.name_}.inter_attention") # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, - eps=args.layernorm_epsilon) + eps=args.layernorm_epsilon, + name_=f"{self.name_}.post_inter_attention_layer_norm", + ) # MLP self.mlp = ParallelMLP(init_method, - output_layer_init_method) + output_layer_init_method, name_=f"{self.name_}.mlp") def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, @@ -450,6 +479,8 @@ def forward(self, hidden_states, attention_mask, layer_past=layer_past, get_key_value=get_key_value) + record_scale(f"{self.name_}.attention", attention_output, bias=attention_bias) + if get_key_value: attention_output, presents = attention_output @@ -458,6 +489,7 @@ def forward(self, hidden_states, attention_mask, residual = layernorm_output else: residual = hidden_states + record_scale(f"{self.name_}.attention_residual_input", residual) # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two @@ -471,6 +503,7 @@ def forward(self, hidden_states, attention_mask, else: bias_dropout_add_func = get_bias_dropout_add(self.training) + # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( @@ -479,6 +512,8 @@ def forward(self, hidden_states, attention_mask, residual, self.hidden_dropout) + record_scale(f"{self.name_}.attention_residual", layernorm_input) + # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -487,11 +522,13 @@ def forward(self, hidden_states, attention_mask, self.inter_attention(layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output) + record_scale(f"{self.name_}.inter_attention", attention_output, bias=attention_bias) # residual connection if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input + record_scale(f"{self.name_}.inter_attention_residual_input", residual) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): @@ -500,6 +537,7 @@ def forward(self, hidden_states, attention_mask, attention_bias.expand_as(residual), residual, self.hidden_dropout) + record_scale(f"{self.name_}.inter_attention_residual", layernorm_input) # Layer norm post the decoder attention layernorm_output = self.post_inter_attention_layernorm(layernorm_input) @@ -512,6 +550,7 @@ def forward(self, hidden_states, attention_mask, residual = layernorm_output else: residual = layernorm_input + record_scale(f"{self.name_}.mlp_residual_input", residual) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): @@ -521,6 +560,8 @@ def forward(self, hidden_states, attention_mask, residual, self.hidden_dropout) + record_scale(f"{self.name_}.mlp_residual", layernorm_input) + if get_key_value: output = [output, presents] @@ -533,9 +574,11 @@ class ParallelTransformer(MegatronModule): def __init__(self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, - pre_process=True, post_process=True): + pre_process=True, post_process=True, + name_=""): super(ParallelTransformer, self).__init__() args = get_args() + self.name_=name_ self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection @@ -559,7 +602,8 @@ def build_layer(layer_number): output_layer_init_method, layer_number, layer_type=layer_type, - self_attn_mask_type=self_attn_mask_type) + self_attn_mask_type=self_attn_mask_type, + name_=f"{self.name_}.layer_{layer_number-1}.transformer_layer") if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ 'num_layers_per_stage must be divisible by ' \ @@ -589,7 +633,8 @@ def build_layer(layer_number): # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, - eps=args.layernorm_epsilon) + eps=args.layernorm_epsilon, + name_=f"{self.name_}.output_layer.final_layer_norm") def _get_layer(self, layer_number): return self.layers[layer_number] diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 465e8aa4ff6..d87616c6d98 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -18,8 +18,7 @@ import math import torch - -from megatron import get_args +from megatron.metrics import record_scale def init_method_normal(sigma): """Init method based on N(0, sigma).""" @@ -31,7 +30,7 @@ def init_(tensor): def scaled_init_method_normal(sigma, num_layers): """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) + std = sigma / math.sqrt(2.0 * max(num_layers,1)) def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=std) @@ -44,12 +43,26 @@ def attention_mask_func(attention_scores, attention_mask): return attention_scores -def get_linear_layer(rows, columns, init_method): +def get_linear_layer(rows, columns, init_method, name_=""): """Simple linear layer with weight initialization.""" layer = torch.nn.Linear(rows, columns) init_method(layer.weight) with torch.no_grad(): layer.bias.zero_() + layer.name_=name_ + layer.weight.name_=f"{name_}.linear_weight" + layer.bias.name_=f"{name_}.linear_bias" + + + old_forward=layer.forward + + def forward(input): + output=old_forward(input) + record_scale(layer.name_,output) + return output + + layer.forward=forward + return layer @torch.jit.script diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 8dd69f72cb8..9bf58d2b8fa 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -35,7 +35,7 @@ from .utils import divide from .utils import split_tensor_along_last_dim from .utils import VocabUtility -from megatron import get_args +from megatron.metrics import get_args, get_log_scales, record_scale _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, @@ -225,8 +225,9 @@ class ColumnParallelLinear(torch.nn.Module): def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False): + skip_bias_add=False,name_=""): super(ColumnParallelLinear, self).__init__() + self.name_=name_ # Keep input parameters self.input_size = input_size @@ -256,6 +257,7 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True, device=torch.cuda.current_device(), dtype=args.params_dtype)) _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) + self.weight.name_=f"{self.name_}.linear_weight" if bias: if args.use_cpu_initialization: @@ -270,6 +272,7 @@ def __init__(self, input_size, output_size, bias=True, gather_output=True, # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() + self.bias.name_ = f"{self.name_}.linear_bias" else: self.register_parameter('bias', None) @@ -288,6 +291,7 @@ def forward(self, input_): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None + record_scale(self.name_, output, bias=output_bias) return output, output_bias @@ -325,8 +329,9 @@ def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, - skip_bias_add=False): + skip_bias_add=False,name_=""): super(RowParallelLinear, self).__init__() + self.name_=name_ # Keep input parameters self.input_size = input_size @@ -356,6 +361,7 @@ def __init__(self, input_size, output_size, bias=True, device=torch.cuda.current_device(), dtype=args.params_dtype)) _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=1, stride=stride) + self.weight.name_ = f"{self.name_}.linear_weight" if bias: if args.use_cpu_initialization: self.bias = Parameter(torch.empty(self.output_size, @@ -367,6 +373,7 @@ def __init__(self, input_size, output_size, bias=True, # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() + self.bias.name_ = f"{self.name_}.linear_bias" else: self.register_parameter('bias', None) @@ -388,5 +395,6 @@ def forward(self, input_): else: output = output_ output_bias = self.bias + record_scale(self.name_, output, bias=output_bias) return output, output_bias diff --git a/megatron/optimizer/__init__.py b/megatron/optimizer/__init__.py index 823a51f4492..7298930daae 100644 --- a/megatron/optimizer/__init__.py +++ b/megatron/optimizer/__init__.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from apex.optimizers import FusedAdam as Adam -from apex.optimizers import FusedSGD as SGD +import warnings + +try: + from apex.optimizers import FusedAdam as Adam + from apex.optimizers import FusedSGD as SGD +except ImportError: + warnings.warn("Apex not found") from megatron import get_args from megatron.model import LayerNorm @@ -52,6 +57,7 @@ def get_megatron_optimizer(model): # Base optimizer. param_groups = _get_params_for_weight_decay_optimization(model) + print("weight_decay", args.weight_decay) if args.optimizer == 'adam': optimizer = Adam(param_groups, lr=args.lr, diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py index 036a1d4c4cf..30e1b820ea0 100644 --- a/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py @@ -17,9 +17,13 @@ import torch from torch._six import inf +import warnings -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C +try: + from apex.multi_tensor_apply import multi_tensor_applier + import amp_C +except ImportError: + warnings.warn("Apex not found") from megatron import mpu from megatron.model.module import param_is_not_shared diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index 77baddd62ad..175a44b4c8d 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -17,15 +17,20 @@ from abc import ABC from abc import abstractmethod +import warnings import torch -from apex.multi_tensor_apply import multi_tensor_applier -import amp_C +try: + from apex.multi_tensor_apply import multi_tensor_applier + import amp_C +except ImportError: + warnings.warn("Apex not found") from megatron import get_timers from megatron import mpu from megatron import print_rank_0 +from megatron.metrics import record_scale,get_log_scales from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 @@ -136,6 +141,13 @@ def state_dict(self): def load_state_dict(self, state_dict): pass + def _record_scales(self): + if get_log_scales(): + for group in self.optimizer.param_groups: + for p in group['params']: + name_=getattr(p, "name_", "unknown") + record_scale(f"optimizer.{name_}.scale", p, False) + record_scale(f"optimizer.{name_}.grad", p.grad, False) # Promote state so it can be retrieved or set via # "optimizer_instance.state" @@ -245,6 +257,8 @@ def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, float16_params_this_group.append(param) # Create a copy main_param = param.detach().clone().float() + if hasattr(param, "name_"): + main_param.name_=param.name_ # Copy tensor model parallel attributes. mpu.copy_tensor_model_parallel_attributes(main_param, param) @@ -406,6 +420,7 @@ def step(self): num_zeros_in_grad = self.count_zeros() if \ self.log_num_zeros_in_grad else None + self._record_scales() # Step the optimizer. self.optimizer.step() @@ -504,6 +519,7 @@ def step(self): num_zeros_in_grad = self.count_zeros() if \ self.log_num_zeros_in_grad else None + self._record_scales() # Update parameters. self.optimizer.step() diff --git a/megatron/training.py b/megatron/training.py index 62ed60c1238..61031e29bea 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -25,7 +25,7 @@ import torch from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP -from megatron import get_args +from megatron.metrics import get_args, get_log_scales, next_iteration, log_metrics from megatron import get_timers from megatron import get_tensorboard_writer from megatron import get_current_global_batch_size @@ -535,6 +535,9 @@ def add_to_logging(name): timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) + if get_log_scales(): + log_metrics() + if iteration % args.log_interval == 0: elapsed_time = timers('interval-time').elapsed() elapsed_time_per_iteration = elapsed_time / total_iterations @@ -617,6 +620,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, print_datetime('before the start of training step') report_memory_flag = True while iteration < args.train_iters: + next_iteration(iteration) update_num_microbatches(args.consumed_train_samples) loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ train_step(forward_step_func,