Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ policy:
train_global_batch_size: 32
train_micro_batch_size: 4
generation_batch_size: 32
learning_rate: 5.0e-6
logprob_batch_size: 4
max_total_sequence_length: 512
precision: "bfloat16"

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-6
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
Expand Down
9 changes: 8 additions & 1 deletion examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ policy:
train_global_batch_size: 32
train_micro_batch_size: 1
generation_batch_size: 32
learning_rate: 5.0e-6
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: "bfloat16"

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-6
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
Expand Down
42 changes: 23 additions & 19 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SFT Algorithm Configuration
sft:
max_num_steps: 20
max_num_steps: 1000
val_period: 10
val_batches: 8
val_global_batch_size: 32
val_micro_batch_size: 2
val_global_batch_size: 128
val_micro_batch_size: 1
val_at_start: true
seed: 42

checkpointing:
enabled: true
Expand All @@ -16,27 +17,30 @@ checkpointing:
save_period: 10

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
train_global_batch_size: 32
train_micro_batch_size: 2
learning_rate: 5.0e-6
max_total_sequence_length: 1024
model_name: "meta-llama/Meta-Llama-3-8B"
train_global_batch_size: 128
train_micro_batch_size: 1
max_total_sequence_length: 2048
precision: "float32"

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 5.0e-6
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
total_iters: 100
- name: "torch.optim.lr_scheduler.CosineAnnealingLR"
kwargs:
T_max: 100
- milestones: [50]
name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.0196078
end_factor: 1.0
total_iters: 50

data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "open_assistant"
dataset_name: "squad"

logger:
log_dir: "logs" # Base directory for all logs
Expand All @@ -49,5 +53,5 @@ logger:
log_dir: "tb_logs"

cluster:
gpus_per_node: 1
gpus_per_node: 8
num_nodes: 1
8 changes: 6 additions & 2 deletions nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def __call__(

# Only compute loss on generated tokens (not input tokens)
# by applying the token_loss_mask (shifted by 1 since we're predicting next tokens)
loss = -torch.sum(token_logprobs * mask)
num_unmasked_tokens = torch.sum(mask)
Comment thread
SahilJain314 marked this conversation as resolved.
if num_unmasked_tokens == 0:
# prevent division by zero
num_unmasked_tokens = torch.tensor(1)
loss = -torch.sum(token_logprobs * mask) / num_unmasked_tokens

return loss, {"loss": loss.item()}
return loss, {"loss": loss.item(), "num_unmasked_tokens": num_unmasked_tokens.item(), "total_tokens": mask.numel()}
14 changes: 10 additions & 4 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
from pathlib import Path
from typing import Optional, Tuple, TypedDict

import numpy as np
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from nemo_reinforcer.algorithms.loss_functions import (
NLLLoss,
)
from nemo_reinforcer.algorithms.utils import set_seed
from nemo_reinforcer.data import DataConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_reinforcer.data.interfaces import TaskDataSpec
Expand Down Expand Up @@ -57,7 +59,7 @@ class SFTConfig(TypedDict):
val_global_batch_size: int
val_micro_batch_size: int
val_at_start: bool

seed: int

class MasterConfig(TypedDict):
policy: PolicyConfig
Expand Down Expand Up @@ -91,6 +93,8 @@ def setup(
Returns:
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, master_config, logger
"""
set_seed(master_config["sft"]["seed"])

# Extract individual configs for easier access
policy_config = master_config["policy"]
data_config = master_config["data"]
Expand Down Expand Up @@ -176,6 +180,7 @@ def setup(
print(f" ✓ Model initialized")

logger = Logger(logger_config)
logger.log_hyperparams(master_config)

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
Expand Down Expand Up @@ -410,11 +415,12 @@ def sft_train(
checkpointer.finalize_checkpoint(checkpoint_path)

losses = train_results["loss"]
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

metrics = {
"loss": losses.numpy(),
"loss": train_results["loss"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
metrics = {k: np.mean(v).item() for k, v in metrics.items()}
timing_metrics = timer.get_timing_metrics(reduction_op="sum")

print("\n📊 Training Results:")
print(f" • Loss: {float(metrics['loss']):.4f}")
Expand Down
9 changes: 9 additions & 0 deletions nemo_reinforcer/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import warnings
from functools import wraps

import numpy as np
import torch
from torch.masked import as_masked_tensor

Expand Down Expand Up @@ -120,3 +122,10 @@ def masked_mean(values, mask, dim=None):
if dim is None:
return values[mask.bool()].mean()
return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan)

def set_seed(seed: int):
"""Sets the seed for python, numpy, and pytorch."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
2 changes: 1 addition & 1 deletion nemo_reinforcer/data/hf_datasets/squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self):
original_ds = load_dataset("rajpurkar/squad")
self.formatted_ds = original_ds.map(format_squad)

custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}"
custom_template = "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"

super().__init__(
dataset_name="squad",
Expand Down
7 changes: 5 additions & 2 deletions nemo_reinforcer/models/policy/hf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def do_fsdp(model):
self._held_reference_model_params = None
# register_fsdp_forward_method(self.model, "generate")
if init_optimizer:
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.cfg["learning_rate"]
optimizer_cls = import_class_from_path(self.cfg["optimizer"]["name"])
self.optimizer = optimizer_cls(
self.model.parameters(),
**self.cfg["optimizer"]["kwargs"]
)
else:
self.optimizer = None
Expand Down Expand Up @@ -285,6 +287,7 @@ def train(
logits = outputs.logits

loss, loss_metrics = loss_fn(logits, mb)
loss_metrics["lr"] = self.optimizer.param_groups[0]["lr"]

# Backward pass
if not eval_mode:
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def test_nll_loss():
)
loss, metrics_dict = loss_fn(next_token_logits, data)
torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0))
# Check the metrics dictionary contains the expected values
assert metrics_dict["num_unmasked_tokens"] == 2
assert metrics_dict["total_tokens"] == 3

## now assume we predict the incorrect token with high probability
next_token_logits = (
Expand All @@ -63,4 +66,8 @@ def test_nll_loss():
)
loss, metrics_dict = loss_fn(next_token_logits, data)
## loss per token is 999, and we have two unmasked tokens
torch.testing.assert_allclose(loss.cpu(), torch.tensor(1998.0))
## with the updated loss function, we now average the loss over unmasked tokens
torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0))
# Check the metrics dictionary contains the expected values
assert metrics_dict["num_unmasked_tokens"] == 2
assert metrics_dict["total_tokens"] == 3