Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,68 @@ uv pip install -e '.[dev,test]'
# Use uv run to launch any runs.
# Note that it is recommended to not activate the venv and instead use `uv run` since
# it ensures consistent environment usage across different shells and sessions.
# Example: uv run python examples/run_grpo_math.py
```

## Quick start

**Reminder**: Don't forget to set your HF_HOME and WANDB_API_KEY (if needed). You'll need to do a `huggingface-cli login` as well for Llama models.

### GRPO

We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset.

#### Single GPU

To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:

```sh
# Run the GRPO math example using a 1B parameter model
uv run python examples/run_grpo_math.py
```

By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides:

```sh
uv run python examples/run_grpo_math.py \
policy.model_name="Qwen/Qwen2-1.5B" \
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
logger.wandb_enabled=True \
logger.wandb.name="grpo-qwen1_5b_math" \
logger.num_val_samples_to_print=10
```

#### Multi-node

For distributed training across multiple nodes:

Set `UV_CACHE_DIR` to a directory that can be read from all workers before running any uv run command.
```sh
export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache
```

```sh
# Run from the root of NeMo-Reinforcer repo
NUM_ACTOR_NODES=2
# Add a timestamp to make each job name unique
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# grpo_math_8b uses Llama-3.1-8B-Instruct model
COMMAND="uv pip install -e .; uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' policy.train_global_batch_size=64 logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \
RAY_DEDUP_LOGS=0 \
UV_CACHE_DIR=YOUR_UV_CACHE_DIR \
CONTAINER=YOUR_CONTAINER \
MOUNTS="$PWD:$PWD" \
sbatch \
--nodes=${NUM_ACTOR_NODES} \
--account=YOUR_ACCOUNT \
--job-name=YOUR_JOBNAME \
--partition=YOUR_PARTITION \
--time=4:0:0 \
--gres=gpu:8 \
ray.sub
```

## Cluster Start

Please visit [Cluster Start](docs/cluster.md) for how to get started on Slurm or Kubernetes.
11 changes: 5 additions & 6 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# GRPO Algorithm Configuration
grpo:
num_prompts_per_step: 8
num_prompts_per_step: 32
num_generations_per_prompt: 8
max_num_steps: 1000000
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: true
max_val_samples: 16
max_val_samples: 256
val_batch_size: 16

loss_fn:
Expand Down Expand Up @@ -59,9 +59,8 @@ data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_train.jsonl"
Comment thread
SahilJain314 marked this conversation as resolved.
val_dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_val.jsonl"

dataset_name: "OpenMathInstruct-2"

env:
math:
num_workers: 8
Expand All @@ -78,4 +77,4 @@ logger:

cluster:
gpus_per_node: 1
num_nodes: 1
num_nodes: 1
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ policy:

cluster:
gpus_per_node: 8
num_nodes: 1
num_nodes: 1
94 changes: 76 additions & 18 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset


def parse_args():
Expand All @@ -55,7 +56,58 @@ def parse_args():
# ===============================================================================


# this processor expects the datum_dict to have a 'problem' key and an 'expected_answer' key
def openinstructmath2_data_processor(
datum_dict: Dict[str, Any],
task_data_spec: TaskDataSpec,
tokenizer,
max_seq_length: int,
idx: int,
) -> DatumSpec:
"""Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Math Environment."""
user_message = datum_dict["messages"]
problem = user_message[0]["content"]
extra_env_info = {"ground_truth": user_message[1]["content"]}

template = task_data_spec.custom_template
message_log: LLMMessageLogType = []
user_message = {
"role": "user",
"content": task_data_spec.prompt.format(problem),
}
message = tokenizer.apply_chat_template(
[user_message],
chat_template=template,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=False,
)
user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0]
user_message["content"] = message
message_log.append(user_message)

length = sum(len(m["token_ids"]) for m in message_log)

loss_multiplier = 1.0
if length > max_seq_length:
# make smaller and mask out
for message in message_log:
message["token_ids"] = message["token_ids"][
: min(4, max_seq_length // len(message_log))
]
loss_multiplier = 0.0

output = {
"message_log": message_log,
"length": length,
"extra_env_info": extra_env_info,
"loss_multiplier": loss_multiplier,
"idx": idx,
"task_name": datum_dict["task_name"],
}
return output


# Example of a generic math data processor
def math_data_processor(
datum_dict: Dict[str, Any],
task_data_spec: TaskDataSpec,
Expand Down Expand Up @@ -128,36 +180,38 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
system_prompt_file=data_config["system_prompt_file"],
)

base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"]
# Load OpenMathInstruct2Dataset using reinforcer datasets
if data_config["dataset_name"] == "OpenMathInstruct-2":
print(f"Loading nvidia/OpenMathInstruct2Dataset for training and validation")
data = OpenMathInstruct2Dataset()
else:
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")

tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])

task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor))
task_data_processors["math"] = (math_task_spec, math_data_processor)
task_data_processors = defaultdict(
lambda: (math_task_spec, openinstructmath2_data_processor)
)
task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor)

math_env = MathEnvironment.options(
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
).remote(env_configs["math"])
dataset = AllTaskProcessedDataset(
base_dataset,
data.formatted_ds["train"],
tokenizer,
math_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)

if "val_dataset_name" in data_config and data_config["val_dataset_name"]:
val_dataset = load_dataset("json", data_files=data_config["val_dataset_name"])[
"train"
]
val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
math_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
else:
val_dataset = None
val_dataset = AllTaskProcessedDataset(
data.formatted_ds["validation"],
tokenizer,
math_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)

task_to_env = defaultdict(lambda: math_env)
task_to_env["math"] = math_env
Expand Down Expand Up @@ -191,6 +245,10 @@ def main():
# Get the next experiment directory with incremented ID
config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
print(
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

init_ray()

Expand Down
4 changes: 4 additions & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def main():

config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
print(
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)

init_ray()

Expand Down
3 changes: 2 additions & 1 deletion nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def setup(

# Load validation dataset if provided
val_dataloader = None
if "val_dataset_name" in data_config and data_config["val_dataset_name"]:
# If validation is enabled, load the validation dataloader
if grpo_config["val_period"] > 0 or grpo_config["val_at_start"]:
val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=grpo_config["val_batch_size"],
Expand Down
87 changes: 87 additions & 0 deletions nemo_reinforcer/data/hf_datasets/openmathinstruct2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from datasets import load_dataset
from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset
from dataclasses import dataclass


def format_math(data):
return {
"messages": [
{
"role": "user",
"content": data["problem"],
},
{
"role": "assistant",
"content": data["expected_answer"],
},
],
# For v0.1 release, reinforcer datasets require a task_name key such that user can map a task processor per unique task.
"task_name": "math",
}


def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_size=0.05):
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
print(
f"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets."
)

# Load the original dataset
original_ds = load_dataset("nvidia/OpenMathInstruct-2", split=split)

# Split into train and validation sets using HF's train_test_split
split_ds = original_ds.train_test_split(test_size=test_size, seed=seed)

# Format the examples, removing original columns
train_formatted = split_ds["train"].map(
format_math, remove_columns=split_ds["train"].column_names
)
val_formatted = split_ds["test"].map(
format_math, remove_columns=split_ds["test"].column_names
)

return {
"train": train_formatted,
"validation": val_formatted,
}


@dataclass
class OpenMathInstruct2Dataset(HfDataset):
def __init__(
self, split: str = "train_1M", seed: int = 42, test_size: float = 0.05
):
"""Initialize the OpenMathInstruct2 dataset with train/validation split.

Args:
seed: Random seed for reproducible splitting
test_size: Proportion of data to use for validation (0.0-1.0)
"""
# train, train_1M, train_2M, and train_5M are supported splits.
if split not in ["train", "train_1M", "train_2M", "train_5M"]:
raise ValueError(
f"Invalid split: {split}. Please use 'train', 'train_1M', 'train_2M', or 'train_5M'."
)

self.formatted_ds = prepare_openinstructmath2_dataset(
split=split, seed=seed, test_size=test_size
)

super().__init__(
dataset_name="OpenMathInstruct-2",
)
2 changes: 1 addition & 1 deletion ray.sub
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ echo "All workers connected!"
# This driver process is responsible for launching a job on the Ray cluster
CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory')
if [[ -n "$COMMAND" ]]; then
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log $COMMAND
srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND"
else
echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:"
cat <<EOF >$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh
Expand Down