Skip to content

feat: add on policy distillation algorithm#1006

Merged
terrykong merged 50 commits intoNVIDIA-NeMo:mainfrom
shuo-nvidia:feat-distillation
Sep 29, 2025
Merged

feat: add on policy distillation algorithm#1006
terrykong merged 50 commits intoNVIDIA-NeMo:mainfrom
shuo-nvidia:feat-distillation

Conversation

@zpqiu
Copy link
Copy Markdown
Contributor

@zpqiu zpqiu commented Aug 28, 2025

What does this PR do ?

This PR implements the on-policy distillation algorithm used in Qwen3, a knowledge distillation approach where the student model generates on-policy sequences based on sampled prompts. The student is then fine-tuned by aligning its output logits with those of a larger teacher model (such as Qwen3-32B or Qwen3-235B-A22B) to minimize the KL divergence. This method enables efficient fine-tuning by leveraging the teacher’s knowledge on the student’s own generated trajectories, resulting in lightweight models that achieve performance close to or surpassing larger models while reducing computational cost compared to reinforcement learning.

Issues

List issues that this PR closes (syntax):
#910

Usage

  • You can potentially add a usage example below
uv run python examples/run_distillation_math.py --config examples/configs/distillation_math.yaml

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

Experiments on Math Domain

  • Off-policy distillation with LMLoss (aka. SFT student models on the distillation dataset generated by other models):
    • 1.7B-Base
      • Dataset: nvidia/OpenMathReasoning (split=cot, generation_model=DeepSeek-R1)
    • 4B-Base:
      • Dataset: nvidia/AceReason-1.1-SFT (category=math)
      • bsz 512, steps: 400
  • Off-policy distillation with KDLoss (aka. finetune student models with the KL loss between teacher model on the distillation dataset generated by other models):
    • 4B-Base:
      • Dataset: nvidia/AceReason-1.1-SFT (category=math)
      • bsz 512, steps: 400
      • test on the best ckpt (step=200)
  • On-policy distillation:
    • Teacher model: Qwen3-32B (thinking mode)
    • Dataset: DeepScaler
    • Max-len: 32K
  • Evaluation:
    • Dataset: AIME 2025
    • Metric: Avg@16
Student Model Original Off-policy-SFT (LMLoss) Off-policy-SFT (KDLoss) On-policy-KD Off-policy-SFT (LMLoss) + On-policy-KD
Qwen3-1.7B-Base 1.67 5.42   11.25 9.79
Qwen3-4B-Base 2.71 24.58 30.42 28.96 47.71
Qwen3-4B-Instruct-2507 46.88     61.04  

cc: @terrykong @sharathts

Summary by CodeRabbit

  • New Features

    • End-to-end student→teacher distillation: configs, example runner, setup/launch, training/validation, checkpointing, and vLLM-backed generation for varied recipes.
  • Performance

    • Distributed top-k/logprob and entropy primitives plus numerical-stability improvements to support large sharded training.
  • Documentation

    • Many new, templated distillation config recipes (single/multi-node, dynamic batching, seqpack, long-context).
  • Tests

    • New unit and functional tests and runner scripts validating distillation, loss behavior, config validation, and metrics.

@zpqiu zpqiu linked an issue Aug 28, 2025 that may be closed by this pull request
@zpqiu zpqiu force-pushed the feat-distillation branch 2 times, most recently from d2ec99c to 0bfecd2 Compare September 5, 2025 09:30
@xxman-google
Copy link
Copy Markdown
Contributor

Hi, is this PR implementing the on-policy distillation described in Qwen3 paper?
"On-policy Distillation: In this phase, the student model generates on-policy sequences for
fine-tuning. Specifically, prompts are sampled, and the student model produces responses in
either /think or /no think mode. The student model is then fine-tuned by aligning its logits
with those of a teacher model (Qwen3-32B or Qwen3-235B-A22B) to minimize the KL divergence"

What is the status of this PR? What I can try for now?

@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Sep 6, 2025

Hi, is this PR implementing the on-policy distillation described in Qwen3 paper? "On-policy Distillation: In this phase, the student model generates on-policy sequences for fine-tuning. Specifically, prompts are sampled, and the student model produces responses in either /think or /no think mode. The student model is then fine-tuned by aligning its logits with those of a teacher model (Qwen3-32B or Qwen3-235B-A22B) to minimize the KL divergence"

What is the status of this PR? What I can try for now?

Hi @xxman-google , we are glad that you are interested in our PR. Yes, we aim to implement on-policy distillation as described in Qwen3.

The basic functionalities are now in place, supporting TP, CP, sequence packing, and dynamic batching. You can use the command uv run examples/run_distillation_math.py to start a test run experiment.

In the first commit, we conducted some preliminary experiments using Qwen3-32B as the teacher model, and we observed improvements in mathematical tasks. As shown in the figure below:
截屏2025-08-27 22 55 02

In the latest commits, we have mainly been refining the functionalities. The experiments are based on a single-node environment and small-scale models. We are currently perfecting the test cases and validating the latest version of the code on the 32B model.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Sep 17, 2025

📝 Walkthrough

Walkthrough

Adds an end‑to‑end on‑policy distillation pipeline: new distillation configs/recipes, an example runner, a Distillation algorithm (setup/train/validate), a distributed top‑k-aware DistillationLossFn and distributed utilities, policy top‑k APIs/workers, batched-data handling updates, and unit/functional test scripts.

Changes

Cohort / File(s) Summary
Core algorithm
nemo_rl/algorithms/distillation.py
New distillation module: types (DistillationConfig, DistillationSaveState, MasterConfig), setup(), distillation_train(), validate(), generation/policy orchestration, logger/checkpointer integration, and save‑state handling.
Loss & numeric fixes
nemo_rl/algorithms/loss_functions.py, tests/unit/algorithms/test_loss_functions.py
Adds DistillationLossConfig/DistillationLossDataDict and DistillationLossFn with TP/CP top‑k support (forward/reverse/mixed KL, temperature, zero_outside_topk); casts logits to float32 in existing losses; unit tests for distillation loss.
Distributed model utilities
nemo_rl/distributed/model_utils.py
Adds chunked distributed autograd ops and helpers: ChunkedDistributedGatherLogprob, ChunkedDistributedEntropy, distributed_vocab_topk, gather_logits_at_global_indices for TP/CP-aware top‑k/logprob gathering.
Batched data handling
nemo_rl/distributed/batched_data_dict.py
from_batches extended to handle 3D tensors (pad along sequence dim then concat) to preserve per‑position features such as top‑k logits/indices.
Policy interfaces & top‑k API
nemo_rl/models/policy/interfaces.py, nemo_rl/models/policy/lm_policy.py
Adds TopkLogitsOutputSpec and Policy.get_topk_logits abstract API; lm_policy aggregates per‑worker top‑k outputs into a single BatchedDataDict.
DTensor workers (v1 & v2)
nemo_rl/models/policy/dtensor_policy_worker.py, ..._v2.py
Adds get_topk_logits implementations computing distributed global top‑k (TP/CP aware), gathering topk_logits/topk_indices, handling dynamic batching and seqpack; train() now forwards context_parallel_group when CP active.
Configs & recipes
examples/configs/distillation_math.yaml, examples/configs/recipes/llm/*.yaml
New base distillation config and multiple recipe YAMLs (Qwen distillation variants) using POLICY_BASE/DTENSOR_BASE anchors, vLLM colocated generation, dynamic batching/seqpack options, optimizers, schedulers, logging and cluster settings.
Example runner
examples/run_distillation_math.py
New CLI example: arg parsing, tokenizer resolver, hf_data_processor, setup_data for math datasets, Ray env integration, config loading, setup() invocation, and invoking distillation_train.
Tests — unit & config validation
tests/unit/algorithms/test_distillation.py, tests/unit/test_config_validation.py
Unit tests for distillation loop (max steps) and validate(); config validation extended to validate distillation and optional loss_fn.
Tests — functional & suites
tests/functional/distillation.sh, tests/test_suites/llm/*.sh
Functional runner and multiple suite scripts to launch recipe configs, convert TB logs to JSON, and run gated metric checks (loss, GPU memory, timing).
Test runner helpers invoked
tests/json_dump_tb_logs.py, tests/check_metrics.py
Existing helpers are invoked by new test scripts (no direct changes in this diff).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as CLI
  participant Run as run_distillation_math.py
  participant Setup as setup()
  participant Student as Student Policy
  participant Teacher as Teacher Policy
  participant Gen as Generation (vLLM)
  participant Data as DataLoader
  participant Loss as DistillationLossFn
  participant Log as Logger/Checkpoint

  CLI->>Run: parse config & overrides
  Run->>Setup: load config, tokenizer, datasets, envs
  Setup->>Student: init student policy + optimizer
  Setup->>Teacher: init teacher policy (inference)
  Setup-->>Run: return components

  loop training step
    Run->>Data: sample prompts
    Run->>Gen: generate responses
    Gen-->>Run: message_log
    Run->>Teacher: get_topk_logits(k)
    Teacher-->>Run: teacher_topk_logits & indices
    Run->>Loss: prepare data dict (student logits + teacher topk)
    Run->>Student: train_step via loss_fn
    Student-->>Run: loss & metrics
    Run->>Log: log metrics & timings
    alt checkpoint interval
      Run->>Log: save checkpoint (model, opt, tokenizer, state)
    end
    alt validation interval
      Run->>Gen: validation rollouts
      Gen->>Log: val metrics
    end
  end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120–180 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • terrykong

Poem

A rabbit taps keys with a scholar’s delight,
Fetches top‑k carrots under bfloat16 light.
Teacher whispers logits, student learns the tune,
Checkpoints hum softly beneath the moon.
Hops through TP and CP — distilled and bright.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly reflects the main change by announcing the addition of an on-policy distillation algorithm and follows a conventional commit prefix style. It is concise, specific, and avoids extraneous information. Therefore, it meets the requirements for a clear and descriptive pull request title.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zpqiu zpqiu marked this pull request as ready for review September 17, 2025 13:27
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/unit/algorithms/test_loss_functions.py (1)

1414-1858: Remove duplicate distillation test definitions.

The entire distillation test suite (from setup_distillation_test_data through all test functions) appears to be duplicated in the patch. This creates redundant code that will make maintenance harder.

Remove the duplicate block of distillation tests. Keep only one set of the test functions. The duplication appears to start at line 1414 and continues through the end of the file.

🧹 Nitpick comments (51)
nemo_rl/distributed/batched_data_dict.py (3)

134-159: 3D path: validate shapes/devices and guard against silent shape drift.

Add basic consistency checks (dtype, device, feature dim) and avoid negative padding. This prevents subtle crashes when a worker emits a different k (top‑k) or dtype.

-                if list_of_tensors[0].ndim == 3:
-                    # 对三维张量,只在序列维度(这里是第1维)补 pad,保留特征维度
-                    max_seq_len = max(tensor.shape[1] for tensor in list_of_tensors)
+                if list_of_tensors[0].ndim == 3:
+                    # For 3D tensors shaped [B, S, K], pad only along the sequence dim (dim=1), keep feature dim intact.
+                    base = list_of_tensors[0]
+                    # sanity checks across workers/microbatches
+                    assert all(t.ndim == 3 for t in list_of_tensors), "Expected 3D tensors"
+                    assert all(t.dtype == base.dtype for t in list_of_tensors), "Mismatched dtypes in from_batches"
+                    assert all(t.device == base.device for t in list_of_tensors), "Mismatched devices in from_batches"
+                    assert len({t.shape[2] for t in list_of_tensors}) == 1, "Last dim (K) must match across tensors"
+                    max_seq_len = max(tensor.shape[1] for tensor in list_of_tensors)
                     padded_tensors = []
                     for tensor in list_of_tensors:
-                        # 在第1维补 pad 到 max_seq_len
-                        pad_length = max_seq_len - tensor.shape[1]
+                        # Right-pad dim=1 up to max_seq_len
+                        pad_length = max(0, max_seq_len - tensor.shape[1])
                         padded = torch.nn.functional.pad(
                             tensor,
-                            (0, 0, 0, pad_length),  # 只补最后二个维度(序列长度)
+                            (0, 0, 0, pad_length),  # pad last two dims, second-last is sequence
                             mode="constant",
                             value=pad_value,
                         )
                         padded_tensors.append(padded)
-                    tensor_or_list = torch.cat(padded_tensors, dim=0)  # 在批次维度拼接
+                    tensor_or_list = torch.cat(padded_tensors, dim=0)  # concat on batch dim

134-148: Replace non-ASCII punctuation and translate inline comments to pass Ruff (RUF003) and aid maintainability.

Current comments use full‑width punctuation and Chinese; convert to concise English.

-                # bug出在这里,当train_data中包含3维tensor时,下面的flatten()会错误展平后两个维度。
-                # grpo在调用这个函数时只会传入二维tensor,但是sft和dpo这里还需要检查有无三维特殊情况。
-                # 现在添加了如下if语句处理蒸馏中的三维情况(即teacher topk logits和indices),else后为原代码。
+                # Bug: when train_data contains a 3D tensor, the flatten() below incorrectly flattens the last two dims.
+                # GRPO sends only 2D tensors here; SFT/DPO/distillation may produce 3D tensors (e.g., teacher top‑k logits/indices).
+                # Handle 3D tensors below; the else branch preserves the original 2D behavior.

151-159: Masking semantics for padded top‑k indices.

If these 3D tensors hold token indices, padding with 0 may produce “in‑range” ids. Ensure downstream losses mask padded positions (via input_lengths) before using indices/logits, or switch pad_value for index tensors to an out‑of‑range sentinel and mask accordingly.

nemo_rl/distributed/model_utils.py (3)

258-381: ChunkedDistributedGatherLogprob: math and TP-reduction look correct; consider memory trim.

ctx.save_for_backward stores full logits and indices; with long S this is heavy. Option: save only global_indices and recompute logits from upstream (or re‑compute log_softmax from the forward input without saving the fp16->fp32 casted tensor) by toggling torch.set_grad_enabled+recompute. If you keep saving logits, document memory impact and gate with a flag.


828-897: distributed_vocab_topk: correct global top‑k; tighten a couple details.

  • K_eff can be simplified to min(k, V_total).
  • Consider using log_softmax only if ties/temperature matter; logits are fine for ordering.
  • Prefer named args for all_gather for clarity.
-    V_total = V_local * world_size
-    K_eff = int(min(k, max(1, V_total)))
+    V_total = V_local * world_size
+    K_eff = min(k, V_total)
@@
-        torch.distributed.all_gather(gathered_vals, local_vals, group=tp_group)
-        torch.distributed.all_gather(gathered_idx, local_idx_global, group=tp_group)
+        torch.distributed.all_gather(tensor_list=gathered_vals, tensor=local_vals, group=tp_group)
+        torch.distributed.all_gather(tensor_list=gathered_idx, tensor=local_idx_global, group=tp_group)

984-1055: ChunkedDistributedEntropy: docstring vs. sign.

You return sum_v p log p (negative entropy). Either rename to “log_softmax_self_cross_term” in the doc or note that it’s negative entropy. Gradient formula matches sum p log p.

-    """Compute H_all = sum_v p_v log p_v across TP with chunking over sequence.
+    """Compute H_all = sum_v p_v log p_v (negative entropy) across TP with chunking over sequence.
tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh (2)

16-18: Quote expansions and guard cd; fix SC2068/SC2164.

-cd $PROJECT_ROOT
+cd "$PROJECT_ROOT" || exit 1
@@
-    $@ \
+    "$@" \

Also applies to: 28-28


5-11: Unused config knobs — export or remove.

NUM_NODES/NUM_RUNS/NUM_MINUTES are unused here. If common.env relies on them, export; otherwise drop to silence SC2034.

-NUM_NODES=1
-STEPS_PER_RUN=100
-MAX_STEPS=100
-NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Round up
-NUM_MINUTES=300
+export NUM_NODES=1
+export STEPS_PER_RUN=100
+export MAX_STEPS=100
+export NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Round up
+export NUM_MINUTES=300
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh (2)

16-18: Quote expansions and guard cd; fix SC2068/SC2164.

-cd $PROJECT_ROOT
+cd "$PROJECT_ROOT" || exit 1
@@
-    $@ \
+    "$@" \

Also applies to: 28-28


5-11: Unused config knobs — export or remove.

Same as the 1n8g script.

-NUM_NODES=2
-STEPS_PER_RUN=100
-MAX_STEPS=100
-NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Round up
-NUM_MINUTES=300
+export NUM_NODES=2
+export STEPS_PER_RUN=100
+export MAX_STEPS=100
+export NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Round up
+export NUM_MINUTES=300
tests/functional/distillation.sh (2)

43-44: Fix command argument expansion for proper handling of additional arguments.

The shellcheck warning is correct - unquoted array expansion can cause issues with arguments containing spaces or special characters.

Apply this fix to properly quote the array expansion:

-    $@ \
+    "$@" \

49-49: Consider adding proper error handling for metrics validation.

The script continues even if the metrics validation fails. Consider whether you want to exit with a non-zero status code when the assertion fails to signal test failure to CI/CD systems.

If test failure propagation is important, consider adding:

 uv run tests/check_metrics.py $JSON_METRICS \
-  'data["train/loss"]["3"] < 10.0'
+  'data["train/loss"]["3"] < 10.0' || exit 1
tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh (3)

28-29: Fix command argument expansion for proper handling of additional arguments.

Same issue as in the functional test - unquoted array expansion can cause problems with arguments containing spaces.

Apply this fix:

-    $@ \
+    "$@" \

16-16: Add error handling for directory change operation.

Shellcheck correctly identifies that cd without error handling can lead to commands being executed in the wrong directory if the cd fails.

Apply this fix to handle potential cd failure:

-cd $PROJECT_ROOT
+cd $PROJECT_ROOT || exit 1

6-10: Consider exporting or using the unused configuration variables.

The variables NUM_NODES, NUM_RUNS, and NUM_MINUTES are defined but never used in the script. Either use them in the distillation command or remove them to avoid confusion.

If these are intended for future use or documentation purposes, consider adding a comment:

# Configuration for multi-node setup (for reference)
NUM_NODES=2  # Used by cluster orchestration
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # For multi-run scenarios  
NUM_MINUTES=3000  # Time budget for the experiment
tests/unit/algorithms/test_loss_functions.py (1)

1806-1806: Use idiomatic boolean comparisons instead of equality checks.

Static analysis correctly identifies non-idiomatic boolean comparisons in the assertions.

Apply these fixes:

-    assert loss_fn.zero_outside_topk == False
+    assert not loss_fn.zero_outside_topk
-    assert loss_fn.zero_outside_topk == True
+    assert loss_fn.zero_outside_topk

Also applies to: 1821-1821

tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh (3)

28-29: Fix command argument expansion for proper handling of additional arguments.

Same issue as the other test scripts - unquoted array expansion can cause problems.

Apply this fix:

-    $@ \
+    "$@" \

16-16: Add error handling for directory change operation.

Same as the 2n8g script - cd without error handling is risky.

Apply this fix:

-cd $PROJECT_ROOT
+cd $PROJECT_ROOT || exit 1

6-10: Configuration inconsistency between nodes and comment.

The comment says "4 nodes (4n8g)" but NUM_NODES=4 is defined but never used in the script.

Either use the NUM_NODES variable in your orchestration or add a comment explaining why it's defined but not used:

NUM_NODES=4  # Expected by cluster orchestrator (not used directly in this script)
tests/unit/algorithms/test_distillation.py (3)

89-91: Remove unused self parameter from nested functions.

The static analysis correctly identifies that the self parameter in the nested functions train_iter and val_iter is unused since these are not methods of a class.

Apply this fix to remove the unused parameter:

-    def train_iter(self):
+    def train_iter():
         return iter([mock_batch] * 10)

-    train_dataloader.__iter__ = train_iter
+    train_dataloader.__iter__ = lambda: train_iter()
-    def val_iter(self):
+    def val_iter():
         return iter([mock_batch] * 10)

-    val_dataloader.__iter__ = val_iter  
+    val_dataloader.__iter__ = lambda: val_iter()

Also applies to: 97-101


214-214: Consider adding more comprehensive validation assertions.

The test only verifies that train is called 5 times but doesn't validate other important aspects like teacher inference calls or data processing.

Consider adding more assertions to ensure the distillation workflow is functioning correctly:

# Verify teacher was called for top-k logits
assert mock_components["teacher_policy"].get_topk_logits.call_count == 5

# Verify proper preparation methods were called
assert mock_components["teacher_policy"].prepare_for_lp_inference.call_count == 5
assert mock_components["student_policy"].prepare_for_training.call_count == 5

56-56: Document the None value for student_generation.

Setting student_generation = None is a clever way to avoid Ray-related refit issues in tests, but this deserves more explanation for future maintainers.

Expand the comment to be more descriptive:

# Set student_generation to None to trigger the NEED_REFIT = False path
# in distillation_train. This avoids calling refit_policy_generation which
# would require Ray to be properly initialized. In this case, student_policy
# will be used directly as the generation interface (Megatron backend).
student_generation = None
tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh (1)

1-42: Fix shell script issues for production readiness.

Several shell scripting issues need to be addressed:

  1. Variables declared but not used
  2. Missing error handling for cd command
  3. Improper quoting of array expansion

Apply this diff to fix the issues:

 #!/bin/bash
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
 source $SCRIPT_DIR/common.env

 # ===== BEGIN CONFIG =====
-NUM_NODES=1
+export NUM_NODES=1  # Export if used by common.env or child processes
 STEPS_PER_RUN=100
 MAX_STEPS=100
-NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Round up
-NUM_MINUTES=300
+# NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))  # Unused - remove or export if needed
+export NUM_MINUTES=300  # Export if used by scheduling/monitoring
 # ===== END CONFIG =====

 exit_if_max_steps_reached

 # Run the experiment
-cd $PROJECT_ROOT
+cd $PROJECT_ROOT || exit 1
 uv run examples/run_distillation_math.py \
     --config $CONFIG_PATH \
     distillation.max_num_steps=$MAX_STEPS \
     logger.log_dir=$LOG_DIR \
     logger.wandb_enabled=True \
     logger.wandb.project=nemo-rl-distillation \
     logger.wandb.name=$EXP_NAME \
     logger.monitor_gpus=True \
     logger.tensorboard_enabled=True \
     checkpointing.enabled=True \
     checkpointing.checkpoint_dir=$CKPT_DIR \
-    $@ \
+    "$@" \
     2>&1 | tee $RUN_LOG
tests/test_distillation_simple.py (2)

140-153: Remove unused self parameter in nested functions.

The nested functions train_iter and val_iter don't use self and should be refactored.

Convert to lambda functions for clarity:

-    def train_iter(self):
-        return iter([mock_batch] * 10)
-
-    train_dataloader.__iter__ = train_iter
+    train_dataloader.__iter__ = lambda: iter([mock_batch] * 10)
     train_dataloader.__len__ = MagicMock(return_value=10)

     val_dataloader = MagicMock(spec=StatefulDataLoader)
-
-    def val_iter(self):
-        return iter([mock_batch] * 5)
-
-    val_dataloader.__iter__ = val_iter
+    val_dataloader.__iter__ = lambda: iter([mock_batch] * 5)
     val_dataloader.__len__ = MagicMock(return_value=5)

158-172: Remove unused variables.

Several variables are created but never used in the test, making it unclear what is being tested.

Either use these variables in assertions or remove them:

-    loss_fn = DistillationLossFn(
+    # Test loss function initialization
+    loss_fn = DistillationLossFn(
         {
             "temperature": 1.0,
             "alpha": 0.5,
             "kl_type": "forward",
             "mixed_kl_weight": 0.5,
             "zero_outside_topk": False,
         }
     )
+    assert loss_fn.temperature == 1.0
+    assert loss_fn.alpha == 0.5

-    logger = MagicMock()
-    checkpointer = MagicMock()
+    # These would be used in actual distillation training
+    # logger = MagicMock()
+    # checkpointer = MagicMock()
nemo_rl/models/policy/dtensor_policy_worker.py (2)

1185-1524: Well-structured distributed top-k implementation with minor improvements needed.

The get_topk_logits method is well-implemented with proper support for TP, CP, dynamic batching, and sequence packing. The distributed top-k computation and gathering logic is correct.

A few minor improvements:

  1. Remove duplicate code block (lines 1452-1489 duplicate lines 1413-1450)
  2. Fix unused variable at line 1335
  3. Remove duplicate model.eval() at line 1213

Apply this diff:

@@ -1210,7 +1210,6 @@ class DTensorPolicyWorker:
         self.model.eval()
         out_topk_vals = []
         out_topk_idx = []
-        self.model.eval()
 
         with unshard_fsdp2_model(self.model), torch.no_grad():
@@ -1333,13 +1332,6 @@ class DTensorPolicyWorker:
                     # IMPORTANT: do not apply generation temperature scaling here for teacher top-k
 
                     if self.cp_size > 1:
-                        seq_index_tensor = (
-                            DTensor.from_local(
-                                seq_index,
-                                device_mesh=self.cp_mesh,
-                                placements=[Shard(1)],
-                            )
-                            .full_tensor()
-                            .squeeze(0)
-                        )
+                        # seq_index_tensor would be used if we needed CP-aware indexing
+                        # Currently handled internally by distributed_vocab_topk
 
                         if isinstance(logits, DTensor):
@@ -1450,40 +1442,6 @@ class DTensorPolicyWorker:
                     batch_size = original_batch_size
                     seq_len = original_seq_len
 
-                # Handle sequence packing unpacking
-                if self.enable_seq_packing:
-                    # Unpack top-k results from packed format back to original batch format
-                    # ... duplicate block removed ...
-
                 # Keep only real sequence tokens (mask padded positions)

1243-1245: Remove unused loop variable.

The batch_idx variable is not used within the loop body.

-            for batch_idx, lp_batch in enumerate(
+            for _, lp_batch in enumerate(
                 itertools.chain(mb_iterator, dummy_iterator)
             ):
nemo_rl/models/policy/lm_policy.py (1)

388-395: Consider adding strict parameter to zip for safety.

While the current implementation ensures matching lengths through the worker dispatch mechanism, adding strict=True would make this guarantee explicit.

         all_topk_indices = [wb["topk_indices"] for wb in worker_batches]
 
         stacked: BatchedDataDict[TopkLogitsOutputSpec] = BatchedDataDict()
         stacked["topk_logits"] = torch.cat(all_topk_logits, dim=0)
         stacked["topk_indices"] = torch.cat(all_topk_indices, dim=0)

Note: The zip() warning from static analysis at line 1501 in dtensor_policy_worker.py should also be addressed:

-        for vals, idx in zip(out_topk_vals, out_topk_idx):
+        for vals, idx in zip(out_topk_vals, out_topk_idx, strict=True):
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml (1)

86-93: Optional: Cap generation length to avoid pathological rollouts

max_new_tokens: ${..max_total_sequence_length} can make on-policy rollouts as long as the full train length (8192). Consider capping generation (e.g., 1024–2048) for faster iterations unless long reasoning chains are explicitly required.

examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)

62-69: Optimizer defaults OK for small steps

AdamW settings look fine for short runs; consider scheduler warmup length scaling with max_num_steps if you extend training.

examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml (1)

86-101: Optional: Consider a smaller generation cap for throughput

max_new_tokens: 32768 will be very slow. If you don't need full-length generations during training, cap to a smaller value and reserve 32k for eval.

examples/configs/distillation_math.yaml (3)

64-69: Comment contradicts formula for make_sequence_length_divisible_by

Comment says “must be divisible by 2*cp”, but the expression enforces 2 * tp * cp. Align the comment or the expression.

Apply one of:

  • Update comment:
-    # must be divisible by 2*cp
+    # must be divisible by 2*tp*cp
  • Or change the formula if only 2*cp is intended:
-    make_sequence_length_divisible_by: ${mul:${mul:${.dtensor_cfg.tensor_parallel_size}, ${.dtensor_cfg.context_parallel_size}}, 2}
+    make_sequence_length_divisible_by: ${mul:${.dtensor_cfg.context_parallel_size}, 2}

171-171: Add trailing newline (yamllint new-line-at-end-of-file)

Ends without a newline; add one to satisfy linters.

-    num_nodes: 1
+    num_nodes: 1
+

42-50: Style consistency: Booleans use capitalized YAML literals

Minor: elsewhere you use lowercase false/true. Consider normalizing for consistency.

examples/run_distillation_math.py (5)

78-96: Over-verbose prompt-file path handling, unused variable, and blind except

  • content is unused.
  • Blind except Exception and re-raising without from.
  • This block still always raises when task_data_spec.prompt is None; reading the file adds no value. TaskDataSpec already loads/validates files.

Apply this simplification:

-    if task_data_spec.prompt is None:
-        if task_data_spec.prompt_file:
-            if os.path.exists(task_data_spec.prompt_file):
-                try:
-                    with open(task_data_spec.prompt_file, "r", encoding="utf-8") as f:
-                        content = f.read()
-                except Exception as e:
-                    raise ValueError(f"Failed to read file: {e}")
-            else:
-                raise ValueError(
-                    f"Prompt file does not exist: {task_data_spec.prompt_file}"
-                )
-
-        raise ValueError(
-            f"TaskDataSpec.prompt is None. This usually means the prompt file "
-            f"'{task_data_spec.prompt_file}' could not be loaded or is empty. "
-            f"Current working directory: {os.getcwd()}, "
-            f"Absolute prompt file path: {os.path.abspath(task_data_spec.prompt_file) if task_data_spec.prompt_file else 'None'}"
-        )
+    if task_data_spec.prompt is None:
+        pf = task_data_spec.prompt_file
+        raise FileNotFoundError(f"Prompt is missing (prompt_file={pf})")

104-118: Type mismatch: apply_chat_template returns str, not list[str]

message is a string. Update the annotation and name to avoid confusion.

-    message: list[str] = tokenizer.apply_chat_template(  # type: ignore
+    rendered: str = tokenizer.apply_chat_template(  # type: ignore
         [user_message],
         tokenize=False,
         add_generation_prompt=True,
         add_special_tokens=False,
     )
-    user_message["token_ids"] = tokenizer(
-        message,
+    user_message["token_ids"] = tokenizer(
+        rendered,
         return_tensors="pt",
         add_special_tokens=False,
     )["input_ids"][0]
-    user_message["content"] = message
+    user_message["content"] = rendered

253-259: Hardcoded seed ignores config

setup_data(..., 42) ignores distillation.seed. Use the configured seed.

-    ) = setup_data(tokenizer, config["data"], config["env"], 42)
+    ) = setup_data(
+        tokenizer,
+        config["data"],
+        config["env"],
+        seed=config["distillation"]["seed"],
+    )

130-140: Truncation policy silently zeroes loss after clipping

Setting loss_multiplier = 0.0 when truncation occurs masks training on that sample. Consider logging a counter or sampling fewer-long prompts to reduce silent drop.

Would you like a small metric hook to log “truncated_samples” per step?


187-195: Ray actor runtime_env depends on registry entry

This call will raise if the env is not registered. If you expect users to run locally without pre-registration, add a fallback to PY_EXECUTABLES.SYSTEM.

Would you like a guarded fallback here?

examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml (3)

86-101: Cap generation length or add stop criteria to avoid runaway 32K token generations

max_new_tokens equals the full model context (32768) with no stop_strings/stop_token_ids. This can lead to extreme runtimes/OOM when prompts are short. Consider capping or wiring task-specific stops.

Apply one of:

-    max_new_tokens: ${..max_total_sequence_length}
+    # Option A: cap hard
+    max_new_tokens: 8192
+    # Option B: make configurable via env var fallback
+    max_new_tokens: ${oc.env:MAX_NEW_TOKENS,8192}
+    # Also add task-specific stops if available
+    stop_strings:
+      - "Answer:"
+      - "</final_answer>"

95-101: Consider raising vLLM memory utilization for throughput

gpu_memory_utilization: 0.6 is conservative and may underutilize GPUs on long runs. 0.80–0.90 is typical when KV cache is the dominant consumer.

-      gpu_memory_utilization: 0.6
+      gpu_memory_utilization: 0.85

49-61: Sequence packing remains disabled while dynamic batching is enabled

Given long-seq runs, enabling packing can reduce padding costs during training/logprob passes.

-  sequence_packing:
-    enabled: false
+  sequence_packing:
+    enabled: true
     train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}}
     logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}}
     algorithm: "modified_first_fit_decreasing"
     sequence_length_round: 64
nemo_rl/algorithms/loss_functions.py (3)

885-902: Guard against missing teacher inputs with a clear error

teacher_topk_logits/teacher_topk_indices are accessed unconditionally. Add validation to fail fast with actionable messaging.

-        teacher_topk_logits = data["teacher_topk_logits"]  # [B, S, k]
-        teacher_topk_indices = data["teacher_topk_indices"]  # [B, S, k]
+        if "teacher_topk_logits" not in data or "teacher_topk_indices" not in data:
+            raise KeyError(
+                "DistillationLossFn requires 'teacher_topk_logits' and 'teacher_topk_indices' in data."
+            )
+        teacher_topk_logits = data["teacher_topk_logits"]  # [B, S, k]
+        teacher_topk_indices = data["teacher_topk_indices"]  # [B, S, k]

930-976: Support for zero_outside_topk already exists under TP/CP; keep it documented

The implementation handles TP/CP via chunked ops; make this explicit in class docstring to avoid restrictive external assertions.

 class DistillationLossFn(LossFunction):
-    """Distillation loss function."""
+    """Distillation loss function.
+
+    Supports TP/CP via ChunkedDistributedGatherLogprob and ChunkedDistributedEntropy.
+    When zero_outside_topk=True, reverse/mixed KL include the missing-mass correction.
+    """

1002-1011: Move teacher tensors to the student device once and preserve dtype

Tiny micro-optimization: store device/dtype before branching.

-        if self.zero_outside_topk:
-            teacher_topk_logits = teacher_topk_logits.to(
-                student_topk_logprobs.device, dtype=student_topk_logprobs.dtype
-            )
-        else:
-            teacher_topk_logits = teacher_topk_logits.to(
-                student_topk_logits.device, dtype=student_topk_logits.dtype
-            )
+        tgt_device = (student_topk_logprobs if self.zero_outside_topk else student_topk_logits).device
+        tgt_dtype = (student_topk_logprobs if self.zero_outside_topk else student_topk_logits).dtype
+        teacher_topk_logits = teacher_topk_logits.to(tgt_device, dtype=tgt_dtype)
examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml (2)

62-63: Divisibility set to 2; consider 8 or 16 for tensor core friendliness

A higher multiple often aligns better with attention kernels on modern GPUs.

-  make_sequence_length_divisible_by: 2
+  make_sequence_length_divisible_by: 8

86-113: Same generation safeguards as long recipe

Recommend adding stop criteria and/or lower max_new_tokens here as well.

-    max_new_tokens: ${..max_total_sequence_length}
+    max_new_tokens: 4096
+    stop_strings:
+      - "</final_answer>"
nemo_rl/algorithms/distillation.py (4)

590-596: Duplicate offload call

offload_after_refit() is invoked twice back‑to‑back. Keep one.

-                with timer.time("logprob_inference_prep"):
-                    teacher_policy.offload_after_refit()
-
-                print("▶ Preparing for training...")
-                with timer.time("training_prep"):
-                    teacher_policy.offload_after_refit()
+                with timer.time("logprob_inference_prep"):
+                    teacher_policy.offload_after_refit()
+
+                print("▶ Preparing for training...")
+                with timer.time("training_prep"):

136-143: Return signature/docstring drift

Docstring mentions clusters but function returns (student_policy, teacher_policy, student_generation, dataloader, val_dataloader, loss_fn, logger, checkpointer, save_state, master_config). Update docstring to avoid confusion.

-    Returns:
-        tuple of student_policy, student_generation,
-        (train_cluster, inference_cluster), train_dataloader, val_dataloader,
-        loss_fn, logger, checkpointer, distillation_save_state, master_config
+    Returns (in order):
+        student_policy, teacher_policy, student_generation,
+        train_dataloader, val_dataloader, loss_fn, logger,
+        checkpointer, distillation_save_state, master_config

666-685: Checkpoint metric presence warning: add stacklevel and explicit metric fallback

Noise reduction in logs and clearer fallback.

-                            warnings.warn(
-                                f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
-                                "Saving most recent k checkpoints instead."
-                            )
+                            warnings.warn(
+                                (
+                                    "Checkpoint metric "
+                                    f"{master_config['checkpointing']['metric_name']} not found in save state; "
+                                    "saving most recent k checkpoints instead."
+                                ),
+                                stacklevel=2,
+                            )

857-868: Avoid broad exception message formatting nit

Use explicit conversion flag for f‑string.

-        except Exception as e:
-            print(f"\n  ⚠️ Error displaying message samples: {str(e)}")
+        except Exception as e:
+            print(f"\n  ⚠️ Error displaying message samples: {e!s}")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5a9f7ac and eae6290.

📒 Files selected for processing (25)
  • examples/configs/distillation_math.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml (1 hunks)
  • examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml (1 hunks)
  • examples/run_distillation_math.py (1 hunks)
  • nemo_rl/algorithms/distillation.py (1 hunks)
  • nemo_rl/algorithms/loss_functions.py (6 hunks)
  • nemo_rl/distributed/batched_data_dict.py (1 hunks)
  • nemo_rl/distributed/model_utils.py (2 hunks)
  • nemo_rl/models/policy/dtensor_policy_worker.py (3 hunks)
  • nemo_rl/models/policy/interfaces.py (2 hunks)
  • nemo_rl/models/policy/lm_policy.py (3 hunks)
  • tests/functional/distillation.sh (1 hunks)
  • tests/test_distillation_simple.py (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh (1 hunks)
  • tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh (1 hunks)
  • tests/unit/algorithms/test_distillation.py (1 hunks)
  • tests/unit/algorithms/test_loss_functions.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (16)
tests/unit/algorithms/test_distillation.py (4)
nemo_rl/algorithms/distillation.py (3)
  • _default_distillation_save_state (92-97)
  • distillation_train (390-773)
  • validate (776-888)
nemo_rl/algorithms/loss_functions.py (1)
  • DistillationLossFn (849-1104)
nemo_rl/data/interfaces.py (1)
  • DatumSpec (32-40)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-857)
tests/test_distillation_simple.py (2)
nemo_rl/algorithms/loss_functions.py (1)
  • DistillationLossFn (849-1104)
tests/unit/algorithms/test_distillation.py (2)
  • train_iter (89-90)
  • val_iter (97-98)
examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml (1)
tests/unit/models/generation/test_vllm_generation.py (2)
  • test_vllm_weight_update_and_prefix_cache_reset (1046-1158)
  • test_vllm_megatron_weight_update_with_packing (1778-1832)
nemo_rl/distributed/model_utils.py (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
  • get_logprobs (835-1131)
nemo_rl/algorithms/loss_functions.py (3)
nemo_rl/distributed/model_utils.py (6)
  • ChunkedDistributedEntropy (984-1055)
  • ChunkedDistributedGatherLogprob (258-381)
  • _get_tokens_on_this_cp_rank (669-703)
  • allgather_cp_sharded_tensor (706-709)
  • from_parallel_logits_to_logprobs (469-541)
  • gather_logits_at_global_indices (899-981)
nemo_rl/algorithms/interfaces.py (2)
  • LossFunction (28-70)
  • LossType (23-25)
nemo_rl/algorithms/utils.py (1)
  • masked_mean (128-140)
nemo_rl/models/policy/interfaces.py (4)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • get_topk_logits (1186-1524)
nemo_rl/models/policy/lm_policy.py (1)
  • get_topk_logits (336-400)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-857)
nemo_rl/models/generation/interfaces.py (1)
  • GenerationDatumSpec (127-158)
nemo_rl/models/policy/dtensor_policy_worker.py (3)
nemo_rl/distributed/model_utils.py (3)
  • allgather_cp_sharded_tensor (706-709)
  • distributed_vocab_topk (829-896)
  • get_logprobs_from_vocab_parallel_logits (778-825)
nemo_rl/utils/nsys.py (1)
  • wrap_with_nvtx_name (82-94)
nemo_rl/distributed/batched_data_dict.py (3)
  • BatchedDataDict (75-857)
  • to (822-829)
  • size (811-820)
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh (1)
tests/unit/test_recipes_and_test_suites.py (1)
  • test_dry_run_does_not_fail_and_prints_total_gpu_hours (193-221)
nemo_rl/algorithms/distillation.py (12)
nemo_rl/algorithms/grpo.py (2)
  • _should_use_async_rollouts (406-420)
  • refit_policy_generation (423-491)
nemo_rl/algorithms/loss_functions.py (3)
  • DistillationLossConfig (831-836)
  • DistillationLossDataDict (839-846)
  • DistillationLossFn (849-1104)
nemo_rl/algorithms/utils.py (1)
  • set_seed (143-148)
nemo_rl/data/datasets.py (2)
  • AllTaskProcessedDataset (36-131)
  • rl_collate_fn (134-178)
nemo_rl/data/llm_message_utils.py (2)
  • batched_message_log_to_flat_message (233-390)
  • get_keys_from_message_log (126-138)
nemo_rl/distributed/batched_data_dict.py (5)
  • BatchedDataDict (75-857)
  • repeat_interleave (721-742)
  • size (811-820)
  • get_multimodal_dict (88-99)
  • to (822-829)
nemo_rl/distributed/virtual_cluster.py (2)
  • ClusterConfig (32-34)
  • RayVirtualCluster (168-410)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • VllmGeneration (47-784)
nemo_rl/models/policy/interfaces.py (6)
  • ColocatablePolicyInterface (126-155)
  • prepare_refit_info (142-143)
  • get_topk_logits (78-87)
  • offload_after_refit (138-139)
  • prepare_for_training (110-111)
  • train (90-107)
nemo_rl/models/policy/lm_policy.py (9)
  • Policy (56-696)
  • finish_generation (554-556)
  • prepare_refit_info (562-571)
  • prepare_for_generation (539-541)
  • prepare_for_lp_inference (548-552)
  • get_topk_logits (336-400)
  • offload_after_refit (650-653)
  • prepare_for_training (543-546)
  • train (402-494)
nemo_rl/utils/checkpoint.py (5)
  • CheckpointManager (55-269)
  • get_latest_checkpoint_path (238-251)
  • load_training_info (253-269)
  • init_tmp_checkpoint (87-126)
  • finalize_checkpoint (128-157)
nemo_rl/utils/logger.py (3)
  • Logger (710-933)
  • print_message_log_samples (1022-1219)
  • log_batched_dict_as_jsonl (804-828)
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml (1)
tests/unit/models/policy/test_dtensor_worker.py (1)
  • policy_setup (134-147)
examples/run_distillation_math.py (13)
nemo_rl/algorithms/distillation.py (3)
  • MasterConfig (100-111)
  • distillation_train (390-773)
  • setup (119-382)
nemo_rl/algorithms/utils.py (1)
  • get_tokenizer (151-267)
nemo_rl/data/__init__.py (1)
  • DataConfig (18-35)
nemo_rl/data/datasets.py (1)
  • AllTaskProcessedDataset (36-131)
nemo_rl/data/hf_datasets/deepscaler.py (1)
  • DeepScalerDataset (67-78)
nemo_rl/data/hf_datasets/openmathinstruct2.py (1)
  • OpenMathInstruct2Dataset (77-105)
nemo_rl/data/interfaces.py (3)
  • DatumSpec (32-40)
  • TaskDataProcessFnCallable (89-100)
  • TaskDataSpec (53-86)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
  • get_actor_python_env (43-58)
nemo_rl/distributed/virtual_cluster.py (1)
  • init_ray (75-161)
nemo_rl/environments/math_environment.py (1)
  • MathEnvironment (222-372)
nemo_rl/utils/config.py (2)
  • load_config (74-137)
  • parse_hydra_overrides (146-166)
nemo_rl/utils/logger.py (1)
  • get_next_experiment_dir (1222-1256)
nemo_rl/models/generation/__init__.py (1)
  • configure_generation_config (24-45)
tests/unit/algorithms/test_loss_functions.py (1)
nemo_rl/algorithms/loss_functions.py (1)
  • DistillationLossFn (849-1104)
nemo_rl/distributed/batched_data_dict.py (2)
nemo_rl/data/llm_message_utils.py (1)
  • batched_message_log_to_flat_message (233-390)
tests/unit/distributed/test_batched_data_dict.py (2)
  • test_sequence_packing_microbatch_boundaries (550-683)
  • test_sequence_packing_basic (246-327)
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
  • __init__ (108-436)
tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh (1)
tests/unit/test_recipes_and_test_suites.py (1)
  • test_dry_run_does_not_fail_and_prints_total_gpu_hours (193-221)
nemo_rl/models/policy/lm_policy.py (3)
nemo_rl/models/policy/interfaces.py (2)
  • TopkLogitsOutputSpec (37-41)
  • get_topk_logits (78-87)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
  • get_topk_logits (1186-1524)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-857)
🪛 Shellcheck (0.10.0)
tests/functional/distillation.sh

[error] 43-43: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh

[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).

(SC2034)


[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


[error] 28-28: Double quote array expansions to avoid re-splitting elements.

(SC2068)

🪛 Ruff (0.12.2)
tests/unit/algorithms/test_distillation.py

89-89: Unused function argument: self

(ARG001)


97-97: Unused function argument: self

(ARG001)

tests/test_distillation_simple.py

1-1: Shebang is present but file is not executable

(EXE001)


140-140: Unused function argument: self

(ARG001)


148-148: Unused function argument: self

(ARG001)


158-158: Local variable loss_fn is assigned to but never used

Remove assignment to unused variable loss_fn

(F841)


168-168: Local variable logger is assigned to but never used

Remove assignment to unused variable logger

(F841)


169-169: Local variable checkpointer is assigned to but never used

Remove assignment to unused variable checkpointer

(F841)


171-171: Local variable task_to_env is assigned to but never used

Remove assignment to unused variable task_to_env

(F841)


172-172: Local variable val_task_to_env is assigned to but never used

Remove assignment to unused variable val_task_to_env

(F841)


188-188: Consider moving this statement to an else block

(TRY300)


190-190: Do not catch blind exception: Exception

(BLE001)

nemo_rl/algorithms/loss_functions.py

865-865: Unused method argument: global_valid_seqs

(ARG002)


1055-1055: Avoid specifying long messages outside the exception class

(TRY003)


1073-1073: Avoid specifying long messages outside the exception class

(TRY003)

nemo_rl/models/policy/dtensor_policy_worker.py

1243-1243: Loop control variable batch_idx not used within loop body

Rename unused batch_idx to _batch_idx

(B007)


1335-1335: Local variable seq_index_tensor is assigned to but never used

Remove assignment to unused variable seq_index_tensor

(F841)


1501-1501: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

nemo_rl/algorithms/distillation.py

266-269: Avoid specifying long messages outside the exception class

(TRY003)


271-274: Avoid specifying long messages outside the exception class

(TRY003)


534-534: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


535-535: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


660-660: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


866-866: Do not catch blind exception: Exception

(BLE001)


867-867: Use explicit conversion flag

Replace with conversion flag

(RUF010)

examples/run_distillation_math.py

83-83: Local variable content is assigned to but never used

Remove assignment to unused variable content

(F841)


84-84: Do not catch blind exception: Exception

(BLE001)


85-85: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


87-89: Avoid specifying long messages outside the exception class

(TRY003)


91-96: Avoid specifying long messages outside the exception class

(TRY003)


106-106: Do not catch blind exception: Exception

(BLE001)


107-107: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


107-107: Avoid specifying long messages outside the exception class

(TRY003)


180-180: Avoid specifying long messages outside the exception class

(TRY003)

tests/unit/algorithms/test_loss_functions.py

1806-1806: Avoid equality comparisons to False; use not loss_fn.zero_outside_topk: for false checks

Replace with not loss_fn.zero_outside_topk

(E712)


1821-1821: Avoid equality comparisons to True; use loss_fn.zero_outside_topk: for truth checks

Replace with loss_fn.zero_outside_topk

(E712)

nemo_rl/distributed/batched_data_dict.py

134-134: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


134-134: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


135-135: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


136-136: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


136-136: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)


136-136: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


138-138: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


138-138: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


138-138: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)


138-138: Comment contains ambiguous (FULLWIDTH COMMA). Did you mean , (COMMA)?

(RUF003)


146-146: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


146-146: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

🪛 GitHub Actions: Copyright check
tests/test_distillation_simple.py

[error] 1-1: Found files with missing copyright: path= ./tests/test_distillation_simple.py

🪛 YAMLlint (1.37.1)
examples/configs/distillation_math.yaml

[error] 171-171: no new line character at the end of file

(new-line-at-end-of-file)

🔇 Additional comments (14)
nemo_rl/distributed/model_utils.py (1)

899-982: gather_logits_at_global_indices: CP padding/gathering mirrors existing logprob path; confirm 2×CP chunking expectation.

_get_tokens_on_this_cp_rank assumes inputs are divisible by 2cp_size. Here pad_len uses cp_size only (same as from_parallel_logits_to_logprobs). Please confirm seq length invariants (and tests) cover this for global_indices; otherwise compute pad_len to satisfy 2cp_size divisibility.

nemo_rl/models/policy/interfaces.py (1)

37-42: TopkLogitsOutputSpec looks good.

Shape contracts are explicit and match downstream usage.

nemo_rl/models/policy/dtensor_policy_worker.py (1)

799-801: LGTM! Clean implementation of context parallel group passing.

The addition of the context_parallel_group parameter to the loss function when CP is enabled is correct and follows the established pattern for distributed training.

nemo_rl/models/policy/lm_policy.py (1)

336-400: Clean implementation of top-k logits API.

The get_topk_logits method properly handles sharding, dispatching to workers, and result aggregation. The manual tensor concatenation approach correctly preserves the 3D tensor shape for top-k results.

tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh (1)

1-42: Fix duplicate shell-script issues and confirm aggressive loss target

  • Apply the same fixes as the 1.7B test script to remove duplicated shell-script problems.
  • Files to update:
    • tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh
    • tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh
  • Both files assert 'data["train/loss"]["200"] < 1.0' — verify from past run logs that <1.0 at step 200 is achievable for these configs; if not, relax the threshold to 1.5 to match the 1.7B baseline.
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)

40-47: LGTM: Teacher CP with no sequence packing/parallel is valid

Teacher uses CP=2 with sequence_packing.disabled and sequence_parallel=false in the base; this matches DTensor worker constraints.

Also applies to: 55-61, 114-121

examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml (2)

37-45: LGTM: Long-context recipe respects CP/SP constraints

Sequence packing disabled; teacher CP=2 without sequence_parallel. Looks consistent for long-context runs.

Also applies to: 55-61, 114-121


74-85: Scheduler shape matches 1k steps

Linear warmup 100 then cosine to 900 fits max_num_steps: 1000. LGTM.

examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml (1)

40-47: LGTM: Conservative TP/CP settings with packing disabled

Teacher CP=2, packing and sequence_parallel disabled; aligns with worker constraints.

Also applies to: 55-61, 114-121

examples/run_distillation_math.py (3)

246-251: Warn-only when generation config missing

OK to proceed, but most setups require generation. You may want to error out if on-policy is intended.

Do you ever run this script off-policy? If not, prefer raising when policy.generation is missing.


239-243: Tokenizer source: pass through tokenizer.chat_template via config

If a custom chat_template is provided in config, ensure get_tokenizer applies it (it does per utils). Just confirming this is intentional.


262-272: Good: Passing tokenizer into setup/distillation_train

This unblocks token-level loss paths and stop-string wiring.

nemo_rl/algorithms/loss_functions.py (1)

160-174: Casting logits to float32 for stability: good

Prevents numerical issues in softmax/log-softmax across losses.

examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml (1)

74-85: Scheduler chain: ensure milestones pairs with SequentialLR

If your scheduler builder expects a SequentialLR wrapper, the lone - milestones: [20] entry is fine. If not, this list item without a name may be ignored. Please confirm the parser behavior.

Would you like me to patch the scheduler builder to explicitly wrap LinearLR + ConstantLR in SequentialLR(milestones=[20])?

Comment thread examples/run_distillation_math.py Outdated
Comment thread nemo_rl/algorithms/distillation.py Outdated
Comment thread nemo_rl/algorithms/distillation.py
Comment thread tests/test_distillation_simple.py Outdated
Comment thread tests/unit/algorithms/test_loss_functions.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
tests/test_distillation_simple.py (1)

1-13: Header fixed; previous compliance issue resolved.

Copyright header added and the stray shebang removed.

🧹 Nitpick comments (6)
tests/test_distillation_simple.py (6)

21-21: Avoid mutating sys.path in tests.

Rely on standard import paths; this can mask packaging issues and cause brittle CI.

Apply this diff to remove the path hack:

-sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

31-35: Replace prints/returns with real assertions.

Ensure the test fails if imports break.

 def test_basic_imports():
-    """Test whether basic imports work"""
-    print("✅ Basic import test passed")
-    return True
+    """Basic import smoke test"""
+    assert DistillationLossFn is not None
+    assert torch is not None

149-165: Remove unused StatefulDataLoader mocks (and the import).

They’re unused and the current __iter__ stubs would be called without self binding if used. Also drops Ruff ARG001 warnings.

-from torchdata.stateful_dataloader import StatefulDataLoader
@@
-    # Create mock dataloaders
-    train_dataloader = MagicMock(spec=StatefulDataLoader)
-
-    def train_iter(self):
-        return iter([mock_batch] * 10)
-
-    train_dataloader.__iter__ = train_iter
-    train_dataloader.__len__ = MagicMock(return_value=10)
-
-    val_dataloader = MagicMock(spec=StatefulDataLoader)
-
-    def val_iter(self):
-        return iter([mock_batch] * 5)
-
-    val_dataloader.__iter__ = val_iter
-    val_dataloader.__len__ = MagicMock(return_value=5)

170-187: Exercise DistillationLossFn and assert outputs; drop unused locals.

Uses your mock batch to compute a real loss and removes unused variables flagged by Ruff.

     loss_fn = DistillationLossFn(
         {
             "temperature": 1.0,
             "alpha": 0.5,
             "kl_type": "forward",
             "mixed_kl_weight": 0.5,
             "zero_outside_topk": False,
         }
     )
-
-    logger = MagicMock()
-    checkpointer = MagicMock()
-
-    task_to_env = {"math": MagicMock()}
-    val_task_to_env = {"math": MagicMock()}
-
-    print("✅ Mock component creation test passed")
-    return True
+    # Compute loss with mock data
+    student_logits = student_generation.generate.return_value["logits"]
+    V = student_logits.shape[-1]
+    # Ensure teacher indices are within student vocab size
+    assert int(mock_batch["teacher_topk_indices"].max()) < V
+    global_valid_seqs = mock_batch["sample_mask"].sum().to(dtype=torch.int64)
+    global_valid_toks = mock_batch["token_mask"][:, 1:].sum().to(dtype=torch.int64)
+    loss, metrics = loss_fn(
+        student_logits,
+        mock_batch,
+        global_valid_seqs=global_valid_seqs,
+        global_valid_toks=global_valid_toks,
+    )
+    assert torch.isfinite(loss)
+    assert "loss" in metrics and "num_valid_samples" in metrics

190-212: Drop custom main harness in pytest tests.

Pytest will collect and run these; the try/except triggers Ruff TRY300/BLE001. Removing reduces noise.

-def main():
-    """Run all tests"""
-    print("Starting simplified distillation tests...")
-
-    try:
-        test_basic_imports()
-        test_config_creation()
-        test_mock_components()
-
-        print("\n🎉 All tests passed! Distillation test configuration is correct.")
-        return 0
-
-    except Exception as e:
-        print(f"\n❌ Test failed: {e}")
-        import traceback
-
-        traceback.print_exc()
-        return 1
-
-
-if __name__ == "__main__":
-    sys.exit(main())
+# Pytest will discover and run test_* functions automatically.

145-147: Optional: Add a second path test with zero_outside_topk=True.

Covers the log-prob path and entropy correction branch in DistillationLossFn.

I can draft a small test_distillation_loss_zero_outside_topk that reuses mock_batch with loss_fn = DistillationLossFn({..., "zero_outside_topk": True}) and asserts finite loss. Want me to add it?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between eae6290 and c330f2d.

📒 Files selected for processing (1)
  • tests/test_distillation_simple.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/test_distillation_simple.py (1)
nemo_rl/algorithms/loss_functions.py (1)
  • DistillationLossFn (849-1104)
🪛 Ruff (0.12.2)
tests/test_distillation_simple.py

152-152: Unused function argument: self

(ARG001)


160-160: Unused function argument: self

(ARG001)


170-170: Local variable loss_fn is assigned to but never used

Remove assignment to unused variable loss_fn

(F841)


180-180: Local variable logger is assigned to but never used

Remove assignment to unused variable logger

(F841)


181-181: Local variable checkpointer is assigned to but never used

Remove assignment to unused variable checkpointer

(F841)


183-183: Local variable task_to_env is assigned to but never used

Remove assignment to unused variable task_to_env

(F841)


184-184: Local variable val_task_to_env is assigned to but never used

Remove assignment to unused variable val_task_to_env

(F841)


200-200: Consider moving this statement to an else block

(TRY300)


202-202: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
tests/test_distillation_simple.py (1)

37-95: Config shape checks look good.

Covers key sections and nested generation.colocated.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
nemo_rl/algorithms/distillation.py (1)

242-273: Non‑colocated cluster resource math creates train_nodes/train_gpus_per_node = 0 (fatal).

  • Single-node: defaults set inference_nodes=1train_nodes=0.
  • Multi-node: defaults set inference_gpus_per_node=gpus_per_node and then compute train_gpus_per_node = gpus_per_node - inference_gpus_per_node = 0.

This will fail cluster creation.

Minimal safe fix:

  • Disallow non‑colocated when num_nodes == 1.
  • In multi‑node mode, allocate full GPUs per node for each cluster and split nodes (not GPUs within a node). Guard that inference_nodes < num_nodes.
@@
-        inference_resources = generation_config["colocated"]["resources"]
-        inference_gpus_per_node = inference_resources["gpus_per_node"]
-        inference_nodes = inference_resources["num_nodes"]
+        inference_resources = generation_config["colocated"]["resources"]
+        inference_gpus_per_node = inference_resources.get("gpus_per_node")
+        inference_nodes = inference_resources.get("num_nodes")
@@
-        if cluster_config["num_nodes"] == 1:
-            if inference_gpus_per_node is None:
-                inference_gpus_per_node = cluster_config["gpus_per_node"] // 2
-            if inference_nodes is None:
-                inference_nodes = 1
+        if cluster_config["num_nodes"] == 1:
+            raise ValueError(
+                "Non-colocated inference requires num_nodes >= 2. "
+                "Set policy.generation.colocated.enabled: true for single-node runs."
+            )
         else:
             if inference_gpus_per_node is None:
                 inference_gpus_per_node = cluster_config["gpus_per_node"]
             if inference_nodes is None:
-                inference_nodes = cluster_config["num_nodes"] // 2
+                inference_nodes = max(1, cluster_config["num_nodes"] // 2)
@@
-        if inference_gpus_per_node > cluster_config["gpus_per_node"]:
+        if inference_gpus_per_node > cluster_config["gpus_per_node"]:
             raise ValueError(
                 f"Inference GPUs per node ({inference_gpus_per_node}) cannot be greater than "
                 f"total GPUs per node ({cluster_config['gpus_per_node']})"
             )
-        if inference_nodes > cluster_config["num_nodes"]:
+        if inference_nodes >= cluster_config["num_nodes"]:
             raise ValueError(
-                f"Inference nodes ({inference_nodes}) cannot be greater than "
-                f"total nodes ({cluster_config['num_nodes']})"
+                "Non-colocated mode must reserve at least 1 node for training: "
+                f"inference_nodes={inference_nodes}, total_nodes={cluster_config['num_nodes']}"
             )
@@
-        train_gpus_per_node = cluster_config["gpus_per_node"] - inference_gpus_per_node
-        train_nodes = cluster_config["num_nodes"] - inference_nodes
+        # Split by nodes, not by GPUs within a node.
+        train_gpus_per_node = cluster_config["gpus_per_node"]
+        train_nodes = cluster_config["num_nodes"] - inference_nodes

Also applies to: 275-291

🧹 Nitpick comments (5)
nemo_rl/algorithms/distillation.py (5)

528-539: Drop unused enumerate indices.

Removes B007 warnings and clarifies intent.

-                    for i, message_log in enumerate(repeated_batch["message_log"]):
-                        for j, message in enumerate(message_log):
+                    for message_log in repeated_batch["message_log"]:
+                        for message in message_log:

654-658: Add stacklevel to warning for correct source attribution.

-                            warnings.warn(
+                            warnings.warn(
                                 f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
                                 "Saving most recent k checkpoints instead."
-                            )
+                            , stacklevel=2)

739-748: Compute tokens/sec per GPU using training GPU count.

Using all cluster GPUs undercounts in non‑colocated mode. Prefer num_ranks from train results.

-            total_num_gpus = (
-                master_config["cluster"]["num_nodes"]
-                * master_config["cluster"]["gpus_per_node"]
-            )
+            total_num_gpus = int(
+                train_results.get("num_ranks", 
+                    master_config["cluster"]["num_nodes"] * master_config["cluster"]["gpus_per_node"]
+                )
+            )

849-863: Narrow overly broad exception and use explicit f-string conversion.

Prefer a narrower exception (or at least log via logger) and use !s formatter.

-        except Exception as e:
-            print(f"\n  ⚠️ Error displaying message samples: {str(e)}")
+        except Exception as e:
+            print(f"\n  ⚠️ Error displaying message samples: {e!s}")
             print("  ⚠️ Continuing validation without displaying samples...")

320-323: Nit: Duplicate “Setting up models…” banner.

Second banner is confusing; consider “▶ Setting up teacher model…”.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c330f2d and 596d3eb.

📒 Files selected for processing (1)
  • nemo_rl/algorithms/distillation.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/algorithms/distillation.py (5)
nemo_rl/algorithms/grpo.py (2)
  • _should_use_async_rollouts (406-420)
  • refit_policy_generation (423-491)
nemo_rl/algorithms/loss_functions.py (3)
  • DistillationLossConfig (831-836)
  • DistillationLossDataDict (839-846)
  • DistillationLossFn (849-1104)
nemo_rl/data/datasets.py (1)
  • rl_collate_fn (134-178)
nemo_rl/data/llm_message_utils.py (1)
  • batched_message_log_to_flat_message (233-390)
nemo_rl/distributed/batched_data_dict.py (4)
  • BatchedDataDict (75-857)
  • repeat_interleave (721-742)
  • size (811-820)
  • to (822-829)
🪛 Ruff (0.12.2)
nemo_rl/algorithms/distillation.py

260-263: Avoid specifying long messages outside the exception class

(TRY003)


265-268: Avoid specifying long messages outside the exception class

(TRY003)


528-528: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


529-529: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


654-654: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


860-860: Do not catch blind exception: Exception

(BLE001)


861-861: Use explicit conversion flag

Replace with conversion flag

(RUF010)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/distillation.py (1)

153-156: Good assert for required generation config.

Keeps misconfiguration from slipping through.

Comment thread nemo_rl/algorithms/distillation.py
Comment thread nemo_rl/algorithms/distillation.py Outdated
Comment thread nemo_rl/algorithms/distillation.py Outdated
@github-actions
Copy link
Copy Markdown

⚠️ File Consistency Check

Check based on commit: a80fb54 (PR #1006 from feat-distillation)

This is a test comment


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
examples/run_distillation_math.py (1)

203-214: Guard validation split; fallback to 'test' when no 'validation'.

-    val_dataset: Optional[AllTaskProcessedDataset] = None
-    if data.formatted_ds["validation"]:
-        val_dataset = AllTaskProcessedDataset(
-            data.formatted_ds["validation"],
-            tokenizer,
-            math_task_spec,
-            task_data_processors,
-            max_seq_length=data_config["max_input_seq_length"],
-        )
-    else:
-        val_dataset = None
+    val_dataset: Optional[AllTaskProcessedDataset] = None
+    if "validation" in data.formatted_ds and data.formatted_ds["validation"]:
+        val_src = data.formatted_ds["validation"]
+    elif "test" in data.formatted_ds and data.formatted_ds["test"]:
+        val_src = data.formatted_ds["test"]
+    else:
+        val_src = None
+    if val_src is not None:
+        val_dataset = AllTaskProcessedDataset(
+            val_src,
+            tokenizer,
+            math_task_spec,
+            task_data_processors,
+            max_seq_length=data_config["max_input_seq_length"],
+        )
nemo_rl/algorithms/distillation.py (1)

246-276: Prevent 0-GPU/0-node train cluster in non-colocated mode (single- and multi-node).

When num_nodes==1, train_gpus_per_node can become 0. For num_nodes>1, train_nodes can become 0 after subtraction. Guard both and give actionable errors.

Apply:

         inference_resources = generation_config["colocated"]["resources"]
         inference_gpus_per_node = inference_resources["gpus_per_node"]
         inference_nodes = inference_resources["num_nodes"]

         # validate and configure resources
         if cluster_config["num_nodes"] == 1:
             assert inference_gpus_per_node > 0, (
                 "policy.generation.colocated.resources.gpus_per_node must be > 0 "
                 "when cluster.num_nodes = 1 and inference is non-colocated, "
                 f"but got {inference_gpus_per_node}."
             )
             assert inference_nodes is None or inference_nodes == 1, (
                 "policy.generation.colocated.resources.num_nodes must be 1 or set to null "
                 "when cluster.num_nodes = 1 and inference is non-colocated, "
                 f"but got {inference_nodes}."
             )
-            inference_nodes = 1
-            train_gpus_per_node -= inference_gpus_per_node
+            inference_nodes = 1
+            train_nodes = 1
+            remaining_gpus = cluster_config["gpus_per_node"] - inference_gpus_per_node
+            if remaining_gpus <= 0:
+                raise ValueError(
+                    "Non-colocated mode (single node) requires reserving >=1 GPU for training: "
+                    f"cluster.gpus_per_node={cluster_config['gpus_per_node']}, "
+                    f"inference_gpus_per_node={inference_gpus_per_node}."
+                )
+            train_gpus_per_node = remaining_gpus
         else:
             assert inference_nodes > 0, (
                 "policy.generation.colocated.resources.num_nodes must be > 0 "
                 "when cluster.num_nodes > 1 and inference is non-colocated, "
                 f"but got {inference_nodes}."
             )
             assert (
                 inference_gpus_per_node is None
                 or inference_gpus_per_node == cluster_config["gpus_per_node"]
             ), (
                 "policy.generation.colocated.resources.gpus_per_node must be equal to cluster.gpus_per_node or set to null "
                 "when cluster.num_nodes > 1 and inference is non-colocated, "
                 f"but got {inference_gpus_per_node}."
             )
             inference_gpus_per_node = cluster_config["gpus_per_node"]
-            train_nodes -= inference_nodes
+            train_nodes -= inference_nodes
+            if train_nodes <= 0:
+                raise ValueError(
+                    "Non-colocated mode (multi-node) requires at least 1 train node; "
+                    f"decrease inference_nodes (got inference_nodes={inference_nodes})."
+                )
🧹 Nitpick comments (11)
nemo_rl/algorithms/distillation.py (7)

138-142: Align setup() docstring with actual return tuple (include teacher_policy).

Docstring omits teacher_policy and order. Update for correctness.

-    Returns:
-        tuple of student_policy, student_generation,
-        train_dataloader, val_dataloader,
-        loss_fn, logger, checkpointer, distillation_save_state, master_config
+    Returns:
+        tuple of (
+            student_policy,
+            teacher_policy,
+            student_generation,
+            train_dataloader,
+            val_dataloader,
+            loss_fn,
+            logger,
+            checkpointer,
+            distillation_save_state,
+            master_config,
+        )

299-301: Clarify setup logs; avoid duplicate “Setting up models…” message.

Differentiate student vs. teacher for easier debugging.

-print("\n▶ Setting up models...")
+print("\n▶ Setting up student model...")
...
-print("\n▶ Setting up models...")
+print("\n▶ Setting up teacher model...")

Also applies to: 323-325


530-541: Remove unused loop indices (ruff B007) and simplify loops.

-                    for i, message_log in enumerate(repeated_batch["message_log"]):
-                        for j, message in enumerate(message_log):
+                    for message_log in repeated_batch["message_log"]:
+                        for message in message_log:

652-657: warnings.warn without stacklevel; set stacklevel=2 for correct source.

-                            warnings.warn(
+                            warnings.warn(
                                 f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
                                 "Saving most recent k checkpoints instead."
-                            )
+                            , stacklevel=2)

743-747: Guard tokens/sec against total_time==0.

-            metrics.update(
-                {
-                    "tokens_per_sec_per_gpu": metrics["total_num_tokens"]
-                    / total_time
-                    / total_num_gpus
-                }
-            )
+            tps = (
+                metrics["total_num_tokens"] / total_time / total_num_gpus
+                if total_time > 0
+                else 0.0
+            )
+            metrics.update({"tokens_per_sec_per_gpu": tps})

789-796: Don’t drop remainder validation samples; ceil the batch count.

-        max_batches = (
-            master_config["distillation"]["max_val_samples"]
-            // master_config["distillation"]["val_batch_size"]
-        )
+        max_batches = (
+            master_config["distillation"]["max_val_samples"]
+            + master_config["distillation"]["val_batch_size"]
+            - 1
+        ) // master_config["distillation"]["val_batch_size"]

851-865: Narrow exception scope for sample printing and use !s; prefer warning.

-        try:
+        try:
             print_message_log_samples(
                 all_message_logs,
                 total_rewards,
                 num_samples=min(
                     master_config["logger"]["num_val_samples_to_print"],
                     len(all_message_logs),
                 ),
                 step=step,
             )
-        except Exception as e:
-            print(f"\n  ⚠️ Error displaying message samples: {str(e)}")
-            print("  ⚠️ Continuing validation without displaying samples...")
+        except (KeyError, IndexError, TypeError) as e:
+            warnings.warn(f"Error displaying message samples: {e!s}", stacklevel=2)
examples/run_distillation_math.py (4)

78-97: Simplify prompt check; remove unused read and chain errors.

The file read is unused and the code raises regardless. Keep a clear validation.

-    # safety check: ensure prompt exists
-    if task_data_spec.prompt is None:
-        if task_data_spec.prompt_file:
-            if os.path.exists(task_data_spec.prompt_file):
-                try:
-                    with open(task_data_spec.prompt_file, "r", encoding="utf-8") as f:
-                        content = f.read()
-                except Exception as e:
-                    raise ValueError(f"Failed to read file: {e}")
-            else:
-                raise ValueError(
-                    f"Prompt file does not exist: {task_data_spec.prompt_file}"
-                )
-
-        raise ValueError(
-            f"TaskDataSpec.prompt is None. This usually means the prompt file "
-            f"'{task_data_spec.prompt_file}' could not be loaded or is empty. "
-            f"Current working directory: {os.getcwd()}, "
-            f"Absolute prompt file path: {os.path.abspath(task_data_spec.prompt_file) if task_data_spec.prompt_file else 'None'}"
-        )
+    # safety check: ensure prompt exists
+    if not task_data_spec.prompt:
+        abs_prompt = (
+            os.path.abspath(task_data_spec.prompt_file)
+            if task_data_spec.prompt_file
+            else "None"
+        )
+        raise ValueError(
+            f"TaskDataSpec.prompt is None. Ensure a valid prompt_file (resolved path: {abs_prompt})."
+        )

104-119: Use exception chaining and correct apply_chat_template typing.

-    try:
-        formatted_content = task_data_spec.prompt.format(problem)
-    except Exception as e:
-        raise ValueError(f"Failed to format prompt: {e}")
+    try:
+        formatted_content = task_data_spec.prompt.format(problem)
+    except Exception as e:
+        raise ValueError("Failed to format prompt") from e
@@
-    message: list[str] = tokenizer.apply_chat_template(  # type: ignore
+    chat_text: str = tokenizer.apply_chat_template(  # type: ignore
         [user_message],
         tokenize=False,
         add_generation_prompt=True,
         add_special_tokens=False,
     )
@@
-    user_message["token_ids"] = tokenizer(
-        message,
+    user_message["token_ids"] = tokenizer(
+        chat_text,
         return_tensors="pt",
         add_special_tokens=False,
     )["input_ids"][0]
-    user_message["content"] = message
+    user_message["content"] = chat_text

Also applies to: 120-126


147-148: Avoid KeyError for task_name; use spec default.

-        "task_name": datum_dict["task_name"],
+        "task_name": task_data_spec.task_name or "math",

258-258: Use configured data.seed instead of hard-coded 42.

-    ) = setup_data(tokenizer, config["data"], config["env"], 42)
+    ) = setup_data(
+        tokenizer,
+        config["data"],
+        config["env"],
+        config["data"].get("seed", 42),
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 596d3eb and a80fb54.

📒 Files selected for processing (2)
  • examples/run_distillation_math.py (1 hunks)
  • nemo_rl/algorithms/distillation.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/run_distillation_math.py (12)
nemo_rl/algorithms/distillation.py (3)
  • MasterConfig (100-111)
  • distillation_train (386-765)
  • setup (119-378)
nemo_rl/algorithms/utils.py (1)
  • get_tokenizer (151-267)
nemo_rl/data/__init__.py (1)
  • DataConfig (18-35)
nemo_rl/data/datasets.py (1)
  • AllTaskProcessedDataset (36-131)
nemo_rl/data/hf_datasets/deepscaler.py (1)
  • DeepScalerDataset (67-78)
nemo_rl/data/hf_datasets/openmathinstruct2.py (1)
  • OpenMathInstruct2Dataset (77-105)
nemo_rl/data/interfaces.py (3)
  • DatumSpec (32-40)
  • TaskDataProcessFnCallable (89-100)
  • TaskDataSpec (53-86)
nemo_rl/distributed/ray_actor_environment_registry.py (1)
  • get_actor_python_env (43-58)
nemo_rl/distributed/virtual_cluster.py (1)
  • init_ray (75-161)
nemo_rl/environments/math_environment.py (1)
  • MathEnvironment (222-372)
nemo_rl/utils/logger.py (1)
  • get_next_experiment_dir (1222-1256)
nemo_rl/models/generation/__init__.py (1)
  • configure_generation_config (24-45)
nemo_rl/algorithms/distillation.py (11)
nemo_rl/algorithms/grpo.py (2)
  • _should_use_async_rollouts (406-420)
  • refit_policy_generation (423-491)
nemo_rl/algorithms/loss_functions.py (3)
  • DistillationLossConfig (831-836)
  • DistillationLossDataDict (839-846)
  • DistillationLossFn (849-1104)
nemo_rl/data/datasets.py (1)
  • rl_collate_fn (134-178)
nemo_rl/data/llm_message_utils.py (1)
  • batched_message_log_to_flat_message (233-390)
nemo_rl/distributed/batched_data_dict.py (3)
  • BatchedDataDict (75-857)
  • size (811-820)
  • to (822-829)
nemo_rl/experience/rollouts.py (2)
  • run_async_multi_turn_rollout (751-895)
  • run_multi_turn_rollout (316-522)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
  • VllmGeneration (47-784)
nemo_rl/models/policy/lm_policy.py (8)
  • Policy (56-696)
  • finish_generation (554-556)
  • prepare_refit_info (562-571)
  • prepare_for_generation (539-541)
  • prepare_for_lp_inference (548-552)
  • get_topk_logits (336-400)
  • offload_after_refit (650-653)
  • train (402-494)
nemo_rl/utils/checkpoint.py (6)
  • CheckpointingConfig (35-52)
  • CheckpointManager (55-269)
  • get_latest_checkpoint_path (238-251)
  • load_training_info (253-269)
  • init_tmp_checkpoint (87-126)
  • finalize_checkpoint (128-157)
nemo_rl/utils/logger.py (4)
  • Logger (710-933)
  • LoggerConfig (69-79)
  • print_message_log_samples (1022-1219)
  • log_batched_dict_as_jsonl (804-828)
nemo_rl/utils/timer.py (7)
  • TimeoutChecker (264-321)
  • Timer (22-248)
  • start_iterations (310-311)
  • time (110-123)
  • mark_iteration (313-321)
  • check_save (284-308)
  • get_timing_metrics (196-233)
🪛 Ruff (0.12.2)
examples/run_distillation_math.py

83-83: Local variable content is assigned to but never used

Remove assignment to unused variable content

(F841)


84-84: Do not catch blind exception: Exception

(BLE001)


85-85: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


85-85: Avoid specifying long messages outside the exception class

(TRY003)


87-89: Avoid specifying long messages outside the exception class

(TRY003)


91-96: Avoid specifying long messages outside the exception class

(TRY003)


106-106: Do not catch blind exception: Exception

(BLE001)


107-107: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


107-107: Avoid specifying long messages outside the exception class

(TRY003)


180-180: Avoid specifying long messages outside the exception class

(TRY003)

nemo_rl/algorithms/distillation.py

530-530: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


531-531: Loop control variable j not used within loop body

Rename unused j to _j

(B007)


652-652: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


862-862: Do not catch blind exception: Exception

(BLE001)


863-863: Use explicit conversion flag

Replace with conversion flag

(RUF010)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Post submodule check comment / Comment on PR

Comment thread examples/configs/distillation_math.yaml
@zpqiu zpqiu added the r0.4.0 label Sep 18, 2025
@terrykong terrykong requested a review from yuki-97 September 18, 2025 01:02
@terrykong
Copy link
Copy Markdown
Collaborator

terrykong commented Sep 18, 2025

@zpqiu please let @yuki-97 know when to review!

@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Sep 18, 2025

@zpqiu please let @yuki-97 know when to review!

Okay, sure. Today, we will first resolve the above comments and then ask yuki to review.

@zpqiu zpqiu requested review from a team as code owners September 18, 2025 03:37
@zpqiu zpqiu removed request for a team September 18, 2025 03:38
@zpqiu zpqiu added the CI:L1 Run doctests, unit tests, and functional tests label Sep 29, 2025
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: b9695fd (PR #1006 from feat-distillation)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: ae4774c (PR #1006 from feat-distillation)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: f751d35 (PR #1006 from feat-distillation)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 101c093 (PR #1006 from feat-distillation)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: c827663 (PR #1006 from feat-distillation)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zpqiu zpqiu added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Sep 29, 2025
@terrykong terrykong merged commit 17ea9ab into NVIDIA-NeMo:main Sep 29, 2025
40 of 42 checks passed
@euronymous-aithal
Copy link
Copy Markdown
Contributor

@zpqiu is it possible to write a quick blog on this feature ?

@zpqiu
Copy link
Copy Markdown
Contributor Author

zpqiu commented Oct 1, 2025

@zpqiu is it possible to write a quick blog on this feature ?

Sure, will do after the holiday.

PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: shuo_nvidia <shuoyang@nvidia.com>
Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com>
Signed-off-by: shuo-nvidia <shuoyang@nvidia.com>
Co-authored-by: shuo_nvidia <shuoyang@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: shuo_nvidia <shuoyang@nvidia.com>
Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com>
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com>
Signed-off-by: shuo-nvidia <shuoyang@nvidia.com>
Co-authored-by: shuo_nvidia <shuoyang@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

algorithm CI:L1 Run doctests, unit tests, and functional tests r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature Request]On-policy distillation support

7 participants