-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Pipeline-parallel support for Knowledge Distillation (NeMo 2) #11766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ko3n1g
merged 23 commits into
NVIDIA-NeMo:main
from
AAnoosheh:aanoosheh/pp-distillation-nemo2
Feb 6, 2025
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
4b426db
First draft of distill script port to 2.0
AAnoosheh bc43237
Pipeline-parallel changes
AAnoosheh 3f7b9c6
Basic distillation running
AAnoosheh d87e64f
Add CLI args
AAnoosheh 12019bf
Most fixes
AAnoosheh c3f5fb2
Fix callbacks in PP loop
AAnoosheh 8770138
More fixes
AAnoosheh 758132d
Rework checkpoint loading
AAnoosheh 01864d3
Resolve seemingly remaining bugs
AAnoosheh e0cc0bc
Refactor into multiple files
AAnoosheh 2438a85
Integration test
AAnoosheh d644382
Clean up strings
AAnoosheh a6a9f07
Appease linter
AAnoosheh eff0d90
Remediate failing tests
AAnoosheh 7650d8f
Update CICD model definition
AAnoosheh 140412b
Divert TB logger to same log_dir
AAnoosheh 3d33570
Load CICD model specially
AAnoosheh 94e6b04
Fix SP flag
AAnoosheh 1513acc
Move test into own script
AAnoosheh 358b810
Update cicd dependency
AAnoosheh 47d9021
Update cicd thing #2
AAnoosheh e297474
Fix new linting errors
AAnoosheh b8dda2c
Merge branch 'main' into aanoosheh/pp-distillation-nemo2
yashaswikarnati File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .loss import LogitsKLLoss | ||
| from .model import DistillationGPTModel | ||
|
|
||
| __all__ = ["LogitsKLLoss", "DistillationGPTModel"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from abc import ABCMeta | ||
| from typing import TYPE_CHECKING, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from megatron.core import parallel_state | ||
| from torch import Tensor | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| if TYPE_CHECKING: | ||
| from megatron.core.transformer.transformer_config import TransformerConfig | ||
|
|
||
|
|
||
| class BaseLoss(_Loss, metaclass=ABCMeta): | ||
| """Abstract base class for Megatron distillation losses.""" | ||
|
|
||
| def __init__(self, model_config: "TransformerConfig"): | ||
| """ | ||
| Constructor. | ||
|
|
||
| Args: | ||
| model_config: MCore transformer config. | ||
| """ | ||
| super().__init__() | ||
| self._config = model_config | ||
|
|
||
| def pre_forward(self, predictions: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: | ||
| """Prepares inputs safely for loss computation.""" | ||
| if isinstance(predictions, tuple): | ||
| # `ColumnParallelLinear` returns bias too | ||
| predictions, targets = predictions[0], targets[0] | ||
| targets = targets.detach() | ||
|
|
||
| return predictions, targets | ||
|
|
||
| def post_forward(self, loss: Tensor, tp_reduce: bool = False) -> Tensor: | ||
| """Reshapes tensor from [s, b] to [b, s] for upcoming loss masking.""" | ||
| loss = loss.transpose(0, 1).contiguous() | ||
| return loss, tp_reduce | ||
|
|
||
|
|
||
| class LogitsKLLoss(BaseLoss): | ||
| """Calculates KL-Divergence loss between two logits tensors without reducing the sequence dim.""" | ||
|
|
||
| def __init__(self, model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False): | ||
| """ | ||
| Constructor. | ||
|
|
||
| Args: | ||
| model_config: MCore transformer config. | ||
| temperature: Divide tensors by this value prior to calculating loss. | ||
| reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) | ||
| """ | ||
| super().__init__(model_config) | ||
| self._temperature = temperature | ||
| self._reverse = reverse | ||
|
|
||
| def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: | ||
| """ | ||
| Forward function. | ||
|
|
||
| Args: | ||
| predictions: Student model tensors (size [s, b, h]) | ||
| targets: Teacher model tensors (size [s, b, h]) | ||
|
|
||
| Returns: | ||
| KLD loss of tensors (size [b, s]) | ||
| """ | ||
| predictions, targets = self.pre_forward(predictions, targets) | ||
|
|
||
| # Division by temp should happen prior to finding max for both student and teacher. | ||
| # Currently we don't use temperature in any of ours runs (temp=1.0) | ||
| output_teacher = targets.float() / self._temperature | ||
| output_student = predictions.float() / self._temperature | ||
|
|
||
| # Compute local softmax, and the reweight to compute global softmax. | ||
| if self._config.tensor_model_parallel_size > 1: | ||
|
|
||
| # Maximum value along vocab dimension across all GPUs. | ||
| teacher_logits_max, _ = torch.max(output_teacher, dim=-1) | ||
| torch.distributed.all_reduce( | ||
| teacher_logits_max, | ||
| op=torch.distributed.ReduceOp.MAX, | ||
| group=parallel_state.get_tensor_model_parallel_group(), | ||
| ) | ||
| output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) | ||
|
|
||
| denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) | ||
| # We can't use standard reduction function here since the computation | ||
| # that follows it isn't identical across TP ranks. | ||
| denom_teacher = all_reduce_autograd(denom_teacher, group=parallel_state.get_tensor_model_parallel_group()) | ||
|
|
||
| # Maximum value along vocab dimension across all GPUs. | ||
| student_logits_max, _ = torch.max(output_student, dim=-1) | ||
| torch.distributed.all_reduce( | ||
| student_logits_max, | ||
| op=torch.distributed.ReduceOp.MAX, | ||
| group=parallel_state.get_tensor_model_parallel_group(), | ||
| ) | ||
| output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() | ||
|
|
||
| denom_student = torch.sum(torch.exp(output_student), dim=-1) | ||
| denom_student = all_reduce_autograd(denom_student, group=parallel_state.get_tensor_model_parallel_group()) | ||
|
|
||
| slen, bsz, sharded_vocab_size = output_student.shape | ||
| student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( | ||
| slen, bsz, sharded_vocab_size | ||
| ) | ||
| teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( | ||
| slen, bsz, sharded_vocab_size | ||
| ) | ||
|
|
||
| if self._reverse: | ||
| loss = torch.sum( | ||
| F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), | ||
| dim=-1, | ||
| ) | ||
| else: | ||
| loss = torch.sum( | ||
| F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), | ||
| dim=-1, | ||
| ) | ||
|
|
||
| else: | ||
| if self._reverse: | ||
| loss = torch.sum( | ||
| F.kl_div( | ||
| F.log_softmax(output_teacher, dim=-1), | ||
| F.softmax(output_student, dim=-1), | ||
| reduction="none", | ||
| ), | ||
| dim=-1, | ||
| ) | ||
| else: | ||
| loss = torch.sum( | ||
| F.kl_div( | ||
| F.log_softmax(output_student, dim=-1), | ||
| F.softmax(output_teacher, dim=-1), | ||
| reduction="none", | ||
| ), | ||
| dim=-1, | ||
| ) | ||
|
|
||
| return self.post_forward(loss, tp_reduce=True) | ||
|
|
||
|
|
||
| class _AllReduce(torch.autograd.Function): | ||
| """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, op, group, tensor): | ||
| # pylint: disable=C0116 | ||
| ctx.group, ctx.op = group, op | ||
| tensor = tensor.clone() | ||
| torch.distributed.all_reduce(tensor, op=op, group=group) | ||
| return tensor | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| # pylint: disable=C0116 | ||
| return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) | ||
|
|
||
|
|
||
| def all_reduce_autograd(tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD): | ||
| """Custom all-reduce function. | ||
|
|
||
| Needed instead of other all-reduce functions available when the computation following | ||
| the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. | ||
| """ | ||
| return _AllReduce.apply(op, group, tensor) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.