From cd7b0f93c89f6ff3497ee7b59a6f2d60a4f4c48c Mon Sep 17 00:00:00 2001 From: ashors1 Date: Sun, 13 Apr 2025 17:36:57 -0700 Subject: [PATCH 1/6] support multi-epoch training Signed-off-by: ashors1 --- examples/configs/sft.yaml | 1 + nemo_reinforcer/algorithms/sft.py | 267 ++++++++++++--------- nemo_reinforcer/models/policy/hf_policy.py | 2 +- 3 files changed, 151 insertions(+), 119 deletions(-) mode change 100644 => 100755 nemo_reinforcer/algorithms/sft.py diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index bb8467165f..dccd39d5f4 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,5 +1,6 @@ # SFT Algorithm Configuration sft: + max_num_epochs: 1 max_num_steps: 60 val_period: 10 val_batches: 8 diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py old mode 100644 new mode 100755 index b5bb41aec5..583eca36ee --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -40,20 +40,25 @@ class SFTSaveState(TypedDict): - step: int + epoch: int # Track current epoch + step: int # Track step within current epoch + total_steps: int # Track total number of steps across all epochs val_loss: float consumed_samples: int def _default_sft_save_state() -> SFTSaveState: return { + "epoch": 0, "step": 0, + "total_steps": 0, "consumed_samples": 0, } class SFTConfig(TypedDict): max_num_steps: int + max_num_epochs: int val_period: int val_batches: int val_global_batch_size: int @@ -138,6 +143,7 @@ def setup( batch_size=policy_config["train_global_batch_size"], shuffle=True, collate_fn=rl_collate_fn, + drop_last=True, ) if last_checkpoint_path is not None: @@ -309,17 +315,22 @@ def sft_train( if sft_save_state is None: sft_save_state = _default_sft_save_state() - step = 0 + current_epoch = 0 + current_step = 0 + total_steps = 0 else: - step = sft_save_state["step"] + current_epoch = sft_save_state["epoch"] + current_step = sft_save_state["step"] + total_steps = sft_save_state["total_steps"] sft_config = master_config["sft"] # Validation configuration val_period = sft_config["val_period"] val_at_start = sft_config["val_at_start"] + max_num_epochs = sft_config["max_num_epochs"] # Run validation at the start if configured - if val_at_start and step == 0: + if val_at_start and total_steps == 0: print("\n🔍 Running initial validation...") val_metrics, validation_timings = validate( policy, @@ -334,124 +345,144 @@ def sft_train( val_mbs=sft_config["val_micro_batch_size"], ) - logger.log_metrics(val_metrics, step, prefix="validation") - logger.log_metrics(validation_timings, step, prefix="timing/validation") + logger.log_metrics(val_metrics, total_steps, prefix="validation") + logger.log_metrics(validation_timings, total_steps, prefix="timing/validation") policy.prepare_for_training() - for batch in train_dataloader: - print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") - - with timer.time("total_step_time"): - # Prepare batch and generate responses - print("▶ Preparing batch...") - with timer.time("data_processing"): - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - roles_to_train_on=["assistant"], - ) - - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) - - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } - ) - - ## train_data.to("cpu") - print("▶ Taking a training step...") - train_results = policy.train(train_data, loss_fn) - - # Run validation if it's a validation step - if val_period > 0 and (step + 1) % val_period == 0: - val_metrics, validation_timings = validate( - policy, - val_dataloader, - tokenizer, - loss_fn, - step=step + 1, - master_config=master_config, - sft_task_spec=sft_task_spec, - val_batches=sft_config["val_batches"], - val_batch_size=sft_config["val_global_batch_size"], - val_mbs=sft_config["val_micro_batch_size"], - ) - logger.log_metrics( - validation_timings, step + 1, prefix="timing/validation" - ) - logger.log_metrics(val_metrics, step + 1, prefix="validation") - - ## Checkpointing - sft_save_state["consumed_samples"] += master_config["policy"][ - "train_global_batch_size" - ] - if ( - master_config["checkpointing"]["enabled"] - and (step + 1) % master_config["checkpointing"]["save_period"] == 0 - ): # +1 because step is 0-indexed - is_last_checkpoint = ( - min(len(train_dataloader), master_config["sft"]["max_num_steps"]) - - (step + 1) - < master_config["checkpointing"]["save_period"] - ) - - sft_save_state["step"] = step + 1 - sft_save_state["val_loss"] = val_metrics["val_loss"] - with timer.time("checkpointing"): - print(f"Saving checkpoint for step {step + 1}...") - checkpoint_path = checkpointer.init_tmp_checkpoint( - step + 1, sft_save_state, master_config + while current_epoch < max_num_epochs: + print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") + + for batch in train_dataloader: + print( + f"\n{'=' * 25} Step {current_step + 1}/{min(len(train_dataloader), master_config['sft']['max_num_steps'])} {'=' * 25}" + ) + + with timer.time("total_step_time"): + # Prepare batch and generate responses + print("▶ Preparing batch...") + with timer.time("data_processing"): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } ) - policy.save_checkpoint( - weights_path=os.path.join(checkpoint_path, "policy", "weights"), - optimizer_path=os.path.join( - checkpoint_path, "policy", "optimizer" - ), - save_hf=is_last_checkpoint, + ## train_data.to("cpu") + print("▶ Taking a training step...") + train_results = policy.train(train_data, loss_fn) + + # Run validation if it's a validation step + if val_period > 0 and (total_steps + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=total_steps + 1, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], + ) + logger.log_metrics( + validation_timings, total_steps + 1, prefix="timing/validation" ) - torch.save( - train_dataloader.state_dict(), - os.path.join(checkpoint_path, "train_dataloader.pt"), + logger.log_metrics( + val_metrics, total_steps + 1, prefix="validation" + ) + + ## Checkpointing + sft_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (total_steps + 1) + % master_config["checkpointing"]["save_period"] + == 0 + ): # +1 because step is 0-indexed + is_last_checkpoint = ( + min( + len(train_dataloader) * max_num_epochs, + master_config["sft"]["max_num_steps"], + ) + - (total_steps + 1) + < master_config["checkpointing"]["save_period"] ) - checkpointer.finalize_checkpoint(checkpoint_path) - - losses = train_results["loss"] - metrics = { - "loss": train_results["loss"].numpy(), - } - metrics.update(train_results["all_mb_metrics"]) - metrics = {k: np.mean(v).item() for k, v in metrics.items()} - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - - print("\n📊 Training Results:") - print(f" • Loss: {float(metrics['loss']):.4f}") - print("\n⏱️ Timing:") - # Display total time first, separately - total_time = timing_metrics.get("total_step_time", 0) - print(f" • Total step time: {total_time:.2f}s") - - # Display all other timing metrics (if any) - for k, v in sorted( - timing_metrics.items(), key=lambda item: item[1], reverse=True - ): - if k != "total_step_time": - percent = (v / total_time * 100) if total_time > 0 else 0 - print(f" • {k}: {v:.2f}s ({percent:.1f}%)") - - logger.log_metrics(metrics, step + 1, prefix="train") - logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - - timer.reset() - step += 1 - - if step >= master_config["sft"]["max_num_steps"]: - break + + sft_save_state["step"] = (current_step + 1) % len(train_dataloader) + sft_save_state["total_steps"] = total_steps + 1 + sft_save_state["epoch"] = current_epoch + sft_save_state["val_loss"] = val_metrics["val_loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {total_steps + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + total_steps + 1, sft_save_state, master_config + ) + + policy.save_checkpoint( + weights_path=os.path.join( + checkpoint_path, "policy", "weights" + ), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + + losses = train_results["loss"] + metrics = { + "loss": train_results["loss"].numpy(), + } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + + print("\n📊 Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\n⏱️ Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics (if any) + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") + + timer.reset() + current_step += 1 + total_steps += 1 + + if total_steps >= master_config["sft"]["max_num_steps"]: + break + + current_epoch += 1 + current_step = 0 # Reset step counter for new epoch diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 051e56e23f..a8b4b18fc1 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -955,7 +955,7 @@ def train( # Shard and replicate the batch shards = self.dp_size sharded_data = data.shard_by_batch_size( - shards, batch_size=self.cfg["train_global_batch_size"] + shards, batch_size=gbs or self.cfg["train_global_batch_size"] ) # Train each shard in parallel From f6aa8990b66bb9b5b2a99c779c659930bfea65b4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 14 Apr 2025 09:15:28 -0700 Subject: [PATCH 2/6] small fix Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100755 => 100644 nemo_reinforcer/algorithms/sft.py diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py old mode 100755 new mode 100644 index 583eca36ee..e69ecdf87b --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -482,7 +482,7 @@ def sft_train( total_steps += 1 if total_steps >= master_config["sft"]["max_num_steps"]: - break + return current_epoch += 1 current_step = 0 # Reset step counter for new epoch From 3992443a978e61704fb47a6e4116ae50feef6a9e Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 09:57:14 -0700 Subject: [PATCH 3/6] add tests, clean up Signed-off-by: ashors1 --- examples/configs/sft.yaml | 3 + nemo_reinforcer/algorithms/sft.py | 6 +- nemo_reinforcer/models/policy/hf_policy.py | 10 +- tests/unit/algorithms/test_sft.py | 122 +++++++++++++++++++++ 4 files changed, 134 insertions(+), 7 deletions(-) create mode 100644 tests/unit/algorithms/test_sft.py diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index dccd39d5f4..9c1edeaed0 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,7 +1,10 @@ # SFT Algorithm Configuration sft: + ## total number of steps to train will equal + ## min((max_num_epochs * len(train_dataloader)), max_num_steps) max_num_epochs: 1 max_num_steps: 60 + val_period: 10 val_batches: 8 val_global_batch_size: 32 diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index e69ecdf87b..c16c3f85e3 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -350,7 +350,10 @@ def sft_train( policy.prepare_for_training() - while current_epoch < max_num_epochs: + while ( + current_epoch < max_num_epochs + and total_steps < master_config["sft"]["max_num_steps"] + ): print(f"\n{'=' * 25} Epoch {current_epoch + 1}/{max_num_epochs} {'=' * 25}") for batch in train_dataloader: @@ -382,7 +385,6 @@ def sft_train( } ) - ## train_data.to("cpu") print("▶ Taking a training step...") train_results = policy.train(train_data, loss_fn) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index a8b4b18fc1..9dc223fa94 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -952,11 +952,11 @@ def train( mbs: Optional[int] = None, ): """Train the policy on a batch of data with a given loss function.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] # Shard and replicate the batch shards = self.dp_size - sharded_data = data.shard_by_batch_size( - shards, batch_size=gbs or self.cfg["train_global_batch_size"] - ) + sharded_data = data.shard_by_batch_size(shards, batch_size=batch_size) # Train each shard in parallel futures = self.worker_group.run_all_workers_multiple_data( @@ -965,8 +965,8 @@ def train( common_kwargs={ "loss_fn": loss_fn, "eval_mode": eval_mode, - "gbs": gbs, - "mbs": mbs, + "gbs": batch_size, + "mbs": micro_batch_size, }, ) results = self.worker_group.get_all_worker_results(futures) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py new file mode 100644 index 0000000000..29fb6601a2 --- /dev/null +++ b/tests/unit/algorithms/test_sft.py @@ -0,0 +1,122 @@ +import pytest +from unittest.mock import MagicMock +import torch +from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_reinforcer.algorithms.sft import sft_train, _default_sft_save_state +from nemo_reinforcer.algorithms.loss_functions import NLLLoss + + +@pytest.fixture +def mock_components(): + # Create mock components + policy = MagicMock() + policy.train.return_value = {"loss": torch.tensor(0.5), "all_mb_metrics": {}} + + # Create a proper message log structure with token_ids + mock_batch = { + "message_log": [[{"token_ids": torch.tensor([1, 2, 3]), "role": "assistant"}]], + "loss_multiplier": torch.tensor(1.0), + } + + # Create mock dataloader with 10 batches that can be iterated multiple times + train_dataloader = MagicMock(spec=StatefulDataLoader) + + def train_iter(self): + return iter([mock_batch] * 10) + + train_dataloader.__iter__ = train_iter + train_dataloader.__len__ = MagicMock(return_value=10) + + val_dataloader = MagicMock(spec=StatefulDataLoader) + + def val_iter(self): + return iter([mock_batch] * 10) + + val_dataloader.__iter__ = val_iter + val_dataloader.__len__ = MagicMock(return_value=10) + + tokenizer = MagicMock() + tokenizer.pad_token_id = 0 + + loss_fn = NLLLoss() + logger = MagicMock() + checkpointer = MagicMock() + sft_task_spec = MagicMock() + + # Create mock master config + master_config = { + "sft": { + "max_num_steps": 5, + "max_num_epochs": 2, + "val_period": 100, + "val_batches": 1, + "val_global_batch_size": 1, + "val_micro_batch_size": 1, + "val_at_start": False, + }, + "policy": {"train_global_batch_size": 1}, + "checkpointing": {"enabled": False}, + } + + return { + "policy": policy, + "train_dataloader": train_dataloader, + "val_dataloader": val_dataloader, + "tokenizer": tokenizer, + "loss_fn": loss_fn, + "logger": logger, + "checkpointer": checkpointer, + "sft_task_spec": sft_task_spec, + "master_config": master_config, + } + + +def test_exit_on_max_steps(mock_components): + """Test that training loop exits when max_num_steps is reached""" + # Set max steps to 12, which is less than len(train_dataloader) * max_num_epochs + mock_components["master_config"]["sft"]["max_num_steps"] = 12 + + sft_save_state = _default_sft_save_state() + + # Run training + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + # Verify we only trained for 12 steps + assert mock_components["policy"].train.call_count == 12 + + +def test_exit_on_max_epochs(mock_components): + """Test that training loop exits when max_num_epochs is reached""" + # Set max epochs to 2 and max steps to a large number + mock_components["master_config"]["sft"]["max_num_epochs"] = 2 + mock_components["master_config"]["sft"]["max_num_steps"] = 100 + + sft_save_state = _default_sft_save_state() + + # Run training + sft_train( + mock_components["policy"], + mock_components["train_dataloader"], + mock_components["val_dataloader"], + mock_components["tokenizer"], + mock_components["loss_fn"], + mock_components["master_config"], + mock_components["logger"], + mock_components["sft_task_spec"], + mock_components["checkpointer"], + sft_save_state, + ) + + # Verify we trained for exactly two epochs (20 batches) + assert mock_components["policy"].train.call_count == 20 From fd00d6c386c359698ace87861ac4e7f5134b7820 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 16 Apr 2025 10:00:08 -0700 Subject: [PATCH 4/6] add copyright header Signed-off-by: ashors1 --- tests/unit/algorithms/test_sft.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 29fb6601a2..6660765b81 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from unittest.mock import MagicMock import torch From e9c985eb9112ff4f2a85bacd2208386bb9b3c029 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Mon, 21 Apr 2025 17:50:44 -0700 Subject: [PATCH 5/6] fix issue with rebase Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 169f3ca81a..8693d7edab 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -455,6 +455,9 @@ def sft_train( optimizer_path=os.path.join( checkpoint_path, "policy", "optimizer" ), + tokenizer_path=os.path.join( + checkpoint_path, "policy", "tokenizer" + ), save_hf=is_last_checkpoint, ) torch.save( From 85a70fc557d89e13d78d837cf57185215c0d8276 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Tue, 22 Apr 2025 09:52:28 -0700 Subject: [PATCH 6/6] fix sft test following rebase Signed-off-by: ashors1 --- tests/unit/algorithms/test_sft.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 6660765b81..3b63c9a8b5 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -68,7 +68,10 @@ def val_iter(self): "val_micro_batch_size": 1, "val_at_start": False, }, - "policy": {"train_global_batch_size": 1}, + "policy": { + "train_global_batch_size": 1, + "make_sequence_length_divisible_by": 8, + }, "checkpointing": {"enabled": False}, }