Skip to content

feat: SFT convergence run changes#21

Merged
SahilJain314 merged 6 commits intomainfrom
yifu/sft
Mar 22, 2025
Merged

feat: SFT convergence run changes#21
SahilJain314 merged 6 commits intomainfrom
yifu/sft

Conversation

@yfw
Copy link
Copy Markdown
Contributor

@yfw yfw commented Mar 21, 2025

What does this PR do ?

Several changes for SFT convergence run:

  • Updated NLLLoss to take average loss over unmasked tokens (instead of sum)
  • Updated sft.yaml to be consistent with default NeMo 2 llama3-8b recipe
  • Added configurable optimizer (instead of hardcoded AdamW)
  • Add a set_seed util for reproducible runs
  • Update squad template to be consistent with NeMo 2 (add a space before answer)

Changelog

  • Please update the CHANGELOG.md under next version with high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing

  • TBD

Additional Information

  • Related to # (issue)

Comment thread nemo_reinforcer/algorithms/sft.py Outdated
Comment thread nemo_reinforcer/algorithms/loss_functions.py
Comment thread examples/configs/sft_nemo_verify.yaml Outdated
Comment thread nemo_reinforcer/models/policy/hf_policy.py Outdated
Comment thread nemo_reinforcer/algorithms/sft.py Outdated
@ashors1
Copy link
Copy Markdown
Contributor

ashors1 commented Mar 21, 2025

@yfw could you actually update the sft config with the convergence config you're using? Right now, the config settings are pretty arbitrary. I think it would be better to use a tested config.

@yfw
Copy link
Copy Markdown
Contributor Author

yfw commented Mar 21, 2025

@yfw could you actually update the sft config with the convergence config you're using? Right now, the config settings are pretty arbitrary. I think it would be better to use a tested config.

Yes, updated

@yfw yfw changed the title (WIP) SFT convergence run changes feat: SFT convergence run changes Mar 21, 2025
@ashors1 ashors1 mentioned this pull request Mar 21, 2025
4 tasks
yfw added 4 commits March 21, 2025 16:21
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
To compare with NeMo 2 default llama3-8b recipe:

```
uv run examples/run_sft.py  --config=examples/configs/sft.yaml \
    sft.max_num_steps=1168251 \
    sft.val_period=-1 \
    sft.val_global_batch_size=128 \
    sft.val_micro_batch_size=1 \
    sft.val_at_start=false \
    checkpointing.enabled=false \
    policy.model_name=meta-llama/Meta-Llama-3-8B \
    policy.train_global_batch_size=128 \
    policy.train_micro_batch_size=1 \
    policy.max_total_sequence_length=2048 \
    policy.optimizer.kwargs='{"lr": 5e-6, "betas": [0.9, 0.98], "eps": 1e-5, "weight_decay":0.1}' \
    policy.scheduler='{"name": "torch.optim.lr_scheduler.LinearLR", "kwargs": {"start_factor": 0.0196078, "end_factor": 1.0, "total_iters": 50}}' \
    data.dataset_name=squad \
    data.max_input_seq_length=2048 \
    cluster.gpus_per_node=8
```

Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@SahilJain314 SahilJain314 merged commit f530ded into main Mar 22, 2025
5 checks passed
@SahilJain314 SahilJain314 deleted the yifu/sft branch March 22, 2025 00:02
KiddoZhu pushed a commit that referenced this pull request May 6, 2025
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants