From 0f863195d631b64b4f90e86e680fcdcd1149e820 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Wed, 16 Apr 2025 12:06:20 -0700 Subject: [PATCH 1/4] FSDP2 SFT changes Signed-off-by: Yi-Fu Wu --- examples/configs/sft.yaml | 4 ++++ nemo_reinforcer/algorithms/loss_functions.py | 18 ++++++++++++------ nemo_reinforcer/algorithms/sft.py | 7 +++++++ .../models/policy/dtensor_policy_worker.py | 2 +- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 785b6e0d2e..28126b526c 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -47,6 +47,10 @@ policy: weight_decay: 0.1 betas: [0.9, 0.98] eps: 1e-5 + # when using Dtensor, we need to set foreach + # and fused to False + foreach: False + fused: False data: max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d674e7deb0..e163672c7a 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -168,14 +168,20 @@ def __call__( sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + next_token_logits = next_token_logits.to(torch.float32) # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) + if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + token_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"] + ) + else: + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token + next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + logprobs = next_token_logprobs[:, :-1] # Remove last position's logits + token_logprobs = logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) # Only compute loss on generated tokens (not input tokens) # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 45f4f08575..e6a6b3f418 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -237,6 +237,7 @@ def validate( val_metrics = {"val_loss": 0.0} + policy.prepare_for_training() for batch_idx, val_batch in enumerate(val_dataloader): ## add loss mask based on role to every message add_loss_mask_to_message_log( @@ -247,6 +248,9 @@ def validate( cat_and_padded, input_lengths = batched_message_log_to_flat_message( val_batch["message_log"], pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], ) val_data: BatchedDataDict = BatchedDataDict( @@ -358,6 +362,9 @@ def sft_train( cat_and_padded, input_lengths = batched_message_log_to_flat_message( batch["message_log"], pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], ) train_data: BatchedDataDict = BatchedDataDict( diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index c967a53c97..a520eac543 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -348,7 +348,7 @@ def train( local_loss = torch.tensor(losses, device="cuda") global_loss = torch.zeros_like(local_loss) torch.distributed.all_reduce(local_loss) - global_loss = local_loss / self.dp_size + global_loss = local_loss / (self.dp_size * self.tp_size) # Aggregate metrics across all microbatches mb_metrics = defaultdict(list) From e2f01c548d91f0aae50e776acc085ccfdc325fde Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Thu, 17 Apr 2025 12:56:37 -0700 Subject: [PATCH 2/4] Fix Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/models/policy/dtensor_policy_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index a520eac543..eec422e640 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -321,6 +321,7 @@ def train( mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) + grad_norm = None if not eval_mode: with torch.no_grad(): grad_norm = get_grad_norm( From 417d5ddb0481f9da3bf620a9bc8f1b0276f7d6e8 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Thu, 17 Apr 2025 12:58:05 -0700 Subject: [PATCH 3/4] ruff Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index e163672c7a..a3c4632dd4 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -177,7 +177,9 @@ def __call__( ) else: next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) logprobs = next_token_logprobs[:, :-1] # Remove last position's logits token_logprobs = logprobs.gather( dim=-1, index=next_tokens.unsqueeze(-1) From 4f5fac44839104d874791f3f78acabdb0e20949f Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 21 Apr 2025 12:34:14 -0700 Subject: [PATCH 4/4] Address comments Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 2 +- nemo_reinforcer/models/policy/dtensor_policy_worker.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index a3c4632dd4..ef5a698678 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -112,7 +112,7 @@ def __call__( next_token_logprobs = torch.nn.functional.log_softmax( next_token_logits, dim=-1 ) - next_tokens = data["input_ids"][:, 1:] # Skip first token + next_tokens = data.get("input_ids")[:, 1:].cuda() # Skip first token curr_logprobs = next_token_logprobs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index eec422e640..a7c7f717fb 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -348,8 +348,8 @@ def train( with torch.no_grad(): local_loss = torch.tensor(losses, device="cuda") global_loss = torch.zeros_like(local_loss) - torch.distributed.all_reduce(local_loss) - global_loss = local_loss / (self.dp_size * self.tp_size) + torch.distributed.all_reduce(local_loss, group=self.dp_mesh.get_group()) + global_loss = local_loss / self.dp_size # Aggregate metrics across all microbatches mb_metrics = defaultdict(list)