diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 101107dddc17..6b2470791a86 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -3700,6 +3700,17 @@ jobs: TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp AFTER_SCRIPT: | rm -rf nemo_experiments + + L2_HF_Transformer_SFT_FSDP2_2gpu: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_FSDP2_2gpu') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_fsdp2.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 + AFTER_SCRIPT: | + rm -rf nemo_experiments L2_HF_Transformer_PT_2gpu: needs: [ cicd-test-container-setup ] @@ -3722,6 +3733,17 @@ jobs: TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_nemorun.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 --strategy ddp AFTER_SCRIPT: | rm -rf nemo_experiments + + L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + TRANSFORMERS_OFFLINE=1 python tests/collections/llm/hf/sft_nemorun_fsdp2.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --max-steps 10 --devices 2 + AFTER_SCRIPT: | + rm -rf nemo_experiments L2_HF_Transformer_PT_2gpu_nemorun: needs: [ cicd-test-container-setup ] @@ -5047,6 +5069,8 @@ jobs: - L2_NeMo_2_PTQ_Llama2_FP8 - L2_NeMo_2_jit_callback - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING + - L2_HF_Transformer_SFT_FSDP2_2gpu + - L2_HF_Transformer_SFT_2gpu_nemorun_fsdp2 if: always() runs-on: ubuntu-latest steps: diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index cea7264543ff..abe966229ffe 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -20,6 +20,7 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.llm import fn from nemo.lightning import io +from nemo.lightning.pytorch.strategies.utils import fsdp2_strategy_parallelize from nemo.utils import logging @@ -91,6 +92,10 @@ def configure_model(self): config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code ) + # Apply FSDP2 and TP to the model + if self.device_mesh is not None: + fsdp2_strategy_parallelize(self.model, device_mesh=self.device_mesh) + if self.model_accelerator is not None: self.model_accelerator(self.model) @@ -99,7 +104,7 @@ def configure_model(self): def forward(self, batch): return self.model(**batch) - def training_step(self, batch): + def training_step(self, batch, batch_idx=None): labels = batch.pop('labels').to(self.model.device) loss_mask = batch.pop('loss_mask', None) diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index e01a2d5e5765..9ad6822243a9 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -31,7 +31,7 @@ from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule, lr_scheduler from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler -from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy +from nemo.lightning.pytorch.strategies import FSDP2Strategy, FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -60,6 +60,7 @@ def _is_slurm_interactive_mode(): "MegatronMixedPrecision", "MegatronOptimizerModule", "FSDPStrategy", + "FSDP2Strategy", "RestoreConfig", "lr_scheduler", "NeMoLogger", diff --git a/nemo/lightning/pytorch/strategies/__init__.py b/nemo/lightning/pytorch/strategies/__init__.py index 9ef58bcc9023..b01ff14a10cc 100644 --- a/nemo/lightning/pytorch/strategies/__init__.py +++ b/nemo/lightning/pytorch/strategies/__init__.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.lightning.pytorch.strategies.fsdp2_strategy import FSDP2Strategy from nemo.lightning.pytorch.strategies.fsdp_strategy import FSDPStrategy from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy - __all__ = [ "FSDPStrategy", + "FSDP2Strategy", "MegatronStrategy", ] diff --git a/nemo/lightning/pytorch/strategies/fsdp2_strategy.py b/nemo/lightning/pytorch/strategies/fsdp2_strategy.py new file mode 100644 index 000000000000..d59dca7be3aa --- /dev/null +++ b/nemo/lightning/pytorch/strategies/fsdp2_strategy.py @@ -0,0 +1,276 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from collections import OrderedDict +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +import lightning.pytorch as pl +import torch +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.strategies.fsdp import _get_sharded_state_dict_context +from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy as PLModelParallelStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch.distributed.checkpoint.state_dict import ( # get_state_dict, + StateDictOptions, + get_optimizer_state_dict, + set_state_dict, +) +from torch.utils.data import DataLoader +from typing_extensions import override + +from nemo.lightning import io +from nemo.lightning.pytorch.strategies.utils import ( + ckpt_to_dir, + create_checkpoint_io, + fix_progress_bar, + init_model_parallel, + mcore_to_pyt_sharded_state_dict, + pyt_to_mcore_state_dict, + setup_data_sampler, + setup_parallel_ranks, +) + + +class FSDP2Strategy(PLModelParallelStrategy, io.IOMixin): + """Megatron plugin for Pytorch Lightning. + + This strategy implements FSDP 2 using PyTorch's native FSDP 2 methods. Comparing with + MegatronStrategy, FSDP2Strategy is designed to be more lightweight, with minimal + modifications over Lightning's ModelParallelStrategy which supports FSDP2 + TP + parallelization but preserves necessary features to be compatible with nemo and mcore. + By default, this strategy wraps FSDP2 per TransformerLayer. + + Note: + This strategy is designed to work with NVIDIA's Megatron-LM framework and requires + specific model implementations that are compatible with Megatron's parallelism techniques. + Note: + Due to the different optimizer structure (FSDP2 only uses torch native optimizers), + MegatronStrategy cannot resume training from checkpoints saved by FSDP2Strategy, and vice + versa. However, the model weights structure is made compatible, so switching strategy is + possible if users only need the weights not the optimizer states. (E.g. run pretrain with + megatron 4D parallelism and run SFT with FSDP2.) + """ + + def __init__( + self, + data_parallel_size: Union[Literal["auto"], int] = "auto", + tensor_parallel_size: Union[Literal["auto"], int] = "auto", + ckpt_load_optimizer: bool = True, + ckpt_save_optimizer: bool = True, + data_sampler=None, + **kwargs, + ): + super().__init__(data_parallel_size=data_parallel_size, tensor_parallel_size=tensor_parallel_size, **kwargs) + + self.data_sampler = data_sampler + self.ckpt_load_optimizer = ckpt_load_optimizer + self.ckpt_save_optimizer = ckpt_save_optimizer + + @override + def setup_environment(self) -> None: + setup_parallel_ranks(self) + super().setup_environment() + init_model_parallel(self.model) + + @override + def setup(self, trainer: pl.Trainer) -> None: + self.trainer = trainer + setup_data_sampler(self.trainer) + fix_progress_bar(trainer) + super().setup(trainer) + + def _get_loss_reduction(self, step_type: str): + for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]: + if hasattr(self.lightning_module, fn_name): + return getattr(self.lightning_module, fn_name) + return None + + def _step_proxy(self, step_type, batch, batch_idx=None): + method_name = f"{step_type}_step" + if self.model != self.lightning_module: + loss = self._forward_redirection(self.model, self.lightning_module, method_name, batch, batch_idx) + else: + loss = getattr(self.lightning_module, method_name)(batch, batch_idx) + + _loss_reduction = self._get_loss_reduction(step_type) + if _loss_reduction: + return _loss_reduction.forward(batch, loss) + return loss, {'avg': loss} + + @override + def training_step(self, batch, batch_idx=None) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + with self.precision_plugin.train_step_context(): + loss, reduced = self._step_proxy("training", batch, batch_idx) + + self.lightning_module.log( + 'global_step', + self.trainer.global_step, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + self.lightning_module.log( + 'step', + self.trainer.global_step, + ) + self.lightning_module.log( + 'reduced_train_loss', reduced['avg'], prog_bar=True, rank_zero_only=True, batch_size=1 + ) + + # returns unreduced loss for backward + return loss + + @override + def validation_step(self, batch, batch_idx=None) -> Any: + assert self.lightning_module is not None + assert self.model is not None + with self.precision_plugin.val_step_context(): + loss, reduced = self._step_proxy("validation", batch, batch_idx) + self.lightning_module.log('val_loss', reduced['avg'], rank_zero_only=True, batch_size=1) + return loss + + @override + def test_step(self, batch, batch_idx=None) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + with self.precision_plugin.test_step_context(): + loss, reduced = self._step_proxy("test", batch, batch_idx) + self.lightning_module.log('test_loss', reduced['avg'], rank_zero_only=True, batch_size=1) + + return loss + + @override + def predict_step(self, batch, batch_idx=None) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + with self.precision_plugin.predict_step_context(): + loss, reduced = self._step_proxy("predict", batch, batch_idx) + return reduced + + @override + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + if self.data_sampler: + return self.data_sampler.transform_dataloader(dataloader) + + return dataloader + + @property + @override + def checkpoint_io(self) -> CheckpointIO: + if not self._checkpoint_io: + self._checkpoint_io = create_checkpoint_io() + + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, io: CheckpointIO) -> None: + self._checkpoint_io = io + + @property + def current_epoch_step(self) -> int: + """ + Get the value of step within an epoch. + """ + return max( + self.trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed, + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed, + ) + + @override + def remove_checkpoint(self, filepath: Union[str, Path]) -> None: + # Taken from MegatronStrategy + ckpt = ckpt_to_dir(filepath) + if self.is_global_zero: + if os.path.islink(ckpt): + os.unlink(ckpt) + else: + shutil.rmtree(ckpt) + + @override + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: + """Converts PyT checkpoints to MCore format and save using MCore dist ckpt library.""" + checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict( + checkpoint.pop("state_dict"), device_mesh=self.device_mesh + ) + checkpoint["state_dict"] = OrderedDict([]) + + if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING: + # Clear the optimizer states. This handles the case where ckpt_save_optimizer=False + # Ideally, the optimizer state dicts should not be generated in this case + checkpoint["optimizer_states"] = {} + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if self.ckpt_save_optimizer: + checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) + pyt_to_mcore_state_dict( + checkpoint['optimizer']['state'], prefix="optimizer.state.", device_mesh=self.device_mesh + ) + + self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + + @override + def load_checkpoint(self, checkpoint_path: str | Path) -> Dict[str, Any]: + """PTL method which we override to integrate distributed checkpoints for FSDP models. + Different from MegatronStrategy, both model and optimizer states are restore within + this method. + + The logic here is slightly more complicated: + 1. Obtain PyT state dicts (sharded & unflattened) for model and optim -> torch::ShardedTensor + 2. Convert to MCore state dicts -> mcore::ShardedTensor + 3. Load from checkpoint using MCore dist ckpt API -> torch::Tensor + 4. Convert to PyT state dicts (sharded & unflattened) -> torch::ShardedTensor + 5. Load into model and optim using PyT dist ckpt API + 6. Return the loaded checkpoint for lightning to load other metadata + """ + path = Path(self.broadcast(checkpoint_path)) + torch.cuda.empty_cache() + + # TODO: the elegant way to load both state dicts. Need pytorch 2.3.1 + # msd, osd = get_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) + sharded_state_dict = {} + with _get_sharded_state_dict_context(self.model): + msd = self.model.state_dict() + pyt_to_mcore_state_dict(msd, device_mesh=self.device_mesh) + sharded_state_dict["sharded_state_dict"] = msd + + if self.ckpt_load_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + osd = get_optimizer_state_dict(self.model, self.optimizers, options=StateDictOptions(cpu_offload=True)) + pyt_to_mcore_state_dict(osd['state'], prefix="optimizer.state.", device_mesh=self.device_mesh) + sharded_state_dict["optimizer"] = osd + + checkpoint = self.checkpoint_io.load_checkpoint(path, sharded_state_dict=sharded_state_dict) + mcore_to_pyt_sharded_state_dict(checkpoint['sharded_state_dict'], msd) + + if self.ckpt_load_optimizer and self.trainer.state.fn == TrainerFn.FITTING: + mcore_to_pyt_sharded_state_dict(checkpoint['optimizer']['state'], osd['state']) + + set_state_dict( + self.model, + self.optimizers if self.ckpt_load_optimizer else [], + model_state_dict=checkpoint['sharded_state_dict'], + optim_state_dict=checkpoint['optimizer'] if self.ckpt_load_optimizer else None, + ) + + return checkpoint diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 8b3daab30b19..d38753bd7935 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -116,7 +116,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): across GPU ranks. Defaults to 1. virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to improve performance by reducing the pipeline bubble. Defaults to None. - microbatch_group_size_per_vp_stage(Optional[int]): the number of micro-batches that are executed + microbatch_group_size_per_vp_stage (Optional[int]): the number of micro-batches that are executed at a time for a given virtual stage (both forward and backward). Defaults to None and convert to pipeline_parallel_size. which specifies a depth-first schedule. context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. diff --git a/nemo/lightning/pytorch/strategies/utils.py b/nemo/lightning/pytorch/strategies/utils.py index 4f5a78419d6d..51e4a7dbfa19 100644 --- a/nemo/lightning/pytorch/strategies/utils.py +++ b/nemo/lightning/pytorch/strategies/utils.py @@ -25,6 +25,8 @@ from megatron.core.dist_checkpointing.mapping import ShardedBase, ShardedObject, ShardedTensor from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor from megatron.core.transformer.utils import _get_extra_state_offsets +from torch.distributed._composable.fsdp import MixedPrecisionPolicy +from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh @@ -328,3 +330,41 @@ def _convert(state_dict, k, sh_key, v, prepend_offsets, prefix="", allow_shape_m _convert(state_dict, k, sh_key, v, prepend_offsets, prefix, allow_shape_mismatch, device_mesh) return state_dict + + +# Taken and modified from torchtitan +# https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py +def fsdp2_strategy_parallelize(model, device_mesh: DeviceMesh = None): + """Apply parallelisms and activation checkpointing to the model. + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + dp_mesh = device_mesh["data_parallel"] + tp_mesh = device_mesh["tensor_parallel"] + + assert tp_mesh.size() == 1, "Tensor parallelism is not supported yet in this model." + + if dp_mesh.size() > 1: + assert dp_mesh.ndim == 1 # Hybrid-sharding not supported + + # NOTE: Currently, the user is required to manually handle precision settings such as the `mp_policy` here + # because the model parallel strategy does not respect all settings of `Fabric(precision=...)` at the moment. + mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + for layer_id, transformer_block in enumerate(model.model.layers): + # Apply activation checkpointing + # transformer_block = checkpoint_wrapper(transformer_block) + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.model.layers) - 1 + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + model.model.layers[layer_id] = transformer_block + model = fully_shard(model, **fsdp_config) + + return model diff --git a/tests/collections/llm/hf/sft_fsdp2.py b/tests/collections/llm/hf/sft_fsdp2.py new file mode 100755 index 000000000000..300b4a08c596 --- /dev/null +++ b/tests/collections/llm/hf/sft_fsdp2.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fiddle as fdl +import torch +from lightning.pytorch.loggers import WandbLogger +from packaging.version import Version as PkgVersion +from utils import get_torch_version_str + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated + +DATA_PATH = '/home/TestData/lite/hf_cache/squad/' + + +def make_squad_hf_dataset(data_path, tokenizer): + EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN + + def formatting_prompts_func(examples): + alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + + ### Instruction: + {} + + ### Input: + {} + + ### Response: + {}""" + instruction = examples["context"] + input = examples["question"] + output = examples["answers"]['text'] + if isinstance(output, list): + output = output[0] + text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN + ans = tokenizer(text) + ans['labels'] = ans['input_ids'] + return ans + + tokenizer = getattr(tokenizer, 'tokenizer', tokenizer) + datamodule = llm.HFDatasetDataModule(data_path, split="train[:100]", pad_token_id=tokenizer.eos_token_id) + + datamodule.map( + formatting_prompts_func, + batched=False, + batch_size=2, + remove_columns=["id", "title", "context", "question", 'answers'], + ) + + return datamodule + + +if __name__ == '__main__': + if PkgVersion(get_torch_version_str()) >= PkgVersion("2.4"): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='meta-llama/Llama-3.2-1B') + parser.add_argument('--devices', default=2) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--model-accelerator', default=None, choices=['te']) + parser.add_argument('--max-steps', type=int, default=5) + parser.add_argument("--fp8-autocast", default=False, action='store_true') + parser.add_argument('--wandb-project', type=str, default=None) + parser.add_argument('--model-save-path', type=str, default=None) + args = parser.parse_args() + + wandb = None + if args.wandb_project is not None: + model = '_'.join(args.model.split('/')[-2:]) + wandb = WandbLogger( + project=args.wandb_project, + name=f'{model}_dev{args.devices}_strat_{args.strategy}', + ) + grad_clip = None + use_dist_samp = False + + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) + tokenizer = model.tokenizer + + llm.api.finetune( + model=model, + data=make_squad_hf_dataset(DATA_PATH, tokenizer), + trainer=nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator=args.accelerator, + strategy=nl.FSDP2Strategy(data_parallel_size=2, tensor_parallel_size=1), + log_every_n_steps=1, + limit_val_batches=0.0, + num_sanity_val_steps=0, + accumulate_grad_batches=10, + gradient_clip_val=grad_clip, + use_distributed_sampler=use_dist_samp, + callbacks=[], + logger=wandb, + ), + optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), + log=None, + ) + + # Check memory usage compared to non-parallelized version + assert ( + torch.cuda.max_memory_allocated(device=None) / 1024 / 1024 < 29326 + ), f"using {torch.cuda.max_memory_allocated(device=None)/1024/1024} MB, larger than 29326 MB when not using parallelization." + + if args.model_accelerator: + if args.model_accelerator == "te": + te_acc = is_te_accelerated(model.model) + assert te_acc, "Transformer Engine acceleration was unsuccessful" + print("TE Accelerated: ", te_acc) + + if args.model_save_path is not None: + model.save_pretrained(args.model_save_path) diff --git a/tests/collections/llm/hf/sft_nemorun_fsdp2.py b/tests/collections/llm/hf/sft_nemorun_fsdp2.py new file mode 100644 index 000000000000..53dd863cb185 --- /dev/null +++ b/tests/collections/llm/hf/sft_nemorun_fsdp2.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run +from packaging.version import Version as PkgVersion +from utils import get_torch_version_str + +import nemo.lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.gpt.data.hf_dataset import SquadHFDataModule + +DATA_PATH = '/lustre/fsw/coreai_dlalgo_llm/boxiangw/squad' + + +def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor: + # Env vars for jobs are configured here + env_vars = { + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + "NCCL_NVLS_ENABLE": "0", + "NVTE_DP_AMAX_REDUCE_INTERVAL": "0", + "NVTE_ASYNC_AMAX_REDUCTION": "1", + "NVTE_FUSED_ATTN": "0", + } + + executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) + + return executor + + +if __name__ == '__main__': + if PkgVersion(get_torch_version_str()) >= PkgVersion("2.4"): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='meta-llama/Meta-Llama-3-8B-Instruct') + parser.add_argument('--devices', default=2) + parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--max-steps', type=int, default=100) + args = parser.parse_args() + + recipe = llm.hf_auto_model_for_causal_lm.finetune_recipe( + model_name=args.model, + name="sft", + num_nodes=1, + num_gpus_per_node=args.devices, + peft_scheme='none', + max_steps=args.max_steps, + ) + recipe.trainer.val_check_interval = 50 + + tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model) + recipe.data = run.Config( + SquadHFDataModule, + path_or_dataset=DATA_PATH, + split="train[:100]", + pad_token_id=tokenizer.tokenizer.eos_token_id, + tokenizer=run.Config(AutoTokenizer, pretrained_model_name=args.model), + ) + + recipe.trainer.strategy = run.Config(nl.FSDP2Strategy, data_parallel_size=2, tensor_parallel_size=1) + recipe.trainer.plugins = None + executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) + run.run(recipe, executor=executor) diff --git a/tests/collections/llm/hf/utils.py b/tests/collections/llm/hf/utils.py new file mode 100644 index 000000000000..2f1730a5fa32 --- /dev/null +++ b/tests/collections/llm/hf/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib.metadata import version +from packaging.version import Version as PkgVersion + + +def get_torch_version_str(): + import torch + + if hasattr(torch, '__version__'): + return str(torch.__version__) + else: + return version("torch") diff --git a/tests/lightning/pytorch/strategies/test_fsdp2_strategy.py b/tests/lightning/pytorch/strategies/test_fsdp2_strategy.py new file mode 100644 index 000000000000..5432e0df0420 --- /dev/null +++ b/tests/lightning/pytorch/strategies/test_fsdp2_strategy.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from importlib.metadata import version +from unittest.mock import patch + +from packaging.version import Version as PkgVersion + +from nemo.lightning.pytorch.strategies import FSDP2Strategy + + +def get_torch_version_str(): + import torch + + if hasattr(torch, '__version__'): + return str(torch.__version__) + else: + return version("torch") + + +if PkgVersion(get_torch_version_str()) >= PkgVersion("2.4"): + + class TestFSDP2Strategy: + @patch('nemo.lightning.pytorch.strategies.fsdp2_strategy.create_checkpoint_io') + def test_checkpoint_io(self, mock_create_checkpoint_io): + class Dummy: ... + + mock_create_checkpoint_io.side_effect = lambda *args, **kwargs: Dummy() + strategy = FSDP2Strategy() + + first_io = strategy.checkpoint_io + mock_create_checkpoint_io.assert_called_once() + + assert first_io == strategy.checkpoint_io + + new_io = object() + strategy.checkpoint_io = new_io + assert new_io == strategy.checkpoint_io + + strategy2 = FSDP2Strategy() + second_io = strategy2.checkpoint_io + mock_create_checkpoint_io.assert_called() + + assert first_io != second_io + assert second_io == strategy2.checkpoint_io