Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions examples/llm/peft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nemo.collections import llm


def mk_hf_dataset(tokenizer):
def make_squad_hf_dataset(tokenizer):
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN

def formatting_prompts_func(examples):
Expand All @@ -45,11 +45,9 @@ def formatting_prompts_func(examples):
'labels': tokens[1:] + [tokens[-1]],
}

from datasets import load_dataset

dataset = load_dataset("rajpurkar/squad", split="train")
dataset = dataset.map(formatting_prompts_func, batched=False, batch_size=2)
return dataset
datamodule = llm.HFDatasetDataModule("rajpurkar/squad", split="train", pad_token_id=tokenizer.eos_token_id)
datamodule.map(formatting_prompts_func, batched=False, batch_size=2)
return datamodule


if __name__ == '__main__':
Expand Down Expand Up @@ -80,9 +78,7 @@ def formatting_prompts_func(examples):

llm.api.finetune(
model=llm.HFAutoModelForCausalLM(args.model),
data=llm.HFDatasetDataModule(
mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id
),
data=make_squad_hf_dataset(tokenizer.tokenizer),
trainer=nl.Trainer(
devices=args.devices,
max_steps=args.max_steps,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def dolly() -> pl.LightningDataModule:

@run.cli.factory
@run.autoconvert
def hf_dataset(dataset: str) -> pl.LightningDataModule:
return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2)
def hf_dataset(path: str) -> pl.LightningDataModule:
return HFDatasetDataModule(path=path, global_batch_size=16, micro_batch_size=2)


__all__ = ["mock", "squad", "dolly", "hf_dataset"]
154 changes: 149 additions & 5 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datasets.dataset_dict
import lightning.pytorch as pl
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader

from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging


def make_dataset_splits(path, split, split_aliases, kwargs):
"""
Loads a dataset with datasets.load_dataset and
returns a dictionary containing all dataset splits.

For example:

ans = make_dataset_splits("dataset-id")
$ ds = load_dataset("dataset-id")
$ print(ds)
> DatasetDict({
> train: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 87599
> })
> validation: Dataset({
> features: ['id', 'title', 'context', 'question', 'answers'],
> num_rows: 10570
> })
> })

In this case the value of `ans` (returned value) will be:
$ print(ans)
> {
> "train": Dataset .. (with 87599 rows),
> "val": Dataset .. (with 10570 rows),
> }
"""
dataset = load_dataset(path, split=split, **kwargs)

split_names = ['train', 'test', 'val']
dataset_splits = {split: None for split in split_names}

alias_to_split = {}
for split_name, _split_aliases in split_aliases.items():
assert split_name in split_names
for alias in _split_aliases:
alias_to_split[alias] = split_name

if isinstance(dataset, datasets.dataset_dict.DatasetDict):
dataset_split_names = dataset.keys()
logging.info(f"HF dataset has the following splits: {dataset_split_names}")
for alias_split_name, split in dataset.items():
split_name = alias_to_split[alias_split_name]
assert dataset_splits[split_name] is None
dataset_splits[split_name] = split
elif isinstance(split, list):
logging.info(f"Loaded HF dataset will use " + str(split) + " splits.")
assert isinstance(dataset, list)
for i, alias_split_name in enumerate(split):
split_name = alias_to_split[alias_split_name]
assert dataset_splits[split_name] is None
dataset_splits[split_name] = dataset[i]
elif isinstance(split, str):
logging.info(f"Loaded HF dataset has a single split.")
assert not isinstance(dataset, list)
alias_split_name = split
if '+' in alias_split_name:
raise ValueError("Split concatenation not supported")
elif '[' in alias_split_name:
alias_split_name = alias_split_name.split('[')[0]
split_name = alias_to_split[alias_split_name]
assert dataset_splits[split_name] is None
dataset_splits[split_name] = dataset
else:
raise ValueError("Expected split name to be None, str or a list")

assert (
sum(map(lambda x: x is not None, dataset_splits.values())) > 0
), "Expected at least one dataset to have been initialized"
return dataset_splits


class HFDatasetDataModule(pl.LightningDataModule):
"""HFDatasetDataModule wraps HF's load_dataset (datasets library)
so that it can be used within NeMo.
Users can select whether to use an mcore-sampler via use_mcore_sampler arg.

Usage examples:

- loading a single split (train) from a dataset
llm.HFDatasetDataModule("rajpurkar/squad", split="train")

- loading multiple splits (train, validation) from a dataset
llm.HFDatasetDataModule("rajpurkar/squad", split=["train", "validation"])
"""

def __init__(
self,
dataset,
path,
collate_fn=None,
split=None,
num_workers=2,
pin_memory=True,
persistent_workers=True,
Expand All @@ -31,11 +123,29 @@ def __init__(
pad_token_id=0,
use_mcore_sampler=False,
mcore_dataloader_type='cyclic',
train_aliases=["train", "training"],
test_aliases=["test", "testing"],
val_aliases=["val", "validation", "valid", "eval"],
**kwargs,
) -> None:
super().__init__()
assert pad_token_id is not None

self.dataset = dataset
logging.info(f"Loading HF dataset from {path}")

# A dataset usually will have several splits (e.g. train, val, test, etc).
# We map synonym names to canonical names (train, test, val).
# A synonym can be a prefix/suffixed word e.g. train <> training.
split_aliases = {'train': train_aliases, 'test': test_aliases, 'val': val_aliases}

# self.dataset_splits will hold the actual dataset for each split.
self.dataset_splits = make_dataset_splits(path, split, split_aliases, kwargs)

if collate_fn is None:
self._collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)
else:
self._collate_fn = collate_fn

self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
Expand Down Expand Up @@ -84,17 +194,51 @@ def setup(self, stage: str):
dataloader_type=self.mcore_dataloader_type,
)

def train_dataloader(self, collate_fn=None):
from nemo.lightning.data import add_megatron_sampler
def _make_dataloader(self, dataset, collate_fn=None):
assert dataset is not None

if collate_fn is None:
collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)

return DataLoader(
self.dataset,
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=collate_fn,
batch_size=self.micro_batch_size,
)

@property
def train(self):
return self.dataset_splits['train']

@property
def val(self):
return self.dataset_splits['val']

@property
def test(self):
return self.dataset_splits['test']

def train_dataloader(self):
return self._make_dataloader(self.train, self._collate_fn)

def val_dataloader(self):
return self._make_dataloader(self.val, self._collate_fn)

def test_dataloader(self):
return self._make_dataloader(self.test, self._collate_fn)

def map(self, function=None, split_names=None, **kwargs):
if isinstance(split_names, str):
dataset_splits = {split_names: self.dataset_splits[split_names]}
elif isinstance(split_names, list):
dataset_splits = {k: self.dataset_splits[k] for k in split_names}
else:
dataset_splits = self.dataset_splits

for split_name, subset in dataset_splits.items():
if subset is None:
continue
dataset_splits[split_name] = subset.map(function, **kwargs)
7 changes: 4 additions & 3 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
Expand Down Expand Up @@ -232,6 +229,10 @@ def generate(
Returns:
dict: A dictionary containing the generated results.
"""
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)

if encoder_prompts is not None:
text_generation_controller = EncoderDecoderTextGenerationController(
inference_wrapped_model=model, tokenizer=tokenizer
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.distributed
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper

from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
Expand Down Expand Up @@ -319,6 +319,7 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=self.tokenizer.vocab_size,
)
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper

model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config)
return model_inference_wrapper
Expand Down
114 changes: 114 additions & 0 deletions tests/collections/llm/gpt/data/test_hf_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections import llm

DATA_PATH = "/home/TestData/lite/hf_cache/squad/"


def test_load_single_split():
ds = llm.HFDatasetDataModule(
path=DATA_PATH,
split='train',
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
from datasets.arrow_dataset import Dataset

assert isinstance(ds.dataset_splits, dict)
assert len(ds.dataset_splits) == 3
assert 'train' in ds.dataset_splits
assert ds.dataset_splits['train'] is not None
assert ds.train is not None
assert isinstance(ds.dataset_splits['train'], Dataset)
assert 'val' in ds.dataset_splits
assert ds.dataset_splits['val'] is None
assert ds.val is None
assert 'test' in ds.dataset_splits
assert ds.dataset_splits['test'] is None
assert ds.test is None


def test_load_nonexistent_split():
exception_msg = ''
expected_msg = '''Unknown split "this_split_name_should_not_exist". Should be one of ['train', 'validation'].'''
try:
llm.HFDatasetDataModule(
path=DATA_PATH,
split='this_split_name_should_not_exist',
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
except ValueError as e:
exception_msg = str(e)
assert exception_msg == expected_msg, exception_msg


def test_load_multiple_split():
ds = llm.HFDatasetDataModule(
path=DATA_PATH,
split=['train', 'validation'],
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
from datasets.arrow_dataset import Dataset

assert isinstance(ds.dataset_splits, dict)
assert len(ds.dataset_splits) == 3
assert 'train' in ds.dataset_splits
assert ds.dataset_splits['train'] is not None
assert ds.train is not None
assert isinstance(ds.dataset_splits['train'], Dataset)
assert isinstance(ds.train, Dataset)
assert 'val' in ds.dataset_splits
assert ds.dataset_splits['val'] is not None
assert ds.val is not None
assert isinstance(ds.dataset_splits['val'], Dataset)
assert isinstance(ds.val, Dataset)
assert 'test' in ds.dataset_splits
assert ds.dataset_splits['test'] is None
assert ds.test is None


def test_validate_dataset_asset_accessibility_file_does_not_exist():
raised_exception = False
try:
llm.HFDatasetDataModule(
path="/this/path/should/not/exist/",
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
except FileNotFoundError:
raised_exception = True

assert raised_exception == True, "Expected to raise a FileNotFoundError"


def test_validate_dataset_asset_accessibility_file_is_none(): # tokenizer, trainer):
raised_exception = False
try:
llm.HFDatasetDataModule(
path=None,
seq_length=512,
micro_batch_size=2,
global_batch_size=2,
)
except TypeError:
raised_exception = True

assert raised_exception == True, "Expected to raise a ValueError"