diff --git a/README.md b/README.md index edff7aba42..447c20c668 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,68 @@ uv pip install -e '.[dev,test]' # Use uv run to launch any runs. # Note that it is recommended to not activate the venv and instead use `uv run` since # it ensures consistent environment usage across different shells and sessions. +# Example: uv run python examples/run_grpo_math.py +``` + +## Quick start + +**Reminder**: Don't forget to set your HF_HOME and WANDB_API_KEY (if needed). You'll need to do a `huggingface-cli login` as well for Llama models. + +### GRPO + +We have a reference GRPO experiment config set up trained for math benchmarks using the [OpenInstructMath2](https://huggingface.co/datasets/nvidia/OpenMathInstruct-2) dataset. + +#### Single GPU + +To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`: + +```sh +# Run the GRPO math example using a 1B parameter model uv run python examples/run_grpo_math.py ``` +By default, this uses the configuration in `examples/configs/grpo_math_1B.yaml`. You can customize parameters with command-line overrides: + +```sh +uv run python examples/run_grpo_math.py \ + policy.model_name="Qwen/Qwen2-1.5B" \ + checkpointing.checkpoint_dir="results/qwen1_5b_math" \ + logger.wandb_enabled=True \ + logger.wandb.name="grpo-qwen1_5b_math" \ + logger.num_val_samples_to_print=10 +``` + +#### Multi-node + +For distributed training across multiple nodes: + +Set `UV_CACHE_DIR` to a directory that can be read from all workers before running any uv run command. +```sh +export UV_CACHE_DIR=/path/that/all/workers/can/access/uv_cache +``` + +```sh +# Run from the root of NeMo-Reinforcer repo +NUM_ACTOR_NODES=2 +# Add a timestamp to make each job name unique +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# grpo_math_8b uses Llama-3.1-8B-Instruct model +COMMAND="uv pip install -e .; uv run ./examples/run_grpo_math.py --config examples/configs/grpo_math_8B.yaml cluster.num_nodes=2 checkpointing.checkpoint_dir='results/llama8b_2nodes' policy.train_global_batch_size=64 logger.wandb_enabled=True logger.wandb.name='grpo-llama8b_math'" \ +RAY_DEDUP_LOGS=0 \ +UV_CACHE_DIR=YOUR_UV_CACHE_DIR \ +CONTAINER=YOUR_CONTAINER \ +MOUNTS="$PWD:$PWD" \ +sbatch \ + --nodes=${NUM_ACTOR_NODES} \ + --account=YOUR_ACCOUNT \ + --job-name=YOUR_JOBNAME \ + --partition=YOUR_PARTITION \ + --time=4:0:0 \ + --gres=gpu:8 \ + ray.sub +``` + ## Cluster Start Please visit [Cluster Start](docs/cluster.md) for how to get started on Slurm or Kubernetes. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index f1884b9f75..7d570adddb 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -1,13 +1,13 @@ # GRPO Algorithm Configuration grpo: - num_prompts_per_step: 8 + num_prompts_per_step: 32 num_generations_per_prompt: 8 max_num_steps: 1000000 normalize_rewards: true use_leave_one_out_baseline: true val_period: 10 val_at_start: true - max_val_samples: 16 + max_val_samples: 256 val_batch_size: 16 loss_fn: @@ -59,9 +59,8 @@ data: max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len prompt_file: "examples/prompts/cot.txt" system_prompt_file: null - dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_train.jsonl" - val_dataset_name: "datasets/Eurus-2-RL-Data/Eurus-2-RL-Data-math_val.jsonl" - + dataset_name: "OpenMathInstruct-2" + env: math: num_workers: 8 @@ -78,4 +77,4 @@ logger: cluster: gpus_per_node: 1 - num_nodes: 1 \ No newline at end of file + num_nodes: 1 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 69f28a41b2..2dab29560a 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -36,4 +36,4 @@ policy: cluster: gpus_per_node: 8 - num_nodes: 1 \ No newline at end of file + num_nodes: 1 diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 1c30f8ea94..72c491206f 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -32,6 +32,7 @@ from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn from nemo_reinforcer.environments.math_environment import MathEnvironment +from nemo_reinforcer.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset def parse_args(): @@ -55,7 +56,58 @@ def parse_args(): # =============================================================================== -# this processor expects the datum_dict to have a 'problem' key and an 'expected_answer' key +def openinstructmath2_data_processor( + datum_dict: Dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from data/hf_datasets/openmathinstruct2.py) into a DatumSpec for the Math Environment.""" + user_message = datum_dict["messages"] + problem = user_message[0]["content"] + extra_env_info = {"ground_truth": user_message[1]["content"]} + + template = task_data_spec.custom_template + message_log: LLMMessageLogType = [] + user_message = { + "role": "user", + "content": task_data_spec.prompt.format(problem), + } + message = tokenizer.apply_chat_template( + [user_message], + chat_template=template, + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0] + user_message["content"] = message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for message in message_log: + message["token_ids"] = message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + "task_name": datum_dict["task_name"], + } + return output + + +# Example of a generic math data processor def math_data_processor( datum_dict: Dict[str, Any], task_data_spec: TaskDataSpec, @@ -128,36 +180,38 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs system_prompt_file=data_config["system_prompt_file"], ) - base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"] + # Load OpenMathInstruct2Dataset using reinforcer datasets + if data_config["dataset_name"] == "OpenMathInstruct-2": + print(f"Loading nvidia/OpenMathInstruct2Dataset for training and validation") + data = OpenMathInstruct2Dataset() + else: + raise ValueError(f"No processor for dataset {data_config['dataset_name']}.") + tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"]) - task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor)) - task_data_processors["math"] = (math_task_spec, math_data_processor) + task_data_processors = defaultdict( + lambda: (math_task_spec, openinstructmath2_data_processor) + ) + task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor) math_env = MathEnvironment.options( runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} ).remote(env_configs["math"]) dataset = AllTaskProcessedDataset( - base_dataset, + data.formatted_ds["train"], tokenizer, math_task_spec, task_data_processors, max_seq_length=data_config["max_input_seq_length"], ) - if "val_dataset_name" in data_config and data_config["val_dataset_name"]: - val_dataset = load_dataset("json", data_files=data_config["val_dataset_name"])[ - "train" - ] - val_dataset = AllTaskProcessedDataset( - val_dataset, - tokenizer, - math_task_spec, - task_data_processors, - max_seq_length=data_config["max_input_seq_length"], - ) - else: - val_dataset = None + val_dataset = AllTaskProcessedDataset( + data.formatted_ds["validation"], + tokenizer, + math_task_spec, + task_data_processors, + max_seq_length=data_config["max_input_seq_length"], + ) task_to_env = defaultdict(lambda: math_env) task_to_env["math"] = math_env @@ -191,6 +245,10 @@ def main(): # Get the next experiment directory with incremented ID config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) init_ray() diff --git a/examples/run_sft.py b/examples/run_sft.py index de0ef8c1ce..950938b4da 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -145,6 +145,10 @@ def main(): config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"]) print(f"📊 Using log directory: {config['logger']['log_dir']}") + if config["checkpointing"]["enabled"]: + print( + f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}" + ) init_ray() diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index bc4f96099a..7acfccd51b 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -180,7 +180,8 @@ def setup( # Load validation dataset if provided val_dataloader = None - if "val_dataset_name" in data_config and data_config["val_dataset_name"]: + # If validation is enabled, load the validation dataloader + if grpo_config["val_period"] > 0 or grpo_config["val_at_start"]: val_dataloader = StatefulDataLoader( val_dataset, batch_size=grpo_config["val_batch_size"], diff --git a/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py b/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py new file mode 100644 index 0000000000..c3c5126263 --- /dev/null +++ b/nemo_reinforcer/data/hf_datasets/openmathinstruct2.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +from datasets import load_dataset +from nemo_reinforcer.data.hf_datasets.interfaces import HfDataset +from dataclasses import dataclass + + +def format_math(data): + return { + "messages": [ + { + "role": "user", + "content": data["problem"], + }, + { + "role": "assistant", + "content": data["expected_answer"], + }, + ], + # For v0.1 release, reinforcer datasets require a task_name key such that user can map a task processor per unique task. + "task_name": "math", + } + + +def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_size=0.05): + """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split.""" + print( + f"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." + ) + + # Load the original dataset + original_ds = load_dataset("nvidia/OpenMathInstruct-2", split=split) + + # Split into train and validation sets using HF's train_test_split + split_ds = original_ds.train_test_split(test_size=test_size, seed=seed) + + # Format the examples, removing original columns + train_formatted = split_ds["train"].map( + format_math, remove_columns=split_ds["train"].column_names + ) + val_formatted = split_ds["test"].map( + format_math, remove_columns=split_ds["test"].column_names + ) + + return { + "train": train_formatted, + "validation": val_formatted, + } + + +@dataclass +class OpenMathInstruct2Dataset(HfDataset): + def __init__( + self, split: str = "train_1M", seed: int = 42, test_size: float = 0.05 + ): + """Initialize the OpenMathInstruct2 dataset with train/validation split. + + Args: + seed: Random seed for reproducible splitting + test_size: Proportion of data to use for validation (0.0-1.0) + """ + # train, train_1M, train_2M, and train_5M are supported splits. + if split not in ["train", "train_1M", "train_2M", "train_5M"]: + raise ValueError( + f"Invalid split: {split}. Please use 'train', 'train_1M', 'train_2M', or 'train_5M'." + ) + + self.formatted_ds = prepare_openinstructmath2_dataset( + split=split, seed=seed, test_size=test_size + ) + + super().__init__( + dataset_name="OpenMathInstruct-2", + ) diff --git a/ray.sub b/ray.sub index dac0899a35..c08cc56b2c 100644 --- a/ray.sub +++ b/ray.sub @@ -159,7 +159,7 @@ echo "All workers connected!" # This driver process is responsible for launching a job on the Ray cluster CONTAINER_CWD=$(scontrol show job $SLURM_JOB_ID --json | jq -r '.jobs[].current_working_directory') if [[ -n "$COMMAND" ]]; then - srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log $COMMAND + srun --no-container-mount-home --gpus=0 --overlap --container-name=ray-head --container-workdir=$CONTAINER_CWD --nodes=1 --ntasks=1 -w "$head_node" -o $LOG_DIR/ray-driver.log bash -c "$COMMAND" else echo "[INFO]: Ray Cluster is idled, run this on the slurm head node to get a shell to the head node:" cat <$SLURM_SUBMIT_DIR/${SLURM_JOB_ID}-attach.sh