From e14aa12d4b2ac7cdc9e4e3440b4a4d6d3c388fab Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 20:22:42 -0700 Subject: [PATCH 01/11] add validation and checkpointing to sft Signed-off-by: ashors1 --- examples/configs/sft.yaml | 16 +- examples/run_sft.py | 10 +- nemo_reinforcer/algorithms/sft.py | 385 ++++++++++++++++++--- nemo_reinforcer/models/policy/hf_policy.py | 46 ++- 4 files changed, 392 insertions(+), 65 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index abe1b7ed98..44a4b15d9c 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,9 +1,17 @@ # SFT Algorithm Configuration sft: - num_steps: 100 - #val_period: 10 - #val_at_start: true - #checkpoint_dir: "results/sft" + num_steps: 20 + val_period: 1 + val_batches: 8 + val_at_start: true + +checkpointing: + enabled: true + checkpoint_dir: "results/sft" ## TODO: get checkpointing to work with relative paths + metric_name: "val_loss" + higher_is_better: false + keep_top_k: 3 + save_period: 1 policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" diff --git a/examples/run_sft.py b/examples/run_sft.py index f8649a9484..a299614abd 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -69,21 +69,27 @@ def main(): ( policy, cluster, - dataloader, + train_dataloader, + val_dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec, + checkpointer, + sft_save_state, ) = setup(config) sft_train( policy, - dataloader, + train_dataloader, + val_dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec, + checkpointer, + sft_save_state, ) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 1f948b1b5c..aec9b2f9db 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Tuple, TypedDict +import os +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, TypedDict -from torch.utils.data import DataLoader +import torch +from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer - from nemo_reinforcer.algorithms.loss_functions import ( NLLLoss, ) @@ -29,14 +31,32 @@ ) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.models.interfaces import PolicyInterface from nemo_reinforcer.models.policy.hf_policy import HfPolicy from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.checkpoint import CheckpointManager, CheckpointingConfig from nemo_reinforcer.utils.logger import Logger, LoggerConfig from nemo_reinforcer.utils.timer import Timer +class SFTSaveState(TypedDict): + step: int + val_loss: float + consumed_samples: int + + +def _default_sft_save_state() -> SFTSaveState: + return { + "step": 0, + "consumed_samples": 0, + } + + class SFTConfig(TypedDict): num_steps: int + val_period: int + val_at_start: bool + checkpoint_dir: str class MasterConfig(TypedDict): @@ -45,8 +65,12 @@ class MasterConfig(TypedDict): sft: SFTConfig logger: LoggerConfig cluster: ClusterConfig + checkpointing: CheckpointingConfig +# ======================================================= +# Data Processing +# ======================================================= def sft_preprocessor( datum_dict: Dict[str, Any], task_data_spec: TaskDataSpec, @@ -80,16 +104,22 @@ def sft_preprocessor( return output +# ======================================================= +# Setup & Initialization +# ======================================================= def setup( master_config: MasterConfig, ) -> Tuple[ HfPolicy, RayVirtualCluster, - DataLoader, + StatefulDataLoader, + Optional[StatefulDataLoader], AutoTokenizer, NLLLoss, MasterConfig, Logger, + TaskDataSpec, + SFTSaveState, ]: """Main entry point for running SFT algorithm. @@ -101,8 +131,33 @@ def setup( data_config = master_config["data"] logger_config = master_config["logger"] cluster_config = master_config["cluster"] + sft_config = master_config["sft"] + + # ========================== + # Checkpointing + # ========================== + checkpointer = CheckpointManager(master_config["checkpointing"]) + last_checkpoint_path = checkpointer.get_latest_checkpoint_path() + sft_save_state: Optional[SFTSaveState] = checkpointer.load_training_info( + last_checkpoint_path + ) + # config validation checks + if master_config["checkpointing"]["enabled"]: + assert master_config["checkpointing"]["save_period"] > 0 + assert ( + master_config["checkpointing"]["save_period"] + % master_config["sft"]["val_period"] + == 0 + ), ( + f"Checkpointing save period {master_config['checkpointing']['save_period']} " + f"must be a multiple of validation period {master_config['sft']['val_period']}" + f", or we won't know what metric to save!" + ) - ## TODO: unify this with grpo + # ========================== + # Data + # ========================== + print("\n▶ Setting up data...") data_cls = data_config["dataset_name"] if data_cls == "open_assistant": data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") @@ -110,27 +165,58 @@ def setup( data = hf_datasets.SquadDataset() else: raise ValueError(f"Unknown dataset class: {data_cls}") + print( + f" ✓ Training and validation dataset loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively." + ) - base_dataset = data.formatted_ds["train"] + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] sft_task_spec = data.task_spec tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) - dataset = AllTaskProcessedDataset( - base_dataset, + train_dataset = AllTaskProcessedDataset( + train_dataset, tokenizer, sft_task_spec, sft_preprocessor, max_seq_length=data_config["max_input_seq_length"], ) - dataloader = DataLoader( - dataset, + train_dataloader = StatefulDataLoader( + train_dataset, batch_size=policy_config["train_global_batch_size"], - shuffle=False, + shuffle=True, collate_fn=rl_collate_fn, ## TODO: change this for sft! or make it more general ) + if last_checkpoint_path is not None: + dataloader_state_dict = torch.load( + os.path.join(last_checkpoint_path, "train_dataloader.pt") + ) + train_dataloader.load_state_dict(dataloader_state_dict) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + ## TODO: support different batch sizes for train and val + val_dataloader = StatefulDataLoader( + val_dataset, + batch_size=policy_config["train_global_batch_size"], + shuffle=False, + collate_fn=rl_collate_fn, + drop_last=True, + ) + + # ========================== + # Cluster + # ========================== + print("\n▶ Setting up compute cluster...") cluster = RayVirtualCluster( name="sft_cluster", bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] @@ -139,72 +225,275 @@ def setup( num_gpus_per_node=cluster_config["gpus_per_node"], max_colocated_worker_groups=1, ) - - policy = HfPolicy(cluster=cluster, config=policy_config) + print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Training + # ========================== + print("\n▶ Setting up model...") + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=Path(last_checkpoint_path) / "policy.pt" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + if last_checkpoint_path + else None, + init_optimizer=True, + ) loss_fn = NLLLoss() + print(f" ✓ Model initialized") logger = Logger(logger_config) + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + return ( policy, cluster, - dataloader, + train_dataloader, + val_dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec, + checkpointer, + sft_save_state, ) -def sft_train( - policy, dataloader, tokenizer, loss_fn, master_config, logger, sft_task_spec +# ======================================================= +# Training & Validation +# ======================================================= +def validate( + policy: PolicyInterface, + val_dataloader: StatefulDataLoader, + tokenizer, + loss_fn, + step: int, + master_config: MasterConfig, + sft_task_spec: TaskDataSpec, + val_batches: int, ): - # Run basic sft training + """Run validation on the validation dataset.""" + if val_dataloader is None: + print(" ⚠️ No validation dataloader provided, skipping validation") + return + timer = Timer() - policy.prepare_for_training() + with timer.time("total_validation_time"): + print(f"▶ Starting validation at step {step}...") - for step, batch in enumerate(dataloader): - timer.start("sft_train_step") + # Show a progress indicator for validation + # val_total = len(val_dataloader) - timer.start("data_processing") - ## add loss mask based on role to every message - add_loss_mask_to_message_log( - batch["message_log"], - roles_to_train_on=["assistant"], - ) + val_metrics = {"val_loss": 0.0} - cat_and_padded, input_lengths = batched_message_log_to_flat_message( - batch["message_log"], - pad_value_dict={"token_ids": tokenizer.eos_token_id}, - ) + for batch_idx, val_batch in enumerate(val_dataloader): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + val_batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + val_batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + val_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": val_batch["loss_multiplier"], + } + ) + + ## just run model fwd + val_results = policy.train(val_data, loss_fn, eval_mode=True) + val_metrics["val_loss"] += float(val_results["loss"]) + + if val_batches > 0 and batch_idx >= val_batches: + break + + val_metrics["val_loss"] /= val_batches + + # Calculate validation metrics + policy.prepare_for_training() + + # Get timing metrics + timing_metrics = timer.get_timing_metrics(reduction_op="sum") + validation_time = timing_metrics.get("total_validation_time", 0) + + # Print summary of validation results + print("\n📊 Validation Results:") + print(f" • Validation loss: {val_metrics['val_loss']:.4f}") + + # Print timing information + print("\n ⏱️ Validation Timing:") + validation_time = timing_metrics.get("total_validation_time", 0) + print(f" • Total validation time: {validation_time:.2f}s") + + # Make sure to reset the timer after validation + timer.reset() + + return val_metrics, timing_metrics - train_data: BatchedDataDict = BatchedDataDict( - { - "input_ids": cat_and_padded["token_ids"], - "input_lengths": input_lengths, - "token_mask": cat_and_padded["token_loss_mask"], - "sample_mask": batch["loss_multiplier"], - } + +def sft_train( + policy, + train_dataloader, + val_dataloader, + tokenizer, + loss_fn, + master_config, + logger, + sft_task_spec, + checkpointer, + sft_save_state, +): + # Run basic sft training + timer = Timer() + + if sft_save_state is None: + sft_save_state = _default_sft_save_state() + step = 0 + else: + step = ( + sft_save_state["step"] + 1 + ) # N+1 because the checkpoint is _after_ SFT iteration N + + # Validation configuration + val_period = master_config["sft"]["val_period"] + val_at_start = master_config["sft"]["val_at_start"] + + # Run validation at the start if configured + if val_at_start and step == 0: + print("\n🔍 Running initial validation...") + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=0, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=master_config["sft"]["val_batches"], ) - timer.stop("data_processing") - ## train_data.to("cpu") - train_results = policy.train(train_data, loss_fn) - timer.stop("sft_train_step") + logger.log_metrics(val_metrics, step, prefix="validation") + logger.log_metrics(validation_timings, step, prefix="timing/validation") + + policy.prepare_for_training() + + for batch in train_dataloader: + print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") + + with timer.time("total_step_time"): + # Prepare batch and generate responses + print("▶ Preparing batch...") + with timer.time("data_processing"): + ## add loss mask based on role to every message + add_loss_mask_to_message_log( + batch["message_log"], + roles_to_train_on=["assistant"], + ) + + cat_and_padded, input_lengths = batched_message_log_to_flat_message( + batch["message_log"], + pad_value_dict={"token_ids": tokenizer.eos_token_id}, + ) + + train_data: BatchedDataDict = BatchedDataDict( + { + "input_ids": cat_and_padded["token_ids"], + "input_lengths": input_lengths, + "token_mask": cat_and_padded["token_loss_mask"], + "sample_mask": batch["loss_multiplier"], + } + ) + + ## train_data.to("cpu") + print("▶ Taking a training step...") + train_results = policy.train(train_data, loss_fn) + + # Run validation if it's a validation step + if val_period > 0 and (step + 1) % val_period == 0: + val_metrics, validation_timings = validate( + policy, + val_dataloader, + tokenizer, + loss_fn, + step=step + 1, + master_config=master_config, + sft_task_spec=sft_task_spec, + val_batches=master_config["sft"]["val_batches"], + ) + logger.log_metrics( + validation_timings, step + 1, prefix="timing/validation" + ) + logger.log_metrics(val_metrics, step + 1, prefix="validation") + + ## Checkpointing + sft_save_state["consumed_samples"] += master_config["policy"][ + "train_global_batch_size" + ] + if ( + master_config["checkpointing"]["enabled"] + and (step + 1) % master_config["checkpointing"]["save_period"] == 0 + ): # +1 because step is 0-indexed + sft_save_state["step"] = step + sft_save_state["val_loss"] = val_metrics["val_loss"] + with timer.time("checkpointing"): + print(f"Saving checkpoint for step {step + 1}...") + checkpoint_path = checkpointer.init_tmp_checkpoint( + step + 1, sft_save_state, master_config + ) + policy.save_checkpoint( + os.path.join(checkpoint_path, "policy.pt"), + os.path.join(checkpoint_path, "policy_optimizer.pt"), + offload_to_cpu=False, + ) + torch.save( + train_dataloader.state_dict(), + os.path.join(checkpoint_path, "train_dataloader.pt"), + ) + checkpointer.finalize_checkpoint(checkpoint_path) + losses = train_results["loss"] timing_metrics = timer.get_timing_metrics(reduction_op="sum") - print(f"Step {step} completed. Loss: {losses[-1].item()}") + metrics = { + "loss": losses.numpy(), + } + + print("\n📊 Training Results:") + print(f" • Loss: {float(metrics['loss']):.4f}") + print("\n⏱️ Timing:") + # Display total time first, separately + total_time = timing_metrics.get("total_step_time", 0) + print(f" • Total step time: {total_time:.2f}s") + + # Display all other timing metrics + ## TODO: remove? + for k, v in sorted( + timing_metrics.items(), key=lambda item: item[1], reverse=True + ): + if k != "total_step_time": + percent = (v / total_time * 100) if total_time > 0 else 0 + print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + + logger.log_metrics(metrics, step + 1, prefix="train") + logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - logger.log_metrics( - {"loss": losses[-1].item()}, - step, - prefix="train", - ) - logger.log_metrics(timing_metrics, step, prefix="timing/train") timer.reset() + step += 1 if step >= master_config["sft"]["num_steps"] - 1: break diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 627c221cee..ee0c2f9eb3 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -238,7 +238,12 @@ def get_gpu_info(self): }, } - def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" mbs = self.cfg["train_micro_batch_size"] gbs = self.cfg["train_global_batch_size"] @@ -284,16 +289,18 @@ def train(self, data: BatchedDataDict, loss_fn: LossFunction) -> Dict[str, Any]: loss, loss_metrics = loss_fn(logits, mb) # Backward pass - loss.backward() + if not eval_mode: + loss.backward() mb_losses.append(loss.item()) all_mb_metrics.append(loss_metrics) # Clip gradients - torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + if not eval_mode: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - # Update parameters - self.optimizer.step() - self.scheduler.step() + # Update parameters + self.optimizer.step() + self.scheduler.step() losses.append(torch.tensor(mb_losses).mean().item()) # Compute global loss across all ranks @@ -744,9 +751,16 @@ def move_to_cpu(self, model): return model - def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): # Config to save full state dict on rank 0, offloaded to CPU - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + state_dict_config = FullStateDictConfig( + offload_to_cpu=offload_to_cpu, rank0_only=True + ) with FullyShardedDataParallel.state_dict_type( self.model, @@ -864,7 +878,9 @@ def get_reference_policy_logprobs( ) return logprobs - def train(self, data: BatchedDataDict, loss_fn: LossFunction): + def train( + self, data: BatchedDataDict, loss_fn: LossFunction, eval_mode: bool = False + ): """Train the policy on a batch of data with a given loss function.""" # Shard and replicate the batch shards = self.dp_size @@ -874,7 +890,9 @@ def train(self, data: BatchedDataDict, loss_fn: LossFunction): # Train each shard in parallel futures = self.worker_group.run_all_workers_multiple_data( - "train", sharded_data, common_kwargs={"loss_fn": loss_fn} + "train", + sharded_data, + common_kwargs={"loss_fn": loss_fn, "eval_mode": eval_mode}, ) results = self.worker_group.get_all_worker_results(futures) @@ -987,12 +1005,18 @@ def offload_after_refit(self): ) ray.get(futures) - def save_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path, optimizer_path, + offload_to_cpu=offload_to_cpu, respect_tied_workers=True, ) ray.get(futures) From b8c7b20c6f35d28138059bff0f0d6572cb36ee66 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 20:43:21 -0700 Subject: [PATCH 02/11] cleanup, support different batch sizes for train and val Signed-off-by: ashors1 --- examples/configs/sft.yaml | 2 + examples/run_sft.py | 89 ++++++++++++++++++++-- nemo_reinforcer/algorithms/sft.py | 83 +++----------------- nemo_reinforcer/models/policy/hf_policy.py | 8 +- 4 files changed, 99 insertions(+), 83 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 44a4b15d9c..a4e2f5918e 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -3,6 +3,8 @@ sft: num_steps: 20 val_period: 1 val_batches: 8 + val_global_batch_size: 2 + val_micro_batch_size: 2 val_at_start: true checkpointing: diff --git a/examples/run_sft.py b/examples/run_sft.py index a299614abd..14ef42d399 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -20,9 +20,9 @@ from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.utils.config import load_config from nemo_reinforcer.utils.logger import get_next_experiment_dir - def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Run SFT training with configuration") @@ -38,6 +38,77 @@ def parse_args(): return args, overrides +# ======================================================= +# Data Processing +# ======================================================= +def sft_preprocessor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary for SFT training.""" + message_log = get_formatted_message_log( + datum_dict["messages"], tokenizer, task_data_spec + ) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": None, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + return output + +def setup_data(data_config: DataConfig, policy_config: PolicyConfig): + print("\n▶ Setting up data...") + data_cls = data_config["dataset_name"] + if data_cls == "open_assistant": + data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") + elif data_cls == "squad": + data = hf_datasets.SquadDataset() + else: + raise ValueError(f"Unknown dataset class: {data_cls}") + print( + f" ✓ Training and validation datasets loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively." + ) + + train_dataset = data.formatted_ds["train"] + val_dataset = data.formatted_ds["validation"] + sft_task_spec = data.task_spec + + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) + + train_dataset = AllTaskProcessedDataset( + train_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + val_dataset = AllTaskProcessedDataset( + val_dataset, + tokenizer, + sft_task_spec, + sft_preprocessor, + max_seq_length=data_config["max_input_seq_length"], + ) + + return train_dataset, val_dataset, tokenizer def main(): """Main entry point.""" @@ -47,13 +118,12 @@ def main(): if not args.config: args.config = os.path.join(os.path.dirname(__file__), "configs", "sft.yaml") - config = OmegaConf.load(args.config) + config = load_config(args.config) print(f"Loaded configuration from: {args.config}") if overrides: - override_conf = OmegaConf.from_cli() - print(f"Overrides: {override_conf}") - config = OmegaConf.merge(config, override_conf) + print(f"Overrides: {overrides}") + config = OmegaConf.merge(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") @@ -66,18 +136,23 @@ def main(): print(f"📊 Using log directory: {config['logger']['log_dir']}") init_ray() + + + # setup data + dataset, val_dataset, tokenizer = setup_data( + config["data"], config["policy"] + ) ( policy, cluster, train_dataloader, val_dataloader, - tokenizer, loss_fn, - master_config, logger, sft_task_spec, checkpointer, sft_save_state, + master_config, ) = setup(config) sft_train( policy, diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index aec9b2f9db..f04155ed0c 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -67,43 +67,6 @@ class MasterConfig(TypedDict): cluster: ClusterConfig checkpointing: CheckpointingConfig - -# ======================================================= -# Data Processing -# ======================================================= -def sft_preprocessor( - datum_dict: Dict[str, Any], - task_data_spec: TaskDataSpec, - tokenizer, - max_seq_length: int, - idx: int, -) -> DatumSpec: - """Process a datum dictionary for SFT training.""" - message_log = get_formatted_message_log( - datum_dict["messages"], tokenizer, task_data_spec - ) - - length = sum(len(m["token_ids"]) for m in message_log) - - loss_multiplier = 1.0 - if length > max_seq_length: - # make smaller and mask out - for message in message_log: - message["token_ids"] = message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] - loss_multiplier = 0.0 - - output = { - "message_log": message_log, - "length": length, - "extra_env_info": None, - "loss_multiplier": loss_multiplier, - "idx": idx, - } - return output - - # ======================================================= # Setup & Initialization # ======================================================= @@ -157,32 +120,6 @@ def setup( # ========================== # Data # ========================== - print("\n▶ Setting up data...") - data_cls = data_config["dataset_name"] - if data_cls == "open_assistant": - data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant") - elif data_cls == "squad": - data = hf_datasets.SquadDataset() - else: - raise ValueError(f"Unknown dataset class: {data_cls}") - print( - f" ✓ Training and validation dataset loaded with {len(data.formatted_ds['train'])} and {len(data.formatted_ds['validation'])} samples, respectively." - ) - - train_dataset = data.formatted_ds["train"] - val_dataset = data.formatted_ds["validation"] - sft_task_spec = data.task_spec - - tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) - - train_dataset = AllTaskProcessedDataset( - train_dataset, - tokenizer, - sft_task_spec, - sft_preprocessor, - max_seq_length=data_config["max_input_seq_length"], - ) - train_dataloader = StatefulDataLoader( train_dataset, batch_size=policy_config["train_global_batch_size"], @@ -196,18 +133,11 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - val_dataset = AllTaskProcessedDataset( - val_dataset, - tokenizer, - sft_task_spec, - sft_preprocessor, - max_seq_length=data_config["max_input_seq_length"], - ) ## TODO: support different batch sizes for train and val val_dataloader = StatefulDataLoader( val_dataset, - batch_size=policy_config["train_global_batch_size"], + batch_size=sft_config["val_global_batch_size"], shuffle=False, collate_fn=rl_collate_fn, drop_last=True, @@ -256,13 +186,12 @@ def setup( cluster, train_dataloader, val_dataloader, - tokenizer, loss_fn, - master_config, logger, sft_task_spec, checkpointer, sft_save_state, + master_config, ) @@ -316,7 +245,13 @@ def validate( ) ## just run model fwd - val_results = policy.train(val_data, loss_fn, eval_mode=True) + val_results = policy.train( + val_data, + loss_fn, + eval_mode=True, + gbs=sft_config["val_global_batch_size"], + mbs=sft_config["val_micro_batch_size"], + ) val_metrics["val_loss"] += float(val_results["loss"]) if val_batches > 0 and batch_idx >= val_batches: diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index ee0c2f9eb3..31e8e10b25 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -243,10 +243,14 @@ def train( data: BatchedDataDict, loss_fn: LossFunction, eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, ) -> Dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - mbs = self.cfg["train_micro_batch_size"] - gbs = self.cfg["train_global_batch_size"] + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] local_gbs = gbs // torch.distributed.get_world_size() dataset_size = data.get("input_ids").shape[0] From f8ba6fec04b3d4984d9f472db3bd63275de3f265 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 20:47:22 -0700 Subject: [PATCH 03/11] cleanup Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index f04155ed0c..5ce775f8b5 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -55,9 +55,10 @@ def _default_sft_save_state() -> SFTSaveState: class SFTConfig(TypedDict): num_steps: int val_period: int + val_batches: int + val_global_batch_size: int + val_micro_batch_size: int val_at_start: bool - checkpoint_dir: str - class MasterConfig(TypedDict): policy: PolicyConfig @@ -77,7 +78,6 @@ def setup( RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], - AutoTokenizer, NLLLoss, MasterConfig, Logger, @@ -124,7 +124,7 @@ def setup( train_dataset, batch_size=policy_config["train_global_batch_size"], shuffle=True, - collate_fn=rl_collate_fn, ## TODO: change this for sft! or make it more general + collate_fn=rl_collate_fn, ) if last_checkpoint_path is not None: @@ -134,7 +134,6 @@ def setup( train_dataloader.load_state_dict(dataloader_state_dict) - ## TODO: support different batch sizes for train and val val_dataloader = StatefulDataLoader( val_dataset, batch_size=sft_config["val_global_batch_size"], @@ -415,8 +414,7 @@ def sft_train( total_time = timing_metrics.get("total_step_time", 0) print(f" • Total step time: {total_time:.2f}s") - # Display all other timing metrics - ## TODO: remove? + # Display all other timing metrics (if any) for k, v in sorted( timing_metrics.items(), key=lambda item: item[1], reverse=True ): From f91cbbb7d09573a11b6ad421f02b037c66e45e27 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 21:41:15 -0700 Subject: [PATCH 04/11] bug fixes Signed-off-by: ashors1 --- examples/run_sft.py | 19 +++++++++++---- nemo_reinforcer/algorithms/sft.py | 27 ++++++++++++++-------- nemo_reinforcer/models/policy/hf_policy.py | 16 ++++++++++--- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/examples/run_sft.py b/examples/run_sft.py index 14ef42d399..de0ef8c1ce 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -15,6 +15,7 @@ import argparse import os import pprint +from typing import Dict, Any from omegaconf import OmegaConf @@ -22,6 +23,13 @@ from nemo_reinforcer.distributed.virtual_cluster import init_ray from nemo_reinforcer.utils.config import load_config from nemo_reinforcer.utils.logger import get_next_experiment_dir +from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log +from transformers import AutoTokenizer +from nemo_reinforcer.models.policy import PolicyConfig + def parse_args(): """Parse command line arguments.""" @@ -38,6 +46,7 @@ def parse_args(): return args, overrides + # ======================================================= # Data Processing # ======================================================= @@ -73,6 +82,7 @@ def sft_preprocessor( } return output + def setup_data(data_config: DataConfig, policy_config: PolicyConfig): print("\n▶ Setting up data...") data_cls = data_config["dataset_name"] @@ -108,7 +118,8 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig): max_seq_length=data_config["max_input_seq_length"], ) - return train_dataset, val_dataset, tokenizer + return train_dataset, val_dataset, tokenizer, sft_task_spec + def main(): """Main entry point.""" @@ -137,9 +148,8 @@ def main(): init_ray() - # setup data - dataset, val_dataset, tokenizer = setup_data( + dataset, val_dataset, tokenizer, sft_task_spec = setup_data( config["data"], config["policy"] ) ( @@ -149,11 +159,10 @@ def main(): val_dataloader, loss_fn, logger, - sft_task_spec, checkpointer, sft_save_state, master_config, - ) = setup(config) + ) = setup(config, dataset, val_dataset) sft_train( policy, train_dataloader, diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 5ce775f8b5..c2d44e6053 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -60,6 +60,7 @@ class SFTConfig(TypedDict): val_micro_batch_size: int val_at_start: bool + class MasterConfig(TypedDict): policy: PolicyConfig data: DataConfig @@ -68,16 +69,19 @@ class MasterConfig(TypedDict): cluster: ClusterConfig checkpointing: CheckpointingConfig + # ======================================================= # Setup & Initialization # ======================================================= def setup( master_config: MasterConfig, + train_dataset: AllTaskProcessedDataset, + val_dataset: AllTaskProcessedDataset, ) -> Tuple[ HfPolicy, RayVirtualCluster, StatefulDataLoader, - Optional[StatefulDataLoader], + StatefulDataLoader, NLLLoss, MasterConfig, Logger, @@ -133,7 +137,6 @@ def setup( ) train_dataloader.load_state_dict(dataloader_state_dict) - val_dataloader = StatefulDataLoader( val_dataset, batch_size=sft_config["val_global_batch_size"], @@ -187,7 +190,6 @@ def setup( val_dataloader, loss_fn, logger, - sft_task_spec, checkpointer, sft_save_state, master_config, @@ -206,6 +208,8 @@ def validate( master_config: MasterConfig, sft_task_spec: TaskDataSpec, val_batches: int, + val_batch_size: int, + val_mbs: int, ): """Run validation on the validation dataset.""" if val_dataloader is None: @@ -248,8 +252,8 @@ def validate( val_data, loss_fn, eval_mode=True, - gbs=sft_config["val_global_batch_size"], - mbs=sft_config["val_micro_batch_size"], + gbs=val_batch_size, + mbs=val_mbs, ) val_metrics["val_loss"] += float(val_results["loss"]) @@ -303,9 +307,10 @@ def sft_train( sft_save_state["step"] + 1 ) # N+1 because the checkpoint is _after_ SFT iteration N + sft_config = master_config["sft"] # Validation configuration - val_period = master_config["sft"]["val_period"] - val_at_start = master_config["sft"]["val_at_start"] + val_period = sft_config["val_period"] + val_at_start = sft_config["val_at_start"] # Run validation at the start if configured if val_at_start and step == 0: @@ -318,7 +323,9 @@ def sft_train( step=0, master_config=master_config, sft_task_spec=sft_task_spec, - val_batches=master_config["sft"]["val_batches"], + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], ) logger.log_metrics(val_metrics, step, prefix="validation") @@ -367,7 +374,9 @@ def sft_train( step=step + 1, master_config=master_config, sft_task_spec=sft_task_spec, - val_batches=master_config["sft"]["val_batches"], + val_batches=sft_config["val_batches"], + val_batch_size=sft_config["val_global_batch_size"], + val_mbs=sft_config["val_micro_batch_size"], ) logger.log_metrics( validation_timings, step + 1, prefix="timing/validation" diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 31e8e10b25..0224991059 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -883,7 +883,12 @@ def get_reference_policy_logprobs( return logprobs def train( - self, data: BatchedDataDict, loss_fn: LossFunction, eval_mode: bool = False + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, ): """Train the policy on a batch of data with a given loss function.""" # Shard and replicate the batch @@ -896,7 +901,12 @@ def train( futures = self.worker_group.run_all_workers_multiple_data( "train", sharded_data, - common_kwargs={"loss_fn": loss_fn, "eval_mode": eval_mode}, + common_kwargs={ + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": gbs, + "mbs": mbs, + }, ) results = self.worker_group.get_all_worker_results(futures) From 93ff8ae630d3b6058a36fe3a3f16c62f9c3d3de6 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 21:52:35 -0700 Subject: [PATCH 05/11] set more reasonable defaults for sft Signed-off-by: ashors1 --- examples/configs/sft.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index a4e2f5918e..45c9ff2fae 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,10 +1,10 @@ # SFT Algorithm Configuration sft: num_steps: 20 - val_period: 1 + val_period: 10 val_batches: 8 - val_global_batch_size: 2 - val_micro_batch_size: 2 + val_global_batch_size: 64 + val_micro_batch_size: 8 val_at_start: true checkpointing: @@ -13,12 +13,12 @@ checkpointing: metric_name: "val_loss" higher_is_better: false keep_top_k: 3 - save_period: 1 + save_period: 10 policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 8 - train_micro_batch_size: 2 + train_global_batch_size: 64 + train_micro_batch_size: 8 learning_rate: 5.0e-6 max_total_sequence_length: 1024 @@ -48,5 +48,5 @@ logger: log_dir: "tb_logs" cluster: - gpus_per_node: 1 + gpus_per_node: 2 num_nodes: 1 From d5b769e332904702e8d16a44e547ea01464477a1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 22:00:53 -0700 Subject: [PATCH 06/11] remove old todo Signed-off-by: ashors1 --- examples/configs/sft.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 45c9ff2fae..33abd6490c 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -9,7 +9,7 @@ sft: checkpointing: enabled: true - checkpoint_dir: "results/sft" ## TODO: get checkpointing to work with relative paths + checkpoint_dir: "results/sft" metric_name: "val_loss" higher_is_better: false keep_top_k: 3 From fbd2b3faba89713a0fd4e5c595335742f586d2ea Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 22:04:12 -0700 Subject: [PATCH 07/11] remove unused imports Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index c2d44e6053..ce6d0c1174 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -13,21 +13,19 @@ # limitations under the License. import os from pathlib import Path -from typing import Any, Dict, Optional, Tuple, TypedDict +from typing import Optional, Tuple, TypedDict import torch from torchdata.stateful_dataloader import StatefulDataLoader -from transformers import AutoTokenizer from nemo_reinforcer.algorithms.loss_functions import ( NLLLoss, ) -from nemo_reinforcer.data import DataConfig, hf_datasets +from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn -from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec +from nemo_reinforcer.data.interfaces import TaskDataSpec from nemo_reinforcer.data.llm_message_utils import ( add_loss_mask_to_message_log, batched_message_log_to_flat_message, - get_formatted_message_log, ) from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster From 4446b2b56a7e143097e6bc7984ff8fe7a05cdcad Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 22:08:53 -0700 Subject: [PATCH 08/11] fix copyright Signed-off-by: ashors1 --- nemo_reinforcer/models/policy/hf_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c9ce152d3c..fdf1496bce 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -1,4 +1,4 @@ -# (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 3dd006ae56bc5bd7d3ba0fe4b653766a6c7915bd Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 20 Mar 2025 23:05:12 -0700 Subject: [PATCH 09/11] fix issue sft ending one step early Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index ce6d0c1174..85604a3851 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -435,5 +435,5 @@ def sft_train( timer.reset() step += 1 - if step >= master_config["sft"]["num_steps"] - 1: + if step >= master_config["sft"]["num_steps"]: break From 3228b5ecd0c3c91fcac7d2d76cb0b91cf29548c4 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 21 Mar 2025 00:15:31 -0700 Subject: [PATCH 10/11] address comments Signed-off-by: ashors1 --- nemo_reinforcer/algorithms/sft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 85604a3851..402b5bad92 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -399,6 +399,8 @@ def sft_train( policy.save_checkpoint( os.path.join(checkpoint_path, "policy.pt"), os.path.join(checkpoint_path, "policy_optimizer.pt"), + ## NOTE: below is a workaround to avoid a bug with checkpointing + ## this should be removed once the bug is fixed offload_to_cpu=False, ) torch.save( From cf943e2c08cb746e937afb31d9d3977a10b75b27 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Fri, 21 Mar 2025 00:16:28 -0700 Subject: [PATCH 11/11] single gpu default Signed-off-by: ashors1 --- examples/configs/sft.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 33abd6490c..b938a9d321 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -3,8 +3,8 @@ sft: num_steps: 20 val_period: 10 val_batches: 8 - val_global_batch_size: 64 - val_micro_batch_size: 8 + val_global_batch_size: 32 + val_micro_batch_size: 2 val_at_start: true checkpointing: @@ -17,8 +17,8 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 64 - train_micro_batch_size: 8 + train_global_batch_size: 32 + train_micro_batch_size: 2 learning_rate: 5.0e-6 max_total_sequence_length: 1024 @@ -48,5 +48,5 @@ logger: log_dir: "tb_logs" cluster: - gpus_per_node: 2 + gpus_per_node: 1 num_nodes: 1