Skip to content
Merged
18 changes: 6 additions & 12 deletions megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import torch
import os

logger = logging.getLogger(__name__)

from .package_info import (
__description__,
__contact_names__,
Expand All @@ -38,21 +42,11 @@
from .initialize import initialize_megatron

def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))

def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)

def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
logger.info(str(message))
59 changes: 27 additions & 32 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
"""Megatron arguments."""

import argparse
import logging
import os

import torch

logger = logging.getLogger(__name__)

def parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse all arguments."""
Expand Down Expand Up @@ -73,13 +76,12 @@ def parse_args(extra_args_provider=None, defaults={},
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
logger.info('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size))

# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
Expand All @@ -98,11 +100,9 @@ def parse_args(extra_args_provider=None, defaults={},
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
logger.warning('Overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)))
else:
setattr(args, key, defaults[key])

Expand All @@ -111,9 +111,8 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
logger.info('setting global batch size to {}'.format(
args.global_batch_size))
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
Expand All @@ -140,13 +139,10 @@ def parse_args(extra_args_provider=None, defaults={},
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
logger.info('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.')

if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
logger.info('using {} for parameters ...'.format(args.params_dtype))

# If we do accumulation and all-reduces in fp32, we need to have
# local DDP and we should set the use-contiguous-buffers-in-ddp.
Expand Down Expand Up @@ -239,17 +235,14 @@ def parse_args(extra_args_provider=None, defaults={},

def _print_args(args):
"""Print arguments."""
if args.rank == 0:
print('------------------------ arguments ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print('-------------------- end of arguments ---------------------',
flush=True)
logger.info('------------------------ arguments ------------------------')
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
logger.info(arg)
logger.info('-------------------- end of arguments ---------------------')


def _check_arg_is_not_none(args, arg):
Expand Down Expand Up @@ -304,6 +297,8 @@ def _add_logging_args(parser):

group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-scales', action='store_true',
help='Log the scales of parameters, gradients and activations.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
Expand Down
7 changes: 7 additions & 0 deletions megatron/data/biencoder_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo

# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])])
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
7 changes: 7 additions & 0 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,13 @@ def get_samples_mapping(indexed_dataset,

# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])])
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May add a log message, so the job is not stuck (maybe each 10 s)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary? Worst cast it will crash after 2 minutes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some log and copied to the other datasets

if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
8 changes: 8 additions & 0 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,14 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])])

# It can take some time for the file to be visible on other nodes.
for i in range(120):
if doc_idx_filename.is_file() and sample_idx_filename.is_file() and shuffle_idx_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index files...")
time.sleep(1.0)

# Load mappings.
start_time = time.time()
print_rank_0(' > loading doc-idx mapping from {}'.format(
Expand Down
7 changes: 7 additions & 0 deletions megatron/data/realm_dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo

# Wait until rank 0 generate the index file.
torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])])
# It can take some time for the file to be visible on other nodes.
for i in range(120):
if indexmap_filename.is_file():
break
if i%10==0:
print_rank_0(" Waiting for index file...")
time.sleep(1.0)

# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
Expand Down
2 changes: 1 addition & 1 deletion megatron/fused_kernels/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include <THC/THCDeviceUtils.cuh>
#include "ATen/cuda/DeviceUtils.cuh"

#include <cuda.h>
#include <cuda_runtime.h>
Expand Down
73 changes: 73 additions & 0 deletions megatron/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import math

import torch
from megatron.global_vars import get_args

logger = logging.getLogger(__name__)

_iteration=0
_metrics={}
_LOGGING_WIDTH=50

def next_iteration(iteration:int):
global _iteration, _metrics
_metrics={}
_iteration=iteration


def record_scale(name:str,x:torch.Tensor,grad=True, bias=None):
global _metrics
if get_log_scales():
_metrics[f"{name}.scale" if grad else name]=get_scale(x if bias is None else x+bias)
if grad and x.requires_grad:
x.register_hook(lambda g: record_scale(f"{name}.grad",g,False))


def get_scale(x):
return x.detach().float().pow(2).mean().pow(0.5)


def get_log_scales():
args=get_args()
return args.log_scales and (_iteration+1) % args.log_interval == 0


def log_metrics():
metrics = {}
for key, value in _metrics.items():
metrics_ = metrics
keys = key.split(".")
for prefix in keys[:-1]:
if prefix not in metrics_:
metrics_[prefix] = {}
metrics_ = metrics_[prefix]
metrics_[keys[-1]] = _format_value(value)
_log_dicts(metrics)


def _log_dicts(metrics, indent=0):
for key, value in metrics.items():
key_ = key.rjust(len(key) + indent)

# Merge keys when there is only one entry.
while isinstance(value, dict) and len(value) == 1:
for value_key, value_ in value.items():
key_ = ".".join([key_, value_key])
value = value_
if isinstance(value, dict):
logger.info(key_ + ":")
_log_dicts(value, indent + 2)
else:
sep = _LOGGING_WIDTH - len(value) - len(key_) - 2
logger.info(f"{key_.ljust(len(key_)+sep,'.')} {value}")


def _format_value(value, precision=5,max_leading_zeros=3):
decimals = 0 if value == 0 or not math.isfinite(value) else precision - math.floor(math.log10(abs(value)))

if 0 <= decimals <= precision + max_leading_zeros:
value = f"{value:.{decimals}f}"
else:
value = f"{value:.{precision}e}"
return value
25 changes: 18 additions & 7 deletions megatron/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

"""BERT model."""

import logging
import torch

from megatron import get_args
from megatron import mpu
from megatron.metrics import record_scale
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
Expand Down Expand Up @@ -67,32 +69,37 @@ class BertLMHead(MegatronModule):
"""

def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
layernorm_epsilon, parallel_output, name_=""):

super(BertLMHead, self).__init__()
self.name_=name_

args = get_args()

self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.name_=f"{self.name_}.logits.linear_bias"
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output

self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.dense = get_linear_layer(hidden_size, hidden_size, init_method, name_=f"{self.name_}.dense")
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon, name_=f"{self.name_}.layer_norm")
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
elif args.onnx_safe:
self.gelu = erf_gelu

def forward(self, hidden_states, word_embeddings_weight):
record_scale(f"{self.name_}.hidden",hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
record_scale(f"{self.name_}.gelu",hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
record_scale(f"{self.name_}.logits",output)
return output


Expand Down Expand Up @@ -129,9 +136,11 @@ def __init__(self,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
post_process=True,
name_="bert"):
super(BertModel, self).__init__()
args = get_args()
self.name_=name_

self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
Expand All @@ -150,18 +159,20 @@ def __init__(self,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
post_process=self.post_process,
name_=self.name_)

self.initialize_word_embeddings(init_method_normal)
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output,
name_=f"{self.name_}.output_layer.lm_head")
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
init_method, name_=f"{self.name_}.output_layer.sop_head.binary_head")
self._binary_head_key = 'binary_head'

def set_input_tensor(self, input_tensor):
Expand Down
Loading