feat: add on policy distillation algorithm#1006
Conversation
d2ec99c to
0bfecd2
Compare
|
Hi, is this PR implementing the on-policy distillation described in Qwen3 paper? What is the status of this PR? What I can try for now? |
Hi @xxman-google , we are glad that you are interested in our PR. Yes, we aim to implement on-policy distillation as described in Qwen3. The basic functionalities are now in place, supporting TP, CP, sequence packing, and dynamic batching. You can use the command In the first commit, we conducted some preliminary experiments using Qwen3-32B as the teacher model, and we observed improvements in mathematical tasks. As shown in the figure below: In the latest commits, we have mainly been refining the functionalities. The experiments are based on a single-node environment and small-scale models. We are currently perfecting the test cases and validating the latest version of the code on the 32B model. |
📝 WalkthroughWalkthroughAdds an end‑to‑end on‑policy distillation pipeline: new distillation configs/recipes, an example runner, a Distillation algorithm (setup/train/validate), a distributed top‑k-aware DistillationLossFn and distributed utilities, policy top‑k APIs/workers, batched-data handling updates, and unit/functional test scripts. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as CLI
participant Run as run_distillation_math.py
participant Setup as setup()
participant Student as Student Policy
participant Teacher as Teacher Policy
participant Gen as Generation (vLLM)
participant Data as DataLoader
participant Loss as DistillationLossFn
participant Log as Logger/Checkpoint
CLI->>Run: parse config & overrides
Run->>Setup: load config, tokenizer, datasets, envs
Setup->>Student: init student policy + optimizer
Setup->>Teacher: init teacher policy (inference)
Setup-->>Run: return components
loop training step
Run->>Data: sample prompts
Run->>Gen: generate responses
Gen-->>Run: message_log
Run->>Teacher: get_topk_logits(k)
Teacher-->>Run: teacher_topk_logits & indices
Run->>Loss: prepare data dict (student logits + teacher topk)
Run->>Student: train_step via loss_fn
Student-->>Run: loss & metrics
Run->>Log: log metrics & timings
alt checkpoint interval
Run->>Log: save checkpoint (model, opt, tokenizer, state)
end
alt validation interval
Run->>Gen: validation rollouts
Gen->>Log: val metrics
end
end
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120–180 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
bff93b1 to
eae6290
Compare
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unit/algorithms/test_loss_functions.py (1)
1414-1858: Remove duplicate distillation test definitions.The entire distillation test suite (from
setup_distillation_test_datathrough all test functions) appears to be duplicated in the patch. This creates redundant code that will make maintenance harder.Remove the duplicate block of distillation tests. Keep only one set of the test functions. The duplication appears to start at line 1414 and continues through the end of the file.
🧹 Nitpick comments (51)
nemo_rl/distributed/batched_data_dict.py (3)
134-159: 3D path: validate shapes/devices and guard against silent shape drift.Add basic consistency checks (dtype, device, feature dim) and avoid negative padding. This prevents subtle crashes when a worker emits a different k (top‑k) or dtype.
- if list_of_tensors[0].ndim == 3: - # 对三维张量,只在序列维度(这里是第1维)补 pad,保留特征维度 - max_seq_len = max(tensor.shape[1] for tensor in list_of_tensors) + if list_of_tensors[0].ndim == 3: + # For 3D tensors shaped [B, S, K], pad only along the sequence dim (dim=1), keep feature dim intact. + base = list_of_tensors[0] + # sanity checks across workers/microbatches + assert all(t.ndim == 3 for t in list_of_tensors), "Expected 3D tensors" + assert all(t.dtype == base.dtype for t in list_of_tensors), "Mismatched dtypes in from_batches" + assert all(t.device == base.device for t in list_of_tensors), "Mismatched devices in from_batches" + assert len({t.shape[2] for t in list_of_tensors}) == 1, "Last dim (K) must match across tensors" + max_seq_len = max(tensor.shape[1] for tensor in list_of_tensors) padded_tensors = [] for tensor in list_of_tensors: - # 在第1维补 pad 到 max_seq_len - pad_length = max_seq_len - tensor.shape[1] + # Right-pad dim=1 up to max_seq_len + pad_length = max(0, max_seq_len - tensor.shape[1]) padded = torch.nn.functional.pad( tensor, - (0, 0, 0, pad_length), # 只补最后二个维度(序列长度) + (0, 0, 0, pad_length), # pad last two dims, second-last is sequence mode="constant", value=pad_value, ) padded_tensors.append(padded) - tensor_or_list = torch.cat(padded_tensors, dim=0) # 在批次维度拼接 + tensor_or_list = torch.cat(padded_tensors, dim=0) # concat on batch dim
134-148: Replace non-ASCII punctuation and translate inline comments to pass Ruff (RUF003) and aid maintainability.Current comments use full‑width punctuation and Chinese; convert to concise English.
- # bug出在这里,当train_data中包含3维tensor时,下面的flatten()会错误展平后两个维度。 - # grpo在调用这个函数时只会传入二维tensor,但是sft和dpo这里还需要检查有无三维特殊情况。 - # 现在添加了如下if语句处理蒸馏中的三维情况(即teacher topk logits和indices),else后为原代码。 + # Bug: when train_data contains a 3D tensor, the flatten() below incorrectly flattens the last two dims. + # GRPO sends only 2D tensors here; SFT/DPO/distillation may produce 3D tensors (e.g., teacher top‑k logits/indices). + # Handle 3D tensors below; the else branch preserves the original 2D behavior.
151-159: Masking semantics for padded top‑k indices.If these 3D tensors hold token indices, padding with 0 may produce “in‑range” ids. Ensure downstream losses mask padded positions (via input_lengths) before using indices/logits, or switch pad_value for index tensors to an out‑of‑range sentinel and mask accordingly.
nemo_rl/distributed/model_utils.py (3)
258-381: ChunkedDistributedGatherLogprob: math and TP-reduction look correct; consider memory trim.ctx.save_for_backward stores full logits and indices; with long S this is heavy. Option: save only global_indices and recompute logits from upstream (or re‑compute log_softmax from the forward input without saving the fp16->fp32 casted tensor) by toggling torch.set_grad_enabled+recompute. If you keep saving logits, document memory impact and gate with a flag.
828-897: distributed_vocab_topk: correct global top‑k; tighten a couple details.
- K_eff can be simplified to min(k, V_total).
- Consider using log_softmax only if ties/temperature matter; logits are fine for ordering.
- Prefer named args for all_gather for clarity.
- V_total = V_local * world_size - K_eff = int(min(k, max(1, V_total))) + V_total = V_local * world_size + K_eff = min(k, V_total) @@ - torch.distributed.all_gather(gathered_vals, local_vals, group=tp_group) - torch.distributed.all_gather(gathered_idx, local_idx_global, group=tp_group) + torch.distributed.all_gather(tensor_list=gathered_vals, tensor=local_vals, group=tp_group) + torch.distributed.all_gather(tensor_list=gathered_idx, tensor=local_idx_global, group=tp_group)
984-1055: ChunkedDistributedEntropy: docstring vs. sign.You return sum_v p log p (negative entropy). Either rename to “log_softmax_self_cross_term” in the doc or note that it’s negative entropy. Gradient formula matches sum p log p.
- """Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. + """Compute H_all = sum_v p_v log p_v (negative entropy) across TP with chunking over sequence.tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh (2)
16-18: Quote expansions and guard cd; fix SC2068/SC2164.-cd $PROJECT_ROOT +cd "$PROJECT_ROOT" || exit 1 @@ - $@ \ + "$@" \Also applies to: 28-28
5-11: Unused config knobs — export or remove.NUM_NODES/NUM_RUNS/NUM_MINUTES are unused here. If common.env relies on them, export; otherwise drop to silence SC2034.
-NUM_NODES=1 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=300 +export NUM_NODES=1 +export STEPS_PER_RUN=100 +export MAX_STEPS=100 +export NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +export NUM_MINUTES=300tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh (2)
16-18: Quote expansions and guard cd; fix SC2068/SC2164.-cd $PROJECT_ROOT +cd "$PROJECT_ROOT" || exit 1 @@ - $@ \ + "$@" \Also applies to: 28-28
5-11: Unused config knobs — export or remove.Same as the 1n8g script.
-NUM_NODES=2 -STEPS_PER_RUN=100 -MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=300 +export NUM_NODES=2 +export STEPS_PER_RUN=100 +export MAX_STEPS=100 +export NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +export NUM_MINUTES=300tests/functional/distillation.sh (2)
43-44: Fix command argument expansion for proper handling of additional arguments.The shellcheck warning is correct - unquoted array expansion can cause issues with arguments containing spaces or special characters.
Apply this fix to properly quote the array expansion:
- $@ \ + "$@" \
49-49: Consider adding proper error handling for metrics validation.The script continues even if the metrics validation fails. Consider whether you want to exit with a non-zero status code when the assertion fails to signal test failure to CI/CD systems.
If test failure propagation is important, consider adding:
uv run tests/check_metrics.py $JSON_METRICS \ - 'data["train/loss"]["3"] < 10.0' + 'data["train/loss"]["3"] < 10.0' || exit 1tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh (3)
28-29: Fix command argument expansion for proper handling of additional arguments.Same issue as in the functional test - unquoted array expansion can cause problems with arguments containing spaces.
Apply this fix:
- $@ \ + "$@" \
16-16: Add error handling for directory change operation.Shellcheck correctly identifies that
cdwithout error handling can lead to commands being executed in the wrong directory if the cd fails.Apply this fix to handle potential cd failure:
-cd $PROJECT_ROOT +cd $PROJECT_ROOT || exit 1
6-10: Consider exporting or using the unused configuration variables.The variables
NUM_NODES,NUM_RUNS, andNUM_MINUTESare defined but never used in the script. Either use them in the distillation command or remove them to avoid confusion.If these are intended for future use or documentation purposes, consider adding a comment:
# Configuration for multi-node setup (for reference) NUM_NODES=2 # Used by cluster orchestration NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # For multi-run scenarios NUM_MINUTES=3000 # Time budget for the experimenttests/unit/algorithms/test_loss_functions.py (1)
1806-1806: Use idiomatic boolean comparisons instead of equality checks.Static analysis correctly identifies non-idiomatic boolean comparisons in the assertions.
Apply these fixes:
- assert loss_fn.zero_outside_topk == False + assert not loss_fn.zero_outside_topk- assert loss_fn.zero_outside_topk == True + assert loss_fn.zero_outside_topkAlso applies to: 1821-1821
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh (3)
28-29: Fix command argument expansion for proper handling of additional arguments.Same issue as the other test scripts - unquoted array expansion can cause problems.
Apply this fix:
- $@ \ + "$@" \
16-16: Add error handling for directory change operation.Same as the 2n8g script - cd without error handling is risky.
Apply this fix:
-cd $PROJECT_ROOT +cd $PROJECT_ROOT || exit 1
6-10: Configuration inconsistency between nodes and comment.The comment says "4 nodes (4n8g)" but
NUM_NODES=4is defined but never used in the script.Either use the NUM_NODES variable in your orchestration or add a comment explaining why it's defined but not used:
NUM_NODES=4 # Expected by cluster orchestrator (not used directly in this script)tests/unit/algorithms/test_distillation.py (3)
89-91: Remove unusedselfparameter from nested functions.The static analysis correctly identifies that the
selfparameter in the nested functionstrain_iterandval_iteris unused since these are not methods of a class.Apply this fix to remove the unused parameter:
- def train_iter(self): + def train_iter(): return iter([mock_batch] * 10) - train_dataloader.__iter__ = train_iter + train_dataloader.__iter__ = lambda: train_iter()- def val_iter(self): + def val_iter(): return iter([mock_batch] * 10) - val_dataloader.__iter__ = val_iter + val_dataloader.__iter__ = lambda: val_iter()Also applies to: 97-101
214-214: Consider adding more comprehensive validation assertions.The test only verifies that
trainis called 5 times but doesn't validate other important aspects like teacher inference calls or data processing.Consider adding more assertions to ensure the distillation workflow is functioning correctly:
# Verify teacher was called for top-k logits assert mock_components["teacher_policy"].get_topk_logits.call_count == 5 # Verify proper preparation methods were called assert mock_components["teacher_policy"].prepare_for_lp_inference.call_count == 5 assert mock_components["student_policy"].prepare_for_training.call_count == 5
56-56: Document the None value for student_generation.Setting
student_generation = Noneis a clever way to avoid Ray-related refit issues in tests, but this deserves more explanation for future maintainers.Expand the comment to be more descriptive:
# Set student_generation to None to trigger the NEED_REFIT = False path # in distillation_train. This avoids calling refit_policy_generation which # would require Ray to be properly initialized. In this case, student_policy # will be used directly as the generation interface (Megatron backend). student_generation = Nonetests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh (1)
1-42: Fix shell script issues for production readiness.Several shell scripting issues need to be addressed:
- Variables declared but not used
- Missing error handling for
cdcommand- Improper quoting of array expansion
Apply this diff to fix the issues:
#!/bin/bash SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) source $SCRIPT_DIR/common.env # ===== BEGIN CONFIG ===== -NUM_NODES=1 +export NUM_NODES=1 # Export if used by common.env or child processes STEPS_PER_RUN=100 MAX_STEPS=100 -NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up -NUM_MINUTES=300 +# NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Unused - remove or export if needed +export NUM_MINUTES=300 # Export if used by scheduling/monitoring # ===== END CONFIG ===== exit_if_max_steps_reached # Run the experiment -cd $PROJECT_ROOT +cd $PROJECT_ROOT || exit 1 uv run examples/run_distillation_math.py \ --config $CONFIG_PATH \ distillation.max_num_steps=$MAX_STEPS \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=True \ logger.wandb.project=nemo-rl-distillation \ logger.wandb.name=$EXP_NAME \ logger.monitor_gpus=True \ logger.tensorboard_enabled=True \ checkpointing.enabled=True \ checkpointing.checkpoint_dir=$CKPT_DIR \ - $@ \ + "$@" \ 2>&1 | tee $RUN_LOGtests/test_distillation_simple.py (2)
140-153: Remove unused self parameter in nested functions.The nested functions
train_iterandval_iterdon't useselfand should be refactored.Convert to lambda functions for clarity:
- def train_iter(self): - return iter([mock_batch] * 10) - - train_dataloader.__iter__ = train_iter + train_dataloader.__iter__ = lambda: iter([mock_batch] * 10) train_dataloader.__len__ = MagicMock(return_value=10) val_dataloader = MagicMock(spec=StatefulDataLoader) - - def val_iter(self): - return iter([mock_batch] * 5) - - val_dataloader.__iter__ = val_iter + val_dataloader.__iter__ = lambda: iter([mock_batch] * 5) val_dataloader.__len__ = MagicMock(return_value=5)
158-172: Remove unused variables.Several variables are created but never used in the test, making it unclear what is being tested.
Either use these variables in assertions or remove them:
- loss_fn = DistillationLossFn( + # Test loss function initialization + loss_fn = DistillationLossFn( { "temperature": 1.0, "alpha": 0.5, "kl_type": "forward", "mixed_kl_weight": 0.5, "zero_outside_topk": False, } ) + assert loss_fn.temperature == 1.0 + assert loss_fn.alpha == 0.5 - logger = MagicMock() - checkpointer = MagicMock() + # These would be used in actual distillation training + # logger = MagicMock() + # checkpointer = MagicMock()nemo_rl/models/policy/dtensor_policy_worker.py (2)
1185-1524: Well-structured distributed top-k implementation with minor improvements needed.The
get_topk_logitsmethod is well-implemented with proper support for TP, CP, dynamic batching, and sequence packing. The distributed top-k computation and gathering logic is correct.A few minor improvements:
- Remove duplicate code block (lines 1452-1489 duplicate lines 1413-1450)
- Fix unused variable at line 1335
- Remove duplicate model.eval() at line 1213
Apply this diff:
@@ -1210,7 +1210,6 @@ class DTensorPolicyWorker: self.model.eval() out_topk_vals = [] out_topk_idx = [] - self.model.eval() with unshard_fsdp2_model(self.model), torch.no_grad(): @@ -1333,13 +1332,6 @@ class DTensorPolicyWorker: # IMPORTANT: do not apply generation temperature scaling here for teacher top-k if self.cp_size > 1: - seq_index_tensor = ( - DTensor.from_local( - seq_index, - device_mesh=self.cp_mesh, - placements=[Shard(1)], - ) - .full_tensor() - .squeeze(0) - ) + # seq_index_tensor would be used if we needed CP-aware indexing + # Currently handled internally by distributed_vocab_topk if isinstance(logits, DTensor): @@ -1450,40 +1442,6 @@ class DTensorPolicyWorker: batch_size = original_batch_size seq_len = original_seq_len - # Handle sequence packing unpacking - if self.enable_seq_packing: - # Unpack top-k results from packed format back to original batch format - # ... duplicate block removed ... - # Keep only real sequence tokens (mask padded positions)
1243-1245: Remove unused loop variable.The
batch_idxvariable is not used within the loop body.- for batch_idx, lp_batch in enumerate( + for _, lp_batch in enumerate( itertools.chain(mb_iterator, dummy_iterator) ):nemo_rl/models/policy/lm_policy.py (1)
388-395: Consider adding strict parameter to zip for safety.While the current implementation ensures matching lengths through the worker dispatch mechanism, adding
strict=Truewould make this guarantee explicit.all_topk_indices = [wb["topk_indices"] for wb in worker_batches] stacked: BatchedDataDict[TopkLogitsOutputSpec] = BatchedDataDict() stacked["topk_logits"] = torch.cat(all_topk_logits, dim=0) stacked["topk_indices"] = torch.cat(all_topk_indices, dim=0)Note: The
zip()warning from static analysis at line 1501 in dtensor_policy_worker.py should also be addressed:- for vals, idx in zip(out_topk_vals, out_topk_idx): + for vals, idx in zip(out_topk_vals, out_topk_idx, strict=True):examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml (1)
86-93: Optional: Cap generation length to avoid pathological rollouts
max_new_tokens: ${..max_total_sequence_length}can make on-policy rollouts as long as the full train length (8192). Consider capping generation (e.g., 1024–2048) for faster iterations unless long reasoning chains are explicitly required.examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)
62-69: Optimizer defaults OK for small stepsAdamW settings look fine for short runs; consider scheduler warmup length scaling with
max_num_stepsif you extend training.examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml (1)
86-101: Optional: Consider a smaller generation cap for throughput
max_new_tokens: 32768will be very slow. If you don't need full-length generations during training, cap to a smaller value and reserve 32k for eval.examples/configs/distillation_math.yaml (3)
64-69: Comment contradicts formula for make_sequence_length_divisible_byComment says “must be divisible by 2*cp”, but the expression enforces
2 * tp * cp. Align the comment or the expression.Apply one of:
- Update comment:
- # must be divisible by 2*cp + # must be divisible by 2*tp*cp
- Or change the formula if only
2*cpis intended:- make_sequence_length_divisible_by: ${mul:${mul:${.dtensor_cfg.tensor_parallel_size}, ${.dtensor_cfg.context_parallel_size}}, 2} + make_sequence_length_divisible_by: ${mul:${.dtensor_cfg.context_parallel_size}, 2}
171-171: Add trailing newline (yamllint new-line-at-end-of-file)Ends without a newline; add one to satisfy linters.
- num_nodes: 1 + num_nodes: 1 +
42-50: Style consistency: Booleans use capitalized YAML literalsMinor: elsewhere you use lowercase
false/true. Consider normalizing for consistency.examples/run_distillation_math.py (5)
78-96: Over-verbose prompt-file path handling, unused variable, and blind except
contentis unused.- Blind
except Exceptionand re-raising withoutfrom.- This block still always raises when
task_data_spec.promptis None; reading the file adds no value. TaskDataSpec already loads/validates files.Apply this simplification:
- if task_data_spec.prompt is None: - if task_data_spec.prompt_file: - if os.path.exists(task_data_spec.prompt_file): - try: - with open(task_data_spec.prompt_file, "r", encoding="utf-8") as f: - content = f.read() - except Exception as e: - raise ValueError(f"Failed to read file: {e}") - else: - raise ValueError( - f"Prompt file does not exist: {task_data_spec.prompt_file}" - ) - - raise ValueError( - f"TaskDataSpec.prompt is None. This usually means the prompt file " - f"'{task_data_spec.prompt_file}' could not be loaded or is empty. " - f"Current working directory: {os.getcwd()}, " - f"Absolute prompt file path: {os.path.abspath(task_data_spec.prompt_file) if task_data_spec.prompt_file else 'None'}" - ) + if task_data_spec.prompt is None: + pf = task_data_spec.prompt_file + raise FileNotFoundError(f"Prompt is missing (prompt_file={pf})")
104-118: Type mismatch: apply_chat_template returns str, not list[str]
messageis a string. Update the annotation and name to avoid confusion.- message: list[str] = tokenizer.apply_chat_template( # type: ignore + rendered: str = tokenizer.apply_chat_template( # type: ignore [user_message], tokenize=False, add_generation_prompt=True, add_special_tokens=False, ) - user_message["token_ids"] = tokenizer( - message, + user_message["token_ids"] = tokenizer( + rendered, return_tensors="pt", add_special_tokens=False, )["input_ids"][0] - user_message["content"] = message + user_message["content"] = rendered
253-259: Hardcoded seed ignores config
setup_data(..., 42)ignoresdistillation.seed. Use the configured seed.- ) = setup_data(tokenizer, config["data"], config["env"], 42) + ) = setup_data( + tokenizer, + config["data"], + config["env"], + seed=config["distillation"]["seed"], + )
130-140: Truncation policy silently zeroes loss after clippingSetting
loss_multiplier = 0.0when truncation occurs masks training on that sample. Consider logging a counter or sampling fewer-long prompts to reduce silent drop.Would you like a small metric hook to log “truncated_samples” per step?
187-195: Ray actor runtime_env depends on registry entryThis call will raise if the env is not registered. If you expect users to run locally without pre-registration, add a fallback to PY_EXECUTABLES.SYSTEM.
Would you like a guarded fallback here?
examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml (3)
86-101: Cap generation length or add stop criteria to avoid runaway 32K token generations
max_new_tokensequals the full model context (32768) with nostop_strings/stop_token_ids. This can lead to extreme runtimes/OOM when prompts are short. Consider capping or wiring task-specific stops.Apply one of:
- max_new_tokens: ${..max_total_sequence_length} + # Option A: cap hard + max_new_tokens: 8192 + # Option B: make configurable via env var fallback + max_new_tokens: ${oc.env:MAX_NEW_TOKENS,8192} + # Also add task-specific stops if available + stop_strings: + - "Answer:" + - "</final_answer>"
95-101: Consider raising vLLM memory utilization for throughput
gpu_memory_utilization: 0.6is conservative and may underutilize GPUs on long runs. 0.80–0.90 is typical when KV cache is the dominant consumer.- gpu_memory_utilization: 0.6 + gpu_memory_utilization: 0.85
49-61: Sequence packing remains disabled while dynamic batching is enabledGiven long-seq runs, enabling packing can reduce padding costs during training/logprob passes.
- sequence_packing: - enabled: false + sequence_packing: + enabled: true train_mb_tokens: ${mul:${..max_total_sequence_length}, ${..train_micro_batch_size}} logprob_mb_tokens: ${mul:${..max_total_sequence_length}, ${..logprob_batch_size}} algorithm: "modified_first_fit_decreasing" sequence_length_round: 64nemo_rl/algorithms/loss_functions.py (3)
885-902: Guard against missing teacher inputs with a clear error
teacher_topk_logits/teacher_topk_indicesare accessed unconditionally. Add validation to fail fast with actionable messaging.- teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] - teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] + if "teacher_topk_logits" not in data or "teacher_topk_indices" not in data: + raise KeyError( + "DistillationLossFn requires 'teacher_topk_logits' and 'teacher_topk_indices' in data." + ) + teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] + teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k]
930-976: Support for zero_outside_topk already exists under TP/CP; keep it documentedThe implementation handles TP/CP via chunked ops; make this explicit in class docstring to avoid restrictive external assertions.
class DistillationLossFn(LossFunction): - """Distillation loss function.""" + """Distillation loss function. + + Supports TP/CP via ChunkedDistributedGatherLogprob and ChunkedDistributedEntropy. + When zero_outside_topk=True, reverse/mixed KL include the missing-mass correction. + """
1002-1011: Move teacher tensors to the student device once and preserve dtypeTiny micro-optimization: store device/dtype before branching.
- if self.zero_outside_topk: - teacher_topk_logits = teacher_topk_logits.to( - student_topk_logprobs.device, dtype=student_topk_logprobs.dtype - ) - else: - teacher_topk_logits = teacher_topk_logits.to( - student_topk_logits.device, dtype=student_topk_logits.dtype - ) + tgt_device = (student_topk_logprobs if self.zero_outside_topk else student_topk_logits).device + tgt_dtype = (student_topk_logprobs if self.zero_outside_topk else student_topk_logits).dtype + teacher_topk_logits = teacher_topk_logits.to(tgt_device, dtype=tgt_dtype)examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml (2)
62-63: Divisibility set to 2; consider 8 or 16 for tensor core friendlinessA higher multiple often aligns better with attention kernels on modern GPUs.
- make_sequence_length_divisible_by: 2 + make_sequence_length_divisible_by: 8
86-113: Same generation safeguards as long recipeRecommend adding stop criteria and/or lower
max_new_tokenshere as well.- max_new_tokens: ${..max_total_sequence_length} + max_new_tokens: 4096 + stop_strings: + - "</final_answer>"nemo_rl/algorithms/distillation.py (4)
590-596: Duplicate offload call
offload_after_refit()is invoked twice back‑to‑back. Keep one.- with timer.time("logprob_inference_prep"): - teacher_policy.offload_after_refit() - - print("▶ Preparing for training...") - with timer.time("training_prep"): - teacher_policy.offload_after_refit() + with timer.time("logprob_inference_prep"): + teacher_policy.offload_after_refit() + + print("▶ Preparing for training...") + with timer.time("training_prep"):
136-143: Return signature/docstring driftDocstring mentions clusters but function returns
(student_policy, teacher_policy, student_generation, dataloader, val_dataloader, loss_fn, logger, checkpointer, save_state, master_config). Update docstring to avoid confusion.- Returns: - tuple of student_policy, student_generation, - (train_cluster, inference_cluster), train_dataloader, val_dataloader, - loss_fn, logger, checkpointer, distillation_save_state, master_config + Returns (in order): + student_policy, teacher_policy, student_generation, + train_dataloader, val_dataloader, loss_fn, logger, + checkpointer, distillation_save_state, master_config
666-685: Checkpoint metric presence warning: add stacklevel and explicit metric fallbackNoise reduction in logs and clearer fallback.
- warnings.warn( - f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." - ) + warnings.warn( + ( + "Checkpoint metric " + f"{master_config['checkpointing']['metric_name']} not found in save state; " + "saving most recent k checkpoints instead." + ), + stacklevel=2, + )
857-868: Avoid broad exception message formatting nitUse explicit conversion flag for f‑string.
- except Exception as e: - print(f"\n ⚠️ Error displaying message samples: {str(e)}") + except Exception as e: + print(f"\n ⚠️ Error displaying message samples: {e!s}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (25)
examples/configs/distillation_math.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml(1 hunks)examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml(1 hunks)examples/run_distillation_math.py(1 hunks)nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/loss_functions.py(6 hunks)nemo_rl/distributed/batched_data_dict.py(1 hunks)nemo_rl/distributed/model_utils.py(2 hunks)nemo_rl/models/policy/dtensor_policy_worker.py(3 hunks)nemo_rl/models/policy/interfaces.py(2 hunks)nemo_rl/models/policy/lm_policy.py(3 hunks)tests/functional/distillation.sh(1 hunks)tests/test_distillation_simple.py(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh(1 hunks)tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh(1 hunks)tests/unit/algorithms/test_distillation.py(1 hunks)tests/unit/algorithms/test_loss_functions.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (16)
tests/unit/algorithms/test_distillation.py (4)
nemo_rl/algorithms/distillation.py (3)
_default_distillation_save_state(92-97)distillation_train(390-773)validate(776-888)nemo_rl/algorithms/loss_functions.py (1)
DistillationLossFn(849-1104)nemo_rl/data/interfaces.py (1)
DatumSpec(32-40)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-857)
tests/test_distillation_simple.py (2)
nemo_rl/algorithms/loss_functions.py (1)
DistillationLossFn(849-1104)tests/unit/algorithms/test_distillation.py (2)
train_iter(89-90)val_iter(97-98)
examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.yaml (1)
tests/unit/models/generation/test_vllm_generation.py (2)
test_vllm_weight_update_and_prefix_cache_reset(1046-1158)test_vllm_megatron_weight_update_with_packing(1778-1832)
nemo_rl/distributed/model_utils.py (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
get_logprobs(835-1131)
nemo_rl/algorithms/loss_functions.py (3)
nemo_rl/distributed/model_utils.py (6)
ChunkedDistributedEntropy(984-1055)ChunkedDistributedGatherLogprob(258-381)_get_tokens_on_this_cp_rank(669-703)allgather_cp_sharded_tensor(706-709)from_parallel_logits_to_logprobs(469-541)gather_logits_at_global_indices(899-981)nemo_rl/algorithms/interfaces.py (2)
LossFunction(28-70)LossType(23-25)nemo_rl/algorithms/utils.py (1)
masked_mean(128-140)
nemo_rl/models/policy/interfaces.py (4)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
get_topk_logits(1186-1524)nemo_rl/models/policy/lm_policy.py (1)
get_topk_logits(336-400)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-857)nemo_rl/models/generation/interfaces.py (1)
GenerationDatumSpec(127-158)
nemo_rl/models/policy/dtensor_policy_worker.py (3)
nemo_rl/distributed/model_utils.py (3)
allgather_cp_sharded_tensor(706-709)distributed_vocab_topk(829-896)get_logprobs_from_vocab_parallel_logits(778-825)nemo_rl/utils/nsys.py (1)
wrap_with_nvtx_name(82-94)nemo_rl/distributed/batched_data_dict.py (3)
BatchedDataDict(75-857)to(822-829)size(811-820)
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh (1)
tests/unit/test_recipes_and_test_suites.py (1)
test_dry_run_does_not_fail_and_prints_total_gpu_hours(193-221)
nemo_rl/algorithms/distillation.py (12)
nemo_rl/algorithms/grpo.py (2)
_should_use_async_rollouts(406-420)refit_policy_generation(423-491)nemo_rl/algorithms/loss_functions.py (3)
DistillationLossConfig(831-836)DistillationLossDataDict(839-846)DistillationLossFn(849-1104)nemo_rl/algorithms/utils.py (1)
set_seed(143-148)nemo_rl/data/datasets.py (2)
AllTaskProcessedDataset(36-131)rl_collate_fn(134-178)nemo_rl/data/llm_message_utils.py (2)
batched_message_log_to_flat_message(233-390)get_keys_from_message_log(126-138)nemo_rl/distributed/batched_data_dict.py (5)
BatchedDataDict(75-857)repeat_interleave(721-742)size(811-820)get_multimodal_dict(88-99)to(822-829)nemo_rl/distributed/virtual_cluster.py (2)
ClusterConfig(32-34)RayVirtualCluster(168-410)nemo_rl/models/generation/vllm/vllm_generation.py (1)
VllmGeneration(47-784)nemo_rl/models/policy/interfaces.py (6)
ColocatablePolicyInterface(126-155)prepare_refit_info(142-143)get_topk_logits(78-87)offload_after_refit(138-139)prepare_for_training(110-111)train(90-107)nemo_rl/models/policy/lm_policy.py (9)
Policy(56-696)finish_generation(554-556)prepare_refit_info(562-571)prepare_for_generation(539-541)prepare_for_lp_inference(548-552)get_topk_logits(336-400)offload_after_refit(650-653)prepare_for_training(543-546)train(402-494)nemo_rl/utils/checkpoint.py (5)
CheckpointManager(55-269)get_latest_checkpoint_path(238-251)load_training_info(253-269)init_tmp_checkpoint(87-126)finalize_checkpoint(128-157)nemo_rl/utils/logger.py (3)
Logger(710-933)print_message_log_samples(1022-1219)log_batched_dict_as_jsonl(804-828)
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.yaml (1)
tests/unit/models/policy/test_dtensor_worker.py (1)
policy_setup(134-147)
examples/run_distillation_math.py (13)
nemo_rl/algorithms/distillation.py (3)
MasterConfig(100-111)distillation_train(390-773)setup(119-382)nemo_rl/algorithms/utils.py (1)
get_tokenizer(151-267)nemo_rl/data/__init__.py (1)
DataConfig(18-35)nemo_rl/data/datasets.py (1)
AllTaskProcessedDataset(36-131)nemo_rl/data/hf_datasets/deepscaler.py (1)
DeepScalerDataset(67-78)nemo_rl/data/hf_datasets/openmathinstruct2.py (1)
OpenMathInstruct2Dataset(77-105)nemo_rl/data/interfaces.py (3)
DatumSpec(32-40)TaskDataProcessFnCallable(89-100)TaskDataSpec(53-86)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(43-58)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(75-161)nemo_rl/environments/math_environment.py (1)
MathEnvironment(222-372)nemo_rl/utils/config.py (2)
load_config(74-137)parse_hydra_overrides(146-166)nemo_rl/utils/logger.py (1)
get_next_experiment_dir(1222-1256)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(24-45)
tests/unit/algorithms/test_loss_functions.py (1)
nemo_rl/algorithms/loss_functions.py (1)
DistillationLossFn(849-1104)
nemo_rl/distributed/batched_data_dict.py (2)
nemo_rl/data/llm_message_utils.py (1)
batched_message_log_to_flat_message(233-390)tests/unit/distributed/test_batched_data_dict.py (2)
test_sequence_packing_microbatch_boundaries(550-683)test_sequence_packing_basic(246-327)
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)
nemo_rl/models/policy/dtensor_policy_worker_v2.py (1)
__init__(108-436)
tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh (1)
tests/unit/test_recipes_and_test_suites.py (1)
test_dry_run_does_not_fail_and_prints_total_gpu_hours(193-221)
nemo_rl/models/policy/lm_policy.py (3)
nemo_rl/models/policy/interfaces.py (2)
TopkLogitsOutputSpec(37-41)get_topk_logits(78-87)nemo_rl/models/policy/dtensor_policy_worker.py (1)
get_topk_logits(1186-1524)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-857)
🪛 Shellcheck (0.10.0)
tests/functional/distillation.sh
[error] 43-43: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-4n8g-fsdp2tp4-long.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
🪛 Ruff (0.12.2)
tests/unit/algorithms/test_distillation.py
89-89: Unused function argument: self
(ARG001)
97-97: Unused function argument: self
(ARG001)
tests/test_distillation_simple.py
1-1: Shebang is present but file is not executable
(EXE001)
140-140: Unused function argument: self
(ARG001)
148-148: Unused function argument: self
(ARG001)
158-158: Local variable loss_fn is assigned to but never used
Remove assignment to unused variable loss_fn
(F841)
168-168: Local variable logger is assigned to but never used
Remove assignment to unused variable logger
(F841)
169-169: Local variable checkpointer is assigned to but never used
Remove assignment to unused variable checkpointer
(F841)
171-171: Local variable task_to_env is assigned to but never used
Remove assignment to unused variable task_to_env
(F841)
172-172: Local variable val_task_to_env is assigned to but never used
Remove assignment to unused variable val_task_to_env
(F841)
188-188: Consider moving this statement to an else block
(TRY300)
190-190: Do not catch blind exception: Exception
(BLE001)
nemo_rl/algorithms/loss_functions.py
865-865: Unused method argument: global_valid_seqs
(ARG002)
1055-1055: Avoid specifying long messages outside the exception class
(TRY003)
1073-1073: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/policy/dtensor_policy_worker.py
1243-1243: Loop control variable batch_idx not used within loop body
Rename unused batch_idx to _batch_idx
(B007)
1335-1335: Local variable seq_index_tensor is assigned to but never used
Remove assignment to unused variable seq_index_tensor
(F841)
1501-1501: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
nemo_rl/algorithms/distillation.py
266-269: Avoid specifying long messages outside the exception class
(TRY003)
271-274: Avoid specifying long messages outside the exception class
(TRY003)
534-534: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
535-535: Loop control variable j not used within loop body
Rename unused j to _j
(B007)
660-660: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
866-866: Do not catch blind exception: Exception
(BLE001)
867-867: Use explicit conversion flag
Replace with conversion flag
(RUF010)
examples/run_distillation_math.py
83-83: Local variable content is assigned to but never used
Remove assignment to unused variable content
(F841)
84-84: Do not catch blind exception: Exception
(BLE001)
85-85: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
87-89: Avoid specifying long messages outside the exception class
(TRY003)
91-96: Avoid specifying long messages outside the exception class
(TRY003)
106-106: Do not catch blind exception: Exception
(BLE001)
107-107: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
107-107: Avoid specifying long messages outside the exception class
(TRY003)
180-180: Avoid specifying long messages outside the exception class
(TRY003)
tests/unit/algorithms/test_loss_functions.py
1806-1806: Avoid equality comparisons to False; use not loss_fn.zero_outside_topk: for false checks
Replace with not loss_fn.zero_outside_topk
(E712)
1821-1821: Avoid equality comparisons to True; use loss_fn.zero_outside_topk: for truth checks
Replace with loss_fn.zero_outside_topk
(E712)
nemo_rl/distributed/batched_data_dict.py
134-134: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
134-134: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
135-135: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
136-136: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
136-136: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
136-136: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
138-138: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
138-138: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
138-138: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
138-138: Comment contains ambiguous , (FULLWIDTH COMMA). Did you mean , (COMMA)?
(RUF003)
146-146: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
146-146: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
🪛 GitHub Actions: Copyright check
tests/test_distillation_simple.py
[error] 1-1: Found files with missing copyright: path= ./tests/test_distillation_simple.py
🪛 YAMLlint (1.37.1)
examples/configs/distillation_math.yaml
[error] 171-171: no new line character at the end of file
(new-line-at-end-of-file)
🔇 Additional comments (14)
nemo_rl/distributed/model_utils.py (1)
899-982: gather_logits_at_global_indices: CP padding/gathering mirrors existing logprob path; confirm 2×CP chunking expectation._get_tokens_on_this_cp_rank assumes inputs are divisible by 2cp_size. Here pad_len uses cp_size only (same as from_parallel_logits_to_logprobs). Please confirm seq length invariants (and tests) cover this for global_indices; otherwise compute pad_len to satisfy 2cp_size divisibility.
nemo_rl/models/policy/interfaces.py (1)
37-42: TopkLogitsOutputSpec looks good.Shape contracts are explicit and match downstream usage.
nemo_rl/models/policy/dtensor_policy_worker.py (1)
799-801: LGTM! Clean implementation of context parallel group passing.The addition of the
context_parallel_groupparameter to the loss function when CP is enabled is correct and follows the established pattern for distributed training.nemo_rl/models/policy/lm_policy.py (1)
336-400: Clean implementation of top-k logits API.The
get_topk_logitsmethod properly handles sharding, dispatching to workers, and result aggregation. The manual tensor concatenation approach correctly preserves the 3D tensor shape for top-k results.tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh (1)
1-42: Fix duplicate shell-script issues and confirm aggressive loss target
- Apply the same fixes as the 1.7B test script to remove duplicated shell-script problems.
- Files to update:
- tests/test_suites/llm/distillation-qwen3-32b-to-4b-instruct-2n8g-fsdp2tp2sp.v1.sh
- tests/test_suites/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.sh
- Both files assert 'data["train/loss"]["200"] < 1.0' — verify from past run logs that <1.0 at step 200 is achievable for these configs; if not, relax the threshold to 1.5 to match the 1.7B baseline.
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-1n8g-fsdp2tp2-dynamicbatch.v1.yaml (1)
40-47: LGTM: Teacher CP with no sequence packing/parallel is validTeacher uses CP=2 with sequence_packing.disabled and sequence_parallel=false in the base; this matches DTensor worker constraints.
Also applies to: 55-61, 114-121
examples/configs/recipes/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.yaml (2)
37-45: LGTM: Long-context recipe respects CP/SP constraintsSequence packing disabled; teacher CP=2 without sequence_parallel. Looks consistent for long-context runs.
Also applies to: 55-61, 114-121
74-85: Scheduler shape matches 1k stepsLinear warmup 100 then cosine to 900 fits
max_num_steps: 1000. LGTM.examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml (1)
40-47: LGTM: Conservative TP/CP settings with packing disabledTeacher CP=2, packing and sequence_parallel disabled; aligns with worker constraints.
Also applies to: 55-61, 114-121
examples/run_distillation_math.py (3)
246-251: Warn-only when generation config missingOK to proceed, but most setups require generation. You may want to error out if on-policy is intended.
Do you ever run this script off-policy? If not, prefer raising when
policy.generationis missing.
239-243: Tokenizer source: pass through tokenizer.chat_template via configIf a custom chat_template is provided in config, ensure
get_tokenizerapplies it (it does per utils). Just confirming this is intentional.
262-272: Good: Passing tokenizer into setup/distillation_trainThis unblocks token-level loss paths and stop-string wiring.
nemo_rl/algorithms/loss_functions.py (1)
160-174: Casting logits to float32 for stability: goodPrevents numerical issues in softmax/log-softmax across losses.
examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml (1)
74-85: Scheduler chain: ensuremilestonespairs withSequentialLRIf your scheduler builder expects a
SequentialLRwrapper, the lone- milestones: [20]entry is fine. If not, this list item without anamemay be ignored. Please confirm the parser behavior.Would you like me to patch the scheduler builder to explicitly wrap
LinearLR+ConstantLRinSequentialLR(milestones=[20])?
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/test_distillation_simple.py (1)
1-13: Header fixed; previous compliance issue resolved.Copyright header added and the stray shebang removed.
🧹 Nitpick comments (6)
tests/test_distillation_simple.py (6)
21-21: Avoid mutating sys.path in tests.Rely on standard import paths; this can mask packaging issues and cause brittle CI.
Apply this diff to remove the path hack:
-sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
31-35: Replace prints/returns with real assertions.Ensure the test fails if imports break.
def test_basic_imports(): - """Test whether basic imports work""" - print("✅ Basic import test passed") - return True + """Basic import smoke test""" + assert DistillationLossFn is not None + assert torch is not None
149-165: Remove unused StatefulDataLoader mocks (and the import).They’re unused and the current
__iter__stubs would be called withoutselfbinding if used. Also drops Ruff ARG001 warnings.-from torchdata.stateful_dataloader import StatefulDataLoader @@ - # Create mock dataloaders - train_dataloader = MagicMock(spec=StatefulDataLoader) - - def train_iter(self): - return iter([mock_batch] * 10) - - train_dataloader.__iter__ = train_iter - train_dataloader.__len__ = MagicMock(return_value=10) - - val_dataloader = MagicMock(spec=StatefulDataLoader) - - def val_iter(self): - return iter([mock_batch] * 5) - - val_dataloader.__iter__ = val_iter - val_dataloader.__len__ = MagicMock(return_value=5)
170-187: Exercise DistillationLossFn and assert outputs; drop unused locals.Uses your mock batch to compute a real loss and removes unused variables flagged by Ruff.
loss_fn = DistillationLossFn( { "temperature": 1.0, "alpha": 0.5, "kl_type": "forward", "mixed_kl_weight": 0.5, "zero_outside_topk": False, } ) - - logger = MagicMock() - checkpointer = MagicMock() - - task_to_env = {"math": MagicMock()} - val_task_to_env = {"math": MagicMock()} - - print("✅ Mock component creation test passed") - return True + # Compute loss with mock data + student_logits = student_generation.generate.return_value["logits"] + V = student_logits.shape[-1] + # Ensure teacher indices are within student vocab size + assert int(mock_batch["teacher_topk_indices"].max()) < V + global_valid_seqs = mock_batch["sample_mask"].sum().to(dtype=torch.int64) + global_valid_toks = mock_batch["token_mask"][:, 1:].sum().to(dtype=torch.int64) + loss, metrics = loss_fn( + student_logits, + mock_batch, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + ) + assert torch.isfinite(loss) + assert "loss" in metrics and "num_valid_samples" in metrics
190-212: Drop custom main harness in pytest tests.Pytest will collect and run these; the try/except triggers Ruff TRY300/BLE001. Removing reduces noise.
-def main(): - """Run all tests""" - print("Starting simplified distillation tests...") - - try: - test_basic_imports() - test_config_creation() - test_mock_components() - - print("\n🎉 All tests passed! Distillation test configuration is correct.") - return 0 - - except Exception as e: - print(f"\n❌ Test failed: {e}") - import traceback - - traceback.print_exc() - return 1 - - -if __name__ == "__main__": - sys.exit(main()) +# Pytest will discover and run test_* functions automatically.
145-147: Optional: Add a second path test with zero_outside_topk=True.Covers the log-prob path and entropy correction branch in DistillationLossFn.
I can draft a small
test_distillation_loss_zero_outside_topkthat reusesmock_batchwithloss_fn = DistillationLossFn({..., "zero_outside_topk": True})and asserts finite loss. Want me to add it?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/test_distillation_simple.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/test_distillation_simple.py (1)
nemo_rl/algorithms/loss_functions.py (1)
DistillationLossFn(849-1104)
🪛 Ruff (0.12.2)
tests/test_distillation_simple.py
152-152: Unused function argument: self
(ARG001)
160-160: Unused function argument: self
(ARG001)
170-170: Local variable loss_fn is assigned to but never used
Remove assignment to unused variable loss_fn
(F841)
180-180: Local variable logger is assigned to but never used
Remove assignment to unused variable logger
(F841)
181-181: Local variable checkpointer is assigned to but never used
Remove assignment to unused variable checkpointer
(F841)
183-183: Local variable task_to_env is assigned to but never used
Remove assignment to unused variable task_to_env
(F841)
184-184: Local variable val_task_to_env is assigned to but never used
Remove assignment to unused variable val_task_to_env
(F841)
200-200: Consider moving this statement to an else block
(TRY300)
202-202: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
tests/test_distillation_simple.py (1)
37-95: Config shape checks look good.Covers key sections and nested generation.colocated.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
nemo_rl/algorithms/distillation.py (1)
242-273: Non‑colocated cluster resource math creates train_nodes/train_gpus_per_node = 0 (fatal).
- Single-node: defaults set
inference_nodes=1⇒train_nodes=0.- Multi-node: defaults set
inference_gpus_per_node=gpus_per_nodeand then computetrain_gpus_per_node = gpus_per_node - inference_gpus_per_node = 0.This will fail cluster creation.
Minimal safe fix:
- Disallow non‑colocated when
num_nodes == 1.- In multi‑node mode, allocate full GPUs per node for each cluster and split nodes (not GPUs within a node). Guard that
inference_nodes < num_nodes.@@ - inference_resources = generation_config["colocated"]["resources"] - inference_gpus_per_node = inference_resources["gpus_per_node"] - inference_nodes = inference_resources["num_nodes"] + inference_resources = generation_config["colocated"]["resources"] + inference_gpus_per_node = inference_resources.get("gpus_per_node") + inference_nodes = inference_resources.get("num_nodes") @@ - if cluster_config["num_nodes"] == 1: - if inference_gpus_per_node is None: - inference_gpus_per_node = cluster_config["gpus_per_node"] // 2 - if inference_nodes is None: - inference_nodes = 1 + if cluster_config["num_nodes"] == 1: + raise ValueError( + "Non-colocated inference requires num_nodes >= 2. " + "Set policy.generation.colocated.enabled: true for single-node runs." + ) else: if inference_gpus_per_node is None: inference_gpus_per_node = cluster_config["gpus_per_node"] if inference_nodes is None: - inference_nodes = cluster_config["num_nodes"] // 2 + inference_nodes = max(1, cluster_config["num_nodes"] // 2) @@ - if inference_gpus_per_node > cluster_config["gpus_per_node"]: + if inference_gpus_per_node > cluster_config["gpus_per_node"]: raise ValueError( f"Inference GPUs per node ({inference_gpus_per_node}) cannot be greater than " f"total GPUs per node ({cluster_config['gpus_per_node']})" ) - if inference_nodes > cluster_config["num_nodes"]: + if inference_nodes >= cluster_config["num_nodes"]: raise ValueError( - f"Inference nodes ({inference_nodes}) cannot be greater than " - f"total nodes ({cluster_config['num_nodes']})" + "Non-colocated mode must reserve at least 1 node for training: " + f"inference_nodes={inference_nodes}, total_nodes={cluster_config['num_nodes']}" ) @@ - train_gpus_per_node = cluster_config["gpus_per_node"] - inference_gpus_per_node - train_nodes = cluster_config["num_nodes"] - inference_nodes + # Split by nodes, not by GPUs within a node. + train_gpus_per_node = cluster_config["gpus_per_node"] + train_nodes = cluster_config["num_nodes"] - inference_nodesAlso applies to: 275-291
🧹 Nitpick comments (5)
nemo_rl/algorithms/distillation.py (5)
528-539: Drop unused enumerate indices.Removes B007 warnings and clarifies intent.
- for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): + for message_log in repeated_batch["message_log"]: + for message in message_log:
654-658: Add stacklevel to warning for correct source attribution.- warnings.warn( + warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " "Saving most recent k checkpoints instead." - ) + , stacklevel=2)
739-748: Compute tokens/sec per GPU using training GPU count.Using all cluster GPUs undercounts in non‑colocated mode. Prefer
num_ranksfrom train results.- total_num_gpus = ( - master_config["cluster"]["num_nodes"] - * master_config["cluster"]["gpus_per_node"] - ) + total_num_gpus = int( + train_results.get("num_ranks", + master_config["cluster"]["num_nodes"] * master_config["cluster"]["gpus_per_node"] + ) + )
849-863: Narrow overly broad exception and use explicit f-string conversion.Prefer a narrower exception (or at least log via logger) and use
!sformatter.- except Exception as e: - print(f"\n ⚠️ Error displaying message samples: {str(e)}") + except Exception as e: + print(f"\n ⚠️ Error displaying message samples: {e!s}") print(" ⚠️ Continuing validation without displaying samples...")
320-323: Nit: Duplicate “Setting up models…” banner.Second banner is confusing; consider “▶ Setting up teacher model…”.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/algorithms/distillation.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
nemo_rl/algorithms/distillation.py (5)
nemo_rl/algorithms/grpo.py (2)
_should_use_async_rollouts(406-420)refit_policy_generation(423-491)nemo_rl/algorithms/loss_functions.py (3)
DistillationLossConfig(831-836)DistillationLossDataDict(839-846)DistillationLossFn(849-1104)nemo_rl/data/datasets.py (1)
rl_collate_fn(134-178)nemo_rl/data/llm_message_utils.py (1)
batched_message_log_to_flat_message(233-390)nemo_rl/distributed/batched_data_dict.py (4)
BatchedDataDict(75-857)repeat_interleave(721-742)size(811-820)to(822-829)
🪛 Ruff (0.12.2)
nemo_rl/algorithms/distillation.py
260-263: Avoid specifying long messages outside the exception class
(TRY003)
265-268: Avoid specifying long messages outside the exception class
(TRY003)
528-528: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
529-529: Loop control variable j not used within loop body
Rename unused j to _j
(B007)
654-654: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
860-860: Do not catch blind exception: Exception
(BLE001)
861-861: Use explicit conversion flag
Replace with conversion flag
(RUF010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/distillation.py (1)
153-156: Good assert for required generation config.Keeps misconfiguration from slipping through.
|
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/run_distillation_math.py (1)
203-214: Guard validation split; fallback to 'test' when no 'validation'.- val_dataset: Optional[AllTaskProcessedDataset] = None - if data.formatted_ds["validation"]: - val_dataset = AllTaskProcessedDataset( - data.formatted_ds["validation"], - tokenizer, - math_task_spec, - task_data_processors, - max_seq_length=data_config["max_input_seq_length"], - ) - else: - val_dataset = None + val_dataset: Optional[AllTaskProcessedDataset] = None + if "validation" in data.formatted_ds and data.formatted_ds["validation"]: + val_src = data.formatted_ds["validation"] + elif "test" in data.formatted_ds and data.formatted_ds["test"]: + val_src = data.formatted_ds["test"] + else: + val_src = None + if val_src is not None: + val_dataset = AllTaskProcessedDataset( + val_src, + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + )nemo_rl/algorithms/distillation.py (1)
246-276: Prevent 0-GPU/0-node train cluster in non-colocated mode (single- and multi-node).When num_nodes==1, train_gpus_per_node can become 0. For num_nodes>1, train_nodes can become 0 after subtraction. Guard both and give actionable errors.
Apply:
inference_resources = generation_config["colocated"]["resources"] inference_gpus_per_node = inference_resources["gpus_per_node"] inference_nodes = inference_resources["num_nodes"] # validate and configure resources if cluster_config["num_nodes"] == 1: assert inference_gpus_per_node > 0, ( "policy.generation.colocated.resources.gpus_per_node must be > 0 " "when cluster.num_nodes = 1 and inference is non-colocated, " f"but got {inference_gpus_per_node}." ) assert inference_nodes is None or inference_nodes == 1, ( "policy.generation.colocated.resources.num_nodes must be 1 or set to null " "when cluster.num_nodes = 1 and inference is non-colocated, " f"but got {inference_nodes}." ) - inference_nodes = 1 - train_gpus_per_node -= inference_gpus_per_node + inference_nodes = 1 + train_nodes = 1 + remaining_gpus = cluster_config["gpus_per_node"] - inference_gpus_per_node + if remaining_gpus <= 0: + raise ValueError( + "Non-colocated mode (single node) requires reserving >=1 GPU for training: " + f"cluster.gpus_per_node={cluster_config['gpus_per_node']}, " + f"inference_gpus_per_node={inference_gpus_per_node}." + ) + train_gpus_per_node = remaining_gpus else: assert inference_nodes > 0, ( "policy.generation.colocated.resources.num_nodes must be > 0 " "when cluster.num_nodes > 1 and inference is non-colocated, " f"but got {inference_nodes}." ) assert ( inference_gpus_per_node is None or inference_gpus_per_node == cluster_config["gpus_per_node"] ), ( "policy.generation.colocated.resources.gpus_per_node must be equal to cluster.gpus_per_node or set to null " "when cluster.num_nodes > 1 and inference is non-colocated, " f"but got {inference_gpus_per_node}." ) inference_gpus_per_node = cluster_config["gpus_per_node"] - train_nodes -= inference_nodes + train_nodes -= inference_nodes + if train_nodes <= 0: + raise ValueError( + "Non-colocated mode (multi-node) requires at least 1 train node; " + f"decrease inference_nodes (got inference_nodes={inference_nodes})." + )
🧹 Nitpick comments (11)
nemo_rl/algorithms/distillation.py (7)
138-142: Align setup() docstring with actual return tuple (include teacher_policy).Docstring omits teacher_policy and order. Update for correctness.
- Returns: - tuple of student_policy, student_generation, - train_dataloader, val_dataloader, - loss_fn, logger, checkpointer, distillation_save_state, master_config + Returns: + tuple of ( + student_policy, + teacher_policy, + student_generation, + train_dataloader, + val_dataloader, + loss_fn, + logger, + checkpointer, + distillation_save_state, + master_config, + )
299-301: Clarify setup logs; avoid duplicate “Setting up models…” message.Differentiate student vs. teacher for easier debugging.
-print("\n▶ Setting up models...") +print("\n▶ Setting up student model...") ... -print("\n▶ Setting up models...") +print("\n▶ Setting up teacher model...")Also applies to: 323-325
530-541: Remove unused loop indices (ruff B007) and simplify loops.- for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): + for message_log in repeated_batch["message_log"]: + for message in message_log:
652-657: warnings.warn without stacklevel; set stacklevel=2 for correct source.- warnings.warn( + warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " "Saving most recent k checkpoints instead." - ) + , stacklevel=2)
743-747: Guard tokens/sec against total_time==0.- metrics.update( - { - "tokens_per_sec_per_gpu": metrics["total_num_tokens"] - / total_time - / total_num_gpus - } - ) + tps = ( + metrics["total_num_tokens"] / total_time / total_num_gpus + if total_time > 0 + else 0.0 + ) + metrics.update({"tokens_per_sec_per_gpu": tps})
789-796: Don’t drop remainder validation samples; ceil the batch count.- max_batches = ( - master_config["distillation"]["max_val_samples"] - // master_config["distillation"]["val_batch_size"] - ) + max_batches = ( + master_config["distillation"]["max_val_samples"] + + master_config["distillation"]["val_batch_size"] + - 1 + ) // master_config["distillation"]["val_batch_size"]
851-865: Narrow exception scope for sample printing and use !s; prefer warning.- try: + try: print_message_log_samples( all_message_logs, total_rewards, num_samples=min( master_config["logger"]["num_val_samples_to_print"], len(all_message_logs), ), step=step, ) - except Exception as e: - print(f"\n ⚠️ Error displaying message samples: {str(e)}") - print(" ⚠️ Continuing validation without displaying samples...") + except (KeyError, IndexError, TypeError) as e: + warnings.warn(f"Error displaying message samples: {e!s}", stacklevel=2)examples/run_distillation_math.py (4)
78-97: Simplify prompt check; remove unused read and chain errors.The file read is unused and the code raises regardless. Keep a clear validation.
- # safety check: ensure prompt exists - if task_data_spec.prompt is None: - if task_data_spec.prompt_file: - if os.path.exists(task_data_spec.prompt_file): - try: - with open(task_data_spec.prompt_file, "r", encoding="utf-8") as f: - content = f.read() - except Exception as e: - raise ValueError(f"Failed to read file: {e}") - else: - raise ValueError( - f"Prompt file does not exist: {task_data_spec.prompt_file}" - ) - - raise ValueError( - f"TaskDataSpec.prompt is None. This usually means the prompt file " - f"'{task_data_spec.prompt_file}' could not be loaded or is empty. " - f"Current working directory: {os.getcwd()}, " - f"Absolute prompt file path: {os.path.abspath(task_data_spec.prompt_file) if task_data_spec.prompt_file else 'None'}" - ) + # safety check: ensure prompt exists + if not task_data_spec.prompt: + abs_prompt = ( + os.path.abspath(task_data_spec.prompt_file) + if task_data_spec.prompt_file + else "None" + ) + raise ValueError( + f"TaskDataSpec.prompt is None. Ensure a valid prompt_file (resolved path: {abs_prompt})." + )
104-119: Use exception chaining and correct apply_chat_template typing.- try: - formatted_content = task_data_spec.prompt.format(problem) - except Exception as e: - raise ValueError(f"Failed to format prompt: {e}") + try: + formatted_content = task_data_spec.prompt.format(problem) + except Exception as e: + raise ValueError("Failed to format prompt") from e @@ - message: list[str] = tokenizer.apply_chat_template( # type: ignore + chat_text: str = tokenizer.apply_chat_template( # type: ignore [user_message], tokenize=False, add_generation_prompt=True, add_special_tokens=False, ) @@ - user_message["token_ids"] = tokenizer( - message, + user_message["token_ids"] = tokenizer( + chat_text, return_tensors="pt", add_special_tokens=False, )["input_ids"][0] - user_message["content"] = message + user_message["content"] = chat_textAlso applies to: 120-126
147-148: Avoid KeyError for task_name; use spec default.- "task_name": datum_dict["task_name"], + "task_name": task_data_spec.task_name or "math",
258-258: Use configured data.seed instead of hard-coded 42.- ) = setup_data(tokenizer, config["data"], config["env"], 42) + ) = setup_data( + tokenizer, + config["data"], + config["env"], + config["data"].get("seed", 42), + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/run_distillation_math.py(1 hunks)nemo_rl/algorithms/distillation.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/run_distillation_math.py (12)
nemo_rl/algorithms/distillation.py (3)
MasterConfig(100-111)distillation_train(386-765)setup(119-378)nemo_rl/algorithms/utils.py (1)
get_tokenizer(151-267)nemo_rl/data/__init__.py (1)
DataConfig(18-35)nemo_rl/data/datasets.py (1)
AllTaskProcessedDataset(36-131)nemo_rl/data/hf_datasets/deepscaler.py (1)
DeepScalerDataset(67-78)nemo_rl/data/hf_datasets/openmathinstruct2.py (1)
OpenMathInstruct2Dataset(77-105)nemo_rl/data/interfaces.py (3)
DatumSpec(32-40)TaskDataProcessFnCallable(89-100)TaskDataSpec(53-86)nemo_rl/distributed/ray_actor_environment_registry.py (1)
get_actor_python_env(43-58)nemo_rl/distributed/virtual_cluster.py (1)
init_ray(75-161)nemo_rl/environments/math_environment.py (1)
MathEnvironment(222-372)nemo_rl/utils/logger.py (1)
get_next_experiment_dir(1222-1256)nemo_rl/models/generation/__init__.py (1)
configure_generation_config(24-45)
nemo_rl/algorithms/distillation.py (11)
nemo_rl/algorithms/grpo.py (2)
_should_use_async_rollouts(406-420)refit_policy_generation(423-491)nemo_rl/algorithms/loss_functions.py (3)
DistillationLossConfig(831-836)DistillationLossDataDict(839-846)DistillationLossFn(849-1104)nemo_rl/data/datasets.py (1)
rl_collate_fn(134-178)nemo_rl/data/llm_message_utils.py (1)
batched_message_log_to_flat_message(233-390)nemo_rl/distributed/batched_data_dict.py (3)
BatchedDataDict(75-857)size(811-820)to(822-829)nemo_rl/experience/rollouts.py (2)
run_async_multi_turn_rollout(751-895)run_multi_turn_rollout(316-522)nemo_rl/models/generation/vllm/vllm_generation.py (1)
VllmGeneration(47-784)nemo_rl/models/policy/lm_policy.py (8)
Policy(56-696)finish_generation(554-556)prepare_refit_info(562-571)prepare_for_generation(539-541)prepare_for_lp_inference(548-552)get_topk_logits(336-400)offload_after_refit(650-653)train(402-494)nemo_rl/utils/checkpoint.py (6)
CheckpointingConfig(35-52)CheckpointManager(55-269)get_latest_checkpoint_path(238-251)load_training_info(253-269)init_tmp_checkpoint(87-126)finalize_checkpoint(128-157)nemo_rl/utils/logger.py (4)
Logger(710-933)LoggerConfig(69-79)print_message_log_samples(1022-1219)log_batched_dict_as_jsonl(804-828)nemo_rl/utils/timer.py (7)
TimeoutChecker(264-321)Timer(22-248)start_iterations(310-311)time(110-123)mark_iteration(313-321)check_save(284-308)get_timing_metrics(196-233)
🪛 Ruff (0.12.2)
examples/run_distillation_math.py
83-83: Local variable content is assigned to but never used
Remove assignment to unused variable content
(F841)
84-84: Do not catch blind exception: Exception
(BLE001)
85-85: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
87-89: Avoid specifying long messages outside the exception class
(TRY003)
91-96: Avoid specifying long messages outside the exception class
(TRY003)
106-106: Do not catch blind exception: Exception
(BLE001)
107-107: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
107-107: Avoid specifying long messages outside the exception class
(TRY003)
180-180: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/algorithms/distillation.py
530-530: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
531-531: Loop control variable j not used within loop body
Rename unused j to _j
(B007)
652-652: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
862-862: Do not catch blind exception: Exception
(BLE001)
863-863: Use explicit conversion flag
Replace with conversion flag
(RUF010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Post submodule check comment / Comment on PR
ℹ️ File Consistency CheckCheck based on commit: b9695fd (PR #1006 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: ae4774c (PR #1006 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: f751d35 (PR #1006 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 101c093 (PR #1006 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: c827663 (PR #1006 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
|
@zpqiu is it possible to write a quick blog on this feature ? |
Sure, will do after the holiday. |
Signed-off-by: shuo_nvidia <shuoyang@nvidia.com> Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com> Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com> Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Signed-off-by: shuo-nvidia <shuoyang@nvidia.com> Co-authored-by: shuo_nvidia <shuoyang@nvidia.com>
Signed-off-by: shuo_nvidia <shuoyang@nvidia.com> Signed-off-by: alexchiu <qiuzhaopeng@foxmail.com> Signed-off-by: Zhaopeng Qiu <alexq@nvidia.com> Signed-off-by: Zhaopeng Qiu <qiuzhaopeng@foxmail.com> Signed-off-by: shuo-nvidia <shuoyang@nvidia.com> Co-authored-by: shuo_nvidia <shuoyang@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>

What does this PR do ?
This PR implements the on-policy distillation algorithm used in Qwen3, a knowledge distillation approach where the student model generates on-policy sequences based on sampled prompts. The student is then fine-tuned by aligning its output logits with those of a larger teacher model (such as Qwen3-32B or Qwen3-235B-A22B) to minimize the KL divergence. This method enables efficient fine-tuning by leveraging the teacher’s knowledge on the student’s own generated trajectories, resulting in lightweight models that achieve performance close to or surpassing larger models while reducing computational cost compared to reinforcement learning.
Issues
List issues that this PR closes (syntax):
#910
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Experiments on Math Domain
cc: @terrykong @sharathts
Summary by CodeRabbit
New Features
Performance
Documentation
Tests