From 8b04e0e43e180e6ae8ca2442ed37df1f6aa581fb Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 10 Mar 2021 15:48:40 -0800 Subject: [PATCH 01/59] ICT zeroshot evaluation code --- examples/evaluate_ict_zeroshot_nq.sh | 34 +++ megatron/arguments.py | 9 + megatron/data/biencoder_dataset_utils.py | 3 - megatron/data/realm_index.py | 109 +++++----- tasks/orqa/evaluate_orqa.py | 49 +++++ tasks/orqa/evaluate_utils.py | 188 ++++++++++++++++ tasks/orqa/natural_questions/nq.py | 228 +++++++++++++++++++ tasks/orqa/natural_questions/qa_utils.py | 174 +++++++++++++++ tasks/orqa/natural_questions/tokenizers.py | 241 +++++++++++++++++++++ 9 files changed, 980 insertions(+), 55 deletions(-) create mode 100644 examples/evaluate_ict_zeroshot_nq.sh create mode 100644 tasks/orqa/evaluate_orqa.py create mode 100644 tasks/orqa/evaluate_utils.py create mode 100644 tasks/orqa/natural_questions/nq.py create mode 100644 tasks/orqa/natural_questions/qa_utils.py create mode 100644 tasks/orqa/natural_questions/tokenizers.py diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh new file mode 100644 index 00000000000..f03270ebdd5 --- /dev/null +++ b/examples/evaluate_ict_zeroshot_nq.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Evaluate natural question test data given Wikipedia embeddings and pretrained +# ICT model + +# Datasets can be downloaded from the following link: +# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py + +EVIDENCE_DATA_DIR= +EMBEDDING_PATH= +CHECKPOINT_PATH= + +QA_FILE= + +python tasks/orqa/evaluate_orqa.py \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --tensor-model-parallel-size 1 \ + --micro-batch-size 128 \ + --checkpoint-activations \ + --seq-length 512 \ + --max-position-embeddings 512 \ + --load ${CHECKPOINT_PATH} \ + --evidence-data-path ${EVIDENCE_DATA_DIR} \ + --embedding-path ${EMBEDDING_PATH} \ + --retriever-seq-length 256 \ + --vocab-file bert-vocab.txt\ + --qa-data-test ${QA_FILE} \ + --num-workers 2 \ + --faiss-use-gpu \ + --retriever-report-topk-accuracies 1 5 20 100 \ + --fp16 + diff --git a/megatron/arguments.py b/megatron/arguments.py index 6e9e06e3d0b..64e803ea89f 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -636,6 +636,10 @@ def _add_data_args(parser): '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') + group.add_argument('--qa-data-dev', type=str, default=None, + help='Path to the QA dataset dev file.') + group.add_argument('--qa-data-test', type=str, default=None, + help='Path to the QA dataset test file.') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' @@ -746,6 +750,11 @@ def _add_biencoder_args(parser): group.add_argument('--embedding-path', type=str, default=None, help='Where to save/load Open-Retrieval Embedding' ' data to/from') + group.add_argument('--faiss-match', type=str, default='string', \ + choices=['regex', 'string'], help="Answer matching '\ + 'logic type") + group.add_argument('--faiss-topk-retrievals', type=int, default=100, + help='Number of blocks to use as top-k during retrieval') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py index 36b8532ee5c..f7b3b961b8c 100644 --- a/megatron/data/biencoder_dataset_utils.py +++ b/megatron/data/biencoder_dataset_utils.py @@ -24,11 +24,8 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() - world_size = mpu.get_data_parallel_world_size() - rank = mpu.get_data_parallel_rank() if micro_batch_size is None: micro_batch_size = args.micro_batch_size - global_batch_size = micro_batch_size * world_size num_workers = args.num_workers # Use megatron's sampler with consumed samples set to 0 as diff --git a/megatron/data/realm_index.py b/megatron/data/realm_index.py index 5fc0cb55cc8..a4b543c7e04 100644 --- a/megatron/data/realm_index.py +++ b/megatron/data/realm_index.py @@ -116,18 +116,22 @@ def merge_shards_and_save(self): class FaissMIPSIndex(object): - """Wrapper object for a BlockData which similarity search via FAISS under the hood""" - def __init__(self, embed_size, block_data=None, use_gpu=False): + """ + Wrapper object for a BlockData which similarity search via FAISS under the hood + """ + def __init__(self, embed_size, embed_data=None, use_gpu=False): self.embed_size = embed_size - self.block_data = block_data + self.embed_data = embed_data self.use_gpu = use_gpu - self.id_map = dict() - self.block_mips_index = None - self._set_block_index() + self.mips_index = None + self._set_mips_index() - def _set_block_index(self): - """Create a Faiss Flat index with inner product as the metric to search against""" + def _set_mips_index(self): + """ + Create a Faiss Flat index with inner product as the metric + to search against + """ try: import faiss except ImportError: @@ -135,85 +139,86 @@ def _set_block_index(self): if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print("\n> Building index", flush=True) - self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT) + + cpu_index = faiss.IndexFlatIP(self.embed_size) if self.use_gpu: # create resources and config for GpuIndex - res = faiss.StandardGpuResources() - config = faiss.GpuIndexFlatConfig() - config.device = torch.cuda.current_device() + config = faiss.GpuMultipleClonerOptions() + config.shard = True config.useFloat16 = True - - self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config) + gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) + self.mips_index = faiss.IndexIDMap(gpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: - print(">> Initialized index on GPU {}".format(self.block_mips_index.getDevice()), flush=True) + print(">> Initialized index on GPU", flush=True) else: # CPU index supports IDs so wrap with IDMap - self.block_mips_index = faiss.IndexIDMap(self.block_mips_index) + self.mips_index = faiss.IndexIDMap(cpu_index) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">> Initialized index on CPU", flush=True) - # if we were constructed with a BlockData, then automatically load it when the FAISS structure is built - if self.block_data is not None: - self.add_block_embed_data(self.block_data) + # if we were constructed with a BlockData, then automatically load it + # when the FAISS structure is built + if self.embed_data is not None: + self.add_embed_data(self.embed_data) def reset_index(self): - """Delete existing index and create anew""" - del self.block_mips_index + """Delete existing index and create a new""" + del self.mips_index # reset the block data so that _set_block_index will reload it as well - if self.block_data is not None: - block_data_path = self.block_data.block_data_path - del self.block_data - self.block_data = BlockData(block_data_path) + if self.embed_data is not None: + embed_data_path = self.embed_data.embedding_path + del self.embed_data + self.embed_data = OpenRetreivalDataStore(embed_data_path) + + self._set_mips_index() - self._set_block_index() + def update_index(self): + """Delete existing index and create a new""" + del self.mips_index - def add_block_embed_data(self, all_block_data): + # reset the block data so that _set_mips_index will reload it as well + if self.embed_data is not None: + self.embed_data.load_from_file() + self._set_mips_index() + + def add_embed_data(self, all_embed_data): """Add the embedding of each block to the underlying FAISS index""" # this assumes the embed_data is a dict : {int: np.array} - block_indices, block_embeds = zip(*all_block_data.embed_data.items()) - - # the embeddings have to be entered in as float32 even though the math internally is done with float16. - block_embeds_arr = np.float32(np.array(block_embeds)) - block_indices_arr = np.array(block_indices) + block_indices, block_embeds = zip(*all_embed_data.embed_data.items()) - # faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with - if self.use_gpu: - for i, idx in enumerate(block_indices): - self.id_map[i] = idx + # the embeddings have to be entered in as float32 even though the math + # internally is done with float16. + embeds_arr = np.float32(np.array(block_embeds)) + indices_arr = np.array(block_indices) # we no longer need the embedding data since it's in the index now - all_block_data.clear() + all_embed_data.clear() - if self.use_gpu: - self.block_mips_index.add(block_embeds_arr) - else: - self.block_mips_index.add_with_ids(block_embeds_arr, block_indices_arr) + self.mips_index.add_with_ids(embeds_arr, indices_arr) if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: print(">>> Finished adding block data to index", flush=True) def search_mips_index(self, query_embeds, top_k, reconstruct=True): - """Get the top-k blocks by the index distance metric. + """ + Get the top-k blocks by the index distance metric. - :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks - if False: return [num_queries x k] array of distances, and another for indices + :param reconstruct: if True: return a [num_queries x k x embed_dim] + array of blocks + if False: return [num_queries x k] array of + distances, and another for indices """ query_embeds = np.float32(detach(query_embeds)) if reconstruct: # get the vectors themselves - top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k) + top_k_block_embeds = self.mips_index.search_and_reconstruct(\ + query_embeds, top_k) return top_k_block_embeds - else: # get distances and indices of closest vectors - distances, block_indices = self.block_mips_index.search(query_embeds, top_k) - if self.use_gpu: - fresh_indices = np.zeros(block_indices.shape) - for i, j in itertools.product(block_indices.shape): - fresh_indices[i, j] = self.id_map[block_indices[i, j]] - block_indices = fresh_indices + distances, block_indices = self.mips_index.search(query_embeds, top_k) return distances, block_indices diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py new file mode 100644 index 00000000000..b878e3219b3 --- /dev/null +++ b/tasks/orqa/evaluate_orqa.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main tasks functionality.""" + +import os +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.join(os.path.pardir, os.path.pardir)))) + +from megatron import get_args +from megatron.initialize import initialize_megatron + +from tasks.orqa.evaluate_utils import ORQAEvaluator + +def main(): + """ + Main program + """ + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + args = get_args() + + # Set up the model and evaluator + evaluator = ORQAEvaluator() + + # Run evaluation + if args.qa_data_dev is not None: + evaluator.evaluate(args.qa_data_dev, "DEV") + + if args.qa_data_test is not None: + evaluator.evaluate(args.qa_data_test, "TEST") + +if __name__ == "__main__": + main() + diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py new file mode 100644 index 00000000000..ebee03522e1 --- /dev/null +++ b/tasks/orqa/evaluate_utils.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from megatron import get_args, print_rank_0 +from megatron.checkpointing import load_biencoder_checkpoint +from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset +from tasks.orqa.natural_questions.nq import get_nq_dataset +from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader +from tasks.orqa.natural_questions.nq import process_nq_batch +from tasks.orqa.natural_questions.qa_utils import calculate_matches +from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex +from megatron.model.biencoder_model import biencoder_model_provider +from megatron.training import get_model + +class ORQAEvaluator(object): + def __init__(self): + args = get_args() + self.embedding_size = args.hidden_size + self.faiss_use_gpu = args.faiss_use_gpu + self.evidence_embedder_obj = None + self.evidence_dataset = None + self.mips_index = None + self.eval_dataset = None + + # Get Evidence (Wikipedia) dataset + self.get_evidence_dataset() + + # Load query encoder checkpoint + only_query_model = True + if args.biencoder_shared_query_context_model: + only_query_model = False + + model = get_model(lambda: biencoder_model_provider(only_query_model=\ + only_query_model, biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model)) + + self.model = load_biencoder_checkpoint(model, + only_query_model=only_query_model) + + assert len(self.model) == 1 + self.model[0].eval() + + # Load faiss indexer + self.faiss_wrapper() + + def get_evidence_embedding(self): + # This will load the embedding from the embedding path + self.evidence_embedder_obj = OpenRetreivalDataStore(load_from_path=True) + + def get_evidence_dataset(self): + self.evidence_dataset = get_open_retrieval_wiki_dataset() + + def faiss_wrapper(self): + # Initialize FAISS wrapper on local rank = 0 as the evidence embeddings + # is distributed over all the GPUs in a node and FAISS is not + # thread-safe + args = get_args() + if args.local_rank == 0: + # Get evidence embeddings computed using context encoder + self.get_evidence_embedding() + + assert self.evidence_embedder_obj is not None + self.mips_index = FaissMIPSIndex(embed_size=self.embedding_size, + embed_data=self.evidence_embedder_obj, + use_gpu=self.faiss_use_gpu) + + # Wait for the FAISS index to be initialized in all the nodes + torch.distributed.barrier() + + def generate_query_vectors(self, qa_data, split): + + self.eval_dataset = get_nq_dataset(qa_data, split) + dataloader = get_one_epoch_nq_dataloader(self.eval_dataset) + + query_vectors = [] + reference_list = [] + + for batch in dataloader: + # batch also has query_tokens and query_pad_data + query_tokens, query_mask, query_types, \ + query_len, reference = process_nq_batch(batch) + + assert len(self.model) == 1 + unwrapped_model = self.model[0] + while not hasattr(unwrapped_model, 'embed_text'): + unwrapped_model = unwrapped_model.module + + with torch.no_grad(): + query_logits = unwrapped_model.embed_text( + unwrapped_model.query_model, query_tokens, + query_mask, query_types) + + reference_list.extend(reference) + query_vectors.extend(query_logits.split(1, dim=0)) + if len(query_vectors) % 100 == 0: + print_rank_0('Encoded queries {}'.format(len(query_vectors))) + + query_tensor = torch.cat(query_vectors, dim=0) + print_rank_0('Total encoded queries tensor {}'.format(query_tensor.size())) + + assert query_tensor.size(0) == len(self.eval_dataset) + return query_tensor, reference_list + + def evaluate(self, qa_data, split): + args = get_args() + query_tensor, reference_list = self.generate_query_vectors(qa_data, \ + split) + local_rank = args.local_rank + rank = torch.distributed.get_rank() + device_count = torch.cuda.device_count() + num_nodes = torch.distributed.get_world_size() // device_count + node_id = rank // device_count + + for node in range(num_nodes): + start_rank = node * device_count + end_rank = (node + 1) * device_count + ranks_list = list(range(start_rank, end_rank)) + node_group = torch.distributed.new_group(ranks=ranks_list) + + if node_id == node: + device_start_rank = start_rank + group = node_group + + input_ = torch.empty_like(query_tensor).copy_(query_tensor).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(device_count)] + torch.distributed.all_gather(tensor_list, query_tensor, group=group) + + if local_rank == 0 and self.mips_index is not None: + all_query_tensor = torch.cat(tensor_list, dim=0).contiguous() + + distance, topkindex = self.mips_index.search_mips_index( + all_query_tensor, top_k=args.faiss_topk_retrievals, + reconstruct=False) + distance = torch.from_numpy(distance).cuda() + topkindex = torch.LongTensor(topkindex).cuda() + + if local_rank != 0: + distance = torch.empty(device_count * len(query_tensor), \ + args.faiss_topk_retrievals, dtype=torch.float32).cuda() + topkindex = torch.empty(device_count * len(query_tensor), \ + args.faiss_topk_retrievals, dtype=torch.int64).cuda() + + torch.distributed.broadcast(distance, src=device_start_rank, \ + group=group) + torch.distributed.broadcast(topkindex, src=device_start_rank, \ + group=group) + + distance = torch.split(distance, len(query_tensor), dim=0)\ + [local_rank] + topkindex = torch.split(topkindex, len(query_tensor), dim=0)\ + [local_rank] + + top_ids_and_scores = [] + for darray, topkarray in zip(distance, topkindex): + top_ids_and_scores.append((topkarray.tolist(), darray.tolist())) + + passages = self.evidence_dataset.id2text + match_stats = calculate_matches(passages, + reference_list, + top_ids_and_scores, + workers_num=args.num_workers, + match_type=args.faiss_match) + top_k_hits = match_stats.top_k_hits + + print_rank_0("{} SET RESULTS".format(split)) + print_rank_0("topk-{} documents hits {}".format( + args.faiss_topk_retrievals, top_k_hits)) + top_k_hits = [v / len(top_ids_and_scores) for v in top_k_hits] + print_rank_0("top-k documents hits accuracy {}".format(top_k_hits)) + + for i in args.retriever_report_topk_accuracies: + print_rank_0("top-{}: {:.2f}".format(i, top_k_hits[i-1] * 100)) + + return diff --git a/tasks/orqa/natural_questions/nq.py b/tasks/orqa/natural_questions/nq.py new file mode 100644 index 00000000000..ca07fe4165c --- /dev/null +++ b/tasks/orqa/natural_questions/nq.py @@ -0,0 +1,228 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + Data Loader for Google NQ dataset +""" + +from abc import ABC +import csv +from collections import OrderedDict +import numpy as np + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset, BatchSampler + +from megatron import print_rank_0, get_args, get_tokenizer, mpu +from megatron.data.biencoder_dataset_utils import make_attention_mask + +def get_nq_dataset(qa_data, split): + args = get_args() + tokenizer = get_tokenizer() + + dataset = NQDataset('Google NQ {} Split'.format(split), + 'Google Natural Questions', + qa_data, + tokenizer, + args.retriever_seq_length) + return dataset + + +def process_nq_batch(batch): + query_tokens = batch['token_ids'].long().cuda() + query_mask = (batch['token_mask'] < 0.5).cuda() + query_types = batch['token_types'].long().cuda() + query_len = batch['seq_len'].long().cuda() + reference = batch['reference'] + + return query_tokens, query_mask, query_types, query_len, reference + + +class CustomDataLoader(DataLoader): + def __init__(self, dataset, eval=False, **kwargs): + if kwargs.get('collate_fn', None) is None: + kwargs['collate_fn'] = self._collate_fn + self.eval = eval + super().__init__(dataset, **kwargs) + + def _collate_fn(self, batch_data): + # generate batch + batch_size = len(batch_data) + tensorized = OrderedDict() + for d in batch_data: + for k, v in d.items(): + tensorized.setdefault(k, []).append(v) + assert len(tensorized) == 5 + + tensorized['token_ids'] = torch.LongTensor(tensorized['token_ids']) + tensorized['token_mask'] = torch.LongTensor(tensorized['token_mask']) + tensorized['token_types'] = torch.LongTensor(tensorized['token_types']) + tensorized['seq_len'] = torch.LongTensor(tensorized['seq_len']) + return tensorized + + +def get_one_epoch_nq_dataloader(dataset, micro_batch_size=None): + """Data loader. Note that batch-size is the local (per GPU) batch-size. + NOTE: This dataloader is not distributed !!! + """ + + args = get_args() + if micro_batch_size is None: + micro_batch_size = args.micro_batch_size + num_workers = args.num_workers + + sampler = torch.utils.data.SequentialSampler(dataset) + # importantly, drop_last must be False to get all the data. + batch_sampler = BatchSampler(sampler, + batch_size=micro_batch_size, + drop_last=False) + + # Data loader. Note that batch size is the per GPU batch size. + data_loader = CustomDataLoader(dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=True) + return data_loader + + +def build_tokens_types_paddings_from_text(src_text, tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + src_text_ids = tokenizer.tokenize(src_text) + + return build_tokens_types_paddings_from_ids(src_text_ids, + max_seq_length, + tokenizer.cls, + tokenizer.sep, + tokenizer.pad) + + +def build_tokens_types_paddings_from_ids(src_ids, max_seq_length, cls_id, \ + sep_id, pad_id): + """ + Build token types and paddings, trim if needed, and pad if needed. + + TODO: Design modular interface to reuse this function. This is getting + repeated multiple times in different tasks + """ + + enc_ids = [] + tokentypes_enc = [] + + # [CLS]. + enc_ids.append(cls_id) + tokentypes_enc.append(0) + + # A. + len_src = len(src_ids) + enc_ids.extend(src_ids) + tokentypes_enc.extend([0] * len_src) + + # Cap the size. + if len(enc_ids) > max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + return enc_ids, tokentypes_enc, num_tokens_enc + + +def build_sample(token_ids, token_types, num_tokens, reference): + """ + Convert to numpy and return a sample consumed by the + batch producer. + """ + + token_ids = np.array(token_ids, dtype=np.int64) + token_types = np.array(token_types, dtype=np.int64) + token_mask = make_attention_mask(token_ids, token_ids) + + sample = ({ + 'token_ids': token_ids, + 'token_mask': token_mask, + 'token_types': token_types, + 'seq_len': num_tokens, + 'reference': reference + }) + return sample + + +class NQDataset(ABC, Dataset): + """ + Open Retrieval Question Answering evaluation using Google NQ dataset. + """ + + def __init__(self, task_name, dataset_name, datapath, + tokenizer, max_seq_length): + # Store inputs. + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + print_rank_0(datapath) + self.samples = self.process_samples_from_single_path(datapath) + print_rank_0(' >> total number of samples: {}'.format(\ + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + ques_tokens, tokentypes_enc, num_tokens_ques = \ + build_tokens_types_paddings_from_text(raw_sample['question'], + self.tokenizer, self.max_seq_length) + + sample = build_sample(ques_tokens, + tokentypes_enc, + num_tokens_ques, + raw_sample['answers']) + return sample + + @staticmethod + def process_samples_from_single_path(filename): + print_rank_0(' > Processing {} ...'.format(filename)) + samples = [] + total = 0 + + with open(filename, 'r') as ifile: + reader = csv.reader(ifile, delimiter='\t') + for row in reader: + question = row[0] + answers = eval(row[1]) + + sample = {'question': question, 'answers': answers} + total += 1 + samples.append(sample) + + if total % 1000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples diff --git a/tasks/orqa/natural_questions/qa_utils.py b/tasks/orqa/natural_questions/qa_utils.py new file mode 100644 index 00000000000..8cd1166db3f --- /dev/null +++ b/tasks/orqa/natural_questions/qa_utils.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" + Set of utilities for Q&A results validation tasks - Retriver passage + validation and Reader predicted answer validation +""" + +import collections +import logging +import string +import unicodedata +from functools import partial +from multiprocessing import Pool as ProcessPool +from typing import Tuple, List, Dict + +import regex as re +from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer + +logger = logging.getLogger(__name__) + +QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\ + 'questions_doc_hits']) + +def calculate_matches(all_docs: Dict[object, Tuple[str, str]], + answers: List[List[str]], closest_docs: List[Tuple[List[object], + List[float]]], workers_num: int, match_type: str) -> QAMatchStats: + """ + Evaluates answers presence in the set of documents. This function is + supposed to be used with a large collection of documents and results. + It internally forks multiple sub-processes for evaluation and then + merges results + :param all_docs: dictionary of the entire documents database. + doc_id -> (doc_text, title) + :param answers: list of answers's list. One list per question + :param closest_docs: document ids of the top results along with their + scores + :param workers_num: amount of parallel threads to process data + :param match_type: type of answer matching. Refer to has_answer code for + available options + :return: matching information tuple. + top_k_hits - a list where the index is the amount of top documents retrieved + and the value is the total amount of valid matches across an entire + dataset. + questions_doc_hits - more detailed info with answer matches for every + question and every retrieved document + """ + global dpr_all_documents + dpr_all_documents = all_docs + + tok_opts = {} + tokenizer = SimpleTokenizer(**tok_opts) + + processes = ProcessPool( + processes=workers_num, + ) + + logger.info('Matching answers in top docs...') + + get_score_partial = partial(check_answer, match_type=match_type, + tokenizer=tokenizer) + + questions_answers_docs = zip(answers, closest_docs) + + scores = processes.map(get_score_partial, questions_answers_docs) + + logger.info('Per question validation results len=%d', len(scores)) + + n_docs = len(closest_docs[0][0]) + top_k_hits = [0] * n_docs + for question_hits in scores: + best_hit = next((i for i, x in enumerate(question_hits) if x), None) + if best_hit is not None: + top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] + + return QAMatchStats(top_k_hits, scores) + + +def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: + """ + Search through all the top docs to see if they have any of the answers. + """ + answers, (doc_ids, doc_scores) = questions_answers_docs + + global dpr_all_documents + hits = [] + + for i, doc_id in enumerate(doc_ids): + doc = dpr_all_documents[doc_id] + text = doc[0] + + answer_found = False + if text is None: # cannot find the document for some reason + logger.warning("no doc in db") + hits.append(False) + continue + + if has_answer(answers, text, tokenizer, match_type): + answer_found = True + hits.append(answer_found) + return hits + + +def has_answer(answers, text, tokenizer, match_type) -> bool: + """ + Check if a document contains an answer string. + If `match_type` is string, token matching is done between the text + and answer. + If `match_type` is regex, we search the whole text with the regex. + """ + text = _normalize(text) + + if match_type == 'string': + # Answer is a list of possible strings + text = tokenizer.tokenize(text).words(uncased=True) + + for single_answer in answers: + single_answer = _normalize(single_answer) + single_answer = tokenizer.tokenize(single_answer) + single_answer = single_answer.words(uncased=True) + + for i in range(0, len(text) - len(single_answer) + 1): + if single_answer == text[i: i + len(single_answer)]: + return True + + elif match_type == 'regex': + # Answer is a regex + for single_answer in answers: + single_answer = _normalize(single_answer) + if regex_match(text, single_answer): + return True + return False + + +def regex_match(text, pattern): + """Test if a regex pattern is contained within a text.""" + try: + pattern = re.compile( + pattern, + flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, + ) + except BaseException: + return False + return pattern.search(text) is not None + + +# function for the reader model answer validation +def exact_match_score(prediction, ground_truth): + return _normalize_answer(prediction) == _normalize_answer(ground_truth) + + +def _normalize_answer(s): + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def _normalize(text): + return unicodedata.normalize('NFD', text) diff --git a/tasks/orqa/natural_questions/tokenizers.py b/tasks/orqa/natural_questions/tokenizers.py new file mode 100644 index 00000000000..a5234a529c5 --- /dev/null +++ b/tasks/orqa/natural_questions/tokenizers.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +""" +Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency +""" + +import copy +import logging + +import regex +import spacy + +logger = logging.getLogger(__name__) + + +class Tokens(object): + """A class to represent a list of tokenized text.""" + TEXT = 0 + TEXT_WS = 1 + SPAN = 2 + POS = 3 + LEMMA = 4 + NER = 5 + + def __init__(self, data, annotators, opts=None): + self.data = data + self.annotators = annotators + self.opts = opts or {} + + def __len__(self): + """The number of tokens.""" + return len(self.data) + + def slice(self, i=None, j=None): + """Return a view of the list of tokens from [i, j).""" + new_tokens = copy.copy(self) + new_tokens.data = self.data[i: j] + return new_tokens + + def untokenize(self): + """Returns the original text (with whitespace reinserted).""" + return ''.join([t[self.TEXT_WS] for t in self.data]).strip() + + def words(self, uncased=False): + """Returns a list of the text of each token + + Args: + uncased: lower cases text + """ + if uncased: + return [t[self.TEXT].lower() for t in self.data] + else: + return [t[self.TEXT] for t in self.data] + + def offsets(self): + """Returns a list of [start, end) character offsets of each token.""" + return [t[self.SPAN] for t in self.data] + + def pos(self): + """Returns a list of part-of-speech tags of each token. + Returns None if this annotation was not included. + """ + if 'pos' not in self.annotators: + return None + return [t[self.POS] for t in self.data] + + def lemmas(self): + """Returns a list of the lemmatized text of each token. + Returns None if this annotation was not included. + """ + if 'lemma' not in self.annotators: + return None + return [t[self.LEMMA] for t in self.data] + + def entities(self): + """Returns a list of named-entity-recognition tags of each token. + Returns None if this annotation was not included. + """ + if 'ner' not in self.annotators: + return None + return [t[self.NER] for t in self.data] + + def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): + """Returns a list of all ngrams from length 1 to n. + + Args: + n: upper limit of ngram length + uncased: lower cases text + filter_fn: user function that takes in an ngram list and returns + True or False to keep or not keep the ngram + as_string: return the ngram as a string vs list + """ + + def _skip(gram): + if not filter_fn: + return False + return filter_fn(gram) + + words = self.words(uncased) + ngrams = [(s, e + 1) + for s in range(len(words)) + for e in range(s, min(s + n, len(words))) + if not _skip(words[s:e + 1])] + + # Concatenate into strings + if as_strings: + ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] + + return ngrams + + def entity_groups(self): + """Group consecutive entity tokens with the same NER tag.""" + entities = self.entities() + if not entities: + return None + non_ent = self.opts.get('non_ent', 'O') + groups = [] + idx = 0 + while idx < len(entities): + ner_tag = entities[idx] + # Check for entity tag + if ner_tag != non_ent: + # Chomp the sequence + start = idx + while (idx < len(entities) and entities[idx] == ner_tag): + idx += 1 + groups.append((self.slice(start, idx).untokenize(), ner_tag)) + else: + idx += 1 + return groups + + +class Tokenizer(object): + """Base tokenizer class. + Tokenizers implement tokenize, which should return a Tokens class. + """ + + def tokenize(self, text): + raise NotImplementedError + + def shutdown(self): + pass + + def __del__(self): + self.shutdown() + + +class SimpleTokenizer(Tokenizer): + ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' + NON_WS = r'[^\p{Z}\p{C}]' + + def __init__(self, **kwargs): + """ + Args: + annotators: None or empty set (only tokenizes). + """ + self._regexp = regex.compile( + '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), + flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE + ) + if len(kwargs.get('annotators', {})) > 0: + logger.warning('%s only tokenizes! Skipping annotators: %s' % + (type(self).__name__, kwargs.get('annotators'))) + self.annotators = set() + + def tokenize(self, text): + data = [] + matches = [m for m in self._regexp.finditer(text)] + for i in range(len(matches)): + # Get text + token = matches[i].group() + + # Get whitespace + span = matches[i].span() + start_ws = span[0] + if i + 1 < len(matches): + end_ws = matches[i + 1].span()[0] + else: + end_ws = span[1] + + # Format data + data.append(( + token, + text[start_ws: end_ws], + span, + )) + return Tokens(data, self.annotators) + + +class SpacyTokenizer(Tokenizer): + + def __init__(self, **kwargs): + """ + Args: + annotators: set that can include pos, lemma, and ner. + model: spaCy model to use (either path, or keyword like 'en'). + """ + model = kwargs.get('model', 'en') + self.annotators = copy.deepcopy(kwargs.get('annotators', set())) + nlp_kwargs = {'parser': False} + if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + nlp_kwargs['tagger'] = False + if 'ner' not in self.annotators: + nlp_kwargs['entity'] = False + self.nlp = spacy.load(model, **nlp_kwargs) + + def tokenize(self, text): + # We don't treat new lines as tokens. + clean_text = text.replace('\n', ' ') + tokens = self.nlp.tokenizer(clean_text) + if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]): + self.nlp.tagger(tokens) + if 'ner' in self.annotators: + self.nlp.entity(tokens) + + data = [] + for i in range(len(tokens)): + # Get whitespace + start_ws = tokens[i].idx + if i + 1 < len(tokens): + end_ws = tokens[i + 1].idx + else: + end_ws = tokens[i].idx + len(tokens[i].text) + + data.append(( + tokens[i].text, + text[start_ws: end_ws], + (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), + tokens[i].tag_, + tokens[i].lemma_, + tokens[i].ent_type_, + )) + + # Set special option for non-entity tag: '' vs 'O' in spaCy + return Tokens(data, self.annotators, opts={'non_ent': ''}) From 661553f648f30ab6b227f98cc049aba53803f83e Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 11 Mar 2021 13:12:32 -0800 Subject: [PATCH 02/59] made more generic, aligned with other tasks --- examples/evaluate_ict_zeroshot_nq.sh | 4 +++- tasks/main.py | 2 ++ tasks/orqa/evaluate_orqa.py | 6 +----- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh index f03270ebdd5..e1ce45a9342 100644 --- a/examples/evaluate_ict_zeroshot_nq.sh +++ b/examples/evaluate_ict_zeroshot_nq.sh @@ -12,7 +12,9 @@ CHECKPOINT_PATH= QA_FILE= -python tasks/orqa/evaluate_orqa.py \ +python tasks/main.py \ + --task ICT-ZEROSHOT-NQ \ + --tokenizer-type BertWordPieceLowerCase \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 12 \ diff --git a/tasks/main.py b/tasks/main.py index 27c45081398..0c11a40251b 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -62,6 +62,8 @@ def get_tasks_args(parser): from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main + elif args.task in ['ICT-ZEROSHOT-NQ']: + from orqa.evaluate_orqa import main else: raise NotImplementedError('Task {} is not implemented.'.format( args.task)) diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index b878e3219b3..e8590a2343b 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -18,9 +18,6 @@ import os import sys -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), - os.path.join(os.path.pardir, os.path.pardir)))) - from megatron import get_args from megatron.initialize import initialize_megatron @@ -30,8 +27,7 @@ def main(): """ Main program """ - initialize_megatron(extra_args_provider=None, - args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + args = get_args() # Set up the model and evaluator From 43c9137b94edcbaa2a9d1e3c671e938bac4cc937 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 18 Mar 2021 17:36:14 -0700 Subject: [PATCH 03/59] Fixed based on review recoemmendation --- megatron/arguments.py | 11 ----------- tasks/main.py | 14 ++++++++++++++ tasks/orqa/evaluate_orqa.py | 2 -- tasks/orqa/natural_questions/qa_utils.py | 7 +++++-- tasks/orqa/natural_questions/tokenizers.py | 6 ++++-- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 64e803ea89f..61b7e5c4a81 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -636,10 +636,6 @@ def _add_data_args(parser): '1) a single data path, 2) multiple datasets in the' 'form: dataset1-weight dataset1-path dataset2-weight ' 'dataset2-path ...') - group.add_argument('--qa-data-dev', type=str, default=None, - help='Path to the QA dataset dev file.') - group.add_argument('--qa-data-test', type=str, default=None, - help='Path to the QA dataset test file.') group.add_argument('--split', type=str, default='969, 30, 1', help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' @@ -743,18 +739,11 @@ def _add_biencoder_args(parser): 'square root of hidden size') # faiss index - group.add_argument('--faiss-use-gpu', action='store_true', - help='Whether create the FaissMIPSIndex on GPU') group.add_argument('--block-data-path', type=str, default=None, help='Where to save/load BlockData to/from') group.add_argument('--embedding-path', type=str, default=None, help='Where to save/load Open-Retrieval Embedding' ' data to/from') - group.add_argument('--faiss-match', type=str, default='string', \ - choices=['regex', 'string'], help="Answer matching '\ - 'logic type") - group.add_argument('--faiss-topk-retrievals', type=int, default=100, - help='Number of blocks to use as top-k during retrieval') # indexer group.add_argument('--indexer-batch-size', type=int, default=128, diff --git a/tasks/main.py b/tasks/main.py index 0c11a40251b..fa25e583cfd 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -47,6 +47,20 @@ def get_tasks_args(parser): help='Sliding window for overlapping evaluation.') group.add_argument('--strict-lambada', action='store_true', help='Use more difficult formulation of lambada.') + # Retriever args + group.add_argument('--qa-data-dev', type=str, default=None, + help='Path to the QA dataset dev file.') + group.add_argument('--qa-data-test', type=str, default=None, + help='Path to the QA dataset test file.') + + # Faiss arguments for retriever + group.add_argument('--faiss-use-gpu', action='store_true', + help='Whether create the FaissMIPSIndex on GPU') + group.add_argument('--faiss-match', type=str, default='string', \ + choices=['regex', 'string'], help="Answer matching '\ + 'logic type") + group.add_argument('--faiss-topk-retrievals', type=int, default=100, + help='Number of blocks to use as top-k during retrieval') return parser diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index e8590a2343b..54c37197f8c 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -19,8 +19,6 @@ import sys from megatron import get_args -from megatron.initialize import initialize_megatron - from tasks.orqa.evaluate_utils import ORQAEvaluator def main(): diff --git a/tasks/orqa/natural_questions/qa_utils.py b/tasks/orqa/natural_questions/qa_utils.py index 8cd1166db3f..24e71e683a4 100644 --- a/tasks/orqa/natural_questions/qa_utils.py +++ b/tasks/orqa/natural_questions/qa_utils.py @@ -2,8 +2,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. + +# The following code has been taken from +# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0 +# licensed as of now. More details on the license can be found +# at https://github.com/facebookresearch/DPR/blob/master/LICENSE """ Set of utilities for Q&A results validation tasks - Retriver passage diff --git a/tasks/orqa/natural_questions/tokenizers.py b/tasks/orqa/natural_questions/tokenizers.py index a5234a529c5..fb23887ebdd 100644 --- a/tasks/orqa/natural_questions/tokenizers.py +++ b/tasks/orqa/natural_questions/tokenizers.py @@ -2,9 +2,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. +# The following code has been taken from +# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0 +# licensed as of now. More details on the license can be found +# at https://github.com/facebookresearch/DPR/blob/master/LICENSE """ Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency From 40565390c4d220288d86d9601027a1f90cbdd9cb Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Fri, 19 Mar 2021 15:37:27 -0700 Subject: [PATCH 04/59] fixed another issue --- tasks/orqa/evaluate_orqa.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index 54c37197f8c..7e6b269231a 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -38,6 +38,3 @@ def main(): if args.qa_data_test is not None: evaluator.evaluate(args.qa_data_test, "TEST") -if __name__ == "__main__": - main() - From 10ff0607a3ab72eec24d8af611c31fa835c0876a Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Fri, 9 Apr 2021 13:59:21 -0700 Subject: [PATCH 05/59] implementing DPR --- megatron/data/biencoder_dataset_utils.py | 7 ++++++ tasks/finetune_utils.py | 2 ++ tasks/main.py | 31 ++++++++++++++++++++++++ 3 files changed, 40 insertions(+) diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py index f7b3b961b8c..58983b54349 100644 --- a/megatron/data/biencoder_dataset_utils.py +++ b/megatron/data/biencoder_dataset_utils.py @@ -20,6 +20,13 @@ def make_attention_mask(source_block, target_block): # (source_length, target_length) return mask +def make_history_mask(block): + length = block.shape[0] + arange = np.arange(length) + history_mask = (arange[None, ] <= arange[:, None]) + history_mask = history_mask.astype(np.int64) + return history_mask + def get_one_epoch_dataloader(dataset, micro_batch_size=None): """Specifically one epoch to be used in an indexing job.""" args = get_args() diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index 5223cec0ba5..00a8997b090 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -248,6 +248,8 @@ def finetune(train_valid_datasets_provider, model_provider, end_of_epoch_callback = end_of_epoch_callback_provider() timers('callback function').stop() + exit() + # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) diff --git a/tasks/main.py b/tasks/main.py index fa25e583cfd..6feb19d0a37 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -62,6 +62,35 @@ def get_tasks_args(parser): group.add_argument('--faiss-topk-retrievals', type=int, default=100, help='Number of blocks to use as top-k during retrieval') + # finetune for retriever + group.add_argument('--eval-micro-batch-size', type=int, default=None, + help='Eval Batch size per model instance (local batch ' + 'size). Global batch size is local batch size ' + 'times data parallel size.') + group.add_argument('--train-with-neg', action='store_true', + help='Whether to use negative examples during model ' + 'training') + group.add_argument('--train-hard-neg', type=int, default=0, + help='Number of hard negative exmaples to use during ' + 'training') + + + # parameters for Av.rank validation method + # Following options/arguments have been taken directly from DPR codebase + #group.add_argument("--val-av-rank-start-epoch", type=int, default=10000, + # help="Av.rank validation: the epoch from which to enable this validation") + group.add_argument('--val-av-rank-hard-neg', type=int, default=30, + help='Av.rank validation: how many hard negatives to' + ' take from each question pool') + group.add_argument('--val-av-rank-other-neg', type=int, default=30, + help='Av.rank validation: how many other negatives to' + ' take from each question pool') + #group.add_argument("--val-av-rank-bsz", type=int, default=128, + # help="Av.rank validation: batch size to process passages") + #group.add_argument("--val-av-rank-max-qs", type=int, default=10000, + # help="Av.rank validation: max num of questions") + + return parser @@ -78,6 +107,8 @@ def get_tasks_args(parser): from zeroshot_gpt.evaluate import main elif args.task in ['ICT-ZEROSHOT-NQ']: from orqa.evaluate_orqa import main + elif args.task in ['RET-FINETUNE-NQ']: + from orqa.supervised.finetune import main else: raise NotImplementedError('Task {} is not implemented.'.format( args.task)) From 06076c7ad28ed32cf91faad940e68bce191d3040 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Fri, 23 Apr 2021 14:29:48 -0700 Subject: [PATCH 06/59] implementation dpr --- megatron/model/biencoder_model.py | 37 +++++++++++++++++++++---- megatron/tokenizer/bert_tokenization.py | 29 +++++++++++++++++++ megatron/tokenizer/tokenizer.py | 4 +++ tasks/finetune_utils.py | 30 +++++++++++++------- 4 files changed, 84 insertions(+), 16 deletions(-) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 51ac0a060d4..188877070b5 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -17,7 +17,9 @@ def biencoder_model_provider(only_query_model=False, only_context_model=False, - biencoder_shared_query_context_model=False): + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): """Build the model.""" args = get_args() @@ -35,7 +37,9 @@ def biencoder_model_provider(only_query_model=False, only_query_model=only_query_model, only_context_model=only_context_model, biencoder_shared_query_context_model=\ - biencoder_shared_query_context_model) + biencoder_shared_query_context_model, + pre_process=pre_process, + post_process=post_process) return model @@ -48,13 +52,17 @@ def __init__(self, parallel_output=True, only_query_model=False, only_context_model=False, - biencoder_shared_query_context_model=False): + biencoder_shared_query_context_model=False, + pre_process=True, + post_process=True): super(BiEncoderModel, self).__init__() args = get_args() bert_kwargs = dict( num_tokentypes=num_tokentypes, - parallel_output=parallel_output) + parallel_output=parallel_output, + pre_process=pre_process, + post_process=post_process) self.biencoder_shared_query_context_model = \ biencoder_shared_query_context_model @@ -78,6 +86,19 @@ def __init__(self, self.context_model = PretrainedBertModel(**bert_kwargs) self._context_key = 'context_model' + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + #self.language_model.set_input_tensor(input_tensor) + return + # #if self._model_key is not None: + # # print("_model_key {}".format(self._model_key), flush=True) + # print(input_tensor) + # if self._query_key is not None: + # print("_query_key {}".format(self._query_key), flush=True) + # if self._context_key is not None: + # print("_context_key {}".format(self._context_key), flush=True) + # exit() + def forward(self, query_tokens, query_attention_mask, query_types, context_tokens, context_attention_mask, context_types): """Run a forward pass for each of the models and @@ -217,7 +238,7 @@ class PretrainedBertModel(MegatronModule): learned information retrieval.""" def __init__(self, num_tokentypes=2, - parallel_output=True): + parallel_output=True, pre_process=True, post_process=True): super(PretrainedBertModel, self).__init__() args = get_args() @@ -225,6 +246,8 @@ def __init__(self, num_tokentypes=2, self.pad_id = tokenizer.pad self.biencoder_projection_dim = args.biencoder_projection_dim self.parallel_output = parallel_output + self.pre_process = pre_process + self.post_process = post_process init_method = init_method_normal(args.init_method_std) scaled_init_method = scaled_init_method_normal( args.init_method_std, args.num_layers) @@ -234,7 +257,9 @@ def __init__(self, num_tokentypes=2, add_pooler=False, encoder_attn_mask_type=AttnMaskType.padding, init_method=init_method, - scaled_init_method=scaled_init_method) + scaled_init_method=scaled_init_method, + pre_process=self.pre_process, + post_process=self.post_process) if args.biencoder_projection_dim > 0: self.projection_enc = get_linear_layer(args.hidden_size, diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py index a3aa6d907e3..99f9a87958d 100644 --- a/megatron/tokenizer/bert_tokenization.py +++ b/megatron/tokenizer/bert_tokenization.py @@ -181,6 +181,35 @@ def convert_tokens_to_ids(self, tokens): def convert_ids_to_tokens(self, ids): return convert_by_vocab(self.inv_vocab, ids) + @staticmethod + def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True): + """ Converts a sequence of tokens (string) in a single string. """ + + def clean_up_tokenization(out_string): + """ Clean up a list of simple English tokenization artifacts + like spaces before punctuations and abreviated forms. + """ + out_string = ( + out_string.replace(" .", ".") + .replace(" ?", "?") + .replace(" !", "!") + .replace(" ,", ",") + .replace(" ' ", "'") + .replace(" n't", "n't") + .replace(" 'm", "'m") + .replace(" 's", "'s") + .replace(" 've", "'ve") + .replace(" 're", "'re") + ) + return out_string + + text = ' '.join(tokens).replace(' ##', '').strip() + if clean_up_tokenization_spaces: + clean_text = clean_up_tokenization(text) + return clean_text + else: + return text + def vocab_size(self): return len(self.vocab) diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index ec835dbdf5e..0e0ff26de0b 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -155,6 +155,10 @@ def tokenize(self, text): text_tokens = self.tokenizer.tokenize(text) return self.tokenizer.convert_tokens_to_ids(text_tokens) + def decode(self, ids): + tokens = self.tokenizer.convert_ids_to_tokens(ids) + return self.tokenizer.convert_tokens_to_string(tokens) + def decode_token_ids(self, token_ids): tokens = self.tokenizer.convert_ids_to_tokens(token_ids) exclude_list = ['[PAD]', '[CLS]'] diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index 50f9548d4f6..d8247a5c365 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model): return output_tensor, partial(cross_entropy_loss_func, labels) -def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): +def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, + task_collate_fn=None): """Data loader. Note that batch-size is the local (per GPU) batch-size.""" # Sampler. @@ -89,6 +90,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) + print_rank_0(len(sampler)) + # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=micro_batch_size, @@ -96,7 +99,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): shuffle=False, num_workers=num_workers, drop_last=drop_last, - pin_memory=True) + pin_memory=True, + collate_fn=task_collate_fn) return data_loader @@ -112,21 +116,23 @@ def _build_infinite_size_dataloader(dataloader): iterator = dataloader.__iter__() -def _build_train_valid_dataloaders(train_dataset, valid_dataset): +def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None): """Traing and validation dataloaders.""" args = get_args() print_rank_0('building train and validation dataloaders ...') # Training dataset. train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, - args.num_workers, not args.keep_last) + args.num_workers, not args.keep_last, + task_collate_fn) # Set the training iterations. args.train_iters_per_epoch = len(train_dataloader) args.train_iters = args.epochs * args.train_iters_per_epoch # Validation dataset. For this dataset, we do not need to set up # shuffling so we can just use a simple infinite loop. valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, - args.num_workers, not args.keep_last) + args.num_workers, not args.keep_last, + task_collate_fn) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) # Now that we've built the data loaders, set batch_size arguments @@ -185,9 +191,10 @@ def _train(model, optimizer, lr_scheduler, forward_step, continue # Set to zero so the next epoch does not skip any batches. start_iteration = 0 - + # Train for one step. out = train_step(forward_step, batch, model, optimizer, lr_scheduler) + losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out iteration += 1 @@ -220,6 +227,10 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) + #if iteration == 1000: + # exit() + #break + # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) @@ -231,7 +242,8 @@ def _train(model, optimizer, lr_scheduler, forward_step, def finetune(train_valid_datasets_provider, model_provider, forward_step=_cross_entropy_forward_step, - end_of_epoch_callback_provider=None): + end_of_epoch_callback_provider=None, + task_collate_fn=None): """Main finetune function used across all tasks.""" args = get_args() timers = get_timers() @@ -244,7 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider, if args.epochs > 0: train_dataset, valid_dataset = train_valid_datasets_provider() train_dataloader, valid_dataloader = _build_train_valid_dataloaders( - train_dataset, valid_dataset) + train_dataset, valid_dataset, task_collate_fn) else: args.train_iters = 0 timers('train/valid/test dataset/dataloder').stop() @@ -256,8 +268,6 @@ def finetune(train_valid_datasets_provider, model_provider, end_of_epoch_callback = end_of_epoch_callback_provider() timers('callback function').stop() - exit() - # Build model, optimizer and learning rate scheduler. timers('model and optimizer').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) From b9fcb7b48357bfdd63d1ce617443642d3b697c28 Mon Sep 17 00:00:00 2001 From: mpatwary Date: Wed, 28 Apr 2021 18:18:18 -0700 Subject: [PATCH 07/59] adding dpr code --- tasks/orqa/supervised/data.py | 301 ++++++++++++++++++++++++++++ tasks/orqa/supervised/eval_utils.py | 211 +++++++++++++++++++ tasks/orqa/supervised/finetune.py | 239 ++++++++++++++++++++++ 3 files changed, 751 insertions(+) create mode 100644 tasks/orqa/supervised/data.py create mode 100644 tasks/orqa/supervised/eval_utils.py create mode 100644 tasks/orqa/supervised/finetune.py diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py new file mode 100644 index 00000000000..922de56edf2 --- /dev/null +++ b/tasks/orqa/supervised/data.py @@ -0,0 +1,301 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA dataset.""" + +import json +import random +from abc import ABC +from abc import abstractmethod + +import numpy as np +from torch.utils.data import Dataset + +from megatron import print_rank_0, get_args +from megatron.data.biencoder_dataset_utils import make_attention_mask +from megatron.data.biencoder_dataset_utils import make_history_mask + + +def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length): + ctx_id_list, ctx_types_list = [], [] + for context in ctx_list: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids + + ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids, + max_seq_length, tokenizer.cls, + tokenizer.sep, tokenizer.pad) + ctx_id_list.append(ctx_ids) + ctx_types_list.append(ctx_types) + + return ctx_id_list, ctx_types_list + + +def build_tokens_types_paddings_from_text(query, context, + tokenizer, max_seq_length): + """Build token types and paddings, trim if needed, and pad if needed.""" + + query_ids = tokenizer.tokenize(query) + query_ids, query_types, query_pad_mask = \ + build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \ + tokenizer.cls, tokenizer.sep, tokenizer.pad) + + # Appending the title of the context at front + extended_ctx_ids = None + if context is not None: + title_ids = tokenizer.tokenize(context['title']) + ctx_ids = tokenizer.tokenize(context['text']) + extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids + + ctx_ids, ctx_types, ctx_pad_mask = \ + build_tokens_types_paddings_from_ids(extended_ctx_ids, + max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) + + return query_ids, query_types, query_pad_mask, \ + ctx_ids, ctx_types, ctx_pad_mask + + +# Similar code tasks/data_utils with some changes +def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, + cls_id, sep_id, pad_id): + """Build token types and paddings, trim if needed, and pad if needed.""" + enc_ids = [] + tokentypes_enc = [] + + # [CLS]. + enc_ids.append(cls_id) + tokentypes_enc.append(0) + + # A. + len_src = len(text_ids) + enc_ids.extend(text_ids) + tokentypes_enc.extend([0] * len_src) + + # Cap the size. + if len(enc_ids) > max_seq_length - 1: + enc_ids = enc_ids[0: max_seq_length - 1] + tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] + + # [SEP]. + enc_ids.append(sep_id) + tokentypes_enc.append(0) + + num_tokens_enc = len(enc_ids) + # Padding. + padding_length = max_seq_length - len(enc_ids) + if padding_length > 0: + enc_ids.extend([pad_id] * padding_length) + tokentypes_enc.extend([pad_id] * padding_length) + + pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) + pad_mask = np.array(pad_mask, dtype=np.int64) + + return enc_ids, tokentypes_enc, pad_mask + + +def build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, answers, + neg_ctx_id_list=None, neg_ctx_types_list=None, + include_neg=False): + """Convert to numpy and return a sample consumed by the batch producer.""" + + query_ids = np.array(query_ids, dtype=np.int64) + query_types = np.array(query_types, dtype=np.int64) + query_mask = make_attention_mask(query_ids, query_ids) + + ctx_ids = np.array(ctx_ids, dtype=np.int64) + ctx_types = np.array(ctx_types, dtype=np.int64) + ctx_mask = make_attention_mask(ctx_ids, ctx_ids) + + sample = ({ + 'query': query_ids, + 'query_mask': query_mask, + 'query_types': query_types, + 'query_pad_mask': query_pad_mask, + 'context': ctx_ids, + 'context_mask': ctx_mask, + 'context_types': ctx_types, + 'context_pad_mask': ctx_pad_mask, + 'reference': answers + }) + + if include_neg: + neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64) + neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64) + neg_ctx_mask = np.array([make_attention_mask(ids, ids) \ + for ids in neg_ctx_ids], dtype=np.int64) + + sample['neg_context'] = neg_ctx_ids + sample['neg_context_types'] = neg_ctx_id_types + sample['neg_context_mask'] = neg_ctx_mask + + return sample + + +class OpenRetrievalAbstractDataset(ABC, Dataset): + """Open Retrieval base dataset class.""" + + def __init__(self, task_name, dataset_name, datapaths, tokenizer, \ + max_seq_length, evaluate=False): + # Store inputs. + args = get_args() + self.evaluate = evaluate + self.val_av_rank_hard_neg = args.val_av_rank_hard_neg + self.val_av_rank_other_neg = args.val_av_rank_other_neg + self.train_with_neg = args.train_with_neg + self.train_hard_neg = args.train_hard_neg + + self.task_name = task_name + self.dataset_name = dataset_name + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + print_rank_0(' > building {} dataset for {}:'.format(self.task_name, + self.dataset_name)) + # Process the files. + string = ' > paths:' + for path in datapaths: + string += ' ' + path + print_rank_0(string) + self.samples = [] + for datapath in datapaths: + self.samples.extend(self.process_samples_from_single_path(datapath)) + + args = get_args() + if args.sample_rate < 1: # subsample + k = int(len(self.samples) * args.sample_rate) + self.samples = random.sample(self.samples, k) + + print_rank_0(' >> total number of samples: {}'.format( + len(self.samples))) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + raw_sample = self.samples[idx] + + query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \ + ctx_pad_mask = build_tokens_types_paddings_from_text( \ + raw_sample['question'], raw_sample['pos_context'], \ + self.tokenizer, self.max_seq_length) + + if self.evaluate: + neg_ctx_list = \ + raw_sample['negative_context'][:self.val_av_rank_other_neg] + \ + raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg] + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, \ + self.tokenizer, self.max_seq_length) + + elif self.train_with_neg: + hard_negative_ctx = raw_sample['hard_negative_context'] + negative_ctx = raw_sample['negative_context'] + if True: # TODO: fix this or remove this condition + random.shuffle(hard_negative_ctx) + random.shuffle(negative_ctx) + + neg_ctx_list = hard_negative_ctx[:self.train_hard_neg] + # In the Google NQ dataset by DPR paper, there are around more than + # 50 missing hard negatives in training data. + # In those cases, substitute hard negatives by simple negatives. + if len(neg_ctx_list) < self.train_hard_neg: + neg_ctx_list += negative_ctx[:self.train_hard_neg - \ + len(neg_ctx_list)] + + neg_ctx_id_list, neg_ctx_types_list = \ + build_token_types_from_context_list(neg_ctx_list, + self.tokenizer, self.max_seq_length) + else: + neg_ctx_id_list = None + neg_ctx_types_list = None + + sample = build_sample(query_ids, query_types, query_pad_mask, + ctx_ids, ctx_types, ctx_pad_mask, + raw_sample['answers'], + neg_ctx_id_list, neg_ctx_types_list, + include_neg=self.evaluate or self.train_with_neg) + + return sample + + @staticmethod + @abstractmethod + def process_samples_from_single_path(filename): + """Abstract method that takes a filename and + returns a list of dataset samples, each sample being a dict of + {'text': string, 'text': string} + """ + pass + + + +def normalize_question(question): + if question[-1] == '?': + question = question[:-1] + return question + +class NQSupervisedDataset(OpenRetrievalAbstractDataset): + + def __init__(self, name, datapaths, tokenizer, max_seq_length, \ + evaluate=False): + super().__init__('natural_questions_ret', + name, + datapaths, + tokenizer, + max_seq_length, + evaluate=evaluate) + + @staticmethod + def process_samples_from_single_path(filename): + """"Implement abstract method.""" + print_rank_0(' > Processing {} ...'.format(filename)) + samples = [] + total = 0 + + with open(filename, 'r', encoding="utf-8") as f: + data = json.load(f) + for row in data: + question = normalize_question(row['question']) + pos_context = row['positive_ctxs'][0] + + # Hard Negative Contexts + if len(row['hard_negative_ctxs']) > 0: + hard_neg_context = row['hard_negative_ctxs'] + else: + hard_neg_context = [] + + # Negative Contexts + if len(row['negative_ctxs']) > 0: + neg_context = row['negative_ctxs'] + else: + neg_context = [] + + answers = row['answers'] + sample = {'question': question, + 'pos_context': pos_context, + 'hard_negative_context': hard_neg_context, + 'negative_context': neg_context, + 'answers': answers} + total += 1 + samples.append(sample) + + if total % 5000 == 0: + print_rank_0(' > processed {} so far ...'.format(total)) + + print_rank_0(' >> processed {} samples.'.format(len(samples))) + return samples + + + diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py new file mode 100644 index 00000000000..729367266d3 --- /dev/null +++ b/tasks/orqa/supervised/eval_utils.py @@ -0,0 +1,211 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation utilities.""" +from collections import OrderedDict +import math +import numpy as np +import time +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader + +from megatron import get_args, print_rank_0 +from megatron import mpu +from megatron.utils import average_losses_across_data_parallel_group +from tasks.finetune_utils import build_data_loader + +def task_collate_fn(batch_data): + # generate batch + batch_size = len(batch_data) + tensorized = OrderedDict() + for d in batch_data: + for k, v in d.items(): + tensorized.setdefault(k, []).append(v) + # assert len(tensorized) == 12 + + tensorized['query'] = torch.LongTensor(tensorized['query']) + tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) + tensorized['query_types'] = torch.LongTensor(tensorized['query_types']) + tensorized['query_pad_mask'] = \ + torch.LongTensor(tensorized['query_pad_mask']) + + tensorized['context'] = torch.LongTensor(tensorized['context']) + tensorized['context_mask'] = \ + torch.LongTensor(tensorized['context_mask']) + tensorized['context_types'] = \ + torch.LongTensor(tensorized['context_types']) + tensorized['context_pad_mask'] = \ + torch.LongTensor(tensorized['context_pad_mask']) + + if 'neg_context' in tensorized: + tensorized['neg_context'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context'])) + tensorized['neg_context_mask'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_mask'])) + tensorized['neg_context_types'] = \ + torch.LongTensor(np.concatenate(tensorized['neg_context_types'])) + + return tensorized + + + +def process_batch(batch): + """Process batch and produce inputs for the model.""" + query_tokens = batch['query'].long().cuda() + query_mask = (batch['query_mask'] < 0.5).cuda() + query_types = batch['query_types'].long().cuda() + query_pad_mask = batch['query_pad_mask'].long().cuda() + + context_tokens = batch['context'].long().cuda() + context_mask = (batch['context_mask'] < 0.5).cuda() + context_types = batch['context_types'].long().cuda() + context_pad_mask = batch['context_pad_mask'].long().cuda() + + if 'neg_context' in batch: + neg_context_tokens = batch['neg_context'].long().cuda() + neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda() + neg_context_types = batch['neg_context_types'].long().cuda() + else: + neg_context_tokens = None + neg_context_mask = None + neg_context_types = None + + reference = batch['reference'] + + return query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, reference + +def accuracy_func_provider(single_dataset_provider, rank0sampler=False): +#, datapath, +# rank0sampler=False): + """Provide function that calculates accuracies.""" + args = get_args() + + print_rank_0("accuracy_func_provider is CALLED") + + # Build dataloaders + datapath = args.valid_data + dataset = single_dataset_provider(datapath) + + drop_last = False + if mpu.get_data_parallel_world_size() > 1 and not rank0sampler: + drop_last = True + + print_rank_0(datapath) + print_rank_0(rank0sampler) + + dataloader = build_data_loader(dataset, + args.eval_micro_batch_size, + num_workers=args.num_workers, + drop_last=drop_last, + task_collate_fn=task_collate_fn) + #shuffle=False, + #rank0sampler=rank0sampler) + dataloaders = (dataset.dataset_name, dataloader) + + def metrics_func(model, epoch, output_predictions=False): + print_rank_0('calculating metrics by accuracy func in ORQA...') + + if output_predictions: + assert rank0sampler + names = 'predictions' + name, dataloader = dataloaders + if args.task == "RET-FINETUNE-NQ": + start_time = time.time() + output = retrieval_loss(model, dataloader) + stats_dict, total = output + format_string = "" + for k, v in stats_dict.items(): + format_string += "|{} = {:.2f}".format(k, v / total) + print_rank_0("epoch:{}{}".format(epoch, format_string)) + print_rank_0("taken time to calcuate metrics {:.3f}".format(\ + time.time() - start_time)) + else: + raise AssertionError("{} Task not supported".format(args.task)) + + return metrics_func + + +def retrieval_loss(model, dataloader): + args = get_args() + total = 0 + topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \ + args.retriever_report_topk_accuracies} + stats_dict = dict(rank=0, **topk_stats_dict) + + assert len(model) == 1 + unwrapped_model = model[0] + unwrapped_model.eval() + + with torch.no_grad(): + # For all the batches in the dataset. + for batch in dataloader: + # Run the model forward. + query_tokens, query_mask, query_types, _, \ + context_tokens, context_mask, context_types, _, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch) + + query_logits, context_logits = unwrapped_model(query_tokens, + query_mask, query_types, + torch.cat([context_tokens, neg_context_tokens]), + torch.cat([context_mask, neg_context_mask]), + torch.cat([context_types, neg_context_types])) + + retrieval_scores = torch.matmul(query_logits, + torch.transpose(context_logits, 0, 1)) + + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / \ + math.sqrt(args.hidden_size) + + local_batch_size = query_logits.shape[0] + labels = torch.arange(local_batch_size).long().cuda() + + softmax_scores = F.softmax(retrieval_scores, dim=1) + sorted_vals, sorted_indices = torch.topk(softmax_scores, + k=softmax_scores.shape[1], + sorted=True) + + def topk_accuracy(k): + return torch.cuda.FloatTensor( + [sum([int(labels[i] in sorted_indices[i, :k]) for i in \ + range(local_batch_size)])]) + + def get_rank(): + return torch.cuda.FloatTensor( + [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \ + for i in range(local_batch_size)])]) + + topk_accs = [topk_accuracy(k) for k in \ + args.retriever_report_topk_accuracies] + rank = get_rank() + losses = average_losses_across_data_parallel_group([rank, \ + *topk_accs]) + + # create stats_dict with retrieval loss and all specified + # top-k accuracies + topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ + zip(args.retriever_report_topk_accuracies, losses[1:])} + temp_stats_dict = dict(rank=losses[0], **topk_acc_dict) + for k in stats_dict.keys(): + stats_dict[k] += temp_stats_dict[k] + total += local_batch_size + + unwrapped_model.train() + + return stats_dict, total diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py new file mode 100644 index 00000000000..8dfd4474695 --- /dev/null +++ b/tasks/orqa/supervised/finetune.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ORQA finetuning/evaluation.""" + +from functools import partial + +import math +import torch +import torch.nn.functional as F + +from megatron import get_args +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu +from megatron import print_rank_0 +from megatron.utils import average_losses_across_data_parallel_group +from megatron.model.biencoder_model import biencoder_model_provider +#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider +#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune +from pretrain_ict import get_group_world_size_rank +from tasks.finetune_utils import finetune +from tasks.orqa.supervised.eval_utils import accuracy_func_provider +from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn + +def orqa(Dataset): # , name_from_datapath_func): + + def cross_entropy_forward_step(batch, model): + """Simple forward step with cross-entropy loss.""" + args = get_args() + timers = get_timers() + tokenizer = get_tokenizer() + + # Get the batch. + timers('batch generator').start() + try: + batch_ = next(batch) + except BaseException: + batch_ = batch + + query_tokens, query_mask, query_types, query_pad_mask, \ + context_tokens, context_mask, context_types, context_pad_mask, \ + neg_context_tokens, neg_context_mask, neg_context_types, \ + reference = process_batch(batch_) + + timers('batch generator').stop() + local_batch_size = query_tokens.shape[0] + + # Text representation of query and context + query_list, context_list = [], [] + for i in range(local_batch_size): + query_list.append(tokenizer.decode(query_tokens[i].tolist())) + context_list.append(tokenizer.decode(context_tokens[i].tolist())) + + if neg_context_tokens is not None: + context_tokens = torch.cat([context_tokens, neg_context_tokens]) + context_mask = torch.cat([context_mask, neg_context_mask]) + context_types = torch.cat([context_types, neg_context_types]) + + # Forward model. + #query_logits, context_logits = model(query_tokens, query_mask, + output_tensor = model(query_tokens, query_mask, + query_types, context_tokens, + context_mask, context_types) + + return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens) + + + #def cross_entropy_loss_func(labels, output_tensor): + def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor): + args = get_args() + + local_batch_size = query_tokens.shape[0] + group, rank, world_size = get_group_world_size_rank() + # recall we assert that model_parallel_size == 1 + global_batch_size = world_size * local_batch_size + + query_logits, context_logits = output_tensor + + if world_size > 1: + input_ = torch.empty_like(context_logits).copy_(\ + context_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == \ + context_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = context_logits + all_context_logits = torch.cat(tensor_list, dim=0).contiguous() + + # Query tensors + input_ = torch.empty_like(query_logits).copy_(\ + query_logits).detach_() + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank].copy_(input_) + torch.distributed.all_gather(tensor_list, input_, group=group) + + # Check if all-gather happens in order + assert tensor_list[rank].sum().item() == query_logits.sum().item() + + # Preserves the gradient + tensor_list[rank] = query_logits + all_query_logits = torch.cat(tensor_list, dim=0).contiguous() + else: + all_query_logits = query_logits + all_context_logits = context_logits + + retrieval_scores = torch.matmul(all_query_logits, + torch.transpose(all_context_logits, 0, 1)) + # Scaling the retrieval scores + if args.retriever_score_scaling: + retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) + + if args.train_with_neg: + # if the world size is 3, local batch size is 4, and + # local context size is 8, what we want is + # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] + labels = [] + local_context_size = context_tokens.shape[0] + for i in range(world_size): + j = i * local_context_size + labels.extend(list(range(j, j + local_batch_size))) + labels = torch.LongTensor(labels).cuda() + assert len(labels) == global_batch_size + else: + labels = torch.arange(global_batch_size).long().cuda() + + # Cross-entropy loss. + softmax_scores = F.log_softmax(retrieval_scores, dim=1) + + loss = F.nll_loss(softmax_scores, labels, reduction='mean') + + max_score, max_idxs = torch.max(softmax_scores, 1) + correct_predictions_count = (max_idxs == labels).sum().float() + + # Reduce loss for logging. + reduced_loss = average_losses_across_data_parallel_group([loss, \ + correct_predictions_count]) + + # Loss scaling for correct losses in Supervised Retrieval + loss = loss * mpu.get_data_parallel_world_size() + + return loss, {'lm loss': reduced_loss[0], + 'correct_prediction_count': reduced_loss[1]} + + + def train_valid_datasets_provider(): + """Build train and validation dataset.""" + args = get_args() + tokenizer = get_tokenizer() + + train_dataset = Dataset('training', + args.train_data, + tokenizer, + args.retriever_seq_length, + evaluate=False) + valid_dataset = Dataset('validation', + args.valid_data, + tokenizer, + args.retriever_seq_length, + evaluate=True) + return train_dataset, valid_dataset + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + args = get_args() + print_rank_0('building retriever model for {} ...'.format(args.task)) + model = biencoder_model_provider(only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + return model + + def single_dataset_provider(datapath): + args = get_args() + tokenizer = get_tokenizer() + + #name = name_from_datapath_func(datapath) + name = datapath[0].split('/')[-1].split('.')[0] + return Dataset(name, + datapath, + tokenizer, + args.retriever_seq_length, + evaluate=True) + + #def distributed_metrics_func_provider(): + def metrics_func_provider(): + """Provide metrics callback function.""" + + #def name_from_datapath(datapath): + # return datapath[0].split('/')[-1].split('.')[0] + + return accuracy_func_provider(single_dataset_provider) + + #def rank0_metrics_func_provider(datapath): + # """Provide metrics callback function.""" + # return accuracy_func_provider(single_dataset_provider, datapath, + # rank0sampler=True) + + """Finetune/evaluate.""" + finetune(train_valid_datasets_provider, + model_provider, + forward_step=cross_entropy_forward_step, + end_of_epoch_callback_provider=metrics_func_provider, + task_collate_fn=task_collate_fn) + #,end_of_training_callback_provider=rank0_metrics_func_provider) + + +def main(): + args = get_args() + + if args.task == 'RET-FINETUNE-NQ': + from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset + + #def name_from_datapath(datapath): + # return datapath[0].split('/')[-1].split('.')[0] + + else: + raise NotImplementedError('ORQA task {} is not implemented.'.format( + args.task)) + + orqa(Dataset) #, name_from_datapath) From 80047314780a9d9e26f027527d9d85a9f93fd5cf Mon Sep 17 00:00:00 2001 From: mpatwary Date: Wed, 28 Apr 2021 18:22:57 -0700 Subject: [PATCH 08/59] removed commnets --- megatron/model/biencoder_model.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 188877070b5..5410a56eeaa 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -88,16 +88,10 @@ def __init__(self, def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" + #this is just a placeholder and will be needed when model + #parallelism will be used #self.language_model.set_input_tensor(input_tensor) return - # #if self._model_key is not None: - # # print("_model_key {}".format(self._model_key), flush=True) - # print(input_tensor) - # if self._query_key is not None: - # print("_query_key {}".format(self._query_key), flush=True) - # if self._context_key is not None: - # print("_context_key {}".format(self._context_key), flush=True) - # exit() def forward(self, query_tokens, query_attention_mask, query_types, context_tokens, context_attention_mask, context_types): From f415dc850838ed30e4a4d2fdcb85920035251488 Mon Sep 17 00:00:00 2001 From: mpatwary Date: Wed, 28 Apr 2021 18:24:36 -0700 Subject: [PATCH 09/59] removed commnets --- tasks/finetune_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index d8247a5c365..b4fb78c67ca 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -90,8 +90,6 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank) - print_rank_0(len(sampler)) - # Data loader. Note that batch size is the per GPU batch size. data_loader = torch.utils.data.DataLoader(dataset, batch_size=micro_batch_size, From a8d172b31e921d0ad0889660443846098496d696 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 28 Apr 2021 18:26:01 -0700 Subject: [PATCH 10/59] removed commnets --- tasks/finetune_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index b4fb78c67ca..215df897a24 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -225,10 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) - #if iteration == 1000: - # exit() - #break - # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) From 220637f945488d409500c4132e891009cf1ef964 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Tue, 11 May 2021 12:40:14 -0700 Subject: [PATCH 11/59] DPR evaluation debugging --- megatron/arguments.py | 6 +++++ megatron/checkpointing.py | 7 ++++-- megatron/indexer.py | 39 +++++++++++++++++++++++++------ megatron/learning_rates.py | 18 +++++++++++++- megatron/model/biencoder_model.py | 4 ++++ megatron/model/language_model.py | 7 +++++- tasks/finetune_utils.py | 4 ++++ tasks/orqa/supervised/finetune.py | 29 ++++++++++++++++++----- 8 files changed, 97 insertions(+), 17 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index f7aa44ccea2..603ce14da73 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -478,6 +478,12 @@ def _add_learning_rate_args(parser): group.add_argument('--min-lr', type=float, default=0.0, help='Minumum value for learning rate. The scheduler' 'clip values below this threshold.') + group.add_argument('--override-lr-new', action='store_true', + help='Reset the values of the scheduler (learning rate,' + 'warmup iterations, minimum learning rate, maximum ' + 'number of iterations, and decay style from input ' + 'arguments and ignore values from checkpoints. Note' + 'that all the above values will be reset.') group.add_argument('--override-lr-scheduler', action='store_true', help='Reset the values of the scheduler (learning rate,' 'warmup iterations, minimum learning rate, maximum ' diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 43dfa16b044..0cd033b1b0a 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -413,8 +413,11 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') - assert len(model) == 1 - model[0].load_state_dict(ret_state_dict) + #print_rank_0(len(model)) + #sys.exit() + #assert len(model) == 1 + #model[0].load_state_dict(ret_state_dict) + model.load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: diff --git a/megatron/indexer.py b/megatron/indexer.py index c0d1ca7de15..dba4ecba082 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -2,7 +2,7 @@ import torch import torch.distributed as dist -from megatron import get_args +from megatron import get_args, print_rank_0 from megatron import mpu from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset @@ -25,6 +25,8 @@ def __init__(self): self.evidence_embedder_obj = None self.biencoder_shared_query_context_model = \ args.biencoder_shared_query_context_model + self.pre_process = True + self.post_process = True # need to know whether we're using a REALM checkpoint (args.load) # or ICT checkpoint @@ -47,15 +49,22 @@ def load_attributes(self): if self.biencoder_shared_query_context_model: only_context_model = False - model = get_model(lambda: biencoder_model_provider(only_context_model \ + #model = get_model(lambda: biencoder_model_provider(only_context_model \ + # = only_context_model, biencoder_shared_query_context_model = \ + # self.biencoder_shared_query_context_model, \ + # pre_process=self.pre_process, post_process=self.post_process)) + + model = biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ - self.biencoder_shared_query_context_model)) + self.biencoder_shared_query_context_model, \ + pre_process=self.pre_process, post_process=self.post_process) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) - assert len(self.model) == 1 - self.model[0].eval() + #assert len(self.model) == 1 + #self.model[0].eval() + self.model.eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ @@ -83,10 +92,12 @@ def build_and_save_index(self): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ - assert len(self.model) == 1 - unwrapped_model = self.model[0] + #assert len(self.model) == 1 + #unwrapped_model = self.model[0] + unwrapped_model = self.model while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module + print_rank_0("hasattr") while True: try: @@ -97,12 +108,26 @@ def build_and_save_index(self): except (StopIteration, IndexError): break + print_rank_0(context_tokens) + print_rank_0(context_mask) + print_rank_0(context_types) + #if torch.cuda.is_available(): + # print_rank_0("cuda available") + #print_rank_0(torch.cuda.current_device()) + #print_rank_0(torch.cuda.get_device_name()) + print_rank_0(next(unwrapped_model.parameters()).device) + print_rank_0(next(unwrapped_model.context_model.parameters()).device) + #print_rank_0("After get_open_retrieval_batch") + # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool context_logits = unwrapped_model.embed_text( unwrapped_model.context_model, context_tokens, context_mask, context_types) + + sys.exit() + context_logits = detach(context_logits) row_id = detach(row_id) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index d200bdb176a..18ce635614f 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -18,6 +18,7 @@ import math from megatron import print_rank_0 +from megatron import get_args class AnnealingLR(object): """Anneals the learning rate.""" @@ -59,6 +60,7 @@ def get_lr(self): """Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" + #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr)) # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: return self.max_lr * float(self.num_steps) / \ @@ -87,7 +89,21 @@ def get_lr(self): else: raise Exception('{} decay style is not supported.'.format( self.decay_style)) - + + args = get_args() + + if args.override_lr_new: + mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps) + mod_num_steps_ = mod_num_steps_ - self.warmup_steps + use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps) + should_use_lr = self.min_lr + coeff * delta_lr + print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) + else: + use_lr = self.min_lr + coeff * delta_lr + print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps)) + + return use_lr + return self.min_lr + coeff * delta_lr diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 5410a56eeaa..0e85d262337 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -266,6 +266,10 @@ def forward(self, input_ids, attention_mask, tokentype_ids=None): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) + print_rank_0(input_ids.device) + print_rank_0(position_ids.device) + print_rank_0(extended_attention_mask.device) + print_rank_0(tokentype_ids.device) lm_output = self.language_model(input_ids, position_ids, diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 06330d81395..0f81b384297 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F -from megatron import get_args +from megatron import get_args, print_rank_0 from megatron import mpu from .module import MegatronModule from megatron.model.enums import LayerType, AttnMaskType @@ -338,6 +338,11 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, get_key_value=False, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): + print_rank_0("before self.embedding") + print_rank_0(enc_input_ids.device) + print_rank_0(enc_position_ids.device) + print_rank_0(tokentype_ids.device) + # Embeddings. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py index 215df897a24..be260a523f6 100644 --- a/tasks/finetune_utils.py +++ b/tasks/finetune_utils.py @@ -16,6 +16,7 @@ """Finetune utilities.""" from functools import partial +import sys import torch @@ -225,6 +226,9 @@ def _train(model, optimizer, lr_scheduler, forward_step, valid_dataloader, model, iteration, False) + #if iteration == 600: + # sys.exit() + # Checkpointing at the end of each epoch. if args.save: save_checkpoint(iteration, model, optimizer, lr_scheduler) diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py index 8dfd4474695..4e6d230756e 100644 --- a/tasks/orqa/supervised/finetune.py +++ b/tasks/orqa/supervised/finetune.py @@ -34,6 +34,8 @@ from tasks.finetune_utils import finetune from tasks.orqa.supervised.eval_utils import accuracy_func_provider from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn +from tasks.orqa.evaluate_utils import ORQAEvaluator +from megatron.indexer import IndexBuilder def orqa(Dataset): # , name_from_datapath_func): @@ -226,14 +228,29 @@ def metrics_func_provider(): def main(): args = get_args() - if args.task == 'RET-FINETUNE-NQ': - from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset + #if args.task == 'RET-FINETUNE-NQ': + # from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset #def name_from_datapath(datapath): # return datapath[0].split('/')[-1].split('.')[0] - else: - raise NotImplementedError('ORQA task {} is not implemented.'.format( - args.task)) + #else: + # raise NotImplementedError('ORQA task {} is not implemented.'.format( + # args.task)) + + #orqa(Dataset) #, name_from_datapath) + + index_builder = IndexBuilder() + index_builder.build_and_save_index() + print_rank_0("Build and save indices: done!") + + # Set up the model and evaluator + #evaluator = ORQAEvaluator() + + # Run evaluation + #if args.qa_data_dev is not None: + # evaluator.evaluate(args.qa_data_dev, "DEV") + #if args.qa_data_test is not None: + # evaluator.evaluate(args.qa_data_test, "TEST") + - orqa(Dataset) #, name_from_datapath) From d2d5086ee709810b62b0969ab0ac8a82f2d0f5a7 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Tue, 11 May 2021 14:47:46 -0700 Subject: [PATCH 12/59] DPR ongoing --- megatron/indexer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/megatron/indexer.py b/megatron/indexer.py index dba4ecba082..6c7ec04cdcd 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -49,15 +49,15 @@ def load_attributes(self): if self.biencoder_shared_query_context_model: only_context_model = False - #model = get_model(lambda: biencoder_model_provider(only_context_model \ - # = only_context_model, biencoder_shared_query_context_model = \ - # self.biencoder_shared_query_context_model, \ - # pre_process=self.pre_process, post_process=self.post_process)) - - model = biencoder_model_provider(only_context_model \ + model = get_model(lambda: biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model, \ - pre_process=self.pre_process, post_process=self.post_process) + pre_process=self.pre_process, post_process=self.post_process)) + + #model = biencoder_model_provider(only_context_model \ + # = only_context_model, biencoder_shared_query_context_model = \ + # self.biencoder_shared_query_context_model, \ + # pre_process=self.pre_process, post_process=self.post_process) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) From 6d03d7af29d81505e41c2cfe7b593d6142fa7864 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Tue, 11 May 2021 22:51:17 -0700 Subject: [PATCH 13/59] DPR finetune and evaluation --- megatron/checkpointing.py | 8 +++--- megatron/indexer.py | 39 +++++++++-------------------- megatron/model/biencoder_model.py | 20 ++++++++------- megatron/model/language_model.py | 5 ---- pretrain_ict.py | 14 +++++++---- tasks/orqa/evaluate_orqa.py | 20 ++++++++++++++- tasks/orqa/evaluate_utils.py | 11 ++++++--- tasks/orqa/supervised/finetune.py | 41 ++++++++++++------------------- 8 files changed, 78 insertions(+), 80 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 0cd033b1b0a..f8f16d38261 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -413,11 +413,9 @@ def load_biencoder_checkpoint(model, only_query_model=False, if only_context_model: ret_state_dict.pop('query_model') - #print_rank_0(len(model)) - #sys.exit() - #assert len(model) == 1 - #model[0].load_state_dict(ret_state_dict) - model.load_state_dict(ret_state_dict) + assert len(model) == 1 + model[0].load_state_dict(ret_state_dict) + torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: diff --git a/megatron/indexer.py b/megatron/indexer.py index 6c7ec04cdcd..33ce50a65d0 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -45,26 +45,25 @@ def load_attributes(self): """ Load the necessary attributes: model, dataloader and empty BlockData """ + args = get_args() only_context_model = True if self.biencoder_shared_query_context_model: only_context_model = False - model = get_model(lambda: biencoder_model_provider(only_context_model \ - = only_context_model, biencoder_shared_query_context_model = \ - self.biencoder_shared_query_context_model, \ - pre_process=self.pre_process, post_process=self.post_process)) + args.only_context_model = only_context_model + args.only_query_model = False + + model = get_model(biencoder_model_provider) - #model = biencoder_model_provider(only_context_model \ + #model = get_model(lambda: biencoder_model_provider(only_context_model \ # = only_context_model, biencoder_shared_query_context_model = \ - # self.biencoder_shared_query_context_model, \ - # pre_process=self.pre_process, post_process=self.post_process) + # self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) - #assert len(self.model) == 1 - #self.model[0].eval() - self.model.eval() + assert len(self.model) == 1 + self.model[0].eval() self.dataset = get_open_retrieval_wiki_dataset() self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \ @@ -92,12 +91,11 @@ def build_and_save_index(self): distributed setting will be consolidated by the rank 0 process and saved as a final pickled BlockData. """ - #assert len(self.model) == 1 - #unwrapped_model = self.model[0] - unwrapped_model = self.model + assert len(self.model) == 1 + unwrapped_model = self.model[0] + while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module - print_rank_0("hasattr") while True: try: @@ -108,17 +106,6 @@ def build_and_save_index(self): except (StopIteration, IndexError): break - print_rank_0(context_tokens) - print_rank_0(context_mask) - print_rank_0(context_types) - #if torch.cuda.is_available(): - # print_rank_0("cuda available") - #print_rank_0(torch.cuda.current_device()) - #print_rank_0(torch.cuda.get_device_name()) - print_rank_0(next(unwrapped_model.parameters()).device) - print_rank_0(next(unwrapped_model.context_model.parameters()).device) - #print_rank_0("After get_open_retrieval_batch") - # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool @@ -126,8 +113,6 @@ def build_and_save_index(self): unwrapped_model.context_model, context_tokens, context_mask, context_types) - sys.exit() - context_logits = detach(context_logits) row_id = detach(row_id) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 0e85d262337..404eb07a9a1 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -15,14 +15,21 @@ from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule -def biencoder_model_provider(only_query_model=False, - only_context_model=False, - biencoder_shared_query_context_model=False, - pre_process=True, +#def biencoder_model_provider(only_query_model=False, +# only_context_model=False, +# biencoder_shared_query_context_model=False, +# pre_process=True, +# post_process=True): + +def biencoder_model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() + biencoder_shared_query_context_model = args.biencoder_shared_query_context_model + only_context_model = args.only_context_model + only_query_model = args.only_query_model + assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ "Model parallel size > 1 not supported for ICT" @@ -266,11 +273,6 @@ def forward(self, input_ids, attention_mask, tokentype_ids=None): #extended_attention_mask = bert_extended_attention_mask(attention_mask) position_ids = bert_position_ids(input_ids) - print_rank_0(input_ids.device) - print_rank_0(position_ids.device) - print_rank_0(extended_attention_mask.device) - print_rank_0(tokentype_ids.device) - lm_output = self.language_model(input_ids, position_ids, extended_attention_mask, diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 0f81b384297..abf1082b4a6 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -338,11 +338,6 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, get_key_value=False, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): - print_rank_0("before self.embedding") - print_rank_0(enc_input_ids.device) - print_rank_0(enc_position_ids.device) - print_rank_0(tokentype_ids.device) - # Embeddings. if self.pre_process: embedding_output = self.embedding(enc_input_ids, enc_position_ids, diff --git a/pretrain_ict.py b/pretrain_ict.py index 1438b3d5782..9d861de30c5 100644 --- a/pretrain_ict.py +++ b/pretrain_ict.py @@ -33,11 +33,15 @@ def pretrain_ict_model_provider(): args = get_args() - model = biencoder_model_provider( - only_context_model=False, - only_query_model=False, - biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model) + args.only_context_model = False + args.only_query_model = False + model = biencoder_model_provider() + + #model = biencoder_model_provider( + # only_context_model=False, + # only_query_model=False, + # biencoder_shared_query_context_model=\ + # args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index 7e6b269231a..c1fe46e717a 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -19,6 +19,7 @@ import sys from megatron import get_args +from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator def main(): @@ -28,6 +29,23 @@ def main(): args = get_args() + """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset + - Include all args needed for initial model specification + + Other key args: + --block-data-path: path to write to + --ict-load or --realm-load: path to checkpoint with which to embed + --data-path and --titles-data-path: paths for dataset + --indexer-log-interval: reporting interval + --indexer-batch-size: size specific for indexer jobs + + Check README.md for example script + """ + + index_builder = IndexBuilder() + index_builder.build_and_save_index() + print_rank_0("Build and save indices: done!") + # Set up the model and evaluator evaluator = ORQAEvaluator() @@ -37,4 +55,4 @@ def main(): if args.qa_data_test is not None: evaluator.evaluate(args.qa_data_test, "TEST") - + diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index ebee03522e1..add4e64df5f 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -44,9 +44,14 @@ def __init__(self): if args.biencoder_shared_query_context_model: only_query_model = False - model = get_model(lambda: biencoder_model_provider(only_query_model=\ - only_query_model, biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model)) + args.only_query_model = only_query_model + args.only_context_model = False + + #model = get_model(lambda: biencoder_model_provider(only_query_model=\ + # only_query_model, biencoder_shared_query_context_model=\ + # args.biencoder_shared_query_context_model)) + + model = get_model(biencoder_model_provider) self.model = load_biencoder_checkpoint(model, only_query_model=only_query_model) diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py index 4e6d230756e..1c108be61aa 100644 --- a/tasks/orqa/supervised/finetune.py +++ b/tasks/orqa/supervised/finetune.py @@ -16,6 +16,7 @@ """ORQA finetuning/evaluation.""" from functools import partial +import sys import math import torch @@ -183,11 +184,15 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() print_rank_0('building retriever model for {} ...'.format(args.task)) - model = biencoder_model_provider(only_context_model=False, - only_query_model=False, - biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model, - pre_process=pre_process, post_process=post_process) + args.only_context_model=False + args.only_query_model=False + model = biencoder_model_provider() + + #model = biencoder_model_provider(only_context_model=False, + # only_query_model=False, + # biencoder_shared_query_context_model=\ + # args.biencoder_shared_query_context_model, + # pre_process=pre_process, post_process=post_process) return model def single_dataset_provider(datapath): @@ -228,29 +233,15 @@ def metrics_func_provider(): def main(): args = get_args() - #if args.task == 'RET-FINETUNE-NQ': - # from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset + if args.task == 'RET-FINETUNE-NQ': + from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset #def name_from_datapath(datapath): # return datapath[0].split('/')[-1].split('.')[0] - #else: - # raise NotImplementedError('ORQA task {} is not implemented.'.format( - # args.task)) - - #orqa(Dataset) #, name_from_datapath) - - index_builder = IndexBuilder() - index_builder.build_and_save_index() - print_rank_0("Build and save indices: done!") - - # Set up the model and evaluator - #evaluator = ORQAEvaluator() - - # Run evaluation - #if args.qa_data_dev is not None: - # evaluator.evaluate(args.qa_data_dev, "DEV") - #if args.qa_data_test is not None: - # evaluator.evaluate(args.qa_data_test, "TEST") + else: + raise NotImplementedError('ORQA task {} is not implemented.'.format( + args.task)) + orqa(Dataset) #, name_from_datapath) From f926720502490ef8b3efdf32362097c94d2671cc Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 12 May 2021 15:01:51 -0700 Subject: [PATCH 14/59] fixing model evaluation of retriver --- megatron/indexer.py | 7 ++++--- megatron/model/biencoder_model.py | 20 ++++++++++---------- pretrain_ict.py | 16 ++++++++-------- tasks/main.py | 2 +- tasks/orqa/evaluate_orqa.py | 13 +++++++++++++ tasks/orqa/evaluate_utils.py | 12 ++++++------ tasks/orqa/supervised/finetune.py | 17 +++++++++-------- 7 files changed, 51 insertions(+), 36 deletions(-) diff --git a/megatron/indexer.py b/megatron/indexer.py index 33ce50a65d0..3a226778d70 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -53,11 +53,12 @@ def load_attributes(self): args.only_context_model = only_context_model args.only_query_model = False - model = get_model(biencoder_model_provider) + #model = get_model(biencoder_model_provider) #model = get_model(lambda: biencoder_model_provider(only_context_model \ - # = only_context_model, biencoder_shared_query_context_model = \ - # self.biencoder_shared_query_context_model)) + model = get_model(biencoder_model_provider(only_context_model \ + = only_context_model, biencoder_shared_query_context_model = \ + self.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 404eb07a9a1..7aefcc3e301 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -15,20 +15,20 @@ from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule -#def biencoder_model_provider(only_query_model=False, -# only_context_model=False, -# biencoder_shared_query_context_model=False, -# pre_process=True, +#def biencoder_model_provider(pre_process=True, # post_process=True): - -def biencoder_model_provider(pre_process=True, + +def biencoder_model_provider(only_query_model=False, + only_context_model=False, + biencoder_shared_query_context_model=False, + pre_process=True, post_process=True): """Build the model.""" - args = get_args() + #args = get_args() - biencoder_shared_query_context_model = args.biencoder_shared_query_context_model - only_context_model = args.only_context_model - only_query_model = args.only_query_model + #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model + #only_context_model = args.only_context_model + #only_query_model = args.only_query_model assert mpu.get_tensor_model_parallel_world_size() == 1 and \ mpu.get_pipeline_model_parallel_world_size() == 1, \ diff --git a/pretrain_ict.py b/pretrain_ict.py index 9d861de30c5..8a5876da49c 100644 --- a/pretrain_ict.py +++ b/pretrain_ict.py @@ -33,15 +33,15 @@ def pretrain_ict_model_provider(): args = get_args() - args.only_context_model = False - args.only_query_model = False - model = biencoder_model_provider() + #args.only_context_model = False + #args.only_query_model = False + #model = biencoder_model_provider() - #model = biencoder_model_provider( - # only_context_model=False, - # only_query_model=False, - # biencoder_shared_query_context_model=\ - # args.biencoder_shared_query_context_model) + model = biencoder_model_provider( + only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model) return model def get_group_world_size_rank(): diff --git a/tasks/main.py b/tasks/main.py index 3056b72a70b..29fd44f39af 100644 --- a/tasks/main.py +++ b/tasks/main.py @@ -110,7 +110,7 @@ def get_tasks_args(parser): from glue.finetune import main elif args.task in ['LAMBADA', 'WIKITEXT103']: from zeroshot_gpt.evaluate import main - elif args.task in ['ICT-ZEROSHOT-NQ']: + elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']: from orqa.evaluate_orqa import main elif args.task in ['RET-FINETUNE-NQ']: from orqa.supervised.finetune import main diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index c1fe46e717a..49d19db8c58 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -18,6 +18,15 @@ import os import sys +#sys.path.append( +# os.path.abspath( +# os.path.join( +# os.path.join(os.path.dirname(__file__), os.path.pardir), +# os.path.pardir, +# ) +# ) +#) + from megatron import get_args from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator @@ -26,6 +35,8 @@ def main(): """ Main program """ + #initialize_megatron(extra_args_provider=None, + # args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args = get_args() @@ -42,6 +53,8 @@ def main(): Check README.md for example script """ + #print_rank_0("Starting index builder!") + index_builder = IndexBuilder() index_builder.build_and_save_index() print_rank_0("Build and save indices: done!") diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index add4e64df5f..3d64edc5039 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -44,14 +44,14 @@ def __init__(self): if args.biencoder_shared_query_context_model: only_query_model = False - args.only_query_model = only_query_model - args.only_context_model = False + #args.only_query_model = only_query_model + #args.only_context_model = False - #model = get_model(lambda: biencoder_model_provider(only_query_model=\ - # only_query_model, biencoder_shared_query_context_model=\ - # args.biencoder_shared_query_context_model)) + model = get_model(lambda: biencoder_model_provider(only_query_model=\ + only_query_model, biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model)) - model = get_model(biencoder_model_provider) + #model = get_model(biencoder_model_provider) self.model = load_biencoder_checkpoint(model, only_query_model=only_query_model) diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py index 1c108be61aa..6b323f434d5 100644 --- a/tasks/orqa/supervised/finetune.py +++ b/tasks/orqa/supervised/finetune.py @@ -184,15 +184,16 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" args = get_args() print_rank_0('building retriever model for {} ...'.format(args.task)) - args.only_context_model=False - args.only_query_model=False - model = biencoder_model_provider() + #args.only_context_model=False + #args.only_query_model=False + #model = biencoder_model_provider() - #model = biencoder_model_provider(only_context_model=False, - # only_query_model=False, - # biencoder_shared_query_context_model=\ - # args.biencoder_shared_query_context_model, - # pre_process=pre_process, post_process=post_process) + model = biencoder_model_provider(only_context_model=False, + only_query_model=False, + biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model, + pre_process=pre_process, post_process=post_process) + return model def single_dataset_provider(datapath): From 54093417f61e354f241ebe1f2ab3e0b66234f524 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 12 May 2021 15:23:35 -0700 Subject: [PATCH 15/59] added pre ad post process --- megatron/indexer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/indexer.py b/megatron/indexer.py index 3a226778d70..3c6a7853ee8 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -58,7 +58,8 @@ def load_attributes(self): #model = get_model(lambda: biencoder_model_provider(only_context_model \ model = get_model(biencoder_model_provider(only_context_model \ = only_context_model, biencoder_shared_query_context_model = \ - self.biencoder_shared_query_context_model)) + self.biencoder_shared_query_context_model, + pre_process=True, post_process=True)) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) From 7e335e15b5fd629aec6460b2e07441b8910942e8 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 12 May 2021 15:34:56 -0700 Subject: [PATCH 16/59] added pre ad post process --- tasks/orqa/evaluate_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index 3d64edc5039..aa981ac9d23 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -47,9 +47,11 @@ def __init__(self): #args.only_query_model = only_query_model #args.only_context_model = False + #model = get_model(lambda: biencoder_model_provider(only_query_model=\ model = get_model(lambda: biencoder_model_provider(only_query_model=\ only_query_model, biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model)) + args.biencoder_shared_query_context_model, + pre_process=True, post_process=True)) #model = get_model(biencoder_model_provider) From f64977fdfe325aefbdf09edb132fc1a77010ebe6 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 13 May 2021 00:15:47 -0700 Subject: [PATCH 17/59] evaluation works! --- megatron/indexer.py | 17 ++++++++++------- megatron/model/biencoder_model.py | 19 +++++++++++++++++++ tasks/orqa/evaluate_orqa.py | 2 +- tasks/orqa/evaluate_utils.py | 14 +++++++++----- 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/megatron/indexer.py b/megatron/indexer.py index 3c6a7853ee8..367ce9d13ac 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -9,7 +9,7 @@ from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.realm_index import detach, OpenRetreivalDataStore -from megatron.model.biencoder_model import biencoder_model_provider +from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model @@ -50,16 +50,19 @@ def load_attributes(self): if self.biencoder_shared_query_context_model: only_context_model = False - args.only_context_model = only_context_model - args.only_query_model = False + #args.only_context_model = only_context_model + #args.only_query_model = False #model = get_model(biencoder_model_provider) + model = get_model(get_model_provider(only_context_model=only_context_model, + biencoder_shared_query_context_model=self.biencoder_shared_query_context_model)) + + #model = get_model(lambda: biencoder_model_provider(only_context_model \ #model = get_model(lambda: biencoder_model_provider(only_context_model \ - model = get_model(biencoder_model_provider(only_context_model \ - = only_context_model, biencoder_shared_query_context_model = \ - self.biencoder_shared_query_context_model, - pre_process=True, post_process=True)) + # = only_context_model, biencoder_shared_query_context_model = \ + # self.biencoder_shared_query_context_model, + # pre_process=True, post_process=True) self.model = load_biencoder_checkpoint(model, only_context_model=only_context_model) diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py index 7aefcc3e301..5fb1dd313f6 100644 --- a/megatron/model/biencoder_model.py +++ b/megatron/model/biencoder_model.py @@ -15,6 +15,25 @@ from megatron.model.utils import scaled_init_method_normal from .module import MegatronModule +def get_model_provider(only_query_model=False, only_context_model=False, + biencoder_shared_query_context_model=False): + + def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building Bienoder model ...') + model = biencoder_model_provider(only_query_model=only_query_model, + only_context_model = only_context_model, + biencoder_shared_query_context_model = \ + biencoder_shared_query_context_model, + pre_process=True, post_process=True) + + return model + + return model_provider + + + #def biencoder_model_provider(pre_process=True, # post_process=True): diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index 49d19db8c58..a9c52e3b35e 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -27,7 +27,7 @@ # ) #) -from megatron import get_args +from megatron import get_args, print_rank_0 from megatron.indexer import IndexBuilder from tasks.orqa.evaluate_utils import ORQAEvaluator diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index aa981ac9d23..d677fc189e3 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -23,7 +23,7 @@ from tasks.orqa.natural_questions.nq import process_nq_batch from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex -from megatron.model.biencoder_model import biencoder_model_provider +from megatron.model.biencoder_model import get_model_provider from megatron.training import get_model class ORQAEvaluator(object): @@ -47,11 +47,15 @@ def __init__(self): #args.only_query_model = only_query_model #args.only_context_model = False + model = get_model(get_model_provider(only_query_model=only_query_model, + biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) + + + #model = get_model(lambda: biencoder_model_provider(only_query_model=\ #model = get_model(lambda: biencoder_model_provider(only_query_model=\ - model = get_model(lambda: biencoder_model_provider(only_query_model=\ - only_query_model, biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model, - pre_process=True, post_process=True)) + # only_query_model, biencoder_shared_query_context_model=\ + # args.biencoder_shared_query_context_model, + # pre_process=True, post_process=True)) #model = get_model(biencoder_model_provider) From dca47cfbeac78cc87e37919d08303c3ddfea95ac Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Fri, 14 May 2021 01:39:19 -0700 Subject: [PATCH 18/59] debugging DPR --- megatron/indexer.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/megatron/indexer.py b/megatron/indexer.py index 367ce9d13ac..cf557c83930 100644 --- a/megatron/indexer.py +++ b/megatron/indexer.py @@ -1,4 +1,5 @@ import sys +import time import torch import torch.distributed as dist @@ -102,7 +103,12 @@ def build_and_save_index(self): while not hasattr(unwrapped_model, 'embed_text'): unwrapped_model = unwrapped_model.module + counter = 0 + start_time = time.time() + cur_time = start_time while True: + #start_time = time.time() + t1 = time.time() try: # batch also has query_tokens and query_pad_data row_id, context_tokens, context_mask, context_types, \ @@ -111,6 +117,8 @@ def build_and_save_index(self): except (StopIteration, IndexError): break + #print_rank_0("get batch time {}".format(cur_time - time.time())) + t2 = time.time() # TODO: can we add with torch.no_grad() to reduce memory usage # detach, separate fields and add to BlockData assert context_mask.dtype == torch.bool @@ -120,10 +128,18 @@ def build_and_save_index(self): context_logits = detach(context_logits) row_id = detach(row_id) - + #print_rank_0("embed text {}".format(cur_time - time.time())) + t3 = time.time() + self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.track_and_report_progress(batch_size=len(row_id)) - + #print_rank_0("add block time {}".format(cur_time - time.time())) + t4 = time.time() + counter += 1 + if counter % 1000 == 0: + print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time)) + print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3)) + cur_time = time.time() # This process signals to finalize its shard and then synchronize with # the other processes self.evidence_embedder_obj.save_shard() From 3f7553749a67ef1ba0895d0836971287016327a5 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 17 May 2021 15:03:17 -0700 Subject: [PATCH 19/59] fix copy-n-paste error remove erroneous arg. --- examples/pretrain_t5_distributed_with_mp.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh index 9be70393d86..c907004abc6 100644 --- a/examples/pretrain_t5_distributed_with_mp.sh +++ b/examples/pretrain_t5_distributed_with_mp.sh @@ -25,7 +25,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \ --decoder-seq-length 128 \ --micro-batch-size 16 \ --global-batch-size 2048 \ - --seq-length 512 \ --max-position-embeddings 512 \ --train-iters 1000000 \ --lr-decay-iters 1000000 \ From 07ca952e924c7c0f92e01a326ed11cc6dd58ca11 Mon Sep 17 00:00:00 2001 From: Devrim <46989091+devrimcavusoglu@users.noreply.github.com> Date: Tue, 18 May 2021 03:27:39 +0300 Subject: [PATCH 20/59] Typo fix in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 489d3771de6..a01af805bba 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ T5_ARGS="--num-layers 24 \ OUTPUT_ARGS=<same as those in BERT pretraining above> python pretrain_t5.py \ - $BERT_ARGS \ + $T5_ARGS \ $OUTPUT_ARGS \ --save $CHECKPOINT_PATH \ --load $CHECKPOINT_PATH \ From 2dae74b6f56339535354f3d352bba43d5f372477 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 18 May 2021 12:10:16 -0700 Subject: [PATCH 21/59] t5 fixes --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 489d3771de6..de208d5b3bf 100644 --- a/README.md +++ b/README.md @@ -103,6 +103,11 @@ python tools/preprocess_data.py \ The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension. +For T5 use the same preprocessing as BERT, perhaps renaming it to: +
+       --output-prefix my-t5 \
+
+ Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
 python tools/preprocess_data.py \
@@ -237,13 +242,14 @@ T5_ARGS="--num-layers 24 \
          --micro-batch-size 16 \
          --global-batch-size 2048 \
          --vocab-file $VOCAB_FILE \
+         --vocab-extra-ids 100 \
          --split 949,50,1 \
          --fp16"
 
 OUTPUT_ARGS=<same as those in BERT pretraining above>
 
 python pretrain_t5.py \
-       $BERT_ARGS \
+       $T5_ARGS \
        $OUTPUT_ARGS \
        --save $CHECKPOINT_PATH \
        --load $CHECKPOINT_PATH \

From 7a0710ec2f5af7b32c48b805b4d263c2bd3e396c Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 15:31:46 -0700
Subject: [PATCH 22/59] before cleaning the comments

---
 megatron/data/biencoder_dataset_utils.py | 6 ------
 tasks/orqa/supervised/data.py            | 2 --
 2 files changed, 8 deletions(-)

diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py
index 58983b54349..dccf060f62a 100644
--- a/megatron/data/biencoder_dataset_utils.py
+++ b/megatron/data/biencoder_dataset_utils.py
@@ -20,12 +20,6 @@ def make_attention_mask(source_block, target_block):
     # (source_length, target_length)
     return mask
 
-def make_history_mask(block):
-    length = block.shape[0]
-    arange = np.arange(length)
-    history_mask = (arange[None, ] <= arange[:, None])
-    history_mask = history_mask.astype(np.int64)
-    return history_mask
 
 def get_one_epoch_dataloader(dataset, micro_batch_size=None):
     """Specifically one epoch to be used in an indexing job."""
diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py
index 922de56edf2..466212729fa 100644
--- a/tasks/orqa/supervised/data.py
+++ b/tasks/orqa/supervised/data.py
@@ -25,8 +25,6 @@
 
 from megatron import print_rank_0, get_args
 from megatron.data.biencoder_dataset_utils import make_attention_mask
-from megatron.data.biencoder_dataset_utils import make_history_mask
-
 
 def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length):
     ctx_id_list, ctx_types_list = [], []

From ccae9dbdeb71ccb2809bc41b10dd3440ee1d037a Mon Sep 17 00:00:00 2001
From: Vijay Korthikanti 
Date: Tue, 18 May 2021 15:47:01 -0700
Subject: [PATCH 23/59] vit pipeline fixes

---
 megatron/checkpointing.py      |   2 +-
 megatron/model/vit_model.py    | 128 +++++++++++++++++++--------------
 pretrain_vit.py                |  34 +++++----
 tasks/vision/classification.py |   5 +-
 tasks/vision/eval_utils.py     |  76 ++++++++++++++------
 tasks/vision/finetune_utils.py |  63 ++++++++++------
 6 files changed, 197 insertions(+), 111 deletions(-)

diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index 14e7971abed..7898c7d4b19 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -60,8 +60,8 @@ def _compare(arg_name, old_arg_name=None):
     _compare('num_layers')
     _compare('hidden_size')
     _compare('num_attention_heads')
-    _compare('max_position_embeddings')
     if args.vocab_file:
+        _compare('max_position_embeddings')
         _compare('make_vocab_size_divisible_by')
         _compare('padded_vocab_size')
         _compare('tokenizer_type')
diff --git a/megatron/model/vit_model.py b/megatron/model/vit_model.py
index 84a52a8294a..a1a86cfff3a 100644
--- a/megatron/model/vit_model.py
+++ b/megatron/model/vit_model.py
@@ -50,11 +50,11 @@ def __init__(self, hidden_size, num_classes):
     def forward(self, hidden_states, sequence_index=0):
         # hidden_states: [b, s, h]
         # sequence_index: index of the token to pool.
-        x = hidden_states[:, sequence_index, :]
-        x = self.dense_in(x)
-        x = torch.tanh(x)
-        x = self.dense_out(x)
-        return x
+        hidden_state = hidden_states[:, sequence_index, :]
+        dense_in_result = self.dense_in(hidden_state)
+        tanh_result = torch.tanh(dense_in_result)
+        dense_out_result = self.dense_out(tanh_result)
+        return dense_out_result
 
 
 def twod_interpolate_position_embeddings_hook(
@@ -122,8 +122,12 @@ def twod_interpolate_position_embeddings_hook(
 class VitModel(MegatronModule):
     """Vision Transformer Model."""
 
-    def __init__(self, num_classes, finetune=False):
-        super(VitModel, self).__init__()
+    def __init__(self, 
+                 num_classes,
+                 finetune=False,
+                 pre_process=True,
+                 post_process=True):
+        super(VitModel, self).__init__(share_word_embeddings=False)
         args = get_args()
 
         self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -136,6 +140,8 @@ def __init__(self, num_classes, finetune=False):
                 args.init_method_std, args.num_layers
             )
 
+        self.pre_process = pre_process
+        self.post_process = post_process
         self.hidden_size = args.hidden_size
         self.num_classes = num_classes
         self.patch_dim = args.patch_dim
@@ -148,63 +154,81 @@ def __init__(self, num_classes, finetune=False):
         self.seq_length = self.num_patches + 1
         self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
 
-        # cls_token
-        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
-        torch.nn.init.zeros_(self.cls_token)
+        if self.pre_process:
+            # cls_token
+            self.cls_token = torch.nn.Parameter(
+                torch.randn(1, 1, self.hidden_size)
+            )
+            torch.nn.init.zeros_(self.cls_token)
 
-        # Linear encoder
-        self.linear_encoder = torch.nn.Linear(
-            self.flatten_dim, self.hidden_size
-        )
+            # Linear encoder
+            self.linear_encoder = torch.nn.Linear(
+                self.flatten_dim, self.hidden_size
+            )
 
-        # embedding
-        self.position_embeddings = torch.nn.Embedding(
-            self.seq_length, self.hidden_size
-        )
-        init_method_normal(args.init_method_std)(
-            self.position_embeddings.weight
-        )
-        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
+            # embedding
+            self.position_embeddings = torch.nn.Embedding(
+                self.seq_length, self.hidden_size
+            )
+            init_method_normal(args.init_method_std)(
+                self.position_embeddings.weight
+            )
+            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
 
-        self.position_embeddings._register_load_state_dict_pre_hook(
-            twod_interpolate_position_embeddings_hook
-        )
+            self.position_embeddings._register_load_state_dict_pre_hook(
+                twod_interpolate_position_embeddings_hook
+            )
 
-        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
+            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
 
         # Transformer
         self.transformer = ParallelTransformer(
-            self.init_method, self.scaled_init_method
+            self.init_method, 
+            self.scaled_init_method,
+            pre_process=self.pre_process,
+            post_process=self.post_process
         )
 
-        # MLP head
-        if not self.finetune:
-            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
-        else:
-            self.class_head = get_linear_layer(
-                self.hidden_size, num_classes, torch.nn.init.zeros_
+        if self.post_process:
+            # MLP head
+            if not self.finetune:
+                self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
+            else:
+                self.class_head = get_linear_layer(
+                    self.hidden_size, num_classes, torch.nn.init.zeros_
+                )
+
+    def set_input_tensor(self, input_tensor):
+        """See megatron.model.transformer.set_input_tensor()"""
+        self.transformer.set_input_tensor(input_tensor)
+
+    def forward(self, input):
+
+        if self.pre_process:
+            rearranged_input = einops.rearrange(
+                input,
+                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+                p1=self.patch_dim,
+                p2=self.patch_dim,
             )
 
-    def forward(self, x):
-        x = einops.rearrange(
-            x,
-            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
-            p1=self.patch_dim,
-            p2=self.patch_dim,
-        )
+            assert rearranged_input.dtype == torch.half
+            encoder_output = self.linear_encoder(rearranged_input)
+            cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
+            concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
 
-        assert x.dtype == torch.half
-        x = self.linear_encoder(x)
-        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
-        x = torch.cat((cls_tokens, x), dim=1)
+            token_embeddings = concatenated_tokens + \
+                self.position_embeddings(self.position_ids)
+            hidden_states = self.embedding_dropout(token_embeddings)
+        else:
+            hidden_states = input
 
-        x = x + self.position_embeddings(self.position_ids)
-        x = self.embedding_dropout(x)
-        x = self.transformer(x, None)
+        hidden_states = self.transformer(hidden_states, None)
 
-        if not self.finetune:
-            x = self.mlp_head(x)
-        else:
-            x = self.class_head(x[:, 0, :])
+        if self.post_process:
+            if not self.finetune:
+                hidden_states = self.mlp_head(hidden_states)
+            else:
+                hidden_states = self.class_head(hidden_states[:, 0, :])
 
-        return x
+        return hidden_states
diff --git a/pretrain_vit.py b/pretrain_vit.py
index 16ec10439a0..7770c68d5d5 100644
--- a/pretrain_vit.py
+++ b/pretrain_vit.py
@@ -17,19 +17,22 @@
 
 import torch
 import torch.nn.functional as F
+from functools import partial
 from megatron import get_args, get_timers, mpu, print_rank_0
 from megatron.data.vit_dataset import build_train_valid_datasets
 from megatron.model.vit_model import VitModel
 from megatron.training import pretrain
 from megatron.utils import average_losses_across_data_parallel_group
 
-def model_provider():
+def model_provider(pre_process=True, post_process=True):
     """Build the model."""
 
     print_rank_0("building VIT model ...")
     args = get_args()
 
-    model = VitModel(num_classes=args.num_classes)
+    model = VitModel(num_classes=args.num_classes,
+                     pre_process=pre_process,
+                     post_process=post_process)
     return model
 
 def get_batch(data_iterator):
@@ -42,10 +45,21 @@ def get_batch(data_iterator):
 
     return images, labels
 
-def forward_step(data_iterator, model, input_tensor):
+def loss_func(labels, output_tensor):
+    logits = output_tensor.contiguous().float()
+    loss = F.cross_entropy(logits, labels)
+
+    outputs = torch.argmax(logits, -1)
+    correct = (outputs == labels).float()
+    accuracy = torch.mean(correct)
+
+    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
+
+    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+
+def forward_step(data_iterator, model):
     """Forward step."""
     timers = get_timers()
-    assert input_tensor is None
 
     # Get the batch.
     timers("batch-generator").start()
@@ -56,17 +70,9 @@ def forward_step(data_iterator, model, input_tensor):
     timers("batch-generator").stop()
 
     # Forward model. lm_labels
-    logits = model(images).contiguous().float()
-    loss = F.cross_entropy(logits, labels)
-
-    outputs = torch.argmax(logits, -1)
-    correct = (outputs == labels).float()
-    accuracy = torch.mean(correct)
-
-    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
-
-    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
+    output_tensor = model(images)
 
+    return output_tensor, partial(loss_func, labels)
 
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid, and test datasets."""
diff --git a/tasks/vision/classification.py b/tasks/vision/classification.py
index 5232b3f5492..71e840757e3 100644
--- a/tasks/vision/classification.py
+++ b/tasks/vision/classification.py
@@ -34,13 +34,14 @@ def train_valid_datasets_provider():
         )
         return train_ds, valid_ds
 
-    def model_provider():
+    def model_provider(pre_process=True, post_process=True):
         """Build the model."""
         args = get_args()
 
         print_rank_0("building classification model for ImageNet ...")
 
-        return VitModel(num_classes=args.num_classes, finetune=True)
+        return VitModel(num_classes=args.num_classes, finetune=True,
+                        pre_process=pre_process, post_process=post_process)
 
     """Finetune/evaluate."""
     finetune(
diff --git a/tasks/vision/eval_utils.py b/tasks/vision/eval_utils.py
index aabc04a1594..3a194119c17 100644
--- a/tasks/vision/eval_utils.py
+++ b/tasks/vision/eval_utils.py
@@ -16,10 +16,14 @@
 """Evaluation utilities."""
 
 import os
+from functools import partial
+
 import torch
+
 from megatron import get_args
-from megatron import print_rank_0
+from megatron import print_rank_0, print_rank_last
 from megatron import mpu
+from megatron.schedules import get_forward_backward_func
 from tasks.vision.finetune_utils import build_data_loader
 from tasks.vision.finetune_utils import process_batch
 from torchvision import datasets, transforms
@@ -56,7 +60,7 @@ def metrics_func(model, epoch):
         print_rank_0("calculating metrics ...")
         correct, total = calculate_correct_answers(model, dataloader, epoch)
         percent = float(correct) * 100.0 / float(total)
-        print_rank_0(
+        print_rank_last(
             " >> |epoch: {}| overall: correct / total = {} / {} = "
             "{:.4f} %".format(epoch, correct, total, percent)
         )
@@ -67,29 +71,61 @@ def metrics_func(model, epoch):
 def calculate_correct_answers(model, dataloader, epoch):
     """Calculate correct over total answers"""
 
-    model.eval()
+    args = get_args()
+    forward_backward_func = get_forward_backward_func()
+    for m in model:
+        m.eval()
+
+    def loss_func(labels, output_tensor):
+        logits = output_tensor
+
+        loss_dict = {}
+        # Compute the correct answers.
+        predicted = torch.argmax(logits, dim=-1)
+        corrects = (predicted == labels).float()
+        # Add to the counters.
+        loss_dict['total'] = labels.size(0)
+        loss_dict['correct'] = corrects.sum().item()
+
+        return 0, loss_dict
+
+    #defined inside to capture output_predictions
+    def correct_answers_forward_step(batch, model):
+        try:
+            batch_ = next(batch)
+        except BaseException:
+            batch_ = batch
+        images, labels = process_batch(batch_)
+
+        # Forward model.
+        args = get_args()
+        output_tensor = model(images)
+
+        return output_tensor, partial(loss_func, labels)
+
     with torch.no_grad():
         # For all the batches in the dataset.
         total = 0
         correct = 0
         for _, batch in enumerate(dataloader):
-            # Run the model forward.
-            images, labels = process_batch(batch)
-            logits = model(images).contiguous().float()
-            # Add output predictions.
-            # Compute the correct answers.
-            predicted = torch.argmax(logits, dim=-1)
-            corrects = (predicted == labels).float()
-            # Add to the counters.
-            total += labels.size(0)
-            correct += corrects.sum().item()
-    model.train()
+
+            loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
+                                               optimizer=None, timers=None, forward_only=True)
+
+            for loss_dict in loss_dicts:
+                total += loss_dict['total']
+                correct += loss_dict['correct']
+
+    for m in model:
+        m.train()
 
     # Reduce.
-    unreduced = torch.cuda.LongTensor([correct, total])
-    torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group())
+    if mpu.is_pipeline_last_stage():
+        unreduced = torch.cuda.LongTensor([correct, total])
+        torch.distributed.all_reduce(unreduced,
+                                     group=mpu.get_data_parallel_group())
 
-    # Print on screen.
-    correct_ans = unreduced[0].item()
-    total_count = unreduced[1].item()
-    return correct_ans, total_count
+        # Print on screen.
+        correct_ans = unreduced[0].item()
+        total_count = unreduced[1].item()
+        return correct_ans, total_count
diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py
index afde4aa8940..f9743883ca3 100644
--- a/tasks/vision/finetune_utils.py
+++ b/tasks/vision/finetune_utils.py
@@ -17,6 +17,7 @@
 
 import torch
 import torch.nn.functional as F
+from functools import partial
 from megatron import get_args
 from megatron import print_rank_0
 from megatron import get_timers
@@ -38,10 +39,21 @@ def process_batch(batch):
     return images, labels
 
 
-def _cross_entropy_forward_step(batch, model, input_tensor):
+def cross_entropy_loss_func(labels, output_tensor):
+    logits = output_tensor
+
+    # Cross-entropy loss.
+    loss = F.cross_entropy(logits.contiguous().float(), labels)
+
+    # Reduce loss for logging.
+    averaged_loss = average_losses_across_data_parallel_group([loss])
+
+    return loss, {'lm loss': averaged_loss[0]}
+
+
+def _cross_entropy_forward_step(batch, model):
     """Simple forward step with cross-entropy loss."""
     timers = get_timers()
-    assert input_tensor is None
 
     # Get the batch.
     timers("batch generator").start()
@@ -52,16 +64,10 @@ def _cross_entropy_forward_step(batch, model, input_tensor):
     images, labels = process_batch(batch_)
     timers("batch generator").stop()
 
-    # Forward model.
-    logits = model(images).contiguous().float()
-
-    # Cross-entropy loss.
-    loss = F.cross_entropy(logits, labels)
-
-    # Reduce loss for logging.
-    average_loss = average_losses_across_data_parallel_group([loss])
-
-    return loss, {"lm loss": average_loss[0]}
+   # Forward model.
+    output_tensor = model(images)
+  
+    return output_tensor, partial(cross_entropy_loss_func, labels)
 
 
 def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
@@ -103,23 +109,28 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
     """Traing and validation dataloaders."""
     args = get_args()
 
-    print_rank_0("building train and validation dataloaders ...")
+    print_rank_0('building train and validation dataloaders ...')
     # Training dataset.
-    train_dataloader = build_data_loader(
-        train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
-    )
+    train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
+                                           args.num_workers, not args.keep_last)
     # Set the training iterations.
     args.train_iters_per_epoch = len(train_dataloader)
     args.train_iters = args.epochs * args.train_iters_per_epoch
     # Validation dataset. For this dataset, we do not need to set up
     # shuffling so we can just use a simple infinite loop.
-    valid_dataloader_ = build_data_loader(
-        valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last
-    )
+    valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
+                                            args.num_workers, not args.keep_last)
     valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
 
-    return train_dataloader, valid_dataloader
+    # Now that we've built the data loaders, set batch_size arguments
+    # to the actual batch size the model will see for this dataset.
+    # This is necessary so pipeline transfers know what size they are
+    # and the LR schedule, which is based on samples seen, gets set
+    # correctly.
+    args.orig_micro_batch_size = args.micro_batch_size
+    args.orig_global_batch_size = args.global_batch_size
 
+    return train_dataloader, valid_dataloader
 
 def _train(
     model,
@@ -135,7 +146,8 @@ def _train(
     timers = get_timers()
 
     # Turn on training mode which enables dropout.
-    model.train()
+    for m in model:
+        m.train()
 
     # Tracking loss.
     losses_dict_sum = {}
@@ -166,12 +178,16 @@ def _train(
             start_iteration = 0
 
             # Train for one step.
-            losses_dict, skipped_iter = train_step(
+            losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
                 forward_step, batch, model, optimizer, lr_scheduler
             )
             iteration += 1
 
             # Logging.
+            params_norm = None
+            if args.log_params_norm:
+                params_norm = calc_params_l2_norm(model)
+
             report_memory_flag = training_log(
                 losses_dict,
                 losses_dict_sum,
@@ -180,6 +196,9 @@ def _train(
                 optimizer.get_loss_scale().item(),
                 report_memory_flag,
                 skipped_iter,
+                grad_norm,
+                params_norm,
+                num_zeros_in_grad
             )
 
             # Autoresume

From 2eaf6c790d71928924fc2d679a2f9193e9d3be30 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 16:44:32 -0700
Subject: [PATCH 24/59] cleaning the code

---
 megatron/indexer.py                 | 32 ++++++++---------
 megatron/model/biencoder_model.py   | 25 +++++--------
 megatron/model/language_model.py    |  2 +-
 pretrain_ict.py                     |  2 +-
 tasks/finetune_utils.py             |  8 ++---
 tasks/main.py                       |  4 +--
 tasks/orqa/evaluate_orqa.py         | 35 +++++--------------
 tasks/orqa/evaluate_utils.py        |  3 +-
 tasks/orqa/supervised/data.py       |  6 ++--
 tasks/orqa/supervised/eval_utils.py |  9 ++---
 tasks/orqa/supervised/finetune.py   | 54 +++++++----------------------
 11 files changed, 56 insertions(+), 124 deletions(-)

diff --git a/megatron/indexer.py b/megatron/indexer.py
index cf557c83930..c88a74f1555 100644
--- a/megatron/indexer.py
+++ b/megatron/indexer.py
@@ -26,8 +26,8 @@ def __init__(self):
         self.evidence_embedder_obj = None
         self.biencoder_shared_query_context_model = \
             args.biencoder_shared_query_context_model
-        self.pre_process = True
-        self.post_process = True
+        #self.pre_process = True
+        #self.post_process = True
 
         # need to know whether we're using a REALM checkpoint (args.load)
         # or ICT checkpoint
@@ -46,7 +46,7 @@ def load_attributes(self):
         """
         Load the necessary attributes: model, dataloader and empty BlockData
         """
-        args = get_args()
+        #args = get_args()
         only_context_model = True
         if self.biencoder_shared_query_context_model:
             only_context_model = False
@@ -56,7 +56,7 @@ def load_attributes(self):
 
         #model = get_model(biencoder_model_provider)
 
-        model = get_model(get_model_provider(only_context_model=only_context_model, 
+        model = get_model(get_model_provider(only_context_model=only_context_model,
             biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
 
         #model = get_model(lambda: biencoder_model_provider(only_context_model \
@@ -103,12 +103,12 @@ def build_and_save_index(self):
         while not hasattr(unwrapped_model, 'embed_text'):
             unwrapped_model = unwrapped_model.module
 
-        counter = 0
-        start_time = time.time()
-        cur_time = start_time
+        #counter = 0
+        #start_time = time.time()
+        #cur_time = start_time
         while True:
             #start_time = time.time()
-            t1 = time.time()
+            #t1 = time.time()
             try:
                 # batch also has query_tokens and query_pad_data
                 row_id, context_tokens, context_mask, context_types, \
@@ -118,7 +118,7 @@ def build_and_save_index(self):
                 break
 
             #print_rank_0("get batch time {}".format(cur_time - time.time()))
-            t2 = time.time()
+            #t2 = time.time()
             # TODO: can we add with torch.no_grad() to reduce memory usage
             # detach, separate fields and add to BlockData
             assert context_mask.dtype == torch.bool
@@ -129,17 +129,17 @@ def build_and_save_index(self):
             context_logits = detach(context_logits)
             row_id = detach(row_id)
             #print_rank_0("embed text {}".format(cur_time - time.time()))
-            t3 = time.time()
+            #t3 = time.time()
  
             self.evidence_embedder_obj.add_block_data(row_id, context_logits)
             self.track_and_report_progress(batch_size=len(row_id))
             #print_rank_0("add block time {}".format(cur_time - time.time()))
-            t4 = time.time()
-            counter += 1
-            if counter % 1000 == 0:
-                print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
-                print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
-                cur_time = time.time()
+            #t4 = time.time()
+            #counter += 1
+            #if counter % 1000 == 0:
+            #    print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
+            #    print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
+            #    cur_time = time.time()
         # This process signals to finalize its shard and then synchronize with
         # the other processes
         self.evidence_embedder_obj.save_shard()
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index 5fb1dd313f6..6478c06bd0c 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -15,17 +15,17 @@
 from megatron.model.utils import scaled_init_method_normal
 from .module import MegatronModule
 
-def get_model_provider(only_query_model=False, only_context_model=False, 
+def get_model_provider(only_query_model=False, only_context_model=False,
         biencoder_shared_query_context_model=False):
 
     def model_provider(pre_process=True, post_process=True):
         """Build the model."""
 
         print_rank_0('building Bienoder model ...')
-        model = biencoder_model_provider(only_query_model=only_query_model, 
-                only_context_model = only_context_model, 
+        model = biencoder_model_provider(only_query_model=only_query_model,
+                only_context_model = only_context_model,
                 biencoder_shared_query_context_model = \
-                biencoder_shared_query_context_model, 
+                biencoder_shared_query_context_model,
                 pre_process=True, post_process=True)
 
         return model
@@ -33,21 +33,12 @@ def model_provider(pre_process=True, post_process=True):
     return model_provider
 
 
-
-#def biencoder_model_provider(pre_process=True, 
-#                             post_process=True):
- 
 def biencoder_model_provider(only_query_model=False,
                              only_context_model=False,
                              biencoder_shared_query_context_model=False,
                              pre_process=True,
                              post_process=True):
     """Build the model."""
-    #args = get_args()
-
-    #biencoder_shared_query_context_model = args.biencoder_shared_query_context_model
-    #only_context_model = args.only_context_model
-    #only_query_model = args.only_query_model
 
     assert mpu.get_tensor_model_parallel_world_size() == 1 and \
         mpu.get_pipeline_model_parallel_world_size() == 1, \
@@ -63,7 +54,7 @@ def biencoder_model_provider(only_query_model=False,
         only_query_model=only_query_model,
         only_context_model=only_context_model,
         biencoder_shared_query_context_model=\
-            biencoder_shared_query_context_model,
+        biencoder_shared_query_context_model,
         pre_process=pre_process,
         post_process=post_process)
 
@@ -114,9 +105,9 @@ def __init__(self,
 
     def set_input_tensor(self, input_tensor):
         """See megatron.model.transformer.set_input_tensor()"""
-        #this is just a placeholder and will be needed when model
-        #parallelism will be used
-        #self.language_model.set_input_tensor(input_tensor)
+        # this is just a placeholder and will be needed when model
+        # parallelism will be used
+        # self.language_model.set_input_tensor(input_tensor)
         return
 
     def forward(self, query_tokens, query_attention_mask, query_types,
diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py
index abf1082b4a6..06330d81395 100644
--- a/megatron/model/language_model.py
+++ b/megatron/model/language_model.py
@@ -18,7 +18,7 @@
 import torch
 import torch.nn.functional as F
 
-from megatron import get_args, print_rank_0
+from megatron import get_args
 from megatron import mpu
 from .module import MegatronModule
 from megatron.model.enums import LayerType, AttnMaskType
diff --git a/pretrain_ict.py b/pretrain_ict.py
index 8a5876da49c..336bb494cda 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -36,7 +36,7 @@ def pretrain_ict_model_provider():
     #args.only_context_model = False
     #args.only_query_model = False
     #model = biencoder_model_provider()
- 
+
     model = biencoder_model_provider(
                 only_context_model=False,
                 only_query_model=False,
diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index be260a523f6..94cb367b3b0 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -16,7 +16,6 @@
 """Finetune utilities."""
 
 from functools import partial
-import sys
 
 import torch
 
@@ -81,7 +80,7 @@ def _cross_entropy_forward_step(batch, model):
     return output_tensor, partial(cross_entropy_loss_func, labels)
 
 
-def build_data_loader(dataset, micro_batch_size, num_workers, drop_last, 
+def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
         task_collate_fn=None):
     """Data loader. Note that batch-size is the local (per GPU) batch-size."""
 
@@ -190,7 +189,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                 continue
             # Set to zero so the next epoch does not skip any batches.
             start_iteration = 0
-    
+
             # Train for one step.
             out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
 
@@ -226,9 +225,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                            valid_dataloader, model,
                                            iteration, False)
 
-            #if iteration == 600:
-            #    sys.exit()
-
         # Checkpointing at the end of each epoch.
         if args.save:
             save_checkpoint(iteration, model, optimizer, lr_scheduler)
diff --git a/tasks/main.py b/tasks/main.py
index 29fd44f39af..59b377a4ba7 100644
--- a/tasks/main.py
+++ b/tasks/main.py
@@ -89,8 +89,8 @@ def get_tasks_args(parser):
     #                    help="Av.rank validation: batch size to process passages")
     #group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
     #                    help="Av.rank validation: max num of questions")
- 
- 
+
+
     return parser
 
 
diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py
index a9c52e3b35e..87c59ea30e2 100644
--- a/tasks/orqa/evaluate_orqa.py
+++ b/tasks/orqa/evaluate_orqa.py
@@ -15,18 +15,6 @@
 
 """Main tasks functionality."""
 
-import os
-import sys
-
-#sys.path.append(
-#    os.path.abspath(
-#        os.path.join(
-#            os.path.join(os.path.dirname(__file__), os.path.pardir),
-#            os.path.pardir,
-#        )
-#    )
-#)
-
 from megatron import get_args, print_rank_0
 from megatron.indexer import IndexBuilder
 from tasks.orqa.evaluate_utils import ORQAEvaluator
@@ -35,30 +23,23 @@ def main():
     """
     Main program
     """
-    #initialize_megatron(extra_args_provider=None,
-    #                    args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
 
     args = get_args()
 
-    """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
-    - Include all args needed for initial model specification
-
-    Other key args:
-        --block-data-path: path to write to
-        --ict-load or --realm-load: path to checkpoint with which to embed
-        --data-path and --titles-data-path: paths for dataset
-        --indexer-log-interval: reporting interval
-        --indexer-batch-size: size specific for indexer jobs
-
-    Check README.md for example script
+    """
+    Create a BlockData data structure by running an IndexBuilder over an
+    ICT Dataset and then evaluate on NQ task
     """
 
-    #print_rank_0("Starting index builder!")
+    print_rank_0("Starting index builder!")
 
     index_builder = IndexBuilder()
     index_builder.build_and_save_index()
     print_rank_0("Build and save indices: done!")
 
+
+    print_rank_0("Starting evaluations!")
+
     # Set up the model and evaluator
     evaluator = ORQAEvaluator()
 
@@ -68,4 +49,4 @@ def main():
 
     if args.qa_data_test is not None:
         evaluator.evaluate(args.qa_data_test, "TEST")
-    
+
diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py
index d677fc189e3..06fcf57766a 100644
--- a/tasks/orqa/evaluate_utils.py
+++ b/tasks/orqa/evaluate_utils.py
@@ -47,10 +47,9 @@ def __init__(self):
         #args.only_query_model = only_query_model
         #args.only_context_model = False
 
-        model = get_model(get_model_provider(only_query_model=only_query_model, 
+        model = get_model(get_model_provider(only_query_model=only_query_model,
             biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
 
-
         #model = get_model(lambda: biencoder_model_provider(only_query_model=\
         #model = get_model(lambda: biencoder_model_provider(only_query_model=\
         #    only_query_model, biencoder_shared_query_context_model=\
diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py
index 466212729fa..e2de454c92f 100644
--- a/tasks/orqa/supervised/data.py
+++ b/tasks/orqa/supervised/data.py
@@ -104,9 +104,9 @@ def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
     return enc_ids, tokentypes_enc, pad_mask
 
 
-def build_sample(query_ids, query_types, query_pad_mask, 
+def build_sample(query_ids, query_types, query_pad_mask,
                 ctx_ids, ctx_types, ctx_pad_mask, answers,
-                neg_ctx_id_list=None, neg_ctx_types_list=None, 
+                neg_ctx_id_list=None, neg_ctx_types_list=None,
                 include_neg=False):
     """Convert to numpy and return a sample consumed by the batch producer."""
 
@@ -295,5 +295,3 @@ def process_samples_from_single_path(filename):
         print_rank_0(' >> processed {} samples.'.format(len(samples)))
         return samples
 
-
-
diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py
index 729367266d3..67dca512b0d 100644
--- a/tasks/orqa/supervised/eval_utils.py
+++ b/tasks/orqa/supervised/eval_utils.py
@@ -34,7 +34,6 @@ def task_collate_fn(batch_data):
     for d in batch_data:
         for k, v in d.items():
             tensorized.setdefault(k, []).append(v)
-    # assert len(tensorized) == 12
 
     tensorized['query'] = torch.LongTensor(tensorized['query'])
     tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
@@ -90,8 +89,6 @@ def process_batch(batch):
            neg_context_tokens, neg_context_mask, neg_context_types, reference
 
 def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
-#, datapath, 
-#    rank0sampler=False):
     """Provide function that calculates accuracies."""
     args = get_args()
 
@@ -112,9 +109,7 @@ def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
                                    args.eval_micro_batch_size,
                                    num_workers=args.num_workers,
                                    drop_last=drop_last,
-                                   task_collate_fn=task_collate_fn) 
-                                   #shuffle=False,
-                                   #rank0sampler=rank0sampler)
+                                   task_collate_fn=task_collate_fn)
     dataloaders = (dataset.dataset_name, dataloader)
 
     def metrics_func(model, epoch, output_predictions=False):
@@ -197,7 +192,7 @@ def get_rank():
             losses = average_losses_across_data_parallel_group([rank, \
                 *topk_accs])
 
-            # create stats_dict with retrieval loss and all specified 
+            # create stats_dict with retrieval loss and all specified
             # top-k accuracies
             topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
                 zip(args.retriever_report_topk_accuracies, losses[1:])}
diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py
index 6b323f434d5..d6db0362afc 100644
--- a/tasks/orqa/supervised/finetune.py
+++ b/tasks/orqa/supervised/finetune.py
@@ -22,27 +22,21 @@
 import torch
 import torch.nn.functional as F
 
-from megatron import get_args
-from megatron import get_timers
-from megatron import get_tokenizer
-from megatron import mpu
-from megatron import print_rank_0
-from megatron.utils import average_losses_across_data_parallel_group
+from megatron import get_args, get_timers, get_tokenizer
+from megatron import mpu, print_rank_0
+from megatron.indexer import IndexBuilder
 from megatron.model.biencoder_model import biencoder_model_provider
-#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider
-#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
+from megatron.utils import average_losses_across_data_parallel_group
 from pretrain_ict import get_group_world_size_rank
 from tasks.finetune_utils import finetune
 from tasks.orqa.supervised.eval_utils import accuracy_func_provider
 from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
 from tasks.orqa.evaluate_utils import ORQAEvaluator
-from megatron.indexer import IndexBuilder
 
-def orqa(Dataset): # , name_from_datapath_func):
+def orqa(Dataset):
 
     def cross_entropy_forward_step(batch, model):
         """Simple forward step with cross-entropy loss."""
-        args = get_args()
         timers = get_timers()
         tokenizer = get_tokenizer()
 
@@ -73,17 +67,15 @@ def cross_entropy_forward_step(batch, model):
             context_types = torch.cat([context_types, neg_context_types])
 
         # Forward model.
-        #query_logits, context_logits = model(query_tokens, query_mask, 
-        output_tensor = model(query_tokens, query_mask, 
-                                        query_types, context_tokens, 
+        output_tensor = model(query_tokens, query_mask,
+                                        query_types, context_tokens,
                                         context_mask, context_types)
 
-        return output_tensor, partial(cross_entropy_loss_func_, query_tokens, context_tokens)
+        return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
 
 
-    #def cross_entropy_loss_func(labels, output_tensor):
-    def cross_entropy_loss_func_(query_tokens, context_tokens, output_tensor):
-        args = get_args() 
+    def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
+        args = get_args()
 
         local_batch_size = query_tokens.shape[0]
         group, rank, world_size = get_group_world_size_rank()
@@ -184,12 +176,9 @@ def model_provider(pre_process=True, post_process=True):
         """Build the model."""
         args = get_args()
         print_rank_0('building retriever model for {} ...'.format(args.task))
-        #args.only_context_model=False
-        #args.only_query_model=False
-        #model = biencoder_model_provider()
-        
+
         model = biencoder_model_provider(only_context_model=False,
-                    only_query_model=False, 
+                    only_query_model=False,
                     biencoder_shared_query_context_model=\
                     args.biencoder_shared_query_context_model,
                     pre_process=pre_process, post_process=post_process)
@@ -200,7 +189,6 @@ def single_dataset_provider(datapath):
         args = get_args()
         tokenizer = get_tokenizer()
 
-        #name = name_from_datapath_func(datapath)
         name = datapath[0].split('/')[-1].split('.')[0]
         return Dataset(name,
                        datapath,
@@ -208,41 +196,25 @@ def single_dataset_provider(datapath):
                        args.retriever_seq_length,
                        evaluate=True)
 
-    #def distributed_metrics_func_provider():
     def metrics_func_provider():
         """Provide metrics callback function."""
-
-        #def name_from_datapath(datapath):
-        #    return datapath[0].split('/')[-1].split('.')[0]
-        
         return accuracy_func_provider(single_dataset_provider)
 
-    #def rank0_metrics_func_provider(datapath):
-    #    """Provide metrics callback function."""
-    #    return accuracy_func_provider(single_dataset_provider, datapath,
-    #                                  rank0sampler=True)
-
     """Finetune/evaluate."""
     finetune(train_valid_datasets_provider,
              model_provider,
              forward_step=cross_entropy_forward_step,
              end_of_epoch_callback_provider=metrics_func_provider,
              task_collate_fn=task_collate_fn)
-            #,end_of_training_callback_provider=rank0_metrics_func_provider)
-
 
 def main():
     args = get_args()
 
     if args.task == 'RET-FINETUNE-NQ':
         from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset
-
-        #def name_from_datapath(datapath):
-        #    return datapath[0].split('/')[-1].split('.')[0]
-
     else:
         raise NotImplementedError('ORQA task {} is not implemented.'.format(
             args.task))
 
-    orqa(Dataset) #, name_from_datapath)
+    orqa(Dataset)
 

From 25293807ca7889fd651f3260a82d437fd19f38d1 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 21:23:01 -0700
Subject: [PATCH 25/59] additional cleaning

---
 megatron/arguments.py        |  6 ------
 megatron/checkpointing.py    |  1 -
 megatron/indexer.py          | 39 +++++-------------------------------
 megatron/learning_rates.py   | 16 ---------------
 tasks/finetune_utils.py      |  3 ++-
 tasks/orqa/evaluate_utils.py | 11 ----------
 6 files changed, 7 insertions(+), 69 deletions(-)

diff --git a/megatron/arguments.py b/megatron/arguments.py
index d3a9ad30102..b8c230f5793 100644
--- a/megatron/arguments.py
+++ b/megatron/arguments.py
@@ -479,12 +479,6 @@ def _add_learning_rate_args(parser):
     group.add_argument('--min-lr', type=float, default=0.0,
                        help='Minumum value for learning rate. The scheduler'
                        'clip values below this threshold.')
-    group.add_argument('--override-lr-new', action='store_true',
-                       help='Reset the values of the scheduler (learning rate,'
-                       'warmup iterations, minimum learning rate, maximum '
-                       'number of iterations, and decay style from input '
-                       'arguments and ignore values from checkpoints. Note'
-                       'that all the above values will be reset.')
     group.add_argument('--override-lr-scheduler', action='store_true',
                        help='Reset the values of the scheduler (learning rate,'
                        'warmup iterations, minimum learning rate, maximum '
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index 7cd13c03b55..14e7971abed 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -419,7 +419,6 @@ def load_biencoder_checkpoint(model, only_query_model=False,
 
     assert len(model) == 1
     model[0].load_state_dict(ret_state_dict)
-
     torch.distributed.barrier()
 
     if mpu.get_data_parallel_rank() == 0:
diff --git a/megatron/indexer.py b/megatron/indexer.py
index c88a74f1555..d2ff9e36f85 100644
--- a/megatron/indexer.py
+++ b/megatron/indexer.py
@@ -26,13 +26,10 @@ def __init__(self):
         self.evidence_embedder_obj = None
         self.biencoder_shared_query_context_model = \
             args.biencoder_shared_query_context_model
-        #self.pre_process = True
-        #self.post_process = True
 
         # need to know whether we're using a REALM checkpoint (args.load)
         # or ICT checkpoint
         assert not (args.load and args.ict_load)
-        #self.using_realm_chkpt = args.ict_load is None
 
         self.log_interval = args.indexer_log_interval
         self.batch_size = args.indexer_batch_size
@@ -46,24 +43,13 @@ def load_attributes(self):
         """
         Load the necessary attributes: model, dataloader and empty BlockData
         """
-        #args = get_args()
         only_context_model = True
         if self.biencoder_shared_query_context_model:
             only_context_model = False
 
-        #args.only_context_model = only_context_model
-        #args.only_query_model = False
-
-        #model = get_model(biencoder_model_provider)
-
-        model = get_model(get_model_provider(only_context_model=only_context_model,
-            biencoder_shared_query_context_model=self.biencoder_shared_query_context_model))
-
-        #model = get_model(lambda: biencoder_model_provider(only_context_model \
-        #model = get_model(lambda: biencoder_model_provider(only_context_model \
-        #    = only_context_model, biencoder_shared_query_context_model = \
-        #    self.biencoder_shared_query_context_model,
-        #    pre_process=True, post_process=True)
+        model = get_model(get_model_provider(only_context_model=\
+            only_context_model, biencoder_shared_query_context_model=\
+            self.biencoder_shared_query_context_model))
 
         self.model = load_biencoder_checkpoint(model,
                 only_context_model=only_context_model)
@@ -103,12 +89,7 @@ def build_and_save_index(self):
         while not hasattr(unwrapped_model, 'embed_text'):
             unwrapped_model = unwrapped_model.module
 
-        #counter = 0
-        #start_time = time.time()
-        #cur_time = start_time
         while True:
-            #start_time = time.time()
-            #t1 = time.time()
             try:
                 # batch also has query_tokens and query_pad_data
                 row_id, context_tokens, context_mask, context_types, \
@@ -117,8 +98,6 @@ def build_and_save_index(self):
             except (StopIteration, IndexError):
                 break
 
-            #print_rank_0("get batch time {}".format(cur_time - time.time()))
-            #t2 = time.time()
             # TODO: can we add with torch.no_grad() to reduce memory usage
             # detach, separate fields and add to BlockData
             assert context_mask.dtype == torch.bool
@@ -128,18 +107,10 @@ def build_and_save_index(self):
 
             context_logits = detach(context_logits)
             row_id = detach(row_id)
-            #print_rank_0("embed text {}".format(cur_time - time.time()))
-            #t3 = time.time()
- 
+
             self.evidence_embedder_obj.add_block_data(row_id, context_logits)
             self.track_and_report_progress(batch_size=len(row_id))
-            #print_rank_0("add block time {}".format(cur_time - time.time()))
-            #t4 = time.time()
-            #counter += 1
-            #if counter % 1000 == 0:
-            #    print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
-            #    print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
-            #    cur_time = time.time()
+
         # This process signals to finalize its shard and then synchronize with
         # the other processes
         self.evidence_embedder_obj.save_shard()
diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py
index 18ce635614f..c53af8d54ac 100644
--- a/megatron/learning_rates.py
+++ b/megatron/learning_rates.py
@@ -18,7 +18,6 @@
 import math
 
 from megatron import print_rank_0
-from megatron import get_args
 
 class AnnealingLR(object):
     """Anneals the learning rate."""
@@ -60,7 +59,6 @@ def get_lr(self):
         """Learning rate decay functions from:
               https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
 
-        #print_rank_0("self.warmup_steps {} self.num_steps {} self.decay_steps {} self.min_lr {} self.maxlr {}".format(self.warmup_steps, self.num_steps, self.decay_steps, self.min_lr, self.max_lr))
         # Use linear warmup for the initial part.
         if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
             return self.max_lr * float(self.num_steps) / \
@@ -90,20 +88,6 @@ def get_lr(self):
             raise Exception('{} decay style is not supported.'.format(
                 self.decay_style))
 
-        args = get_args()
-
-        if args.override_lr_new:
-            mod_num_steps_ = min(self.num_steps, self.decay_steps - self.warmup_steps)
-            mod_num_steps_ = mod_num_steps_ - self.warmup_steps
-            use_lr = delta_lr * float(self.decay_steps - mod_num_steps_) / float(self.decay_steps)
-            should_use_lr = self.min_lr + coeff * delta_lr
-            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} should_use_lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, should_use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
-        else:
-            use_lr = self.min_lr + coeff * delta_lr
-            print_rank_0("num_steps {} decay_steps {} decay_ratio {} coeff {} delta_lr {} use lr {} self.warmup_steps {} self.num_steps {} self.decay_steps {}".format(num_steps_, decay_steps_, decay_ratio, coeff, delta_lr, use_lr, self.warmup_steps, self.num_steps, self.decay_steps))
-
-        return use_lr
-
         return self.min_lr + coeff * delta_lr
 
 
diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index 94cb367b3b0..cee4e1ee83c 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -114,7 +114,8 @@ def _build_infinite_size_dataloader(dataloader):
             iterator = dataloader.__iter__()
 
 
-def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None):
+def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
+    task_collate_fn=None):
     """Traing and validation dataloaders."""
     args = get_args()
 
diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py
index 06fcf57766a..f750f3f99c3 100644
--- a/tasks/orqa/evaluate_utils.py
+++ b/tasks/orqa/evaluate_utils.py
@@ -44,20 +44,9 @@ def __init__(self):
         if args.biencoder_shared_query_context_model:
             only_query_model = False
 
-        #args.only_query_model = only_query_model
-        #args.only_context_model = False
-
         model = get_model(get_model_provider(only_query_model=only_query_model,
             biencoder_shared_query_context_model=args.biencoder_shared_query_context_model))
 
-        #model = get_model(lambda: biencoder_model_provider(only_query_model=\
-        #model = get_model(lambda: biencoder_model_provider(only_query_model=\
-        #    only_query_model, biencoder_shared_query_context_model=\
-        #    args.biencoder_shared_query_context_model,
-        #    pre_process=True, post_process=True))
-
-        #model = get_model(biencoder_model_provider)
-
         self.model = load_biencoder_checkpoint(model,
                 only_query_model=only_query_model)
 

From 8e44d619b00577114f902f24814eee98dd1a7fde Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 21:55:17 -0700
Subject: [PATCH 26/59] renaming the folders

---
 pretrain_ict.py                                          | 6 ++----
 tasks/orqa/evaluate_utils.py                             | 9 +++++----
 tasks/orqa/{natural_questions => unsupervised}/nq.py     | 0
 .../orqa/{natural_questions => unsupervised}/qa_utils.py | 2 +-
 .../{natural_questions => unsupervised}/tokenizers.py    | 0
 5 files changed, 8 insertions(+), 9 deletions(-)
 rename tasks/orqa/{natural_questions => unsupervised}/nq.py (100%)
 rename tasks/orqa/{natural_questions => unsupervised}/qa_utils.py (98%)
 rename tasks/orqa/{natural_questions => unsupervised}/tokenizers.py (100%)

diff --git a/pretrain_ict.py b/pretrain_ict.py
index 336bb494cda..220e2725480 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -33,15 +33,13 @@
 
 def pretrain_ict_model_provider():
     args = get_args()
-    #args.only_context_model = False
-    #args.only_query_model = False
-    #model = biencoder_model_provider()
 
     model = biencoder_model_provider(
                 only_context_model=False,
                 only_query_model=False,
                 biencoder_shared_query_context_model=\
-                    args.biencoder_shared_query_context_model)
+                args.biencoder_shared_query_context_model)
+
     return model
 
 def get_group_world_size_rank():
diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py
index f750f3f99c3..08b1e929b3e 100644
--- a/tasks/orqa/evaluate_utils.py
+++ b/tasks/orqa/evaluate_utils.py
@@ -18,13 +18,14 @@
 from megatron import get_args, print_rank_0
 from megatron.checkpointing import load_biencoder_checkpoint
 from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
-from tasks.orqa.natural_questions.nq import get_nq_dataset
-from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader
-from tasks.orqa.natural_questions.nq import process_nq_batch
-from tasks.orqa.natural_questions.qa_utils import calculate_matches
 from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex
 from megatron.model.biencoder_model import get_model_provider
 from megatron.training import get_model
+from tasks.orqa.unsupervised.nq import get_nq_dataset
+from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader
+from tasks.orqa.unsupervised.nq import process_nq_batch
+from tasks.orqa.unsupervised.qa_utils import calculate_matches
+
 
 class ORQAEvaluator(object):
     def __init__(self):
diff --git a/tasks/orqa/natural_questions/nq.py b/tasks/orqa/unsupervised/nq.py
similarity index 100%
rename from tasks/orqa/natural_questions/nq.py
rename to tasks/orqa/unsupervised/nq.py
diff --git a/tasks/orqa/natural_questions/qa_utils.py b/tasks/orqa/unsupervised/qa_utils.py
similarity index 98%
rename from tasks/orqa/natural_questions/qa_utils.py
rename to tasks/orqa/unsupervised/qa_utils.py
index 24e71e683a4..811a05834a4 100644
--- a/tasks/orqa/natural_questions/qa_utils.py
+++ b/tasks/orqa/unsupervised/qa_utils.py
@@ -22,7 +22,7 @@
 from typing import Tuple, List, Dict
 
 import regex as re
-from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer
+from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer
 
 logger = logging.getLogger(__name__)
 
diff --git a/tasks/orqa/natural_questions/tokenizers.py b/tasks/orqa/unsupervised/tokenizers.py
similarity index 100%
rename from tasks/orqa/natural_questions/tokenizers.py
rename to tasks/orqa/unsupervised/tokenizers.py

From 113c63623050267e2d5387334042ccc3f008cc9c Mon Sep 17 00:00:00 2001
From: Jared Casper 
Date: Tue, 18 May 2021 23:18:32 -0700
Subject: [PATCH 27/59] Add temporary assert to finetuning until it can be
 fixed.

---
 tasks/finetune_utils.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index 918417b41b2..e6c31e88f0a 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -19,7 +19,7 @@
 
 import torch
 
-from megatron import get_args
+from megatron import get_args, get_num_microbatches
 from megatron import print_rank_0
 from megatron import get_timers
 from megatron import mpu
@@ -154,6 +154,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
     args = get_args()
     timers = get_timers()
 
+    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
+
     # Turn on training mode which enables dropout.
     for m in model:
         m.train()

From 7577931bcee062ce6213f1147fbb034fb35c307c Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 23:39:13 -0700
Subject: [PATCH 28/59] Fixed issues with ICT pretraining

---
 examples/create_embeddings.sh | 32 -----------------------
 pretrain_ict.py               | 49 +++++++++++++++++++++--------------
 tools/create_doc_index.py     | 33 -----------------------
 3 files changed, 29 insertions(+), 85 deletions(-)
 delete mode 100644 examples/create_embeddings.sh
 delete mode 100644 tools/create_doc_index.py

diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh
deleted file mode 100644
index 59a5839f7e2..00000000000
--- a/examples/create_embeddings.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/bin/bash
-
-# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
-
-RANK=0
-WORLD_SIZE=1
-
-# Wikipedia data can be downloaded from the following link:
-# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
-EVIDENCE_DATA_DIR=
-EMBEDDING_PATH=
-CHECKPOINT_PATH=
-
-python tools/create_doc_index.py \
-    --num-layers 12 \
-    --hidden-size 768 \
-    --num-attention-heads 12 \
-    --tensor-model-parallel-size 1 \
-    --micro-batch-size 128 \
-    --checkpoint-activations \
-    --seq-length 512 \
-    --retriever-seq-length 256 \
-    --max-position-embeddings 512 \
-    --load ${CHECKPOINT_PATH} \
-    --evidence-data-path ${EVIDENCE_DATA_DIR} \
-    --embedding-path ${EMBEDDING_PATH} \
-    --indexer-log-interval 1000 \
-    --indexer-batch-size 128 \
-    --vocab-file bert-vocab.txt \
-    --num-workers 2 \
-    --fp16
-
diff --git a/pretrain_ict.py b/pretrain_ict.py
index 220e2725480..79759250f18 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -14,6 +14,8 @@
 # limitations under the License.
 
 """Pretrain BERT for Inverse Cloze Task"""
+
+from functools import partial
 import math
 
 import torch
@@ -31,14 +33,15 @@
 from megatron.utils import average_losses_across_data_parallel_group
 
 
-def pretrain_ict_model_provider():
+def pretrain_ict_model_provider(pre_process=True, post_process=True):
     args = get_args()
 
     model = biencoder_model_provider(
                 only_context_model=False,
                 only_query_model=False,
                 biencoder_shared_query_context_model=\
-                args.biencoder_shared_query_context_model)
+                args.biencoder_shared_query_context_model,
+                pre_process=pre_process, post_process=post_process)
 
     return model
 
@@ -79,25 +82,9 @@ def backward(ctx, grad_output):
         output = output_list[rank].contiguous()
         return output
 
-def forward_step(data_iterator, model, input_tensor):
-    """Forward step."""
+def loss_func(output_tensor):
     args = get_args()
-    timers = get_timers()
-
-    # Get the batch.
-    timers('batch-generator').start()
-    query_tokens, query_mask, \
-    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
-    timers('batch-generator').stop()
-
-    # Query and Context Types
-    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
-    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
-
-    # Forward model.
-    query_logits, context_logits = model(query_tokens, query_mask,
-                                    query_types, context_tokens,
-                                    context_mask, context_types)
+    query_logits, context_logits = output_tensor
 
     micro_batch_size = query_logits.shape[0]
     # recall we assert that tensor_model_parallel_size == 1
@@ -139,6 +126,28 @@ def topk_accuracy(k):
     return loss, stats_dict
 
 
+
+def forward_step(data_iterator, model):
+    """Forward step."""
+    args = get_args()
+    timers = get_timers()
+
+    # Get the batch.
+    timers('batch-generator').start()
+    query_tokens, query_mask, \
+    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
+    timers('batch-generator').stop()
+
+    # Query and Context Types
+    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
+    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
+
+    # Forward model.
+    output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
+                        context_mask, context_types)
+
+    return output_tensor, partial(loss_func)
+
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid and test datasets."""
     args = get_args()
diff --git a/tools/create_doc_index.py b/tools/create_doc_index.py
deleted file mode 100644
index 4448d0e2908..00000000000
--- a/tools/create_doc_index.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import os
-import sys
-sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
-                                             os.path.pardir)))
-
-from megatron import print_rank_0
-from megatron.indexer import IndexBuilder
-from megatron.initialize import initialize_megatron
-
-
-def main():
-    """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
-    - Include all args needed for initial model specification
-
-    Other key args:
-        --block-data-path: path to write to
-        --ict-load or --realm-load: path to checkpoint with which to embed
-        --data-path and --titles-data-path: paths for dataset
-        --indexer-log-interval: reporting interval
-        --indexer-batch-size: size specific for indexer jobs
-
-    Check README.md for example script
-    """
-
-    initialize_megatron(extra_args_provider=None,
-                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
-    index_builder = IndexBuilder()
-    index_builder.build_and_save_index()
-    print_rank_0("Build and save indices: done!")
-
-if __name__ == "__main__":
-    main()
-

From dfb6a9b3753d7f749d50f96a07c29ba681aafafc Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 23:49:25 -0700
Subject: [PATCH 29/59] updated the evaluation script for retriver

---
 ...aluate_ict_zeroshot_nq.sh => evaluate_retriever_nq.sh} | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)
 rename examples/{evaluate_ict_zeroshot_nq.sh => evaluate_retriever_nq.sh} (81%)

diff --git a/examples/evaluate_ict_zeroshot_nq.sh b/examples/evaluate_retriever_nq.sh
similarity index 81%
rename from examples/evaluate_ict_zeroshot_nq.sh
rename to examples/evaluate_retriever_nq.sh
index e1ce45a9342..8519c3166da 100644
--- a/examples/evaluate_ict_zeroshot_nq.sh
+++ b/examples/evaluate_retriever_nq.sh
@@ -1,19 +1,19 @@
 #!/bin/bash
 
 # Evaluate natural question test data given Wikipedia embeddings and pretrained
-# ICT model
+# ICT model or a finetuned model for Natural Question task
 
 # Datasets can be downloaded from the following link:
 # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
 
 EVIDENCE_DATA_DIR=
 EMBEDDING_PATH=
-CHECKPOINT_PATH=
+CHECKPOINT_PATH=
 
-QA_FILE=
+QA_FILE=
 
 python tasks/main.py \
-    --task ICT-ZEROSHOT-NQ \
+    --task RETRIEVER-EVAL \
     --tokenizer-type BertWordPieceLowerCase \
     --num-layers 12 \
     --hidden-size 768 \

From f21a6628cd6f887c19e757cb62cc90bacdc8e0d7 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 23:50:28 -0700
Subject: [PATCH 30/59] updated the evaluation script for retriver

---
 megatron/data/biencoder_dataset_utils.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/megatron/data/biencoder_dataset_utils.py b/megatron/data/biencoder_dataset_utils.py
index dccf060f62a..f7b3b961b8c 100644
--- a/megatron/data/biencoder_dataset_utils.py
+++ b/megatron/data/biencoder_dataset_utils.py
@@ -20,7 +20,6 @@ def make_attention_mask(source_block, target_block):
     # (source_length, target_length)
     return mask
 
-
 def get_one_epoch_dataloader(dataset, micro_batch_size=None):
     """Specifically one epoch to be used in an indexing job."""
     args = get_args()

From a41e47812057169cd8eda1f20ea055c319db8d38 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 23:51:54 -0700
Subject: [PATCH 31/59] updated the evaluation script for retriver

---
 megatron/model/biencoder_model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index 6478c06bd0c..0573dbef0e4 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -26,7 +26,7 @@ def model_provider(pre_process=True, post_process=True):
                 only_context_model = only_context_model,
                 biencoder_shared_query_context_model = \
                 biencoder_shared_query_context_model,
-                pre_process=True, post_process=True)
+                pre_process=pre_process, post_process=post_process)
 
         return model
 

From 825375cf3895410f6c63bff2ab07f3debe9336ea Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 18 May 2021 23:58:30 -0700
Subject: [PATCH 32/59] updated the evaluation script for retriver

---
 tasks/main.py | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/tasks/main.py b/tasks/main.py
index 59b377a4ba7..6d8fc8f5fd6 100644
--- a/tasks/main.py
+++ b/tasks/main.py
@@ -77,18 +77,12 @@ def get_tasks_args(parser):
 
     # parameters for Av.rank validation method
     # Following options/arguments have been taken directly from DPR codebase
-    #group.add_argument("--val-av-rank-start-epoch", type=int, default=10000,
-    #                    help="Av.rank validation: the epoch from which to enable this validation")
     group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
                         help='Av.rank validation: how many hard negatives to'
                         ' take from each question pool')
     group.add_argument('--val-av-rank-other-neg', type=int, default=30,
                         help='Av.rank validation: how many other negatives to'
                         ' take from each question pool')
-    #group.add_argument("--val-av-rank-bsz", type=int, default=128,
-    #                    help="Av.rank validation: batch size to process passages")
-    #group.add_argument("--val-av-rank-max-qs", type=int, default=10000,
-    #                    help="Av.rank validation: max num of questions")
 
 
     return parser

From d078e54ab6142f9bab04c4edb7804d5e417f3746 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 19 May 2021 17:57:08 -0700
Subject: [PATCH 33/59] added exit interval for finetuning

---
 tasks/finetune_utils.py | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index cee4e1ee83c..a6b4c9fde80 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -16,7 +16,7 @@
 """Finetune utilities."""
 
 from functools import partial
-
+import sys
 import torch
 
 from megatron import get_args
@@ -215,9 +215,11 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                                   optimizer, lr_scheduler)
 
             # Checkpointing
+            saved_checkpoint = False
             if args.save and args.save_interval and \
                iteration % args.save_interval == 0:
                 save_checkpoint(iteration, model, optimizer, lr_scheduler)
+                saved_checkpoint = True
 
             # Evaluation
             if args.eval_interval and iteration % args.eval_interval == 0:
@@ -226,6 +228,14 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                            valid_dataloader, model,
                                            iteration, False)
 
+            # Exiting based on iterations
+            if args.exit_interval and iteration % args.exit_interval == 0:
+                if not saved_checkpoint:
+                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
+                torch.distributed.barrier()
+                print_rank_0('exiting program at iteration {}'.format(iteration))
+                sys.exit()
+
         # Checkpointing at the end of each epoch.
         if args.save:
             save_checkpoint(iteration, model, optimizer, lr_scheduler)

From 63121a9e0f5b85dcff046fb2f918557f0c885594 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 19 May 2021 23:10:16 -0700
Subject: [PATCH 34/59] updating the scripts

---
 examples/evaluate_retriever_nq.sh          |  5 +-
 examples/finetune_retriever_distributed.sh | 56 ++++++++++++++++++++++
 megatron/model/biencoder_model.py          |  2 +-
 3 files changed, 61 insertions(+), 2 deletions(-)
 create mode 100755 examples/finetune_retriever_distributed.sh

diff --git a/examples/evaluate_retriever_nq.sh b/examples/evaluate_retriever_nq.sh
index 8519c3166da..8191af8476a 100644
--- a/examples/evaluate_retriever_nq.sh
+++ b/examples/evaluate_retriever_nq.sh
@@ -32,5 +32,8 @@ python tasks/main.py \
     --num-workers 2 \
     --faiss-use-gpu \
     --retriever-report-topk-accuracies 1 5 20 100 \
-    --fp16
+    --fp16 \
+    --indexer-log-interval 1000 \
+    --indexer-batch-size 128
+
 
diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
new file mode 100755
index 00000000000..6592ed51b64
--- /dev/null
+++ b/examples/finetune_retriever_distributed.sh
@@ -0,0 +1,56 @@
+#!/bin/bash
+
+# Finetune a BERT or pretrained ICT model using Google natural question data 
+# Datasets can be downloaded from the following link:
+# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
+
+WORLD_SIZE=8
+
+DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
+                  --nnodes 1 \
+                  --node_rank 0 \
+                  --master_addr localhost \
+                  --master_port 6000"
+
+CHECKPOINT_PATH=
+
+# Load either of the below
+BERT_LOAD_PATH=
+PRETRAINED_CHECKPOINT=
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
+        --task RET-FINETUNE-NQ \
+        --train-with-neg \
+        --train-hard-neg 1 \
+        --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
+        --num-layers 12 \
+        --hidden-size 768 \
+        --num-attention-heads 12 \
+        --tensor-model-parallel-size 1 \
+        --tokenizer-type BertWordPieceLowerCase \
+        --train-data nq-train.json \
+        --valid-data nq-dev.json \
+        --save ${CHECKPOINT_PATH} \
+        --load ${CHECKPOINT_PATH} \
+        --vocab-file bert-vocab.txt \
+        --bert-load ${BERT_LOAD_PATH} \
+        --save-interval 5000 \
+        --log-interval 10 \
+        --eval-interval 25000 \
+        --eval-iters 100 \
+        --indexer-log-interval 1000 \
+        --faiss-use-gpu \
+        --DDP-impl torch \
+        --fp16 \
+        --retriever-report-topk-accuracies 1 5 10 20 100 \
+        --seq-length 512 \
+        --retriever-seq-length 256 \
+        --max-position-embeddings 512 \
+        --retriever-score-scaling \
+        --epochs 80 \
+        --micro-batch-size 8 \
+        --eval-micro-batch-size 16 \
+        --indexer-batch-size 128 \
+        --lr 2e-5 \
+        --lr-warmup-fraction 0.01 \
+        --weight-decay 1e-1
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index 0573dbef0e4..e1f94bf1c43 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -320,7 +320,7 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
 
     def load_state_dict(self, state_dict, strict=True):
         """Customized load."""
-        print_rank_0("loading BERT weights")
+        print_rank_0("loading pretrained weights")
         self.language_model.load_state_dict(
             state_dict[self._language_model_key], strict=strict)
 

From fda81a212b7370c7ea0252ac1d4352244abb2c78 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 25 May 2021 13:08:54 -0700
Subject: [PATCH 35/59] updating no load rng

---
 tasks/finetune_utils.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index a6b4c9fde80..fc8380a5346 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -285,8 +285,11 @@ def finetune(train_valid_datasets_provider, model_provider,
     if args.iteration == 0 and args.pretrained_checkpoint is not None:
         original_load = args.load
         args.load = args.pretrained_checkpoint
+        original_rng = args.no_load_rng
+        args.no_load_rng = True
         _ = load_checkpoint(model, None, None)
         args.load = original_load
+        args.no_load_rng = original_rng
         # This is critical when only model is loaded. We should make sure
         # main parameters are also updated.
         optimizer.reload_model_params()

From c7c65bbb69a683d14750d09c57724aad540617ea Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Thu, 3 Jun 2021 09:51:54 -0700
Subject: [PATCH 36/59] updating script

---
 examples/evaluate_retriever_nq.sh          | 1 -
 examples/finetune_retriever_distributed.sh | 2 +-
 2 files changed, 1 insertion(+), 2 deletions(-)

diff --git a/examples/evaluate_retriever_nq.sh b/examples/evaluate_retriever_nq.sh
index 8191af8476a..8b87be3024a 100644
--- a/examples/evaluate_retriever_nq.sh
+++ b/examples/evaluate_retriever_nq.sh
@@ -29,7 +29,6 @@ python tasks/main.py \
     --retriever-seq-length 256 \
     --vocab-file  bert-vocab.txt\
     --qa-data-test ${QA_FILE} \
-    --num-workers 2 \
     --faiss-use-gpu \
     --retriever-report-topk-accuracies 1 5 20 100 \
     --fp16 \
diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
index 6592ed51b64..592427bf4e3 100755
--- a/examples/finetune_retriever_distributed.sh
+++ b/examples/finetune_retriever_distributed.sh
@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
         --bert-load ${BERT_LOAD_PATH} \
         --save-interval 5000 \
         --log-interval 10 \
-        --eval-interval 25000 \
+        --eval-interval 250000 \
         --eval-iters 100 \
         --indexer-log-interval 1000 \
         --faiss-use-gpu \

From 3dadd16d38e9e3089caddba193fb1317ffa338c6 Mon Sep 17 00:00:00 2001
From: Deepak Narayanan 
Date: Mon, 7 Jun 2021 18:42:56 +0000
Subject: [PATCH 37/59] Update T5 scripts

---
 examples/pretrain_t5.sh                     | 5 +++--
 examples/pretrain_t5_distributed.sh         | 5 +++--
 examples/pretrain_t5_distributed_with_mp.sh | 5 +++--
 3 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh
index 71fea8489a7..91fd5929bf6 100644
--- a/examples/pretrain_t5.sh
+++ b/examples/pretrain_t5.sh
@@ -15,7 +15,7 @@ python pretrain_t5.py \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
+       --global-batch-size 16 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -35,4 +35,5 @@ python pretrain_t5.py \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16 \
+       --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh
index 778b4ad2a36..2beb1cdaca6 100644
--- a/examples/pretrain_t5_distributed.sh
+++ b/examples/pretrain_t5_distributed.sh
@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
+       --global-batch-size 128 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -44,4 +44,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16 \
+       --vocab-extra-ids 100
diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh
index 9be70393d86..0012ac9a55d 100644
--- a/examples/pretrain_t5_distributed_with_mp.sh
+++ b/examples/pretrain_t5_distributed_with_mp.sh
@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 2048 \
+       --global-batch-size 128 \
        --seq-length 512 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
@@ -45,4 +45,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16
+       --fp16  \
+       --vocab-extra-ids 100

From 04c79f306d17ec9bd1b8445f46d8b1dbcff8f272 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Tue, 8 Jun 2021 11:49:55 -0700
Subject: [PATCH 38/59] resolved hang issue

---
 tasks/orqa/supervised/finetune.py | 41 ++++++++++++++++++++++++++++++-
 1 file changed, 40 insertions(+), 1 deletion(-)

diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py
index d6db0362afc..6edc4b888db 100644
--- a/tasks/orqa/supervised/finetune.py
+++ b/tasks/orqa/supervised/finetune.py
@@ -47,6 +47,8 @@ def cross_entropy_forward_step(batch, model):
         except BaseException:
             batch_ = batch
 
+        group, rank, world_size = get_group_world_size_rank()
+
         query_tokens, query_mask, query_types, query_pad_mask, \
         context_tokens, context_mask, context_types, context_pad_mask, \
         neg_context_tokens, neg_context_mask, neg_context_types, \
@@ -54,6 +56,7 @@ def cross_entropy_forward_step(batch, model):
 
         timers('batch generator').stop()
         local_batch_size = query_tokens.shape[0]
+        #print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True)
 
         # Text representation of query and context
         query_list, context_list = [], []
@@ -61,16 +64,49 @@ def cross_entropy_forward_step(batch, model):
             query_list.append(tokenizer.decode(query_tokens[i].tolist()))
             context_list.append(tokenizer.decode(context_tokens[i].tolist()))
 
+        if neg_context_tokens.size()[0] > 200:
+            current_length = neg_context_tokens.size()[0]
+            first_dim = torch.tensor([[neg_context_tokens.size()[0]]], device=torch.cuda.current_device())
+            neg_context_list = [torch.empty_like(first_dim) for _ in range(world_size)]
+            neg_context_list[rank].copy_(first_dim)
+            torch.distributed.all_gather(neg_context_list, first_dim, group=group)
+            all_neg_context_list = torch.cat(neg_context_list, dim=0).contiguous()
+            max_length = torch.max(all_neg_context_list)
+            torch.set_printoptions(profile="full")
+
+            if max_length > current_length:
+                print("rank {} before pad neg_context_tokens {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
+            neg_context_tokens = torch.nn.functional.pad(input=neg_context_tokens, pad=(0, 0, 0, max_length - neg_context_tokens.size()[0]))
+
+            input_ = torch.empty_like(neg_context_tokens).copy_(\
+                neg_context_tokens).detach_()
+            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
+            tensor_list[rank].copy_(input_)
+            torch.distributed.all_gather(tensor_list, input_, group=group)
+
+            if max_length > current_length:
+                print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
+                print("rank {} after pad neg_context_tokens current_length {}".format(rank, neg_context_tokens[current_length]), flush=True)
+                print("rank {} after pad neg_context_tokens max_length-1 {}".format(rank, neg_context_tokens[max_length-1]), flush=True)
+
+            if rank == 0:
+                print("rank {} other pad neg_context_tokens current_length-1 {}".format(rank, tensor_list[5][451]), flush=True)
+                print("rank {} other pad neg_context_tokens current_length {}".format(rank, tensor_list[5][452]), flush=True)
+                print("rank {} other pad neg_context_tokens max_length-1 {}".format(rank, tensor_list[5][max_length-1]), flush=True)
+
+            torch.set_printoptions(profile="default")
+            exit()
+
         if neg_context_tokens is not None:
             context_tokens = torch.cat([context_tokens, neg_context_tokens])
             context_mask = torch.cat([context_mask, neg_context_mask])
             context_types = torch.cat([context_types, neg_context_types])
 
+        #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
         # Forward model.
         output_tensor = model(query_tokens, query_mask,
                                         query_types, context_tokens,
                                         context_mask, context_types)
-
         return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
 
 
@@ -85,10 +121,13 @@ def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
         query_logits, context_logits = output_tensor
 
         if world_size > 1:
+            #print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
             input_ = torch.empty_like(context_logits).copy_(\
                 context_logits).detach_()
             tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
             tensor_list[rank].copy_(input_)
+            #print_rank_0("At cross_entropy_loss_func")
+            #print("rank {} input_ {}".format(rank, input_.size()))
             torch.distributed.all_gather(tensor_list, input_, group=group)
 
             # Check if all-gather happens in order

From ebfbfcec9d5d62df804fe3af75b4006e5ae34fde Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 01:09:23 -0700
Subject: [PATCH 39/59] fixed the tensor size miss-mass issue

---
 tasks/orqa/supervised/finetune.py | 89 +++++++++++++++++++------------
 1 file changed, 56 insertions(+), 33 deletions(-)

diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py
index 6edc4b888db..8f2b505fec7 100644
--- a/tasks/orqa/supervised/finetune.py
+++ b/tasks/orqa/supervised/finetune.py
@@ -33,6 +33,44 @@
 from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn
 from tasks.orqa.evaluate_utils import ORQAEvaluator
 
+# input_ is a 2D tensor
+def check_and_append_tensor_for_gather(group, rank, world_size, input_):
+
+    # gather the size of the first dimension of the tensor from all ranks
+    current_length = input_.size()[0]
+    first_dim = torch.tensor([[current_length]], 
+        device=torch.cuda.current_device())
+    input_list = [torch.empty_like(first_dim) for _ in range(world_size)]
+    input_list[rank].copy_(first_dim)
+    torch.distributed.all_gather(input_list, first_dim, group=group)
+    all_input_list = torch.cat(input_list, dim=0).contiguous()
+    max_length = torch.max(all_input_list)
+    min_length = torch.min(all_input_list)
+
+    #if rank == 0:
+    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
+    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
+
+    if max_length > current_length:
+        #print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
+        #torch.set_printoptions(profile="full")
+        
+        #input_ = torch.nn.functional.pad(input=input_, 
+        #    pad=(0, 0, 0, max_length - current_length))
+        padding=tuple([0] * (input_.dim() * 2 - 1)) + \
+            tuple([max_length - current_length])
+        input_ = F.pad(input=input_, pad=padding)
+
+        #print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
+        #print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
+        #print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
+
+    #if rank == 0:
+    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
+    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
+        
+    return input_
+
 def orqa(Dataset):
 
     def cross_entropy_forward_step(batch, model):
@@ -56,7 +94,6 @@ def cross_entropy_forward_step(batch, model):
 
         timers('batch generator').stop()
         local_batch_size = query_tokens.shape[0]
-        #print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True)
 
         # Text representation of query and context
         query_list, context_list = [], []
@@ -64,44 +101,30 @@ def cross_entropy_forward_step(batch, model):
             query_list.append(tokenizer.decode(query_tokens[i].tolist()))
             context_list.append(tokenizer.decode(context_tokens[i].tolist()))
 
-        if neg_context_tokens.size()[0] > 200:
-            current_length = neg_context_tokens.size()[0]
-            first_dim = torch.tensor([[neg_context_tokens.size()[0]]], device=torch.cuda.current_device())
-            neg_context_list = [torch.empty_like(first_dim) for _ in range(world_size)]
-            neg_context_list[rank].copy_(first_dim)
-            torch.distributed.all_gather(neg_context_list, first_dim, group=group)
-            all_neg_context_list = torch.cat(neg_context_list, dim=0).contiguous()
-            max_length = torch.max(all_neg_context_list)
-            torch.set_printoptions(profile="full")
-
-            if max_length > current_length:
-                print("rank {} before pad neg_context_tokens {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
-            neg_context_tokens = torch.nn.functional.pad(input=neg_context_tokens, pad=(0, 0, 0, max_length - neg_context_tokens.size()[0]))
-
-            input_ = torch.empty_like(neg_context_tokens).copy_(\
-                neg_context_tokens).detach_()
-            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
-            tensor_list[rank].copy_(input_)
-            torch.distributed.all_gather(tensor_list, input_, group=group)
-
-            if max_length > current_length:
-                print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
-                print("rank {} after pad neg_context_tokens current_length {}".format(rank, neg_context_tokens[current_length]), flush=True)
-                print("rank {} after pad neg_context_tokens max_length-1 {}".format(rank, neg_context_tokens[max_length-1]), flush=True)
-
-            if rank == 0:
-                print("rank {} other pad neg_context_tokens current_length-1 {}".format(rank, tensor_list[5][451]), flush=True)
-                print("rank {} other pad neg_context_tokens current_length {}".format(rank, tensor_list[5][452]), flush=True)
-                print("rank {} other pad neg_context_tokens max_length-1 {}".format(rank, tensor_list[5][max_length-1]), flush=True)
-
-            torch.set_printoptions(profile="default")
-            exit()
+        #if rank == 5:
+        #    print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
+        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
+ 
+        if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size:
+            neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens)
+            neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask)
+            neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types)
+            #exit()
+
+        #if rank == 5:
+        #    print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
+        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
+ 
 
         if neg_context_tokens is not None:
             context_tokens = torch.cat([context_tokens, neg_context_tokens])
             context_mask = torch.cat([context_mask, neg_context_mask])
             context_types = torch.cat([context_types, neg_context_types])
 
+        #if rank == 5:
+        #    print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
+        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)
+
         #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
         # Forward model.
         output_tensor = model(query_tokens, query_mask,

From e46f3260639355254e34975c531b0d628bfc9583 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 22:45:05 -0700
Subject: [PATCH 40/59] fixed the evaluation hangs

---
 tasks/orqa/supervised/finetune.py | 48 ++++++-------------------------
 1 file changed, 9 insertions(+), 39 deletions(-)

diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py
index 8f2b505fec7..50ca85fc810 100644
--- a/tasks/orqa/supervised/finetune.py
+++ b/tasks/orqa/supervised/finetune.py
@@ -47,28 +47,13 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_):
     max_length = torch.max(all_input_list)
     min_length = torch.min(all_input_list)
 
-    #if rank == 0:
-    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
-    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
-
+    # if the size are different than the max, extend the tensor
+    # accordingly
     if max_length > current_length:
-        #print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
-        #torch.set_printoptions(profile="full")
-        
-        #input_ = torch.nn.functional.pad(input=input_, 
-        #    pad=(0, 0, 0, max_length - current_length))
         padding=tuple([0] * (input_.dim() * 2 - 1)) + \
             tuple([max_length - current_length])
         input_ = F.pad(input=input_, pad=padding)
 
-        #print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
-        #print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
-        #print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
-
-    #if rank == 0:
-    #    print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
-    #    print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
-        
     return input_
 
 def orqa(Dataset):
@@ -101,31 +86,19 @@ def cross_entropy_forward_step(batch, model):
             query_list.append(tokenizer.decode(query_tokens[i].tolist()))
             context_list.append(tokenizer.decode(context_tokens[i].tolist()))
 
-        #if rank == 5:
-        #    print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
-        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
- 
-        if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size:
-            neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens)
-            neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask)
-            neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types)
-            #exit()
-
-        #if rank == 5:
-        #    print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
-        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
- 
+        if neg_context_tokens is not None:
+            neg_context_tokens = check_and_append_tensor_for_gather(group,
+                rank, world_size, neg_context_tokens)
+            neg_context_mask = check_and_append_tensor_for_gather(group,
+                rank, world_size, neg_context_mask)
+            neg_context_types = check_and_append_tensor_for_gather(group,
+                rank, world_size, neg_context_types)
 
         if neg_context_tokens is not None:
             context_tokens = torch.cat([context_tokens, neg_context_tokens])
             context_mask = torch.cat([context_mask, neg_context_mask])
             context_types = torch.cat([context_types, neg_context_types])
 
-        #if rank == 5:
-        #    print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(), 
-        #        query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)
-
-        #print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
         # Forward model.
         output_tensor = model(query_tokens, query_mask,
                                         query_types, context_tokens,
@@ -144,13 +117,10 @@ def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
         query_logits, context_logits = output_tensor
 
         if world_size > 1:
-            #print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
             input_ = torch.empty_like(context_logits).copy_(\
                 context_logits).detach_()
             tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
             tensor_list[rank].copy_(input_)
-            #print_rank_0("At cross_entropy_loss_func")
-            #print("rank {} input_ {}".format(rank, input_.size()))
             torch.distributed.all_gather(tensor_list, input_, group=group)
 
             # Check if all-gather happens in order

From a983cab331c3cbb937eaf1f0679c647ed942ede2 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 22:51:21 -0700
Subject: [PATCH 41/59] Adding readme

---
 tasks/orqa/README.md | 57 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 57 insertions(+)
 create mode 100644 tasks/orqa/README.md

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
new file mode 100644
index 00000000000..dd3d75cc542
--- /dev/null
+++ b/tasks/orqa/README.md
@@ -0,0 +1,57 @@
+The following steps show how to run unsupervised and supervised trainining and evaluation for retriever for open domain question answering.
+
+
+## REALM Pipeline
+The following sections (will) reflect the three stages of training a REALM system. For now it's just the ICT code.
+Loosely, they are pretraining the retriever modules, then jointly training the language model and the retriever, and then finetuning a question answering head on the language model with fixed retriever.
+
+### Inverse Cloze Task (ICT) Pretraining
+1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document.
+Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body.
+Refer to the following script
+
+python preprocess_data.py \
+    --input /path/to/corpus.json \
+    --json-keys text title \
+    --split-sentences \
+    --tokenizer-type BertWordPieceLowerCase \
+    --vocab-file /path/to/vocab.txt \
+    --output-prefix corpus_indexed \
+    --workers 5  # works well for 10 CPU cores. Scale up accordingly.
+
+ +2. Use a custom samples mapping function in place of `megatron/data/realm_dataset_utils.get_block_samples_mapping` if required. To do this, you will need to implement a new function in C++ inside of `megatron/data/helpers.cpp`. The samples mapping data structure is used to select the data that will constitute every training sample in advance of the training loop. + The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block. +3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task. +In REALM, this is an uncased bert base model trained with the standard hyperparameters. +4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. +The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32. +
+python pretrain_ict.py \
+    --num-layers 12 \
+    --num-attention-heads 12 \
+    --hidden-size 768 \
+    --batch-size 128 \
+    --seq-length 256 \
+    --max-position-embeddings 256 \
+    --ict-head-size 128 \
+    --train-iters 100000 \
+    --checkpoint-activations \
+    --bert-load /path/to/pretrained_bert \
+    --load checkpoints \
+    --save checkpoints \
+    --data-path /path/to/indexed_dataset \
+    --titles-data-path /path/to/titles_indexed_dataset \
+    --vocab-file /path/to/vocab.txt \
+    --lr 0.0001 \
+    --num-workers 2 \
+    --lr-decay-style linear \
+    --weight-decay 1e-2 \
+    --clip-grad 1.0 \
+    --warmup .01 \
+    --save-interval 3000 \
+    --query-in-block-prob 0.1 \
+    --fp16
+
+
+ From d562d7b5fddb7bf7e6140d35d05b2b4ea331f6e0 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 9 Jun 2021 23:17:35 -0700 Subject: [PATCH 42/59] Adding readme --- tasks/orqa/README.md | 89 +++++++++++++++++++++++--------------------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index dd3d75cc542..3d2f0216d9f 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -1,57 +1,62 @@ -The following steps show how to run unsupervised and supervised trainining and evaluation for retriever for open domain question answering. +We present below the steps on show how to run unsupervised and supervised trainining and evaluation for retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). - -## REALM Pipeline -The following sections (will) reflect the three stages of training a REALM system. For now it's just the ICT code. -Loosely, they are pretraining the retriever modules, then jointly training the language model and the retriever, and then finetuning a question answering head on the language model with fixed retriever. +## End-to-End Training of Neural Retrievers for Open-Domain Question Answering + +We use two stages for retriever pretraining and finetuning, (i) unsupervised pretraining, and (ii) supervised finetuning. + +### Unsupervised pretraining +1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body. -### Inverse Cloze Task (ICT) Pretraining -1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. -Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body. -Refer to the following script
-python preprocess_data.py \
+python tools/preprocess_data.py \
     --input /path/to/corpus.json \
     --json-keys text title \
     --split-sentences \
     --tokenizer-type BertWordPieceLowerCase \
     --vocab-file /path/to/vocab.txt \
     --output-prefix corpus_indexed \
-    --workers 5  # works well for 10 CPU cores. Scale up accordingly.
+    --workers 10
 
-2. Use a custom samples mapping function in place of `megatron/data/realm_dataset_utils.get_block_samples_mapping` if required. To do this, you will need to implement a new function in C++ inside of `megatron/data/helpers.cpp`. The samples mapping data structure is used to select the data that will constitute every training sample in advance of the training loop. - The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block. -3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task. -In REALM, this is an uncased bert base model trained with the standard hyperparameters. -4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. -The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32. +2. The `examples/pretrain_ict.sh` script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32). +
-python pretrain_ict.py \
-    --num-layers 12 \
-    --num-attention-heads 12 \
-    --hidden-size 768 \
-    --batch-size 128 \
-    --seq-length 256 \
-    --max-position-embeddings 256 \
-    --ict-head-size 128 \
-    --train-iters 100000 \
-    --checkpoint-activations \
-    --bert-load /path/to/pretrained_bert \
-    --load checkpoints \
-    --save checkpoints \
-    --data-path /path/to/indexed_dataset \
-    --titles-data-path /path/to/titles_indexed_dataset \
-    --vocab-file /path/to/vocab.txt \
-    --lr 0.0001 \
-    --num-workers 2 \
-    --lr-decay-style linear \
-    --weight-decay 1e-2 \
-    --clip-grad 1.0 \
-    --warmup .01 \
-    --save-interval 3000 \
-    --query-in-block-prob 0.1 \
-    --fp16
 
+PRETRAINED_BERT_PATH=
+TEXT_DATA_PATH=
+TITLE_DATA_PATH=
+CHECKPOINT_PATH=
+
+python pretrain_ict.py \
+        --num-layers 12 \
+        --hidden-size 768 \
+        --num-attention-heads 12 \
+        --tensor-model-parallel-size 1 \
+        --micro-batch-size 32 \
+        --seq-length 256 \
+        --max-position-embeddings 512 \
+        --train-iters 100000 \
+        --vocab-file bert-vocab.txt \
+        --tokenizer-type BertWordPieceLowerCase \
+        --DDP-impl torch \
+        --bert-load ${PRETRAINED_BERT_PATH} \
+        --log-interval 100 \
+        --eval-interval 1000 \
+        --eval-iters 10 \
+        --retriever-report-topk-accuracies 1 5 10 20 100 \
+        --retriever-score-scaling \
+        --load $CHECKPOINT_PATH \
+        --save $CHECKPOINT_PATH \
+        --data-path ${TEXT_DATA_PATH} \
+        --titles-data-path ${TITLE_DATA_PATH} \
+        --lr 0.0001 \
+        --lr-decay-style linear \
+        --weight-decay 1e-2 \
+        --clip-grad 1.0 \
+        --lr-warmup-fraction 0.01 \
+        --save-interval 4000 \
+        --exit-interval 8000 \
+        --query-in-block-prob 0.1 \
+        --fp16
 
From 1095d7e6fb171268e1ffc3f56ad7e63323ab420e Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 9 Jun 2021 23:19:12 -0700 Subject: [PATCH 43/59] Adding readme --- tasks/orqa/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index 3d2f0216d9f..18c7b051757 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -22,10 +22,10 @@ python tools/preprocess_data.py \
 
-PRETRAINED_BERT_PATH=
-TEXT_DATA_PATH=
-TITLE_DATA_PATH=
-CHECKPOINT_PATH=
+PRETRAINED_BERT_PATH="Specify path of pretrained BERT model"
+TEXT_DATA_PATH="Specify path and file prefix of the text data"
+TITLE_DATA_PATH="Specify path and file prefix od the titles"
+CHECKPOINT_PATH="Specify path"
 
 python pretrain_ict.py \
         --num-layers 12 \

From bab5cc4e37e259c39497fe0a76bbdf69aa5e3e51 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 23:38:30 -0700
Subject: [PATCH 44/59] Adding readme

---
 tasks/orqa/README.md | 49 ++++++++++----------------------------------
 1 file changed, 11 insertions(+), 38 deletions(-)

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
index 18c7b051757..7dcc276c1fa 100644
--- a/tasks/orqa/README.md
+++ b/tasks/orqa/README.md
@@ -20,43 +20,16 @@ python tools/preprocess_data.py \
 
 2. The `examples/pretrain_ict.sh` script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32).
 
-
+3. Evaluate the pretrained ICT model using `examples/evaluate_retriever_nq.sh` for natural question answering dataset.
 
-PRETRAINED_BERT_PATH="Specify path of pretrained BERT model"
-TEXT_DATA_PATH="Specify path and file prefix of the text data"
-TITLE_DATA_PATH="Specify path and file prefix od the titles"
-CHECKPOINT_PATH="Specify path"
-
-python pretrain_ict.py \
-        --num-layers 12 \
-        --hidden-size 768 \
-        --num-attention-heads 12 \
-        --tensor-model-parallel-size 1 \
-        --micro-batch-size 32 \
-        --seq-length 256 \
-        --max-position-embeddings 512 \
-        --train-iters 100000 \
-        --vocab-file bert-vocab.txt \
-        --tokenizer-type BertWordPieceLowerCase \
-        --DDP-impl torch \
-        --bert-load ${PRETRAINED_BERT_PATH} \
-        --log-interval 100 \
-        --eval-interval 1000 \
-        --eval-iters 10 \
-        --retriever-report-topk-accuracies 1 5 10 20 100 \
-        --retriever-score-scaling \
-        --load $CHECKPOINT_PATH \
-        --save $CHECKPOINT_PATH \
-        --data-path ${TEXT_DATA_PATH} \
-        --titles-data-path ${TITLE_DATA_PATH} \
-        --lr 0.0001 \
-        --lr-decay-style linear \
-        --weight-decay 1e-2 \
-        --clip-grad 1.0 \
-        --lr-warmup-fraction 0.01 \
-        --save-interval 4000 \
-        --exit-interval 8000 \
-        --query-in-block-prob 0.1 \
-        --fp16
-
+### Supervised finetuning + +1. We use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). We use the script `examples/finetune_retriever_distributed.sh` for this purpose. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. + +2. We evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. + + +More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). +The reader component will be available soon. + From 8661ca26ceb9f2e06dc3eaf08ed021838acd9edf Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 9 Jun 2021 23:47:41 -0700 Subject: [PATCH 45/59] Adding readme --- tasks/orqa/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index 7dcc276c1fa..dfcab731c50 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -1,11 +1,10 @@ -We present below the steps on show how to run unsupervised and supervised trainining and evaluation for retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). ## End-to-End Training of Neural Retrievers for Open-Domain Question Answering -We use two stages for retriever pretraining and finetuning, (i) unsupervised pretraining, and (ii) supervised finetuning. +We present below the steps on show how to run unsupervised and supervised trainining and evaluation for retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). ### Unsupervised pretraining -1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body. +1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body.
 python tools/preprocess_data.py \

From 293554aa352056a9fd1dc21a55a1faf3c7c980d1 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 23:49:32 -0700
Subject: [PATCH 46/59] Adding readme

---
 tasks/orqa/README.md | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
index dfcab731c50..dc3352733f6 100644
--- a/tasks/orqa/README.md
+++ b/tasks/orqa/README.md
@@ -17,13 +17,13 @@ python tools/preprocess_data.py \
     --workers 10
 
-2. The `examples/pretrain_ict.sh` script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32). +2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32). -3. Evaluate the pretrained ICT model using `examples/evaluate_retriever_nq.sh` for natural question answering dataset. +3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset. ### Supervised finetuning -1. We use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). We use the script `examples/finetune_retriever_distributed.sh` for this purpose. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. +1. We use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). We use the script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) for this purpose. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. 2. We evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. From e287bf0074e2a6038f5cd4360b687eedc086a068 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Wed, 9 Jun 2021 23:51:20 -0700 Subject: [PATCH 47/59] Adding readme --- tasks/orqa/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index dc3352733f6..b5528dc768b 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -4,7 +4,7 @@ We present below the steps on show how to run unsupervised and supervised trainining and evaluation for retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). ### Unsupervised pretraining -1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body. +1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. We run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body.
 python tools/preprocess_data.py \
@@ -19,7 +19,7 @@ python tools/preprocess_data.py \
 
 2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32).
 
-3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset.
+3. We evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset.
 
 ### Supervised finetuning
 

From c45109ed6059a66c32aeafa2c6c212fa37c5ec8c Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Wed, 9 Jun 2021 23:52:51 -0700
Subject: [PATCH 48/59] Adding readme

---
 examples/finetune_retriever_distributed.sh | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
index 592427bf4e3..535a2e053d4 100755
--- a/examples/finetune_retriever_distributed.sh
+++ b/examples/finetune_retriever_distributed.sh
@@ -36,7 +36,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
         --bert-load ${BERT_LOAD_PATH} \
         --save-interval 5000 \
         --log-interval 10 \
-        --eval-interval 250000 \
+        --eval-interval 20000 \
         --eval-iters 100 \
         --indexer-log-interval 1000 \
         --faiss-use-gpu \

From 473127f985cbfbc19f1ec7bf814d66768bfa241f Mon Sep 17 00:00:00 2001
From: Jared Casper 
Date: Thu, 10 Jun 2021 20:45:38 +0000
Subject: [PATCH 49/59] Clean up README.md a bit

---
 tasks/orqa/README.md | 14 ++++++--------
 1 file changed, 6 insertions(+), 8 deletions(-)

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
index b5528dc768b..e457e1f2890 100644
--- a/tasks/orqa/README.md
+++ b/tasks/orqa/README.md
@@ -1,10 +1,9 @@
-
 ## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
 
-We present below the steps on show how to run unsupervised and supervised trainining and evaluation for retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
+Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
 
 ### Unsupervised pretraining
-1. We use the following to preprocess dataset for Inverse Cloze Task (ICT) task, we call unsupervised pretraining. Having a corpus in loose JSON format with the intension of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. We run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document, and another with the body.
+1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT) task, which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
 
 
 python tools/preprocess_data.py \
@@ -17,16 +16,15 @@ python tools/preprocess_data.py \
     --workers 10
 
-2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses pretrained BERT model with batch size of 4096 (hence need data parallel world size of 32). +2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model with a batch size of 4096 (hence the need for a data parallel world size of 32). -3. We evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset. +3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset. ### Supervised finetuning -1. We use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). We use the script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) for this purpose. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. - -2. We evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. +1. Use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to do this. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. +2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). From 28450473ba3e22f31d8ff66a08e8db8f54385ea0 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 10 Jun 2021 14:02:20 -0700 Subject: [PATCH 50/59] addressed comments --- tasks/orqa/supervised/finetune.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py index 50ca85fc810..aed65ac9791 100644 --- a/tasks/orqa/supervised/finetune.py +++ b/tasks/orqa/supervised/finetune.py @@ -45,7 +45,6 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_): torch.distributed.all_gather(input_list, first_dim, group=group) all_input_list = torch.cat(input_list, dim=0).contiguous() max_length = torch.max(all_input_list) - min_length = torch.min(all_input_list) # if the size are different than the max, extend the tensor # accordingly From 9d350c9c102d55b1b2116b51a2c01af290e97084 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 10 Jun 2021 15:10:01 -0700 Subject: [PATCH 51/59] updated readme --- tasks/orqa/README.md | 18 +++++++++++------- tasks/orqa/supervised/data.py | 3 +++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index e457e1f2890..de53f9f633c 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -2,8 +2,10 @@ Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408). -### Unsupervised pretraining -1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT) task, which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body. +### Retriever Training + +#### Unsupervised pretraining by ICT +1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
 python tools/preprocess_data.py \
@@ -16,17 +18,19 @@ python tools/preprocess_data.py \
     --workers 10
 
-2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model with a batch size of 4096 (hence the need for a data parallel world size of 32). +2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training. -3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for natural question answering dataset. +3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf). -### Supervised finetuning +#### Supervised finetuning -1. Use the above pretrained ICT model to finetune using [Google's natural question answering dataset](https://ai.google.com/research/NaturalQuestions/). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to do this. Our finetuning consists of score scaling, longer training (80 epochs), and hard negative examples. +1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906). 2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). +### Reader Training + The reader component will be available soon. - + diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py index e2de454c92f..b45a842b61c 100644 --- a/tasks/orqa/supervised/data.py +++ b/tasks/orqa/supervised/data.py @@ -244,6 +244,9 @@ def normalize_question(question): question = question[:-1] return question +# The following class reads the datasets for training retriever as +# prepared by the DPR codebase (https://github.com/facebookresearch/DPR) + class NQSupervisedDataset(OpenRetrievalAbstractDataset): def __init__(self, name, datapaths, tokenizer, max_seq_length, \ From baf2e2a4232bafd2e228872ff4c5211f6de3dc29 Mon Sep 17 00:00:00 2001 From: Mostofa Patwary Date: Thu, 10 Jun 2021 15:29:04 -0700 Subject: [PATCH 52/59] updated readme --- tasks/orqa/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md index de53f9f633c..05960cb921f 100644 --- a/tasks/orqa/README.md +++ b/tasks/orqa/README.md @@ -4,7 +4,7 @@ Below we present the steps to run unsupervised and supervised trainining and eva ### Retriever Training -#### Unsupervised pretraining by ICT +##### Unsupervised pretraining by ICT 1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
@@ -22,7 +22,7 @@ python tools/preprocess_data.py \
 
 3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
 
-#### Supervised finetuning
+##### Supervised finetuning
 
 1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
 

From 32da2e78cf4484acf1e1b865d37233a78234c6c0 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Thu, 10 Jun 2021 15:30:02 -0700
Subject: [PATCH 53/59] updated readme

---
 tasks/orqa/README.md | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
index 05960cb921f..e8c36cf3270 100644
--- a/tasks/orqa/README.md
+++ b/tasks/orqa/README.md
@@ -2,9 +2,9 @@
 
 Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
 
-### Retriever Training
+## Retriever Training
 
-##### Unsupervised pretraining by ICT
+#### Unsupervised pretraining by ICT
 1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
 
 
@@ -22,7 +22,7 @@ python tools/preprocess_data.py \
 
 3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf).
 
-##### Supervised finetuning
+#### Supervised finetuning
 
 1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906).
 
@@ -30,7 +30,7 @@ python tools/preprocess_data.py \
 
 More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408).
 
-### Reader Training
+## Reader Training
 
 The reader component will be available soon.
 

From 4c92ca82c5c6f7157246abdaa83a0d65aab19630 Mon Sep 17 00:00:00 2001
From: Mostofa Patwary 
Date: Thu, 10 Jun 2021 15:57:46 -0700
Subject: [PATCH 54/59] updated readme

---
 tasks/orqa/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
index e8c36cf3270..a8e8f8e6fab 100644
--- a/tasks/orqa/README.md
+++ b/tasks/orqa/README.md
@@ -4,7 +4,7 @@ Below we present the steps to run unsupervised and supervised trainining and eva
 
 ## Retriever Training
 
-#### Unsupervised pretraining by ICT
+#### Unsupervised pretraining
 1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
 
 

From 2812b2ca55ba0c61b4ef5ead4ddf07c5b5f83110 Mon Sep 17 00:00:00 2001
From: Tunji Ruwase 
Date: Mon, 2 Aug 2021 15:11:22 +0000
Subject: [PATCH 55/59] zero.Init() with mpu

---
 megatron/training.py | 130 +++++++++++++++++++++----------------------
 pretrain_gpt.py      |  14 +++--
 2 files changed, 74 insertions(+), 70 deletions(-)

diff --git a/megatron/training.py b/megatron/training.py
index a46bd76480b..cfbcc2e57d2 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -15,6 +15,35 @@
 
 """Pretrain utilities."""
 
+import deepspeed
+from megatron.utils import report_memory, flops_calculator
+from megatron.schedules import forward_backward_pipelining_with_interleaving
+from megatron.schedules import forward_backward_pipelining_without_interleaving
+from megatron.schedules import forward_backward_no_pipelining
+from megatron.utils import calc_params_l2_norm
+from megatron.data.data_samplers import build_pretraining_data_loader
+from megatron.utils import unwrap_model
+from megatron.utils import check_adlr_autoresume_termination
+from megatron.model import DistributedDataParallel as LocalDDP
+from megatron.learning_rates import AnnealingLR
+from megatron.initialize import write_args_to_tensorboard
+from megatron.initialize import initialize_megatron
+from megatron.optimizer import get_megatron_optimizer
+from megatron.model import Float16Module
+from megatron.checkpointing import save_checkpoint
+from megatron.checkpointing import load_checkpoint
+from megatron import print_rank_last
+from megatron import print_rank_0
+from megatron import mpu
+from megatron import update_num_microbatches
+from megatron import is_last_rank
+from megatron import get_num_microbatches
+from megatron import get_current_global_batch_size
+from megatron import get_tensorboard_writer
+from megatron import get_timers
+from megatron import get_args
+from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
+import torch
 from datetime import datetime
 import math
 import sys
@@ -22,38 +51,6 @@
 # The earliest we can measure the start time.
 _TRAIN_START_TIME = time.time()
 
-import torch
-from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
-
-from megatron import get_args
-from megatron import get_timers
-from megatron import get_tensorboard_writer
-from megatron import get_current_global_batch_size
-from megatron import get_num_microbatches
-from megatron import is_last_rank
-from megatron import update_num_microbatches
-from megatron import mpu
-from megatron import print_rank_0
-from megatron import print_rank_last
-from megatron.checkpointing import load_checkpoint
-from megatron.checkpointing import save_checkpoint
-from megatron.model import Float16Module
-from megatron.optimizer import get_megatron_optimizer
-from megatron.initialize import initialize_megatron
-from megatron.initialize import write_args_to_tensorboard
-from megatron.learning_rates import AnnealingLR
-from megatron.model import DistributedDataParallel as LocalDDP
-from megatron.utils import check_adlr_autoresume_termination
-from megatron.utils import unwrap_model
-from megatron.data.data_samplers import build_pretraining_data_loader
-from megatron.utils import calc_params_l2_norm
-from megatron.schedules import forward_backward_no_pipelining
-from megatron.schedules import forward_backward_pipelining_without_interleaving
-from megatron.schedules import forward_backward_pipelining_with_interleaving
-from megatron.utils import report_memory, flops_calculator
-
-import deepspeed
-
 
 def print_datetime(string):
     """Note that this call will sync across all ranks."""
@@ -162,6 +159,7 @@ def pretrain(train_valid_test_dataset_provider,
                                    test_data_iterator, model,
                                    0, True)
 
+
 def update_train_iters(args):
 
     # For iteration-based training, we don't need to do anything
@@ -186,7 +184,7 @@ def update_train_iters(args):
         # Constant phase
         # Note that we throw away any partial last batch.
         iterations += (args.train_samples - consumed_samples) // \
-                      args.global_batch_size
+            args.global_batch_size
         args.train_iters = iterations
 
     print_rank_0('setting training iterations to {}'.format(args.train_iters))
@@ -218,7 +216,6 @@ def get_model(model_provider_func):
             post_process=post_process
         )
 
-
     if not isinstance(model, list):
         model = [model]
 
@@ -234,10 +231,10 @@ def get_model(model_provider_func):
     if mpu.get_data_parallel_rank() == 0:
         print(' > number of parameters on (tensor, pipeline) '
               'model parallel rank ({}, {}): {}'.format(
-            mpu.get_tensor_model_parallel_rank(),
-            mpu.get_pipeline_model_parallel_rank(),
-            sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
-                 for model_module in model])), flush=True)
+                  mpu.get_tensor_model_parallel_rank(),
+                  mpu.get_pipeline_model_parallel_rank(),
+                  sum([sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() for p in model_module.parameters()])
+                       for model_module in model])), flush=True)
 
     if args.deepspeed:
         return model
@@ -361,7 +358,7 @@ def setup_model_and_optimizer(model_provider_func):
 
     # get model without FP16 and/or TorchDDP wrappers
     if args.iteration == 0 and len(unwrapped_model) == 1 \
-        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
+            and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
         print_rank_0("Initializing ICT from pretrained BERT model")
         unwrapped_model[0].init_state_dict_from_bert()
         if args.fp16:
@@ -382,7 +379,7 @@ def train_step(forward_step_func, data_iterator,
         skipped_iter = 0
         grad_norm = 0.
         num_zeros_in_grad = 0
-        return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
+        return {'lm loss': loss}, skipped_iter, grad_norm, num_zeros_in_grad
 
     # Set grad to zero.
     if not args.deepspeed:
@@ -442,8 +439,8 @@ def train_step(forward_step_func, data_iterator,
     timers('optimizer').start()
     if args.deepspeed:
         increment = get_num_microbatches() * \
-                    args.micro_batch_size * \
-                    args.data_parallel_size
+            args.micro_batch_size * \
+            args.data_parallel_size
         model[0].step(lr_kwargs={'increment': increment})
         update_successful = model[0].was_step_applied()
     else:
@@ -458,8 +455,8 @@ def train_step(forward_step_func, data_iterator,
     else:
         if update_successful:
             increment = get_num_microbatches() * \
-                        args.micro_batch_size * \
-                        args.data_parallel_size
+                args.micro_batch_size * \
+                args.data_parallel_size
             lr_scheduler.step(increment=increment)
             skipped_iter = 0
         else:
@@ -507,8 +504,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
         else:
             value = loss_dict[key].float().sum().item()
             is_nan = value == float('inf') or \
-                     value == -float('inf') or \
-                     value != value
+                value == -float('inf') or \
+                value != value
             got_nan = got_nan or is_nan
     total_loss_dict[nan_iters_key] = total_loss_dict.get(
         nan_iters_key, 0) + int(got_nan)
@@ -542,10 +539,10 @@ def add_to_logging(name):
         get_num_microbatches()
 
     total_iterations = total_loss_dict[advanced_iters_key] + \
-                       total_loss_dict[skipped_iters_key]
+        total_loss_dict[skipped_iters_key]
 
     # Tensorboard values.
-    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
+    if writer and (iteration % args.tensorboard_log_interval == 0) and \
        is_last_rank():
         if args.log_learning_rate_to_tensorboard:
             writer.add_scalar('learning-rate', learning_rate, iteration)
@@ -556,7 +553,7 @@ def add_to_logging(name):
             writer.add_scalar('batch-size vs samples', batch_size,
                               args.consumed_train_samples)
         for key in loss_dict:
-            writer.add_scalar(key , loss_dict[key], iteration)
+            writer.add_scalar(key, loss_dict[key], iteration)
             writer.add_scalar(key + ' vs samples', loss_dict[key],
                               args.consumed_train_samples)
         if args.log_loss_scale_to_tensorboard:
@@ -598,7 +595,7 @@ def add_to_logging(name):
             if key not in [advanced_iters_key, skipped_iters_key,
                            nan_iters_key]:
                 avg = total_loss_dict[key].item() / \
-                      float(max(1, total_loss_dict[advanced_iters_key]))
+                    float(max(1, total_loss_dict[advanced_iters_key]))
                 if avg > 0.0:
                     log_string += ' {}: {:.6E} |'.format(key, avg)
                 total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
@@ -666,11 +663,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
         if args.deepspeed:
             # inform deepspeed of any batch size changes
             global_batch_size = mpu.get_data_parallel_world_size() * \
-                                args.micro_batch_size * \
-                                get_num_microbatches()
+                args.micro_batch_size * \
+                get_num_microbatches()
             model[0].set_train_batch_size(global_batch_size)
 
-
         loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
             train_step(forward_step_func,
                        train_data_iterator,
@@ -679,8 +675,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                        lr_scheduler)
         iteration += 1
         args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
-                                       args.micro_batch_size * \
-                                       get_num_microbatches()
+            args.micro_batch_size * \
+            get_num_microbatches()
 
         # Logging.
         if args.deepspeed:
@@ -743,7 +739,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
             print_datetime('exiting program at iteration {}'.format(iteration))
             sys.exit()
 
-
     return iteration
 
 
@@ -772,17 +767,17 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                     forward_backward_func = forward_backward_pipelining_without_interleaving
             else:
                 forward_backward_func = forward_backward_no_pipelining
-            
+
             if args.deepspeed:
                 # DeepSpeed uses eval_batch() and already aggregates losses.
                 assert isinstance(model, list) and len(model) == 1
                 loss = model[0].eval_batch(data_iterator)
-                loss_dicts = [{'lm loss' : loss}] * get_num_microbatches()
+                loss_dicts = [{'lm loss': loss}] * get_num_microbatches()
             else:
                 loss_dicts = forward_backward_func(
                     forward_step_func, data_iterator, model, optimizer=None,
                     timers=None, forward_only=True)
-            
+
             if mpu.is_pipeline_last_stage(ignore_virtual=True):
                 # Reduce across processes.
                 for loss_dict in loss_dicts:
@@ -791,8 +786,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                             key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
 
             args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
-                                           * args.micro_batch_size \
-                                           * get_num_microbatches()
+                * args.micro_batch_size \
+                * get_num_microbatches()
     # Move model back to the train mode.
     for model_module in model:
         model_module.train()
@@ -802,6 +797,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
 
     return total_loss_dict
 
+
 def evaluate_and_print_results(prefix, forward_step_func,
                                data_iterator, model,
                                iteration, verbose=False):
@@ -839,6 +835,7 @@ def cyclic_iter(iter):
         for x in iter:
             yield x
 
+
 def build_train_valid_test_data_iterators(
         build_train_valid_test_datasets_provider):
     """XXX"""
@@ -854,6 +851,8 @@ def build_train_valid_test_data_iterators(
             'only backward compatiblity support for iteration-based training'
         args.consumed_train_samples = args.iteration * args.global_batch_size
     if args.iteration > 0 and args.consumed_valid_samples == 0:
+        print_rank_0(
+            f'iteration = {args.iteration} consumed_valid_samples = {args.consumed_valid_samples} consumed train samples = {args.consumed_train_samples}')
         assert args.train_samples is None, \
             'only backward compatiblity support for iteration-based training'
         args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
@@ -868,7 +867,7 @@ def build_train_valid_test_data_iterators(
         else:
             train_samples = args.train_iters * args.global_batch_size
         eval_iters = (args.train_iters // args.eval_interval + 1) * \
-                     args.eval_iters
+            args.eval_iters
         test_iters = args.eval_iters
         train_val_test_num_samples = [train_samples,
                                       eval_iters * args.global_batch_size,
@@ -907,26 +906,25 @@ def build_train_valid_test_data_iterators(
     args.do_valid = flags[1].item()
     args.do_test = flags[2].item()
 
-
     # Build iterators.
     dl_type = args.dataloader_type
     assert dl_type in ['single', 'cyclic']
 
     if train_dataloader is not None:
         train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
-                              else iter(cyclic_iter(train_dataloader))
+            else iter(cyclic_iter(train_dataloader))
     else:
         train_data_iterator = None
 
     if valid_dataloader is not None:
         valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
-                              else iter(cyclic_iter(valid_dataloader))
+            else iter(cyclic_iter(valid_dataloader))
     else:
         valid_data_iterator = None
 
     if test_dataloader is not None:
         test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
-                             else iter(cyclic_iter(test_dataloader))
+            else iter(cyclic_iter(test_dataloader))
     else:
         test_data_iterator = None
 
diff --git a/pretrain_gpt.py b/pretrain_gpt.py
index 4e6e745cf6b..4230beefecf 100644
--- a/pretrain_gpt.py
+++ b/pretrain_gpt.py
@@ -33,6 +33,7 @@
 import os
 import subprocess
 
+
 def model_provider(pre_process=True, post_process=True):
     """Build the model."""
 
@@ -41,9 +42,10 @@ def model_provider(pre_process=True, post_process=True):
 
     args = get_args()
     with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
-                             remote_device=None if args.remote_device=='none' else args.remote_device,
+                             remote_device=None if args.remote_device == 'none' else args.remote_device,
                              config=args.deepspeed_config,
-                             enabled=args.zero_stage==3):
+                             enabled=args.zero_stage == 3,
+                             mpu=mpu):
         if args.deepspeed:
             model = GPTModelPipe(
                 num_tokentypes=0,
@@ -59,14 +61,14 @@ def model_provider(pre_process=True, post_process=True):
             attention_mask = torch.tril(torch.ones(
                 (1, args.seq_length, args.seq_length), device=torch.cuda.current_device())).view(
                     1, 1, args.seq_length, args.seq_length)
-            
+
             # Convert attention mask to binary:
             attention_mask = (attention_mask < 0.5)
             if args.fp16:
                 attention_mask = attention_mask.half()
             elif args.bf16:
                 attention_mask = attention_mask.bfloat16()
-            
+
             args.attn_mask = attention_mask
 
         else:
@@ -111,6 +113,7 @@ def get_batch(data_iterator):
 
     return tokens, labels, loss_mask, attention_mask, position_ids
 
+
 def get_batch_pipe(data):
     """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
     args = get_args()
@@ -138,6 +141,7 @@ def get_batch_pipe(data):
 
     return (tokens, position_ids, attention_mask), (labels, loss_mask)
 
+
 def loss_func(loss_mask, output_tensor):
     losses = output_tensor.float()
     loss_mask = loss_mask.view(-1).float()
@@ -184,10 +188,12 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
 
     return train_ds, valid_ds, test_ds
 
+
 def command_exists(cmd):
     result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True)
     return result.wait() == 0
 
+
 def git_ds_info():
     from deepspeed.env_report import main as ds_report
     ds_report()

From 7222a978dc02e5fb6be4873a2abf865145f47146 Mon Sep 17 00:00:00 2001
From: Tunji Ruwase 
Date: Mon, 2 Aug 2021 15:57:01 +0000
Subject: [PATCH 56/59] Remove debug print

---
 megatron/training.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/megatron/training.py b/megatron/training.py
index cfbcc2e57d2..7edeb196078 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -851,8 +851,6 @@ def build_train_valid_test_data_iterators(
             'only backward compatiblity support for iteration-based training'
         args.consumed_train_samples = args.iteration * args.global_batch_size
     if args.iteration > 0 and args.consumed_valid_samples == 0:
-        print_rank_0(
-            f'iteration = {args.iteration} consumed_valid_samples = {args.consumed_valid_samples} consumed train samples = {args.consumed_train_samples}')
         assert args.train_samples is None, \
             'only backward compatiblity support for iteration-based training'
         args.consumed_valid_samples = (args.iteration // args.eval_interval) * \

From 3644a9d6db6946ab22926a0a83c5baf35a8da67b Mon Sep 17 00:00:00 2001
From: Tunji Ruwase 
Date: Mon, 2 Aug 2021 16:15:21 +0000
Subject: [PATCH 57/59] Manual revert training.py to main verion

---
 megatron/training.py | 128 ++++++++++++++++++++++---------------------
 1 file changed, 66 insertions(+), 62 deletions(-)

diff --git a/megatron/training.py b/megatron/training.py
index 7edeb196078..a46bd76480b 100644
--- a/megatron/training.py
+++ b/megatron/training.py
@@ -15,35 +15,6 @@
 
 """Pretrain utilities."""
 
-import deepspeed
-from megatron.utils import report_memory, flops_calculator
-from megatron.schedules import forward_backward_pipelining_with_interleaving
-from megatron.schedules import forward_backward_pipelining_without_interleaving
-from megatron.schedules import forward_backward_no_pipelining
-from megatron.utils import calc_params_l2_norm
-from megatron.data.data_samplers import build_pretraining_data_loader
-from megatron.utils import unwrap_model
-from megatron.utils import check_adlr_autoresume_termination
-from megatron.model import DistributedDataParallel as LocalDDP
-from megatron.learning_rates import AnnealingLR
-from megatron.initialize import write_args_to_tensorboard
-from megatron.initialize import initialize_megatron
-from megatron.optimizer import get_megatron_optimizer
-from megatron.model import Float16Module
-from megatron.checkpointing import save_checkpoint
-from megatron.checkpointing import load_checkpoint
-from megatron import print_rank_last
-from megatron import print_rank_0
-from megatron import mpu
-from megatron import update_num_microbatches
-from megatron import is_last_rank
-from megatron import get_num_microbatches
-from megatron import get_current_global_batch_size
-from megatron import get_tensorboard_writer
-from megatron import get_timers
-from megatron import get_args
-from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
-import torch
 from datetime import datetime
 import math
 import sys
@@ -51,6 +22,38 @@
 # The earliest we can measure the start time.
 _TRAIN_START_TIME = time.time()
 
+import torch
+from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
+
+from megatron import get_args
+from megatron import get_timers
+from megatron import get_tensorboard_writer
+from megatron import get_current_global_batch_size
+from megatron import get_num_microbatches
+from megatron import is_last_rank
+from megatron import update_num_microbatches
+from megatron import mpu
+from megatron import print_rank_0
+from megatron import print_rank_last
+from megatron.checkpointing import load_checkpoint
+from megatron.checkpointing import save_checkpoint
+from megatron.model import Float16Module
+from megatron.optimizer import get_megatron_optimizer
+from megatron.initialize import initialize_megatron
+from megatron.initialize import write_args_to_tensorboard
+from megatron.learning_rates import AnnealingLR
+from megatron.model import DistributedDataParallel as LocalDDP
+from megatron.utils import check_adlr_autoresume_termination
+from megatron.utils import unwrap_model
+from megatron.data.data_samplers import build_pretraining_data_loader
+from megatron.utils import calc_params_l2_norm
+from megatron.schedules import forward_backward_no_pipelining
+from megatron.schedules import forward_backward_pipelining_without_interleaving
+from megatron.schedules import forward_backward_pipelining_with_interleaving
+from megatron.utils import report_memory, flops_calculator
+
+import deepspeed
+
 
 def print_datetime(string):
     """Note that this call will sync across all ranks."""
@@ -159,7 +162,6 @@ def pretrain(train_valid_test_dataset_provider,
                                    test_data_iterator, model,
                                    0, True)
 
-
 def update_train_iters(args):
 
     # For iteration-based training, we don't need to do anything
@@ -184,7 +186,7 @@ def update_train_iters(args):
         # Constant phase
         # Note that we throw away any partial last batch.
         iterations += (args.train_samples - consumed_samples) // \
-            args.global_batch_size
+                      args.global_batch_size
         args.train_iters = iterations
 
     print_rank_0('setting training iterations to {}'.format(args.train_iters))
@@ -216,6 +218,7 @@ def get_model(model_provider_func):
             post_process=post_process
         )
 
+
     if not isinstance(model, list):
         model = [model]
 
@@ -231,10 +234,10 @@ def get_model(model_provider_func):
     if mpu.get_data_parallel_rank() == 0:
         print(' > number of parameters on (tensor, pipeline) '
               'model parallel rank ({}, {}): {}'.format(
-                  mpu.get_tensor_model_parallel_rank(),
-                  mpu.get_pipeline_model_parallel_rank(),
-                  sum([sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() for p in model_module.parameters()])
-                       for model_module in model])), flush=True)
+            mpu.get_tensor_model_parallel_rank(),
+            mpu.get_pipeline_model_parallel_rank(),
+            sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
+                 for model_module in model])), flush=True)
 
     if args.deepspeed:
         return model
@@ -358,7 +361,7 @@ def setup_model_and_optimizer(model_provider_func):
 
     # get model without FP16 and/or TorchDDP wrappers
     if args.iteration == 0 and len(unwrapped_model) == 1 \
-            and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
+        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
         print_rank_0("Initializing ICT from pretrained BERT model")
         unwrapped_model[0].init_state_dict_from_bert()
         if args.fp16:
@@ -379,7 +382,7 @@ def train_step(forward_step_func, data_iterator,
         skipped_iter = 0
         grad_norm = 0.
         num_zeros_in_grad = 0
-        return {'lm loss': loss}, skipped_iter, grad_norm, num_zeros_in_grad
+        return {'lm loss' : loss}, skipped_iter, grad_norm, num_zeros_in_grad
 
     # Set grad to zero.
     if not args.deepspeed:
@@ -439,8 +442,8 @@ def train_step(forward_step_func, data_iterator,
     timers('optimizer').start()
     if args.deepspeed:
         increment = get_num_microbatches() * \
-            args.micro_batch_size * \
-            args.data_parallel_size
+                    args.micro_batch_size * \
+                    args.data_parallel_size
         model[0].step(lr_kwargs={'increment': increment})
         update_successful = model[0].was_step_applied()
     else:
@@ -455,8 +458,8 @@ def train_step(forward_step_func, data_iterator,
     else:
         if update_successful:
             increment = get_num_microbatches() * \
-                args.micro_batch_size * \
-                args.data_parallel_size
+                        args.micro_batch_size * \
+                        args.data_parallel_size
             lr_scheduler.step(increment=increment)
             skipped_iter = 0
         else:
@@ -504,8 +507,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
         else:
             value = loss_dict[key].float().sum().item()
             is_nan = value == float('inf') or \
-                value == -float('inf') or \
-                value != value
+                     value == -float('inf') or \
+                     value != value
             got_nan = got_nan or is_nan
     total_loss_dict[nan_iters_key] = total_loss_dict.get(
         nan_iters_key, 0) + int(got_nan)
@@ -539,10 +542,10 @@ def add_to_logging(name):
         get_num_microbatches()
 
     total_iterations = total_loss_dict[advanced_iters_key] + \
-        total_loss_dict[skipped_iters_key]
+                       total_loss_dict[skipped_iters_key]
 
     # Tensorboard values.
-    if writer and (iteration % args.tensorboard_log_interval == 0) and \
+    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
        is_last_rank():
         if args.log_learning_rate_to_tensorboard:
             writer.add_scalar('learning-rate', learning_rate, iteration)
@@ -553,7 +556,7 @@ def add_to_logging(name):
             writer.add_scalar('batch-size vs samples', batch_size,
                               args.consumed_train_samples)
         for key in loss_dict:
-            writer.add_scalar(key, loss_dict[key], iteration)
+            writer.add_scalar(key , loss_dict[key], iteration)
             writer.add_scalar(key + ' vs samples', loss_dict[key],
                               args.consumed_train_samples)
         if args.log_loss_scale_to_tensorboard:
@@ -595,7 +598,7 @@ def add_to_logging(name):
             if key not in [advanced_iters_key, skipped_iters_key,
                            nan_iters_key]:
                 avg = total_loss_dict[key].item() / \
-                    float(max(1, total_loss_dict[advanced_iters_key]))
+                      float(max(1, total_loss_dict[advanced_iters_key]))
                 if avg > 0.0:
                     log_string += ' {}: {:.6E} |'.format(key, avg)
                 total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
@@ -663,10 +666,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
         if args.deepspeed:
             # inform deepspeed of any batch size changes
             global_batch_size = mpu.get_data_parallel_world_size() * \
-                args.micro_batch_size * \
-                get_num_microbatches()
+                                args.micro_batch_size * \
+                                get_num_microbatches()
             model[0].set_train_batch_size(global_batch_size)
 
+
         loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
             train_step(forward_step_func,
                        train_data_iterator,
@@ -675,8 +679,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
                        lr_scheduler)
         iteration += 1
         args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
-            args.micro_batch_size * \
-            get_num_microbatches()
+                                       args.micro_batch_size * \
+                                       get_num_microbatches()
 
         # Logging.
         if args.deepspeed:
@@ -739,6 +743,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
             print_datetime('exiting program at iteration {}'.format(iteration))
             sys.exit()
 
+
     return iteration
 
 
@@ -767,17 +772,17 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                     forward_backward_func = forward_backward_pipelining_without_interleaving
             else:
                 forward_backward_func = forward_backward_no_pipelining
-
+            
             if args.deepspeed:
                 # DeepSpeed uses eval_batch() and already aggregates losses.
                 assert isinstance(model, list) and len(model) == 1
                 loss = model[0].eval_batch(data_iterator)
-                loss_dicts = [{'lm loss': loss}] * get_num_microbatches()
+                loss_dicts = [{'lm loss' : loss}] * get_num_microbatches()
             else:
                 loss_dicts = forward_backward_func(
                     forward_step_func, data_iterator, model, optimizer=None,
                     timers=None, forward_only=True)
-
+            
             if mpu.is_pipeline_last_stage(ignore_virtual=True):
                 # Reduce across processes.
                 for loss_dict in loss_dicts:
@@ -786,8 +791,8 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
                             key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
 
             args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
-                * args.micro_batch_size \
-                * get_num_microbatches()
+                                           * args.micro_batch_size \
+                                           * get_num_microbatches()
     # Move model back to the train mode.
     for model_module in model:
         model_module.train()
@@ -797,7 +802,6 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
 
     return total_loss_dict
 
-
 def evaluate_and_print_results(prefix, forward_step_func,
                                data_iterator, model,
                                iteration, verbose=False):
@@ -835,7 +839,6 @@ def cyclic_iter(iter):
         for x in iter:
             yield x
 
-
 def build_train_valid_test_data_iterators(
         build_train_valid_test_datasets_provider):
     """XXX"""
@@ -865,7 +868,7 @@ def build_train_valid_test_data_iterators(
         else:
             train_samples = args.train_iters * args.global_batch_size
         eval_iters = (args.train_iters // args.eval_interval + 1) * \
-            args.eval_iters
+                     args.eval_iters
         test_iters = args.eval_iters
         train_val_test_num_samples = [train_samples,
                                       eval_iters * args.global_batch_size,
@@ -904,25 +907,26 @@ def build_train_valid_test_data_iterators(
     args.do_valid = flags[1].item()
     args.do_test = flags[2].item()
 
+
     # Build iterators.
     dl_type = args.dataloader_type
     assert dl_type in ['single', 'cyclic']
 
     if train_dataloader is not None:
         train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
-            else iter(cyclic_iter(train_dataloader))
+                              else iter(cyclic_iter(train_dataloader))
     else:
         train_data_iterator = None
 
     if valid_dataloader is not None:
         valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
-            else iter(cyclic_iter(valid_dataloader))
+                              else iter(cyclic_iter(valid_dataloader))
     else:
         valid_data_iterator = None
 
     if test_dataloader is not None:
         test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
-            else iter(cyclic_iter(test_dataloader))
+                             else iter(cyclic_iter(test_dataloader))
     else:
         test_data_iterator = None
 

From 34511160936c351575f92b56aefc0c8e2db916ea Mon Sep 17 00:00:00 2001
From: Tunji Ruwase 
Date: Mon, 9 Aug 2021 14:15:26 +0000
Subject: [PATCH 58/59] Use new argument name

---
 pretrain_gpt.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pretrain_gpt.py b/pretrain_gpt.py
index 4230beefecf..7dd8ad0d564 100644
--- a/pretrain_gpt.py
+++ b/pretrain_gpt.py
@@ -43,7 +43,7 @@ def model_provider(pre_process=True, post_process=True):
     args = get_args()
     with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
                              remote_device=None if args.remote_device == 'none' else args.remote_device,
-                             config=args.deepspeed_config,
+                             config_dict_or_path=args.deepspeed_config,
                              enabled=args.zero_stage == 3,
                              mpu=mpu):
         if args.deepspeed:

From d64338a1c8a04275000d08eb09b3faf398a4c283 Mon Sep 17 00:00:00 2001
From: Tunji Ruwase 
Date: Mon, 9 Aug 2021 19:44:17 +0000
Subject: [PATCH 59/59] Revert "Merge branch 'master' of
 github.com:microsoft/Megatron-DeepSpeed into olruwase/zero_init_mpu"

This reverts commit 3e10ebab948aec17be872382480f2aaa6247e0c0, reversing
changes made to 3644a9d6db6946ab22926a0a83c5baf35a8da67b.
---
 README.md                                     |   8 +-
 examples/create_embeddings.sh                 |  32 ++
 ...ever_nq.sh => evaluate_ict_zeroshot_nq.sh} |  14 +-
 examples/finetune_retriever_distributed.sh    |  56 ----
 examples/pretrain_t5.sh                       |   5 +-
 examples/pretrain_t5_distributed.sh           |   5 +-
 examples/pretrain_t5_distributed_with_mp.sh   |   6 +-
 megatron/checkpointing.py                     |   2 +-
 megatron/indexer.py                           |  12 +-
 megatron/learning_rates.py                    |   2 +-
 megatron/model/biencoder_model.py             |  53 +---
 megatron/model/vit_model.py                   | 128 +++-----
 megatron/tokenizer/bert_tokenization.py       |  29 --
 pretrain_ict.py                               |  51 ++-
 pretrain_vit.py                               |  34 +-
 tasks/finetune_utils.py                       |  40 +--
 tasks/main.py                                 |  27 +-
 tasks/orqa/README.md                          |  36 ---
 tasks/orqa/evaluate_orqa.py                   |  20 +-
 tasks/orqa/evaluate_utils.py                  |  16 +-
 .../{unsupervised => natural_questions}/nq.py |   0
 .../qa_utils.py                               |   2 +-
 .../tokenizers.py                             |   0
 tasks/orqa/supervised/data.py                 | 300 ------------------
 tasks/orqa/supervised/eval_utils.py           | 206 ------------
 tasks/orqa/supervised/finetune.py             | 251 ---------------
 tasks/vision/classification.py                |   5 +-
 tasks/vision/eval_utils.py                    |  76 ++---
 tasks/vision/finetune_utils.py                |  63 ++--
 tools/create_doc_index.py                     |  33 ++
 30 files changed, 248 insertions(+), 1264 deletions(-)
 create mode 100644 examples/create_embeddings.sh
 rename examples/{evaluate_retriever_nq.sh => evaluate_ict_zeroshot_nq.sh} (75%)
 delete mode 100755 examples/finetune_retriever_distributed.sh
 delete mode 100644 tasks/orqa/README.md
 rename tasks/orqa/{unsupervised => natural_questions}/nq.py (100%)
 rename tasks/orqa/{unsupervised => natural_questions}/qa_utils.py (98%)
 rename tasks/orqa/{unsupervised => natural_questions}/tokenizers.py (100%)
 delete mode 100644 tasks/orqa/supervised/data.py
 delete mode 100644 tasks/orqa/supervised/eval_utils.py
 delete mode 100644 tasks/orqa/supervised/finetune.py
 create mode 100644 tools/create_doc_index.py

diff --git a/README.md b/README.md
index de208d5b3bf..489d3771de6 100644
--- a/README.md
+++ b/README.md
@@ -103,11 +103,6 @@ python tools/preprocess_data.py \
 
 The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension.
 
-For T5 use the same preprocessing as BERT, perhaps renaming it to:
-
-       --output-prefix my-t5 \
-
- Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type:
 python tools/preprocess_data.py \
@@ -242,14 +237,13 @@ T5_ARGS="--num-layers 24 \
          --micro-batch-size 16 \
          --global-batch-size 2048 \
          --vocab-file $VOCAB_FILE \
-         --vocab-extra-ids 100 \
          --split 949,50,1 \
          --fp16"
 
 OUTPUT_ARGS=<same as those in BERT pretraining above>
 
 python pretrain_t5.py \
-       $T5_ARGS \
+       $BERT_ARGS \
        $OUTPUT_ARGS \
        --save $CHECKPOINT_PATH \
        --load $CHECKPOINT_PATH \
diff --git a/examples/create_embeddings.sh b/examples/create_embeddings.sh
new file mode 100644
index 00000000000..59a5839f7e2
--- /dev/null
+++ b/examples/create_embeddings.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
+
+RANK=0
+WORLD_SIZE=1
+
+# Wikipedia data can be downloaded from the following link:
+# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
+EVIDENCE_DATA_DIR=
+EMBEDDING_PATH=
+CHECKPOINT_PATH=
+
+python tools/create_doc_index.py \
+    --num-layers 12 \
+    --hidden-size 768 \
+    --num-attention-heads 12 \
+    --tensor-model-parallel-size 1 \
+    --micro-batch-size 128 \
+    --checkpoint-activations \
+    --seq-length 512 \
+    --retriever-seq-length 256 \
+    --max-position-embeddings 512 \
+    --load ${CHECKPOINT_PATH} \
+    --evidence-data-path ${EVIDENCE_DATA_DIR} \
+    --embedding-path ${EMBEDDING_PATH} \
+    --indexer-log-interval 1000 \
+    --indexer-batch-size 128 \
+    --vocab-file bert-vocab.txt \
+    --num-workers 2 \
+    --fp16
+
diff --git a/examples/evaluate_retriever_nq.sh b/examples/evaluate_ict_zeroshot_nq.sh
similarity index 75%
rename from examples/evaluate_retriever_nq.sh
rename to examples/evaluate_ict_zeroshot_nq.sh
index 8b87be3024a..e1ce45a9342 100644
--- a/examples/evaluate_retriever_nq.sh
+++ b/examples/evaluate_ict_zeroshot_nq.sh
@@ -1,19 +1,19 @@
 #!/bin/bash
 
 # Evaluate natural question test data given Wikipedia embeddings and pretrained
-# ICT model or a finetuned model for Natural Question task
+# ICT model
 
 # Datasets can be downloaded from the following link:
 # https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
 
 EVIDENCE_DATA_DIR=
 EMBEDDING_PATH=
-CHECKPOINT_PATH=
+CHECKPOINT_PATH=
 
-QA_FILE=
+QA_FILE=
 
 python tasks/main.py \
-    --task RETRIEVER-EVAL \
+    --task ICT-ZEROSHOT-NQ \
     --tokenizer-type BertWordPieceLowerCase \
     --num-layers 12 \
     --hidden-size 768 \
@@ -29,10 +29,8 @@ python tasks/main.py \
     --retriever-seq-length 256 \
     --vocab-file  bert-vocab.txt\
     --qa-data-test ${QA_FILE} \
+    --num-workers 2 \
     --faiss-use-gpu \
     --retriever-report-topk-accuracies 1 5 20 100 \
-    --fp16 \
-    --indexer-log-interval 1000 \
-    --indexer-batch-size 128
-
+    --fp16
 
diff --git a/examples/finetune_retriever_distributed.sh b/examples/finetune_retriever_distributed.sh
deleted file mode 100755
index 535a2e053d4..00000000000
--- a/examples/finetune_retriever_distributed.sh
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/bin/bash
-
-# Finetune a BERT or pretrained ICT model using Google natural question data 
-# Datasets can be downloaded from the following link:
-# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
-
-WORLD_SIZE=8
-
-DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
-                  --nnodes 1 \
-                  --node_rank 0 \
-                  --master_addr localhost \
-                  --master_port 6000"
-
-CHECKPOINT_PATH=
-
-# Load either of the below
-BERT_LOAD_PATH=
-PRETRAINED_CHECKPOINT=
-
-python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
-        --task RET-FINETUNE-NQ \
-        --train-with-neg \
-        --train-hard-neg 1 \
-        --pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
-        --num-layers 12 \
-        --hidden-size 768 \
-        --num-attention-heads 12 \
-        --tensor-model-parallel-size 1 \
-        --tokenizer-type BertWordPieceLowerCase \
-        --train-data nq-train.json \
-        --valid-data nq-dev.json \
-        --save ${CHECKPOINT_PATH} \
-        --load ${CHECKPOINT_PATH} \
-        --vocab-file bert-vocab.txt \
-        --bert-load ${BERT_LOAD_PATH} \
-        --save-interval 5000 \
-        --log-interval 10 \
-        --eval-interval 20000 \
-        --eval-iters 100 \
-        --indexer-log-interval 1000 \
-        --faiss-use-gpu \
-        --DDP-impl torch \
-        --fp16 \
-        --retriever-report-topk-accuracies 1 5 10 20 100 \
-        --seq-length 512 \
-        --retriever-seq-length 256 \
-        --max-position-embeddings 512 \
-        --retriever-score-scaling \
-        --epochs 80 \
-        --micro-batch-size 8 \
-        --eval-micro-batch-size 16 \
-        --indexer-batch-size 128 \
-        --lr 2e-5 \
-        --lr-warmup-fraction 0.01 \
-        --weight-decay 1e-1
diff --git a/examples/pretrain_t5.sh b/examples/pretrain_t5.sh
index 91fd5929bf6..71fea8489a7 100644
--- a/examples/pretrain_t5.sh
+++ b/examples/pretrain_t5.sh
@@ -15,7 +15,7 @@ python pretrain_t5.py \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 16 \
+       --global-batch-size 2048 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -35,5 +35,4 @@ python pretrain_t5.py \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16 \
-       --vocab-extra-ids 100
+       --fp16
diff --git a/examples/pretrain_t5_distributed.sh b/examples/pretrain_t5_distributed.sh
index 2beb1cdaca6..778b4ad2a36 100644
--- a/examples/pretrain_t5_distributed.sh
+++ b/examples/pretrain_t5_distributed.sh
@@ -24,7 +24,7 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 128 \
+       --global-batch-size 2048 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -44,5 +44,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16 \
-       --vocab-extra-ids 100
+       --fp16
diff --git a/examples/pretrain_t5_distributed_with_mp.sh b/examples/pretrain_t5_distributed_with_mp.sh
index 23f1cd664e3..9be70393d86 100644
--- a/examples/pretrain_t5_distributed_with_mp.sh
+++ b/examples/pretrain_t5_distributed_with_mp.sh
@@ -24,7 +24,8 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --encoder-seq-length 512 \
        --decoder-seq-length 128 \
        --micro-batch-size 16 \
-       --global-batch-size 128 \
+       --global-batch-size 2048 \
+       --seq-length 512 \
        --max-position-embeddings 512 \
        --train-iters 1000000 \
        --lr-decay-iters 1000000 \
@@ -44,5 +45,4 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS \
        --save-interval 10000 \
        --eval-interval 1000 \
        --eval-iters 10 \
-       --fp16  \
-       --vocab-extra-ids 100
+       --fp16
diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py
index 49c4b6533ff..3cc6a8e2e41 100644
--- a/megatron/checkpointing.py
+++ b/megatron/checkpointing.py
@@ -60,8 +60,8 @@ def _compare(arg_name, old_arg_name=None):
     _compare('num_layers')
     _compare('hidden_size')
     _compare('num_attention_heads')
+    _compare('max_position_embeddings')
     if args.vocab_file:
-        _compare('max_position_embeddings')
         _compare('make_vocab_size_divisible_by')
         _compare('padded_vocab_size')
         _compare('tokenizer_type')
diff --git a/megatron/indexer.py b/megatron/indexer.py
index d2ff9e36f85..c0d1ca7de15 100644
--- a/megatron/indexer.py
+++ b/megatron/indexer.py
@@ -1,16 +1,15 @@
 import sys
-import time
 import torch
 import torch.distributed as dist
 
-from megatron import get_args, print_rank_0
+from megatron import get_args
 from megatron import mpu
 from megatron.checkpointing import load_biencoder_checkpoint
 from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
 from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
 from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
 from megatron.data.realm_index import detach, OpenRetreivalDataStore
-from megatron.model.biencoder_model import get_model_provider
+from megatron.model.biencoder_model import biencoder_model_provider
 from megatron.training import get_model
 
 
@@ -30,6 +29,7 @@ def __init__(self):
         # need to know whether we're using a REALM checkpoint (args.load)
         # or ICT checkpoint
         assert not (args.load and args.ict_load)
+        #self.using_realm_chkpt = args.ict_load is None
 
         self.log_interval = args.indexer_log_interval
         self.batch_size = args.indexer_batch_size
@@ -47,8 +47,8 @@ def load_attributes(self):
         if self.biencoder_shared_query_context_model:
             only_context_model = False
 
-        model = get_model(get_model_provider(only_context_model=\
-            only_context_model, biencoder_shared_query_context_model=\
+        model = get_model(lambda: biencoder_model_provider(only_context_model \
+            = only_context_model, biencoder_shared_query_context_model = \
             self.biencoder_shared_query_context_model))
 
         self.model = load_biencoder_checkpoint(model,
@@ -85,7 +85,6 @@ def build_and_save_index(self):
         """
         assert len(self.model) == 1
         unwrapped_model = self.model[0]
-
         while not hasattr(unwrapped_model, 'embed_text'):
             unwrapped_model = unwrapped_model.module
 
@@ -104,7 +103,6 @@ def build_and_save_index(self):
             context_logits = unwrapped_model.embed_text(
                 unwrapped_model.context_model, context_tokens, context_mask,
                 context_types)
-
             context_logits = detach(context_logits)
             row_id = detach(row_id)
 
diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py
index c53af8d54ac..d200bdb176a 100644
--- a/megatron/learning_rates.py
+++ b/megatron/learning_rates.py
@@ -87,7 +87,7 @@ def get_lr(self):
         else:
             raise Exception('{} decay style is not supported.'.format(
                 self.decay_style))
-
+       
         return self.min_lr + coeff * delta_lr
 
 
diff --git a/megatron/model/biencoder_model.py b/megatron/model/biencoder_model.py
index e1f94bf1c43..51ac0a060d4 100644
--- a/megatron/model/biencoder_model.py
+++ b/megatron/model/biencoder_model.py
@@ -15,30 +15,11 @@
 from megatron.model.utils import scaled_init_method_normal
 from .module import MegatronModule
 
-def get_model_provider(only_query_model=False, only_context_model=False,
-        biencoder_shared_query_context_model=False):
-
-    def model_provider(pre_process=True, post_process=True):
-        """Build the model."""
-
-        print_rank_0('building Bienoder model ...')
-        model = biencoder_model_provider(only_query_model=only_query_model,
-                only_context_model = only_context_model,
-                biencoder_shared_query_context_model = \
-                biencoder_shared_query_context_model,
-                pre_process=pre_process, post_process=post_process)
-
-        return model
-
-    return model_provider
-
-
 def biencoder_model_provider(only_query_model=False,
                              only_context_model=False,
-                             biencoder_shared_query_context_model=False,
-                             pre_process=True,
-                             post_process=True):
+                             biencoder_shared_query_context_model=False):
     """Build the model."""
+    args = get_args()
 
     assert mpu.get_tensor_model_parallel_world_size() == 1 and \
         mpu.get_pipeline_model_parallel_world_size() == 1, \
@@ -54,9 +35,7 @@ def biencoder_model_provider(only_query_model=False,
         only_query_model=only_query_model,
         only_context_model=only_context_model,
         biencoder_shared_query_context_model=\
-        biencoder_shared_query_context_model,
-        pre_process=pre_process,
-        post_process=post_process)
+            biencoder_shared_query_context_model)
 
     return model
 
@@ -69,17 +48,13 @@ def __init__(self,
                  parallel_output=True,
                  only_query_model=False,
                  only_context_model=False,
-                 biencoder_shared_query_context_model=False,
-                 pre_process=True,
-                 post_process=True):
+                 biencoder_shared_query_context_model=False):
         super(BiEncoderModel, self).__init__()
         args = get_args()
 
         bert_kwargs = dict(
             num_tokentypes=num_tokentypes,
-            parallel_output=parallel_output,
-            pre_process=pre_process,
-            post_process=post_process)
+            parallel_output=parallel_output)
 
         self.biencoder_shared_query_context_model = \
             biencoder_shared_query_context_model
@@ -103,13 +78,6 @@ def __init__(self,
                 self.context_model = PretrainedBertModel(**bert_kwargs)
                 self._context_key = 'context_model'
 
-    def set_input_tensor(self, input_tensor):
-        """See megatron.model.transformer.set_input_tensor()"""
-        # this is just a placeholder and will be needed when model
-        # parallelism will be used
-        # self.language_model.set_input_tensor(input_tensor)
-        return
-
     def forward(self, query_tokens, query_attention_mask, query_types,
                 context_tokens, context_attention_mask, context_types):
         """Run a forward pass for each of the models and
@@ -249,7 +217,7 @@ class PretrainedBertModel(MegatronModule):
     learned information retrieval."""
 
     def __init__(self, num_tokentypes=2,
-            parallel_output=True, pre_process=True, post_process=True):
+            parallel_output=True):
         super(PretrainedBertModel, self).__init__()
 
         args = get_args()
@@ -257,8 +225,6 @@ def __init__(self, num_tokentypes=2,
         self.pad_id = tokenizer.pad
         self.biencoder_projection_dim = args.biencoder_projection_dim
         self.parallel_output = parallel_output
-        self.pre_process = pre_process
-        self.post_process = post_process
         init_method = init_method_normal(args.init_method_std)
         scaled_init_method = scaled_init_method_normal(
             args.init_method_std, args.num_layers)
@@ -268,9 +234,7 @@ def __init__(self, num_tokentypes=2,
             add_pooler=False,
             encoder_attn_mask_type=AttnMaskType.padding,
             init_method=init_method,
-            scaled_init_method=scaled_init_method,
-            pre_process=self.pre_process,
-            post_process=self.post_process)
+            scaled_init_method=scaled_init_method)
 
         if args.biencoder_projection_dim > 0:
             self.projection_enc = get_linear_layer(args.hidden_size,
@@ -283,6 +247,7 @@ def forward(self, input_ids, attention_mask, tokentype_ids=None):
         #extended_attention_mask = bert_extended_attention_mask(attention_mask)
         position_ids = bert_position_ids(input_ids)
 
+
         lm_output = self.language_model(input_ids,
                                         position_ids,
                                         extended_attention_mask,
@@ -320,7 +285,7 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
 
     def load_state_dict(self, state_dict, strict=True):
         """Customized load."""
-        print_rank_0("loading pretrained weights")
+        print_rank_0("loading BERT weights")
         self.language_model.load_state_dict(
             state_dict[self._language_model_key], strict=strict)
 
diff --git a/megatron/model/vit_model.py b/megatron/model/vit_model.py
index a1a86cfff3a..84a52a8294a 100644
--- a/megatron/model/vit_model.py
+++ b/megatron/model/vit_model.py
@@ -50,11 +50,11 @@ def __init__(self, hidden_size, num_classes):
     def forward(self, hidden_states, sequence_index=0):
         # hidden_states: [b, s, h]
         # sequence_index: index of the token to pool.
-        hidden_state = hidden_states[:, sequence_index, :]
-        dense_in_result = self.dense_in(hidden_state)
-        tanh_result = torch.tanh(dense_in_result)
-        dense_out_result = self.dense_out(tanh_result)
-        return dense_out_result
+        x = hidden_states[:, sequence_index, :]
+        x = self.dense_in(x)
+        x = torch.tanh(x)
+        x = self.dense_out(x)
+        return x
 
 
 def twod_interpolate_position_embeddings_hook(
@@ -122,12 +122,8 @@ def twod_interpolate_position_embeddings_hook(
 class VitModel(MegatronModule):
     """Vision Transformer Model."""
 
-    def __init__(self, 
-                 num_classes,
-                 finetune=False,
-                 pre_process=True,
-                 post_process=True):
-        super(VitModel, self).__init__(share_word_embeddings=False)
+    def __init__(self, num_classes, finetune=False):
+        super(VitModel, self).__init__()
         args = get_args()
 
         self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
@@ -140,8 +136,6 @@ def __init__(self,
                 args.init_method_std, args.num_layers
             )
 
-        self.pre_process = pre_process
-        self.post_process = post_process
         self.hidden_size = args.hidden_size
         self.num_classes = num_classes
         self.patch_dim = args.patch_dim
@@ -154,81 +148,63 @@ def __init__(self,
         self.seq_length = self.num_patches + 1
         self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
 
-        if self.pre_process:
-            # cls_token
-            self.cls_token = torch.nn.Parameter(
-                torch.randn(1, 1, self.hidden_size)
-            )
-            torch.nn.init.zeros_(self.cls_token)
+        # cls_token
+        self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
+        torch.nn.init.zeros_(self.cls_token)
 
-            # Linear encoder
-            self.linear_encoder = torch.nn.Linear(
-                self.flatten_dim, self.hidden_size
-            )
+        # Linear encoder
+        self.linear_encoder = torch.nn.Linear(
+            self.flatten_dim, self.hidden_size
+        )
 
-            # embedding
-            self.position_embeddings = torch.nn.Embedding(
-                self.seq_length, self.hidden_size
-            )
-            init_method_normal(args.init_method_std)(
-                self.position_embeddings.weight
-            )
-            self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
+        # embedding
+        self.position_embeddings = torch.nn.Embedding(
+            self.seq_length, self.hidden_size
+        )
+        init_method_normal(args.init_method_std)(
+            self.position_embeddings.weight
+        )
+        self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
 
-            self.position_embeddings._register_load_state_dict_pre_hook(
-                twod_interpolate_position_embeddings_hook
-            )
+        self.position_embeddings._register_load_state_dict_pre_hook(
+            twod_interpolate_position_embeddings_hook
+        )
 
-            self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
+        self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
 
         # Transformer
         self.transformer = ParallelTransformer(
-            self.init_method, 
-            self.scaled_init_method,
-            pre_process=self.pre_process,
-            post_process=self.post_process
+            self.init_method, self.scaled_init_method
         )
 
-        if self.post_process:
-            # MLP head
-            if not self.finetune:
-                self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
-            else:
-                self.class_head = get_linear_layer(
-                    self.hidden_size, num_classes, torch.nn.init.zeros_
-                )
-
-    def set_input_tensor(self, input_tensor):
-        """See megatron.model.transformer.set_input_tensor()"""
-        self.transformer.set_input_tensor(input_tensor)
-
-    def forward(self, input):
-
-        if self.pre_process:
-            rearranged_input = einops.rearrange(
-                input,
-                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
-                p1=self.patch_dim,
-                p2=self.patch_dim,
+        # MLP head
+        if not self.finetune:
+            self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
+        else:
+            self.class_head = get_linear_layer(
+                self.hidden_size, num_classes, torch.nn.init.zeros_
             )
 
-            assert rearranged_input.dtype == torch.half
-            encoder_output = self.linear_encoder(rearranged_input)
-            cls_tokens = self.cls_token.expand(encoder_output.shape[0], -1, -1)
-            concatenated_tokens = torch.cat((cls_tokens, encoder_output), dim=1)
+    def forward(self, x):
+        x = einops.rearrange(
+            x,
+            "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+            p1=self.patch_dim,
+            p2=self.patch_dim,
+        )
 
-            token_embeddings = concatenated_tokens + \
-                self.position_embeddings(self.position_ids)
-            hidden_states = self.embedding_dropout(token_embeddings)
-        else:
-            hidden_states = input
+        assert x.dtype == torch.half
+        x = self.linear_encoder(x)
+        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
+        x = torch.cat((cls_tokens, x), dim=1)
 
-        hidden_states = self.transformer(hidden_states, None)
+        x = x + self.position_embeddings(self.position_ids)
+        x = self.embedding_dropout(x)
+        x = self.transformer(x, None)
 
-        if self.post_process:
-            if not self.finetune:
-                hidden_states = self.mlp_head(hidden_states)
-            else:
-                hidden_states = self.class_head(hidden_states[:, 0, :])
+        if not self.finetune:
+            x = self.mlp_head(x)
+        else:
+            x = self.class_head(x[:, 0, :])
 
-        return hidden_states
+        return x
diff --git a/megatron/tokenizer/bert_tokenization.py b/megatron/tokenizer/bert_tokenization.py
index 99f9a87958d..a3aa6d907e3 100644
--- a/megatron/tokenizer/bert_tokenization.py
+++ b/megatron/tokenizer/bert_tokenization.py
@@ -181,35 +181,6 @@ def convert_tokens_to_ids(self, tokens):
     def convert_ids_to_tokens(self, ids):
         return convert_by_vocab(self.inv_vocab, ids)
 
-    @staticmethod
-    def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
-        """ Converts a sequence of tokens (string) in a single string. """
-
-        def clean_up_tokenization(out_string):
-            """ Clean up a list of simple English tokenization artifacts
-            like spaces before punctuations and abreviated forms.
-            """
-            out_string = (
-                out_string.replace(" .", ".")
-                    .replace(" ?", "?")
-                    .replace(" !", "!")
-                    .replace(" ,", ",")
-                    .replace(" ' ", "'")
-                    .replace(" n't", "n't")
-                    .replace(" 'm", "'m")
-                    .replace(" 's", "'s")
-                    .replace(" 've", "'ve")
-                    .replace(" 're", "'re")
-            )
-            return out_string
-
-        text = ' '.join(tokens).replace(' ##', '').strip()
-        if clean_up_tokenization_spaces:
-            clean_text = clean_up_tokenization(text)
-            return clean_text
-        else:
-            return text
-
     def vocab_size(self):
         return len(self.vocab)
 
diff --git a/pretrain_ict.py b/pretrain_ict.py
index 79759250f18..1438b3d5782 100644
--- a/pretrain_ict.py
+++ b/pretrain_ict.py
@@ -14,8 +14,6 @@
 # limitations under the License.
 
 """Pretrain BERT for Inverse Cloze Task"""
-
-from functools import partial
 import math
 
 import torch
@@ -33,16 +31,13 @@
 from megatron.utils import average_losses_across_data_parallel_group
 
 
-def pretrain_ict_model_provider(pre_process=True, post_process=True):
+def pretrain_ict_model_provider():
     args = get_args()
-
     model = biencoder_model_provider(
                 only_context_model=False,
                 only_query_model=False,
                 biencoder_shared_query_context_model=\
-                args.biencoder_shared_query_context_model,
-                pre_process=pre_process, post_process=post_process)
-
+                    args.biencoder_shared_query_context_model)
     return model
 
 def get_group_world_size_rank():
@@ -82,9 +77,25 @@ def backward(ctx, grad_output):
         output = output_list[rank].contiguous()
         return output
 
-def loss_func(output_tensor):
+def forward_step(data_iterator, model, input_tensor):
+    """Forward step."""
     args = get_args()
-    query_logits, context_logits = output_tensor
+    timers = get_timers()
+
+    # Get the batch.
+    timers('batch-generator').start()
+    query_tokens, query_mask, \
+    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
+    timers('batch-generator').stop()
+
+    # Query and Context Types
+    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
+    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
+
+    # Forward model.
+    query_logits, context_logits = model(query_tokens, query_mask,
+                                    query_types, context_tokens,
+                                    context_mask, context_types)
 
     micro_batch_size = query_logits.shape[0]
     # recall we assert that tensor_model_parallel_size == 1
@@ -126,28 +137,6 @@ def topk_accuracy(k):
     return loss, stats_dict
 
 
-
-def forward_step(data_iterator, model):
-    """Forward step."""
-    args = get_args()
-    timers = get_timers()
-
-    # Get the batch.
-    timers('batch-generator').start()
-    query_tokens, query_mask, \
-    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
-    timers('batch-generator').stop()
-
-    # Query and Context Types
-    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
-    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
-
-    # Forward model.
-    output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
-                        context_mask, context_types)
-
-    return output_tensor, partial(loss_func)
-
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid and test datasets."""
     args = get_args()
diff --git a/pretrain_vit.py b/pretrain_vit.py
index 7770c68d5d5..16ec10439a0 100644
--- a/pretrain_vit.py
+++ b/pretrain_vit.py
@@ -17,22 +17,19 @@
 
 import torch
 import torch.nn.functional as F
-from functools import partial
 from megatron import get_args, get_timers, mpu, print_rank_0
 from megatron.data.vit_dataset import build_train_valid_datasets
 from megatron.model.vit_model import VitModel
 from megatron.training import pretrain
 from megatron.utils import average_losses_across_data_parallel_group
 
-def model_provider(pre_process=True, post_process=True):
+def model_provider():
     """Build the model."""
 
     print_rank_0("building VIT model ...")
     args = get_args()
 
-    model = VitModel(num_classes=args.num_classes,
-                     pre_process=pre_process,
-                     post_process=post_process)
+    model = VitModel(num_classes=args.num_classes)
     return model
 
 def get_batch(data_iterator):
@@ -45,21 +42,10 @@ def get_batch(data_iterator):
 
     return images, labels
 
-def loss_func(labels, output_tensor):
-    logits = output_tensor.contiguous().float()
-    loss = F.cross_entropy(logits, labels)
-
-    outputs = torch.argmax(logits, -1)
-    correct = (outputs == labels).float()
-    accuracy = torch.mean(correct)
-
-    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
-
-    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
-
-def forward_step(data_iterator, model):
+def forward_step(data_iterator, model, input_tensor):
     """Forward step."""
     timers = get_timers()
+    assert input_tensor is None
 
     # Get the batch.
     timers("batch-generator").start()
@@ -70,9 +56,17 @@ def forward_step(data_iterator, model):
     timers("batch-generator").stop()
 
     # Forward model. lm_labels
-    output_tensor = model(images)
+    logits = model(images).contiguous().float()
+    loss = F.cross_entropy(logits, labels)
+
+    outputs = torch.argmax(logits, -1)
+    correct = (outputs == labels).float()
+    accuracy = torch.mean(correct)
+
+    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])
+
+    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
 
-    return output_tensor, partial(loss_func, labels)
 
 def train_valid_test_datasets_provider(train_val_test_num_samples):
     """Build train, valid, and test datasets."""
diff --git a/tasks/finetune_utils.py b/tasks/finetune_utils.py
index 9411b1849fc..918417b41b2 100644
--- a/tasks/finetune_utils.py
+++ b/tasks/finetune_utils.py
@@ -16,10 +16,10 @@
 """Finetune utilities."""
 
 from functools import partial
-import sys
+
 import torch
 
-from megatron import get_args, get_num_microbatches
+from megatron import get_args
 from megatron import print_rank_0
 from megatron import get_timers
 from megatron import mpu
@@ -80,8 +80,7 @@ def _cross_entropy_forward_step(batch, model):
     return output_tensor, partial(cross_entropy_loss_func, labels)
 
 
-def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
-        task_collate_fn=None):
+def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
     """Data loader. Note that batch-size is the local (per GPU) batch-size."""
 
     # Sampler.
@@ -97,8 +96,7 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
                                               shuffle=False,
                                               num_workers=num_workers,
                                               drop_last=drop_last,
-                                              pin_memory=True,
-                                              collate_fn=task_collate_fn)
+                                              pin_memory=True)
 
     return data_loader
 
@@ -114,24 +112,21 @@ def _build_infinite_size_dataloader(dataloader):
             iterator = dataloader.__iter__()
 
 
-def _build_train_valid_dataloaders(train_dataset, valid_dataset, 
-    task_collate_fn=None):
+def _build_train_valid_dataloaders(train_dataset, valid_dataset):
     """Traing and validation dataloaders."""
     args = get_args()
 
     print_rank_0('building train and validation dataloaders ...')
     # Training dataset.
     train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
-                                         args.num_workers, not args.keep_last,
-                                         task_collate_fn)
+                                         args.num_workers, not args.keep_last)
     # Set the training iterations.
     args.train_iters_per_epoch = len(train_dataloader)
     args.train_iters = args.epochs * args.train_iters_per_epoch
     # Validation dataset. For this dataset, we do not need to set up
     # shuffling so we can just use a simple infinite loop.
     valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
-                                          args.num_workers, not args.keep_last,
-                                          task_collate_fn)
+                                          args.num_workers, not args.keep_last)
     valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
 
     # Now that we've built the data loaders, set batch_size arguments
@@ -159,8 +154,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
     args = get_args()
     timers = get_timers()
 
-    assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
-
     # Turn on training mode which enables dropout.
     for m in model:
         m.train()
@@ -195,7 +188,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
 
             # Train for one step.
             out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
-
             losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
             iteration += 1
 
@@ -217,11 +209,9 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                                   optimizer, lr_scheduler)
 
             # Checkpointing
-            saved_checkpoint = False
             if args.save and args.save_interval and \
                iteration % args.save_interval == 0:
                 save_checkpoint(iteration, model, optimizer, lr_scheduler)
-                saved_checkpoint = True
 
             # Evaluation
             if args.eval_interval and iteration % args.eval_interval == 0:
@@ -230,14 +220,6 @@ def _train(model, optimizer, lr_scheduler, forward_step,
                                            valid_dataloader, model,
                                            iteration, False)
 
-            # Exiting based on iterations
-            if args.exit_interval and iteration % args.exit_interval == 0:
-                if not saved_checkpoint:
-                    save_checkpoint(iteration, model, optimizer, lr_scheduler)
-                torch.distributed.barrier()
-                print_rank_0('exiting program at iteration {}'.format(iteration))
-                sys.exit()
-
         # Checkpointing at the end of each epoch.
         if args.save:
             save_checkpoint(iteration, model, optimizer, lr_scheduler)
@@ -249,8 +231,7 @@ def _train(model, optimizer, lr_scheduler, forward_step,
 
 def finetune(train_valid_datasets_provider, model_provider,
              forward_step=_cross_entropy_forward_step,
-             end_of_epoch_callback_provider=None,
-             task_collate_fn=None):
+             end_of_epoch_callback_provider=None):
     """Main finetune function used across all tasks."""
     args = get_args()
     timers = get_timers()
@@ -263,7 +244,7 @@ def finetune(train_valid_datasets_provider, model_provider,
     if args.epochs > 0:
         train_dataset, valid_dataset = train_valid_datasets_provider()
         train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
-            train_dataset, valid_dataset, task_collate_fn)
+            train_dataset, valid_dataset)
     else:
         args.train_iters = 0
     timers('train/valid/test dataset/dataloder').stop()
@@ -287,11 +268,8 @@ def finetune(train_valid_datasets_provider, model_provider,
     if args.iteration == 0 and args.pretrained_checkpoint is not None:
         original_load = args.load
         args.load = args.pretrained_checkpoint
-        original_rng = args.no_load_rng
-        args.no_load_rng = True
         _ = load_checkpoint(model, None, None)
         args.load = original_load
-        args.no_load_rng = original_rng
         # This is critical when only model is loaded. We should make sure
         # main parameters are also updated.
         optimizer.reload_model_params()
diff --git a/tasks/main.py b/tasks/main.py
index 6d8fc8f5fd6..f5bd5ad6930 100644
--- a/tasks/main.py
+++ b/tasks/main.py
@@ -62,29 +62,6 @@ def get_tasks_args(parser):
     group.add_argument('--faiss-topk-retrievals', type=int, default=100,
                        help='Number of blocks to use as top-k during retrieval')
 
-    # finetune for retriever
-    group.add_argument('--eval-micro-batch-size', type=int, default=None,
-                       help='Eval Batch size per model instance (local batch '
-                            'size). Global batch size is local batch size '
-                            'times data parallel size.')
-    group.add_argument('--train-with-neg', action='store_true',
-                       help='Whether to use negative examples during model '
-                        'training')
-    group.add_argument('--train-hard-neg', type=int, default=0,
-                       help='Number of hard negative exmaples to use during '
-                        'training')
-
-
-    # parameters for Av.rank validation method
-    # Following options/arguments have been taken directly from DPR codebase
-    group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
-                        help='Av.rank validation: how many hard negatives to'
-                        ' take from each question pool')
-    group.add_argument('--val-av-rank-other-neg', type=int, default=30,
-                        help='Av.rank validation: how many other negatives to'
-                        ' take from each question pool')
-
-
     return parser
 
 
@@ -104,10 +81,8 @@ def get_tasks_args(parser):
         from glue.finetune import main
     elif args.task in ['LAMBADA', 'WIKITEXT103']:
         from zeroshot_gpt.evaluate import main
-    elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
+    elif args.task in ['ICT-ZEROSHOT-NQ']:
         from orqa.evaluate_orqa import main
-    elif args.task in ['RET-FINETUNE-NQ']:
-        from orqa.supervised.finetune import main
     else:
         raise NotImplementedError('Task {} is not implemented.'.format(
             args.task))
diff --git a/tasks/orqa/README.md b/tasks/orqa/README.md
deleted file mode 100644
index a8e8f8e6fab..00000000000
--- a/tasks/orqa/README.md
+++ /dev/null
@@ -1,36 +0,0 @@
-## End-to-End Training of Neural Retrievers for Open-Domain Question Answering
-
-Below we present the steps to run unsupervised and supervised trainining and evaluation of the retriever for [open domain question answering](https://arxiv.org/abs/2101.00408).
-
-## Retriever Training
-
-#### Unsupervised pretraining
-1. Use `tools/preprocess_data.py` to preprocess the dataset for Inverse Cloze Task (ICT), which we call unsupervised pretraining. This script takes as input a corpus in loose JSON format and creates fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block and multiple blocks per document. Run [`tools/preprocess_data.py`](../../tools/preprocess_data.py) to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. We construct two datasets, one with the title of every document and another with the body.
-
-
-python tools/preprocess_data.py \
-    --input /path/to/corpus.json \
-    --json-keys text title \
-    --split-sentences \
-    --tokenizer-type BertWordPieceLowerCase \
-    --vocab-file /path/to/vocab.txt \
-    --output-prefix corpus_indexed \
-    --workers 10
-
- -2. The [`examples/pretrain_ict.sh`](../../examples/pretrain_ict.sh) script runs a single GPU 217M parameter biencoder model for ICT retriever training. Single GPU training is primarily intended for debugging purposes, as the code is developed for distributed training. The script uses a pretrained BERT model and we use a total of batch size of 4096 for the ICT training. - -3. Evaluate the pretrained ICT model using [`examples/evaluate_retriever_nq.sh`](../../examples/evaluate_retriever_nq.sh) for [Google's Natural Questions Open dataset](https://arxiv.org/pdf/1906.00300.pdf). - -#### Supervised finetuning - -1. Use the above pretrained ICT model to finetune using [Google's Natural Questions Open dataset](https://github.com/google-research/language/tree/master/language/orqa). The script [`examples/finetune_retriever_distributed.sh`](../../examples/finetune_retriever_distributed.sh) provides an example for how to perform the training. Our finetuning process includes retriever score scaling and longer training (80 epochs) on top [DPR training](https://arxiv.org/abs/2004.04906). - -2. Evaluate the finetuned model using the same evaluation script as mentioned above for the unsupervised model. - -More details on the retriever are available in [our paper](https://arxiv.org/abs/2101.00408). - -## Reader Training - -The reader component will be available soon. - diff --git a/tasks/orqa/evaluate_orqa.py b/tasks/orqa/evaluate_orqa.py index 87c59ea30e2..7e6b269231a 100644 --- a/tasks/orqa/evaluate_orqa.py +++ b/tasks/orqa/evaluate_orqa.py @@ -15,8 +15,10 @@ """Main tasks functionality.""" -from megatron import get_args, print_rank_0 -from megatron.indexer import IndexBuilder +import os +import sys + +from megatron import get_args from tasks.orqa.evaluate_utils import ORQAEvaluator def main(): @@ -26,20 +28,6 @@ def main(): args = get_args() - """ - Create a BlockData data structure by running an IndexBuilder over an - ICT Dataset and then evaluate on NQ task - """ - - print_rank_0("Starting index builder!") - - index_builder = IndexBuilder() - index_builder.build_and_save_index() - print_rank_0("Build and save indices: done!") - - - print_rank_0("Starting evaluations!") - # Set up the model and evaluator evaluator = ORQAEvaluator() diff --git a/tasks/orqa/evaluate_utils.py b/tasks/orqa/evaluate_utils.py index 08b1e929b3e..ebee03522e1 100644 --- a/tasks/orqa/evaluate_utils.py +++ b/tasks/orqa/evaluate_utils.py @@ -18,14 +18,13 @@ from megatron import get_args, print_rank_0 from megatron.checkpointing import load_biencoder_checkpoint from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset +from tasks.orqa.natural_questions.nq import get_nq_dataset +from tasks.orqa.natural_questions.nq import get_one_epoch_nq_dataloader +from tasks.orqa.natural_questions.nq import process_nq_batch +from tasks.orqa.natural_questions.qa_utils import calculate_matches from megatron.data.realm_index import OpenRetreivalDataStore, FaissMIPSIndex -from megatron.model.biencoder_model import get_model_provider +from megatron.model.biencoder_model import biencoder_model_provider from megatron.training import get_model -from tasks.orqa.unsupervised.nq import get_nq_dataset -from tasks.orqa.unsupervised.nq import get_one_epoch_nq_dataloader -from tasks.orqa.unsupervised.nq import process_nq_batch -from tasks.orqa.unsupervised.qa_utils import calculate_matches - class ORQAEvaluator(object): def __init__(self): @@ -45,8 +44,9 @@ def __init__(self): if args.biencoder_shared_query_context_model: only_query_model = False - model = get_model(get_model_provider(only_query_model=only_query_model, - biencoder_shared_query_context_model=args.biencoder_shared_query_context_model)) + model = get_model(lambda: biencoder_model_provider(only_query_model=\ + only_query_model, biencoder_shared_query_context_model=\ + args.biencoder_shared_query_context_model)) self.model = load_biencoder_checkpoint(model, only_query_model=only_query_model) diff --git a/tasks/orqa/unsupervised/nq.py b/tasks/orqa/natural_questions/nq.py similarity index 100% rename from tasks/orqa/unsupervised/nq.py rename to tasks/orqa/natural_questions/nq.py diff --git a/tasks/orqa/unsupervised/qa_utils.py b/tasks/orqa/natural_questions/qa_utils.py similarity index 98% rename from tasks/orqa/unsupervised/qa_utils.py rename to tasks/orqa/natural_questions/qa_utils.py index 811a05834a4..24e71e683a4 100644 --- a/tasks/orqa/unsupervised/qa_utils.py +++ b/tasks/orqa/natural_questions/qa_utils.py @@ -22,7 +22,7 @@ from typing import Tuple, List, Dict import regex as re -from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer +from tasks.orqa.natural_questions.tokenizers import SimpleTokenizer logger = logging.getLogger(__name__) diff --git a/tasks/orqa/unsupervised/tokenizers.py b/tasks/orqa/natural_questions/tokenizers.py similarity index 100% rename from tasks/orqa/unsupervised/tokenizers.py rename to tasks/orqa/natural_questions/tokenizers.py diff --git a/tasks/orqa/supervised/data.py b/tasks/orqa/supervised/data.py deleted file mode 100644 index b45a842b61c..00000000000 --- a/tasks/orqa/supervised/data.py +++ /dev/null @@ -1,300 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ORQA dataset.""" - -import json -import random -from abc import ABC -from abc import abstractmethod - -import numpy as np -from torch.utils.data import Dataset - -from megatron import print_rank_0, get_args -from megatron.data.biencoder_dataset_utils import make_attention_mask - -def build_token_types_from_context_list(ctx_list, tokenizer, max_seq_length): - ctx_id_list, ctx_types_list = [], [] - for context in ctx_list: - title_ids = tokenizer.tokenize(context['title']) - ctx_ids = tokenizer.tokenize(context['text']) - ctx_ids = title_ids + [tokenizer.sep_id] + ctx_ids - - ctx_ids, ctx_types, _ = build_tokens_types_paddings_from_ids(ctx_ids, - max_seq_length, tokenizer.cls, - tokenizer.sep, tokenizer.pad) - ctx_id_list.append(ctx_ids) - ctx_types_list.append(ctx_types) - - return ctx_id_list, ctx_types_list - - -def build_tokens_types_paddings_from_text(query, context, - tokenizer, max_seq_length): - """Build token types and paddings, trim if needed, and pad if needed.""" - - query_ids = tokenizer.tokenize(query) - query_ids, query_types, query_pad_mask = \ - build_tokens_types_paddings_from_ids(query_ids, max_seq_length, \ - tokenizer.cls, tokenizer.sep, tokenizer.pad) - - # Appending the title of the context at front - extended_ctx_ids = None - if context is not None: - title_ids = tokenizer.tokenize(context['title']) - ctx_ids = tokenizer.tokenize(context['text']) - extended_ctx_ids = title_ids + [tokenizer.sep] + ctx_ids - - ctx_ids, ctx_types, ctx_pad_mask = \ - build_tokens_types_paddings_from_ids(extended_ctx_ids, - max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad) - - return query_ids, query_types, query_pad_mask, \ - ctx_ids, ctx_types, ctx_pad_mask - - -# Similar code tasks/data_utils with some changes -def build_tokens_types_paddings_from_ids(text_ids, max_seq_length, - cls_id, sep_id, pad_id): - """Build token types and paddings, trim if needed, and pad if needed.""" - enc_ids = [] - tokentypes_enc = [] - - # [CLS]. - enc_ids.append(cls_id) - tokentypes_enc.append(0) - - # A. - len_src = len(text_ids) - enc_ids.extend(text_ids) - tokentypes_enc.extend([0] * len_src) - - # Cap the size. - if len(enc_ids) > max_seq_length - 1: - enc_ids = enc_ids[0: max_seq_length - 1] - tokentypes_enc = tokentypes_enc[0: max_seq_length - 1] - - # [SEP]. - enc_ids.append(sep_id) - tokentypes_enc.append(0) - - num_tokens_enc = len(enc_ids) - # Padding. - padding_length = max_seq_length - len(enc_ids) - if padding_length > 0: - enc_ids.extend([pad_id] * padding_length) - tokentypes_enc.extend([pad_id] * padding_length) - - pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length) - pad_mask = np.array(pad_mask, dtype=np.int64) - - return enc_ids, tokentypes_enc, pad_mask - - -def build_sample(query_ids, query_types, query_pad_mask, - ctx_ids, ctx_types, ctx_pad_mask, answers, - neg_ctx_id_list=None, neg_ctx_types_list=None, - include_neg=False): - """Convert to numpy and return a sample consumed by the batch producer.""" - - query_ids = np.array(query_ids, dtype=np.int64) - query_types = np.array(query_types, dtype=np.int64) - query_mask = make_attention_mask(query_ids, query_ids) - - ctx_ids = np.array(ctx_ids, dtype=np.int64) - ctx_types = np.array(ctx_types, dtype=np.int64) - ctx_mask = make_attention_mask(ctx_ids, ctx_ids) - - sample = ({ - 'query': query_ids, - 'query_mask': query_mask, - 'query_types': query_types, - 'query_pad_mask': query_pad_mask, - 'context': ctx_ids, - 'context_mask': ctx_mask, - 'context_types': ctx_types, - 'context_pad_mask': ctx_pad_mask, - 'reference': answers - }) - - if include_neg: - neg_ctx_ids = np.array(neg_ctx_id_list, dtype=np.int64) - neg_ctx_id_types = np.array(neg_ctx_types_list, dtype=np.int64) - neg_ctx_mask = np.array([make_attention_mask(ids, ids) \ - for ids in neg_ctx_ids], dtype=np.int64) - - sample['neg_context'] = neg_ctx_ids - sample['neg_context_types'] = neg_ctx_id_types - sample['neg_context_mask'] = neg_ctx_mask - - return sample - - -class OpenRetrievalAbstractDataset(ABC, Dataset): - """Open Retrieval base dataset class.""" - - def __init__(self, task_name, dataset_name, datapaths, tokenizer, \ - max_seq_length, evaluate=False): - # Store inputs. - args = get_args() - self.evaluate = evaluate - self.val_av_rank_hard_neg = args.val_av_rank_hard_neg - self.val_av_rank_other_neg = args.val_av_rank_other_neg - self.train_with_neg = args.train_with_neg - self.train_hard_neg = args.train_hard_neg - - self.task_name = task_name - self.dataset_name = dataset_name - self.tokenizer = tokenizer - self.max_seq_length = max_seq_length - print_rank_0(' > building {} dataset for {}:'.format(self.task_name, - self.dataset_name)) - # Process the files. - string = ' > paths:' - for path in datapaths: - string += ' ' + path - print_rank_0(string) - self.samples = [] - for datapath in datapaths: - self.samples.extend(self.process_samples_from_single_path(datapath)) - - args = get_args() - if args.sample_rate < 1: # subsample - k = int(len(self.samples) * args.sample_rate) - self.samples = random.sample(self.samples, k) - - print_rank_0(' >> total number of samples: {}'.format( - len(self.samples))) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - raw_sample = self.samples[idx] - - query_ids, query_types, query_pad_mask, ctx_ids, ctx_types, \ - ctx_pad_mask = build_tokens_types_paddings_from_text( \ - raw_sample['question'], raw_sample['pos_context'], \ - self.tokenizer, self.max_seq_length) - - if self.evaluate: - neg_ctx_list = \ - raw_sample['negative_context'][:self.val_av_rank_other_neg] + \ - raw_sample['hard_negative_context'][:self.val_av_rank_hard_neg] - neg_ctx_id_list, neg_ctx_types_list = \ - build_token_types_from_context_list(neg_ctx_list, \ - self.tokenizer, self.max_seq_length) - - elif self.train_with_neg: - hard_negative_ctx = raw_sample['hard_negative_context'] - negative_ctx = raw_sample['negative_context'] - if True: # TODO: fix this or remove this condition - random.shuffle(hard_negative_ctx) - random.shuffle(negative_ctx) - - neg_ctx_list = hard_negative_ctx[:self.train_hard_neg] - # In the Google NQ dataset by DPR paper, there are around more than - # 50 missing hard negatives in training data. - # In those cases, substitute hard negatives by simple negatives. - if len(neg_ctx_list) < self.train_hard_neg: - neg_ctx_list += negative_ctx[:self.train_hard_neg - \ - len(neg_ctx_list)] - - neg_ctx_id_list, neg_ctx_types_list = \ - build_token_types_from_context_list(neg_ctx_list, - self.tokenizer, self.max_seq_length) - else: - neg_ctx_id_list = None - neg_ctx_types_list = None - - sample = build_sample(query_ids, query_types, query_pad_mask, - ctx_ids, ctx_types, ctx_pad_mask, - raw_sample['answers'], - neg_ctx_id_list, neg_ctx_types_list, - include_neg=self.evaluate or self.train_with_neg) - - return sample - - @staticmethod - @abstractmethod - def process_samples_from_single_path(filename): - """Abstract method that takes a filename and - returns a list of dataset samples, each sample being a dict of - {'text': string, 'text': string} - """ - pass - - - -def normalize_question(question): - if question[-1] == '?': - question = question[:-1] - return question - -# The following class reads the datasets for training retriever as -# prepared by the DPR codebase (https://github.com/facebookresearch/DPR) - -class NQSupervisedDataset(OpenRetrievalAbstractDataset): - - def __init__(self, name, datapaths, tokenizer, max_seq_length, \ - evaluate=False): - super().__init__('natural_questions_ret', - name, - datapaths, - tokenizer, - max_seq_length, - evaluate=evaluate) - - @staticmethod - def process_samples_from_single_path(filename): - """"Implement abstract method.""" - print_rank_0(' > Processing {} ...'.format(filename)) - samples = [] - total = 0 - - with open(filename, 'r', encoding="utf-8") as f: - data = json.load(f) - for row in data: - question = normalize_question(row['question']) - pos_context = row['positive_ctxs'][0] - - # Hard Negative Contexts - if len(row['hard_negative_ctxs']) > 0: - hard_neg_context = row['hard_negative_ctxs'] - else: - hard_neg_context = [] - - # Negative Contexts - if len(row['negative_ctxs']) > 0: - neg_context = row['negative_ctxs'] - else: - neg_context = [] - - answers = row['answers'] - sample = {'question': question, - 'pos_context': pos_context, - 'hard_negative_context': hard_neg_context, - 'negative_context': neg_context, - 'answers': answers} - total += 1 - samples.append(sample) - - if total % 5000 == 0: - print_rank_0(' > processed {} so far ...'.format(total)) - - print_rank_0(' >> processed {} samples.'.format(len(samples))) - return samples - diff --git a/tasks/orqa/supervised/eval_utils.py b/tasks/orqa/supervised/eval_utils.py deleted file mode 100644 index 67dca512b0d..00000000000 --- a/tasks/orqa/supervised/eval_utils.py +++ /dev/null @@ -1,206 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Evaluation utilities.""" -from collections import OrderedDict -import math -import numpy as np -import time -import torch -import torch.nn.functional as F -from torch.utils.data import DataLoader - -from megatron import get_args, print_rank_0 -from megatron import mpu -from megatron.utils import average_losses_across_data_parallel_group -from tasks.finetune_utils import build_data_loader - -def task_collate_fn(batch_data): - # generate batch - batch_size = len(batch_data) - tensorized = OrderedDict() - for d in batch_data: - for k, v in d.items(): - tensorized.setdefault(k, []).append(v) - - tensorized['query'] = torch.LongTensor(tensorized['query']) - tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask']) - tensorized['query_types'] = torch.LongTensor(tensorized['query_types']) - tensorized['query_pad_mask'] = \ - torch.LongTensor(tensorized['query_pad_mask']) - - tensorized['context'] = torch.LongTensor(tensorized['context']) - tensorized['context_mask'] = \ - torch.LongTensor(tensorized['context_mask']) - tensorized['context_types'] = \ - torch.LongTensor(tensorized['context_types']) - tensorized['context_pad_mask'] = \ - torch.LongTensor(tensorized['context_pad_mask']) - - if 'neg_context' in tensorized: - tensorized['neg_context'] = \ - torch.LongTensor(np.concatenate(tensorized['neg_context'])) - tensorized['neg_context_mask'] = \ - torch.LongTensor(np.concatenate(tensorized['neg_context_mask'])) - tensorized['neg_context_types'] = \ - torch.LongTensor(np.concatenate(tensorized['neg_context_types'])) - - return tensorized - - - -def process_batch(batch): - """Process batch and produce inputs for the model.""" - query_tokens = batch['query'].long().cuda() - query_mask = (batch['query_mask'] < 0.5).cuda() - query_types = batch['query_types'].long().cuda() - query_pad_mask = batch['query_pad_mask'].long().cuda() - - context_tokens = batch['context'].long().cuda() - context_mask = (batch['context_mask'] < 0.5).cuda() - context_types = batch['context_types'].long().cuda() - context_pad_mask = batch['context_pad_mask'].long().cuda() - - if 'neg_context' in batch: - neg_context_tokens = batch['neg_context'].long().cuda() - neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda() - neg_context_types = batch['neg_context_types'].long().cuda() - else: - neg_context_tokens = None - neg_context_mask = None - neg_context_types = None - - reference = batch['reference'] - - return query_tokens, query_mask, query_types, query_pad_mask, \ - context_tokens, context_mask, context_types, context_pad_mask, \ - neg_context_tokens, neg_context_mask, neg_context_types, reference - -def accuracy_func_provider(single_dataset_provider, rank0sampler=False): - """Provide function that calculates accuracies.""" - args = get_args() - - print_rank_0("accuracy_func_provider is CALLED") - - # Build dataloaders - datapath = args.valid_data - dataset = single_dataset_provider(datapath) - - drop_last = False - if mpu.get_data_parallel_world_size() > 1 and not rank0sampler: - drop_last = True - - print_rank_0(datapath) - print_rank_0(rank0sampler) - - dataloader = build_data_loader(dataset, - args.eval_micro_batch_size, - num_workers=args.num_workers, - drop_last=drop_last, - task_collate_fn=task_collate_fn) - dataloaders = (dataset.dataset_name, dataloader) - - def metrics_func(model, epoch, output_predictions=False): - print_rank_0('calculating metrics by accuracy func in ORQA...') - - if output_predictions: - assert rank0sampler - names = 'predictions' - name, dataloader = dataloaders - if args.task == "RET-FINETUNE-NQ": - start_time = time.time() - output = retrieval_loss(model, dataloader) - stats_dict, total = output - format_string = "" - for k, v in stats_dict.items(): - format_string += "|{} = {:.2f}".format(k, v / total) - print_rank_0("epoch:{}{}".format(epoch, format_string)) - print_rank_0("taken time to calcuate metrics {:.3f}".format(\ - time.time() - start_time)) - else: - raise AssertionError("{} Task not supported".format(args.task)) - - return metrics_func - - -def retrieval_loss(model, dataloader): - args = get_args() - total = 0 - topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \ - args.retriever_report_topk_accuracies} - stats_dict = dict(rank=0, **topk_stats_dict) - - assert len(model) == 1 - unwrapped_model = model[0] - unwrapped_model.eval() - - with torch.no_grad(): - # For all the batches in the dataset. - for batch in dataloader: - # Run the model forward. - query_tokens, query_mask, query_types, _, \ - context_tokens, context_mask, context_types, _, \ - neg_context_tokens, neg_context_mask, neg_context_types, \ - reference = process_batch(batch) - - query_logits, context_logits = unwrapped_model(query_tokens, - query_mask, query_types, - torch.cat([context_tokens, neg_context_tokens]), - torch.cat([context_mask, neg_context_mask]), - torch.cat([context_types, neg_context_types])) - - retrieval_scores = torch.matmul(query_logits, - torch.transpose(context_logits, 0, 1)) - - if args.retriever_score_scaling: - retrieval_scores = retrieval_scores / \ - math.sqrt(args.hidden_size) - - local_batch_size = query_logits.shape[0] - labels = torch.arange(local_batch_size).long().cuda() - - softmax_scores = F.softmax(retrieval_scores, dim=1) - sorted_vals, sorted_indices = torch.topk(softmax_scores, - k=softmax_scores.shape[1], - sorted=True) - - def topk_accuracy(k): - return torch.cuda.FloatTensor( - [sum([int(labels[i] in sorted_indices[i, :k]) for i in \ - range(local_batch_size)])]) - - def get_rank(): - return torch.cuda.FloatTensor( - [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \ - for i in range(local_batch_size)])]) - - topk_accs = [topk_accuracy(k) for k in \ - args.retriever_report_topk_accuracies] - rank = get_rank() - losses = average_losses_across_data_parallel_group([rank, \ - *topk_accs]) - - # create stats_dict with retrieval loss and all specified - # top-k accuracies - topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \ - zip(args.retriever_report_topk_accuracies, losses[1:])} - temp_stats_dict = dict(rank=losses[0], **topk_acc_dict) - for k in stats_dict.keys(): - stats_dict[k] += temp_stats_dict[k] - total += local_batch_size - - unwrapped_model.train() - - return stats_dict, total diff --git a/tasks/orqa/supervised/finetune.py b/tasks/orqa/supervised/finetune.py deleted file mode 100644 index aed65ac9791..00000000000 --- a/tasks/orqa/supervised/finetune.py +++ /dev/null @@ -1,251 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""ORQA finetuning/evaluation.""" - -from functools import partial -import sys - -import math -import torch -import torch.nn.functional as F - -from megatron import get_args, get_timers, get_tokenizer -from megatron import mpu, print_rank_0 -from megatron.indexer import IndexBuilder -from megatron.model.biencoder_model import biencoder_model_provider -from megatron.utils import average_losses_across_data_parallel_group -from pretrain_ict import get_group_world_size_rank -from tasks.finetune_utils import finetune -from tasks.orqa.supervised.eval_utils import accuracy_func_provider -from tasks.orqa.supervised.eval_utils import process_batch, task_collate_fn -from tasks.orqa.evaluate_utils import ORQAEvaluator - -# input_ is a 2D tensor -def check_and_append_tensor_for_gather(group, rank, world_size, input_): - - # gather the size of the first dimension of the tensor from all ranks - current_length = input_.size()[0] - first_dim = torch.tensor([[current_length]], - device=torch.cuda.current_device()) - input_list = [torch.empty_like(first_dim) for _ in range(world_size)] - input_list[rank].copy_(first_dim) - torch.distributed.all_gather(input_list, first_dim, group=group) - all_input_list = torch.cat(input_list, dim=0).contiguous() - max_length = torch.max(all_input_list) - - # if the size are different than the max, extend the tensor - # accordingly - if max_length > current_length: - padding=tuple([0] * (input_.dim() * 2 - 1)) + \ - tuple([max_length - current_length]) - input_ = F.pad(input=input_, pad=padding) - - return input_ - -def orqa(Dataset): - - def cross_entropy_forward_step(batch, model): - """Simple forward step with cross-entropy loss.""" - timers = get_timers() - tokenizer = get_tokenizer() - - # Get the batch. - timers('batch generator').start() - try: - batch_ = next(batch) - except BaseException: - batch_ = batch - - group, rank, world_size = get_group_world_size_rank() - - query_tokens, query_mask, query_types, query_pad_mask, \ - context_tokens, context_mask, context_types, context_pad_mask, \ - neg_context_tokens, neg_context_mask, neg_context_types, \ - reference = process_batch(batch_) - - timers('batch generator').stop() - local_batch_size = query_tokens.shape[0] - - # Text representation of query and context - query_list, context_list = [], [] - for i in range(local_batch_size): - query_list.append(tokenizer.decode(query_tokens[i].tolist())) - context_list.append(tokenizer.decode(context_tokens[i].tolist())) - - if neg_context_tokens is not None: - neg_context_tokens = check_and_append_tensor_for_gather(group, - rank, world_size, neg_context_tokens) - neg_context_mask = check_and_append_tensor_for_gather(group, - rank, world_size, neg_context_mask) - neg_context_types = check_and_append_tensor_for_gather(group, - rank, world_size, neg_context_types) - - if neg_context_tokens is not None: - context_tokens = torch.cat([context_tokens, neg_context_tokens]) - context_mask = torch.cat([context_mask, neg_context_mask]) - context_types = torch.cat([context_types, neg_context_types]) - - # Forward model. - output_tensor = model(query_tokens, query_mask, - query_types, context_tokens, - context_mask, context_types) - return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens) - - - def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor): - args = get_args() - - local_batch_size = query_tokens.shape[0] - group, rank, world_size = get_group_world_size_rank() - # recall we assert that model_parallel_size == 1 - global_batch_size = world_size * local_batch_size - - query_logits, context_logits = output_tensor - - if world_size > 1: - input_ = torch.empty_like(context_logits).copy_(\ - context_logits).detach_() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank].copy_(input_) - torch.distributed.all_gather(tensor_list, input_, group=group) - - # Check if all-gather happens in order - assert tensor_list[rank].sum().item() == \ - context_logits.sum().item() - - # Preserves the gradient - tensor_list[rank] = context_logits - all_context_logits = torch.cat(tensor_list, dim=0).contiguous() - - # Query tensors - input_ = torch.empty_like(query_logits).copy_(\ - query_logits).detach_() - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank].copy_(input_) - torch.distributed.all_gather(tensor_list, input_, group=group) - - # Check if all-gather happens in order - assert tensor_list[rank].sum().item() == query_logits.sum().item() - - # Preserves the gradient - tensor_list[rank] = query_logits - all_query_logits = torch.cat(tensor_list, dim=0).contiguous() - else: - all_query_logits = query_logits - all_context_logits = context_logits - - retrieval_scores = torch.matmul(all_query_logits, - torch.transpose(all_context_logits, 0, 1)) - # Scaling the retrieval scores - if args.retriever_score_scaling: - retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size) - - if args.train_with_neg: - # if the world size is 3, local batch size is 4, and - # local context size is 8, what we want is - # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19] - labels = [] - local_context_size = context_tokens.shape[0] - for i in range(world_size): - j = i * local_context_size - labels.extend(list(range(j, j + local_batch_size))) - labels = torch.LongTensor(labels).cuda() - assert len(labels) == global_batch_size - else: - labels = torch.arange(global_batch_size).long().cuda() - - # Cross-entropy loss. - softmax_scores = F.log_softmax(retrieval_scores, dim=1) - - loss = F.nll_loss(softmax_scores, labels, reduction='mean') - - max_score, max_idxs = torch.max(softmax_scores, 1) - correct_predictions_count = (max_idxs == labels).sum().float() - - # Reduce loss for logging. - reduced_loss = average_losses_across_data_parallel_group([loss, \ - correct_predictions_count]) - - # Loss scaling for correct losses in Supervised Retrieval - loss = loss * mpu.get_data_parallel_world_size() - - return loss, {'lm loss': reduced_loss[0], - 'correct_prediction_count': reduced_loss[1]} - - - def train_valid_datasets_provider(): - """Build train and validation dataset.""" - args = get_args() - tokenizer = get_tokenizer() - - train_dataset = Dataset('training', - args.train_data, - tokenizer, - args.retriever_seq_length, - evaluate=False) - valid_dataset = Dataset('validation', - args.valid_data, - tokenizer, - args.retriever_seq_length, - evaluate=True) - return train_dataset, valid_dataset - - def model_provider(pre_process=True, post_process=True): - """Build the model.""" - args = get_args() - print_rank_0('building retriever model for {} ...'.format(args.task)) - - model = biencoder_model_provider(only_context_model=False, - only_query_model=False, - biencoder_shared_query_context_model=\ - args.biencoder_shared_query_context_model, - pre_process=pre_process, post_process=post_process) - - return model - - def single_dataset_provider(datapath): - args = get_args() - tokenizer = get_tokenizer() - - name = datapath[0].split('/')[-1].split('.')[0] - return Dataset(name, - datapath, - tokenizer, - args.retriever_seq_length, - evaluate=True) - - def metrics_func_provider(): - """Provide metrics callback function.""" - return accuracy_func_provider(single_dataset_provider) - - """Finetune/evaluate.""" - finetune(train_valid_datasets_provider, - model_provider, - forward_step=cross_entropy_forward_step, - end_of_epoch_callback_provider=metrics_func_provider, - task_collate_fn=task_collate_fn) - -def main(): - args = get_args() - - if args.task == 'RET-FINETUNE-NQ': - from tasks.orqa.supervised.data import NQSupervisedDataset as Dataset - else: - raise NotImplementedError('ORQA task {} is not implemented.'.format( - args.task)) - - orqa(Dataset) - diff --git a/tasks/vision/classification.py b/tasks/vision/classification.py index 71e840757e3..5232b3f5492 100644 --- a/tasks/vision/classification.py +++ b/tasks/vision/classification.py @@ -34,14 +34,13 @@ def train_valid_datasets_provider(): ) return train_ds, valid_ds - def model_provider(pre_process=True, post_process=True): + def model_provider(): """Build the model.""" args = get_args() print_rank_0("building classification model for ImageNet ...") - return VitModel(num_classes=args.num_classes, finetune=True, - pre_process=pre_process, post_process=post_process) + return VitModel(num_classes=args.num_classes, finetune=True) """Finetune/evaluate.""" finetune( diff --git a/tasks/vision/eval_utils.py b/tasks/vision/eval_utils.py index 3a194119c17..aabc04a1594 100644 --- a/tasks/vision/eval_utils.py +++ b/tasks/vision/eval_utils.py @@ -16,14 +16,10 @@ """Evaluation utilities.""" import os -from functools import partial - import torch - from megatron import get_args -from megatron import print_rank_0, print_rank_last +from megatron import print_rank_0 from megatron import mpu -from megatron.schedules import get_forward_backward_func from tasks.vision.finetune_utils import build_data_loader from tasks.vision.finetune_utils import process_batch from torchvision import datasets, transforms @@ -60,7 +56,7 @@ def metrics_func(model, epoch): print_rank_0("calculating metrics ...") correct, total = calculate_correct_answers(model, dataloader, epoch) percent = float(correct) * 100.0 / float(total) - print_rank_last( + print_rank_0( " >> |epoch: {}| overall: correct / total = {} / {} = " "{:.4f} %".format(epoch, correct, total, percent) ) @@ -71,61 +67,29 @@ def metrics_func(model, epoch): def calculate_correct_answers(model, dataloader, epoch): """Calculate correct over total answers""" - args = get_args() - forward_backward_func = get_forward_backward_func() - for m in model: - m.eval() - - def loss_func(labels, output_tensor): - logits = output_tensor - - loss_dict = {} - # Compute the correct answers. - predicted = torch.argmax(logits, dim=-1) - corrects = (predicted == labels).float() - # Add to the counters. - loss_dict['total'] = labels.size(0) - loss_dict['correct'] = corrects.sum().item() - - return 0, loss_dict - - #defined inside to capture output_predictions - def correct_answers_forward_step(batch, model): - try: - batch_ = next(batch) - except BaseException: - batch_ = batch - images, labels = process_batch(batch_) - - # Forward model. - args = get_args() - output_tensor = model(images) - - return output_tensor, partial(loss_func, labels) - + model.eval() with torch.no_grad(): # For all the batches in the dataset. total = 0 correct = 0 for _, batch in enumerate(dataloader): - - loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model, - optimizer=None, timers=None, forward_only=True) - - for loss_dict in loss_dicts: - total += loss_dict['total'] - correct += loss_dict['correct'] - - for m in model: - m.train() + # Run the model forward. + images, labels = process_batch(batch) + logits = model(images).contiguous().float() + # Add output predictions. + # Compute the correct answers. + predicted = torch.argmax(logits, dim=-1) + corrects = (predicted == labels).float() + # Add to the counters. + total += labels.size(0) + correct += corrects.sum().item() + model.train() # Reduce. - if mpu.is_pipeline_last_stage(): - unreduced = torch.cuda.LongTensor([correct, total]) - torch.distributed.all_reduce(unreduced, - group=mpu.get_data_parallel_group()) + unreduced = torch.cuda.LongTensor([correct, total]) + torch.distributed.all_reduce(unreduced, group=mpu.get_data_parallel_group()) - # Print on screen. - correct_ans = unreduced[0].item() - total_count = unreduced[1].item() - return correct_ans, total_count + # Print on screen. + correct_ans = unreduced[0].item() + total_count = unreduced[1].item() + return correct_ans, total_count diff --git a/tasks/vision/finetune_utils.py b/tasks/vision/finetune_utils.py index f9743883ca3..afde4aa8940 100644 --- a/tasks/vision/finetune_utils.py +++ b/tasks/vision/finetune_utils.py @@ -17,7 +17,6 @@ import torch import torch.nn.functional as F -from functools import partial from megatron import get_args from megatron import print_rank_0 from megatron import get_timers @@ -39,21 +38,10 @@ def process_batch(batch): return images, labels -def cross_entropy_loss_func(labels, output_tensor): - logits = output_tensor - - # Cross-entropy loss. - loss = F.cross_entropy(logits.contiguous().float(), labels) - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {'lm loss': averaged_loss[0]} - - -def _cross_entropy_forward_step(batch, model): +def _cross_entropy_forward_step(batch, model, input_tensor): """Simple forward step with cross-entropy loss.""" timers = get_timers() + assert input_tensor is None # Get the batch. timers("batch generator").start() @@ -64,10 +52,16 @@ def _cross_entropy_forward_step(batch, model): images, labels = process_batch(batch_) timers("batch generator").stop() - # Forward model. - output_tensor = model(images) - - return output_tensor, partial(cross_entropy_loss_func, labels) + # Forward model. + logits = model(images).contiguous().float() + + # Cross-entropy loss. + loss = F.cross_entropy(logits, labels) + + # Reduce loss for logging. + average_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {"lm loss": average_loss[0]} def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): @@ -109,29 +103,24 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset): """Traing and validation dataloaders.""" args = get_args() - print_rank_0('building train and validation dataloaders ...') + print_rank_0("building train and validation dataloaders ...") # Training dataset. - train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, - args.num_workers, not args.keep_last) + train_dataloader = build_data_loader( + train_dataset, args.micro_batch_size, args.num_workers, not args.keep_last + ) # Set the training iterations. args.train_iters_per_epoch = len(train_dataloader) args.train_iters = args.epochs * args.train_iters_per_epoch # Validation dataset. For this dataset, we do not need to set up # shuffling so we can just use a simple infinite loop. - valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, - args.num_workers, not args.keep_last) + valid_dataloader_ = build_data_loader( + valid_dataset, args.micro_batch_size, args.num_workers, not args.keep_last + ) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) - # Now that we've built the data loaders, set batch_size arguments - # to the actual batch size the model will see for this dataset. - # This is necessary so pipeline transfers know what size they are - # and the LR schedule, which is based on samples seen, gets set - # correctly. - args.orig_micro_batch_size = args.micro_batch_size - args.orig_global_batch_size = args.global_batch_size - return train_dataloader, valid_dataloader + def _train( model, optimizer, @@ -146,8 +135,7 @@ def _train( timers = get_timers() # Turn on training mode which enables dropout. - for m in model: - m.train() + model.train() # Tracking loss. losses_dict_sum = {} @@ -178,16 +166,12 @@ def _train( start_iteration = 0 # Train for one step. - losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step( + losses_dict, skipped_iter = train_step( forward_step, batch, model, optimizer, lr_scheduler ) iteration += 1 # Logging. - params_norm = None - if args.log_params_norm: - params_norm = calc_params_l2_norm(model) - report_memory_flag = training_log( losses_dict, losses_dict_sum, @@ -196,9 +180,6 @@ def _train( optimizer.get_loss_scale().item(), report_memory_flag, skipped_iter, - grad_norm, - params_norm, - num_zeros_in_grad ) # Autoresume diff --git a/tools/create_doc_index.py b/tools/create_doc_index.py new file mode 100644 index 00000000000..4448d0e2908 --- /dev/null +++ b/tools/create_doc_index.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) + +from megatron import print_rank_0 +from megatron.indexer import IndexBuilder +from megatron.initialize import initialize_megatron + + +def main(): + """Create a BlockData data structure by running an IndexBuilder over an ICT Dataset + - Include all args needed for initial model specification + + Other key args: + --block-data-path: path to write to + --ict-load or --realm-load: path to checkpoint with which to embed + --data-path and --titles-data-path: paths for dataset + --indexer-log-interval: reporting interval + --indexer-batch-size: size specific for indexer jobs + + Check README.md for example script + """ + + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) + index_builder = IndexBuilder() + index_builder.build_and_save_index() + print_rank_0("Build and save indices: done!") + +if __name__ == "__main__": + main() +