diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 28126b526c..d6e38c300e 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,6 +1,10 @@ # SFT Algorithm Configuration sft: + ## total number of steps to train will equal + ## min((max_num_epochs * len(train_dataloader)), max_num_steps) + max_num_epochs: 1 max_num_steps: 60 + val_period: 10 val_batches: 8 val_global_batch_size: 32 diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index f0dc551026..a42967aa31 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -42,20 +42,25 @@ class SFTSaveState(TypedDict): - step: int + epoch: int # Track current epoch + step: int # Track step within current epoch + total_steps: int # Track total number of steps across all epochs val_loss: float consumed_samples: int def _default_sft_save_state() -> SFTSaveState: return { + "epoch": 0, "step": 0, + "total_steps": 0, "consumed_samples": 0, } class SFTConfig(TypedDict): max_num_steps: int + max_num_epochs: int val_period: int val_batches: int val_global_batch_size: int @@ -141,6 +146,7 @@ def setup( batch_size=policy_config["train_global_batch_size"], shuffle=True, collate_fn=rl_collate_fn, + drop_last=True, ) if last_checkpoint_path is not None: @@ -333,17 +339,22 @@ def sft_train( if sft_save_state is None: sft_save_state = _default_sft_save_state() - step = 0 + current_epoch = 0 + current_step = 0 + total_steps = 0 else: - step = sft_save_state["step"] + current_epoch = sft_save_state["epoch"] + current_step = sft_save_state["step"] + total_steps = sft_save_state["total_steps"] sft_config = master_config["sft"] # Validation configuration val_period = sft_config["val_period"] val_at_start = sft_config["val_at_start"] + max_num_epochs = sft_config["max_num_epochs"] # Run validation at the start if configured - if val_at_start and step == 0: + if val_at_start and total_steps == 0: print("\n🔍 Running initial validation...") val_metrics, validation_timings = validate( policy, @@ -358,134 +369,156 @@ def sft_train( val_mbs=sft_config["val_micro_batch_size"], ) - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") policy.prepare_for_training() - for batch in train_dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") - - with timer.time("total_step_time"): - # Prepare batch and generate responses - print("▶ Preparing batch...") - with timer.time("data_processing"): - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - roles_to_train_on=["assistant"], - ) + while ( + current_epoch < max_num_epochs + and total_steps < master_config["sft"]["max_num_steps"] + ): + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - make_sequence_length_divisible_by=master_config["policy"][ - "make_sequence_length_divisible_by" - ], - ) + for batch in train_dataloader: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['sft']['max_num_steps'])} {'=' * 25}" + ) - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } - ) + with timer.time("total_step_time"): + # Prepare batch and generate responses + print("▶ Preparing batch...") + with timer.time("data_processing"): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + roles_to_train_on=["assistant"], + ) - ## train_data.to("cpu") - print("▶ Taking a training step...") - train_results = policy.train(train_data, loss_fn) - - # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: - val_metrics, validation_timings = validate( - policy, - val_dataloader, - tokenizer, - loss_fn, - step=step + 1, - master_config=master_config, - sft_task_spec=sft_task_spec, - val_batches=sft_config["val_batches"], - val_batch_size=sft_config["val_global_batch_size"], - val_mbs=sft_config["val_micro_batch_size"], - ) - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" - ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") - - ## Checkpointing - sft_save_state["consumed_samples"] += master_config["policy"][ - "train_global_batch_size" - ] - if ( - master_config["checkpointing"]["enabled"] - and (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed - is_last_checkpoint = ( - min(len(train_dataloader), master_config["sft"]["max_num_steps"]) - - (step + 1) - < master_config["checkpointing"]["save_period"] - ) + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + make_sequence_length_divisible_by=master_config["policy"][ + "make_sequence_length_divisible_by" + ], + ) - sft_save_state["step"] = step + 1 - sft_save_state["val_loss"] = val_metrics["val_loss"] - with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") - checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, sft_save_state, master_config + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } ) - policy.save_checkpoint( - weights_path=os.path.join(checkpoint_path, "policy", "weights"), - optimizer_path=os.path.join( - checkpoint_path, "policy", "optimizer" - ), - tokenizer_path=os.path.join( - checkpoint_path, "policy", "tokenizer" - ), - save_hf=is_last_checkpoint, + print("▶ Taking a training step...") + train_results = policy.train(train_data, loss_fn) + + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=total_steps + 1, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], ) - torch.save( - train_dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" ) - checkpointer.finalize_checkpoint(checkpoint_path) - - losses = train_results["loss"] - metrics = { - "loss": train_results["loss"].numpy(), - } - metrics.update(train_results["all_mb_metrics"]) - for k, v in metrics.items(): - if k == "num_valid_samples": - metrics[k] = np.sum(v).item() - else: - metrics[k] = np.mean(v).item() - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - - print("\n📊 Training Results:") - print(f" • Loss: {float(metrics['loss']):.4f}") - print("\n⏱️ Timing:") - # Display total time first, separately - total_time = timing_metrics.get("total_step_time", 0) - print(f" • Total step time: {total_time:.2f}s") - - # Display all other timing metrics (if any) - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") - - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - - timer.reset() - step += 1 - - if step >= master_config["sft"]["max_num_steps"]: - break + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + ## Checkpointing + sft_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 + ): # +1 because step is 0-indexed + is_last_checkpoint = ( + min( + len(train_dataloader) * max_num_epochs, + master_config["sft"]["max_num_steps"], + ) + - (total_steps + 1) + < master_config["checkpointing"]["save_period"] + ) + + sft_save_state["step"] = (current_step + 1) % len(train_dataloader) + sft_save_state["total_steps"] = total_steps + 1 + sft_save_state["epoch"] = current_epoch + sft_save_state["val_loss"] = val_metrics["val_loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {total_steps + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, sft_save_state, master_config + ) + + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), + save_hf=is_last_checkpoint, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + losses = train_results["loss"] + metrics = { + "loss": train_results["loss"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + for k, v in metrics.items(): + if k == "num_valid_samples": + metrics[k] = np.sum(v).item() + else: + metrics[k] = np.mean(v).item() + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\n📊 Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\n⏱️ Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + + if total_steps >= master_config["sft"]["max_num_steps"]: + return + + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py new file mode 100644 index 0000000000..3b63c9a8b5 --- /dev/null +++ b/tests/unit/algorithms/test_sft.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import MagicMock +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_reinforcer.algorithms.sft import sft_train, _default_sft_save_state +from nemo_reinforcer.algorithms.loss_functions import NLLLoss + + +@pytest.fixture +def mock_components(): + # Create mock components + policy = MagicMock() + policy.train.return_value = {"loss": torch.tensor(0.5), "all_mb_metrics": {}} + + # Create a proper message log structure with token_ids + mock_batch = { + "message_log": [[{"token_ids": torch.tensor([1, 2, 3]), "role": "assistant"}]], + "loss_multiplier": torch.tensor(1.0), + } + + # Create mock dataloader with 10 batches that can be iterated multiple times + train_dataloader = MagicMock(spec=StatefulDataLoader) + + def train_iter(self): + return iter([mock_batch] * 10) + + train_dataloader.__iter__ = train_iter + train_dataloader.__len__ = MagicMock(return_value=10) + + val_dataloader = MagicMock(spec=StatefulDataLoader) + + def val_iter(self): + return iter([mock_batch] * 10) + + val_dataloader.__iter__ = val_iter + val_dataloader.__len__ = MagicMock(return_value=10) + + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + + loss_fn = NLLLoss() + logger = MagicMock() + checkpointer = MagicMock() + sft_task_spec = MagicMock() + + # Create mock master config + master_config = { + "sft": { + "max_num_steps": 5, + "max_num_epochs": 2, + "val_period": 100, + "val_batches": 1, + "val_global_batch_size": 1, + "val_micro_batch_size": 1, + "val_at_start": False, + }, + "policy": { + "train_global_batch_size": 1, + "make_sequence_length_divisible_by": 8, + }, + "checkpointing": {"enabled": False}, + } + + return { + "policy": policy, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + "tokenizer": tokenizer, + "loss_fn": loss_fn, + "logger": logger, + "checkpointer": checkpointer, + "sft_task_spec": sft_task_spec, + "master_config": master_config, + } + + +def test_exit_on_max_steps(mock_components): + """Test that training loop exits when max_num_steps is reached""" + # Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs + mock_components["master_config"]["sft"]["max_num_steps"] = 12 + + sft_save_state = _default_sft_save_state() + + # Run training + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + # Verify we only trained for 12 steps + assert mock_components["policy"].train.call_count == 12 + + +def test_exit_on_max_epochs(mock_components): + """Test that training loop exits when max_num_epochs is reached""" + # Set max epochs to 2 and max steps to a large number + mock_components["master_config"]["sft"]["max_num_epochs"] = 2 + mock_components["master_config"]["sft"]["max_num_steps"] = 100 + + sft_save_state = _default_sft_save_state() + + # Run training + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + # Verify we trained for exactly two epochs (20 batches) + assert mock_components["policy"].train.call_count == 20