From 976f20ed1137cf6fd74ca2ff7111f863a1018c18 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 25 Mar 2025 06:49:42 +0000 Subject: [PATCH 1/8] add evaluation implement and minial doc Signed-off-by: Yuki Huang --- docs/guides/eval.md | 33 +++++ docs/index.md | 1 + examples/configs/eval.yaml | 34 +++++ examples/run_eval.py | 128 ++++++++++++++++ examples/run_grpo_math.py | 14 +- nemo_reinforcer/data/__init__.py | 2 + nemo_reinforcer/data/datasets.py | 61 ++++++++ nemo_reinforcer/evals/eval.py | 173 ++++++++++++++++++++++ nemo_reinforcer/evals/run_env_eval.py | 13 -- nemo_reinforcer/models/generation/vllm.py | 68 +++++++++ 10 files changed, 509 insertions(+), 18 deletions(-) create mode 100644 docs/guides/eval.md create mode 100644 examples/configs/eval.yaml create mode 100644 examples/run_eval.py create mode 100644 nemo_reinforcer/evals/eval.py delete mode 100644 nemo_reinforcer/evals/run_env_eval.py diff --git a/docs/guides/eval.md b/docs/guides/eval.md new file mode 100644 index 0000000000..8ac5ab5675 --- /dev/null +++ b/docs/guides/eval.md @@ -0,0 +1,33 @@ +# Evaluation + +## Start Evaluation + +### Start Script +```sh +# To run the evaluation with default config (examples/configs/eval.yaml) +uv run python examples/run_eval.py + +# Specify a custom config file +uv run python examples/run_eval.py --config path/to/custom_config.yaml + +# Override specific config values via command line +uv run python examples/run_eval.py generation.model_name="Qwen/Qwen2.5-Math-7B-Instruct" +``` + +### Example Output + +``` +============================================================ +model_name='Qwen2.5-Math-1.5B-Instruct' dataset_name='aime_2024' +score=0.10 (3.0/30) +============================================================ +``` + +## Configuration + +An example Evaluation configuration file can be found [here](../../examples/configs/eval.yaml). + +### Prompt Template Configuration +Always remember to use the same `prompt_file` and `system_prompt_file` that were used during training. + +For open-source models, we recommend setting `prompt_file=null` and `system_prompt_file=null` to allow them to use their native chat templates. diff --git a/docs/index.md b/docs/index.md index 56cd64ac1b..0628f19953 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,6 +18,7 @@ cluster.md adding_new_models.md guides/sft.md guides/grpo.md +guides/eval.md ``` ```{toctree} diff --git a/examples/configs/eval.yaml b/examples/configs/eval.yaml new file mode 100644 index 0000000000..6837c7c7ac --- /dev/null +++ b/examples/configs/eval.yaml @@ -0,0 +1,34 @@ +# Evaluation Configuration +generation: + backend: "vllm" # only vllm is supported for evaluation + max_new_tokens: ${generation.vllm_cfg.max_model_len} + temperature: 0.0 + top_p: 1.0 + top_k: -1 # disable + num_prompts_per_step: -1 # -1 means pass all prompts at once + model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct" + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.9 + max_model_len: 2048 + vllm_kwargs: + dtype: "bfloat16" + enable_prefix_caching: True + disable_custom_all_reduce: false + +data: + max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation + prompt_file: null + system_prompt_file: null + dataset_name: "HuggingFaceH4/aime_2024" + dataset_key: "train" + problem_key: "problem" + solution_key: "answer" + +env: + math: + num_workers: 8 + +cluster: + gpus_per_node: 1 + num_nodes: 1 diff --git a/examples/run_eval.py b/examples/run_eval.py new file mode 100644 index 0000000000..4704cedee7 --- /dev/null +++ b/examples/run_eval.py @@ -0,0 +1,128 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from datasets import load_dataset +from omegaconf import OmegaConf +from transformers import AutoTokenizer + +from examples.run_grpo_math import math_data_processor +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset +from nemo_reinforcer.data.interfaces import TaskDataSpec +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.environments.math_environment import MathEnvironment +from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup +from nemo_reinforcer.models.generation.interfaces import GenerationConfig + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description="Run Evaluation with configuration") + parser.add_argument( + "--config", type=str, default=None, help="Path to YAML config file" + ) + + # Parse known args for the script + args, remaining = parser.parse_known_args() + + # Convert remaining args to OmegaConf format + overrides = OmegaConf.from_dotlist(remaining) + + return args, overrides + + +def setup_data(data_config: DataConfig, generation_config: GenerationConfig, env_configs): + print("\n▶ Setting up data...") + math_task_spec = TaskDataSpec( + task_name="math", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + + base_dataset = load_dataset(data_config["dataset_name"]) + if data_config["dataset_key"] is not None: + base_dataset = base_dataset[data_config["dataset_key"]] + tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"]) + + math_env = MathEnvironment.options( + runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} + ).remote(env_configs["math"]) + + dataset = AllTaskProcessedDataset( + dataset=base_dataset, + config=data_config, + tokenizer=tokenizer, + default_task_data_spec=math_task_spec, + task_data_processors=math_data_processor, + ) + + return dataset, math_env + + +def main(): + """Main entry point.""" + # Parse arguments + args, overrides = parse_args() + + if not args.config: + args.config = os.path.join(os.path.dirname(__file__), "configs", "eval.yaml") + + config = OmegaConf.load(args.config) + print(f"Loaded configuration from: {args.config}") + + if overrides: + override_conf = OmegaConf.from_cli() + print(f"Overrides: {override_conf}") + config = OmegaConf.merge(config, override_conf) + + config: MasterConfig = OmegaConf.to_container(config, resolve=True) + print("Applied CLI overrides") + + # Print config + print("Final config:") + pprint.pprint(config) + + # Init ray + init_ray() + + # Setup data + dataset, math_env = setup_data( + config["data"], config["generation"], config["env"] + ) + + # Setup + ( + vllm_generation, + dataloader, + master_config, + ) = setup(config, dataset) + + # Run evaluation + run_env_eval( + vllm_generation, + dataloader, + math_env, + master_config, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 7e2f3e693d..b87c7037b3 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -122,6 +122,8 @@ def math_data_processor( template = task_data_spec.custom_template message_log: LLMMessageLogType = [] + + # system prompt if task_data_spec.system_prompt: sys_message = {"role": "system", "content": task_data_spec.system_prompt} message = tokenizer.apply_chat_template( @@ -135,10 +137,11 @@ def math_data_processor( 0 ] message_log.append(sys_message) - user_message = { - "role": "user", - "content": task_data_spec.prompt.format(problem), - } + + # user prompt + if task_data_spec.prompt: + problem = task_data_spec.prompt.format(problem) + user_message = {"role": "user", "content": problem} message = tokenizer.apply_chat_template( [user_message], chat_template=template, @@ -167,8 +170,9 @@ def math_data_processor( "extra_env_info": extra_env_info, "loss_multiplier": loss_multiplier, "idx": idx, - "task_name": datum_dict["task_name"], } + if "task_name" in datum_dict: + output["task_name"] = datum_dict["task_name"] return output diff --git a/nemo_reinforcer/data/__init__.py b/nemo_reinforcer/data/__init__.py index 63aad516b2..fc996661af 100644 --- a/nemo_reinforcer/data/__init__.py +++ b/nemo_reinforcer/data/__init__.py @@ -21,3 +21,5 @@ class DataConfig(TypedDict): system_prompt_file: Optional[str] dataset_name: str val_dataset_name: Optional[str] + problem_key: str + solution_key: str diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 033cee05d7..2a03bf6b96 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -130,3 +130,64 @@ def rl_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: batch_max_length=batch_max_length, ) return output + + +def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: + """Collate function for evaluation. + + Takes a list of data samples and combines them into a single batched dictionary + for model evaluation. Each sample's message content is concatenated into a single + prompt string. + + Args: + data_batch: List of data samples with message_log, extra_env_info, and idx fields. + + Returns: + BatchedDataDict with prompts, message_log, extra_env_info, and idx fields. + + Examples: + ```{doctest} + >>> import torch + >>> from nemo_reinforcer.data.datasets import eval_collate_fn + >>> from nemo_reinforcer.data.interfaces import DatumSpec + >>> data_batch = [ + ... DatumSpec( + ... message_log=[{"role": "user", "content": "Hello", "token_ids": torch.tensor([1, 2, 3])}], + ... extra_env_info={'ground_truth': '1'}, + ... idx=0, + ... ), + ... DatumSpec( + ... message_log=[{"role": "assistant", "content": "Hi there", "token_ids": torch.tensor([4, 5, 6, 7])}], + ... extra_env_info={'ground_truth': '2'}, + ... idx=1, + ... ), + ... ] + >>> output = eval_collate_fn(data_batch) + >>> output['prompts'] + ['Hello', 'Hi there'] + >>> output['message_log'][0] + [{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] + >>> output['message_log'][1] + [{'role': 'assistant', 'content': 'Hi there', 'token_ids': tensor([4, 5, 6, 7])}] + >>> output['extra_env_info'] + [{'ground_truth': '1'}, {'ground_truth': '2'}] + >>> output['idx'] + [0, 1] + """ + prompts = [] + for datum_spec in data_batch: + content = [message["content"] for message in datum_spec["message_log"]] + content = "\n".join(content) + prompts.append(content) + + message_log = [datum_spec["message_log"] for datum_spec in data_batch] + extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] + idx = [datum_spec["idx"] for datum_spec in data_batch] + + output = BatchedDataDict( + prompts=prompts, + message_log=message_log, + extra_env_info=extra_env_info, + idx=idx, + ) + return output diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py new file mode 100644 index 0000000000..631ab34580 --- /dev/null +++ b/nemo_reinforcer/evals/eval.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Tuple, TypedDict + +import ray +from torch.utils.data import DataLoader + +from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, eval_collate_fn +from nemo_reinforcer.data.llm_message_utils import get_keys_from_message_log +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_reinforcer.environments.math_environment import MathEnvConfig +from nemo_reinforcer.models.generation.interfaces import GenerationConfig +from nemo_reinforcer.models.generation.vllm import VllmGeneration + + +# =============================================================================== +# Configuration +# =============================================================================== + + +class MasterConfig(TypedDict): + generate: GenerationConfig + data: DataConfig + math_env: MathEnvConfig + cluster: ClusterConfig + + +# =============================================================================== +# Setup & Initialization +# =============================================================================== + + +def setup( + master_config: MasterConfig, + dataset: AllTaskProcessedDataset, +) -> Tuple[ + VllmGeneration, + DataLoader, + MasterConfig, +]: + """Set up components for model evaluation. + + Initializes the VLLM model and data loader. + + Args: + master_config: Configuration settings. + dataset: Dataset to evaluate on. + + Returns: + VLLM model, data loader, and config. + """ + # Extract individual configs for easier access + generation_config = master_config["generation"] + cluster_config = master_config["cluster"] + + # ========================== + # Data + # ========================== + if generation_config["num_prompts_per_step"] == -1: + generation_config["num_prompts_per_step"] = len(dataset) + dataloader = DataLoader( + dataset, + batch_size=generation_config["num_prompts_per_step"], + shuffle=False, + collate_fn=eval_collate_fn, + ) + print(f" ✓ Evaluation dataset loaded with {len(dataset)} samples") + + # ========================== + # Cluster + # ========================== + print("\n▶ Setting up compute cluster...") + cluster = RayVirtualCluster( + name="eval_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes") + + # ========================== + # Model + # ========================== + print("\n▶ Setting up model...") + backend = generation_config["backend"] + assert backend == "vllm", "Only vLLM backend is supported for evaluation" + vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) + print( + f" ✓ Using vLLM backend for generation with {generation_config['model_name']}" + ) + + print("\n" + "=" * 60) + print(" " * 18 + "SETUP COMPLETE") + print("=" * 60 + "\n") + + return ( + vllm_generation, + dataloader, + master_config, + ) + + +# =============================================================================== +# Evaluation +# =============================================================================== + + +def run_env_eval(vllm_generation, dataloader, env, master_config): + """Main entry point for running evaluation using environment. + + Generates model responses and evaluates them by env. + + Args: + vllm_generation: Model for generating responses. + dataloader: Data loader with evaluation samples. + env: Environment that scores responses. + master_config: Configuration settings. + """ + # Run evaluation loop + score, count = 0.0, 0 + for batch in dataloader: + inputs = BatchedDataDict({"prompts": batch["prompts"]}) + outputs = vllm_generation.generate_text(inputs)["texts"] + + # append to message_log + for idx, output in enumerate(outputs): + batch["message_log"][idx].append( + { + "role": "assistant", + "content": output, + } + ) + + # evaluate generations with the environment + to_env = [ + get_keys_from_message_log(batch["message_log"][i], ["role", "content"]) + for i in range(len(batch["message_log"])) + ] + _, _, rewards, _ = ray.get(env.step.remote(to_env, batch["extra_env_info"])) + + score += rewards.sum().item() + count += len(rewards) + + # Cleanup before printing results + ray.get(env.shutdown.remote()) + vllm_generation.shutdown() + + # Print results + dataset_name = os.path.basename(master_config["data"]["dataset_name"]) + model_name = os.path.basename(master_config["generation"]["model_name"]) + average_score = score / count + + print("\n" + "=" * 60) + print(f"{model_name=} {dataset_name=}") + print(f"score={average_score:.2f} ({score}/{count})") + print("=" * 60 + "\n") diff --git a/nemo_reinforcer/evals/run_env_eval.py b/nemo_reinforcer/evals/run_env_eval.py deleted file mode 100644 index 341a77c5bc..0000000000 --- a/nemo_reinforcer/evals/run_env_eval.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index cb60b7fe8c..8d232e8f97 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -200,6 +200,7 @@ def generate( Args: data: BatchedDataDict containing input_ids and input_lengths tensors + greedy: Whether to use greedy decoding instead of sampling Returns: BatchedDataDict conforming to GenerationOutputSpec: @@ -330,6 +331,37 @@ def generate( return return_data + def generate_text( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate text responses using vLLM generation. + + Args: + data: BatchedDataDict containing prompts with text strings + greedy: Whether to use greedy decoding instead of sampling + + Returns: + BatchedDataDict containing: + - texts: List of generated text responses + """ + # Read generation parameters from config + top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 + sampling_params = self.SamplingParams( + temperature=self.cfg["temperature"], + top_p=self.cfg["top_p"], + top_k=top_k if not greedy else 1, + max_tokens=self.cfg["max_new_tokens"], + stop=self.cfg.get("stop_sequences", None), + ) + + # Generate outputs + outputs = self.llm.generate(data["prompts"], sampling_params) + texts = [output.outputs[0].text for output in outputs] + + # Convert to BatchedDataDict + return_data = BatchedDataDict({"texts": texts}) + return return_data + def shutdown(self): """Clean up vLLM resources.""" try: @@ -537,6 +569,42 @@ def generate( return combined + def generate_text( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate text responses using vLLM.""" + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + + # Get total batch size + batch_size = len(data["prompts"]) + + # Shard the data across the tied worker groups + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=batch_size) + future_bundle = self.worker_group.run_all_workers_multiple_data( + "generate_text", + sharded_data, + common_kwargs={"greedy": greedy}, + respect_tied_workers=True, + ) + + # Get results from the workers, respecting tied worker groups (only one result per tied worker group) + results = self.worker_group.get_all_worker_results(future_bundle) + + # Combine results from all tied worker groups + combined = BatchedDataDict.from_batches(results) + + # Verify the output has all required fields + required_keys = ["texts"] + missing_keys = [key for key in required_keys if key not in combined] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return combined + def prepare_for_generation(self, *args, **kwargs): """Abstract method that must be implemented by subclasses.""" try: From 9c1be171fa2cc99aee17b8ceee26c5f8953c4696 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 25 Mar 2025 07:18:32 +0000 Subject: [PATCH 2/8] add remap_problem_solution Signed-off-by: Yuki Huang --- examples/run_eval.py | 21 +++++++++++----- nemo_reinforcer/data/llm_message_utils.py | 30 +++++++++++++++++++++-- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/examples/run_eval.py b/examples/run_eval.py index 4704cedee7..65bd532d32 100644 --- a/examples/run_eval.py +++ b/examples/run_eval.py @@ -27,6 +27,7 @@ from nemo_reinforcer.data import DataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset from nemo_reinforcer.data.interfaces import TaskDataSpec +from nemo_reinforcer.data.llm_message_utils import remap_problem_solution from nemo_reinforcer.distributed.virtual_cluster import init_ray from nemo_reinforcer.environments.math_environment import MathEnvironment from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup @@ -49,7 +50,9 @@ def parse_args(): return args, overrides -def setup_data(data_config: DataConfig, generation_config: GenerationConfig, env_configs): +def setup_data( + data_config: DataConfig, generation_config: GenerationConfig, env_configs +): print("\n▶ Setting up data...") math_task_spec = TaskDataSpec( task_name="math", @@ -57,9 +60,17 @@ def setup_data(data_config: DataConfig, generation_config: GenerationConfig, env system_prompt_file=data_config["system_prompt_file"], ) + # load dataset base_dataset = load_dataset(data_config["dataset_name"]) if data_config["dataset_key"] is not None: base_dataset = base_dataset[data_config["dataset_key"]] + # remap problem and solution keys + remapped_dataset = remap_problem_solution( + base_dataset, + problem_key=data_config["problem_key"], + solution_key=data_config["solution_key"], + ) + tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"]) math_env = MathEnvironment.options( @@ -67,11 +78,11 @@ def setup_data(data_config: DataConfig, generation_config: GenerationConfig, env ).remote(env_configs["math"]) dataset = AllTaskProcessedDataset( - dataset=base_dataset, - config=data_config, + dataset=remapped_dataset, tokenizer=tokenizer, default_task_data_spec=math_task_spec, task_data_processors=math_data_processor, + max_seq_length=data_config["max_input_seq_length"], ) return dataset, math_env @@ -104,9 +115,7 @@ def main(): init_ray() # Setup data - dataset, math_env = setup_data( - config["data"], config["generation"], config["env"] - ) + dataset, math_env = setup_data(config["data"], config["generation"], config["env"]) # Setup ( diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 43e24fc1ce..08bcef464a 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Union - +from typing import Dict, List import torch +from datasets import Dataset from nemo_reinforcer.data.interfaces import ( LLMMessageLogType, @@ -390,3 +390,29 @@ def get_formatted_message_log( prev_formatted_message = formatted_message return message_log + + +def remap_problem_solution( + dataset: Dataset, + problem_key: str, + solution_key: str, +) -> Dataset: + """Remap the problem and solution keys in a dataset. + + Args: + dataset: The input dataset to remap keys in + problem_key: The key to map problem data to + solution_key: The key to map solution data to + + Returns: + Dataset: A new dataset with remapped keys + """ + # no need to remap if the keys are already correct + if problem_key == "problem" and solution_key == "expected_answer": + return dataset + + # return the remapped dataset + return dataset.map( + lambda x: {"problem": x[problem_key], "expected_answer": x[solution_key]}, + remove_columns=[problem_key, solution_key], + ) From 53994262cf62bb701a5799aa10177330f6ef0f2f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 25 Mar 2025 08:15:26 +0000 Subject: [PATCH 3/8] set load_format=auto to avoid dummy loading in eval Signed-off-by: Yuki Huang --- nemo_reinforcer/evals/eval.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py index 631ab34580..5fa1cb5a1a 100644 --- a/nemo_reinforcer/evals/eval.py +++ b/nemo_reinforcer/evals/eval.py @@ -99,8 +99,11 @@ def setup( # Model # ========================== print("\n▶ Setting up model...") + # check backend backend = generation_config["backend"] assert backend == "vllm", "Only vLLM backend is supported for evaluation" + # initialize vllm generation + generation_config["vllm_cfg"]["load_format"] = "auto" vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) print( f" ✓ Using vLLM backend for generation with {generation_config['model_name']}" From 3e2321f60f96e77a2975111d99cfeb5664783567 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Tue, 25 Mar 2025 15:37:55 +0000 Subject: [PATCH 4/8] remove useless keys; move prompt process outside Signed-off-by: Yuki Huang --- examples/configs/eval.yaml | 4 ---- nemo_reinforcer/data/datasets.py | 14 ++------------ nemo_reinforcer/evals/eval.py | 9 ++++++++- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/examples/configs/eval.yaml b/examples/configs/eval.yaml index 6837c7c7ac..a867e9617f 100644 --- a/examples/configs/eval.yaml +++ b/examples/configs/eval.yaml @@ -11,10 +11,6 @@ generation: tensor_parallel_size: 1 gpu_memory_utilization: 0.9 max_model_len: 2048 - vllm_kwargs: - dtype: "bfloat16" - enable_prefix_caching: True - disable_custom_all_reduce: false data: max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation diff --git a/nemo_reinforcer/data/datasets.py b/nemo_reinforcer/data/datasets.py index 2a03bf6b96..8a81c85fb2 100644 --- a/nemo_reinforcer/data/datasets.py +++ b/nemo_reinforcer/data/datasets.py @@ -136,14 +136,13 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: """Collate function for evaluation. Takes a list of data samples and combines them into a single batched dictionary - for model evaluation. Each sample's message content is concatenated into a single - prompt string. + for model evaluation. Args: data_batch: List of data samples with message_log, extra_env_info, and idx fields. Returns: - BatchedDataDict with prompts, message_log, extra_env_info, and idx fields. + BatchedDataDict with message_log, extra_env_info, and idx fields. Examples: ```{doctest} @@ -163,8 +162,6 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: ... ), ... ] >>> output = eval_collate_fn(data_batch) - >>> output['prompts'] - ['Hello', 'Hi there'] >>> output['message_log'][0] [{'role': 'user', 'content': 'Hello', 'token_ids': tensor([1, 2, 3])}] >>> output['message_log'][1] @@ -174,18 +171,11 @@ def eval_collate_fn(data_batch: List[DatumSpec]) -> BatchedDataDict: >>> output['idx'] [0, 1] """ - prompts = [] - for datum_spec in data_batch: - content = [message["content"] for message in datum_spec["message_log"]] - content = "\n".join(content) - prompts.append(content) - message_log = [datum_spec["message_log"] for datum_spec in data_batch] extra_env_info = [datum_spec["extra_env_info"] for datum_spec in data_batch] idx = [datum_spec["idx"] for datum_spec in data_batch] output = BatchedDataDict( - prompts=prompts, message_log=message_log, extra_env_info=extra_env_info, idx=idx, diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py index 5fa1cb5a1a..8dc5b7546c 100644 --- a/nemo_reinforcer/evals/eval.py +++ b/nemo_reinforcer/evals/eval.py @@ -139,7 +139,14 @@ def run_env_eval(vllm_generation, dataloader, env, master_config): # Run evaluation loop score, count = 0.0, 0 for batch in dataloader: - inputs = BatchedDataDict({"prompts": batch["prompts"]}) + # get input prompt from message_log + prompts = [] + for message_log in batch["message_log"]: + content = [message["content"] for message in message_log] + content = "\n".join(content) + prompts.append(content) + # generate by vllm + inputs = BatchedDataDict({"prompts": prompts}) outputs = vllm_generation.generate_text(inputs)["texts"] # append to message_log From b60c94df2064463310087f1040981b66e77e39ec Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 26 Mar 2025 04:18:17 +0000 Subject: [PATCH 5/8] add MathDataConfig; add remap_dataset_keys; fix rebase Signed-off-by: Yuki Huang --- examples/run_eval.py | 28 +++++++++++++++-------- nemo_reinforcer/data/__init__.py | 10 ++++++++ nemo_reinforcer/data/llm_message_utils.py | 17 +++++++------- nemo_reinforcer/evals/eval.py | 16 +++++++++---- 4 files changed, 49 insertions(+), 22 deletions(-) diff --git a/examples/run_eval.py b/examples/run_eval.py index 65bd532d32..586fb1d6ec 100644 --- a/examples/run_eval.py +++ b/examples/run_eval.py @@ -24,10 +24,10 @@ from transformers import AutoTokenizer from examples.run_grpo_math import math_data_processor -from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data import MathDataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset from nemo_reinforcer.data.interfaces import TaskDataSpec -from nemo_reinforcer.data.llm_message_utils import remap_problem_solution +from nemo_reinforcer.data.llm_message_utils import remap_dataset_keys from nemo_reinforcer.distributed.virtual_cluster import init_ray from nemo_reinforcer.environments.math_environment import MathEnvironment from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup @@ -51,7 +51,7 @@ def parse_args(): def setup_data( - data_config: DataConfig, generation_config: GenerationConfig, env_configs + data_config: MathDataConfig, generation_config: GenerationConfig, env_configs ): print("\n▶ Setting up data...") math_task_spec = TaskDataSpec( @@ -65,13 +65,19 @@ def setup_data( if data_config["dataset_key"] is not None: base_dataset = base_dataset[data_config["dataset_key"]] # remap problem and solution keys - remapped_dataset = remap_problem_solution( + remapped_dataset = remap_dataset_keys( base_dataset, - problem_key=data_config["problem_key"], - solution_key=data_config["solution_key"], + mapping_dict={ + data_config["problem_key"]: "problem", + data_config["solution_key"]: "expected_answer", + }, ) tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"]) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id math_env = MathEnvironment.options( runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} @@ -85,7 +91,7 @@ def setup_data( max_seq_length=data_config["max_input_seq_length"], ) - return dataset, math_env + return dataset, math_env, tokenizer def main(): @@ -115,14 +121,18 @@ def main(): init_ray() # Setup data - dataset, math_env = setup_data(config["data"], config["generation"], config["env"]) + ( + dataset, + math_env, + tokenizer, + ) = setup_data(config["data"], config["generation"], config["env"]) # Setup ( vllm_generation, dataloader, master_config, - ) = setup(config, dataset) + ) = setup(config, tokenizer, dataset) # Run evaluation run_env_eval( diff --git a/nemo_reinforcer/data/__init__.py b/nemo_reinforcer/data/__init__.py index fc996661af..90c61811e4 100644 --- a/nemo_reinforcer/data/__init__.py +++ b/nemo_reinforcer/data/__init__.py @@ -21,5 +21,15 @@ class DataConfig(TypedDict): system_prompt_file: Optional[str] dataset_name: str val_dataset_name: Optional[str] + + +class MathDataConfig(TypedDict): + # all fields from DataConfig + max_input_seq_length: int + prompt_file: str + system_prompt_file: Optional[str] + dataset_name: str + val_dataset_name: Optional[str] + # additional fields specific to math data problem_key: str solution_key: str diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index 08bcef464a..bd86863ab9 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -392,27 +392,26 @@ def get_formatted_message_log( return message_log -def remap_problem_solution( +def remap_dataset_keys( dataset: Dataset, - problem_key: str, - solution_key: str, + mapping_dict: Dict[str, str], ) -> Dataset: - """Remap the problem and solution keys in a dataset. + """Remap dataset keys as per mapping Args: dataset: The input dataset to remap keys in - problem_key: The key to map problem data to - solution_key: The key to map solution data to + mapping_dict: A dictionary mapping input keys to output keys Returns: Dataset: A new dataset with remapped keys """ + # no need to remap if the keys are already correct - if problem_key == "problem" and solution_key == "expected_answer": + if all(k == v for k, v in mapping_dict.items()): return dataset # return the remapped dataset return dataset.map( - lambda x: {"problem": x[problem_key], "expected_answer": x[solution_key]}, - remove_columns=[problem_key, solution_key], + lambda x: {v: x[k] for k, v in mapping_dict.items()}, + remove_columns=list(mapping_dict.keys()), ) diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py index 8dc5b7546c..33d486a4d5 100644 --- a/nemo_reinforcer/evals/eval.py +++ b/nemo_reinforcer/evals/eval.py @@ -17,8 +17,9 @@ import ray from torch.utils.data import DataLoader +from transformers import AutoTokenizer -from nemo_reinforcer.data import DataConfig +from nemo_reinforcer.data import MathDataConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, eval_collate_fn from nemo_reinforcer.data.llm_message_utils import get_keys_from_message_log from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -35,8 +36,8 @@ class MasterConfig(TypedDict): generate: GenerationConfig - data: DataConfig - math_env: MathEnvConfig + data: MathDataConfig + env: MathEnvConfig cluster: ClusterConfig @@ -47,6 +48,7 @@ class MasterConfig(TypedDict): def setup( master_config: MasterConfig, + tokenizer: AutoTokenizer, dataset: AllTaskProcessedDataset, ) -> Tuple[ VllmGeneration, @@ -102,8 +104,14 @@ def setup( # check backend backend = generation_config["backend"] assert backend == "vllm", "Only vLLM backend is supported for evaluation" - # initialize vllm generation + + # set vllm config generation_config["vllm_cfg"]["load_format"] = "auto" + generation_config["vllm_cfg"]["skip_tokenizer_init"] = False + generation_config["stop_token_ids"] = [tokenizer.eos_token_id] + generation_config["pad_token"] = tokenizer.pad_token_id + + # initialize vllm generation vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) print( f" ✓ Using vLLM backend for generation with {generation_config['model_name']}" From 0a21a305bc2cddfe3a7964decf5b81d34634ed5f Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 26 Mar 2025 10:15:51 +0000 Subject: [PATCH 6/8] add eval unit test Signed-off-by: Yuki Huang --- nemo_reinforcer/data/llm_message_utils.py | 3 +- nemo_reinforcer/models/generation/vllm.py | 4 +- .../models/generation/test_vllm_generation.py | 45 +++++++++++++++++-- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/nemo_reinforcer/data/llm_message_utils.py b/nemo_reinforcer/data/llm_message_utils.py index bd86863ab9..5ae8bee9a8 100644 --- a/nemo_reinforcer/data/llm_message_utils.py +++ b/nemo_reinforcer/data/llm_message_utils.py @@ -396,7 +396,7 @@ def remap_dataset_keys( dataset: Dataset, mapping_dict: Dict[str, str], ) -> Dataset: - """Remap dataset keys as per mapping + """Remap dataset keys as per mapping. Args: dataset: The input dataset to remap keys in @@ -405,7 +405,6 @@ def remap_dataset_keys( Returns: Dataset: A new dataset with remapped keys """ - # no need to remap if the keys are already correct if all(k == v for k, v in mapping_dict.items()): return dataset diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 8d232e8f97..2395b65e34 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -247,7 +247,7 @@ def generate( # Read generation parameters from config top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 sampling_params = self.SamplingParams( - temperature=self.cfg["temperature"], + temperature=self.cfg["temperature"] if not greedy else 0, top_p=self.cfg["top_p"], top_k=top_k if not greedy @@ -347,7 +347,7 @@ def generate_text( # Read generation parameters from config top_k = self.cfg["top_k"] if self.cfg["top_k"] is not None else -1 sampling_params = self.SamplingParams( - temperature=self.cfg["temperature"], + temperature=self.cfg["temperature"] if not greedy else 0, top_p=self.cfg["top_p"], top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index af8b5698e3..8c810e31dc 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -40,10 +40,14 @@ } -def configure_vllm_with_tokenizer(vllm_config, tokenizer): +def configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=False): """Apply tokenizer-specific configurations to vLLM config.""" - vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True - vllm_config["vllm_cfg"]["load_format"] = "dummy" + if is_eval: + vllm_config["vllm_cfg"]["skip_tokenizer_init"] = False + vllm_config["vllm_cfg"]["load_format"] = "auto" + else: + vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True + vllm_config["vllm_cfg"]["load_format"] = "dummy" vllm_config["pad_token"] = tokenizer.pad_token_id vllm_config["stop_token_ids"] = [tokenizer.eos_token_id] return vllm_config @@ -532,3 +536,38 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): # Clean up vllm_policy.shutdown() + + +def test_vllm_generate_text(cluster, tokenizer): + """Test that vLLM can generate text.""" + # Prepare test data + test_prompts = [ + "Hello, my name is", + "The capital of France is", + ] + test_prompts = BatchedDataDict({"prompts": test_prompts}) + + # Create separate configs for each policy + vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=True) + + # Ensure we can get same output + assert vllm_config["model_name"] == "meta-llama/Llama-3.2-1B", ( + "Model name should be meta-llama/Llama-3.2-1B to get expected output" + ) + assert vllm_config["vllm_cfg"]["tensor_parallel_size"] == 1, ( + "Tensor parallel size should be 1 to get expected output" + ) + + # Create vLLM generation + vllm_generation = VllmGeneration(cluster, vllm_config) + + # Generate and check result + output = vllm_generation.generate_text(test_prompts, greedy=True) + assert output["texts"] == [ + " Kelsey and I am a 2018 graduate", + " Paris. The city is located in the north of", + ], "Output should be the same as the expected output" + + # Clean up + vllm_generation.shutdown() From f0a5b548370159c4469e3dc7bb8f16f21172e7be Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 26 Mar 2025 14:54:38 +0000 Subject: [PATCH 7/8] remove setting pad_token_id Signed-off-by: Yuki Huang --- examples/run_eval.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/run_eval.py b/examples/run_eval.py index 586fb1d6ec..54358ad260 100644 --- a/examples/run_eval.py +++ b/examples/run_eval.py @@ -76,8 +76,6 @@ def setup_data( tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"]) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id math_env = MathEnvironment.options( runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE} From af3928973f667db8e516fdb5cf6872d3bdb2267b Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 26 Mar 2025 16:18:14 +0000 Subject: [PATCH 8/8] remove duplicated keys in MathDataConfig Signed-off-by: Yuki Huang --- nemo_reinforcer/data/__init__.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/nemo_reinforcer/data/__init__.py b/nemo_reinforcer/data/__init__.py index 90c61811e4..09eaf35fb5 100644 --- a/nemo_reinforcer/data/__init__.py +++ b/nemo_reinforcer/data/__init__.py @@ -23,13 +23,6 @@ class DataConfig(TypedDict): val_dataset_name: Optional[str] -class MathDataConfig(TypedDict): - # all fields from DataConfig - max_input_seq_length: int - prompt_file: str - system_prompt_file: Optional[str] - dataset_name: str - val_dataset_name: Optional[str] - # additional fields specific to math data +class MathDataConfig(DataConfig): problem_key: str solution_key: str