From eee023b5a08f79feb1afd55528c2e79a10afd8b3 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 8 Aug 2025 08:04:57 -0700 Subject: [PATCH 1/4] update Signed-off-by: Qidong Su --- 3rdparty/NeMo-workspace/NeMo | 2 +- examples/configs/grpo-deepscaler-1.5b-16K.yaml | 1 + examples/configs/grpo-deepscaler-1.5b-24K.yaml | 5 +---- nemo_rl/models/policy/dtensor_policy_worker.py | 14 ++++++++++---- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 8ddf438734..4b7ded58d8 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 +Subproject commit 4b7ded58d804bf3470499c6cfa385c6fa915879d diff --git a/examples/configs/grpo-deepscaler-1.5b-16K.yaml b/examples/configs/grpo-deepscaler-1.5b-16K.yaml index 866b365da4..570fecb1b9 100644 --- a/examples/configs/grpo-deepscaler-1.5b-16K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-16K.yaml @@ -8,6 +8,7 @@ loss_fn: policy: max_total_sequence_length: 16384 + logprob_batch_size: 2 dtensor_cfg: enabled: true diff --git a/examples/configs/grpo-deepscaler-1.5b-24K.yaml b/examples/configs/grpo-deepscaler-1.5b-24K.yaml index 52d1ed2018..4da4a1674d 100644 --- a/examples/configs/grpo-deepscaler-1.5b-24K.yaml +++ b/examples/configs/grpo-deepscaler-1.5b-24K.yaml @@ -8,6 +8,7 @@ loss_fn: policy: max_total_sequence_length: 24576 + logprob_batch_size: 2 dtensor_cfg: enabled: true @@ -47,7 +48,3 @@ policy: # For most cases, use "dummy" to load the initial weights, since they will be overwritten during refit # For Gemma models, we need to use "auto" due to a vllm bug load_format: dummy - -cluster: - gpus_per_node: 8 - num_nodes: 4 diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 18ae23d95b..dd265e9aea 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -720,6 +720,7 @@ def train( logits = self.model.lm_head(outputs.last_hidden_state) else: logits = outputs.logits + del outputs # Apply temperature scaling logits = self._apply_temperature_scaling(logits) @@ -786,6 +787,7 @@ def train( global_valid_seqs, global_valid_toks, ) + del logits # skip the update for dummy batches if mb_idx < iterator_len: @@ -1044,8 +1046,9 @@ def get_logprobs( placements=[Shard(sequence_dim), Shard(-1)], ) + logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), + logits, input_ids_dtensor, seq_index_tensor, ) @@ -1053,8 +1056,9 @@ def get_logprobs( assert token_logprobs.shape[1] == seq_len - 1 else: if isinstance(logits, DTensor): + logits = logits.to(torch.float32) token_logprobs = get_logprobs_from_vocab_parallel_logits( - logits.to(torch.float32), input_ids + logits, input_ids ) else: # Extract logprobs for each token in the sequence by gathering the logprob @@ -1064,15 +1068,17 @@ def get_logprobs( # token_ids: [batch_size, sequence_length] - actual tokens # Output shape: [batch_size, sequence_length] - logprob of each token given previous # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - + outputs.logits = outputs.logits.to(torch.float32) log_probs = torch.nn.functional.log_softmax( - outputs.logits.to(torch.float32), dim=-1 + outputs.logits, dim=-1 ) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] token_logprobs = log_probs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) + + del outputs, logits token_logprobs = torch.cat( [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 From 9040aa2f373534d051ece54fdd37fcc7b7cc1540 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 8 Aug 2025 10:53:48 -0700 Subject: [PATCH 2/4] fix Signed-off-by: Qidong Su --- nemo_rl/models/policy/dtensor_policy_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index dd265e9aea..8867245f9c 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1068,9 +1068,9 @@ def get_logprobs( # token_ids: [batch_size, sequence_length] - actual tokens # Output shape: [batch_size, sequence_length] - logprob of each token given previous # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length - outputs.logits = outputs.logits.to(torch.float32) + logits = outputs.logits.to(torch.float32) log_probs = torch.nn.functional.log_softmax( - outputs.logits, dim=-1 + logits, dim=-1 ) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] From 325b7b4d118d716a20fb22306167fa9301ff0481 Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 8 Aug 2025 11:35:18 -0700 Subject: [PATCH 3/4] fix Signed-off-by: Qidong Su --- 3rdparty/NeMo-workspace/NeMo | 2 +- nemo_rl/models/policy/dtensor_policy_worker.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 4b7ded58d8..8ddf438734 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 4b7ded58d804bf3470499c6cfa385c6fa915879d +Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 8867245f9c..fe8efeb483 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -1069,15 +1069,13 @@ def get_logprobs( # Output shape: [batch_size, sequence_length] - logprob of each token given previous # We get logprob of token[t+1] from logits[t], prepending 0 to maintain sequence length logits = outputs.logits.to(torch.float32) - log_probs = torch.nn.functional.log_softmax( - logits, dim=-1 - ) + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) next_tokens = input_ids[:, 1:] log_probs = log_probs[:, :-1] token_logprobs = log_probs.gather( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) - + del outputs, logits token_logprobs = torch.cat( From 988d2e9b127ccbdef3bfdcd9d733da1e8d12c11a Mon Sep 17 00:00:00 2001 From: Qidong Su Date: Fri, 8 Aug 2025 15:47:08 -0700 Subject: [PATCH 4/4] update nemo Signed-off-by: Qidong Su --- 3rdparty/NeMo-workspace/NeMo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/NeMo-workspace/NeMo b/3rdparty/NeMo-workspace/NeMo index 8ddf438734..aaefedd1d1 160000 --- a/3rdparty/NeMo-workspace/NeMo +++ b/3rdparty/NeMo-workspace/NeMo @@ -1 +1 @@ -Subproject commit 8ddf4387344c6423763ec9ee0c9a755cbb5d8d35 +Subproject commit aaefedd1d13f4ccd5cd06a19e06f1df33589a235