diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py index feaeb3b90c1..b1b61cd87b7 100644 --- a/megatron/data/biencoder_dataset_utils.py +++ b/megatron/data/biencoder_dataset_utils.py @@ -187,13 +187,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo '(seconds): {:4f}'.format( time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - assert counts[0].item() == torch.distributed.get_world_size( - group=mpu.get_data_parallel_group()) + # Wait until rank 0 generate the index file. + torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 51388e75c92..81acb6cde64 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -699,15 +699,9 @@ def get_samples_mapping(indexed_dataset, print_rank_0(' > elasped time to build and save samples mapping ' '(seconds): {:4f}'.format( time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Wait until rank 0 generate the index file. + torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format( diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 993f74d8b99..ca16f38efbd 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -299,15 +299,8 @@ def _build_index_mappings(name, data_prefix, documents, sizes, print_rank_0(' > elasped time to build and save shuffle-idx mapping' ' (seconds): {:4f}'.format(time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) - assert counts[0].item() == ( - torch.distributed.get_world_size() // - torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + # Wait until rank 0 generate the index file. + torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) # Load mappings. start_time = time.time() diff --git a/megatron/data/realm_dataset_utils.py b/megatron/data/realm_dataset_utils.py index 8b8ed038f29..dd33fcd2886 100644 --- a/megatron/data/realm_dataset_utils.py +++ b/megatron/data/realm_dataset_utils.py @@ -177,13 +177,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo '(seconds): {:4f}'.format( time.time() - start_time)) - # This should be a barrier but nccl barrier assumes - # device_index=rank which is not the case for model - # parallel case - counts = torch.cuda.LongTensor([1]) - torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) - assert counts[0].item() == torch.distributed.get_world_size( - group=mpu.get_data_parallel_group()) + # Wait until rank 0 generate the index file. + torch.distributed.barrier(device_ids=[int(os.environ['LOCAL_RANK'])]) # Load indexed dataset. print_rank_0(' > loading indexed mapping from {}'.format(