diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index fa2f4b190e30..656df82fd3e7 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5062,6 +5062,36 @@ jobs: rm -rf /tmp/nemo2_ckpt rm -rf /tmp/nemo2_ptq_engine + L2_NeMo_2_Distill_Llama3_TP1PP2: + needs: [pre-flight, cicd-test-container-build] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.pre-flight.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/llm/gpt_distillation.py \ + --name nemo2_llama_distill \ + --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ + --student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ + --tp_size 1 \ + --cp_size 1 \ + --pp_size 2 \ + --devices 2 \ + --log_dir /tmp/distill_logs \ + --max_steps 5 \ + --gbs 4 \ + --mbs 1 \ + --data_paths 1.0 /home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ + --index_mapping_dir examples/nlp/language_modeling/gpt_index_mappings \ + --seq_length 2048 \ + --warmup_steps 1 \ + --val_check_interval 5 \ + --log_interval 5 \ + --limit_val_batches 2 + + AFTER_SCRIPT: | + rm -rf /tmp/distill_logs + L2_NeMo_2_Export_In_Framework: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml @@ -5321,6 +5351,7 @@ jobs: - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact - L2_NeMo_2_PTQ_Llama2_FP8 + - L2_NeMo_2_Distill_Llama3_TP1PP2 - L2_NeMo_2_Export_In_Framework - L2_NeMo_2_jit_callback - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING diff --git a/nemo/collections/llm/distillation/__init__.py b/nemo/collections/llm/distillation/__init__.py new file mode 100644 index 000000000000..78eddd3b5ad7 --- /dev/null +++ b/nemo/collections/llm/distillation/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .loss import LogitsKLLoss +from .model import DistillationGPTModel + +__all__ = ["LogitsKLLoss", "DistillationGPTModel"] diff --git a/nemo/collections/llm/distillation/loss.py b/nemo/collections/llm/distillation/loss.py new file mode 100644 index 000000000000..f936f56759c1 --- /dev/null +++ b/nemo/collections/llm/distillation/loss.py @@ -0,0 +1,184 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta +from typing import TYPE_CHECKING, Tuple + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from torch import Tensor +from torch.nn.modules.loss import _Loss + +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + + +class BaseLoss(_Loss, metaclass=ABCMeta): + """Abstract base class for Megatron distillation losses.""" + + def __init__(self, model_config: "TransformerConfig"): + """ + Constructor. + + Args: + model_config: MCore transformer config. + """ + super().__init__() + self._config = model_config + + def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + """Prepares inputs safely for loss computation.""" + if isinstance(predictions, tuple): + # `ColumnParallelLinear` returns bias too + predictions, targets = predictions[0], targets[0] + targets = targets.detach() + + return predictions, targets + + def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: + """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" + loss = loss.transpose(0, 1).contiguous() + return loss, tp_reduce + + +class LogitsKLLoss(BaseLoss): + """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" + + def __init__(self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False): + """ + Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + """ + super().__init__(model_config) + self._temperature = temperature + self._reverse = reverse + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """ + Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Division by temp should happen prior to finding max for both student and teacher. + # Currently we don't use temperature in any of ours runs (temp=1.0) + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # Compute local softmax, and the reweight to compute global softmax. + if self._config.tensor_model_parallel_size > 1: + + # Maximum value along vocab dimension across all GPUs. + teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + torch.distributed.all_reduce( + teacher_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_tensor_model_parallel_group(), + ) + output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) + # We can't use standard reduction function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group()) + + # Maximum value along vocab dimension across all GPUs. + student_logits_max, _ = torch.max(output_student, dim=-1) + torch.distributed.all_reduce( + student_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_tensor_model_parallel_group(), + ) + output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + + denom_student = torch.sum(torch.exp(output_student), dim=-1) + denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group()) + + slen, bsz, sharded_vocab_size = output_student.shape + student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + + if self._reverse: + loss = torch.sum( + F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), + dim=-1, + ) + + else: + if self._reverse: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_teacher, dim=-1), + F.softmax(output_student, dim=-1), + reduction="none", + ), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_student, dim=-1), + F.softmax(output_teacher, dim=-1), + reduction="none", + ), + dim=-1, + ) + + return self.post_forward(loss, tp_reduce=True) + + +class _AllReduce(torch.autograd.Function): + """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" + + @staticmethod + def forward(ctx, op, group, tensor): + # pylint: disable=C0116 + ctx.group, ctx.op = group, op + tensor = tensor.clone() + torch.distributed.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=C0116 + return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) + + +def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD): + """Custom all-reduce function. + + Needed instead of other all-reduce functions available when the computation following + the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. + """ + return _AllReduce.apply(op, group, tensor) diff --git a/nemo/collections/llm/distillation/model.py b/nemo/collections/llm/distillation/model.py new file mode 100644 index 000000000000..4bd1510a4b7c --- /dev/null +++ b/nemo/collections/llm/distillation/model.py @@ -0,0 +1,257 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple + +import modelopt.torch.distill as mtd +import torch +from megatron.core import parallel_state +from megatron.core.transformer.module import Float16Module as MCoreFloat16Module +from torch import Tensor, nn + +from nemo.collections import llm +from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction +from nemo.utils.model_utils import unwrap_model + +from .utils import ( + LoopingCachedDataIterator, + adjust_distillation_model_for_mcore, + load_distillation_config, + teacher_provider, +) + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from nemo.lightning.pytorch.optim import OptimizerModule + + +def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, Tensor]: + """Same as base GPT's data step but with ability to move attention mask to CPU.""" + batch = next(dataloader_iter) + + if isinstance(batch, tuple) and len(batch) == 3: + batch = batch[0] + + required_device_keys = set() + required_host_keys = set() + + if attn_mask_cpu: + # [ModelOpt]: We cache data for PP distillation, and save GPU mem by storing masks on CPU mem. + required_host_keys.add("attention_mask") + else: + required_device_keys.add("attention_mask") + + if 'cu_seqlens' in batch: + required_device_keys.add('cu_seqlens') + required_host_keys.add('cu_seqlens_argmin') + required_host_keys.add('max_seqlen') + + if parallel_state.is_pipeline_first_stage(): + required_device_keys.update(("tokens", "position_ids")) + if parallel_state.is_pipeline_last_stage(): + required_device_keys.update(("labels", "loss_mask")) + + batch_required_keys = {} + for key, val in batch.items(): + if key in required_device_keys: + batch_required_keys[key] = val.cuda(non_blocking=True) + elif key in required_host_keys: + batch_required_keys[key] = val.cpu() + else: + batch_required_keys[key] = None + + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(batch_required_keys) + + return output + + +class _DistillationLossReduction(MaskedTokenLossReduction): + """Custom masking and reduction callable used only in training mode.""" + + def __init__(self, distillation_loss_fn, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distillation_loss_fn = distillation_loss_fn + self._cp_size = parallel_state.get_context_parallel_world_size() + + def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: + if isinstance(forward_out, tuple): + # neva returns (logits, loss_mask) + forward_out, batch["loss_mask"] = forward_out + + # [ModelOpt]: KD loss calculation. + loss_for_ub = self._distillation_loss_fn( + loss_reduction_fn=lambda x: self._masked_token_loss( + x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub") + ) + ) + + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub * self._cp_size, {"avg": reduced_loss} + + def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None): + """The function takes as input per-token loss and masks non-required values.""" + if isinstance(loss_output, tuple): + # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) + loss_output, tp_reduce = loss_output + else: + tp_reduce = False + losses = loss_output.float() + loss_mask = mask.view(-1).float() + + if self._cp_size > 1: + if num_valid_tokens_in_ub is None: + num_valid_tokens_in_ub = loss_mask.sum() + if num_valid_tokens_in_ub < 0.5: # no valid tokens + num_valid_tokens_in_ub += 1.0 + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + + if tp_reduce is True: + torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group()) + + return loss + + +class DistillationGPTModel(llm.GPTModel): + """Custom GPT subclass for distillation-related modifications.""" + + def __init__( + self, + student_config: llm.GPTConfig, + teacher_config: llm.GPTConfig, + teacher_ckpt_path: str, + optim: Optional["OptimizerModule"] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + """ + Constructor. + + This subclass of GPTModel takes the configs of a student and teacher model and overrides + the model construction step to create a ModelOpt `DistillationModel` as the underlying + MCore model. This model abstracts both student and teacher as a single module whose forward + pass runs both, and whose loss function automatically calculates a distillation loss on the + output logits. + + NOTE: This class saves checkpoints which will be re-loaded as the student's original class. + This allows one to continue using the model after distillation without this special class. + + Args: + student_config: Config of student model. + teacher_config: Config of teacher model. + teacher_ckpt_path: Path to teacher checkpoint (to restore weights). + optim: Optimizer. + tokenizer: Tokenizer. + model_transform: Transform to apply to model during setup. + """ + super().__init__(student_config, optim, tokenizer, model_transform) + self._teacher_config = teacher_config + self._teacher_ckpt_path = teacher_ckpt_path + self._train_called = False + + if self.config.virtual_pipeline_model_parallel_size is not None: + raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") + + def configure_model(self): + if hasattr(self, "module"): + return + + model = self.config.configure_model(self.tokenizer) + + # Ensure same for both models. + for attr in [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "sequence_parallel", + "pipeline_dtype", + ]: + setattr(self._teacher_config, attr, getattr(self.config, attr)) + + # [ModelOpt] Intialize DistillationModel. + distill_cfg = load_distillation_config(self.config) + kd_config = { + "teacher_model": ( + teacher_provider, + [self._teacher_config, self._teacher_ckpt_path], + {"tokenizer": self.tokenizer, "trainer": self.trainer}, + ), + "criterion": distill_cfg["criterion"], + "loss_balancer": distill_cfg["loss_balancer"], + } + distillation_model = mtd.convert(model, mode=[("kd_loss", kd_config)]) + + # Additional MCore-specific tweaks needed. + adjust_distillation_model_for_mcore(distillation_model, model_cfg=self.config, distill_cfg=distill_cfg) + + self.module = distillation_model + + def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> Dict[str, Tensor]: + # NOTE: Ignores `self.config.data_step_fn` + if cache_num_batches: + batches = [ + gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches) + ] + return LoopingCachedDataIterator(batches) + elif isinstance(dataloader_iter, LoopingCachedDataIterator): + batch = next(dataloader_iter) + if "attention_mask" in batch: + batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU + return batch + else: + return gpt_distillation_data_step(dataloader_iter) + + def get_inference_wrapper(self, *args, **kwargs) -> Tensor: + raise NotImplementedError( + "Please restore a checkpoint of this model to its original class to call `get_inference_wrapper`" + ) + + @property + def training_loss_reduction(self) -> _DistillationLossReduction: + if not self._training_loss_reduction: + self._training_loss_reduction = _DistillationLossReduction( + distillation_loss_fn=self.core_module.compute_kd_loss + ) + return self._training_loss_reduction + + def load_state_dict(self, state_dict, *args, **kwargs): + # pylint: disable=C0116 + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # `super()` would go to `nn.Module` and skip the Context Manager in `mtd.DistillationModel.load_state_dict()` + return self.core_module.load_state_dict(state_dict, *args, *kwargs) + + @property + def core_module(self): + # pylint: disable=C0116 + return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) + + def train(self, mode: bool = True): + # pylint: disable=C0116 + self._train_called = True + return super().train(mode) + + def __setattr__(self, name, value): + # HACK: PTL calls `module.training = True` after sanity check, bypassing `module.train()` which we depend on. + if name == "training": + if not self._train_called: + self.train(value) + return + self._train_called = False + return super().__setattr__(name, value) diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py new file mode 100644 index 000000000000..3ed48865d19f --- /dev/null +++ b/nemo/collections/llm/distillation/utils.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto +import torch +from megatron.core import parallel_state +from torch import Tensor + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.utils import logging + +from .loss import LogitsKLLoss + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + from megatron.core.transformer.transformer_config import TransformerConfig + + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +def load_distillation_config(cfg: "TransformerConfig") -> Dict[str, Any]: + """Create a default distillation config for MCore GPT Models.""" + logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel + distill_cfg = { + "criterion": {}, + "loss_balancer": _DummyLossBalancer(), # HACK: to appease ModelOpt until validation relaxed + "skip_lm_loss": True, + } + if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): + distill_cfg["criterion"][logit_pair] = LogitsKLLoss(cfg) + + return distill_cfg + + +class _DummyLossBalancer(mtd.DistillationLossBalancer): + def forward(self, loss_dict): + # pylint: disable=C0116 + return next(iter(loss_dict.values())) + + +def teacher_provider( + config: llm.GPTConfig, ckpt_path: str, tokenizer: "TokenizerSpec", trainer: nl.Trainer +) -> "MCoreGPTModel": + """Teacher model factory (must be a non-local function to pickle).""" + logging.info("Distillation: Loading teacher weights...") + + # TODO(aanoosheh): Replace spec with modelopt one + model = config.configure_model(tokenizer) + + sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")} + checkpoint = trainer.strategy.checkpoint_io.load_checkpoint(ckpt_path, sharded_state_dict) + state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} + model.load_state_dict(state_dict) + + torch.cuda.empty_cache() + logging.info("Distillation: ...teacher weights loaded.") + return model + + +class LoopingCachedDataIterator: + """Iterator which takes in a sequence and cycles through it when exhausted.""" + + def __init__(self, data): + self.data = data + self.it = iter(self.data) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.it) + except StopIteration: + self.it = iter(self.data) + return next(self.it) + + +def adjust_distillation_model_for_mcore( + model: mtd.DistillationModel, model_cfg: "TransformerConfig", distill_cfg: Dict[str, Any] +): + """Extra modifcations to ``mtd.DistillationModel`` requried for Megatron-Core.""" + + # Get rid of ModelOpt Distillation state + # NOTE: If re-placed, above losses need modifcation as `TransformerConfig` has non-pickleable elements. + mto.ModeloptStateManager(model)._state.pop() + + # Hide teacher during `sharded_state_dict` method. + def _sharded_state_dict(self, *args, **kwargs) -> "ShardedStateDict": + with self.hide_teacher_model(): + return self._sharded_state_dict(*args, **kwargs) + + model._sharded_state_dict = model.sharded_state_dict + model.sharded_state_dict = MethodType(_sharded_state_dict, model) + + # Skip `lm_loss` bypassing it when training if not needed for backprop. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + if self.training: + return torch.zeros_like(labels, dtype=logits.dtype) + return self._compute_language_model_loss(labels, logits) + + if distill_cfg["skip_lm_loss"]: + model._compute_language_model_loss = model.compute_language_model_loss + model.compute_language_model_loss = MethodType(_compute_language_model_loss, model) + + # Skip `lm_loss` always for teacher. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + return torch.zeros_like(labels, dtype=logits.dtype) + + model.teacher_model.compute_language_model_loss = MethodType(_compute_language_model_loss, model.teacher_model) + + if model_cfg.pipeline_model_parallel_size > 1: + + def _set_input_tensor(self, input_tensor: Tensor): + obj = self.teacher_model if self._only_teacher_fwd else self + return type(self).set_input_tensor(obj, input_tensor) + + # Pipeline-parallel Distillation requires a way to cache input batches for subsequent + # forward calls, as well as a way to pass through output tensors to teacher model. + model.set_input_tensor = MethodType(_set_input_tensor, model) + + @contextmanager + def _swap_teacher_config(self, model_wrapper): + try: + if hasattr(model_wrapper, "config"): + model_wrapper._config = model_wrapper.config + model_wrapper.config = self.teacher_model.config + yield + finally: + del model_wrapper.config + if hasattr(model_wrapper, "_config"): + model_wrapper.config = model_wrapper._config + del model_wrapper._config + + # Pipeline-parallel forward function relies on the config in the model to know what + # hidden size of tensor to communicate to next stage. + model.swap_teacher_config = MethodType(_swap_teacher_config, model) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 1a28dc26b25c..4af9c6c1263b 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -40,7 +40,7 @@ # TODO: Clean this up with a getter and install instructions _grad_accum_fusion_available = True try: - import fused_weight_gradient_mlp_cuda + import fused_weight_gradient_mlp_cuda # noqa: F401 # pylint: disable=unused-import except ImportError: _grad_accum_fusion_available = False @@ -237,7 +237,8 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC # TP, CP group to the TE modules. # Deep iterate but skip self to avoid infinite recursion. if HAVE_TE and self.use_transformer_engine_full_layer_spec: - # Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + # Copied from: + # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py if parallel_state.get_tensor_model_parallel_world_size() > 1: for index, child in enumerate(model.modules()): if index == 0: @@ -414,7 +415,8 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres vocab_size = self.config.vocab_size else: raise ValueError( - 'Unable to find vocab size. Either pass in a tokenizer with vocab size, or set vocab size in the model config' + 'Unable to find vocab size.' + ' Either pass in a tokenizer with vocab size, or set vocab size in the model config' ) inference_wrapper_config = InferenceWrapperConfig( @@ -460,8 +462,8 @@ def get_batch_on_this_context_parallel_rank(batch) -> Dict[str, torch.Tensor]: val.shape[seq_dim] // (2 * cp_size), *val.shape[(seq_dim + 1) :], ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).to( + _val.device, non_blocking=True ) _val = _val.index_select(seq_dim, index) _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :]) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 90499b69f216..7d0b5ba31f0c 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -187,7 +187,8 @@ def __init__( convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: from megatron.core import parallel_state - from megatron.core.tensor_parallel import set_defaults_if_not_set_tensor_model_parallel_attributes + + from nemo.utils.model_utils import unwrap_model _pipeline: List[nn.Module] if isinstance(pipeline, nn.ModuleList): @@ -219,6 +220,20 @@ def __init__( self.ddp_config = ddp_config self.convert_module_fn = convert_module_fn + # [ModelOpt]: Detect Pipeline-parallel Distillation mode. + self._unwrapped_model = [unwrap_model(self)] + # Avoid re-registering module which breaks the inherited `ModuleList` somehow. + if ( + hasattr(self.unwrapped_model, "teacher_model") + and parallel_state.get_pipeline_model_parallel_world_size() > 1 + ): + self._kd_teacher_in_pp = True + assert ( + not self.ddp_config.overlap_grad_reduce + ), "Pipeline-parallel Distillation currently incomatible with `overlap_grad_reduce` DDP option." + else: + self._kd_teacher_in_pp = False + def forward( self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]], @@ -242,9 +257,12 @@ def forward( Args: data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The input data for the model. forward_only (bool, optional): If True, only perform the forward pass. Defaults to True. - data_step (Optional[Callable[[Iterator[DataT]], DataT]], optional): Function to process the data. Defaults to None. - forward_step (Optional[Callable[[nn.Module, DataT], Tensor]], optional): Function to perform the forward pass. Defaults to None. - loss_reduction (Optional[MegatronLossReduction[DataT, Any]], optional): Function to reduce the loss. Defaults to None. + data_step (Optional[Callable[[Iterator[DataT]], DataT]], optional): Function to process the data. + Defaults to None. + forward_step (Optional[Callable[[nn.Module, DataT], Tensor]], optional): Function to perform the + forward pass. Defaults to None. + loss_reduction (Optional[MegatronLossReduction[DataT, Any]], optional): Function to reduce the + loss. Defaults to None. seq_length (Optional[int], optional): Sequence length for the model. Defaults to None. micro_batch_size (Optional[int], optional): Size of the micro batch. Defaults to None. num_microbatches (Optional[int], optional): Number of microbatches. Defaults to None. @@ -269,6 +287,37 @@ def forward( else: forward_step_func = _forward_step + if self._kd_teacher_in_pp: + assert wrap_forward_step + _teacher_forward_context = {} + if isinstance(_data_step, _ModuleStepFunction): + _data_step = _data_step(self.module) + + def _dummy_reduction(output_tensor, *args, **kwargs): + return output_tensor.new_tensor(-1), {} + + teacher_forward_step_func = self.wrapped_forward_step( + forward_step=_forward_step, + data_step=_data_step, + loss_reduction=_dummy_reduction, + context=_teacher_forward_context, + ) + teacher_step = MegatronStep.infer( + self, + None, # updated later below once we actually know `num_microbatches` + teacher_forward_step_func, + forward_only=True, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + seq_length=seq_length, + step_i=step_i, + ) + _teacher_forward_context["step"] = teacher_step + teacher_step = self.callbacks.transform_event("on_megatron_step_start", teacher_step) + + data = _data_step(data, cache_num_batches=teacher_step.num_microbatches) + teacher_step.data = data + step = MegatronStep.infer( self, data, @@ -283,7 +332,14 @@ def forward( step = self.callbacks.transform_event("on_megatron_step_start", step) self.callbacks.event("on_megatron_microbatches_start", step=step) - microbatch_outputs = step() + if self._kd_teacher_in_pp: + with self.unwrapped_model.only_teacher_forward(): + with self.unwrapped_model.swap_teacher_config(self.module): + teacher_step() + with self.unwrapped_model.only_student_forward(): + microbatch_outputs = step() + else: + microbatch_outputs = step() self.callbacks.event("on_megatron_microbatches_end", step=step, microbatch_outputs=microbatch_outputs) if microbatch_outputs: @@ -292,7 +348,7 @@ def forward( ) if isinstance(_loss_reduction, _ModuleStepFunction): - _loss_reduction = _loss_reduction(self[0]) + _loss_reduction = _loss_reduction(self.module) reduced = _loss_reduction.reduce(microbatch_outputs) self.callbacks.event( @@ -551,14 +607,15 @@ def init_model_parallel(self): msg = ( f" > number of parameters on (tensor, pipeline) model parallel rank " - f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " + f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " # pylint: disable=line-too-long f"{num_params}" ) logging.info(msg) if num_params != num_trainable_params: logging.info( - f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" + " > number of trainable parameters: " + f"{num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" ) if self.convert_module_fn: @@ -700,6 +757,10 @@ def pipeline(self) -> Union[ModelT, List[ModelT]]: def module(self) -> ModelT: return self[0] + @property + def unwrapped_model(self): + return self._unwrapped_model[0] + @override def __getattr__(self, item: Any) -> Any: try: @@ -846,8 +907,8 @@ class CallbackConnector: Each of these methods corresponds to a specific stage in the model's operation. You can define these methods in your callback functions to perform specific actions at these stages. - There is no need for the class to be a subclass of a specific parent class. As long as the class contains the methods outlined above, - it can be used as a callback. + There is no need for the class to be a subclass of a specific parent class. + As long as the class contains the methods outlined above, it can be used as a callback. """ def __init__(self, callbacks=None) -> None: @@ -1076,7 +1137,8 @@ class MegatronStep(Generic[ModelT, DataT]): micro_batch_size (Optional[int]): Size of each micro-batch. seq_length (Optional[int]): Sequence length for the current step. num_microbatches (Optional[int]): Number of micro-batches in this step. - decoder_seq_length (Optional[int]): Sequence length of decoder (used only in encoder-decoder style models) for the current step. + decoder_seq_length (Optional[int]): Sequence length of decoder (used only in + encoder-decoder style models) for the current step. Type Parameters: ModelT: The type of the model being used. @@ -1648,9 +1710,7 @@ def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> def forward( self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """Taken from: - https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 . - """ + """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .""" # pylint: disable=line-too-long from megatron.core import parallel_state from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group @@ -1688,7 +1748,7 @@ def forward( return loss_for_ub * cp_size, {"avg": reduced_loss} def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: - """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" + """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" # pylint: disable=line-too-long if losses_reduced_per_micro_batch: if "avg" in losses_reduced_per_micro_batch[0]: loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch] diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 5d7d019c6099..974263eb7f68 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -24,13 +24,15 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import wrapt from nemo.utils import AppState, logging -from nemo.utils.data_utils import resolve_cache_dir # imported for compatibility: model_utils.resolve_cache_dir() -from nemo.utils.data_utils import is_datastore_path +from nemo.utils.data_utils import ( # imported for compatibility: model_utils.resolve_cache_dir() # noqa: F401 # pylint: disable=unused-import,line-too-long + is_datastore_path, + resolve_cache_dir, +) # TODO @blisc: Perhaps refactor instead of import guarding @@ -43,6 +45,12 @@ except ModuleNotFoundError: _HAS_HYDRA = False +if TYPE_CHECKING: + import lightning.pytorch as pl + + from nemo.core.classes import ModelPT, PretrainedModelInfo + from nemo.core.config.modelPT import NemoConfig + MODEL_CONFIG = "model_config.yaml" _VAL_TEST_FASTPATH_KEY = 'ds_item' @@ -92,7 +100,7 @@ def load_config(model_file: str) -> DictConfig: return model_config -def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): +def unwrap_model(model, module_instances: Optional[Union[Type, Tuple[Type]]] = None): """Unwrap model from wrapper classes like Float16Module, for example.""" # TODO: Import this from megatron.core once moved there from megatron.training. @@ -102,8 +110,12 @@ def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): return_list = False unwrapped_model = [] for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module + if module_instances: + while isinstance(model_module, module_instances): + model_module = model_module.module + else: # remove any wrappers that have a '.module' attribute + while hasattr(model_module, "module"): + model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] @@ -342,7 +354,8 @@ def resolve_validation_dataloaders(model: 'ModelPT'): if len(ds_names) > 0: if len(ds_names) != len(ds_values): raise ValueError( - f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + f"Number of names ({len(ds_names)}) does not match number of " + f"datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" ) model._validation_names = [parse_dataset_as_name(n) for n in ds_names] else: @@ -436,7 +449,8 @@ def resolve_test_dataloaders(model: 'ModelPT'): if len(ds_names) > 0: if len(ds_names) != len(ds_values): raise ValueError( - f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + f"Number of names ({len(ds_names)}) does not match number of " + f"datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" ) model._test_names = [parse_dataset_as_name(n) for n in ds_names] else: @@ -648,7 +662,8 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op return True, msg else: msg = ( - f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} than required version {checked_version}.\n" + f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} " + f"than required version {checked_version}.\n" f"Please upgrade the lib using either pip or conda to the latest version." ) return False, msg @@ -692,7 +707,7 @@ def inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=False): if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1: filepath = f'{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}{fsdp_shard}/{basename}' else: - filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}' + filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}' # pylint: disable=line-too-long return filepath else: fsdp_shard = f'/fsdp_shard_{app_state.data_parallel_rank:05d}' if fsdp_sharded_ckpt else '' diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py new file mode 100644 index 000000000000..64271551b08c --- /dev/null +++ b/scripts/llm/gpt_distillation.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +from lightning.pytorch.loggers import TensorBoardLogger +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm import distillation as distill +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler + +# Suppress lengthy HF warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def get_args(): + """Parse the command line arguments.""" + parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") + + parser.add_argument("--name", type=str, required=True, help="""Experiment name""") + parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--tp_size", type=int, default=1, help="""Tensor parallel size""") + parser.add_argument("--cp_size", type=int, default=1, help="""Context parallel size""") + parser.add_argument("--pp_size", type=int, default=1, help="""Pipeline parallel size""") + parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""") + parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""") + parser.add_argument("--num_nodes", type=int, default=1, help="""Number of nodes to use""") + parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") + parser.add_argument("--max_steps", type=int, required=True, help="""Number of global batches to process""") + parser.add_argument("--gbs", type=int, required=True, help="""Global Batch Size""") + parser.add_argument("--mbs", type=int, required=True, help="""Micro-batch Size""") + parser.add_argument("--data_paths", nargs="+", required=True, help="""List of tokenized data paths to load from""") + parser.add_argument("--split", type=str, default="99,1,0", help="""Train,Val,Test ratios to split data""") + parser.add_argument("--index_mapping_dir", type=str, default=None, help="""Folder to write cached data indices""") + parser.add_argument("--seq_length", type=int, required=True, help="""Number of tokens per input sample""") + parser.add_argument("--lr", type=float, default=3e-5, help="""Base LR for Cosine-Annealing scheduler""") + parser.add_argument("--min_lr", type=float, default=2e-7, help="""Minimum LR for Cosine-Annealing scheduler""") + parser.add_argument("--warmup_steps", type=int, default=50, help="""Number of scheduler warmup steps""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""Validate + checkpoint every _ steps""") + parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") + parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + ## Initialize the strategy and trainer + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, + sequence_parallel=(args.tp_size > 1), + ) + trainer = nl.Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + max_steps=args.max_steps, + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval, + limit_val_batches=args.limit_val_batches, + strategy=strategy, + accelerator="gpu", + plugins=nl.MegatronMixedPrecision(precision=args.precision), + ) + + ## Load both models and combine into an aggregate module + _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") + + tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) + assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." + + model = distill.DistillationGPTModel( + _student_model.config, + _teacher_model.config, + teacher_ckpt_path=args.teacher_path, + tokenizer=tokenizer, + ) + # TODO(aanoosheh): Replace spec with modelopt one + model.__io__ = _student_model.__io__ # HACK: model saves and restores as original class + + # Set up dataset + data = llm.PreTrainingDataModule( + paths=args.data_paths, + seq_length=args.seq_length, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + split=args.split, + index_mapping_dir=args.index_mapping_dir, + tokenizer=tokenizer, + ) + + ## Set up optimizer + opt_config = OptimizerConfig( + optimizer="adam", + lr=args.lr, + bf16=("bf16" in args.precision), + use_distributed_optimizer=True, + ) + sched = CosineAnnealingScheduler( + max_steps=args.max_steps, + warmup_steps=args.warmup_steps, + constant_steps=0, + min_lr=args.min_lr, + ) + opt = nl.MegatronOptimizerModule(opt_config, sched) + + # Set up checkpointing and logging + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_top_k=1, + every_n_train_steps=args.val_check_interval, + ) + logger = nl.NeMoLogger( + name=args.name, + log_dir=args.log_dir, + ckpt=checkpoint_callback, + tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), + update_logger_directory=False, + ) + + # Set up resume and/or restore functionality + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + restore_config=nl.RestoreConfig(path=args.student_path), + ) + + # Run + llm.train( + model=model, + data=data, + optim=opt, + tokenizer="model", + trainer=trainer, + log=logger, + resume=resume, + ) diff --git a/tests/collections/llm/gpt_distillation.py b/tests/collections/llm/gpt_distillation.py new file mode 100644 index 000000000000..4f8e0c1cef37 --- /dev/null +++ b/tests/collections/llm/gpt_distillation.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +from lightning.pytorch.loggers import TensorBoardLogger +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers.huggingface import AutoTokenizer +from nemo.collections.llm import distillation as distill +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from tests.collections.llm.common import Llama3ConfigCI + +# Suppress lengthy HF warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def get_args(): + """Parse the command line arguments.""" + parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") + + parser.add_argument("--name", type=str, required=True, help="""Experiment name""") + parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--tp_size", type=int, default=1, help="""Tensor parallel size""") + parser.add_argument("--cp_size", type=int, default=1, help="""Context parallel size""") + parser.add_argument("--pp_size", type=int, default=1, help="""Pipeline parallel size""") + parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""") + parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""") + parser.add_argument("--num_nodes", type=int, default=1, help="""Number of nodes to use""") + parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") + parser.add_argument("--max_steps", type=int, required=True, help="""Number of global batches to process""") + parser.add_argument("--gbs", type=int, required=True, help="""Global Batch Size""") + parser.add_argument("--mbs", type=int, required=True, help="""Micro-batch Size""") + parser.add_argument("--data_paths", nargs="+", required=True, help="""List of tokenized data paths to load from""") + parser.add_argument("--split", type=str, default="99,1,0", help="""Train,Val,Test ratios to split data""") + parser.add_argument("--index_mapping_dir", type=str, default=None, help="""Folder to write cached data indices""") + parser.add_argument("--seq_length", type=int, required=True, help="""Number of tokens per input sample""") + parser.add_argument("--lr", type=float, default=3e-5, help="""Base LR for Cosine-Annealing scheduler""") + parser.add_argument("--min_lr", type=float, default=2e-7, help="""Minimum LR for Cosine-Annealing scheduler""") + parser.add_argument("--warmup_steps", type=int, default=50, help="""Number of scheduler warmup steps""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""Validate + checkpoint every _ steps""") + parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") + parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + ## Initialize the strategy and trainer + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, + sequence_parallel=(args.tp_size > 1), + ) + trainer = nl.Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + max_steps=args.max_steps, + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval, + limit_val_batches=args.limit_val_batches, + strategy=strategy, + accelerator="gpu", + plugins=nl.MegatronMixedPrecision(precision=args.precision), + ) + + ## Load both models and combine into an aggregate module + # NOTE: Special model and tokenizer for CI runs only + tokenizer = AutoTokenizer("gpt2") + _student_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + _teacher_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + + model = distill.DistillationGPTModel( + _student_model.config, + _teacher_model.config, + teacher_ckpt_path=args.teacher_path, + tokenizer=tokenizer, + ) + # TODO(aanoosheh): Replace spec with modelopt one + model.__io__ = _student_model.__io__ # HACK: model saves and restores as original class + + # Set up dataset + data = llm.PreTrainingDataModule( + paths=args.data_paths, + seq_length=args.seq_length, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + split=args.split, + index_mapping_dir=args.index_mapping_dir, + tokenizer=tokenizer, + ) + + ## Set up optimizer + opt_config = OptimizerConfig( + optimizer="adam", + lr=args.lr, + bf16=("bf16" in args.precision), + use_distributed_optimizer=True, + ) + sched = CosineAnnealingScheduler( + max_steps=args.max_steps, + warmup_steps=args.warmup_steps, + constant_steps=0, + min_lr=args.min_lr, + ) + opt = nl.MegatronOptimizerModule(opt_config, sched) + + # Set up checkpointing and logging + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_top_k=1, + every_n_train_steps=args.val_check_interval, + ) + logger = nl.NeMoLogger( + name=args.name, + log_dir=args.log_dir, + ckpt=checkpoint_callback, + tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), + update_logger_directory=False, + ) + + # Set up resume and/or restore functionality + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + restore_config=nl.RestoreConfig(path=args.student_path), + ) + + # Run + llm.train( + model=model, + data=data, + optim=opt, + tokenizer="model", + trainer=trainer, + log=logger, + resume=resume, + )