Skip to content
33 changes: 33 additions & 0 deletions docs/guides/eval.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Evaluation
Comment thread
parthchadha marked this conversation as resolved.

## Start Evaluation

### Start Script
```sh
# To run the evaluation with default config (examples/configs/eval.yaml)
uv run python examples/run_eval.py

# Specify a custom config file
uv run python examples/run_eval.py --config path/to/custom_config.yaml

# Override specific config values via command line
uv run python examples/run_eval.py generation.model_name="Qwen/Qwen2.5-Math-7B-Instruct"
```

### Example Output

```
============================================================
model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime_2024'
score=0.10 (3.0/30)
============================================================
```

## Configuration

An example Evaluation configuration file can be found [here](../../examples/configs/eval.yaml).

### Prompt Template Configuration
Always remember to use the same `prompt_file` and `system_prompt_file` that were used during training.

For open-source models, we recommend setting `prompt_file=null` and `system_prompt_file=null` to allow them to use their native chat templates.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cluster.md
adding_new_models.md
guides/sft.md
guides/grpo.md
guides/eval.md
```

```{toctree}
Expand Down
30 changes: 30 additions & 0 deletions examples/configs/eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Evaluation Configuration
generation:
backend: "vllm" # only vllm is supported for evaluation
max_new_tokens: ${generation.vllm_cfg.max_model_len}
temperature: 0.0
top_p: 1.0
top_k: -1 # disable
num_prompts_per_step: -1 # -1 means pass all prompts at once
model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct"
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.9
max_model_len: 2048

data:
max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation
prompt_file: null
system_prompt_file: null
dataset_name: "HuggingFaceH4/aime_2024"
dataset_key: "train"
problem_key: "problem"
solution_key: "answer"

env:
math:
num_workers: 8

cluster:
gpus_per_node: 1
num_nodes: 1
145 changes: 145 additions & 0 deletions examples/run_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import pprint
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import AutoTokenizer

from examples.run_grpo_math import math_data_processor
from nemo_reinforcer.data import MathDataConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
from nemo_reinforcer.data.interfaces import TaskDataSpec
from nemo_reinforcer.data.llm_message_utils import remap_dataset_keys
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup
from nemo_reinforcer.models.generation.interfaces import GenerationConfig


def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run Evaluation with configuration")
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)

# Parse known args for the script
args, remaining = parser.parse_known_args()

# Convert remaining args to OmegaConf format
overrides = OmegaConf.from_dotlist(remaining)

return args, overrides


def setup_data(
data_config: MathDataConfig, generation_config: GenerationConfig, env_configs
):
print("\n▶ Setting up data...")
math_task_spec = TaskDataSpec(
task_name="math",
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)

# load dataset
base_dataset = load_dataset(data_config["dataset_name"])
if data_config["dataset_key"] is not None:
base_dataset = base_dataset[data_config["dataset_key"]]
# remap problem and solution keys
remapped_dataset = remap_dataset_keys(
base_dataset,
mapping_dict={
data_config["problem_key"]: "problem",
data_config["solution_key"]: "expected_answer",
},
)

tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

math_env = MathEnvironment.options(
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
).remote(env_configs["math"])

dataset = AllTaskProcessedDataset(
dataset=remapped_dataset,
tokenizer=tokenizer,
default_task_data_spec=math_task_spec,
task_data_processors=math_data_processor,
max_seq_length=data_config["max_input_seq_length"],
)

return dataset, math_env, tokenizer


def main():
"""Main entry point."""
# Parse arguments
args, overrides = parse_args()

if not args.config:
args.config = os.path.join(os.path.dirname(__file__), "configs", "eval.yaml")

config = OmegaConf.load(args.config)
print(f"Loaded configuration from: {args.config}")

if overrides:
override_conf = OmegaConf.from_cli()
print(f"Overrides: {override_conf}")
config = OmegaConf.merge(config, override_conf)

config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")

# Print config
print("Final config:")
pprint.pprint(config)

# Init ray
init_ray()

# Setup data
(
dataset,
math_env,
tokenizer,
) = setup_data(config["data"], config["generation"], config["env"])

# Setup
(
vllm_generation,
dataloader,
master_config,
) = setup(config, tokenizer, dataset)

# Run evaluation
run_env_eval(
vllm_generation,
dataloader,
math_env,
master_config,
)


if __name__ == "__main__":
main()
14 changes: 9 additions & 5 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def math_data_processor(

template = task_data_spec.custom_template
message_log: LLMMessageLogType = []

# system prompt
if task_data_spec.system_prompt:
sys_message = {"role": "system", "content": task_data_spec.system_prompt}
message = tokenizer.apply_chat_template(
Expand All @@ -135,10 +137,11 @@ def math_data_processor(
0
]
message_log.append(sys_message)
user_message = {
"role": "user",
"content": task_data_spec.prompt.format(problem),
}

# user prompt
if task_data_spec.prompt:
problem = task_data_spec.prompt.format(problem)
user_message = {"role": "user", "content": problem}
message = tokenizer.apply_chat_template(
[user_message],
chat_template=template,
Expand Down Expand Up @@ -167,8 +170,9 @@ def math_data_processor(
"extra_env_info": extra_env_info,
"loss_multiplier": loss_multiplier,
"idx": idx,
"task_name": datum_dict["task_name"],
}
if "task_name" in datum_dict:
output["task_name"] = datum_dict["task_name"]
return output


Expand Down
5 changes: 5 additions & 0 deletions nemo_reinforcer/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@ class DataConfig(TypedDict):
system_prompt_file: Optional[str]
dataset_name: str
val_dataset_name: Optional[str]


class MathDataConfig(DataConfig):
problem_key: str
Comment thread
parthchadha marked this conversation as resolved.
solution_key: str
51 changes: 51 additions & 0 deletions nemo_reinforcer/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,54 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
batch_max_length=batch_max_length,
)
return output


def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict:
"""Collate function for evaluation.

Takes a list of data samples and combines them into a single batched dictionary
for model evaluation.

Args:
data_batch: List of data samples with message_log, extra_env_info, and idx fields.

Returns:
BatchedDataDict with message_log, extra_env_info, and idx fields.

Examples:
```{doctest}
>>> import torch
>>> from nemo_reinforcer.data.datasets import eval_collate_fn
>>> from nemo_reinforcer.data.interfaces import DatumSpec
>>> data_batch = [
... DatumSpec(
... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}],
... extra_env_info={'ground_truth': '1'},
... idx=0,
... ),
... DatumSpec(
... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}],
... extra_env_info={'ground_truth': '2'},
... idx=1,
... ),
... ]
>>> output = eval_collate_fn(data_batch)
>>> output['message_log'][0]
[{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}]
>>> output['message_log'][1]
[{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}]
>>> output['extra_env_info']
[{'ground_truth': '1'}, {'ground_truth': '2'}]
>>> output['idx']
[0, 1]
"""
message_log = [datum_spec["message_log"] for datum_spec in data_batch]
extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch]
idx = [datum_spec["idx"] for datum_spec in data_batch]

output = BatchedDataDict(
message_log=message_log,
extra_env_info=extra_env_info,
idx=idx,
)
return output
28 changes: 26 additions & 2 deletions nemo_reinforcer/data/llm_message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union

from typing import Dict, List

import torch
from datasets import Dataset

from nemo_reinforcer.data.interfaces import (
LLMMessageLogType,
Expand Down Expand Up @@ -390,3 +390,27 @@ def get_formatted_message_log(
prev_formatted_message = formatted_message

return message_log


def remap_dataset_keys(
dataset: Dataset,
mapping_dict: Dict[str, str],
) -> Dataset:
"""Remap dataset keys as per mapping.

Args:
dataset: The input dataset to remap keys in
mapping_dict: A dictionary mapping input keys to output keys

Returns:
Dataset: A new dataset with remapped keys
"""
# no need to remap if the keys are already correct
if all(k == v for k, v in mapping_dict.items()):
return dataset

# return the remapped dataset
return dataset.map(
lambda x: {v: x[k] for k, v in mapping_dict.items()},
remove_columns=list(mapping_dict.keys()),
)
Loading