From a05a0358fba4758deb8539b3f062d882b70509ac Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 11:10:58 -0800 Subject: [PATCH 01/29] Make HfDatasetDataModule a datasets.load_dataset wrapper Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 56 ++++++++++++++++++--- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 0f45ecf265b7..125b16efadd9 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -16,12 +16,30 @@ import torch from torch.utils.data import DataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler +from datasets import load_dataset +import datasets.dataset_dict.DatasetDict +def listify(x): + if isinstance(x, list): + return x + return [x] + +def extract_split(dataset, split_names): + if isinstance(dataset, datasets.dataset_dict.DatasetDict): + for split_name in split_names: + if split_name in dataset: + return dataset[split_name] + raise ValueError(("Dataset does not contain any of " + str(split_names) + \ + "; available splits= " + str(dataset.keys())) + ) + else: + return dataset class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, - dataset, + path, + split=None, num_workers=2, pin_memory=True, persistent_workers=True, @@ -31,11 +49,13 @@ def __init__( pad_token_id=0, use_mcore_sampler=False, mcore_dataloader_type='cyclic', + **kwargs, ) -> None: super().__init__() assert pad_token_id is not None - self.dataset = dataset + self.dataset = load_dataset(path, **kwargs) + self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers @@ -84,17 +104,41 @@ def setup(self, stage: str): dataloader_type=self.mcore_dataloader_type, ) - def train_dataloader(self, collate_fn=None): - from nemo.lightning.data import add_megatron_sampler - + def _make_dataloader(self, dataset, collate_fn=None): if collate_fn is None: collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) return DataLoader( - self.dataset, + dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, collate_fn=collate_fn, batch_size=self.micro_batch_size, ) + + def train_dataloader(self, collate_fn=None, split_names=["train", "training"]): + dataset = extract_split(self.dataset, split_names) + return self._make_dataloader(dataset, collate_fn) + + def val_dataloader(self, collate_fn=None, split_names=["val", "validation", "eval"]): + dataset = extract_split(self.dataset, split_names) + return self._make_dataloader(dataset, collate_fn) + + def test_dataloader(self, collate_fn=None, split_names=["test", "testing"]): + dataset = extract_split(self.dataset, split_names) + return self._make_dataloader(dataset, collate_fn) + + def map(self, function=None, split_names=None, **kwargs): + if split_names is not None: + dataset = extract_split(self.dataset, split_names) + else: + dataset = self.dataset + + if isinstance(dataset, datasets.dataset_dict.DatasetDict): + dataset_iter = dataset.values() + else: + dataset_iter = [dataset] + + for subset in dataset_iter: + subset.map(function, **kwargs) \ No newline at end of file From 8cacad8529fd2f7133754bb777e15fdd38023e8b Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 11:23:13 -0800 Subject: [PATCH 02/29] add logging Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 125b16efadd9..63b368cb613f 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import lightning.pytorch as pl -import torch -from torch.utils.data import DataLoader -from nemo.lightning.pytorch.plugins import MegatronDataSampler from datasets import load_dataset +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging +from torch.utils.data import DataLoader + import datasets.dataset_dict.DatasetDict +import lightning.pytorch as pl +import torch def listify(x): if isinstance(x, list): @@ -54,7 +56,15 @@ def __init__( super().__init__() assert pad_token_id is not None + logging.info(f"Loading HF dataset from {path}") + self.dataset = load_dataset(path, **kwargs) + if isinstance(self.dataset, datasets.dataset_dict.DatasetDict): + split_names = self.dataset.keys() + logging.info(f"HF dataset has the following splits: {split_names}") + else: + logging.info(f"Loaded HF dataset has a single split.") + self.num_workers = num_workers self.pin_memory = pin_memory From d863a5844b82d0d08c309b95c3614ea5b627b07c Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 11:24:43 -0800 Subject: [PATCH 03/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 63b368cb613f..5527781565b4 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -141,14 +141,14 @@ def test_dataloader(self, collate_fn=None, split_names=["test", "testing"]): def map(self, function=None, split_names=None, **kwargs): if split_names is not None: - dataset = extract_split(self.dataset, split_names) + datasets = extract_split(self.dataset, split_names) else: - dataset = self.dataset + datasets = self.dataset if isinstance(dataset, datasets.dataset_dict.DatasetDict): - dataset_iter = dataset.values() + dataset_iter = datasets.values() else: - dataset_iter = [dataset] + dataset_iter = [datasets] for subset in dataset_iter: subset.map(function, **kwargs) \ No newline at end of file From 23ff8504b0a399398755881007c2181b87308f40 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 11:28:22 -0800 Subject: [PATCH 04/29] Update HFDatasetDataModule Signed-off-by: Alexandros Koumparoulis --- examples/llm/peft/hf.py | 14 +++++--------- nemo/collections/llm/gpt/data/api.py | 4 ++-- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 357dc5a7bd17..c24c5958b388 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -18,7 +18,7 @@ from nemo.collections import llm -def mk_hf_dataset(tokenizer): +def make_squad_hf_dataset(tokenizer): EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN def formatting_prompts_func(examples): @@ -45,11 +45,9 @@ def formatting_prompts_func(examples): 'labels': tokens[1:] + [tokens[-1]], } - from datasets import load_dataset - - dataset = load_dataset("rajpurkar/squad", split="train") - dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2) - return dataset + datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id) + datamodule.map(formatting_prompts_func, batched=False, batch_size=2) + return datamodule if __name__ == '__main__': @@ -80,9 +78,7 @@ def formatting_prompts_func(examples): llm.api.finetune( model=llm.HFAutoModelForCausalLM(args.model), - data=llm.HFDatasetDataModule( - mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id - ), + data=make_squad_hf_dataset(tokenizer.tokenizer), trainer=nl.Trainer( devices=args.devices, max_steps=args.max_steps, diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py index 374bee83b8b2..b4e603186bf4 100644 --- a/nemo/collections/llm/gpt/data/api.py +++ b/nemo/collections/llm/gpt/data/api.py @@ -41,8 +41,8 @@ def dolly() -> pl.LightningDataModule: @run.cli.factory @run.autoconvert -def hf_dataset(dataset: str) -> pl.LightningDataModule: - return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) +def hf_dataset(path: str) -> pl.LightningDataModule: + return HFDatasetDataModule(path=path, global_batch_size=16, micro_batch_size=2) __all__ = ["mock", "squad", "dolly", "hf_dataset"] From ca25a9688e5d404a352ee1578c01374558ce5ed0 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:13:50 -0800 Subject: [PATCH 05/29] refactor Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 115 ++++++++++++++++---- 1 file changed, 91 insertions(+), 24 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 5527781565b4..edaae38116ee 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -26,16 +26,80 @@ def listify(x): return x return [x] -def extract_split(dataset, split_names): +def is_dataset_dict(dataset): + return isinstance(dataset, datasets.dataset_dict.DatasetDict) + +def extract_matching_split(dataset, split_names): + assert is_dataset_dict(dataset) + for split_name in split_names: + if split_name in dataset: + return dataset[split_name] + raise ValueError(("Dataset does not contain any of " + str(split_names) + \ + "; available splits= " + str(dataset.keys())) + ) + + +def make_dataset_splits(path, split=None, **kwargs): + """ + Loads a dataset with datasets.load_dataset and returns a dict containing dataset splits, + For example: + + ans = make_dataset_splits("dataset-id") + $ ds = load_dataset("dataset-id") + $ print(ds) + > DatasetDict({ + > train: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 87599 + > }) + > validation: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 10570 + > }) + > }) + + In this case the value of `ans` (returned value) will be: + $ print(ans) + > { + > "train": Dataset .. (with 87599 rows), + > "val": Dataset .. (with 10570 rows), + > } + """ + split_names = ['train', 'test', 'val'] + dataset_splits = {split: None for split in split_names} + + alias_to_split = {} + for split_name, aliases in zip(split_names, [train_aliases, test_aliases, val_aliases]): + for alias in aliases: + alias_to_split[alias] = split_name + + dataset = load_dataset(path, split=split, **kwargs) + if isinstance(dataset, datasets.dataset_dict.DatasetDict): - for split_name in split_names: - if split_name in dataset: - return dataset[split_name] - raise ValueError(("Dataset does not contain any of " + str(split_names) + \ - "; available splits= " + str(dataset.keys())) - ) + dataset_split_names = dataset.keys() + logging.info(f"HF dataset has the following splits: {dataset_split_names}") + for alias_split_name, split in dataset.items(): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = split + elif isinstance(split, list): + logging.info(f"Loaded HF dataset will use " + str(self.split_names) + " splits.") + assert isinstance(dataset, list) + for i, alias_split_name in enumerate(self.split_names): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset[i] + elif isinstance(split, str): + logging.info(f"Loaded HF dataset has a single split.") + assert not isinstance(dataset, list) + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset else: - return dataset + raise ValueError("Expected split name to be None, str or a list") + + assert sum(map(lambda x: x is not None, dataset_splits.values())) > 0, "Expected at least one dataset to have been initialized" + return dataset_splits class HFDatasetDataModule(pl.LightningDataModule): def __init__( @@ -51,6 +115,9 @@ def __init__( pad_token_id=0, use_mcore_sampler=False, mcore_dataloader_type='cyclic', + train_aliases=["train", "training"], + test_aliases=["test", "testing"], + val_aliases=["val", "validation", "eval"], **kwargs, ) -> None: super().__init__() @@ -58,13 +125,8 @@ def __init__( logging.info(f"Loading HF dataset from {path}") - self.dataset = load_dataset(path, **kwargs) - if isinstance(self.dataset, datasets.dataset_dict.DatasetDict): - split_names = self.dataset.keys() - logging.info(f"HF dataset has the following splits: {split_names}") - else: - logging.info(f"Loaded HF dataset has a single split.") - + # self.dataset_splits will hold the actual dataset for each split. + self.dataset_splits = make_dataset_splits(path, split, **kwargs) self.num_workers = num_workers self.pin_memory = pin_memory @@ -127,17 +189,22 @@ def _make_dataloader(self, dataset, collate_fn=None): batch_size=self.micro_batch_size, ) - def train_dataloader(self, collate_fn=None, split_names=["train", "training"]): - dataset = extract_split(self.dataset, split_names) - return self._make_dataloader(dataset, collate_fn) + def _extract_split_from_dict(self, split_names): + if is_dataset_dict(self.dataset): + return extract_matching_split(self.dataset, split_names) + else: + if self.split is not None: + assert any(map(lambda x: x in split_names, self.split)) + return self.dataset + + def train_dataloader(self, collate_fn=None): + return self._make_dataloader(self.dataset_splits['train'], collate_fn) - def val_dataloader(self, collate_fn=None, split_names=["val", "validation", "eval"]): - dataset = extract_split(self.dataset, split_names) - return self._make_dataloader(dataset, collate_fn) + def val_dataloader(self, collate_fn=None): + return self._make_dataloader(self.dataset_splits['val'], collate_fn) - def test_dataloader(self, collate_fn=None, split_names=["test", "testing"]): - dataset = extract_split(self.dataset, split_names) - return self._make_dataloader(dataset, collate_fn) + def test_dataloader(self, collate_fn=None): + return self._make_dataloader(self.dataset_splits['test'], collate_fn) def map(self, function=None, split_names=None, **kwargs): if split_names is not None: From 635499e1df61e07f42b63080f3f3100e3f544a09 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:14:29 -0800 Subject: [PATCH 06/29] refactor fixup Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 26 --------------------- 1 file changed, 26 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index edaae38116ee..bb8e533fb4cd 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -21,24 +21,6 @@ import lightning.pytorch as pl import torch -def listify(x): - if isinstance(x, list): - return x - return [x] - -def is_dataset_dict(dataset): - return isinstance(dataset, datasets.dataset_dict.DatasetDict) - -def extract_matching_split(dataset, split_names): - assert is_dataset_dict(dataset) - for split_name in split_names: - if split_name in dataset: - return dataset[split_name] - raise ValueError(("Dataset does not contain any of " + str(split_names) + \ - "; available splits= " + str(dataset.keys())) - ) - - def make_dataset_splits(path, split=None, **kwargs): """ Loads a dataset with datasets.load_dataset and returns a dict containing dataset splits, @@ -189,14 +171,6 @@ def _make_dataloader(self, dataset, collate_fn=None): batch_size=self.micro_batch_size, ) - def _extract_split_from_dict(self, split_names): - if is_dataset_dict(self.dataset): - return extract_matching_split(self.dataset, split_names) - else: - if self.split is not None: - assert any(map(lambda x: x in split_names, self.split)) - return self.dataset - def train_dataloader(self, collate_fn=None): return self._make_dataloader(self.dataset_splits['train'], collate_fn) From b47c40dbcfd0861a6ae985a6c1c775906db4edc4 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:16:44 -0800 Subject: [PATCH 07/29] refactor fixup #2 Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index bb8e533fb4cd..fb1ff9b7ff63 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -181,15 +181,12 @@ def test_dataloader(self, collate_fn=None): return self._make_dataloader(self.dataset_splits['test'], collate_fn) def map(self, function=None, split_names=None, **kwargs): - if split_names is not None: - datasets = extract_split(self.dataset, split_names) + if isintance(split_names, str): + dataset_splits = [self.dataset[split_names]] + elif isintance(split_names, list): + dataset_splits = [self.dataset[k] for k in split_names] else: - datasets = self.dataset + dataset_splits = self.dataset.values() - if isinstance(dataset, datasets.dataset_dict.DatasetDict): - dataset_iter = datasets.values() - else: - dataset_iter = [datasets] - - for subset in dataset_iter: + for subset in dataset_splits: subset.map(function, **kwargs) \ No newline at end of file From 56c29ac31020f1b5c33e086a97a11358d8df0652 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:17:33 -0800 Subject: [PATCH 08/29] do not expand Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index fb1ff9b7ff63..76937274c794 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -21,7 +21,7 @@ import lightning.pytorch as pl import torch -def make_dataset_splits(path, split=None, **kwargs): +def make_dataset_splits(path, split, kwargs): """ Loads a dataset with datasets.load_dataset and returns a dict containing dataset splits, For example: @@ -108,7 +108,7 @@ def __init__( logging.info(f"Loading HF dataset from {path}") # self.dataset_splits will hold the actual dataset for each split. - self.dataset_splits = make_dataset_splits(path, split, **kwargs) + self.dataset_splits = make_dataset_splits(path, split, kwargs) self.num_workers = num_workers self.pin_memory = pin_memory From 8ce2ba7ee1780c0e5812b45545b2fff41dae4a89 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:28:40 -0800 Subject: [PATCH 09/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 76937274c794..69c8a7f69928 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -74,6 +74,11 @@ def make_dataset_splits(path, split, kwargs): elif isinstance(split, str): logging.info(f"Loaded HF dataset has a single split.") assert not isinstance(dataset, list) + alias_split_name = split + if '+' in alias_split_name: + raise ValueError("Split concatenation not supported") + elif '[' in alias_split_name: + alias_split_name = alias_split_name.split('[')[0] split_name = alias_to_split[alias_split_name] assert dataset_splits[split_name] is None dataset_splits[split_name] = dataset From e55dc2ef674f4ddb618ee6dcc34e9dc275bdcde9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:30:34 -0800 Subject: [PATCH 10/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 69c8a7f69928..851c4a9ff60f 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -23,7 +23,9 @@ def make_dataset_splits(path, split, kwargs): """ - Loads a dataset with datasets.load_dataset and returns a dict containing dataset splits, + Loads a dataset with datasets.load_dataset and + returns a dictionary containing all dataset splits. + For example: ans = make_dataset_splits("dataset-id") @@ -85,7 +87,8 @@ def make_dataset_splits(path, split, kwargs): else: raise ValueError("Expected split name to be None, str or a list") - assert sum(map(lambda x: x is not None, dataset_splits.values())) > 0, "Expected at least one dataset to have been initialized" + assert sum(map(lambda x: x is not None, dataset_splits.values())) > 0, \ + "Expected at least one dataset to have been initialized" return dataset_splits class HFDatasetDataModule(pl.LightningDataModule): From 1b711aa84672b2bbdce15e7fa8a1c4e743d14e66 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:50:50 -0800 Subject: [PATCH 11/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 851c4a9ff60f..f672e1624b1a 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -17,7 +17,7 @@ from nemo.utils import logging from torch.utils.data import DataLoader -import datasets.dataset_dict.DatasetDict +import datasets.dataset_dict import lightning.pytorch as pl import torch From 8047ecbd44da1bcd7fc556edcfed440136f9b38e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:56:47 -0800 Subject: [PATCH 12/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index f672e1624b1a..204bf13bdc6e 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -21,7 +21,7 @@ import lightning.pytorch as pl import torch -def make_dataset_splits(path, split, kwargs): +def make_dataset_splits(path, split, split_aliases, kwargs): """ Loads a dataset with datasets.load_dataset and returns a dictionary containing all dataset splits. @@ -49,15 +49,17 @@ def make_dataset_splits(path, split, kwargs): > "val": Dataset .. (with 10570 rows), > } """ + dataset = load_dataset(path, split=split, **kwargs) + split_names = ['train', 'test', 'val'] dataset_splits = {split: None for split in split_names} alias_to_split = {} - for split_name, aliases in zip(split_names, [train_aliases, test_aliases, val_aliases]): - for alias in aliases: + for split_name, _split_aliases in split_aliases.items(): + assert split_name in split_names + for alias in _split_aliases: alias_to_split[alias] = split_name - dataset = load_dataset(path, split=split, **kwargs) if isinstance(dataset, datasets.dataset_dict.DatasetDict): dataset_split_names = dataset.keys() @@ -67,9 +69,9 @@ def make_dataset_splits(path, split, kwargs): assert dataset_splits[split_name] is None dataset_splits[split_name] = split elif isinstance(split, list): - logging.info(f"Loaded HF dataset will use " + str(self.split_names) + " splits.") + logging.info(f"Loaded HF dataset will use " + str(split) + " splits.") assert isinstance(dataset, list) - for i, alias_split_name in enumerate(self.split_names): + for i, alias_split_name in enumerate(split): split_name = alias_to_split[alias_split_name] assert dataset_splits[split_name] is None dataset_splits[split_name] = dataset[i] @@ -116,7 +118,9 @@ def __init__( logging.info(f"Loading HF dataset from {path}") # self.dataset_splits will hold the actual dataset for each split. - self.dataset_splits = make_dataset_splits(path, split, kwargs) + self.dataset_splits = make_dataset_splits(path, split, { + 'train':train_aliases, 'test': test_aliases, 'val': val_aliases}, + kwargs) self.num_workers = num_workers self.pin_memory = pin_memory From 1b5989d4d4352881368a9e952487f1e78c4af1e9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:57:29 -0800 Subject: [PATCH 13/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 204bf13bdc6e..f83e893c0b65 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -118,9 +118,12 @@ def __init__( logging.info(f"Loading HF dataset from {path}") # self.dataset_splits will hold the actual dataset for each split. - self.dataset_splits = make_dataset_splits(path, split, { - 'train':train_aliases, 'test': test_aliases, 'val': val_aliases}, - kwargs) + split_aliases = { + 'train': train_aliases, + 'test': test_aliases, + 'val': val_aliases + } + self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) self.num_workers = num_workers self.pin_memory = pin_memory From cd446be83f8cddeff8ec6835aba47dd1bc808893 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:59:02 -0800 Subject: [PATCH 14/29] doc Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index f83e893c0b65..ec1854ecb351 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -117,12 +117,16 @@ def __init__( logging.info(f"Loading HF dataset from {path}") - # self.dataset_splits will hold the actual dataset for each split. + # A dataset usually will have several splits (e.g. train, val, test, etc). + # We canonicalize synonym names to canonical names (train, test, val). + # A synonym can be a prefix/suffixed word e.g. train <> training. split_aliases = { 'train': train_aliases, 'test': test_aliases, 'val': val_aliases } + + # self.dataset_splits will hold the actual dataset for each split. self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) self.num_workers = num_workers From 2777d2f54f69a02b97dc225db5e764830d969f62 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 12:59:30 -0800 Subject: [PATCH 15/29] doc Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index ec1854ecb351..4cccf023314d 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -118,7 +118,7 @@ def __init__( logging.info(f"Loading HF dataset from {path}") # A dataset usually will have several splits (e.g. train, val, test, etc). - # We canonicalize synonym names to canonical names (train, test, val). + # We map synonym names to canonical names (train, test, val). # A synonym can be a prefix/suffixed word e.g. train <> training. split_aliases = { 'train': train_aliases, From 24782462cfb7a45397ccb0ea2d1d6ec7ffbcc2e1 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 13:00:11 -0800 Subject: [PATCH 16/29] add synonym Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 4cccf023314d..404a6289e9fe 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -109,7 +109,7 @@ def __init__( mcore_dataloader_type='cyclic', train_aliases=["train", "training"], test_aliases=["test", "testing"], - val_aliases=["val", "validation", "eval"], + val_aliases=["val", "validation", "valid", "eval"], **kwargs, ) -> None: super().__init__() From 957e7884a97aff73127d7a1fd2663ff66105cd7a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 13:15:19 -0800 Subject: [PATCH 17/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 404a6289e9fe..373547157e37 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -200,12 +200,12 @@ def test_dataloader(self, collate_fn=None): return self._make_dataloader(self.dataset_splits['test'], collate_fn) def map(self, function=None, split_names=None, **kwargs): - if isintance(split_names, str): - dataset_splits = [self.dataset[split_names]] - elif isintance(split_names, list): - dataset_splits = [self.dataset[k] for k in split_names] + if isinstance(split_names, str): + dataset_splits = [self.dataset_splits[split_names]] + elif isinstance(split_names, list): + dataset_splits = [self.dataset_splits[k] for k in split_names] else: - dataset_splits = self.dataset.values() + dataset_splits = self.dataset_splits.values() for subset in dataset_splits: - subset.map(function, **kwargs) \ No newline at end of file + subset.map(function, **kwargs) From 3272f0f92de0fdc6b7b1404cb9f37906462f9649 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 13:15:44 -0800 Subject: [PATCH 18/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 373547157e37..e8687fccb3e3 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -208,4 +208,6 @@ def map(self, function=None, split_names=None, **kwargs): dataset_splits = self.dataset_splits.values() for subset in dataset_splits: + if subset is None: + continue subset.map(function, **kwargs) From 024f668b531259c55993b51dd3bec914b9ccba52 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 13:27:11 -0800 Subject: [PATCH 19/29] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index e8687fccb3e3..bd941181f3f9 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -178,6 +178,8 @@ def setup(self, stage: str): ) def _make_dataloader(self, dataset, collate_fn=None): + assert dataset is not None + if collate_fn is None: collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) @@ -201,13 +203,13 @@ def test_dataloader(self, collate_fn=None): def map(self, function=None, split_names=None, **kwargs): if isinstance(split_names, str): - dataset_splits = [self.dataset_splits[split_names]] + dataset_splits = {'split_names': self.dataset_splits[split_names]} elif isinstance(split_names, list): - dataset_splits = [self.dataset_splits[k] for k in split_names] + dataset_splits = {k: self.dataset_splits[k] for k in split_names} else: - dataset_splits = self.dataset_splits.values() + dataset_splits = self.dataset_splits - for subset in dataset_splits: + for split_name, subset in dataset_splits.items(): if subset is None: continue - subset.map(function, **kwargs) + dataset_splits[split_name] = subset.map(function, **kwargs) From cf24f5f30f322bf636039994c9efd3ed97806a07 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Dec 2024 13:31:14 -0800 Subject: [PATCH 20/29] typo Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index bd941181f3f9..2ef52362e038 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -203,7 +203,7 @@ def test_dataloader(self, collate_fn=None): def map(self, function=None, split_names=None, **kwargs): if isinstance(split_names, str): - dataset_splits = {'split_names': self.dataset_splits[split_names]} + dataset_splits = {split_names: self.dataset_splits[split_names]} elif isinstance(split_names, list): dataset_splits = {k: self.dataset_splits[k] for k in split_names} else: From 2ea77e174b64540c34a3854c11973ce65724fe5e Mon Sep 17 00:00:00 2001 From: akoumpa Date: Fri, 6 Dec 2024 21:32:07 +0000 Subject: [PATCH 21/29] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/llm/gpt/data/hf_dataset.py | 22 ++++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 2ef52362e038..f36c8465e6ff 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datasets.dataset_dict +import lightning.pytorch as pl +import torch from datasets import load_dataset +from torch.utils.data import DataLoader + from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging -from torch.utils.data import DataLoader -import datasets.dataset_dict -import lightning.pytorch as pl -import torch def make_dataset_splits(path, split, split_aliases, kwargs): """ @@ -60,7 +61,6 @@ def make_dataset_splits(path, split, split_aliases, kwargs): for alias in _split_aliases: alias_to_split[alias] = split_name - if isinstance(dataset, datasets.dataset_dict.DatasetDict): dataset_split_names = dataset.keys() logging.info(f"HF dataset has the following splits: {dataset_split_names}") @@ -89,10 +89,12 @@ def make_dataset_splits(path, split, split_aliases, kwargs): else: raise ValueError("Expected split name to be None, str or a list") - assert sum(map(lambda x: x is not None, dataset_splits.values())) > 0, \ - "Expected at least one dataset to have been initialized" + assert ( + sum(map(lambda x: x is not None, dataset_splits.values())) > 0 + ), "Expected at least one dataset to have been initialized" return dataset_splits + class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, @@ -120,11 +122,7 @@ def __init__( # A dataset usually will have several splits (e.g. train, val, test, etc). # We map synonym names to canonical names (train, test, val). # A synonym can be a prefix/suffixed word e.g. train <> training. - split_aliases = { - 'train': train_aliases, - 'test': test_aliases, - 'val': val_aliases - } + split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases} # self.dataset_splits will hold the actual dataset for each split. self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) From 4053b912532966ccdd90e8217bbe521cda90101e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 12:05:59 -0800 Subject: [PATCH 22/29] Add train/val/test attributes Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index f36c8465e6ff..483b2fdfc15b 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -190,14 +190,26 @@ def _make_dataloader(self, dataset, collate_fn=None): batch_size=self.micro_batch_size, ) + @property + def train(self): + return self.dataset_splits['train'] + + @property + def val(self): + return self.dataset_splits['val'] + + @property + def test(self): + return self.dataset_splits['test'] + def train_dataloader(self, collate_fn=None): - return self._make_dataloader(self.dataset_splits['train'], collate_fn) + return self._make_dataloader(self.train, collate_fn) def val_dataloader(self, collate_fn=None): - return self._make_dataloader(self.dataset_splits['val'], collate_fn) + return self._make_dataloader(self.val, collate_fn) def test_dataloader(self, collate_fn=None): - return self._make_dataloader(self.dataset_splits['test'], collate_fn) + return self._make_dataloader(self.test, collate_fn) def map(self, function=None, split_names=None, **kwargs): if isinstance(split_names, str): From 1749ee74898822e027b7bebd512751dcfafebe20 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 12:06:19 -0800 Subject: [PATCH 23/29] Add test for hf-datamodule Signed-off-by: Alexandros Koumparoulis --- .../llm/gpt/data/test_hf_datamodule.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/collections/llm/gpt/data/test_hf_datamodule.py diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py new file mode 100644 index 000000000000..fddfd33163eb --- /dev/null +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import nemo.lightning as nl +from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections import llm + + +DATA_PATH = "/home/TestData/lite/hf_cache/squad/" + +def test_load_single_split(): + ds = llm.HFDatasetDataModule( + path=DATA_PATH, + split='train', + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is None + assert ds.val is None + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + +def test_load_nonexistent_split(): + exception_msg = '' + expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' + try: + llm.HFDatasetDataModule( + path=DATA_PATH, + split='this_split_name_should_not_exist', + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except ValueError as e: + exception_msg = str(e) + assert exception_msg == expected_msg, exception_msg + +def test_load_multiple_split(): + ds = llm.HFDatasetDataModule( + path=DATA_PATH, + split=['train', 'validation'], + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert isinstance(ds.train, Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is not None + assert ds.val is not None + assert isinstance(ds.dataset_splits['val'], Dataset) + assert isinstance(ds.val, Dataset) + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + + + +def test_validate_dataset_asset_accessibility_file_does_not_exist(): + raised_exception = False + try: + data = llm.HFDatasetDataModule( + path="/this/path/should/not/exist/", + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except FileNotFoundError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a FileNotFoundError" + + +def test_validate_dataset_asset_accessibility_file_is_none(): #tokenizer, trainer): + raised_exception = False + try: + data = llm.HFDatasetDataModule( + path=None, + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except TypeError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a ValueError" From eceee08390cf16e362f4fb0caac3d2388e137b7a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 12:07:23 -0800 Subject: [PATCH 24/29] Import lazily to avoid breaking with older megatron versions Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/inference/base.py | 6 +++--- nemo/collections/llm/t5/model/t5.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 795d6efadd3a..d40f55ddd387 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -25,9 +25,6 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( AbstractModelInferenceWrapper, ) -from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( - EncoderDecoderTextGenerationController, -) from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( SimpleTextGenerationController, ) @@ -232,6 +229,9 @@ def generate( Returns: dict: A dictionary containing the generated results. """ + from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( + EncoderDecoderTextGenerationController, + ) if encoder_prompts is not None: text_generation_controller = EncoderDecoderTextGenerationController( inference_wrapped_model=model, tokenizer=tokenizer diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index 940c0e51ee92..cc4ad0665757 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -20,7 +20,7 @@ import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig -from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper + from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec @@ -319,7 +319,7 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, padded_vocab_size=self.tokenizer.vocab_size, ) - + from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config) return model_inference_wrapper From 107d0a4029aa0a8f60b4097ebda9f331b32378fa Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 12:24:45 -0800 Subject: [PATCH 25/29] bot happy Signed-off-by: Alexandros Koumparoulis --- tests/collections/llm/gpt/data/test_hf_datamodule.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index fddfd33163eb..22ae62e4e26c 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -15,8 +15,6 @@ import pytest import nemo.lightning as nl -from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule -from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.collections import llm @@ -89,7 +87,7 @@ def test_load_multiple_split(): def test_validate_dataset_asset_accessibility_file_does_not_exist(): raised_exception = False try: - data = llm.HFDatasetDataModule( + llm.HFDatasetDataModule( path="/this/path/should/not/exist/", seq_length=512, micro_batch_size=2, @@ -104,7 +102,7 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist(): def test_validate_dataset_asset_accessibility_file_is_none(): #tokenizer, trainer): raised_exception = False try: - data = llm.HFDatasetDataModule( + llm.HFDatasetDataModule( path=None, seq_length=512, micro_batch_size=2, From ab307354de810c7ccb20c09b2b82e89e6b6f60ac Mon Sep 17 00:00:00 2001 From: akoumpa Date: Tue, 10 Dec 2024 20:25:40 +0000 Subject: [PATCH 26/29] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/llm/inference/base.py | 1 + nemo/collections/llm/t5/model/t5.py | 1 + tests/collections/llm/gpt/data/test_hf_datamodule.py | 8 ++++++-- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index d40f55ddd387..6c89a1b42b15 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -232,6 +232,7 @@ def generate( from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( EncoderDecoderTextGenerationController, ) + if encoder_prompts is not None: text_generation_controller = EncoderDecoderTextGenerationController( inference_wrapped_model=model, tokenizer=tokenizer diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index cc4ad0665757..743d16f57c2b 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -320,6 +320,7 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres padded_vocab_size=self.tokenizer.vocab_size, ) from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper + model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config) return model_inference_wrapper diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index 22ae62e4e26c..28e9c33db6bd 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -20,6 +20,7 @@ DATA_PATH = "/home/TestData/lite/hf_cache/squad/" + def test_load_single_split(): ds = llm.HFDatasetDataModule( path=DATA_PATH, @@ -29,6 +30,7 @@ def test_load_single_split(): global_batch_size=2, ) from datasets.arrow_dataset import Dataset + assert isinstance(ds.dataset_splits, dict) assert len(ds.dataset_splits) == 3 assert 'train' in ds.dataset_splits @@ -42,6 +44,7 @@ def test_load_single_split(): assert ds.dataset_splits['test'] is None assert ds.test is None + def test_load_nonexistent_split(): exception_msg = '' expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' @@ -57,6 +60,7 @@ def test_load_nonexistent_split(): exception_msg = str(e) assert exception_msg == expected_msg, exception_msg + def test_load_multiple_split(): ds = llm.HFDatasetDataModule( path=DATA_PATH, @@ -66,6 +70,7 @@ def test_load_multiple_split(): global_batch_size=2, ) from datasets.arrow_dataset import Dataset + assert isinstance(ds.dataset_splits, dict) assert len(ds.dataset_splits) == 3 assert 'train' in ds.dataset_splits @@ -83,7 +88,6 @@ def test_load_multiple_split(): assert ds.test is None - def test_validate_dataset_asset_accessibility_file_does_not_exist(): raised_exception = False try: @@ -99,7 +103,7 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist(): assert raised_exception == True, "Expected to raise a FileNotFoundError" -def test_validate_dataset_asset_accessibility_file_is_none(): #tokenizer, trainer): +def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer): raised_exception = False try: llm.HFDatasetDataModule( From 19c4da4d3d8db12975165eb9a466a791ea1449b6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 13:07:24 -0800 Subject: [PATCH 27/29] bot happy2 Signed-off-by: Alexandros Koumparoulis --- tests/collections/llm/gpt/data/test_hf_datamodule.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index 28e9c33db6bd..a8d264701d39 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -import nemo.lightning as nl from nemo.collections import llm - DATA_PATH = "/home/TestData/lite/hf_cache/squad/" From afaaf8a74819d85ae5aae499262ca03f886a676a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Tue, 10 Dec 2024 13:15:44 -0800 Subject: [PATCH 28/29] add doc-strings and collate-fn arg Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 30 ++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 483b2fdfc15b..0dd43111826e 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -96,9 +96,22 @@ def make_dataset_splits(path, split, split_aliases, kwargs): class HFDatasetDataModule(pl.LightningDataModule): + """HFDatasetDataModule wraps HF's load_dataset (datasets library) + so that it can be used within NeMo. + Users can select whether to use an mcore-sampler via use_mcore_sampler arg. + + Usage examples: + + - loading a single split (train) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split="train") + + - loading multiple splits (train, validation) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"]) + """ def __init__( self, path, + collate_fn=None, split=None, num_workers=2, pin_memory=True, @@ -127,6 +140,11 @@ def __init__( # self.dataset_splits will hold the actual dataset for each split. self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) + if collate_fn is None: + self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + else: + self._collate_fn = collate_fn + self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers @@ -202,14 +220,14 @@ def val(self): def test(self): return self.dataset_splits['test'] - def train_dataloader(self, collate_fn=None): - return self._make_dataloader(self.train, collate_fn) + def train_dataloader(self): + return self._make_dataloader(self.train, self._collate_fn) - def val_dataloader(self, collate_fn=None): - return self._make_dataloader(self.val, collate_fn) + def val_dataloader(self): + return self._make_dataloader(self.val, self._collate_fn) - def test_dataloader(self, collate_fn=None): - return self._make_dataloader(self.test, collate_fn) + def test_dataloader(self): + return self._make_dataloader(self.test, self._collate_fn) def map(self, function=None, split_names=None, **kwargs): if isinstance(split_names, str): From 8843d9ed021440d0f5ff84568092edb9fd4d0e80 Mon Sep 17 00:00:00 2001 From: akoumpa Date: Tue, 10 Dec 2024 21:21:28 +0000 Subject: [PATCH 29/29] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/llm/gpt/data/hf_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 0dd43111826e..039e5b90b096 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -108,6 +108,7 @@ class HFDatasetDataModule(pl.LightningDataModule): - loading multiple splits (train, validation) from a dataset llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"]) """ + def __init__( self, path,