From 90a1536ef8af45b7eadbc0bd1a8c724afb3070f6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:36:46 -0800 Subject: [PATCH 1/7] Add from_dict method Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 79 ++++++++++++--------- 1 file changed, 47 insertions(+), 32 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 039e5b90b096..58555fd23956 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datasets.dataset_dict import lightning.pytorch as pl import torch from datasets import load_dataset @@ -22,38 +21,39 @@ from nemo.utils import logging -def make_dataset_splits(path, split, split_aliases, kwargs): +def make_dataset_splits(dataset, split, split_aliases): """ - Loads a dataset with datasets.load_dataset and - returns a dictionary containing all dataset splits. + Given a dataset (e.g. from datasets.load_dataset or datasets.Dataset.from_dict) it + returns a dictionary containing the corresponding dataset splits. For example: - ans = make_dataset_splits("dataset-id") - $ ds = load_dataset("dataset-id") - $ print(ds) - > DatasetDict({ - > train: Dataset({ - > features: ['id', 'title', 'context', 'question', 'answers'], - > num_rows: 87599 - > }) - > validation: Dataset({ - > features: ['id', 'title', 'context', 'question', 'answers'], - > num_rows: 10570 - > }) - > }) - - In this case the value of `ans` (returned value) will be: + $ ds = load_dataset("dataset-id") + $ ans = make_dataset_splits(ds) + + # `ds` contains the following + $ print(ds) + > DatasetDict({ + > train: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 87599 + > }) + > validation: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 10570 + > }) + > }) + + # In this case the value of `ans` (returned value) will be: $ print(ans) > { > "train": Dataset .. (with 87599 rows), > "val": Dataset .. (with 10570 rows), > } """ - dataset = load_dataset(path, split=split, **kwargs) - + from datasets import Dataset, DatasetDict split_names = ['train', 'test', 'val'] - dataset_splits = {split: None for split in split_names} + dataset_splits = {_split: None for _split in split_names} alias_to_split = {} for split_name, _split_aliases in split_aliases.items(): @@ -61,7 +61,10 @@ def make_dataset_splits(path, split, split_aliases, kwargs): for alias in _split_aliases: alias_to_split[alias] = split_name - if isinstance(dataset, datasets.dataset_dict.DatasetDict): + if isinstance(dataset, Dataset): + assert isinstance(split, str), "Expected split to be a string, but got " + str(type(split)) + dataset_splits[split] = dataset + elif isinstance(dataset, DatasetDict): dataset_split_names = dataset.keys() logging.info(f"HF dataset has the following splits: {dataset_split_names}") for alias_split_name, split in dataset.items(): @@ -89,9 +92,8 @@ def make_dataset_splits(path, split, split_aliases, kwargs): else: raise ValueError("Expected split name to be None, str or a list") - assert ( - sum(map(lambda x: x is not None, dataset_splits.values())) > 0 - ), "Expected at least one dataset to have been initialized" + num_init_splits = sum(map(lambda x: x is not None, dataset_splits.values())) + assert num_init_splits > 0, f"Expected at least one split to have been initialized {num_init_splits}" return dataset_splits @@ -111,9 +113,9 @@ class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, - path, - collate_fn=None, + path_or_dataset, split=None, + collate_fn=None, num_workers=2, pin_memory=True, persistent_workers=True, @@ -130,16 +132,23 @@ def __init__( ) -> None: super().__init__() assert pad_token_id is not None - - logging.info(f"Loading HF dataset from {path}") - + from datasets import Dataset, DatasetDict # A dataset usually will have several splits (e.g. train, val, test, etc). # We map synonym names to canonical names (train, test, val). # A synonym can be a prefix/suffixed word e.g. train <> training. split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases} # self.dataset_splits will hold the actual dataset for each split. - self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) + if isinstance(path_or_dataset, str): + logging.info(f"Loading HF dataset from {path_or_dataset}") + dataset = load_dataset(path_or_dataset, split=split, **kwargs) + elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict): + dataset = path + else: + raise ValueError("Expecter `path_or_dataset` to be str, Dataset, DatasetDict, but got "\ + + str(type(path_or_dataset))) + + self.dataset_splits = make_dataset_splits(dataset, split, split_aliases) if collate_fn is None: self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) @@ -157,6 +166,12 @@ def __init__( self.use_mcore_sampler = use_mcore_sampler self.mcore_dataloader_type = mcore_dataloader_type + @staticmethod + def from_dict(dataset_dict, split, **kwargs): + from datasets import Dataset + dataset = Dataset.from_dict(dataset_dict) + return HFDatasetDataModule(path=dataset, split=split, **kwargs) + @staticmethod def collate_fn(batch, pad_token_id=0): def batchify(tensor): From 66724196f03619200fab27afcf0d63f871c5172a Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:39:54 -0800 Subject: [PATCH 2/7] add test_load_from_dict Signed-off-by: Alexandros Koumparoulis --- .../llm/gpt/data/test_hf_datamodule.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index a8d264701d39..6d83be92146f 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -112,3 +112,22 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai raised_exception = True assert raised_exception == True, "Expected to raise a ValueError" + +def test_load_from_dict(): + data = {'text': "Below is an instruction that describes a task, paired with an input that "} + + datamodule = llm.HFDatasetDataModule.from_dict( + {"text": [data['text'] for _ in range(101)]}, + split='train', + global_batch_size=4, + micro_batch_size=1, + ) + assert is not None datamodule + assert isinstance(datamodule, llm.HFDatasetDataModule) + assert hasattr(datamodule, 'train') + assert datamodule.train is not None + assert len(datamodule.train) == 101 + assert hasattr(datamodule, 'val') + assert datamodule.val is None + assert hasattr(datamodule, 'test') + assert datamodule.test is None \ No newline at end of file From a10326e578eafab5632d81635e8c6f2728fbb8e1 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:42:21 -0800 Subject: [PATCH 3/7] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 58555fd23956..9530a473f97a 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -141,7 +141,7 @@ def __init__( # self.dataset_splits will hold the actual dataset for each split. if isinstance(path_or_dataset, str): logging.info(f"Loading HF dataset from {path_or_dataset}") - dataset = load_dataset(path_or_dataset, split=split, **kwargs) + dataset = load_dataset(path_or_dataset, split=split, **kwargs) elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict): dataset = path else: From 0644c57388a706fa593466e3bcc881c6841db924 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:44:27 -0800 Subject: [PATCH 4/7] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 9530a473f97a..060b802ffc9f 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -143,7 +143,8 @@ def __init__( logging.info(f"Loading HF dataset from {path_or_dataset}") dataset = load_dataset(path_or_dataset, split=split, **kwargs) elif isinstance(path_or_dataset, Dataset) or isinstance(path_or_dataset, DatasetDict): - dataset = path + logging.info(f"Using passed HF dataset {str(path_or_dataset)}") + dataset = path_or_dataset else: raise ValueError("Expecter `path_or_dataset` to be str, Dataset, DatasetDict, but got "\ + str(type(path_or_dataset))) From 8e82e57d9ce7b9f4087f7a5095c7eef57b766930 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:44:39 -0800 Subject: [PATCH 5/7] add test_load_from_dict Signed-off-by: Alexandros Koumparoulis --- tests/collections/llm/gpt/data/test_hf_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index 6d83be92146f..135166738acb 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -114,7 +114,7 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai assert raised_exception == True, "Expected to raise a ValueError" def test_load_from_dict(): - data = {'text': "Below is an instruction that describes a task, paired with an input that "} + data = {'text': "Below is an instruction that describes a task, paired with an input that "} datamodule = llm.HFDatasetDataModule.from_dict( {"text": [data['text'] for _ in range(101)]}, From 11f1202c0ff0682da2e37091d63fa943d4d5a5ac Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 11 Dec 2024 14:49:08 -0800 Subject: [PATCH 6/7] fix Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/data/hf_dataset.py | 4 ++-- .../llm/gpt/data/test_hf_datamodule.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 060b802ffc9f..ded77f77cfc7 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -146,7 +146,7 @@ def __init__( logging.info(f"Using passed HF dataset {str(path_or_dataset)}") dataset = path_or_dataset else: - raise ValueError("Expecter `path_or_dataset` to be str, Dataset, DatasetDict, but got "\ + raise ValueError("Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got "\ + str(type(path_or_dataset))) self.dataset_splits = make_dataset_splits(dataset, split, split_aliases) @@ -171,7 +171,7 @@ def __init__( def from_dict(dataset_dict, split, **kwargs): from datasets import Dataset dataset = Dataset.from_dict(dataset_dict) - return HFDatasetDataModule(path=dataset, split=split, **kwargs) + return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs) @staticmethod def collate_fn(batch, pad_token_id=0): diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index 135166738acb..cad804b2678d 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -19,7 +19,7 @@ def test_load_single_split(): ds = llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split='train', seq_length=512, micro_batch_size=2, @@ -46,7 +46,7 @@ def test_load_nonexistent_split(): expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' try: llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split='this_split_name_should_not_exist', seq_length=512, micro_batch_size=2, @@ -59,7 +59,7 @@ def test_load_nonexistent_split(): def test_load_multiple_split(): ds = llm.HFDatasetDataModule( - path=DATA_PATH, + path_or_dataset=DATA_PATH, split=['train', 'validation'], seq_length=512, micro_batch_size=2, @@ -88,7 +88,7 @@ def test_validate_dataset_asset_accessibility_file_does_not_exist(): raised_exception = False try: llm.HFDatasetDataModule( - path="/this/path/should/not/exist/", + path_or_dataset="/this/path/should/not/exist/", seq_length=512, micro_batch_size=2, global_batch_size=2, @@ -103,13 +103,13 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai raised_exception = False try: llm.HFDatasetDataModule( - path=None, + path_or_dataset=None, seq_length=512, micro_batch_size=2, global_batch_size=2, ) - except TypeError: - raised_exception = True + except ValueError as e: + raised_exception = str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " assert raised_exception == True, "Expected to raise a ValueError" @@ -122,7 +122,7 @@ def test_load_from_dict(): global_batch_size=4, micro_batch_size=1, ) - assert is not None datamodule + assert datamodule is not None assert isinstance(datamodule, llm.HFDatasetDataModule) assert hasattr(datamodule, 'train') assert datamodule.train is not None From 329d7372f64807498d7393e59f909216259d5c3b Mon Sep 17 00:00:00 2001 From: akoumpa Date: Wed, 11 Dec 2024 22:49:53 +0000 Subject: [PATCH 7/7] Apply isort and black reformatting Signed-off-by: akoumpa --- nemo/collections/llm/gpt/data/hf_dataset.py | 8 ++++++-- tests/collections/llm/gpt/data/test_hf_datamodule.py | 7 +++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index ded77f77cfc7..73b6444a6e9c 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -52,6 +52,7 @@ def make_dataset_splits(dataset, split, split_aliases): > } """ from datasets import Dataset, DatasetDict + split_names = ['train', 'test', 'val'] dataset_splits = {_split: None for _split in split_names} @@ -133,6 +134,7 @@ def __init__( super().__init__() assert pad_token_id is not None from datasets import Dataset, DatasetDict + # A dataset usually will have several splits (e.g. train, val, test, etc). # We map synonym names to canonical names (train, test, val). # A synonym can be a prefix/suffixed word e.g. train <> training. @@ -146,8 +148,9 @@ def __init__( logging.info(f"Using passed HF dataset {str(path_or_dataset)}") dataset = path_or_dataset else: - raise ValueError("Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got "\ - + str(type(path_or_dataset))) + raise ValueError( + "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + str(type(path_or_dataset)) + ) self.dataset_splits = make_dataset_splits(dataset, split, split_aliases) @@ -170,6 +173,7 @@ def __init__( @staticmethod def from_dict(dataset_dict, split, **kwargs): from datasets import Dataset + dataset = Dataset.from_dict(dataset_dict) return HFDatasetDataModule(path_or_dataset=dataset, split=split, **kwargs) diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py index cad804b2678d..58f7c02e091b 100644 --- a/tests/collections/llm/gpt/data/test_hf_datamodule.py +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -109,10 +109,13 @@ def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trai global_batch_size=2, ) except ValueError as e: - raised_exception = str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + raised_exception = ( + str(e) == "Expected `path_or_dataset` to be str, Dataset, DatasetDict, but got " + ) assert raised_exception == True, "Expected to raise a ValueError" + def test_load_from_dict(): data = {'text': "Below is an instruction that describes a task, paired with an input that "} @@ -130,4 +133,4 @@ def test_load_from_dict(): assert hasattr(datamodule, 'val') assert datamodule.val is None assert hasattr(datamodule, 'test') - assert datamodule.test is None \ No newline at end of file + assert datamodule.test is None