From 8727d89b01dcbd2a0ec65ada2a5366b544b077ee Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Fri, 21 Mar 2025 11:09:48 -0700 Subject: [PATCH 1/4] Pull changes from gitlab Signed-off-by: Yi-Fu Wu --- examples/configs/sft_nemo_verify.yaml | 40 ++++++++++++++++++++ nemo_reinforcer/algorithms/loss_functions.py | 8 +++- nemo_reinforcer/algorithms/sft.py | 21 ++++++++-- nemo_reinforcer/algorithms/utils.py | 9 +++++ nemo_reinforcer/models/policy/hf_policy.py | 15 ++++++-- 5 files changed, 84 insertions(+), 9 deletions(-) create mode 100644 examples/configs/sft_nemo_verify.yaml diff --git a/examples/configs/sft_nemo_verify.yaml b/examples/configs/sft_nemo_verify.yaml new file mode 100644 index 0000000000..d5c4ebc051 --- /dev/null +++ b/examples/configs/sft_nemo_verify.yaml @@ -0,0 +1,40 @@ +# SFT Algorithm Configuration +sft: + num_steps: 1168251 + seed: 42 + +policy: + model_name: "meta-llama/Meta-Llama-3-8B" + train_global_batch_size: 128 + train_micro_batch_size: 1 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + + scheduler: + name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.0196078 + end_factor: 1.0 + total_iters: 50 + +data: + max_input_seq_length: 2048 + dataset_name: "squad" + +logger: + wandb_enabled: true + tensorboard_enabled: false + wandb: + project: "sft-dev" + name: "sft-dev-logger" + tensorboard: + log_dir: "tb_logs" + +cluster: + gpus_per_node: 8 + num_nodes: 1 \ No newline at end of file diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 0518e3f0c0..cc7e2f77a5 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -160,6 +160,10 @@ def __call__( # Only compute loss on generated tokens (not input tokens) # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - loss = -torch.sum(token_logprobs * mask) + num_unmasked_tokens = torch.sum(mask) + if num_unmasked_tokens == 0: + # prevent division by zero + num_unmasked_tokens = 1 + loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens - return loss, {"loss": loss.item()} + return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()} diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index a8d065d37d..e3aa2a9792 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -15,11 +15,13 @@ from pathlib import Path from typing import Optional, Tuple, TypedDict +import numpy as np import torch from torchdata.stateful_dataloader import StatefulDataLoader from nemo_reinforcer.algorithms.loss_functions import ( NLLLoss, ) +from nemo_reinforcer.algorithms.utils import set_seed from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn from nemo_reinforcer.data.interfaces import TaskDataSpec @@ -57,7 +59,7 @@ class SFTConfig(TypedDict): val_global_batch_size: int val_micro_batch_size: int val_at_start: bool - + seed: int class MasterConfig(TypedDict): policy: PolicyConfig @@ -91,6 +93,8 @@ def setup( Returns: Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger """ + set_seed(master_config["sft"]["seed"]) + # Extract individual configs for easier access policy_config = master_config["policy"] data_config = master_config["data"] @@ -176,6 +180,7 @@ def setup( print(f" āœ“ Model initialized") logger = Logger(logger_config) + logger.log_hyperparams(master_config) print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") @@ -331,6 +336,7 @@ def sft_train( policy.prepare_for_training() + consumed_samples = 0 for batch in train_dataloader: print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") @@ -361,6 +367,7 @@ def sft_train( ## train_data.to("cpu") print("ā–¶ Taking a training step...") train_results = policy.train(train_data, loss_fn) + consumed_samples += train_data["input_ids"].shape[0] # Run validation if it's a validation step if val_period > 0 and (step + 1) % val_period == 0: @@ -410,11 +417,12 @@ def sft_train( checkpointer.finalize_checkpoint(checkpoint_path) losses = train_results["loss"] - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - metrics = { - "loss": losses.numpy(), + "loss": train_results["loss"].numpy(), } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\nšŸ“Š Training Results:") print(f" • Loss: {float(metrics['loss']):.4f}") @@ -434,6 +442,11 @@ def sft_train( logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") + # To match NeMo2 + logger.log_metrics({"reduced_train_loss": metrics['loss']}, step) + logger.log_metrics({"lr": metrics['lr']}, step) + logger.log_metrics({"consumed_samples": consumed_samples}, step) + timer.reset() step += 1 diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a153bfb53e..8315294d0d 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random import warnings from functools import wraps +import numpy as np import torch from torch.masked import as_masked_tensor @@ -120,3 +122,10 @@ def masked_mean(values, mask, dim=None): if dim is None: return values[mask.bool()].mean() return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + +def set_seed(seed: int): + """Sets the seed for python, numpy, and pytorch.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) \ No newline at end of file diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 6f44574c71..a44d677e05 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -115,9 +115,17 @@ def do_fsdp(model): self._held_reference_model_params = None # register_fsdp_forward_method(self.model, "generate") if init_optimizer: - self.optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.cfg["learning_rate"] - ) + if "optimizer" in self.cfg: + optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"] + ) + else: + self.optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.cfg["learning_rate"], + ) else: self.optimizer = None @@ -278,6 +286,7 @@ def train( logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass if not eval_mode: From cf38855cfe7e930db224c728def3f41f104e077c Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Fri, 21 Mar 2025 11:14:45 -0700 Subject: [PATCH 2/4] Whitespace Signed-off-by: Yi-Fu Wu --- examples/configs/sft_nemo_verify.yaml | 2 +- nemo_reinforcer/algorithms/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/sft_nemo_verify.yaml b/examples/configs/sft_nemo_verify.yaml index d5c4ebc051..3b41d7f9e2 100644 --- a/examples/configs/sft_nemo_verify.yaml +++ b/examples/configs/sft_nemo_verify.yaml @@ -37,4 +37,4 @@ logger: cluster: gpus_per_node: 8 - num_nodes: 1 \ No newline at end of file + num_nodes: 1 diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index 8315294d0d..a568dbcda6 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -128,4 +128,4 @@ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) \ No newline at end of file + torch.cuda.manual_seed_all(seed) From 8841a6a6e23a81457b9b8246df21f9643b143bcb Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Fri, 21 Mar 2025 15:34:07 -0700 Subject: [PATCH 3/4] Address PR comments To compare with NeMo 2 default llama3-8b recipe: ``` uv run examples/run_sft.py --config=examples/configs/sft.yaml \ sft.max_num_steps=1168251 \ sft.val_period=-1 \ sft.val_global_batch_size=128 \ sft.val_micro_batch_size=1 \ sft.val_at_start=false \ checkpointing.enabled=false \ policy.model_name=meta-llama/Meta-Llama-3-8B \ policy.train_global_batch_size=128 \ policy.train_micro_batch_size=1 \ policy.max_total_sequence_length=2048 \ policy.optimizer.kwargs='{"lr": 5e-6, "betas": [0.9, 0.98], "eps": 1e-5, "weight_decay":0.1}' \ policy.scheduler='{"name": "torch.optim.lr_scheduler.LinearLR", "kwargs": {"start_factor": 0.0196078, "end_factor": 1.0, "total_iters": 50}}' \ data.dataset_name=squad \ data.max_input_seq_length=2048 \ cluster.gpus_per_node=8 ``` Signed-off-by: Yi-Fu Wu --- examples/configs/grpo_math_1B.yaml | 9 ++++- examples/configs/grpo_math_8B.yaml | 9 ++++- examples/configs/sft.yaml | 23 ++++++----- examples/configs/sft_nemo_verify.yaml | 40 -------------------- nemo_reinforcer/algorithms/loss_functions.py | 2 +- nemo_reinforcer/algorithms/sft.py | 7 ---- nemo_reinforcer/data/hf_datasets/squad.py | 2 +- nemo_reinforcer/models/policy/hf_policy.py | 16 +++----- 8 files changed, 36 insertions(+), 72 deletions(-) delete mode 100644 examples/configs/sft_nemo_verify.yaml diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 684e24e7f5..5816d8cc4e 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -27,10 +27,17 @@ policy: train_global_batch_size: 32 train_micro_batch_size: 4 generation_batch_size: 32 - learning_rate: 5.0e-6 logprob_batch_size: 4 max_total_sequence_length: 512 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 2415ee8cd8..0d6db87165 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -6,10 +6,17 @@ policy: train_global_batch_size: 32 train_micro_batch_size: 1 generation_batch_size: 32 - learning_rate: 5.0e-6 logprob_batch_size: 2 max_total_sequence_length: 4096 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 38db492017..b99ed0ec7d 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -19,19 +19,22 @@ policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" train_global_batch_size: 32 train_micro_batch_size: 2 - learning_rate: 5.0e-6 max_total_sequence_length: 1024 + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 100 - - name: "torch.optim.lr_scheduler.CosineAnnealingLR" - kwargs: - T_max: 100 - - milestones: [50] + name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.1 + end_factor: 1.0 + total_iters: 100 data: max_input_seq_length: ${policy.max_total_sequence_length} diff --git a/examples/configs/sft_nemo_verify.yaml b/examples/configs/sft_nemo_verify.yaml deleted file mode 100644 index 3b41d7f9e2..0000000000 --- a/examples/configs/sft_nemo_verify.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# SFT Algorithm Configuration -sft: - num_steps: 1168251 - seed: 42 - -policy: - model_name: "meta-llama/Meta-Llama-3-8B" - train_global_batch_size: 128 - train_micro_batch_size: 1 - optimizer: - name: "torch.optim.AdamW" - kwargs: - lr: 5.0e-6 - weight_decay: 0.1 - betas: [0.9, 0.98] - eps: 1e-5 - - scheduler: - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.0196078 - end_factor: 1.0 - total_iters: 50 - -data: - max_input_seq_length: 2048 - dataset_name: "squad" - -logger: - wandb_enabled: true - tensorboard_enabled: false - wandb: - project: "sft-dev" - name: "sft-dev-logger" - tensorboard: - log_dir: "tb_logs" - -cluster: - gpus_per_node: 8 - num_nodes: 1 diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index cc7e2f77a5..90230a06ab 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -163,7 +163,7 @@ def __call__( num_unmasked_tokens = torch.sum(mask) if num_unmasked_tokens == 0: # prevent division by zero - num_unmasked_tokens = 1 + num_unmasked_tokens = torch.tensor(1) loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()} diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index e3aa2a9792..b216c02724 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -336,7 +336,6 @@ def sft_train( policy.prepare_for_training() - consumed_samples = 0 for batch in train_dataloader: print(f"\n{'=' * 25} Step {step + 1}/{len(train_dataloader)} {'=' * 25}") @@ -367,7 +366,6 @@ def sft_train( ## train_data.to("cpu") print("ā–¶ Taking a training step...") train_results = policy.train(train_data, loss_fn) - consumed_samples += train_data["input_ids"].shape[0] # Run validation if it's a validation step if val_period > 0 and (step + 1) % val_period == 0: @@ -442,11 +440,6 @@ def sft_train( logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") - # To match NeMo2 - logger.log_metrics({"reduced_train_loss": metrics['loss']}, step) - logger.log_metrics({"lr": metrics['lr']}, step) - logger.log_metrics({"consumed_samples": consumed_samples}, step) - timer.reset() step += 1 diff --git a/nemo_reinforcer/data/hf_datasets/squad.py b/nemo_reinforcer/data/hf_datasets/squad.py index a1378761d1..4ba3e87949 100644 --- a/nemo_reinforcer/data/hf_datasets/squad.py +++ b/nemo_reinforcer/data/hf_datasets/squad.py @@ -41,7 +41,7 @@ def __init__(self): original_ds = load_dataset("rajpurkar/squad") self.formatted_ds = original_ds.map(format_squad) - custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" + custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" super().__init__( dataset_name="squad", diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index a44d677e05..ff36f788e6 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -115,17 +115,11 @@ def do_fsdp(model): self._held_reference_model_params = None # register_fsdp_forward_method(self.model, "generate") if init_optimizer: - if "optimizer" in self.cfg: - optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) - self.optimizer = optimizer_cls( - self.model.parameters(), - **self.cfg["optimizer"]["kwargs"] - ) - else: - self.optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=self.cfg["learning_rate"], - ) + optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"] + ) else: self.optimizer = None From 4daf25e876edbe5c89894f9713b3913a3ad779ba Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Fri, 21 Mar 2025 16:12:48 -0700 Subject: [PATCH 4/4] Update sft.yaml and nll loss test Signed-off-by: Yi-Fu Wu --- examples/configs/sft.yaml | 29 ++++++++++---------- tests/unit/algorithms/test_loss_functions.py | 9 +++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index b99ed0ec7d..7828bcfb1b 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,11 +1,12 @@ # SFT Algorithm Configuration sft: - max_num_steps: 20 + max_num_steps: 1000 val_period: 10 val_batches: 8 - val_global_batch_size: 32 - val_micro_batch_size: 2 + val_global_batch_size: 128 + val_micro_batch_size: 1 val_at_start: true + seed: 42 checkpointing: enabled: true @@ -16,29 +17,29 @@ checkpointing: save_period: 10 policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 32 - train_micro_batch_size: 2 - max_total_sequence_length: 1024 + model_name: "meta-llama/Meta-Llama-3-8B" + train_global_batch_size: 128 + train_micro_batch_size: 1 + max_total_sequence_length: 2048 optimizer: name: "torch.optim.AdamW" kwargs: lr: 5.0e-6 - weight_decay: 0.01 - betas: [0.9, 0.999] - eps: 1e-8 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 scheduler: name: "torch.optim.lr_scheduler.LinearLR" kwargs: - start_factor: 0.1 + start_factor: 0.0196078 end_factor: 1.0 - total_iters: 100 + total_iters: 50 data: max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "open_assistant" + dataset_name: "squad" logger: log_dir: "logs" # Base directory for all logs @@ -51,5 +52,5 @@ logger: log_dir: "tb_logs" cluster: - gpus_per_node: 1 + gpus_per_node: 8 num_nodes: 1 diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 02e7a072b6..d509a296a8 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -43,6 +43,9 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + # Check the metrics dictionary contains the expected values + assert metrics_dict["num_unmasked_tokens"] == 2 + assert metrics_dict["total_tokens"] == 3 ## now assume we predict the incorrect token with high probability next_token_logits = ( @@ -59,4 +62,8 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(1998.0)) + ## with the updated loss function, we now average the loss over unmasked tokens + torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) + # Check the metrics dictionary contains the expected values + assert metrics_dict["num_unmasked_tokens"] == 2 + assert metrics_dict["total_tokens"] == 3