Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5062,6 +5062,36 @@ jobs:
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine

L2_NeMo_2_Distill_Llama3_TP1PP2:
needs: [pre-flight, cicd-test-container-build]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.pre-flight.outputs.test_to_run), 'L2_NeMo_2_Distill_Llama3_TP1PP2') || needs.pre-flight.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/gpt_distillation.py \
--name nemo2_llama_distill \
--teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--tp_size 1 \
--cp_size 1 \
--pp_size 2 \
--devices 2 \
--log_dir /tmp/distill_logs \
--max_steps 5 \
--gbs 4 \
--mbs 1 \
--data_paths 1.0 /home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \
--index_mapping_dir examples/nlp/language_modeling/gpt_index_mappings \
--seq_length 2048 \
--warmup_steps 1 \
--val_check_interval 5 \
--log_interval 5 \
--limit_val_batches 2

AFTER_SCRIPT: |
rm -rf /tmp/distill_logs

L2_NeMo_2_Export_In_Framework:
needs: [pre-flight, cicd-test-container-build]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -5321,6 +5351,7 @@ jobs:
- L2_Megatron_GPT_Reranker
- L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact
- L2_NeMo_2_PTQ_Llama2_FP8
- L2_NeMo_2_Distill_Llama3_TP1PP2
- L2_NeMo_2_Export_In_Framework
- L2_NeMo_2_jit_callback
- L2_NeMo_2_LLAVA_NEXT_MOCK_TRAINING
Expand Down
18 changes: 18 additions & 0 deletions nemo/collections/llm/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .loss import LogitsKLLoss
from .model import DistillationGPTModel

__all__ = ["LogitsKLLoss", "DistillationGPTModel"]
184 changes: 184 additions & 0 deletions nemo/collections/llm/distillation/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta
from typing import TYPE_CHECKING, Tuple

import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from torch import Tensor
from torch.nn.modules.loss import _Loss

if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig


class BaseLoss(_Loss, metaclass=ABCMeta):
"""Abstract base class for Megatron distillation losses."""

def __init__(self, model_config: "TransformerConfig"):
"""
Constructor.

Args:
model_config: MCore transformer config.
"""
super().__init__()
self._config = model_config

def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""Prepares inputs safely for loss computation."""
if isinstance(predictions, tuple):
# `ColumnParallelLinear` returns bias too
predictions, targets = predictions[0], targets[0]
targets = targets.detach()

return predictions, targets

def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor:
"""Reshapes tensor from [s, b] to [b, s] for upcoming loss masking."""
loss = loss.transpose(0, 1).contiguous()
return loss, tp_reduce


class LogitsKLLoss(BaseLoss):
"""Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim."""

def __init__(self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False):
"""
Constructor.

Args:
model_config: MCore transformer config.
temperature: Divide tensors by this value prior to calculating loss.
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
"""
super().__init__(model_config)
self._temperature = temperature
self._reverse = reverse

def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
"""
Forward function.

Args:
predictions: Student model tensors (size [s, b, h])
targets: Teacher model tensors (size [s, b, h])

Returns:
KLD loss of tensors (size [b, s])
"""
predictions, targets = self.pre_forward(predictions, targets)

# Division by temp should happen prior to finding max for both student and teacher.
# Currently we don't use temperature in any of ours runs (temp=1.0)
output_teacher = targets.float() / self._temperature
output_student = predictions.float() / self._temperature

# Compute local softmax, and the reweight to compute global softmax.
if self._config.tensor_model_parallel_size > 1:

# Maximum value along vocab dimension across all GPUs.
teacher_logits_max, _ = torch.max(output_teacher, dim=-1)
torch.distributed.all_reduce(
teacher_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1)

denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1)
# We can't use standard reduction function here since the computation
# that follows it isn't identical across TP ranks.
denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group())

# Maximum value along vocab dimension across all GPUs.
student_logits_max, _ = torch.max(output_student, dim=-1)
torch.distributed.all_reduce(
student_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
)
output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach()

denom_student = torch.sum(torch.exp(output_student), dim=-1)
denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group())

slen, bsz, sharded_vocab_size = output_student.shape
student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
)
teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
)

if self._reverse:
loss = torch.sum(
F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True),
dim=-1,
)
else:
loss = torch.sum(
F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True),
dim=-1,
)

else:
if self._reverse:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_teacher, dim=-1),
F.softmax(output_student, dim=-1),
reduction="none",
),
dim=-1,
)
else:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_student, dim=-1),
F.softmax(output_teacher, dim=-1),
reduction="none",
),
dim=-1,
)

return self.post_forward(loss, tp_reduce=True)


class _AllReduce(torch.autograd.Function):
"""Implementation from old PyTorch `torch.distributed.nn.parallel`."""

@staticmethod
def forward(ctx, op, group, tensor):
# pylint: disable=C0116
ctx.group, ctx.op = group, op
tensor = tensor.clone()
torch.distributed.all_reduce(tensor, op=op, group=group)
return tensor

@staticmethod
def backward(ctx, grad_output):
# pylint: disable=C0116
return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output))


def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD):
"""Custom all-reduce function.

Needed instead of other all-reduce functions available when the computation following
the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators.
"""
return _AllReduce.apply(op, group, tensor)
Loading
Loading