diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 357dc5a7bd17..c24c5958b388 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -18,7 +18,7 @@ from nemo.collections import llm -def mk_hf_dataset(tokenizer): +def make_squad_hf_dataset(tokenizer): EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN def formatting_prompts_func(examples): @@ -45,11 +45,9 @@ def formatting_prompts_func(examples): 'labels': tokens[1:] + [tokens[-1]], } - from datasets import load_dataset - - dataset = load_dataset("rajpurkar/squad", split="train") - dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2) - return dataset + datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id) + datamodule.map(formatting_prompts_func, batched=False, batch_size=2) + return datamodule if __name__ == '__main__': @@ -80,9 +78,7 @@ def formatting_prompts_func(examples): llm.api.finetune( model=llm.HFAutoModelForCausalLM(args.model), - data=llm.HFDatasetDataModule( - mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id - ), + data=make_squad_hf_dataset(tokenizer.tokenizer), trainer=nl.Trainer( devices=args.devices, max_steps=args.max_steps, diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py index 374bee83b8b2..b4e603186bf4 100644 --- a/nemo/collections/llm/gpt/data/api.py +++ b/nemo/collections/llm/gpt/data/api.py @@ -41,8 +41,8 @@ def dolly() -> pl.LightningDataModule: @run.cli.factory @run.autoconvert -def hf_dataset(dataset: str) -> pl.LightningDataModule: - return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) +def hf_dataset(path: str) -> pl.LightningDataModule: + return HFDatasetDataModule(path=path, global_batch_size=16, micro_batch_size=2) __all__ = ["mock", "squad", "dolly", "hf_dataset"] diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 0f45ecf265b7..039e5b90b096 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,16 +12,108 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datasets.dataset_dict import lightning.pytorch as pl import torch +from datasets import load_dataset from torch.utils.data import DataLoader + from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +def make_dataset_splits(path, split, split_aliases, kwargs): + """ + Loads a dataset with datasets.load_dataset and + returns a dictionary containing all dataset splits. + + For example: + + ans = make_dataset_splits("dataset-id") + $ ds = load_dataset("dataset-id") + $ print(ds) + > DatasetDict({ + > train: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 87599 + > }) + > validation: Dataset({ + > features: ['id', 'title', 'context', 'question', 'answers'], + > num_rows: 10570 + > }) + > }) + + In this case the value of `ans` (returned value) will be: + $ print(ans) + > { + > "train": Dataset .. (with 87599 rows), + > "val": Dataset .. (with 10570 rows), + > } + """ + dataset = load_dataset(path, split=split, **kwargs) + + split_names = ['train', 'test', 'val'] + dataset_splits = {split: None for split in split_names} + + alias_to_split = {} + for split_name, _split_aliases in split_aliases.items(): + assert split_name in split_names + for alias in _split_aliases: + alias_to_split[alias] = split_name + + if isinstance(dataset, datasets.dataset_dict.DatasetDict): + dataset_split_names = dataset.keys() + logging.info(f"HF dataset has the following splits: {dataset_split_names}") + for alias_split_name, split in dataset.items(): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = split + elif isinstance(split, list): + logging.info(f"Loaded HF dataset will use " + str(split) + " splits.") + assert isinstance(dataset, list) + for i, alias_split_name in enumerate(split): + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset[i] + elif isinstance(split, str): + logging.info(f"Loaded HF dataset has a single split.") + assert not isinstance(dataset, list) + alias_split_name = split + if '+' in alias_split_name: + raise ValueError("Split concatenation not supported") + elif '[' in alias_split_name: + alias_split_name = alias_split_name.split('[')[0] + split_name = alias_to_split[alias_split_name] + assert dataset_splits[split_name] is None + dataset_splits[split_name] = dataset + else: + raise ValueError("Expected split name to be None, str or a list") + + assert ( + sum(map(lambda x: x is not None, dataset_splits.values())) > 0 + ), "Expected at least one dataset to have been initialized" + return dataset_splits class HFDatasetDataModule(pl.LightningDataModule): + """HFDatasetDataModule wraps HF's load_dataset (datasets library) + so that it can be used within NeMo. + Users can select whether to use an mcore-sampler via use_mcore_sampler arg. + + Usage examples: + + - loading a single split (train) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split="train") + + - loading multiple splits (train, validation) from a dataset + llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"]) + """ + def __init__( self, - dataset, + path, + collate_fn=None, + split=None, num_workers=2, pin_memory=True, persistent_workers=True, @@ -31,11 +123,29 @@ def __init__( pad_token_id=0, use_mcore_sampler=False, mcore_dataloader_type='cyclic', + train_aliases=["train", "training"], + test_aliases=["test", "testing"], + val_aliases=["val", "validation", "valid", "eval"], + **kwargs, ) -> None: super().__init__() assert pad_token_id is not None - self.dataset = dataset + logging.info(f"Loading HF dataset from {path}") + + # A dataset usually will have several splits (e.g. train, val, test, etc). + # We map synonym names to canonical names (train, test, val). + # A synonym can be a prefix/suffixed word e.g. train <> training. + split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases} + + # self.dataset_splits will hold the actual dataset for each split. + self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs) + + if collate_fn is None: + self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + else: + self._collate_fn = collate_fn + self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers @@ -84,17 +194,51 @@ def setup(self, stage: str): dataloader_type=self.mcore_dataloader_type, ) - def train_dataloader(self, collate_fn=None): - from nemo.lightning.data import add_megatron_sampler + def _make_dataloader(self, dataset, collate_fn=None): + assert dataset is not None if collate_fn is None: collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) return DataLoader( - self.dataset, + dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, collate_fn=collate_fn, batch_size=self.micro_batch_size, ) + + @property + def train(self): + return self.dataset_splits['train'] + + @property + def val(self): + return self.dataset_splits['val'] + + @property + def test(self): + return self.dataset_splits['test'] + + def train_dataloader(self): + return self._make_dataloader(self.train, self._collate_fn) + + def val_dataloader(self): + return self._make_dataloader(self.val, self._collate_fn) + + def test_dataloader(self): + return self._make_dataloader(self.test, self._collate_fn) + + def map(self, function=None, split_names=None, **kwargs): + if isinstance(split_names, str): + dataset_splits = {split_names: self.dataset_splits[split_names]} + elif isinstance(split_names, list): + dataset_splits = {k: self.dataset_splits[k] for k in split_names} + else: + dataset_splits = self.dataset_splits + + for split_name, subset in dataset_splits.items(): + if subset is None: + continue + dataset_splits[split_name] = subset.map(function, **kwargs) diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 795d6efadd3a..6c89a1b42b15 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -25,9 +25,6 @@ from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( AbstractModelInferenceWrapper, ) -from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( - EncoderDecoderTextGenerationController, -) from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( SimpleTextGenerationController, ) @@ -232,6 +229,10 @@ def generate( Returns: dict: A dictionary containing the generated results. """ + from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import ( + EncoderDecoderTextGenerationController, + ) + if encoder_prompts is not None: text_generation_controller = EncoderDecoderTextGenerationController( inference_wrapped_model=model, tokenizer=tokenizer diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index 940c0e51ee92..743d16f57c2b 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -20,7 +20,7 @@ import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig -from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper + from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec @@ -319,6 +319,7 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, padded_vocab_size=self.tokenizer.vocab_size, ) + from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config) return model_inference_wrapper diff --git a/tests/collections/llm/gpt/data/test_hf_datamodule.py b/tests/collections/llm/gpt/data/test_hf_datamodule.py new file mode 100644 index 000000000000..a8d264701d39 --- /dev/null +++ b/tests/collections/llm/gpt/data/test_hf_datamodule.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections import llm + +DATA_PATH = "/home/TestData/lite/hf_cache/squad/" + + +def test_load_single_split(): + ds = llm.HFDatasetDataModule( + path=DATA_PATH, + split='train', + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is None + assert ds.val is None + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + + +def test_load_nonexistent_split(): + exception_msg = '' + expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].''' + try: + llm.HFDatasetDataModule( + path=DATA_PATH, + split='this_split_name_should_not_exist', + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except ValueError as e: + exception_msg = str(e) + assert exception_msg == expected_msg, exception_msg + + +def test_load_multiple_split(): + ds = llm.HFDatasetDataModule( + path=DATA_PATH, + split=['train', 'validation'], + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + from datasets.arrow_dataset import Dataset + + assert isinstance(ds.dataset_splits, dict) + assert len(ds.dataset_splits) == 3 + assert 'train' in ds.dataset_splits + assert ds.dataset_splits['train'] is not None + assert ds.train is not None + assert isinstance(ds.dataset_splits['train'], Dataset) + assert isinstance(ds.train, Dataset) + assert 'val' in ds.dataset_splits + assert ds.dataset_splits['val'] is not None + assert ds.val is not None + assert isinstance(ds.dataset_splits['val'], Dataset) + assert isinstance(ds.val, Dataset) + assert 'test' in ds.dataset_splits + assert ds.dataset_splits['test'] is None + assert ds.test is None + + +def test_validate_dataset_asset_accessibility_file_does_not_exist(): + raised_exception = False + try: + llm.HFDatasetDataModule( + path="/this/path/should/not/exist/", + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except FileNotFoundError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a FileNotFoundError" + + +def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer): + raised_exception = False + try: + llm.HFDatasetDataModule( + path=None, + seq_length=512, + micro_batch_size=2, + global_batch_size=2, + ) + except TypeError: + raised_exception = True + + assert raised_exception == True, "Expected to raise a ValueError"