From 4b426db5c50d181a9cd909d74526f3c9579eb294 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 25 Dec 2024 08:33:02 -0800 Subject: [PATCH 01/22] First draft of distill script port to 2.0 Signed-off-by: Asha Anoosheh --- scripts/llm/gpt_distillation.py | 491 ++++++++++++++++++++++++++++++++ 1 file changed, 491 insertions(+) create mode 100644 scripts/llm/gpt_distillation.py diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py new file mode 100644 index 000000000000..98696a815b2d --- /dev/null +++ b/scripts/llm/gpt_distillation.py @@ -0,0 +1,491 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from abc import ABCMeta +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto +import torch +import torch.nn.functional as F +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel +from megatron.core.optimizer import OptimizerConfig +from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, + get_tensor_model_parallel_group, +) +from megatron.core.transformer.transformer_config import TransformerConfig +from torch import Tensor +from torch.nn.modules.loss import _Loss + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.quantization import load_with_modelopt_layer_spec +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec +from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.lightning.megatron_parallel import MaskedTokenLossReduction +from nemo.utils import logging + + +@dataclass +class DistillationGPTConfig(llm.GPTConfig): + kd_teacher_restore_from_path: str = "" # default set only for dataclass inheritance + + def configure_model(self, *args, **kwargs) -> MCoreGPTModel: + if not self.kd_teacher_restore_from_path: + raise ValueError("Config attribute `kd_teacher_restore_from_path` must be set.") + + model = super().configure_model(*args, **kwargs) + + # [ModelOpt] Intialize DistillationModel. + distill_cfg = load_distillation_config(self) + kd_config = { + "teacher_model": (_teacher_provider, [], {"cfg": self}), + "criterion": distill_cfg["criterion"], + "loss_balancer": distill_cfg["loss_balancer"], + } + model = mtd.convert(model, mode=[("kd_loss", kd_config)]) + + # Additional MCore-specific tweaks needed. + adjust_distillation_model_for_mcore(model, model_cfg=self, distill_cfg=distill_cfg) + + return model + + +class _DistillationLossReduction(MaskedTokenLossReduction): + """Custom masking and reduction callable used only in training mode.""" + + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distillation_model: mtd.DistillationModel = model.module + self._cp_size = get_context_parallel_world_size() + + def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: + if isinstance(forward_out, tuple): + # neva returns (logits, loss_mask) + forward_out, batch["loss_mask"] = forward_out + + # [ModelOpt]: KD loss calculation. + loss_for_ub = self._distillation_model.compute_kd_loss( + loss_reduction_fn=lambda x: self._masked_token_loss(x, batch["loss_mask"], batch['num_valid_tokens_in_ub']) + ) + + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub * self._cp_size, {"avg": reduced_loss} + + def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None): + """ + The function takes as input per-token loss and masks non-required values. + """ + if isinstance(loss_output, tuple): + # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) + loss_output, tp_reduce = loss_output + losses = loss_output.float() + loss_mask = mask.view(-1).float() + + if self._cp_size > 1: + if num_valid_tokens_in_ub is None: + num_valid_tokens_in_ub = loss_mask.sum() + if num_valid_tokens_in_ub < 0.5: # no valid tokens + num_valid_tokens_in_ub += 1.0 + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll + torch.distributed.all_reduce(loss, group=get_context_parallel_group()) + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + + if tp_reduce is True: + torch.distributed.all_reduce(loss, group=get_tensor_model_parallel_group()) + + return loss + + +class DistillationGPTModel(llm.GPTModel): + """Custom GPT subclass for distillation-related modifications.""" + + @property + def training_loss_reduction(self) -> _DistillationLossReduction: + if not self._training_loss_reduction: + self._training_loss_reduction = _DistillationLossReduction() + + return self._training_loss_reduction + + +######################################################## + + +class BaseLoss(_Loss, metaclass=ABCMeta): + """Abstract base class for Megatron distillation losses.""" + + def __init__(self, model_config: TransformerConfig): + """ + Constructor. + + Args: + model_config: MCore transformer config. + """ + super().__init__() + self._config = model_config + + def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + """Prepares inputs safely for loss computation.""" + if isinstance(predictions, tuple): + # `ColumnParallelLinear` returns bias too + predictions, targets = predictions[0], targets[0] + targets = targets.detach() + + return predictions, targets + + def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: + """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" + loss = loss.transpose(0, 1).contiguous() + return loss, tp_reduce + + +class LogitsKLLoss(BaseLoss): + """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" + + def __init__(self, model_config: TransformerConfig, temperature: float = 1.0, reverse: bool = False): + """ + Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + """ + super().__init__(model_config) + self._temperature = temperature + self._reverse = reverse + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """ + Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Division by temp should happen prior to finding max for both student and teacher. + # Currently we don't use temperature in any of ours runs (temp=1.0) + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # Compute local softmax, and the reweight to compute global softmax. + if self._config.tensor_model_parallel_size > 1: + + # Maximum value along vocab dimension across all GPUs. + teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + torch.distributed.all_reduce( + teacher_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_tensor_model_parallel_group(), + ) + output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) + # We can't use standard reduction function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = all_reduce_autograd(denom_teacher, group=get_tensor_model_parallel_group()) + + # Maximum value along vocab dimension across all GPUs. + student_logits_max, _ = torch.max(output_student, dim=-1) + torch.distributed.all_reduce( + student_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=get_tensor_model_parallel_group(), + ) + output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + + denom_student = torch.sum(torch.exp(output_student), dim=-1) + denom_student = all_reduce_autograd(denom_student, group=get_tensor_model_parallel_group()) + + slen, bsz, sharded_vocab_size = output_student.shape + student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + + if self._reverse: + loss = torch.sum( + F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), + dim=-1, + ) + + else: + if self._reverse: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_teacher, dim=-1), + F.softmax(output_student, dim=-1), + reduction="none", + ), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_student, dim=-1), + F.softmax(output_teacher, dim=-1), + reduction="none", + ), + dim=-1, + ) + + return self.post_forward(loss, tp_reduce=True) + + +class _AllReduce(torch.autograd.Function): + """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" + + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group, ctx.op = group, op + tensor = tensor.clone() + torch.distributed.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) + + +def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD): + """Custom all-reduce function. + + Needed instead of other all-reduce functions available when the computation following + the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. + """ + return _AllReduce.apply(op, group, tensor) + + +######################################################## + + +def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: + """Create a default distillation config for MCore GPT Models. + + Args: + student_cfg: Model config for student model. + """ + logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel + cfg = { + "criterion": {tuple(logit_pair): LogitsKLLoss(cfg)}, + "loss_balancer": None, + "skip_lm_loss": True, + } + return cfg + + +def _adjust_layer_index_for_pp(submodule_name, model_cfg): + """ + Adjust any sequence-based layer indices found in a submodule name for Pipeline Parallelism. + + For example, on PP=2, layer called `"decoder.layers.17.input_layernorm"` in a model with 32 layers + will be assumed to be evely distributed among PP ranks and now be referenced on the final rank as + `"decoder.layers.2.input_layernorm"`. + + NOTE: Operates under assumption of being final PP rank and only one numerical index per layer name. + """ + ... + # TODO + + +def _teacher_provider(cfg: TransformerConfig) -> MCoreGPTModel: + """Teacher model factory (must be a non-local function to pickle).""" + + logging.info("Distillation: Loading teacher weights...") + model = load_with_modelopt_layer_spec( + cfg.kd_teacher_restore_from_path, + tensor_model_parallel_size=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size=cfg.pipeline_model_parallel_size, + inference_only=True, + ) + logging.info("Distillation: ... teacher weights loaded.") + return model.module + + +def adjust_distillation_model_for_mcore( + model: mtd.DistillationModel, model_cfg: TransformerConfig, distill_cfg: Dict[str, Any] +): + """Extra modifcations to ``mtd.DistillationModel`` requried for Megatron-Core.""" + + # HACK: Get rid of ModelOpt Distillation state + # NOTE: If re-placed, above losses need modifcation as `TransformerConfig` has non-pickleable elements. + mto.ModeloptStateManager(model)._state.pop() + + # HACK: Hide teacher during `sharded_state_dict` method. + def _sharded_state_dict(self, *args, **kwargs) -> ShardedStateDict: + with self.hide_teacher_model(): + return self._sharded_state_dict(*args, **kwargs) + + model._sharded_state_dict = model.sharded_state_dict + model.sharded_state_dict = types.MethodType(_sharded_state_dict, model) + + # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + if self.training: + return torch.zeros_like(labels) + return self._compute_language_model_loss(labels, logits) + + if distill_cfg["skip_lm_loss"]: + model._compute_language_model_loss = model.compute_language_model_loss + model.compute_language_model_loss = types.MethodType(_compute_language_model_loss, model) + + # HACK: Skip `lm_loss` always for teacher. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + return torch.zeros_like(labels) + + model.teacher_model.compute_language_model_loss = types.MethodType( + _compute_language_model_loss, model.teacher_model + ) + + if model_cfg.pipeline_model_parallel_size > 1: + + def _set_input_tensor(self, input_tensor: Tensor): + obj = self.teacher_model if self._only_teacher_fwd else self + return type(self).set_input_tensor(obj, input_tensor) + + # HACK: Pipeline-parallel Distillation requires a way to cache input batches for subsequent + # forward calls, as well as a way to pass through output tensors to teacher model. + model.set_input_tensor = types.MethodType(_set_input_tensor, model) + + @contextmanager + def _swap_teacher_config(self, model_wrapper): + try: + if hasattr(model_wrapper, "config"): + model_wrapper._config = model_wrapper.config + model_wrapper.config = self.teacher_model.config + yield + finally: + del model_wrapper.config + if hasattr(model_wrapper, "_config"): + model_wrapper.config = model_wrapper._config + del model_wrapper._config + + # HACK: Pipeline-parallel forward function relies on the config in the model to know what + # hidden size of tensor to communicate to next stage. + model.swap_teacher_config = types.MethodType(_swap_teacher_config, model) + + +######################################################## + + +if __name__ == "__main__": + logging.info("Distillation enabled.") + + seq_length = 2048 + global_batch_size = 16 + tp = 1 + pp = 1 + + # TODO: setup the dummy dataset + data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size) + + TEACHER_PATH = "./test_teacher/" + # + import os + import sys + + from megatron.core import dist_checkpointing + + from nemo.lightning.io.pl import ckpt_to_weights_subdir + + if not os.path.exists(TEACHER_PATH): + gpt_config = llm.GPTConfig( + num_layers=9, + hidden_size=384, + ffn_hidden_size=1536, + num_attention_heads=6, + seq_length=seq_length, + init_method_std=0.023, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + make_vocab_size_divisible_by=128, + transformer_layer_spec=get_gpt_layer_modelopt_spec(), + ) + model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer) + dist_checkpointing.save(model.sharded_state_dict(), str(ckpt_to_weights_subdir(TEACHER_PATH, is_saving=True))) + sys.exit(0) + # + + ## initialize a small GPT model + gpt_config = DistillationGPTConfig( + num_layers=6, + hidden_size=384, + ffn_hidden_size=1536, + num_attention_heads=6, + seq_length=seq_length, + init_method_std=0.023, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + make_vocab_size_divisible_by=128, + transformer_layer_spec=get_gpt_layer_modelopt_spec(), + kd_teacher_restore_from_path=TEACHER_PATH, + ) + model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer) + + ## initialize the strategy + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + pipeline_dtype=torch.bfloat16, + ) + + ## setup the optimizer + opt_config = OptimizerConfig( + optimizer='adam', + lr=3e-5, + bf16=True, + ) + opt = nl.MegatronOptimizerModule(config=opt_config) + + trainer = nl.Trainer( + devices=1, ## you can change the number of devices to suit your setup + max_steps=50, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + nemo_logger = nl.NeMoLogger( + log_dir="test_logdir", ## logs and checkpoints will be written here + ) + + llm.train( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + tokenizer='data', + optim=opt, + ) From bc432379fbba905a46071447295cad1629661732 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 6 Jan 2025 12:50:32 -0800 Subject: [PATCH 02/22] Pipeline-parallel changes Signed-off-by: Asha Anoosheh --- nemo/collections/llm/gpt/model/base.py | 4 +- nemo/lightning/megatron_parallel.py | 51 ++++++++++- scripts/llm/gpt_distillation.py | 115 ++++++++++++++++++------- 3 files changed, 137 insertions(+), 33 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 1a28dc26b25c..62389099e807 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -460,8 +460,8 @@ def get_batch_on_this_context_parallel_rank(batch) -> Dict[str, torch.Tensor]: val.shape[seq_dim] // (2 * cp_size), *val.shape[(seq_dim + 1) :], ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).to( + _val.device, non_blocking=True ) _val = _val.index_select(seq_dim, index) _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :]) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 90499b69f216..03464d9883c8 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -187,7 +187,10 @@ def __init__( convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: from megatron.core import parallel_state - from megatron.core.tensor_parallel import set_defaults_if_not_set_tensor_model_parallel_attributes + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module + + from nemo.collections.nlp.modules.common.megatron.module import Float16Module + from nemo.utils.model_utils import unwrap_model _pipeline: List[nn.Module] if isinstance(pipeline, nn.ModuleList): @@ -219,6 +222,19 @@ def __init__( self.ddp_config = ddp_config self.convert_module_fn = convert_module_fn + # [ModelOpt]: Detect Pipeline-parallel Distillation mode. + self._unwrapped_model = [unwrap_model(self[0].module, (Float16Module, MCoreFloat16Module))] + if ( + hasattr(self._unwrapped_model[0], "teacher_model") + and parallel_state.get_pipeline_model_parallel_world_size() > 1 + ): + self._kd_teacher_in_pp = True + assert ( + not self.ddp_config.overlap_grad_reduce + ), "Pipeline-parallel Distillation currently incomatible with `overlap_grad_reduce` DDP option." + else: + self._kd_teacher_in_pp = False + def forward( self, data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]], @@ -269,6 +285,30 @@ def forward( else: forward_step_func = _forward_step + if self._kd_teacher_in_pp: + assert wrap_forward_step and num_microbatches + data = _data_step(data, cache_num_batches=num_microbatches) + + def _dummy_reduction(output_tensor, *args, **kwargs): + return output_tensor.new_tensor(-1), {} + + teacher_forward_step_func = self.wrapped_forward_step( + forward_step=_forward_step, + data_step=_data_step, + loss_reduction=_dummy_reduction, + context=_forward_context, + ) + teacher_step = MegatronStep.infer( + self, + data, + teacher_forward_step_func, + forward_only=True, + micro_batch_size=micro_batch_size, + num_microbatches=num_microbatches, + seq_length=seq_length, + step_i=step_i, + ) + step = MegatronStep.infer( self, data, @@ -283,7 +323,14 @@ def forward( step = self.callbacks.transform_event("on_megatron_step_start", step) self.callbacks.event("on_megatron_microbatches_start", step=step) - microbatch_outputs = step() + if self._kd_teacher_in_pp: + with self._unwrapped_model[0].only_teacher_forward(): + with self._unwrapped_model[0].swap_teacher_config(self[0].module): + teacher_step() + with self._unwrapped_model[0].only_student_forward(): + microbatch_outputs = step() + else: + microbatch_outputs = step() self.callbacks.event("on_megatron_microbatches_end", step=step, microbatch_outputs=microbatch_outputs) if microbatch_outputs: diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 98696a815b2d..f45d6ed45803 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -16,26 +16,23 @@ from abc import ABCMeta from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import modelopt.torch.distill as mtd import modelopt.torch.opt as mto import torch import torch.nn.functional as F +from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from megatron.core.optimizer import OptimizerConfig -from megatron.core.parallel_state import ( - get_context_parallel_group, - get_context_parallel_world_size, - get_tensor_model_parallel_group, -) from megatron.core.transformer.transformer_config import TransformerConfig from torch import Tensor from torch.nn.modules.loss import _Loss from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank from nemo.collections.llm.quantization import load_with_modelopt_layer_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group @@ -43,13 +40,60 @@ from nemo.utils import logging +def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]: + batch = next(dataloader_iter) + + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + required_device_keys = set() + required_host_keys = set() + + if attn_mask_cpu: + # [ModelOpt]: We cache data for PP distillation, and save GPU mem by storing masks on CPU mem. + required_host_keys.add("attention_mask") + else: + required_device_keys.add("attention_mask") + + if 'cu_seqlens' in _batch: + required_device_keys.add('cu_seqlens') + required_host_keys.add('cu_seqlens_argmin') + required_host_keys.add('max_seqlen') + + if parallel_state.is_pipeline_first_stage(): + required_device_keys.update(("tokens", "position_ids")) + if parallel_state.is_pipeline_last_stage(): + required_device_keys.update(("labels", "loss_mask")) + + _batch_required_keys = {} + for key, val in _batch.items(): + if key in required_device_keys: + _batch_required_keys[key] = val.cuda(non_blocking=True) + elif key in required_host_keys: + _batch_required_keys[key] = val.cpu() + else: + _batch_required_keys[key] = None + + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(_batch_required_keys) + + return output + + @dataclass class DistillationGPTConfig(llm.GPTConfig): kd_teacher_restore_from_path: str = "" # default set only for dataclass inheritance + data_step_fn: Callable = gpt_distillation_data_step + def configure_model(self, *args, **kwargs) -> MCoreGPTModel: if not self.kd_teacher_restore_from_path: raise ValueError("Config attribute `kd_teacher_restore_from_path` must be set.") + if self.virtual_pipeline_model_parallel_size is not None: + raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") model = super().configure_model(*args, **kwargs) @@ -74,7 +118,7 @@ class _DistillationLossReduction(MaskedTokenLossReduction): def __init__(self, model, *args, **kwargs): super().__init__(*args, **kwargs) self._distillation_model: mtd.DistillationModel = model.module - self._cp_size = get_context_parallel_world_size() + self._cp_size = parallel_state.get_context_parallel_world_size() def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: if isinstance(forward_out, tuple): @@ -105,19 +149,43 @@ def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens if num_valid_tokens_in_ub < 0.5: # no valid tokens num_valid_tokens_in_ub += 1.0 loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll - torch.distributed.all_reduce(loss, group=get_context_parallel_group()) + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) else: loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll if tp_reduce is True: - torch.distributed.all_reduce(loss, group=get_tensor_model_parallel_group()) + torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group()) return loss +class _LoopingCachedDataIterator: + def __init__(self, data): + self.data = data + self.it = iter(self.data) + + def __next__(self): + try: + return next(self.it) + except StopIteration: + self.it = iter(self.data) + return next(self.it) + + class DistillationGPTModel(llm.GPTModel): """Custom GPT subclass for distillation-related modifications.""" + def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> Dict[str, torch.Tensor]: + if cache_num_batches: + batches = [self.config.data_step_fn(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches)] + return _LoopingCachedDataIterator(batches) + elif isinstance(dataloader_iter, _LoopingCachedDataIterator): + batch = next(dataloader_iter) + batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU + return batch + else: + return self.config.data_step_fn(dataloader_iter) + @property def training_loss_reduction(self) -> _DistillationLossReduction: if not self._training_loss_reduction: @@ -199,26 +267,26 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, - group=get_tensor_model_parallel_group(), + group=parallel_state.get_tensor_model_parallel_group(), ) output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) # We can't use standard reduction function here since the computation # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd(denom_teacher, group=get_tensor_model_parallel_group()) + denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group()) # Maximum value along vocab dimension across all GPUs. student_logits_max, _ = torch.max(output_student, dim=-1) torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, - group=get_tensor_model_parallel_group(), + group=parallel_state.get_tensor_model_parallel_group(), ) output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() denom_student = torch.sum(torch.exp(output_student), dim=-1) - denom_student = all_reduce_autograd(denom_student, group=get_tensor_model_parallel_group()) + denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group()) slen, bsz, sharded_vocab_size = output_student.shape student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( @@ -296,26 +364,15 @@ def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: student_cfg: Model config for student model. """ logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel - cfg = { - "criterion": {tuple(logit_pair): LogitsKLLoss(cfg)}, + distill_cfg = { + "criterion": {}, "loss_balancer": None, "skip_lm_loss": True, } - return cfg - + if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): + distill_cfg["criterion"][tuple(logit_pair)] = LogitsKLLoss(cfg) -def _adjust_layer_index_for_pp(submodule_name, model_cfg): - """ - Adjust any sequence-based layer indices found in a submodule name for Pipeline Parallelism. - - For example, on PP=2, layer called `"decoder.layers.17.input_layernorm"` in a model with 32 layers - will be assumed to be evely distributed among PP ranks and now be referenced on the final rank as - `"decoder.layers.2.input_layernorm"`. - - NOTE: Operates under assumption of being final PP rank and only one numerical index per layer name. - """ - ... - # TODO + return distill_cfg def _teacher_provider(cfg: TransformerConfig) -> MCoreGPTModel: From 3f7b9c6140f370045a89de36ee5b055cdba9e2bf Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Sat, 11 Jan 2025 10:06:28 -0800 Subject: [PATCH 03/22] Basic distillation running Signed-off-by: Asha Anoosheh --- nemo/lightning/megatron_parallel.py | 14 +-- scripts/llm/gpt_distillation.py | 153 ++++++++++++++++------------ 2 files changed, 96 insertions(+), 71 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 03464d9883c8..0eacb8f2b0ca 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -187,7 +187,7 @@ def __init__( convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: from megatron.core import parallel_state - from megatron.core.transformer.module import Float16Module as MCoreFloat16Module + from megatron.core.transformer.module import Float16Module as McoreFloat16Module from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.utils.model_utils import unwrap_model @@ -223,9 +223,9 @@ def __init__( self.convert_module_fn = convert_module_fn # [ModelOpt]: Detect Pipeline-parallel Distillation mode. - self._unwrapped_model = [unwrap_model(self[0].module, (Float16Module, MCoreFloat16Module))] + self._unwrapped_model = unwrap_model(self.module, (DDP, Float16Module, McoreFloat16Module)) if ( - hasattr(self._unwrapped_model[0], "teacher_model") + hasattr(self._unwrapped_model, "teacher_model") and parallel_state.get_pipeline_model_parallel_world_size() > 1 ): self._kd_teacher_in_pp = True @@ -324,10 +324,10 @@ def _dummy_reduction(output_tensor, *args, **kwargs): self.callbacks.event("on_megatron_microbatches_start", step=step) if self._kd_teacher_in_pp: - with self._unwrapped_model[0].only_teacher_forward(): - with self._unwrapped_model[0].swap_teacher_config(self[0].module): + with self._unwrapped_model.only_teacher_forward(): + with self._unwrapped_model.swap_teacher_config(self.module): teacher_step() - with self._unwrapped_model[0].only_student_forward(): + with self._unwrapped_model.only_student_forward(): microbatch_outputs = step() else: microbatch_outputs = step() @@ -339,7 +339,7 @@ def _dummy_reduction(output_tensor, *args, **kwargs): ) if isinstance(_loss_reduction, _ModuleStepFunction): - _loss_reduction = _loss_reduction(self[0]) + _loss_reduction = _loss_reduction(self.module) reduced = _loss_reduction.reduce(microbatch_outputs) self.callbacks.event( diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index f45d6ed45803..fd6909e89d1f 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -26,6 +26,7 @@ from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig from torch import Tensor from torch.nn.modules.loss import _Loss @@ -33,11 +34,13 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank -from nemo.collections.llm.quantization import load_with_modelopt_layer_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec +from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group -from nemo.lightning.megatron_parallel import MaskedTokenLossReduction +from nemo.lightning import io +from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction from nemo.utils import logging +from nemo.utils.model_utils import unwrap_model def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]: @@ -115,9 +118,9 @@ def configure_model(self, *args, **kwargs) -> MCoreGPTModel: class _DistillationLossReduction(MaskedTokenLossReduction): """Custom masking and reduction callable used only in training mode.""" - def __init__(self, model, *args, **kwargs): + def __init__(self, distillation_loss_fn, *args, **kwargs): super().__init__(*args, **kwargs) - self._distillation_model: mtd.DistillationModel = model.module + self._distillation_loss_fn = distillation_loss_fn self._cp_size = parallel_state.get_context_parallel_world_size() def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -126,8 +129,10 @@ def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor forward_out, batch["loss_mask"] = forward_out # [ModelOpt]: KD loss calculation. - loss_for_ub = self._distillation_model.compute_kd_loss( - loss_reduction_fn=lambda x: self._masked_token_loss(x, batch["loss_mask"], batch['num_valid_tokens_in_ub']) + loss_for_ub = self._distillation_loss_fn( + loss_reduction_fn=lambda x: self._masked_token_loss( + x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub") + ) ) reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) @@ -159,19 +164,6 @@ def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens return loss -class _LoopingCachedDataIterator: - def __init__(self, data): - self.data = data - self.it = iter(self.data) - - def __next__(self): - try: - return next(self.it) - except StopIteration: - self.it = iter(self.data) - return next(self.it) - - class DistillationGPTModel(llm.GPTModel): """Custom GPT subclass for distillation-related modifications.""" @@ -189,7 +181,10 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> @property def training_loss_reduction(self) -> _DistillationLossReduction: if not self._training_loss_reduction: - self._training_loss_reduction = _DistillationLossReduction() + core_module = unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) + self._training_loss_reduction = _DistillationLossReduction( + distillation_loss_fn=core_module.compute_kd_loss + ) return self._training_loss_reduction @@ -357,7 +352,7 @@ def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.d ######################################################## -def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: +def load_distillation_config(cfg: DistillationGPTConfig) -> Dict[str, Any]: """Create a default distillation config for MCore GPT Models. Args: @@ -370,23 +365,49 @@ def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: "skip_lm_loss": True, } if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): - distill_cfg["criterion"][tuple(logit_pair)] = LogitsKLLoss(cfg) + distill_cfg["criterion"][logit_pair] = LogitsKLLoss(cfg) return distill_cfg -def _teacher_provider(cfg: TransformerConfig) -> MCoreGPTModel: +def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel: """Teacher model factory (must be a non-local function to pickle).""" logging.info("Distillation: Loading teacher weights...") - model = load_with_modelopt_layer_spec( - cfg.kd_teacher_restore_from_path, + strategy = nl.MegatronStrategy( tensor_model_parallel_size=cfg.tensor_model_parallel_size, + context_parallel_size=cfg.context_parallel_size, pipeline_model_parallel_size=cfg.pipeline_model_parallel_size, - inference_only=True, + ckpt_load_optimizer=False, + ckpt_parallel_save_optim=False, + setup_optimizers=False, + ddp="pytorch", ) + trainer = nl.Trainer( + devices=cfg.tensor_model_parallel_size, + num_nodes=cfg.context_parallel_size * cfg.pipeline_model_parallel_size, + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False) + model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) + logging.info("Distillation: ... teacher weights loaded.") - return model.module + return model + + +class _LoopingCachedDataIterator: + def __init__(self, data): + self.data = data + self.it = iter(self.data) + + def __next__(self): + try: + return next(self.it) + except StopIteration: + self.it = iter(self.data) + return next(self.it) def adjust_distillation_model_for_mcore( @@ -458,6 +479,8 @@ def _swap_teacher_config(self, model_wrapper): if __name__ == "__main__": logging.info("Distillation enabled.") + TEACHER_PATH = "./test_teacher/" + seq_length = 2048 global_batch_size = 16 tp = 1 @@ -466,33 +489,56 @@ def _swap_teacher_config(self, model_wrapper): # TODO: setup the dummy dataset data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size) - TEACHER_PATH = "./test_teacher/" - # - import os - import sys + ## initialize the strategy + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tp, + pipeline_model_parallel_size=pp, + ) + trainer = nl.Trainer( + devices=1, ## you can change the number of devices to suit your setup + max_steps=50, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) - from megatron.core import dist_checkpointing + common_model_kwargs = dict( + seq_length=seq_length, + init_method_std=0.023, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + make_vocab_size_divisible_by=128, + transformer_layer_spec=get_gpt_layer_modelopt_spec(), + ) - from nemo.lightning.io.pl import ckpt_to_weights_subdir + ############# TEACHER HACK ############# + import os + import sys if not os.path.exists(TEACHER_PATH): + from lightning.pytorch.trainer.states import TrainerFn + gpt_config = llm.GPTConfig( num_layers=9, hidden_size=384, ffn_hidden_size=1536, num_attention_heads=6, - seq_length=seq_length, - init_method_std=0.023, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - make_vocab_size_divisible_by=128, - transformer_layer_spec=get_gpt_layer_modelopt_spec(), + **common_model_kwargs, ) model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer) - dist_checkpointing.save(model.sharded_state_dict(), str(ckpt_to_weights_subdir(TEACHER_PATH, is_saving=True))) + + strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer` + trainer.state.fn = TrainerFn.FITTING # needed for proper save. + trainer.strategy.connect(model) + trainer.strategy.setup_environment() + with trainer.init_module(): + model.configure_model() + + io.ModelConnector().nemo_save(TEACHER_PATH, trainer) + sys.exit(0) - # + ########################################## ## initialize a small GPT model gpt_config = DistillationGPTConfig( @@ -500,24 +546,11 @@ def _swap_teacher_config(self, model_wrapper): hidden_size=384, ffn_hidden_size=1536, num_attention_heads=6, - seq_length=seq_length, - init_method_std=0.023, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - make_vocab_size_divisible_by=128, - transformer_layer_spec=get_gpt_layer_modelopt_spec(), + **common_model_kwargs, kd_teacher_restore_from_path=TEACHER_PATH, ) model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer) - ## initialize the strategy - strategy = nl.MegatronStrategy( - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, - pipeline_dtype=torch.bfloat16, - ) - ## setup the optimizer opt_config = OptimizerConfig( optimizer='adam', @@ -526,14 +559,6 @@ def _swap_teacher_config(self, model_wrapper): ) opt = nl.MegatronOptimizerModule(config=opt_config) - trainer = nl.Trainer( - devices=1, ## you can change the number of devices to suit your setup - max_steps=50, - accelerator="gpu", - strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) - nemo_logger = nl.NeMoLogger( log_dir="test_logdir", ## logs and checkpoints will be written here ) From d87e64f5b3e17358a3b965edf1a60ecfad453e99 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 14 Jan 2025 08:36:24 -0800 Subject: [PATCH 04/22] Add CLI args Signed-off-by: Asha Anoosheh --- scripts/llm/gpt_distillation.py | 188 +++++++++++++++++++------------- 1 file changed, 114 insertions(+), 74 deletions(-) diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index fd6909e89d1f..0efc558525bf 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -14,8 +14,9 @@ import types from abc import ABCMeta +from argparse import ArgumentParser from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, Optional, Tuple import modelopt.torch.distill as mtd @@ -34,15 +35,55 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec + +# from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.lightning import io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.utils import logging from nemo.utils.model_utils import unwrap_model +def get_args(): + """ + Parse the command line arguments. + """ + parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") + + parser.add_argument("--name", type=str, required=True, help="""Experiment name""") + parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--tp", type=int, default=1, help="""Tensor parallel size""") + parser.add_argument("--cp", type=int, default=1, help="""Context parallel size""") + parser.add_argument("--pp", type=int, default=1, help="""Pipeline parallel size""") + parser.add_argument("--enable_sp", action="store_true", help="""Enable Sequence parallelism""") + parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""") + parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""") + parser.add_argument("--nodes", type=int, default=1, help="""Number of nodes to use""") + parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") + parser.add_argument("--steps", type=int, required=True, help="""Number of global batches to process""") + parser.add_argument("--global_batch_size", type=int, required=True, help="""Data samples per optimizer step""") + parser.add_argument("--micro_batch_size", type=int, required=True, help="""Data samples per forward pass""") + parser.add_argument("--data_paths", nargs='+', required=True, help="""List of tokenized data paths to load from""") + parser.add_argument("--split", type=str, default="99,1,0", help="""""") + parser.add_argument("--index_mapping_dir", type=str, default=None, help="""""") + parser.add_argument("--sequence_length", type=int, required=True, help="""Number of tokens per input sample""") + parser.add_argument("--lr", type=float, default=3e-5, help="""""") + parser.add_argument("--min_lr", type=float, default=2e-7, help="""""") + parser.add_argument("--warmup_steps", type=int, default=50, help="""""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""""") + parser.add_argument("--limit_val_batches", type=int, default=32, help="""""") + parser.add_argument("--limit_test_batches", type=int, default=32, help="""""") + parser.add_argument("--log_interval", type=int, default=10, help="""""") + + args = parser.parse_args() + return args + + def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]: batch = next(dataloader_iter) @@ -173,7 +214,8 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> return _LoopingCachedDataIterator(batches) elif isinstance(dataloader_iter, _LoopingCachedDataIterator): batch = next(dataloader_iter) - batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU + if "attention_mask" in batch: + batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU return batch else: return self.config.data_step_fn(dataloader_iter) @@ -390,6 +432,7 @@ def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel: plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) + # TODO(aanoosheh): Replace spec with modelopt one model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False) model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) @@ -479,95 +522,92 @@ def _swap_teacher_config(self, model_wrapper): if __name__ == "__main__": logging.info("Distillation enabled.") - TEACHER_PATH = "./test_teacher/" + args = get_args() - seq_length = 2048 - global_batch_size = 16 - tp = 1 - pp = 1 - - # TODO: setup the dummy dataset - data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size) - - ## initialize the strategy + ## initialize the strategy and trainer strategy = nl.MegatronStrategy( - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, + tensor_model_parallel_size=args.tp, + pipeline_model_parallel_size=args.pp, + context_parallel_size=args.cp, + sequence_parallel=args.enable_sp, ) trainer = nl.Trainer( - devices=1, ## you can change the number of devices to suit your setup - max_steps=50, - accelerator="gpu", + devices=args.devices, + num_nodes=args.nodes, + max_steps=args.steps, + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval, + limit_val_batches=args.limit_val_batches, + limit_test_batches=args.limit_test_batches, strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) - - common_model_kwargs = dict( - seq_length=seq_length, - init_method_std=0.023, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - make_vocab_size_divisible_by=128, - transformer_layer_spec=get_gpt_layer_modelopt_spec(), + accelerator="gpu", + plugins=nl.MegatronMixedPrecision(precision=args.precision), ) - ############# TEACHER HACK ############# - import os - import sys - - if not os.path.exists(TEACHER_PATH): - from lightning.pytorch.trainer.states import TrainerFn + ## load student model + model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + assert hasattr(model, "tokenizer"), "Please provide a model checkpoint with tokenizer included." + model_config, tokenizer = model.config, model.tokenizer - gpt_config = llm.GPTConfig( - num_layers=9, - hidden_size=384, - ffn_hidden_size=1536, - num_attention_heads=6, - **common_model_kwargs, - ) - model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer) - - strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer` - trainer.state.fn = TrainerFn.FITTING # needed for proper save. - trainer.strategy.connect(model) - trainer.strategy.setup_environment() - with trainer.init_module(): - model.configure_model() - - io.ModelConnector().nemo_save(TEACHER_PATH, trainer) - - sys.exit(0) - ########################################## - - ## initialize a small GPT model - gpt_config = DistillationGPTConfig( - num_layers=6, - hidden_size=384, - ffn_hidden_size=1536, - num_attention_heads=6, - **common_model_kwargs, - kd_teacher_restore_from_path=TEACHER_PATH, + model_config = DistillationGPTConfig( + **asdict(model_config), + # transformer_layer_spec=get_gpt_layer_modelopt_spec(), # TODO(aanoosheh) + kd_teacher_restore_from_path=args.teacher_path, + ) + model = DistillationGPTModel(model_config, tokenizer=tokenizer) + + # setup the dataset + data = llm.PreTrainingDataModule( + paths=args.data_paths, + seq_length=args.sequence_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + split=args.split, + index_mapping_dir=args.index_mapping_dir, + tokenizer=tokenizer, ) - model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer) + + # auto-resume setup + # resume = nl.AutoResume( + # resume_if_exists=True, + # resume_ignore_no_checkpoint=True, + # resume_from_directory=LOG_DIR, + # restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None, + # ) ## setup the optimizer opt_config = OptimizerConfig( - optimizer='adam', - lr=3e-5, - bf16=True, + optimizer="adam", + lr=args.lr, + bf16=("bf16" in args.precision), + use_distributed_optimizer=True, + ) + sched = CosineAnnealingScheduler( + max_steps=args.steps, + warmup_steps=args.warmup_steps, + constant_steps=0, + min_lr=args.min_lr, ) - opt = nl.MegatronOptimizerModule(config=opt_config) + opt = nl.MegatronOptimizerModule(opt_config, sched) - nemo_logger = nl.NeMoLogger( - log_dir="test_logdir", ## logs and checkpoints will be written here + # checkpointing and logging + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_top_k=1, + every_n_train_steps=args.val_check_interval, + ) + logger = nl.NeMoLogger( + name=args.name, + log_dir=args.log_dir, + ckpt=checkpoint_callback, ) + # run llm.train( model=model, data=data, - trainer=trainer, - log=nemo_logger, - tokenizer='data', optim=opt, + tokenizer='model', + trainer=trainer, + log=logger, ) From 12019bf86634319ca3b8d12869151853717f4acd Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 16 Jan 2025 12:52:54 -0800 Subject: [PATCH 05/22] Most fixes Signed-off-by: Asha Anoosheh --- nemo/lightning/megatron_parallel.py | 51 ++++---- scripts/llm/gpt_distillation.py | 182 ++++++++++++++++------------ 2 files changed, 134 insertions(+), 99 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 0eacb8f2b0ca..99e8246dfd91 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -223,9 +223,10 @@ def __init__( self.convert_module_fn = convert_module_fn # [ModelOpt]: Detect Pipeline-parallel Distillation mode. - self._unwrapped_model = unwrap_model(self.module, (DDP, Float16Module, McoreFloat16Module)) + self._unwrapped_model = [unwrap_model(self.module.module, (DDP, Float16Module, McoreFloat16Module))] + # Avoid re-registering module which breaks the inherited `ModuleList` somehow. if ( - hasattr(self._unwrapped_model, "teacher_model") + hasattr(self.unwrapped_model, "teacher_model") and parallel_state.get_pipeline_model_parallel_world_size() > 1 ): self._kd_teacher_in_pp = True @@ -280,13 +281,15 @@ def forward( forward_step=_forward_step, data_step=_data_step, loss_reduction=_loss_reduction, - context=_forward_context, + context={}, ) else: forward_step_func = _forward_step if self._kd_teacher_in_pp: - assert wrap_forward_step and num_microbatches + assert wrap_forward_step + if isinstance(_data_step, _ModuleStepFunction): + _data_step = _data_step(self.module) data = _data_step(data, cache_num_batches=num_microbatches) def _dummy_reduction(output_tensor, *args, **kwargs): @@ -324,10 +327,10 @@ def _dummy_reduction(output_tensor, *args, **kwargs): self.callbacks.event("on_megatron_microbatches_start", step=step) if self._kd_teacher_in_pp: - with self._unwrapped_model.only_teacher_forward(): - with self._unwrapped_model.swap_teacher_config(self.module): + with self.unwrapped_model.only_teacher_forward(): + with self.unwrapped_model.swap_teacher_config(self.module): teacher_step() - with self._unwrapped_model.only_student_forward(): + with self.unwrapped_model.only_student_forward(): microbatch_outputs = step() else: microbatch_outputs = step() @@ -521,7 +524,7 @@ def wrapped_forward_step_func(dataloader_iter, model): _data_step = data_step batch = _data_step(dataloader_iter) - step = context["step"] + step = context.get("step") if isinstance(loss_reduction, _ModuleStepFunction): forward_callback = loss_reduction(model) @@ -533,12 +536,13 @@ def wrapped_forward_step_func(dataloader_iter, model): else: _forward_step = forward_step - self.callbacks.event( - "on_megatron_microbatch_start", - step=step, - batch=batch, - forward_callback=forward_callback, - ) + if step is not None: + self.callbacks.event( + "on_megatron_microbatch_start", + step=step, + batch=batch, + forward_callback=forward_callback, + ) if self.precision_plugin and parallel_state.is_pipeline_first_stage(): batch = self.precision_plugin.convert_input(batch) @@ -557,13 +561,14 @@ def wrapped_forward_step_func(dataloader_iter, model): if self.precision_plugin and parallel_state.is_pipeline_last_stage(): output_tensor = self.precision_plugin.convert_output(output_tensor) - self.callbacks.event( - "on_megatron_microbatch_end", - step=step, - batch=batch, - output=output_tensor, - forward_callback=forward_callback, - ) + if step is not None: + self.callbacks.event( + "on_megatron_microbatch_end", + step=step, + batch=batch, + output=output_tensor, + forward_callback=forward_callback, + ) return output_tensor, forward_callback @@ -747,6 +752,10 @@ def pipeline(self) -> Union[ModelT, List[ModelT]]: def module(self) -> ModelT: return self[0] + @property + def unwrapped_model(self): + return self._unwrapped_model[0] + @override def __getattr__(self, item: Any) -> Any: try: diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 0efc558525bf..b8bb2286e1dd 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import types from abc import ABCMeta from argparse import ArgumentParser from contextlib import contextmanager -from dataclasses import asdict, dataclass -from typing import Any, Callable, Dict, Optional, Tuple +from types import MethodType +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple import modelopt.torch.distill as mtd import modelopt.torch.opt as mto @@ -29,24 +28,26 @@ from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_config import TransformerConfig -from torch import Tensor +from torch import Tensor, nn from torch.nn.modules.loss import _Loss from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank - -# from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec +from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction from nemo.lightning.pytorch.callbacks import ModelCheckpoint -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, OptimizerModule from nemo.utils import logging from nemo.utils.model_utils import unwrap_model +if TYPE_CHECKING: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + def get_args(): """ @@ -127,35 +128,6 @@ def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str return output -@dataclass -class DistillationGPTConfig(llm.GPTConfig): - kd_teacher_restore_from_path: str = "" # default set only for dataclass inheritance - - data_step_fn: Callable = gpt_distillation_data_step - - def configure_model(self, *args, **kwargs) -> MCoreGPTModel: - if not self.kd_teacher_restore_from_path: - raise ValueError("Config attribute `kd_teacher_restore_from_path` must be set.") - if self.virtual_pipeline_model_parallel_size is not None: - raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") - - model = super().configure_model(*args, **kwargs) - - # [ModelOpt] Intialize DistillationModel. - distill_cfg = load_distillation_config(self) - kd_config = { - "teacher_model": (_teacher_provider, [], {"cfg": self}), - "criterion": distill_cfg["criterion"], - "loss_balancer": distill_cfg["loss_balancer"], - } - model = mtd.convert(model, mode=[("kd_loss", kd_config)]) - - # Additional MCore-specific tweaks needed. - adjust_distillation_model_for_mcore(model, model_cfg=self, distill_cfg=distill_cfg) - - return model - - class _DistillationLossReduction(MaskedTokenLossReduction): """Custom masking and reduction callable used only in training mode.""" @@ -208,9 +180,48 @@ def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens class DistillationGPTModel(llm.GPTModel): """Custom GPT subclass for distillation-related modifications.""" + def __init__( + self, + kd_teacher_restore_from_path: str, + config: llm.GPTConfig, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config, optim, tokenizer, model_transform) + self._kd_teacher_restore_from_path = kd_teacher_restore_from_path + + if not self._kd_teacher_restore_from_path: + raise ValueError("Config attribute `kd_teacher_restore_from_path` must be set.") + if self.config.virtual_pipeline_model_parallel_size is not None: + raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") + + def configure_model(self): + if hasattr(self, "module"): + return + + model = self.config.configure_model(self.tokenizer) + + # [ModelOpt] Intialize DistillationModel. + distill_cfg = load_distillation_config(self.config) + kd_config = { + "teacher_model": (_teacher_provider, [self._kd_teacher_restore_from_path], {"trainer": self.trainer}), + "criterion": distill_cfg["criterion"], + "loss_balancer": distill_cfg["loss_balancer"], + } + model = mtd.convert(model, mode=[("kd_loss", kd_config)]) + + # Additional MCore-specific tweaks needed. + adjust_distillation_model_for_mcore(model, model_cfg=self.config, distill_cfg=distill_cfg) + + self.module = model + def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> Dict[str, torch.Tensor]: + # NOTE: Ignores `self.config.data_step_fn` if cache_num_batches: - batches = [self.config.data_step_fn(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches)] + batches = [ + gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches) + ] return _LoopingCachedDataIterator(batches) elif isinstance(dataloader_iter, _LoopingCachedDataIterator): batch = next(dataloader_iter) @@ -218,7 +229,7 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU return batch else: - return self.config.data_step_fn(dataloader_iter) + return gpt_distillation_data_step(dataloader_iter) @property def training_loss_reduction(self) -> _DistillationLossReduction: @@ -394,7 +405,7 @@ def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.d ######################################################## -def load_distillation_config(cfg: DistillationGPTConfig) -> Dict[str, Any]: +def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: """Create a default distillation config for MCore GPT Models. Args: @@ -403,7 +414,7 @@ def load_distillation_config(cfg: DistillationGPTConfig) -> Dict[str, Any]: logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel distill_cfg = { "criterion": {}, - "loss_balancer": None, + "loss_balancer": _DummyLossBalancer(), # HACK: to appease ModelOpt until validation relaxed "skip_lm_loss": True, } if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): @@ -412,31 +423,30 @@ def load_distillation_config(cfg: DistillationGPTConfig) -> Dict[str, Any]: return distill_cfg -def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel: - """Teacher model factory (must be a non-local function to pickle).""" +class _DummyLossBalancer(mtd.DistillationLossBalancer): + def forward(loss_dict): + return next(iter(loss_dict.values())) + +def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: + """Teacher model factory (must be a non-local function to pickle).""" logging.info("Distillation: Loading teacher weights...") - strategy = nl.MegatronStrategy( - tensor_model_parallel_size=cfg.tensor_model_parallel_size, - context_parallel_size=cfg.context_parallel_size, - pipeline_model_parallel_size=cfg.pipeline_model_parallel_size, - ckpt_load_optimizer=False, - ckpt_parallel_save_optim=False, - setup_optimizers=False, - ddp="pytorch", - ) - trainer = nl.Trainer( - devices=cfg.tensor_model_parallel_size, - num_nodes=cfg.context_parallel_size * cfg.pipeline_model_parallel_size, - strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) + + ckpt_load_optimizer = trainer.strategy.ckpt_load_optimizer + _setup_optimizers = trainer.strategy._setup_optimizers + trainer.strategy.ckpt_load_optimizer = False + trainer.strategy._setup_optimizers = False + orig_model = trainer.model # TODO(aanoosheh): Replace spec with modelopt one - model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False) + model, _ = io.ModelConnector().nemo_load(teacher_path, trainer, cpu=False) model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) - logging.info("Distillation: ... teacher weights loaded.") + trainer.strategy.ckpt_load_optimizer = ckpt_load_optimizer + trainer.strategy._setup_optimizers = _setup_optimizers + trainer.strategy.connect(orig_model) + + logging.info("Distillation: ...teacher weights loaded.") return model @@ -468,7 +478,7 @@ def _sharded_state_dict(self, *args, **kwargs) -> ShardedStateDict: return self._sharded_state_dict(*args, **kwargs) model._sharded_state_dict = model.sharded_state_dict - model.sharded_state_dict = types.MethodType(_sharded_state_dict, model) + model.sharded_state_dict = MethodType(_sharded_state_dict, model) # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. def _compute_language_model_loss(self, labels, logits) -> Tensor: @@ -478,15 +488,13 @@ def _compute_language_model_loss(self, labels, logits) -> Tensor: if distill_cfg["skip_lm_loss"]: model._compute_language_model_loss = model.compute_language_model_loss - model.compute_language_model_loss = types.MethodType(_compute_language_model_loss, model) + model.compute_language_model_loss = MethodType(_compute_language_model_loss, model) # HACK: Skip `lm_loss` always for teacher. def _compute_language_model_loss(self, labels, logits) -> Tensor: return torch.zeros_like(labels) - model.teacher_model.compute_language_model_loss = types.MethodType( - _compute_language_model_loss, model.teacher_model - ) + model.teacher_model.compute_language_model_loss = MethodType(_compute_language_model_loss, model.teacher_model) if model_cfg.pipeline_model_parallel_size > 1: @@ -496,7 +504,7 @@ def _set_input_tensor(self, input_tensor: Tensor): # HACK: Pipeline-parallel Distillation requires a way to cache input batches for subsequent # forward calls, as well as a way to pass through output tensors to teacher model. - model.set_input_tensor = types.MethodType(_set_input_tensor, model) + model.set_input_tensor = MethodType(_set_input_tensor, model) @contextmanager def _swap_teacher_config(self, model_wrapper): @@ -513,11 +521,28 @@ def _swap_teacher_config(self, model_wrapper): # HACK: Pipeline-parallel forward function relies on the config in the model to know what # hidden size of tensor to communicate to next stage. - model.swap_teacher_config = types.MethodType(_swap_teacher_config, model) + model.swap_teacher_config = MethodType(_swap_teacher_config, model) ######################################################## +# # # +from dataclasses import dataclass + +from nemo.collections.llm.gpt.model.llama import Llama31Config + + +@dataclass +class Llama31Config4B(Llama31Config): + rotary_base: int = 500000 + seq_length: int = 131072 + num_layers: int = 16 + hidden_size: int = 4096 + ffn_hidden_size: int = 14336 + num_attention_heads: int = 32 + + +# # # if __name__ == "__main__": logging.info("Distillation enabled.") @@ -545,16 +570,12 @@ def _swap_teacher_config(self, model_wrapper): ) ## load student model - model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") - assert hasattr(model, "tokenizer"), "Please provide a model checkpoint with tokenizer included." - model_config, tokenizer = model.config, model.tokenizer - - model_config = DistillationGPTConfig( - **asdict(model_config), - # transformer_layer_spec=get_gpt_layer_modelopt_spec(), # TODO(aanoosheh) - kd_teacher_restore_from_path=args.teacher_path, - ) - model = DistillationGPTModel(model_config, tokenizer=tokenizer) + _model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + assert hasattr(_model, "tokenizer"), "Please provide a model checkpoint with tokenizer included." + model_config, tokenizer = _model.config, _model.tokenizer + # model_config.transformer_layer_spec = get_gpt_layer_modelopt_spec() # TODO(aanoosheh) + model = DistillationGPTModel(args.teacher_path, model_config, tokenizer=tokenizer) + _setup_trainer_and_restore_model(args.student_path, trainer, model) # setup the dataset data = llm.PreTrainingDataModule( @@ -567,11 +588,11 @@ def _swap_teacher_config(self, model_wrapper): tokenizer=tokenizer, ) - # auto-resume setup + # TODO auto-resume setup # resume = nl.AutoResume( # resume_if_exists=True, # resume_ignore_no_checkpoint=True, - # resume_from_directory=LOG_DIR, + # resume_from_directory=args.log_dir, # restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None, # ) @@ -602,6 +623,11 @@ def _swap_teacher_config(self, model_wrapper): ckpt=checkpoint_callback, ) + import os + + # suppress warning + os.environ["TOKENIZERS_PARALLELISM"] = "false" + # run llm.train( model=model, From c3f5fb22242898f6c9b771930bf2eb2fce63adaa Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 17 Jan 2025 15:05:12 -0800 Subject: [PATCH 06/22] Fix callbacks in PP loop Signed-off-by: Asha Anoosheh --- nemo/lightning/megatron_parallel.py | 8 ++++++-- scripts/llm/gpt_distillation.py | 8 ++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 99e8246dfd91..03fd54bf2b9e 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -281,13 +281,15 @@ def forward( forward_step=_forward_step, data_step=_data_step, loss_reduction=_loss_reduction, - context={}, + context=_forward_context, ) else: forward_step_func = _forward_step if self._kd_teacher_in_pp: assert wrap_forward_step + _teacher_forward_context = {} + if isinstance(_data_step, _ModuleStepFunction): _data_step = _data_step(self.module) data = _data_step(data, cache_num_batches=num_microbatches) @@ -299,7 +301,7 @@ def _dummy_reduction(output_tensor, *args, **kwargs): forward_step=_forward_step, data_step=_data_step, loss_reduction=_dummy_reduction, - context=_forward_context, + context=_teacher_forward_context, ) teacher_step = MegatronStep.infer( self, @@ -311,6 +313,8 @@ def _dummy_reduction(output_tensor, *args, **kwargs): seq_length=seq_length, step_i=step_i, ) + _teacher_forward_context["step"] = teacher_step + teacher_step = self.callbacks.transform_event("on_megatron_step_start", teacher_step) step = MegatronStep.infer( self, diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index b8bb2286e1dd..9a1a5e4f8c7f 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import ABCMeta from argparse import ArgumentParser from contextlib import contextmanager @@ -433,7 +434,7 @@ def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: logging.info("Distillation: Loading teacher weights...") ckpt_load_optimizer = trainer.strategy.ckpt_load_optimizer - _setup_optimizers = trainer.strategy._setup_optimizers + setup_optimizers = trainer.strategy._setup_optimizers trainer.strategy.ckpt_load_optimizer = False trainer.strategy._setup_optimizers = False orig_model = trainer.model @@ -443,7 +444,7 @@ def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) trainer.strategy.ckpt_load_optimizer = ckpt_load_optimizer - trainer.strategy._setup_optimizers = _setup_optimizers + trainer.strategy._setup_optimizers = setup_optimizers trainer.strategy.connect(orig_model) logging.info("Distillation: ...teacher weights loaded.") @@ -575,6 +576,7 @@ class Llama31Config4B(Llama31Config): model_config, tokenizer = _model.config, _model.tokenizer # model_config.transformer_layer_spec = get_gpt_layer_modelopt_spec() # TODO(aanoosheh) model = DistillationGPTModel(args.teacher_path, model_config, tokenizer=tokenizer) + model.trainer = trainer _setup_trainer_and_restore_model(args.student_path, trainer, model) # setup the dataset @@ -623,8 +625,6 @@ class Llama31Config4B(Llama31Config): ckpt=checkpoint_callback, ) - import os - # suppress warning os.environ["TOKENIZERS_PARALLELISM"] = "false" From 8770138f27065a867ffa02290aff61c8b05f9c6a Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Sat, 18 Jan 2025 10:09:45 -0800 Subject: [PATCH 07/22] More fixes Signed-off-by: Asha Anoosheh --- nemo/lightning/megatron_parallel.py | 37 ++++++++++++++--------------- scripts/llm/gpt_distillation.py | 16 +++++++++---- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 03fd54bf2b9e..9713657f67b9 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -289,10 +289,8 @@ def forward( if self._kd_teacher_in_pp: assert wrap_forward_step _teacher_forward_context = {} - if isinstance(_data_step, _ModuleStepFunction): _data_step = _data_step(self.module) - data = _data_step(data, cache_num_batches=num_microbatches) def _dummy_reduction(output_tensor, *args, **kwargs): return output_tensor.new_tensor(-1), {} @@ -305,7 +303,7 @@ def _dummy_reduction(output_tensor, *args, **kwargs): ) teacher_step = MegatronStep.infer( self, - data, + None, # updated later below once we actually know `num_microbatches` teacher_forward_step_func, forward_only=True, micro_batch_size=micro_batch_size, @@ -316,6 +314,9 @@ def _dummy_reduction(output_tensor, *args, **kwargs): _teacher_forward_context["step"] = teacher_step teacher_step = self.callbacks.transform_event("on_megatron_step_start", teacher_step) + data = _data_step(data, cache_num_batches=teacher_step.num_microbatches) + teacher_step.data = data + step = MegatronStep.infer( self, data, @@ -528,7 +529,7 @@ def wrapped_forward_step_func(dataloader_iter, model): _data_step = data_step batch = _data_step(dataloader_iter) - step = context.get("step") + step = context["step"] if isinstance(loss_reduction, _ModuleStepFunction): forward_callback = loss_reduction(model) @@ -540,13 +541,12 @@ def wrapped_forward_step_func(dataloader_iter, model): else: _forward_step = forward_step - if step is not None: - self.callbacks.event( - "on_megatron_microbatch_start", - step=step, - batch=batch, - forward_callback=forward_callback, - ) + self.callbacks.event( + "on_megatron_microbatch_start", + step=step, + batch=batch, + forward_callback=forward_callback, + ) if self.precision_plugin and parallel_state.is_pipeline_first_stage(): batch = self.precision_plugin.convert_input(batch) @@ -565,14 +565,13 @@ def wrapped_forward_step_func(dataloader_iter, model): if self.precision_plugin and parallel_state.is_pipeline_last_stage(): output_tensor = self.precision_plugin.convert_output(output_tensor) - if step is not None: - self.callbacks.event( - "on_megatron_microbatch_end", - step=step, - batch=batch, - output=output_tensor, - forward_callback=forward_callback, - ) + self.callbacks.event( + "on_megatron_microbatch_end", + step=step, + batch=batch, + output=output_tensor, + forward_callback=forward_callback, + ) return output_tensor, forward_callback diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 9a1a5e4f8c7f..4ef3eb8ac8e4 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -425,7 +425,7 @@ def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: class _DummyLossBalancer(mtd.DistillationLossBalancer): - def forward(loss_dict): + def forward(self, loss_dict): return next(iter(loss_dict.values())) @@ -435,8 +435,10 @@ def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: ckpt_load_optimizer = trainer.strategy.ckpt_load_optimizer setup_optimizers = trainer.strategy._setup_optimizers + trainer.strategy.ckpt_load_optimizer = False trainer.strategy._setup_optimizers = False + # trainer.strategy.megatron_parallel.ddp_config = None orig_model = trainer.model # TODO(aanoosheh): Replace spec with modelopt one @@ -445,6 +447,7 @@ def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: trainer.strategy.ckpt_load_optimizer = ckpt_load_optimizer trainer.strategy._setup_optimizers = setup_optimizers + # trainer.strategy.megatron_parallel.ddp_config = ddp_config trainer.strategy.connect(orig_model) logging.info("Distillation: ...teacher weights loaded.") @@ -456,6 +459,9 @@ def __init__(self, data): self.data = data self.it = iter(self.data) + def __iter__(self): + return self + def __next__(self): try: return next(self.it) @@ -484,7 +490,7 @@ def _sharded_state_dict(self, *args, **kwargs) -> ShardedStateDict: # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. def _compute_language_model_loss(self, labels, logits) -> Tensor: if self.training: - return torch.zeros_like(labels) + return torch.zeros_like(labels, dtype=logits.dtype) return self._compute_language_model_loss(labels, logits) if distill_cfg["skip_lm_loss"]: @@ -493,7 +499,7 @@ def _compute_language_model_loss(self, labels, logits) -> Tensor: # HACK: Skip `lm_loss` always for teacher. def _compute_language_model_loss(self, labels, logits) -> Tensor: - return torch.zeros_like(labels) + return torch.zeros_like(labels, dtype=logits.dtype) model.teacher_model.compute_language_model_loss = MethodType(_compute_language_model_loss, model.teacher_model) @@ -576,8 +582,8 @@ class Llama31Config4B(Llama31Config): model_config, tokenizer = _model.config, _model.tokenizer # model_config.transformer_layer_spec = get_gpt_layer_modelopt_spec() # TODO(aanoosheh) model = DistillationGPTModel(args.teacher_path, model_config, tokenizer=tokenizer) - model.trainer = trainer - _setup_trainer_and_restore_model(args.student_path, trainer, model) + # model.trainer = trainer + # _setup_trainer_and_restore_model(args.student_path, trainer, model) # setup the dataset data = llm.PreTrainingDataModule( From 758132d045bd5c90cd13d049a85d334b349bc2a8 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 20 Jan 2025 13:47:36 -0800 Subject: [PATCH 08/22] Rework checkpoint loading Signed-off-by: Asha Anoosheh --- scripts/llm/gpt_distillation.py | 102 ++++++++++++++++++-------------- 1 file changed, 59 insertions(+), 43 deletions(-) diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 4ef3eb8ac8e4..e3d760353bd5 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -35,10 +35,8 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank -from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group -from nemo.lightning import io from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction from nemo.lightning.pytorch.callbacks import ModelCheckpoint @@ -183,17 +181,17 @@ class DistillationGPTModel(llm.GPTModel): def __init__( self, - kd_teacher_restore_from_path: str, - config: llm.GPTConfig, + student_config: llm.GPTConfig, + teacher_config: llm.GPTConfig, + teacher_ckpt_path: str, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(config, optim, tokenizer, model_transform) - self._kd_teacher_restore_from_path = kd_teacher_restore_from_path + super().__init__(student_config, optim, tokenizer, model_transform) + self._teacher_config = teacher_config + self._teacher_ckpt_path = teacher_ckpt_path - if not self._kd_teacher_restore_from_path: - raise ValueError("Config attribute `kd_teacher_restore_from_path` must be set.") if self.config.virtual_pipeline_model_parallel_size is not None: raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") @@ -206,7 +204,11 @@ def configure_model(self): # [ModelOpt] Intialize DistillationModel. distill_cfg = load_distillation_config(self.config) kd_config = { - "teacher_model": (_teacher_provider, [self._kd_teacher_restore_from_path], {"trainer": self.trainer}), + "teacher_model": ( + _teacher_provider, + [self._teacher_config, self._teacher_ckpt_path], + {"tokenizer": self.tokenizer, "trainer": self.trainer}, + ), "criterion": distill_cfg["criterion"], "loss_balancer": distill_cfg["loss_balancer"], } @@ -235,13 +237,21 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> @property def training_loss_reduction(self) -> _DistillationLossReduction: if not self._training_loss_reduction: - core_module = unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) self._training_loss_reduction = _DistillationLossReduction( - distillation_loss_fn=core_module.compute_kd_loss + distillation_loss_fn=self.core_module.compute_kd_loss ) return self._training_loss_reduction + def load_state_dict(self, state_dict, *args, **kwargs): + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # `super()` would go to `nn.Module` and skip the Context Manager in `mtd.DistillationModel.load_state_dict()` + return self.core_module.load_state_dict(state_dict, *args, *kwargs) + + @property + def core_module(self): + return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) + ######################################################## @@ -429,27 +439,25 @@ def forward(self, loss_dict): return next(iter(loss_dict.values())) -def _teacher_provider(teacher_path: str, trainer: nl.Trainer) -> MCoreGPTModel: +def _teacher_provider( + config: llm.GPTConfig, ckpt_path: str, tokenizer: "TokenizerSpec", trainer: nl.Trainer +) -> MCoreGPTModel: """Teacher model factory (must be a non-local function to pickle).""" logging.info("Distillation: Loading teacher weights...") - ckpt_load_optimizer = trainer.strategy.ckpt_load_optimizer - setup_optimizers = trainer.strategy._setup_optimizers - - trainer.strategy.ckpt_load_optimizer = False - trainer.strategy._setup_optimizers = False - # trainer.strategy.megatron_parallel.ddp_config = None - orig_model = trainer.model - # TODO(aanoosheh): Replace spec with modelopt one - model, _ = io.ModelConnector().nemo_load(teacher_path, trainer, cpu=False) - model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) + model = config.configure_model(tokenizer) - trainer.strategy.ckpt_load_optimizer = ckpt_load_optimizer - trainer.strategy._setup_optimizers = setup_optimizers - # trainer.strategy.megatron_parallel.ddp_config = ddp_config - trainer.strategy.connect(orig_model) + sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")} + checkpoint = trainer.strategy.checkpoint_io.load_checkpoint( + ckpt_path, + sharded_state_dict=sharded_state_dict, + ) + state_dict = checkpoint["state_dict"] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + model.load_state_dict(state_dict) + torch.cuda.empty_cache() logging.info("Distillation: ...teacher weights loaded.") return model @@ -576,14 +584,21 @@ class Llama31Config4B(Llama31Config): plugins=nl.MegatronMixedPrecision(precision=args.precision), ) - ## load student model - _model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") - assert hasattr(_model, "tokenizer"), "Please provide a model checkpoint with tokenizer included." - model_config, tokenizer = _model.config, _model.tokenizer - # model_config.transformer_layer_spec = get_gpt_layer_modelopt_spec() # TODO(aanoosheh) - model = DistillationGPTModel(args.teacher_path, model_config, tokenizer=tokenizer) - # model.trainer = trainer - # _setup_trainer_and_restore_model(args.student_path, trainer, model) + ## load the combined teacher-student model + _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") + + tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) + assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." + + # TODO(aanoosheh): Replace spec with modelopt one + model = DistillationGPTModel( + _student_model.config, + _teacher_model.config, + teacher_ckpt_path=args.teacher_path, + tokenizer=tokenizer, + ) + model.__io__ = _student_model.__io__ # HACK: model saves and restores as original class # setup the dataset data = llm.PreTrainingDataModule( @@ -596,14 +611,6 @@ class Llama31Config4B(Llama31Config): tokenizer=tokenizer, ) - # TODO auto-resume setup - # resume = nl.AutoResume( - # resume_if_exists=True, - # resume_ignore_no_checkpoint=True, - # resume_from_directory=args.log_dir, - # restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None, - # ) - ## setup the optimizer opt_config = OptimizerConfig( optimizer="adam", @@ -631,7 +638,15 @@ class Llama31Config4B(Llama31Config): ckpt=checkpoint_callback, ) - # suppress warning + # auto-resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_from_directory=args.log_dir, + resume_ignore_no_checkpoint=True, + restore_config=nl.RestoreConfig(path=args.student_path), + ) + + # suppress HF warning os.environ["TOKENIZERS_PARALLELISM"] = "false" # run @@ -642,4 +657,5 @@ class Llama31Config4B(Llama31Config): tokenizer='model', trainer=trainer, log=logger, + resume=resume, ) From 01864d330b477deb29aa9ba4cdd778c9b187f38b Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 21 Jan 2025 09:58:03 -0800 Subject: [PATCH 09/22] Resolve seemingly remaining bugs Signed-off-by: Asha Anoosheh --- scripts/llm/gpt_distillation.py | 35 ++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index e3d760353bd5..8014bc7bd1ac 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -157,6 +157,8 @@ def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens if isinstance(loss_output, tuple): # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) loss_output, tp_reduce = loss_output + else: + tp_reduce = False losses = loss_output.float() loss_mask = mask.view(-1).float() @@ -191,6 +193,7 @@ def __init__( super().__init__(student_config, optim, tokenizer, model_transform) self._teacher_config = teacher_config self._teacher_ckpt_path = teacher_ckpt_path + self._train_called = False if self.config.virtual_pipeline_model_parallel_size is not None: raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") @@ -201,6 +204,16 @@ def configure_model(self): model = self.config.configure_model(self.tokenizer) + # Ensure same for both models. + for attr in [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "sequence_parallel", + "pipeline_dtype", + ]: + setattr(self._teacher_config, attr, getattr(self.config, attr)) + # [ModelOpt] Intialize DistillationModel. distill_cfg = load_distillation_config(self.config) kd_config = { @@ -252,6 +265,19 @@ def load_state_dict(self, state_dict, *args, **kwargs): def core_module(self): return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) + def train(self, mode: bool = True): + self._train_called = True + return super().train(mode) + + def __setattr__(self, name, value): + # HACK: PTL calls `module.training = True` after sanity check, bypassing `module.train()` which we depend on. + if name == "training": + if not self._train_called: + self.train(value) + return + self._train_called = False + return super().__setattr__(name, value) + ######################################################## @@ -449,12 +475,8 @@ def _teacher_provider( model = config.configure_model(tokenizer) sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")} - checkpoint = trainer.strategy.checkpoint_io.load_checkpoint( - ckpt_path, - sharded_state_dict=sharded_state_dict, - ) - state_dict = checkpoint["state_dict"] - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + checkpoint = trainer.strategy.checkpoint_io.load_checkpoint(ckpt_path, sharded_state_dict) + state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} model.load_state_dict(state_dict) torch.cuda.empty_cache() @@ -641,7 +663,6 @@ class Llama31Config4B(Llama31Config): # auto-resume setup resume = nl.AutoResume( resume_if_exists=True, - resume_from_directory=args.log_dir, resume_ignore_no_checkpoint=True, restore_config=nl.RestoreConfig(path=args.student_path), ) From e0cc0bc37107a33a7a30a1732fc056479568574a Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 21 Jan 2025 11:37:23 -0800 Subject: [PATCH 10/22] Refactor into multiple files Signed-off-by: Asha Anoosheh --- nemo/collections/llm/distillation/__init__.py | 18 + nemo/collections/llm/distillation/loss.py | 182 ++++++ nemo/collections/llm/distillation/model.py | 235 ++++++++ nemo/collections/llm/distillation/utils.py | 155 +++++ scripts/llm/gpt_distillation.py | 549 +----------------- 5 files changed, 613 insertions(+), 526 deletions(-) create mode 100644 nemo/collections/llm/distillation/__init__.py create mode 100644 nemo/collections/llm/distillation/loss.py create mode 100644 nemo/collections/llm/distillation/model.py create mode 100644 nemo/collections/llm/distillation/utils.py diff --git a/nemo/collections/llm/distillation/__init__.py b/nemo/collections/llm/distillation/__init__.py new file mode 100644 index 000000000000..78eddd3b5ad7 --- /dev/null +++ b/nemo/collections/llm/distillation/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .loss import LogitsKLLoss +from .model import DistillationGPTModel + +__all__ = ["LogitsKLLoss", "DistillationGPTModel"] diff --git a/nemo/collections/llm/distillation/loss.py b/nemo/collections/llm/distillation/loss.py new file mode 100644 index 000000000000..e970386a058c --- /dev/null +++ b/nemo/collections/llm/distillation/loss.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABCMeta +from typing import TYPE_CHECKING, Tuple + +import torch +import torch.nn.functional as F +from megatron.core import parallel_state +from torch import Tensor +from torch.nn.modules.loss import _Loss + +if TYPE_CHECKING: + from megatron.core.transformer.transformer_config import TransformerConfig + + +class BaseLoss(_Loss, metaclass=ABCMeta): + """Abstract base class for Megatron distillation losses.""" + + def __init__(self, model_config: "TransformerConfig"): + """ + Constructor. + + Args: + model_config: MCore transformer config. + """ + super().__init__() + self._config = model_config + + def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + """Prepares inputs safely for loss computation.""" + if isinstance(predictions, tuple): + # `ColumnParallelLinear` returns bias too + predictions, targets = predictions[0], targets[0] + targets = targets.detach() + + return predictions, targets + + def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: + """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" + loss = loss.transpose(0, 1).contiguous() + return loss, tp_reduce + + +class LogitsKLLoss(BaseLoss): + """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" + + def __init__(self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False): + """ + Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + """ + super().__init__(model_config) + self._temperature = temperature + self._reverse = reverse + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """ + Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Division by temp should happen prior to finding max for both student and teacher. + # Currently we don't use temperature in any of ours runs (temp=1.0) + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # Compute local softmax, and the reweight to compute global softmax. + if self._config.tensor_model_parallel_size > 1: + + # Maximum value along vocab dimension across all GPUs. + teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + torch.distributed.all_reduce( + teacher_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_tensor_model_parallel_group(), + ) + output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) + # We can't use standard reduction function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group()) + + # Maximum value along vocab dimension across all GPUs. + student_logits_max, _ = torch.max(output_student, dim=-1) + torch.distributed.all_reduce( + student_logits_max, + op=torch.distributed.ReduceOp.MAX, + group=parallel_state.get_tensor_model_parallel_group(), + ) + output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + + denom_student = torch.sum(torch.exp(output_student), dim=-1) + denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group()) + + slen, bsz, sharded_vocab_size = output_student.shape + student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( + slen, bsz, sharded_vocab_size + ) + + if self._reverse: + loss = torch.sum( + F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), + dim=-1, + ) + + else: + if self._reverse: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_teacher, dim=-1), + F.softmax(output_student, dim=-1), + reduction="none", + ), + dim=-1, + ) + else: + loss = torch.sum( + F.kl_div( + F.log_softmax(output_student, dim=-1), + F.softmax(output_teacher, dim=-1), + reduction="none", + ), + dim=-1, + ) + + return self.post_forward(loss, tp_reduce=True) + + +class _AllReduce(torch.autograd.Function): + """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" + + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group, ctx.op = group, op + tensor = tensor.clone() + torch.distributed.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) + + +def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD): + """Custom all-reduce function. + + Needed instead of other all-reduce functions available when the computation following + the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. + """ + return _AllReduce.apply(op, group, tensor) diff --git a/nemo/collections/llm/distillation/model.py b/nemo/collections/llm/distillation/model.py new file mode 100644 index 000000000000..9b69fc722bb3 --- /dev/null +++ b/nemo/collections/llm/distillation/model.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple + +import modelopt.torch.distill as mtd +import torch +from megatron.core import parallel_state +from megatron.core.transformer.module import Float16Module as MCoreFloat16Module +from torch import Tensor, nn + +from nemo.collections import llm +from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction +from nemo.utils.model_utils import unwrap_model + +from .utils import ( + LoopingCachedDataIterator, + adjust_distillation_model_for_mcore, + load_distillation_config, + teacher_provider, +) + +if TYPE_CHECKING: + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + from nemo.lightning.pytorch.optim import OptimizerModule + + +def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, Tensor]: + batch = next(dataloader_iter) + + if isinstance(batch, tuple) and len(batch) == 3: + batch = batch[0] + + required_device_keys = set() + required_host_keys = set() + + if attn_mask_cpu: + # [ModelOpt]: We cache data for PP distillation, and save GPU mem by storing masks on CPU mem. + required_host_keys.add("attention_mask") + else: + required_device_keys.add("attention_mask") + + if 'cu_seqlens' in batch: + required_device_keys.add('cu_seqlens') + required_host_keys.add('cu_seqlens_argmin') + required_host_keys.add('max_seqlen') + + if parallel_state.is_pipeline_first_stage(): + required_device_keys.update(("tokens", "position_ids")) + if parallel_state.is_pipeline_last_stage(): + required_device_keys.update(("labels", "loss_mask")) + + batch_required_keys = {} + for key, val in batch.items(): + if key in required_device_keys: + batch_required_keys[key] = val.cuda(non_blocking=True) + elif key in required_host_keys: + batch_required_keys[key] = val.cpu() + else: + batch_required_keys[key] = None + + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(batch_required_keys) + + return output + + +class _DistillationLossReduction(MaskedTokenLossReduction): + """Custom masking and reduction callable used only in training mode.""" + + def __init__(self, distillation_loss_fn, *args, **kwargs): + super().__init__(*args, **kwargs) + self._distillation_loss_fn = distillation_loss_fn + self._cp_size = parallel_state.get_context_parallel_world_size() + + def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: + if isinstance(forward_out, tuple): + # neva returns (logits, loss_mask) + forward_out, batch["loss_mask"] = forward_out + + # [ModelOpt]: KD loss calculation. + loss_for_ub = self._distillation_loss_fn( + loss_reduction_fn=lambda x: self._masked_token_loss( + x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub") + ) + ) + + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub * self._cp_size, {"avg": reduced_loss} + + def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None): + """ + The function takes as input per-token loss and masks non-required values. + """ + if isinstance(loss_output, tuple): + # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) + loss_output, tp_reduce = loss_output + else: + tp_reduce = False + losses = loss_output.float() + loss_mask = mask.view(-1).float() + + if self._cp_size > 1: + if num_valid_tokens_in_ub is None: + num_valid_tokens_in_ub = loss_mask.sum() + if num_valid_tokens_in_ub < 0.5: # no valid tokens + num_valid_tokens_in_ub += 1.0 + loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + + if tp_reduce is True: + torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group()) + + return loss + + +class DistillationGPTModel(llm.GPTModel): + """Custom GPT subclass for distillation-related modifications.""" + + def __init__( + self, + student_config: llm.GPTConfig, + teacher_config: llm.GPTConfig, + teacher_ckpt_path: str, + optim: Optional["OptimizerModule"] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(student_config, optim, tokenizer, model_transform) + self._teacher_config = teacher_config + self._teacher_ckpt_path = teacher_ckpt_path + self._train_called = False + + if self.config.virtual_pipeline_model_parallel_size is not None: + raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") + + def configure_model(self): + if hasattr(self, "module"): + return + + model = self.config.configure_model(self.tokenizer) + + # Ensure same for both models. + for attr in [ + "tensor_model_parallel_size", + "pipeline_model_parallel_size", + "context_parallel_size", + "sequence_parallel", + "pipeline_dtype", + ]: + setattr(self._teacher_config, attr, getattr(self.config, attr)) + + # [ModelOpt] Intialize DistillationModel. + distill_cfg = load_distillation_config(self.config) + kd_config = { + "teacher_model": ( + teacher_provider, + [self._teacher_config, self._teacher_ckpt_path], + {"tokenizer": self.tokenizer, "trainer": self.trainer}, + ), + "criterion": distill_cfg["criterion"], + "loss_balancer": distill_cfg["loss_balancer"], + } + distillation_model = mtd.convert(model, mode=[("kd_loss", kd_config)]) + + # Additional MCore-specific tweaks needed. + adjust_distillation_model_for_mcore(distillation_model, model_cfg=self.config, distill_cfg=distill_cfg) + + self.module = distillation_model + + def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> Dict[str, Tensor]: + # NOTE: Ignores `self.config.data_step_fn` + if cache_num_batches: + batches = [ + gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches) + ] + return LoopingCachedDataIterator(batches) + elif isinstance(dataloader_iter, LoopingCachedDataIterator): + batch = next(dataloader_iter) + if "attention_mask" in batch: + batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU + return batch + else: + return gpt_distillation_data_step(dataloader_iter) + + def get_inference_wrapper(self, *args, **kwargs) -> Tensor: + raise NotImplementedError( + "Please restore a checkpoint of this model to its original class to call `get_inference_wrapper`" + ) + + @property + def training_loss_reduction(self) -> _DistillationLossReduction: + if not self._training_loss_reduction: + self._training_loss_reduction = _DistillationLossReduction( + distillation_loss_fn=self.core_module.compute_kd_loss + ) + return self._training_loss_reduction + + def load_state_dict(self, state_dict, *args, **kwargs): + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # `super()` would go to `nn.Module` and skip the Context Manager in `mtd.DistillationModel.load_state_dict()` + return self.core_module.load_state_dict(state_dict, *args, *kwargs) + + @property + def core_module(self): + return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) + + def train(self, mode: bool = True): + self._train_called = True + return super().train(mode) + + def __setattr__(self, name, value): + # HACK: PTL calls `module.training = True` after sanity check, bypassing `module.train()` which we depend on. + if name == "training": + if not self._train_called: + self.train(value) + return + self._train_called = False + return super().__setattr__(name, value) diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py new file mode 100644 index 000000000000..12a10337c210 --- /dev/null +++ b/nemo/collections/llm/distillation/utils.py @@ -0,0 +1,155 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict + +import modelopt.torch.distill as mtd +import modelopt.torch.opt as mto +import torch +from megatron.core import parallel_state +from torch import Tensor + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.utils import logging + +from .loss import LogitsKLLoss + +if TYPE_CHECKING: + from megatron.core.dist_checkpointing.mapping import ShardedStateDict + from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel + from megatron.core.transformer.transformer_config import TransformerConfig + + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +def load_distillation_config(cfg: "TransformerConfig") -> Dict[str, Any]: + """Create a default distillation config for MCore GPT Models. + + Args: + student_cfg: Model config for student model. + """ + logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel + distill_cfg = { + "criterion": {}, + "loss_balancer": _DummyLossBalancer(), # HACK: to appease ModelOpt until validation relaxed + "skip_lm_loss": True, + } + if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): + distill_cfg["criterion"][logit_pair] = LogitsKLLoss(cfg) + + return distill_cfg + + +class _DummyLossBalancer(mtd.DistillationLossBalancer): + def forward(self, loss_dict): + return next(iter(loss_dict.values())) + + +def teacher_provider( + config: llm.GPTConfig, ckpt_path: str, tokenizer: "TokenizerSpec", trainer: nl.Trainer +) -> "MCoreGPTModel": + """Teacher model factory (must be a non-local function to pickle).""" + logging.info("Distillation: Loading teacher weights...") + + # TODO(aanoosheh): Replace spec with modelopt one + model = config.configure_model(tokenizer) + + sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")} + checkpoint = trainer.strategy.checkpoint_io.load_checkpoint(ckpt_path, sharded_state_dict) + state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} + model.load_state_dict(state_dict) + + torch.cuda.empty_cache() + logging.info("Distillation: ...teacher weights loaded.") + return model + + +class LoopingCachedDataIterator: + def __init__(self, data): + self.data = data + self.it = iter(self.data) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.it) + except StopIteration: + self.it = iter(self.data) + return next(self.it) + + +def adjust_distillation_model_for_mcore( + model: mtd.DistillationModel, model_cfg: "TransformerConfig", distill_cfg: Dict[str, Any] +): + """Extra modifcations to ``mtd.DistillationModel`` requried for Megatron-Core.""" + + # HACK: Get rid of ModelOpt Distillation state + # NOTE: If re-placed, above losses need modifcation as `TransformerConfig` has non-pickleable elements. + mto.ModeloptStateManager(model)._state.pop() + + # HACK: Hide teacher during `sharded_state_dict` method. + def _sharded_state_dict(self, *args, **kwargs) -> "ShardedStateDict": + with self.hide_teacher_model(): + return self._sharded_state_dict(*args, **kwargs) + + model._sharded_state_dict = model.sharded_state_dict + model.sharded_state_dict = MethodType(_sharded_state_dict, model) + + # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + if self.training: + return torch.zeros_like(labels, dtype=logits.dtype) + return self._compute_language_model_loss(labels, logits) + + if distill_cfg["skip_lm_loss"]: + model._compute_language_model_loss = model.compute_language_model_loss + model.compute_language_model_loss = MethodType(_compute_language_model_loss, model) + + # HACK: Skip `lm_loss` always for teacher. + def _compute_language_model_loss(self, labels, logits) -> Tensor: + return torch.zeros_like(labels, dtype=logits.dtype) + + model.teacher_model.compute_language_model_loss = MethodType(_compute_language_model_loss, model.teacher_model) + + if model_cfg.pipeline_model_parallel_size > 1: + + def _set_input_tensor(self, input_tensor: Tensor): + obj = self.teacher_model if self._only_teacher_fwd else self + return type(self).set_input_tensor(obj, input_tensor) + + # HACK: Pipeline-parallel Distillation requires a way to cache input batches for subsequent + # forward calls, as well as a way to pass through output tensors to teacher model. + model.set_input_tensor = MethodType(_set_input_tensor, model) + + @contextmanager + def _swap_teacher_config(self, model_wrapper): + try: + if hasattr(model_wrapper, "config"): + model_wrapper._config = model_wrapper.config + model_wrapper.config = self.teacher_model.config + yield + finally: + del model_wrapper.config + if hasattr(model_wrapper, "_config"): + model_wrapper.config = model_wrapper._config + del model_wrapper._config + + # HACK: Pipeline-parallel forward function relies on the config in the model to know what + # hidden size of tensor to communicate to next stage. + model.swap_teacher_config = MethodType(_swap_teacher_config, model) diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 8014bc7bd1ac..609a18a0a841 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,39 +13,16 @@ # limitations under the License. import os -from abc import ABCMeta from argparse import ArgumentParser -from contextlib import contextmanager -from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple -import modelopt.torch.distill as mtd -import modelopt.torch.opt as mto -import torch -import torch.nn.functional as F -from megatron.core import parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel from megatron.core.optimizer import OptimizerConfig -from megatron.core.transformer.module import Float16Module as MCoreFloat16Module -from megatron.core.transformer.transformer_config import TransformerConfig -from torch import Tensor, nn -from torch.nn.modules.loss import _Loss from nemo import lightning as nl from nemo.collections import llm -from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank -from nemo.collections.nlp.modules.common.megatron.module import Float16Module -from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.collections.llm import distillation as distill from nemo.lightning.ckpt_utils import ckpt_to_context_subdir -from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction from nemo.lightning.pytorch.callbacks import ModelCheckpoint -from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, OptimizerModule -from nemo.utils import logging -from nemo.utils.model_utils import unwrap_model - -if TYPE_CHECKING: - from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler def get_args(): @@ -57,21 +34,20 @@ def get_args(): parser.add_argument("--name", type=str, required=True, help="""Experiment name""") parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") - parser.add_argument("--tp", type=int, default=1, help="""Tensor parallel size""") - parser.add_argument("--cp", type=int, default=1, help="""Context parallel size""") - parser.add_argument("--pp", type=int, default=1, help="""Pipeline parallel size""") - parser.add_argument("--enable_sp", action="store_true", help="""Enable Sequence parallelism""") + parser.add_argument("--tp_size", type=int, default=1, help="""Tensor parallel size""") + parser.add_argument("--cp_size", type=int, default=1, help="""Context parallel size""") + parser.add_argument("--pp_size", type=int, default=1, help="""Pipeline parallel size""") parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""") parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""") - parser.add_argument("--nodes", type=int, default=1, help="""Number of nodes to use""") + parser.add_argument("--num_nodes", type=int, default=1, help="""Number of nodes to use""") parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") - parser.add_argument("--steps", type=int, required=True, help="""Number of global batches to process""") - parser.add_argument("--global_batch_size", type=int, required=True, help="""Data samples per optimizer step""") - parser.add_argument("--micro_batch_size", type=int, required=True, help="""Data samples per forward pass""") + parser.add_argument("--max_steps", type=int, required=True, help="""Number of global batches to process""") + parser.add_argument("--gbs", type=int, required=True, help="""Data samples per optimizer step""") + parser.add_argument("--mbs", type=int, required=True, help="""Data samples per forward pass""") parser.add_argument("--data_paths", nargs='+', required=True, help="""List of tokenized data paths to load from""") parser.add_argument("--split", type=str, default="99,1,0", help="""""") parser.add_argument("--index_mapping_dir", type=str, default=None, help="""""") - parser.add_argument("--sequence_length", type=int, required=True, help="""Number of tokens per input sample""") + parser.add_argument("--seq_length", type=int, required=True, help="""Number of tokens per input sample""") parser.add_argument("--lr", type=float, default=3e-5, help="""""") parser.add_argument("--min_lr", type=float, default=2e-7, help="""""") parser.add_argument("--warmup_steps", type=int, default=50, help="""""") @@ -80,489 +56,12 @@ def get_args(): parser.add_argument("--limit_test_batches", type=int, default=32, help="""""") parser.add_argument("--log_interval", type=int, default=10, help="""""") - args = parser.parse_args() - return args - - -def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]: - batch = next(dataloader_iter) - - _batch: dict - if isinstance(batch, tuple) and len(batch) == 3: - _batch = batch[0] - else: - _batch = batch - - required_device_keys = set() - required_host_keys = set() - - if attn_mask_cpu: - # [ModelOpt]: We cache data for PP distillation, and save GPU mem by storing masks on CPU mem. - required_host_keys.add("attention_mask") - else: - required_device_keys.add("attention_mask") - - if 'cu_seqlens' in _batch: - required_device_keys.add('cu_seqlens') - required_host_keys.add('cu_seqlens_argmin') - required_host_keys.add('max_seqlen') - - if parallel_state.is_pipeline_first_stage(): - required_device_keys.update(("tokens", "position_ids")) - if parallel_state.is_pipeline_last_stage(): - required_device_keys.update(("labels", "loss_mask")) - - _batch_required_keys = {} - for key, val in _batch.items(): - if key in required_device_keys: - _batch_required_keys[key] = val.cuda(non_blocking=True) - elif key in required_host_keys: - _batch_required_keys[key] = val.cpu() - else: - _batch_required_keys[key] = None - - # slice batch along sequence dimension for context parallelism - output = get_batch_on_this_context_parallel_rank(_batch_required_keys) - - return output - - -class _DistillationLossReduction(MaskedTokenLossReduction): - """Custom masking and reduction callable used only in training mode.""" - - def __init__(self, distillation_loss_fn, *args, **kwargs): - super().__init__(*args, **kwargs) - self._distillation_loss_fn = distillation_loss_fn - self._cp_size = parallel_state.get_context_parallel_world_size() - - def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]: - if isinstance(forward_out, tuple): - # neva returns (logits, loss_mask) - forward_out, batch["loss_mask"] = forward_out - - # [ModelOpt]: KD loss calculation. - loss_for_ub = self._distillation_loss_fn( - loss_reduction_fn=lambda x: self._masked_token_loss( - x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub") - ) - ) - - reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - return loss_for_ub * self._cp_size, {"avg": reduced_loss} - - def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None): - """ - The function takes as input per-token loss and masks non-required values. - """ - if isinstance(loss_output, tuple): - # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) - loss_output, tp_reduce = loss_output - else: - tp_reduce = False - losses = loss_output.float() - loss_mask = mask.view(-1).float() - - if self._cp_size > 1: - if num_valid_tokens_in_ub is None: - num_valid_tokens_in_ub = loss_mask.sum() - if num_valid_tokens_in_ub < 0.5: # no valid tokens - num_valid_tokens_in_ub += 1.0 - loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll - torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) - else: - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll - - if tp_reduce is True: - torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group()) - - return loss - - -class DistillationGPTModel(llm.GPTModel): - """Custom GPT subclass for distillation-related modifications.""" - - def __init__( - self, - student_config: llm.GPTConfig, - teacher_config: llm.GPTConfig, - teacher_ckpt_path: str, - optim: Optional[OptimizerModule] = None, - tokenizer: Optional["TokenizerSpec"] = None, - model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, - ): - super().__init__(student_config, optim, tokenizer, model_transform) - self._teacher_config = teacher_config - self._teacher_ckpt_path = teacher_ckpt_path - self._train_called = False - - if self.config.virtual_pipeline_model_parallel_size is not None: - raise ValueError("ModelOpt Distillation incompatible with interleaved pipeline schedule.") - - def configure_model(self): - if hasattr(self, "module"): - return - - model = self.config.configure_model(self.tokenizer) - - # Ensure same for both models. - for attr in [ - "tensor_model_parallel_size", - "pipeline_model_parallel_size", - "context_parallel_size", - "sequence_parallel", - "pipeline_dtype", - ]: - setattr(self._teacher_config, attr, getattr(self.config, attr)) - - # [ModelOpt] Intialize DistillationModel. - distill_cfg = load_distillation_config(self.config) - kd_config = { - "teacher_model": ( - _teacher_provider, - [self._teacher_config, self._teacher_ckpt_path], - {"tokenizer": self.tokenizer, "trainer": self.trainer}, - ), - "criterion": distill_cfg["criterion"], - "loss_balancer": distill_cfg["loss_balancer"], - } - model = mtd.convert(model, mode=[("kd_loss", kd_config)]) - - # Additional MCore-specific tweaks needed. - adjust_distillation_model_for_mcore(model, model_cfg=self.config, distill_cfg=distill_cfg) - - self.module = model - - def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) -> Dict[str, torch.Tensor]: - # NOTE: Ignores `self.config.data_step_fn` - if cache_num_batches: - batches = [ - gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=True) for _ in range(cache_num_batches) - ] - return _LoopingCachedDataIterator(batches) - elif isinstance(dataloader_iter, _LoopingCachedDataIterator): - batch = next(dataloader_iter) - if "attention_mask" in batch: - batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU - return batch - else: - return gpt_distillation_data_step(dataloader_iter) - - @property - def training_loss_reduction(self) -> _DistillationLossReduction: - if not self._training_loss_reduction: - self._training_loss_reduction = _DistillationLossReduction( - distillation_loss_fn=self.core_module.compute_kd_loss - ) - - return self._training_loss_reduction - - def load_state_dict(self, state_dict, *args, **kwargs): - state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} - # `super()` would go to `nn.Module` and skip the Context Manager in `mtd.DistillationModel.load_state_dict()` - return self.core_module.load_state_dict(state_dict, *args, *kwargs) - - @property - def core_module(self): - return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) - - def train(self, mode: bool = True): - self._train_called = True - return super().train(mode) - - def __setattr__(self, name, value): - # HACK: PTL calls `module.training = True` after sanity check, bypassing `module.train()` which we depend on. - if name == "training": - if not self._train_called: - self.train(value) - return - self._train_called = False - return super().__setattr__(name, value) + return parser.parse_args() ######################################################## -class BaseLoss(_Loss, metaclass=ABCMeta): - """Abstract base class for Megatron distillation losses.""" - - def __init__(self, model_config: TransformerConfig): - """ - Constructor. - - Args: - model_config: MCore transformer config. - """ - super().__init__() - self._config = model_config - - def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: - """Prepares inputs safely for loss computation.""" - if isinstance(predictions, tuple): - # `ColumnParallelLinear` returns bias too - predictions, targets = predictions[0], targets[0] - targets = targets.detach() - - return predictions, targets - - def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: - """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" - loss = loss.transpose(0, 1).contiguous() - return loss, tp_reduce - - -class LogitsKLLoss(BaseLoss): - """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" - - def __init__(self, model_config: TransformerConfig, temperature: float = 1.0, reverse: bool = False): - """ - Constructor. - - Args: - model_config: MCore transformer config. - temperature: Divide tensors by this value prior to calculating loss. - reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) - """ - super().__init__(model_config) - self._temperature = temperature - self._reverse = reverse - - def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: - """ - Forward function. - - Args: - predictions: Student model tensors (size [s, b, h]) - targets: Teacher model tensors (size [s, b, h]) - - Returns: - KLD loss of tensors (size [b, s]) - """ - predictions, targets = self.pre_forward(predictions, targets) - - # Division by temp should happen prior to finding max for both student and teacher. - # Currently we don't use temperature in any of ours runs (temp=1.0) - output_teacher = targets.float() / self._temperature - output_student = predictions.float() / self._temperature - - # Compute local softmax, and the reweight to compute global softmax. - if self._config.tensor_model_parallel_size > 1: - - # Maximum value along vocab dimension across all GPUs. - teacher_logits_max, _ = torch.max(output_teacher, dim=-1) - torch.distributed.all_reduce( - teacher_logits_max, - op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), - ) - output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) - - denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) - # We can't use standard reduction function here since the computation - # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group()) - - # Maximum value along vocab dimension across all GPUs. - student_logits_max, _ = torch.max(output_student, dim=-1) - torch.distributed.all_reduce( - student_logits_max, - op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), - ) - output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() - - denom_student = torch.sum(torch.exp(output_student), dim=-1) - denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group()) - - slen, bsz, sharded_vocab_size = output_student.shape - student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size - ) - teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size - ) - - if self._reverse: - loss = torch.sum( - F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), - dim=-1, - ) - else: - loss = torch.sum( - F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), - dim=-1, - ) - - else: - if self._reverse: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_teacher, dim=-1), - F.softmax(output_student, dim=-1), - reduction="none", - ), - dim=-1, - ) - else: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_student, dim=-1), - F.softmax(output_teacher, dim=-1), - reduction="none", - ), - dim=-1, - ) - - return self.post_forward(loss, tp_reduce=True) - - -class _AllReduce(torch.autograd.Function): - """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" - - @staticmethod - def forward(ctx, op, group, tensor): - ctx.group, ctx.op = group, op - tensor = tensor.clone() - torch.distributed.all_reduce(tensor, op=op, group=group) - return tensor - - @staticmethod - def backward(ctx, grad_output): - return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) - - -def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD): - """Custom all-reduce function. - - Needed instead of other all-reduce functions available when the computation following - the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. - """ - return _AllReduce.apply(op, group, tensor) - - -######################################################## - - -def load_distillation_config(cfg: TransformerConfig) -> Dict[str, Any]: - """Create a default distillation config for MCore GPT Models. - - Args: - student_cfg: Model config for student model. - """ - logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel - distill_cfg = { - "criterion": {}, - "loss_balancer": _DummyLossBalancer(), # HACK: to appease ModelOpt until validation relaxed - "skip_lm_loss": True, - } - if cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage(): - distill_cfg["criterion"][logit_pair] = LogitsKLLoss(cfg) - - return distill_cfg - - -class _DummyLossBalancer(mtd.DistillationLossBalancer): - def forward(self, loss_dict): - return next(iter(loss_dict.values())) - - -def _teacher_provider( - config: llm.GPTConfig, ckpt_path: str, tokenizer: "TokenizerSpec", trainer: nl.Trainer -) -> MCoreGPTModel: - """Teacher model factory (must be a non-local function to pickle).""" - logging.info("Distillation: Loading teacher weights...") - - # TODO(aanoosheh): Replace spec with modelopt one - model = config.configure_model(tokenizer) - - sharded_state_dict = {"state_dict": model.sharded_state_dict(prefix="module.")} - checkpoint = trainer.strategy.checkpoint_io.load_checkpoint(ckpt_path, sharded_state_dict) - state_dict = {k.replace("module.", ""): v for k, v in checkpoint["state_dict"].items()} - model.load_state_dict(state_dict) - - torch.cuda.empty_cache() - logging.info("Distillation: ...teacher weights loaded.") - return model - - -class _LoopingCachedDataIterator: - def __init__(self, data): - self.data = data - self.it = iter(self.data) - - def __iter__(self): - return self - - def __next__(self): - try: - return next(self.it) - except StopIteration: - self.it = iter(self.data) - return next(self.it) - - -def adjust_distillation_model_for_mcore( - model: mtd.DistillationModel, model_cfg: TransformerConfig, distill_cfg: Dict[str, Any] -): - """Extra modifcations to ``mtd.DistillationModel`` requried for Megatron-Core.""" - - # HACK: Get rid of ModelOpt Distillation state - # NOTE: If re-placed, above losses need modifcation as `TransformerConfig` has non-pickleable elements. - mto.ModeloptStateManager(model)._state.pop() - - # HACK: Hide teacher during `sharded_state_dict` method. - def _sharded_state_dict(self, *args, **kwargs) -> ShardedStateDict: - with self.hide_teacher_model(): - return self._sharded_state_dict(*args, **kwargs) - - model._sharded_state_dict = model.sharded_state_dict - model.sharded_state_dict = MethodType(_sharded_state_dict, model) - - # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. - def _compute_language_model_loss(self, labels, logits) -> Tensor: - if self.training: - return torch.zeros_like(labels, dtype=logits.dtype) - return self._compute_language_model_loss(labels, logits) - - if distill_cfg["skip_lm_loss"]: - model._compute_language_model_loss = model.compute_language_model_loss - model.compute_language_model_loss = MethodType(_compute_language_model_loss, model) - - # HACK: Skip `lm_loss` always for teacher. - def _compute_language_model_loss(self, labels, logits) -> Tensor: - return torch.zeros_like(labels, dtype=logits.dtype) - - model.teacher_model.compute_language_model_loss = MethodType(_compute_language_model_loss, model.teacher_model) - - if model_cfg.pipeline_model_parallel_size > 1: - - def _set_input_tensor(self, input_tensor: Tensor): - obj = self.teacher_model if self._only_teacher_fwd else self - return type(self).set_input_tensor(obj, input_tensor) - - # HACK: Pipeline-parallel Distillation requires a way to cache input batches for subsequent - # forward calls, as well as a way to pass through output tensors to teacher model. - model.set_input_tensor = MethodType(_set_input_tensor, model) - - @contextmanager - def _swap_teacher_config(self, model_wrapper): - try: - if hasattr(model_wrapper, "config"): - model_wrapper._config = model_wrapper.config - model_wrapper.config = self.teacher_model.config - yield - finally: - del model_wrapper.config - if hasattr(model_wrapper, "_config"): - model_wrapper.config = model_wrapper._config - del model_wrapper._config - - # HACK: Pipeline-parallel forward function relies on the config in the model to know what - # hidden size of tensor to communicate to next stage. - model.swap_teacher_config = MethodType(_swap_teacher_config, model) - - -######################################################## - # # # from dataclasses import dataclass @@ -582,21 +81,19 @@ class Llama31Config4B(Llama31Config): # # # if __name__ == "__main__": - logging.info("Distillation enabled.") - args = get_args() ## initialize the strategy and trainer strategy = nl.MegatronStrategy( - tensor_model_parallel_size=args.tp, - pipeline_model_parallel_size=args.pp, - context_parallel_size=args.cp, - sequence_parallel=args.enable_sp, + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, + sequence_parallel=True, ) trainer = nl.Trainer( devices=args.devices, - num_nodes=args.nodes, - max_steps=args.steps, + num_nodes=args.num_nodes, + max_steps=args.max_steps, log_every_n_steps=args.log_interval, val_check_interval=args.val_check_interval, limit_val_batches=args.limit_val_batches, @@ -614,7 +111,7 @@ class Llama31Config4B(Llama31Config): assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." # TODO(aanoosheh): Replace spec with modelopt one - model = DistillationGPTModel( + model = distill.DistillationGPTModel( _student_model.config, _teacher_model.config, teacher_ckpt_path=args.teacher_path, @@ -625,9 +122,9 @@ class Llama31Config4B(Llama31Config): # setup the dataset data = llm.PreTrainingDataModule( paths=args.data_paths, - seq_length=args.sequence_length, - micro_batch_size=args.micro_batch_size, - global_batch_size=args.global_batch_size, + seq_length=args.seq_length, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, split=args.split, index_mapping_dir=args.index_mapping_dir, tokenizer=tokenizer, @@ -641,7 +138,7 @@ class Llama31Config4B(Llama31Config): use_distributed_optimizer=True, ) sched = CosineAnnealingScheduler( - max_steps=args.steps, + max_steps=args.max_steps, warmup_steps=args.warmup_steps, constant_steps=0, min_lr=args.min_lr, From 2438a854b7a3f60c99b01ccbbf1cf59fb512a45a Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 22 Jan 2025 07:32:44 -0800 Subject: [PATCH 11/22] Integration test Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index fa2f4b190e30..bdb4562ba68d 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5062,6 +5062,36 @@ jobs: rm -rf /tmp/nemo2_ckpt rm -rf /tmp/nemo2_ptq_engine + L2_NeMo_2_Distill_pp2_Llama2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_pp2_Llama2') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python scripts/llm/gpt_distillation.py \ + --name nemo2_llama_distill \ + --teacher_path /home/TestData/nemo2_ckpt/llama_68M \ + --student_path /home/TestData/nemo2_ckpt/llama_68M \ + --tp_size 2 \ + --cp_size 1 \ + --pp_size 2 \ + --devices 4 \ + --log_dir /tmp/nemo2_llama_distill \ + --max_steps 5 \ + --gbs 4 \ + --mbs 1 \ + --data_paths 1.0 /home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ + --index_mapping_dir examples/nlp/language_modeling/gpt_index_mappings \ + --seq_length 2048 \ + --warmup_steps 1 \ + --val_check_interval 5 \ + --log_interval 5 \ + --limit_val_batches 2 + + AFTER_SCRIPT: | + rm -rf /tmp/nemo2_llama_distill + L2_NeMo_2_Export_In_Framework: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml @@ -5321,6 +5351,7 @@ jobs: - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact - L2_NeMo_2_PTQ_Llama2_FP8 + - L2_NeMo_2_Distill_pp2_Llama2 - L2_NeMo_2_Export_In_Framework - L2_NeMo_2_jit_callback - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING From d644382336af96a5f434c2ace19e0f5ad8b24f36 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 22 Jan 2025 07:59:21 -0800 Subject: [PATCH 12/22] Clean up strings Signed-off-by: Asha Anoosheh --- nemo/collections/llm/distillation/utils.py | 6 +- scripts/llm/gpt_distillation.py | 70 ++++++++-------------- 2 files changed, 25 insertions(+), 51 deletions(-) diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py index 12a10337c210..4c33b350fc46 100644 --- a/nemo/collections/llm/distillation/utils.py +++ b/nemo/collections/llm/distillation/utils.py @@ -37,11 +37,7 @@ def load_distillation_config(cfg: "TransformerConfig") -> Dict[str, Any]: - """Create a default distillation config for MCore GPT Models. - - Args: - student_cfg: Model config for student model. - """ + """Create a default distillation config for MCore GPT Models.""" logit_pair = ("output_layer", "output_layer") # logit module names for MCoreGPTModel distill_cfg = { "criterion": {}, diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 609a18a0a841..694a4ca0a660 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -24,11 +24,12 @@ from nemo.lightning.pytorch.callbacks import ModelCheckpoint from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +# Suppress lengthy HF warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + def get_args(): - """ - Parse the command line arguments. - """ + """Parse the command line arguments.""" parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") parser.add_argument("--name", type=str, required=True, help="""Experiment name""") @@ -42,19 +43,18 @@ def get_args(): parser.add_argument("--num_nodes", type=int, default=1, help="""Number of nodes to use""") parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") parser.add_argument("--max_steps", type=int, required=True, help="""Number of global batches to process""") - parser.add_argument("--gbs", type=int, required=True, help="""Data samples per optimizer step""") - parser.add_argument("--mbs", type=int, required=True, help="""Data samples per forward pass""") - parser.add_argument("--data_paths", nargs='+', required=True, help="""List of tokenized data paths to load from""") - parser.add_argument("--split", type=str, default="99,1,0", help="""""") - parser.add_argument("--index_mapping_dir", type=str, default=None, help="""""") + parser.add_argument("--gbs", type=int, required=True, help="""Global Batch Size""") + parser.add_argument("--mbs", type=int, required=True, help="""Micro-batch Size""") + parser.add_argument("--data_paths", nargs="+", required=True, help="""List of tokenized data paths to load from""") + parser.add_argument("--split", type=str, default="99,1,0", help="""Train,Val,Test ratios to split data""") + parser.add_argument("--index_mapping_dir", type=str, default=None, help="""Folder to write cached data indices""") parser.add_argument("--seq_length", type=int, required=True, help="""Number of tokens per input sample""") - parser.add_argument("--lr", type=float, default=3e-5, help="""""") - parser.add_argument("--min_lr", type=float, default=2e-7, help="""""") - parser.add_argument("--warmup_steps", type=int, default=50, help="""""") - parser.add_argument("--val_check_interval", type=int, default=100, help="""""") - parser.add_argument("--limit_val_batches", type=int, default=32, help="""""") - parser.add_argument("--limit_test_batches", type=int, default=32, help="""""") - parser.add_argument("--log_interval", type=int, default=10, help="""""") + parser.add_argument("--lr", type=float, default=3e-5, help="""Base LR for Cosine-Annealing scheduler""") + parser.add_argument("--min_lr", type=float, default=2e-7, help="""Minimum LR for Cosine-Annealing scheduler""") + parser.add_argument("--warmup_steps", type=int, default=50, help="""Number of scheduler warmup steps""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""Run validation every _ steps""") + parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") + parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") return parser.parse_args() @@ -62,28 +62,10 @@ def get_args(): ######################################################## -# # # -from dataclasses import dataclass - -from nemo.collections.llm.gpt.model.llama import Llama31Config - - -@dataclass -class Llama31Config4B(Llama31Config): - rotary_base: int = 500000 - seq_length: int = 131072 - num_layers: int = 16 - hidden_size: int = 4096 - ffn_hidden_size: int = 14336 - num_attention_heads: int = 32 - - -# # # - if __name__ == "__main__": args = get_args() - ## initialize the strategy and trainer + ## Initialize the strategy and trainer strategy = nl.MegatronStrategy( tensor_model_parallel_size=args.tp_size, pipeline_model_parallel_size=args.pp_size, @@ -97,29 +79,28 @@ class Llama31Config4B(Llama31Config): log_every_n_steps=args.log_interval, val_check_interval=args.val_check_interval, limit_val_batches=args.limit_val_batches, - limit_test_batches=args.limit_test_batches, strategy=strategy, accelerator="gpu", plugins=nl.MegatronMixedPrecision(precision=args.precision), ) - ## load the combined teacher-student model + ## Load both models and combine into an aggregate module _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." - # TODO(aanoosheh): Replace spec with modelopt one model = distill.DistillationGPTModel( _student_model.config, _teacher_model.config, teacher_ckpt_path=args.teacher_path, tokenizer=tokenizer, ) + # TODO(aanoosheh): Replace spec with modelopt one model.__io__ = _student_model.__io__ # HACK: model saves and restores as original class - # setup the dataset + # Set up dataset data = llm.PreTrainingDataModule( paths=args.data_paths, seq_length=args.seq_length, @@ -130,7 +111,7 @@ class Llama31Config4B(Llama31Config): tokenizer=tokenizer, ) - ## setup the optimizer + ## Set up optimizer opt_config = OptimizerConfig( optimizer="adam", lr=args.lr, @@ -145,7 +126,7 @@ class Llama31Config4B(Llama31Config): ) opt = nl.MegatronOptimizerModule(opt_config, sched) - # checkpointing and logging + # Set up checkpointing and logging checkpoint_callback = ModelCheckpoint( monitor="val_loss", save_top_k=1, @@ -157,22 +138,19 @@ class Llama31Config4B(Llama31Config): ckpt=checkpoint_callback, ) - # auto-resume setup + # Set up resume and/or restore functionality resume = nl.AutoResume( resume_if_exists=True, resume_ignore_no_checkpoint=True, restore_config=nl.RestoreConfig(path=args.student_path), ) - # suppress HF warning - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # run + # Run llm.train( model=model, data=data, optim=opt, - tokenizer='model', + tokenizer="model", trainer=trainer, log=logger, resume=resume, From a6a9f071647d00b24be10ed26964d356ad5fa045 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 22 Jan 2025 09:42:09 -0800 Subject: [PATCH 13/22] Appease linter Signed-off-by: Asha Anoosheh --- nemo/collections/llm/distillation/loss.py | 2 ++ nemo/collections/llm/distillation/model.py | 8 +++++--- nemo/collections/llm/distillation/utils.py | 3 +++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/distillation/loss.py b/nemo/collections/llm/distillation/loss.py index e970386a058c..f936f56759c1 100644 --- a/nemo/collections/llm/distillation/loss.py +++ b/nemo/collections/llm/distillation/loss.py @@ -163,6 +163,7 @@ class _AllReduce(torch.autograd.Function): @staticmethod def forward(ctx, op, group, tensor): + # pylint: disable=C0116 ctx.group, ctx.op = group, op tensor = tensor.clone() torch.distributed.all_reduce(tensor, op=op, group=group) @@ -170,6 +171,7 @@ def forward(ctx, op, group, tensor): @staticmethod def backward(ctx, grad_output): + # pylint: disable=C0116 return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) diff --git a/nemo/collections/llm/distillation/model.py b/nemo/collections/llm/distillation/model.py index 9b69fc722bb3..618c3ef4f78e 100644 --- a/nemo/collections/llm/distillation/model.py +++ b/nemo/collections/llm/distillation/model.py @@ -40,6 +40,7 @@ def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, Tensor]: + """Same as base GPT's data step but with ability to move attention mask to CPU.""" batch = next(dataloader_iter) if isinstance(batch, tuple) and len(batch) == 3: @@ -103,9 +104,7 @@ def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor return loss_for_ub * self._cp_size, {"avg": reduced_loss} def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None): - """ - The function takes as input per-token loss and masks non-required values. - """ + """The function takes as input per-token loss and masks non-required values.""" if isinstance(loss_output, tuple): # [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required) loss_output, tp_reduce = loss_output @@ -213,15 +212,18 @@ def training_loss_reduction(self) -> _DistillationLossReduction: return self._training_loss_reduction def load_state_dict(self, state_dict, *args, **kwargs): + # pylint: disable=C0116 state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} # `super()` would go to `nn.Module` and skip the Context Manager in `mtd.DistillationModel.load_state_dict()` return self.core_module.load_state_dict(state_dict, *args, *kwargs) @property def core_module(self): + # pylint: disable=C0116 return unwrap_model(self.module, (DDP, Float16Module, MCoreFloat16Module)) def train(self, mode: bool = True): + # pylint: disable=C0116 self._train_called = True return super().train(mode) diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py index 4c33b350fc46..00078628b66a 100644 --- a/nemo/collections/llm/distillation/utils.py +++ b/nemo/collections/llm/distillation/utils.py @@ -52,6 +52,7 @@ def load_distillation_config(cfg: "TransformerConfig") -> Dict[str, Any]: class _DummyLossBalancer(mtd.DistillationLossBalancer): def forward(self, loss_dict): + # pylint: disable=C0116 return next(iter(loss_dict.values())) @@ -75,6 +76,8 @@ def teacher_provider( class LoopingCachedDataIterator: + """Iterator which takes in a sequence and cycles through it when exhausted.""" + def __init__(self, data): self.data = data self.it = iter(self.data) From eff0d9018a99e4a3c5a374cca9a4a7455c7d0242 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 23 Jan 2025 05:58:38 -0800 Subject: [PATCH 14/22] Remediate failing tests Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 4 ++-- nemo/lightning/megatron_parallel.py | 4 +--- nemo/utils/model_utils.py | 10 +++++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index bdb4562ba68d..97a156ee7aa6 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5073,10 +5073,10 @@ jobs: --name nemo2_llama_distill \ --teacher_path /home/TestData/nemo2_ckpt/llama_68M \ --student_path /home/TestData/nemo2_ckpt/llama_68M \ - --tp_size 2 \ + --tp_size 1 \ --cp_size 1 \ --pp_size 2 \ - --devices 4 \ + --devices 2 \ --log_dir /tmp/nemo2_llama_distill \ --max_steps 5 \ --gbs 4 \ diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 9713657f67b9..9b21e9052b93 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -187,9 +187,7 @@ def __init__( convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None, ) -> None: from megatron.core import parallel_state - from megatron.core.transformer.module import Float16Module as McoreFloat16Module - from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.utils.model_utils import unwrap_model _pipeline: List[nn.Module] @@ -223,7 +221,7 @@ def __init__( self.convert_module_fn = convert_module_fn # [ModelOpt]: Detect Pipeline-parallel Distillation mode. - self._unwrapped_model = [unwrap_model(self.module.module, (DDP, Float16Module, McoreFloat16Module))] + self._unwrapped_model = [unwrap_model(self)] # Avoid re-registering module which breaks the inherited `ModuleList` somehow. if ( hasattr(self.unwrapped_model, "teacher_model") diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 5d7d019c6099..a64e8d7b4a22 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -92,7 +92,7 @@ def load_config(model_file: str) -> DictConfig: return model_config -def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): +def unwrap_model(model, module_instances: Optional[Union[Type, Tuple[Type]]] = None): """Unwrap model from wrapper classes like Float16Module, for example.""" # TODO: Import this from megatron.core once moved there from megatron.training. @@ -102,8 +102,12 @@ def unwrap_model(model, module_instances: Union[Type, Tuple[Type]]): return_list = False unwrapped_model = [] for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module + if module_instances: + while isinstance(model_module, module_instances): + model_module = model_module.module + else: # remove any wrappers that have a '.module' attribute + while hasattr(model_module, "module"): + model_module = model_module.module unwrapped_model.append(model_module) if not return_list: return unwrapped_model[0] From 7650d8f804fe71eb0947f20b97ee991aa672428c Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 24 Jan 2025 06:22:44 -0800 Subject: [PATCH 15/22] Update CICD model definition Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 13 +++++++------ scripts/llm/gpt_distillation.py | 4 ++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 97a156ee7aa6..e090ffaa8ef3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5062,17 +5062,17 @@ jobs: rm -rf /tmp/nemo2_ckpt rm -rf /tmp/nemo2_ptq_engine - L2_NeMo_2_Distill_pp2_Llama2: + L2_NeMo_2_Distill_pp2_Llama3: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_pp2_Llama2') || needs.cicd-test-container-setup.outputs.all == 'true' + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_pp2_Llama3') || needs.cicd-test-container-setup.outputs.all == 'true' with: RUNNER: self-hosted-azure SCRIPT: | python scripts/llm/gpt_distillation.py \ --name nemo2_llama_distill \ - --teacher_path /home/TestData/nemo2_ckpt/llama_68M \ - --student_path /home/TestData/nemo2_ckpt/llama_68M \ + --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ + --student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --tp_size 1 \ --cp_size 1 \ --pp_size 2 \ @@ -5087,7 +5087,8 @@ jobs: --warmup_steps 1 \ --val_check_interval 5 \ --log_interval 5 \ - --limit_val_batches 2 + --limit_val_batches 2 \ + --cicd_run AFTER_SCRIPT: | rm -rf /tmp/nemo2_llama_distill @@ -5351,7 +5352,7 @@ jobs: - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact - L2_NeMo_2_PTQ_Llama2_FP8 - - L2_NeMo_2_Distill_pp2_Llama2 + - L2_NeMo_2_Distill_pp2_Llama3 - L2_NeMo_2_Export_In_Framework - L2_NeMo_2_jit_callback - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 694a4ca0a660..35878afc5431 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -55,6 +55,7 @@ def get_args(): parser.add_argument("--val_check_interval", type=int, default=100, help="""Run validation every _ steps""") parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") + parser.add_argument("--cicd_run", action="store_true", help="Used only when called by NeMo CI") return parser.parse_args() @@ -85,6 +86,9 @@ def get_args(): ) ## Load both models and combine into an aggregate module + if args.cicd_run: + from tests.collections.llm.common import Llama3ConfigCI # pylint: disable=W0611 + _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") From 140412bccf849d7d46b4175d48db2142b198adfa Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 24 Jan 2025 10:00:02 -0800 Subject: [PATCH 16/22] Divert TB logger to same log_dir Signed-off-by: Asha Anoosheh --- scripts/llm/gpt_distillation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 35878afc5431..ca04be712a54 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -15,6 +15,7 @@ import os from argparse import ArgumentParser +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.optimizer import OptimizerConfig from nemo import lightning as nl @@ -140,6 +141,8 @@ def get_args(): name=args.name, log_dir=args.log_dir, ckpt=checkpoint_callback, + tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), + update_logger_directory=False, ) # Set up resume and/or restore functionality From 3d33570f57811762e92d7440bef30fb81bbf73ae Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 27 Jan 2025 07:18:42 -0800 Subject: [PATCH 17/22] Load CICD model specially Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 4 ++-- nemo/collections/llm/distillation/utils.py | 14 ++++++++++++++ scripts/llm/gpt_distillation.py | 12 +++++++----- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e090ffaa8ef3..c1ad8c676e9e 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5071,8 +5071,8 @@ jobs: SCRIPT: | python scripts/llm/gpt_distillation.py \ --name nemo2_llama_distill \ - --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ - --student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ + --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v3 \ + --student_path /home/TestData/nemo2_ckpt/llama_68M_v3 \ --tp_size 1 \ --cp_size 1 \ --pp_size 2 \ diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py index 00078628b66a..b442e01b6977 100644 --- a/nemo/collections/llm/distillation/utils.py +++ b/nemo/collections/llm/distillation/utils.py @@ -152,3 +152,17 @@ def _swap_teacher_config(self, model_wrapper): # HACK: Pipeline-parallel forward function relies on the config in the model to know what # hidden size of tensor to communicate to next stage. model.swap_teacher_config = MethodType(_swap_teacher_config, model) + + +def load_cicd_models(student_path: str): + # pylint: disable=C0116 + import os.path + + from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer + from tests.collections.llm.common import Llama3ConfigCI + + tokenizer = get_nmt_tokenizer(tokenizer_model=os.path.join(student_path, "dummy_tokenizer.model")) + student_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + teacher_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + + return student_model, teacher_model, tokenizer diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index ca04be712a54..e4d210efabf0 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -88,13 +88,15 @@ def get_args(): ## Load both models and combine into an aggregate module if args.cicd_run: - from tests.collections.llm.common import Llama3ConfigCI # pylint: disable=W0611 + from nemo.collections.llm.distillation.utils import load_cicd_models - _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") - _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") + _student_model, _teacher_model, tokenizer = load_cicd_models(args.student_path) + else: + _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") - tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) - assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." + tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) + assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." model = distill.DistillationGPTModel( _student_model.config, From 94e6b04701d880f197d189d6ebdb93e97c935e37 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 28 Jan 2025 06:42:58 -0800 Subject: [PATCH 18/22] Fix SP flag Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 5 +++-- scripts/llm/gpt_distillation.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index c1ad8c676e9e..b9abb625daa5 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5069,10 +5069,11 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | + export CUDA_LAUNCH_BLOCKING=1 && \ python scripts/llm/gpt_distillation.py \ --name nemo2_llama_distill \ - --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v3 \ - --student_path /home/TestData/nemo2_ckpt/llama_68M_v3 \ + --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ + --student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --tp_size 1 \ --cp_size 1 \ --pp_size 2 \ diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index e4d210efabf0..6224369ca1f1 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -72,7 +72,7 @@ def get_args(): tensor_model_parallel_size=args.tp_size, pipeline_model_parallel_size=args.pp_size, context_parallel_size=args.cp_size, - sequence_parallel=True, + sequence_parallel=(args.tp_size > 1), ) trainer = nl.Trainer( devices=args.devices, From 1513acc40e14e010d4e68f2b26e1369eedac80af Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 29 Jan 2025 06:27:51 -0800 Subject: [PATCH 19/22] Move test into own script Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 16 +-- nemo/collections/llm/distillation/model.py | 20 +++ nemo/collections/llm/distillation/utils.py | 26 +--- scripts/llm/gpt_distillation.py | 19 +-- tests/collections/llm/gpt_distillation.py | 157 +++++++++++++++++++++ 5 files changed, 195 insertions(+), 43 deletions(-) create mode 100644 tests/collections/llm/gpt_distillation.py diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index b9abb625daa5..50c7321fe65b 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5062,15 +5062,14 @@ jobs: rm -rf /tmp/nemo2_ckpt rm -rf /tmp/nemo2_ptq_engine - L2_NeMo_2_Distill_pp2_Llama3: + L2_NeMo_2_Distill_Llama3_TP1PP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_pp2_Llama3') || needs.cicd-test-container-setup.outputs.all == 'true' + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.cicd-test-container-setup.outputs.all == 'true' with: RUNNER: self-hosted-azure SCRIPT: | - export CUDA_LAUNCH_BLOCKING=1 && \ - python scripts/llm/gpt_distillation.py \ + python tests/collections/llm/gpt_distillation.py \ --name nemo2_llama_distill \ --teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ --student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \ @@ -5078,7 +5077,7 @@ jobs: --cp_size 1 \ --pp_size 2 \ --devices 2 \ - --log_dir /tmp/nemo2_llama_distill \ + --log_dir /tmp/distill_logs \ --max_steps 5 \ --gbs 4 \ --mbs 1 \ @@ -5088,11 +5087,10 @@ jobs: --warmup_steps 1 \ --val_check_interval 5 \ --log_interval 5 \ - --limit_val_batches 2 \ - --cicd_run + --limit_val_batches 2 AFTER_SCRIPT: | - rm -rf /tmp/nemo2_llama_distill + rm -rf /tmp/distill_logs L2_NeMo_2_Export_In_Framework: needs: [pre-flight, cicd-test-container-build] @@ -5353,7 +5351,7 @@ jobs: - L2_Megatron_GPT_Reranker - L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact - L2_NeMo_2_PTQ_Llama2_FP8 - - L2_NeMo_2_Distill_pp2_Llama3 + - L2_NeMo_2_Distill_Llama3_TP1PP2 - L2_NeMo_2_Export_In_Framework - L2_NeMo_2_jit_callback - L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING diff --git a/nemo/collections/llm/distillation/model.py b/nemo/collections/llm/distillation/model.py index 618c3ef4f78e..4bd1510a4b7c 100644 --- a/nemo/collections/llm/distillation/model.py +++ b/nemo/collections/llm/distillation/model.py @@ -141,6 +141,26 @@ def __init__( tokenizer: Optional["TokenizerSpec"] = None, model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, ): + """ + Constructor. + + This subclass of GPTModel takes the configs of a student and teacher model and overrides + the model construction step to create a ModelOpt `DistillationModel` as the underlying + MCore model. This model abstracts both student and teacher as a single module whose forward + pass runs both, and whose loss function automatically calculates a distillation loss on the + output logits. + + NOTE: This class saves checkpoints which will be re-loaded as the student's original class. + This allows one to continue using the model after distillation without this special class. + + Args: + student_config: Config of student model. + teacher_config: Config of teacher model. + teacher_ckpt_path: Path to teacher checkpoint (to restore weights). + optim: Optimizer. + tokenizer: Tokenizer. + model_transform: Transform to apply to model during setup. + """ super().__init__(student_config, optim, tokenizer, model_transform) self._teacher_config = teacher_config self._teacher_ckpt_path = teacher_ckpt_path diff --git a/nemo/collections/llm/distillation/utils.py b/nemo/collections/llm/distillation/utils.py index b442e01b6977..3ed48865d19f 100644 --- a/nemo/collections/llm/distillation/utils.py +++ b/nemo/collections/llm/distillation/utils.py @@ -98,11 +98,11 @@ def adjust_distillation_model_for_mcore( ): """Extra modifcations to ``mtd.DistillationModel`` requried for Megatron-Core.""" - # HACK: Get rid of ModelOpt Distillation state + # Get rid of ModelOpt Distillation state # NOTE: If re-placed, above losses need modifcation as `TransformerConfig` has non-pickleable elements. mto.ModeloptStateManager(model)._state.pop() - # HACK: Hide teacher during `sharded_state_dict` method. + # Hide teacher during `sharded_state_dict` method. def _sharded_state_dict(self, *args, **kwargs) -> "ShardedStateDict": with self.hide_teacher_model(): return self._sharded_state_dict(*args, **kwargs) @@ -110,7 +110,7 @@ def _sharded_state_dict(self, *args, **kwargs) -> "ShardedStateDict": model._sharded_state_dict = model.sharded_state_dict model.sharded_state_dict = MethodType(_sharded_state_dict, model) - # HACK: Skip `lm_loss` bypassing it when training if not needed for backprop. + # Skip `lm_loss` bypassing it when training if not needed for backprop. def _compute_language_model_loss(self, labels, logits) -> Tensor: if self.training: return torch.zeros_like(labels, dtype=logits.dtype) @@ -120,7 +120,7 @@ def _compute_language_model_loss(self, labels, logits) -> Tensor: model._compute_language_model_loss = model.compute_language_model_loss model.compute_language_model_loss = MethodType(_compute_language_model_loss, model) - # HACK: Skip `lm_loss` always for teacher. + # Skip `lm_loss` always for teacher. def _compute_language_model_loss(self, labels, logits) -> Tensor: return torch.zeros_like(labels, dtype=logits.dtype) @@ -132,7 +132,7 @@ def _set_input_tensor(self, input_tensor: Tensor): obj = self.teacher_model if self._only_teacher_fwd else self return type(self).set_input_tensor(obj, input_tensor) - # HACK: Pipeline-parallel Distillation requires a way to cache input batches for subsequent + # Pipeline-parallel Distillation requires a way to cache input batches for subsequent # forward calls, as well as a way to pass through output tensors to teacher model. model.set_input_tensor = MethodType(_set_input_tensor, model) @@ -149,20 +149,6 @@ def _swap_teacher_config(self, model_wrapper): model_wrapper.config = model_wrapper._config del model_wrapper._config - # HACK: Pipeline-parallel forward function relies on the config in the model to know what + # Pipeline-parallel forward function relies on the config in the model to know what # hidden size of tensor to communicate to next stage. model.swap_teacher_config = MethodType(_swap_teacher_config, model) - - -def load_cicd_models(student_path: str): - # pylint: disable=C0116 - import os.path - - from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer - from tests.collections.llm.common import Llama3ConfigCI - - tokenizer = get_nmt_tokenizer(tokenizer_model=os.path.join(student_path, "dummy_tokenizer.model")) - student_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) - teacher_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) - - return student_model, teacher_model, tokenizer diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index 6224369ca1f1..64271551b08c 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -53,17 +53,13 @@ def get_args(): parser.add_argument("--lr", type=float, default=3e-5, help="""Base LR for Cosine-Annealing scheduler""") parser.add_argument("--min_lr", type=float, default=2e-7, help="""Minimum LR for Cosine-Annealing scheduler""") parser.add_argument("--warmup_steps", type=int, default=50, help="""Number of scheduler warmup steps""") - parser.add_argument("--val_check_interval", type=int, default=100, help="""Run validation every _ steps""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""Validate + checkpoint every _ steps""") parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") - parser.add_argument("--cicd_run", action="store_true", help="Used only when called by NeMo CI") return parser.parse_args() -######################################################## - - if __name__ == "__main__": args = get_args() @@ -87,16 +83,11 @@ def get_args(): ) ## Load both models and combine into an aggregate module - if args.cicd_run: - from nemo.collections.llm.distillation.utils import load_cicd_models - - _student_model, _teacher_model, tokenizer = load_cicd_models(args.student_path) - else: - _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") - _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") + _student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + _teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model") - tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) - assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." + tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None) + assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included." model = distill.DistillationGPTModel( _student_model.config, diff --git a/tests/collections/llm/gpt_distillation.py b/tests/collections/llm/gpt_distillation.py new file mode 100644 index 000000000000..4f8e0c1cef37 --- /dev/null +++ b/tests/collections/llm/gpt_distillation.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser + +from lightning.pytorch.loggers import TensorBoardLogger +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.common.tokenizers.huggingface import AutoTokenizer +from nemo.collections.llm import distillation as distill +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from tests.collections.llm.common import Llama3ConfigCI + +# Suppress lengthy HF warning +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def get_args(): + """Parse the command line arguments.""" + parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") + + parser.add_argument("--name", type=str, required=True, help="""Experiment name""") + parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""") + parser.add_argument("--tp_size", type=int, default=1, help="""Tensor parallel size""") + parser.add_argument("--cp_size", type=int, default=1, help="""Context parallel size""") + parser.add_argument("--pp_size", type=int, default=1, help="""Pipeline parallel size""") + parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""") + parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""") + parser.add_argument("--num_nodes", type=int, default=1, help="""Number of nodes to use""") + parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""") + parser.add_argument("--max_steps", type=int, required=True, help="""Number of global batches to process""") + parser.add_argument("--gbs", type=int, required=True, help="""Global Batch Size""") + parser.add_argument("--mbs", type=int, required=True, help="""Micro-batch Size""") + parser.add_argument("--data_paths", nargs="+", required=True, help="""List of tokenized data paths to load from""") + parser.add_argument("--split", type=str, default="99,1,0", help="""Train,Val,Test ratios to split data""") + parser.add_argument("--index_mapping_dir", type=str, default=None, help="""Folder to write cached data indices""") + parser.add_argument("--seq_length", type=int, required=True, help="""Number of tokens per input sample""") + parser.add_argument("--lr", type=float, default=3e-5, help="""Base LR for Cosine-Annealing scheduler""") + parser.add_argument("--min_lr", type=float, default=2e-7, help="""Minimum LR for Cosine-Annealing scheduler""") + parser.add_argument("--warmup_steps", type=int, default=50, help="""Number of scheduler warmup steps""") + parser.add_argument("--val_check_interval", type=int, default=100, help="""Validate + checkpoint every _ steps""") + parser.add_argument("--limit_val_batches", type=int, default=32, help="""Number of batches per validation stage""") + parser.add_argument("--log_interval", type=int, default=10, help="""Write to log every _ steps""") + + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + + ## Initialize the strategy and trainer + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, + sequence_parallel=(args.tp_size > 1), + ) + trainer = nl.Trainer( + devices=args.devices, + num_nodes=args.num_nodes, + max_steps=args.max_steps, + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval, + limit_val_batches=args.limit_val_batches, + strategy=strategy, + accelerator="gpu", + plugins=nl.MegatronMixedPrecision(precision=args.precision), + ) + + ## Load both models and combine into an aggregate module + # NOTE: Special model and tokenizer for CI runs only + tokenizer = AutoTokenizer("gpt2") + _student_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + _teacher_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer) + + model = distill.DistillationGPTModel( + _student_model.config, + _teacher_model.config, + teacher_ckpt_path=args.teacher_path, + tokenizer=tokenizer, + ) + # TODO(aanoosheh): Replace spec with modelopt one + model.__io__ = _student_model.__io__ # HACK: model saves and restores as original class + + # Set up dataset + data = llm.PreTrainingDataModule( + paths=args.data_paths, + seq_length=args.seq_length, + global_batch_size=args.gbs, + micro_batch_size=args.mbs, + split=args.split, + index_mapping_dir=args.index_mapping_dir, + tokenizer=tokenizer, + ) + + ## Set up optimizer + opt_config = OptimizerConfig( + optimizer="adam", + lr=args.lr, + bf16=("bf16" in args.precision), + use_distributed_optimizer=True, + ) + sched = CosineAnnealingScheduler( + max_steps=args.max_steps, + warmup_steps=args.warmup_steps, + constant_steps=0, + min_lr=args.min_lr, + ) + opt = nl.MegatronOptimizerModule(opt_config, sched) + + # Set up checkpointing and logging + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_top_k=1, + every_n_train_steps=args.val_check_interval, + ) + logger = nl.NeMoLogger( + name=args.name, + log_dir=args.log_dir, + ckpt=checkpoint_callback, + tensorboard=TensorBoardLogger(os.path.join(args.log_dir, args.name)), + update_logger_directory=False, + ) + + # Set up resume and/or restore functionality + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + restore_config=nl.RestoreConfig(path=args.student_path), + ) + + # Run + llm.train( + model=model, + data=data, + optim=opt, + tokenizer="model", + trainer=trainer, + log=logger, + resume=resume, + ) From 358b8102e8b31017f68bcacf8b4714f218efc420 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 29 Jan 2025 08:32:22 -0800 Subject: [PATCH 20/22] Update cicd dependency Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 50c7321fe65b..7ee02d5c7949 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5063,7 +5063,7 @@ jobs: rm -rf /tmp/nemo2_ptq_engine L2_NeMo_2_Distill_Llama3_TP1PP2: - needs: [cicd-test-container-setup] + needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.cicd-test-container-setup.outputs.all == 'true' with: From 47d9021e8e82259c09dc9e21e28a36eb7c798f81 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 29 Jan 2025 08:34:46 -0800 Subject: [PATCH 21/22] Update cicd thing #2 Signed-off-by: Asha Anoosheh --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 7ee02d5c7949..656df82fd3e7 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -5065,7 +5065,7 @@ jobs: L2_NeMo_2_Distill_Llama3_TP1PP2: needs: [pre-flight, cicd-test-container-build] uses: ./.github/workflows/_test_template.yml - if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.cicd-test-container-setup.outputs.all == 'true' + if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.pre-flight.outputs.all == 'true' with: RUNNER: self-hosted-azure SCRIPT: | From e297474a33e819a8111ff8836b9eb782d1e225c1 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 4 Feb 2025 14:10:42 -0800 Subject: [PATCH 22/22] Fix new linting errors Signed-off-by: Asha Anoosheh --- nemo/collections/llm/gpt/model/base.py | 8 +++++--- nemo/lightning/megatron_parallel.py | 27 ++++++++++++++------------ nemo/utils/model_utils.py | 25 +++++++++++++++++------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 62389099e807..4af9c6c1263b 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -40,7 +40,7 @@ # TODO: Clean this up with a getter and install instructions _grad_accum_fusion_available = True try: - import fused_weight_gradient_mlp_cuda + import fused_weight_gradient_mlp_cuda # noqa: F401 # pylint: disable=unused-import except ImportError: _grad_accum_fusion_available = False @@ -237,7 +237,8 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC # TP, CP group to the TE modules. # Deep iterate but skip self to avoid infinite recursion. if HAVE_TE and self.use_transformer_engine_full_layer_spec: - # Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py + # Copied from: + # https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py if parallel_state.get_tensor_model_parallel_world_size() > 1: for index, child in enumerate(model.modules()): if index == 0: @@ -414,7 +415,8 @@ def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_thres vocab_size = self.config.vocab_size else: raise ValueError( - 'Unable to find vocab size. Either pass in a tokenizer with vocab size, or set vocab size in the model config' + 'Unable to find vocab size.' + ' Either pass in a tokenizer with vocab size, or set vocab size in the model config' ) inference_wrapper_config = InferenceWrapperConfig( diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 9b21e9052b93..7d0b5ba31f0c 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -257,9 +257,12 @@ def forward( Args: data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The input data for the model. forward_only (bool, optional): If True, only perform the forward pass. Defaults to True. - data_step (Optional[Callable[[Iterator[DataT]], DataT]], optional): Function to process the data. Defaults to None. - forward_step (Optional[Callable[[nn.Module, DataT], Tensor]], optional): Function to perform the forward pass. Defaults to None. - loss_reduction (Optional[MegatronLossReduction[DataT, Any]], optional): Function to reduce the loss. Defaults to None. + data_step (Optional[Callable[[Iterator[DataT]], DataT]], optional): Function to process the data. + Defaults to None. + forward_step (Optional[Callable[[nn.Module, DataT], Tensor]], optional): Function to perform the + forward pass. Defaults to None. + loss_reduction (Optional[MegatronLossReduction[DataT, Any]], optional): Function to reduce the + loss. Defaults to None. seq_length (Optional[int], optional): Sequence length for the model. Defaults to None. micro_batch_size (Optional[int], optional): Size of the micro batch. Defaults to None. num_microbatches (Optional[int], optional): Number of microbatches. Defaults to None. @@ -604,14 +607,15 @@ def init_model_parallel(self): msg = ( f" > number of parameters on (tensor, pipeline) model parallel rank " - f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " + f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): " # pylint: disable=line-too-long f"{num_params}" ) logging.info(msg) if num_params != num_trainable_params: logging.info( - f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" + " > number of trainable parameters: " + f"{num_trainable_params} ({num_trainable_params / num_params:.2%} of total)" ) if self.convert_module_fn: @@ -903,8 +907,8 @@ class CallbackConnector: Each of these methods corresponds to a specific stage in the model's operation. You can define these methods in your callback functions to perform specific actions at these stages. - There is no need for the class to be a subclass of a specific parent class. As long as the class contains the methods outlined above, - it can be used as a callback. + There is no need for the class to be a subclass of a specific parent class. + As long as the class contains the methods outlined above, it can be used as a callback. """ def __init__(self, callbacks=None) -> None: @@ -1133,7 +1137,8 @@ class MegatronStep(Generic[ModelT, DataT]): micro_batch_size (Optional[int]): Size of each micro-batch. seq_length (Optional[int]): Sequence length for the current step. num_microbatches (Optional[int]): Number of micro-batches in this step. - decoder_seq_length (Optional[int]): Sequence length of decoder (used only in encoder-decoder style models) for the current step. + decoder_seq_length (Optional[int]): Sequence length of decoder (used only in + encoder-decoder style models) for the current step. Type Parameters: ModelT: The type of the model being used. @@ -1705,9 +1710,7 @@ def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> def forward( self, batch: Dict[str, torch.Tensor], forward_out: torch.Tensor ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """Taken from: - https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 . - """ + """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .""" # pylint: disable=line-too-long from megatron.core import parallel_state from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group @@ -1745,7 +1748,7 @@ def forward( return loss_for_ub * cp_size, {"avg": reduced_loss} def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor: - """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" + """Taken from: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 .""" # pylint: disable=line-too-long if losses_reduced_per_micro_batch: if "avg" in losses_reduced_per_micro_batch[0]: loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch] diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index a64e8d7b4a22..974263eb7f68 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -24,13 +24,15 @@ from enum import Enum from functools import lru_cache from pathlib import Path -from typing import List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import wrapt from nemo.utils import AppState, logging -from nemo.utils.data_utils import resolve_cache_dir # imported for compatibility: model_utils.resolve_cache_dir() -from nemo.utils.data_utils import is_datastore_path +from nemo.utils.data_utils import ( # imported for compatibility: model_utils.resolve_cache_dir() # noqa: F401 # pylint: disable=unused-import,line-too-long + is_datastore_path, + resolve_cache_dir, +) # TODO @blisc: Perhaps refactor instead of import guarding @@ -43,6 +45,12 @@ except ModuleNotFoundError: _HAS_HYDRA = False +if TYPE_CHECKING: + import lightning.pytorch as pl + + from nemo.core.classes import ModelPT, PretrainedModelInfo + from nemo.core.config.modelPT import NemoConfig + MODEL_CONFIG = "model_config.yaml" _VAL_TEST_FASTPATH_KEY = 'ds_item' @@ -346,7 +354,8 @@ def resolve_validation_dataloaders(model: 'ModelPT'): if len(ds_names) > 0: if len(ds_names) != len(ds_values): raise ValueError( - f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + f"Number of names ({len(ds_names)}) does not match number of " + f"datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" ) model._validation_names = [parse_dataset_as_name(n) for n in ds_names] else: @@ -440,7 +449,8 @@ def resolve_test_dataloaders(model: 'ModelPT'): if len(ds_names) > 0: if len(ds_names) != len(ds_values): raise ValueError( - f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + f"Number of names ({len(ds_names)}) does not match number of " + f"datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" ) model._test_names = [parse_dataset_as_name(n) for n in ds_names] else: @@ -652,7 +662,8 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op return True, msg else: msg = ( - f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} than required version {checked_version}.\n" + f"Lib {lib_name} version ({lib_ver}) is not {operator.__name__} " + f"than required version {checked_version}.\n" f"Please upgrade the lib using either pip or conda to the latest version." ) return False, msg @@ -696,7 +707,7 @@ def inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=False): if app_state.pipeline_model_parallel_size is None or app_state.pipeline_model_parallel_size == 1: filepath = f'{dirname}/mp_rank_{app_state.tensor_model_parallel_rank:02d}{fsdp_shard}/{basename}' else: - filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}' + filepath = f'{dirname}/tp_rank_{app_state.tensor_model_parallel_rank:02d}_pp_rank_{app_state.pipeline_model_parallel_rank:03d}/{basename}' # pylint: disable=line-too-long return filepath else: fsdp_shard = f'/fsdp_shard_{app_state.data_parallel_rank:05d}' if fsdp_sharded_ckpt else ''