From 665a42c2543d00865c753c776a86b56e24e71df9 Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Wed, 25 Jun 2025 15:11:02 +0800 Subject: [PATCH 1/6] Optimize get logprobs for CP enabled FSDP2 case. Signed-off-by: Jonas yang --- nemo_rl/algorithms/loss_functions.py | 3 +- nemo_rl/distributed/model_utils.py | 45 ++++++++++++++++--- nemo_rl/models/dtensor/parallelize.py | 27 ++++++++--- .../models/policy/dtensor_policy_worker.py | 31 ++++++++----- 4 files changed, 84 insertions(+), 22 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 1078da5fa3..7a054c223d 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -122,6 +122,7 @@ def __call__( prev_logprobs = data["prev_logprobs"][:, 1:] generation_logprobs = data["generation_logprobs"][:, 1:] reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] + seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -151,7 +152,7 @@ def __call__( ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"] + next_token_logits, data["input_ids"], seq_index=seq_index ) else: next_token_logits_wo_last = next_token_logits[ diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index f1e2e6ac81..0dc2442b0c 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch +from torch.distributed.tensor import DTensor, distribute_tensor @torch.no_grad() @@ -121,11 +122,12 @@ def backward( def from_parallel_logits_to_logprobs( vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, + target: torch.Tensor | DTensor, vocab_start_index: int, vocab_end_index: int, - group: torch.distributed.ProcessGroup, + tp_group: torch.distributed.ProcessGroup, inference_only: bool = False, + seq_index: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Get log probabilities from TP sharded vocab logits. @@ -136,8 +138,10 @@ def from_parallel_logits_to_logprobs( NOTE: Must be the unmodified targets as this function will shift them internally. vocab_start_index (int): Starting vocabulary index for this worker's partition. vocab_end_index (int): Ending vocabulary index for this worker's partition. - group (torch.distributed.ProcessGroup): Process group for distributed communication. + tp_group (torch.distributed.ProcessGroup): Process group for distributed communication. inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False. + seq_index (Optional[torch.Tensor]): Sequence index tensor with shape [seq_len]. + It is only provided for cp sharded logits. It represents how tensor is sharded across the sequence dimension. Returns: torch.Tensor: Log probabilities tensor with shape [batch_size, seq_len-1]. @@ -145,13 +149,42 @@ def from_parallel_logits_to_logprobs( Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 """ - target = target.roll(shifts=-1, dims=-1) + cp_mesh = None + cp_placements = None + target_shape = torch.Size(target.shape) + sorted_indices = None + + if seq_index is not None: + assert isinstance(target, DTensor), ( + "target must be a DTensor if seq_index is provided" + ) + cp_mesh = target.device_mesh + cp_placements = target.placements + _, sorted_indices = torch.sort(seq_index) + # Recover the original order of the target + target = target.full_tensor()[:, sorted_indices] + target = target.roll(shifts=-1, dims=-1)[:, seq_index] + + # Reshard + target = distribute_tensor(target, cp_mesh, cp_placements) + target = target.to_local() + else: + target = target.roll(shifts=-1, dims=-1) + probs: torch.Tensor = DistributedLogprob.apply( # type: ignore vocab_parallel_logits, target, vocab_start_index, vocab_end_index, - group, + tp_group, inference_only, ).contiguous() + + if seq_index is not None: + # probs is sharded on the sequence dimension. + # Get full sequence tensor, vocab dim has been reduced already. + probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements) + probs = probs_dtensor.full_tensor()[:, sorted_indices] + assert probs.shape == target_shape + return probs[:, :-1] diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index fb9c720c20..fe64f9e2e8 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -616,7 +616,9 @@ def get_grad_norm( def get_logprobs_from_vocab_parallel_logits( - vocab_parallel_logits: DTensor, input_ids: torch.Tensor + vocab_parallel_logits: DTensor, + input_ids: torch.Tensor | DTensor, + seq_index: Optional[torch.Tensor] = None, ): """Computes log probabilities from vocabulary-parallel logits. @@ -632,16 +634,31 @@ def get_logprobs_from_vocab_parallel_logits( Returns: torch.Tensor: Log probabilities for the given input IDs. """ - tp_mesh = vocab_parallel_logits.device_mesh - tp_rank: int = tp_mesh.get_local_rank() + device_mesh = vocab_parallel_logits.device_mesh + if seq_index is not None: + assert "cp" in device_mesh.mesh_dim_names, ( + "seq_index must be provided for cp sharded logits" + ) + + cp_size = 1 + tp_size = 1 + + tp_group = device_mesh.get_group("tp") + tp_rank = tp_group.rank() + tp_size = tp_group.size() + + if "cp" in device_mesh.mesh_dim_names: + cp_group = device_mesh.get_group("cp") + cp_size = cp_group.size() - vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_mesh.size() + vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_size return from_parallel_logits_to_logprobs( vocab_parallel_logits.to_local(), input_ids, vocab_interval_per_rank * tp_rank, (tp_rank + 1) * vocab_interval_per_rank, - tp_mesh.get_group(), + tp_group, inference_only=not torch.is_grad_enabled(), + seq_index=seq_index, ) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 46e1e8a52a..68a1c97496 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -580,7 +580,8 @@ def train( .full_tensor() .squeeze(0) ) - _, sorted_indices = torch.sort(seq_index_dtensor) + + mb["seq_index"] = seq_index_dtensor for tensor_name in mb: current_tensor = mb[tensor_name] @@ -593,18 +594,28 @@ def train( current_tensor, device_mesh=self.cp_mesh, placements=[Shard(sequence_dim)], - ).full_tensor()[:, sorted_indices] + ) break if isinstance(logits, DTensor): - logits = logits.full_tensor() - - logits_dtensor = DTensor.from_local( - logits, - device_mesh=self.cp_mesh, - placements=[Shard(sequence_dim)], - ) - logits = logits_dtensor.full_tensor()[:, sorted_indices] + # Must be tp sharded + assert ( + logits.device_mesh.ndim == 1 + and logits.device_mesh.mesh_dim_names[0] == "tp" + ), "logits must be tp sharded" + local_logits = logits.to_local() + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim + logits = DTensor.from_local( + local_logits, + device_mesh=self.device_mesh["cp", "tp"], + placements=[Shard(sequence_dim), Shard(-1)], + ) + else: + logits = DTensor.from_local( + logits, + device_mesh=self.device_mesh[("cp", "tp")], + placements=[Shard(sequence_dim), Shard(-1)], + ) loss, loss_metrics = loss_fn( logits, mb, global_valid_seqs, global_valid_toks From 519e7b89568ee49d0817794a622546f1335f3586 Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Wed, 25 Jun 2025 15:13:39 +0800 Subject: [PATCH 2/6] Remove unused code. Signed-off-by: Jonas yang --- nemo_rl/models/dtensor/parallelize.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index fe64f9e2e8..370624a163 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -640,17 +640,12 @@ def get_logprobs_from_vocab_parallel_logits( "seq_index must be provided for cp sharded logits" ) - cp_size = 1 tp_size = 1 tp_group = device_mesh.get_group("tp") tp_rank = tp_group.rank() tp_size = tp_group.size() - if "cp" in device_mesh.mesh_dim_names: - cp_group = device_mesh.get_group("cp") - cp_size = cp_group.size() - vocab_interval_per_rank = vocab_parallel_logits.shape[-1] // tp_size return from_parallel_logits_to_logprobs( From ef33ca7e0ee3c62235d3ef629c76b4611ce491e6 Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Wed, 25 Jun 2025 21:16:49 +0800 Subject: [PATCH 3/6] Fix CI. Signed-off-by: Jonas yang --- nemo_rl/algorithms/loss_functions.py | 6 +++--- nemo_rl/models/policy/megatron_policy_worker.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 7a054c223d..1bf472d830 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -147,7 +147,7 @@ def __call__( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): @@ -333,7 +333,7 @@ def __call__( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): @@ -481,7 +481,7 @@ def _preference_loss( data["input_ids"], vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - group=vocab_parallel_group, + tp_group=vocab_parallel_group, inference_only=False, ) elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 89eb263674..ec8f78dc8d 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -947,7 +947,7 @@ def collection_fn(output_tensor): target=input_ids, vocab_start_index=tp_rank * output_tensor.shape[-1], vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - group=tp_grp, + tp_group=tp_grp, inference_only=True, ) From 023b6e43113d3244195ecc8dd80cb7ae83fef21a Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Tue, 1 Jul 2025 21:20:24 +0800 Subject: [PATCH 4/6] Fix review. Signed-off-by: Jonas yang --- nemo_rl/models/policy/dtensor_policy_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 68a1c97496..793b9e5a96 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -603,12 +603,12 @@ def train( logits.device_mesh.ndim == 1 and logits.device_mesh.mesh_dim_names[0] == "tp" ), "logits must be tp sharded" - local_logits = logits.to_local() + # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim - logits = DTensor.from_local( - local_logits, + logits = logits.redistribute( device_mesh=self.device_mesh["cp", "tp"], placements=[Shard(sequence_dim), Shard(-1)], + async_op=True, ) else: logits = DTensor.from_local( From 23a99d4f441a5c41818500c943248c42069e5852 Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Tue, 1 Jul 2025 22:01:42 +0800 Subject: [PATCH 5/6] Revert review feedback. Signed-off-by: Jonas yang --- nemo_rl/models/policy/dtensor_policy_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 793b9e5a96..d9d28faa6d 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -605,10 +605,10 @@ def train( ), "logits must be tp sharded" # CP is implicitly sharded on the seq dim, so we need to redistribute to the tp dim - logits = logits.redistribute( - device_mesh=self.device_mesh["cp", "tp"], + logits = DTensor.from_local( + logits.to_local(), + device_mesh=self.device_mesh[("cp", "tp")], placements=[Shard(sequence_dim), Shard(-1)], - async_op=True, ) else: logits = DTensor.from_local( From ec72dddf5cd553e42127395b3afdf732e2c4ccf2 Mon Sep 17 00:00:00 2001 From: Jonas yang Date: Wed, 2 Jul 2025 11:27:30 +0800 Subject: [PATCH 6/6] Follow up review to build cp branch. Signed-off-by: Jonas yang --- nemo_rl/distributed/model_utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 0dc2442b0c..31ac71cc23 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -149,15 +149,15 @@ def from_parallel_logits_to_logprobs( Taken from: https://github.com/NVIDIA/NeMo-Aligner/blob/9faab404f21994a7eb1d6ed5890b76152b941636/nemo_aligner/utils/distributed.py#L354 """ - cp_mesh = None - cp_placements = None - target_shape = torch.Size(target.shape) - sorted_indices = None - - if seq_index is not None: - assert isinstance(target, DTensor), ( - "target must be a DTensor if seq_index is provided" - ) + cp_size = 1 + + if isinstance(target, DTensor) and "cp" in target.device_mesh.mesh_dim_names: + cp_dim_index = target.device_mesh.mesh_dim_names.index("cp") + cp_size = target.device_mesh.shape[cp_dim_index] + + if cp_size > 1: + assert seq_index is not None, "seq_index must be provided for cp sharded logits" + target_shape = torch.Size(target.shape) cp_mesh = target.device_mesh cp_placements = target.placements _, sorted_indices = torch.sort(seq_index) @@ -180,7 +180,7 @@ def from_parallel_logits_to_logprobs( inference_only, ).contiguous() - if seq_index is not None: + if cp_size > 1: # probs is sharded on the sequence dimension. # Get full sequence tensor, vocab dim has been reduced already. probs_dtensor = DTensor.from_local(probs, cp_mesh, cp_placements)