diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 7d570adddb..8711345afc 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -27,11 +27,18 @@ policy: train_global_batch_size: 32 train_micro_batch_size: 4 generation_batch_size: 32 - learning_rate: 5.0e-6 logprob_batch_size: 4 max_total_sequence_length: 512 precision: "bfloat16" + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 2dab29560a..69802553c1 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -6,11 +6,18 @@ policy: train_global_batch_size: 32 train_micro_batch_size: 1 generation_batch_size: 32 - learning_rate: 5.0e-6 logprob_batch_size: 2 max_total_sequence_length: 4096 precision: "bfloat16" + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.01 + betas: [0.9, 0.999] + eps: 1e-8 + scheduler: - name: "torch.optim.lr_scheduler.LinearLR" kwargs: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index ef049bea88..2436795abb 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -1,11 +1,12 @@ # SFT Algorithm Configuration sft: - max_num_steps: 20 + max_num_steps: 1000 val_period: 10 val_batches: 8 - val_global_batch_size: 32 - val_micro_batch_size: 2 + val_global_batch_size: 128 + val_micro_batch_size: 1 val_at_start: true + seed: 42 checkpointing: enabled: true @@ -16,27 +17,30 @@ checkpointing: save_period: 10 policy: - model_name: "meta-llama/Llama-3.2-1B-Instruct" - train_global_batch_size: 32 - train_micro_batch_size: 2 - learning_rate: 5.0e-6 - max_total_sequence_length: 1024 + model_name: "meta-llama/Meta-Llama-3-8B" + train_global_batch_size: 128 + train_micro_batch_size: 1 + max_total_sequence_length: 2048 precision: "float32" + optimizer: + name: "torch.optim.AdamW" + kwargs: + lr: 5.0e-6 + weight_decay: 0.1 + betas: [0.9, 0.98] + eps: 1e-5 + scheduler: - - name: "torch.optim.lr_scheduler.LinearLR" - kwargs: - start_factor: 0.1 - end_factor: 1.0 - total_iters: 100 - - name: "torch.optim.lr_scheduler.CosineAnnealingLR" - kwargs: - T_max: 100 - - milestones: [50] + name: "torch.optim.lr_scheduler.LinearLR" + kwargs: + start_factor: 0.0196078 + end_factor: 1.0 + total_iters: 50 data: max_input_seq_length: ${policy.max_total_sequence_length} - dataset_name: "open_assistant" + dataset_name: "squad" logger: log_dir: "logs" # Base directory for all logs @@ -49,5 +53,5 @@ logger: log_dir: "tb_logs" cluster: - gpus_per_node: 1 + gpus_per_node: 8 num_nodes: 1 diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 0518e3f0c0..90230a06ab 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -160,6 +160,10 @@ def __call__( # Only compute loss on generated tokens (not input tokens) # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - loss = -torch.sum(token_logprobs * mask) + num_unmasked_tokens = torch.sum(mask) + if num_unmasked_tokens == 0: + # prevent division by zero + num_unmasked_tokens = torch.tensor(1) + loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens - return loss, {"loss": loss.item()} + return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()} diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index a8d065d37d..b216c02724 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -15,11 +15,13 @@ from pathlib import Path from typing import Optional, Tuple, TypedDict +import numpy as np import torch from torchdata.stateful_dataloader import StatefulDataLoader from nemo_reinforcer.algorithms.loss_functions import ( NLLLoss, ) +from nemo_reinforcer.algorithms.utils import set_seed from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn from nemo_reinforcer.data.interfaces import TaskDataSpec @@ -57,7 +59,7 @@ class SFTConfig(TypedDict): val_global_batch_size: int val_micro_batch_size: int val_at_start: bool - + seed: int class MasterConfig(TypedDict): policy: PolicyConfig @@ -91,6 +93,8 @@ def setup( Returns: Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger """ + set_seed(master_config["sft"]["seed"]) + # Extract individual configs for easier access policy_config = master_config["policy"] data_config = master_config["data"] @@ -176,6 +180,7 @@ def setup( print(f" āœ“ Model initialized") logger = Logger(logger_config) + logger.log_hyperparams(master_config) print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") @@ -410,11 +415,12 @@ def sft_train( checkpointer.finalize_checkpoint(checkpoint_path) losses = train_results["loss"] - timing_metrics = timer.get_timing_metrics(reduction_op="sum") - metrics = { - "loss": losses.numpy(), + "loss": train_results["loss"].numpy(), } + metrics.update(train_results["all_mb_metrics"]) + metrics = {k: np.mean(v).item() for k, v in metrics.items()} + timing_metrics = timer.get_timing_metrics(reduction_op="sum") print("\nšŸ“Š Training Results:") print(f" • Loss: {float(metrics['loss']):.4f}") diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a153bfb53e..a568dbcda6 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random import warnings from functools import wraps +import numpy as np import torch from torch.masked import as_masked_tensor @@ -120,3 +122,10 @@ def masked_mean(values, mask, dim=None): if dim is None: return values[mask.bool()].mean() return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + +def set_seed(seed: int): + """Sets the seed for python, numpy, and pytorch.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/nemo_reinforcer/data/hf_datasets/squad.py b/nemo_reinforcer/data/hf_datasets/squad.py index a1378761d1..4ba3e87949 100644 --- a/nemo_reinforcer/data/hf_datasets/squad.py +++ b/nemo_reinforcer/data/hf_datasets/squad.py @@ -41,7 +41,7 @@ def __init__(self): original_ds = load_dataset("rajpurkar/squad") self.formatted_ds = original_ds.map(format_squad) - custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}" + custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}" super().__init__( dataset_name="squad", diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index cf932e98d6..2eb5598f6b 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -121,8 +121,10 @@ def do_fsdp(model): self._held_reference_model_params = None # register_fsdp_forward_method(self.model, "generate") if init_optimizer: - self.optimizer = torch.optim.AdamW( - self.model.parameters(), lr=self.cfg["learning_rate"] + optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"]) + self.optimizer = optimizer_cls( + self.model.parameters(), + **self.cfg["optimizer"]["kwargs"] ) else: self.optimizer = None @@ -285,6 +287,7 @@ def train( logits = outputs.logits loss, loss_metrics = loss_fn(logits, mb) + loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"] # Backward pass if not eval_mode: diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index c129ee11fb..fe874ecc26 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -47,6 +47,9 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + # Check the metrics dictionary contains the expected values + assert metrics_dict["num_unmasked_tokens"] == 2 + assert metrics_dict["total_tokens"] == 3 ## now assume we predict the incorrect token with high probability next_token_logits = ( @@ -63,4 +66,8 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(1998.0)) + ## with the updated loss function, we now average the loss over unmasked tokens + torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) + # Check the metrics dictionary contains the expected values + assert metrics_dict["num_unmasked_tokens"] == 2 + assert metrics_dict["total_tokens"] == 3