diff --git a/README.md b/README.md index cdf9404834..4dc2f7395f 100644 --- a/README.md +++ b/README.md @@ -377,7 +377,7 @@ uv run python examples/run_eval.py \ ``` > **Note:** Evaluation results may vary slightly due to various factors, such as sampling parameters, random seed, inference engine version, and inference engine settings. -Refer to `examples/configs/eval.yaml` for a full list of parameters that can be overridden. For an in-depth explanation of evaluation, refer to the [Evaluation documentation](docs/guides/eval.md). +Refer to `examples/configs/evals/eval.yaml` for a full list of parameters that can be overridden. For an in-depth explanation of evaluation, refer to the [Evaluation documentation](docs/guides/eval.md). ## Set Up Clusters diff --git a/docs/guides/eval.md b/docs/guides/eval.md index 0281bb21f7..b4f97b8c64 100644 --- a/docs/guides/eval.md +++ b/docs/guides/eval.md @@ -25,7 +25,7 @@ Once the conversion is complete, you can override the `generation.model_name` to ### Prepare the Evaluation Configuration **Override with Custom Settings** -To run the evaluation, you can use the [default configuration file](../../examples/configs/eval.yaml). Alternatively, you can specify a custom one or override some settings via the command line. +To run the evaluation, you can use the [default configuration file](../../examples/configs/evals/eval.yaml). Alternatively, you can specify a custom one or override some settings via the command line. The default configuration employs greedy sampling to evaluate Qwen2.5-Math-1.5B-Instruct on AIME-2024. @@ -42,7 +42,7 @@ We will use the `run_eval.py` script to run an evaluation using a model directly Note that the evaluation script only supports the Hugging Face format model. If you haven't converted your DCP format model, you should back to [Convert DCP to HF](#convert-dcp-to-hf-optional) and follow the guide to convert your model. ```sh -# Run evaluation script with default config (examples/configs/eval.yaml) +# Run evaluation script with default config (examples/configs/evals/eval.yaml) uv run python examples/run_eval.py # Run evaluation script with converted model @@ -51,16 +51,22 @@ uv run python examples/run_eval.py generation.model_name=$PWD/results/grpo/hf # Run evaluation script with custom config file uv run python examples/run_eval.py --config path/to/custom_config.yaml +# Run evaluation script on one of the supported benchmarks (e.g., GPQA) +uv run python examples/run_eval.py --config examples/configs/evals/gpqa_eval.yaml + +# Run evaluation script with a local dataset that is prefetched as a csv file. +uv run python examples/run_eval.py --config examples/configs/evals/local_eval.yaml + # Override specific config values via command line # Example: Evaluation of DeepScaleR-1.5B-Preview on MATH-500 using 8 GPUs # Pass@1 accuracy averaged over 16 samples for each problem uv run python examples/run_eval.py \ + --config examples/configs/evals/math_eval.yaml \ generation.model_name=agentica-org/DeepScaleR-1.5B-Preview \ generation.temperature=0.6 \ generation.top_p=0.95 \ - generation.vllm_cfg.max_model_len=32768 \ - data.dataset_name=HuggingFaceH4/MATH-500 \ - data.dataset_key=test \ + generation.vllm_cfg.max_model_len=32768 \ + data.dataset_name="math500" \ eval.num_tests_per_prompt=16 \ cluster.gpus_per_node=8 ``` @@ -80,3 +86,12 @@ metric='pass@1' num_tests_per_prompt=1 score=0.1000 (3.0/30) ============================================================ ``` + +## List of currently supported benchmarks + +- [AIME-2024](../../nemo_rl/data/eval_datasets/aime2024.py) +- [GPQA and GPQA-diamond](../../nemo_rl/data/eval_datasets/gpqa.py) +- [MATH and MATH-500](../../nemo_rl/data/eval_datasets/math.py) +- [MMLU](../../nemo_rl/data/eval_datasets/mmlu.py) +- [MMLU-Pro](../../nemo_rl/data/eval_datasets/mmlu_pro.py) + diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 1f63df559d..f577820a21 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -67,7 +67,7 @@ def my_data_processor( ) -> DatumSpec: ``` -We have an example of this as `math_data_processor` in [run_grpo_math.py](../../examples/run_grpo_math.py) +We have an example of this as `math_data_processor` in [processors.py](../../nemo_rl/data/processors.py) #### Putting it all together diff --git a/docs/guides/sft-openmathinstruct2.md b/docs/guides/sft-openmathinstruct2.md index 6698c12bc0..1228d42a7d 100644 --- a/docs/guides/sft-openmathinstruct2.md +++ b/docs/guides/sft-openmathinstruct2.md @@ -38,7 +38,7 @@ To evaluate on the [MATH-500 benchmark](https://huggingface.co/datasets/HuggingF ``` uv run examples/run_eval.py \ - --config=examples/configs/eval.yaml \ + --config=examples/configs/evals/eval.yaml \ generation.model_name=results/sft_openmathinstruct2/step_1855/hf \ tokenizer.name=meta-llama/Llama-3.1-8B-Instruct \ data.dataset_name=HuggingFaceH4/MATH-500 \ diff --git a/examples/configs/eval.yaml b/examples/configs/evals/eval.yaml similarity index 92% rename from examples/configs/eval.yaml rename to examples/configs/evals/eval.yaml index e880d98bc7..eab0f1db21 100644 --- a/examples/configs/eval.yaml +++ b/examples/configs/evals/eval.yaml @@ -40,10 +40,7 @@ data: max_input_seq_length: ${generation.vllm_cfg.max_model_len} # useless since we directly use prompts in evaluation prompt_file: null system_prompt_file: null - dataset_name: "HuggingFaceH4/aime_2024" - dataset_key: "train" - problem_key: "problem" - solution_key: "answer" + dataset_name: "aime2024" env: math: diff --git a/examples/configs/evals/gpqa_eval.yaml b/examples/configs/evals/gpqa_eval.yaml new file mode 100644 index 0000000000..463702d3a4 --- /dev/null +++ b/examples/configs/evals/gpqa_eval.yaml @@ -0,0 +1,15 @@ +# GPQA evaluation Configuration +defaults: "eval.yaml" + +generation: + model_name: "Qwen/Qwen2.5-7B-Instruct" + vllm_cfg: + max_model_len: 3072 + +data: + prompt_file: "examples/prompts/gpqa.txt" + dataset_name: "gpqa" + +env: + math: + verifier_type: "multichoice" diff --git a/examples/configs/evals/local_eval.yaml b/examples/configs/evals/local_eval.yaml new file mode 100644 index 0000000000..ad9def2112 --- /dev/null +++ b/examples/configs/evals/local_eval.yaml @@ -0,0 +1,14 @@ +# Evaluation Configuration from local files. +defaults: "eval.yaml" + +generation: + model_name: "Qwen/Qwen2.5-7B-Instruct" + +data: + prompt_file: "examples/prompts/cot.txt" + dataset_name: "local" + problem_key: "Question" + solution_key: "Answer" + split: "train" + data_paths: "https:\/\/openaipublic.blob.core.windows.net\/simple-evals\/math_500_test.csv" + file_format: "csv" diff --git a/examples/configs/evals/math_eval.yaml b/examples/configs/evals/math_eval.yaml new file mode 100644 index 0000000000..b42956866d --- /dev/null +++ b/examples/configs/evals/math_eval.yaml @@ -0,0 +1,9 @@ +# Math evaluation Configuration +defaults: "eval.yaml" + +generation: + model_name: "Qwen/Qwen2.5-7B-Instruct" + +data: + prompt_file: "examples/prompts/cot.txt" + dataset_name: "math" diff --git a/examples/prompts/gpqa.txt b/examples/prompts/gpqa.txt new file mode 100644 index 0000000000..04ea20d553 --- /dev/null +++ b/examples/prompts/gpqa.txt @@ -0,0 +1 @@ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. diff --git a/examples/prompts/mmlu.txt b/examples/prompts/mmlu.txt new file mode 100644 index 0000000000..04ea20d553 --- /dev/null +++ b/examples/prompts/mmlu.txt @@ -0,0 +1 @@ +Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. diff --git a/examples/run_eval.py b/examples/run_eval.py index 6f7f60cc44..89e2ede395 100644 --- a/examples/run_eval.py +++ b/examples/run_eval.py @@ -19,16 +19,12 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from datasets import load_dataset from omegaconf import OmegaConf -from transformers import AutoTokenizer +from transformers import AutoTokenizer, PreTrainedTokenizerBase -from examples.run_grpo_math import math_data_processor from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.data import MathDataConfig from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import TaskDataSpec -from nemo_rl.data.llm_message_utils import remap_dataset_keys +from nemo_rl.data.eval_datasets import load_eval_dataset from nemo_rl.distributed.ray_actor_environment_registry import ( get_actor_python_env, ) @@ -36,6 +32,9 @@ from nemo_rl.environments.math_environment import MathEnvironment from nemo_rl.evals.eval import MasterConfig, run_env_eval, setup from nemo_rl.models.generation import configure_generation_config +from nemo_rl.utils.config import load_config + +TokenizerType = PreTrainedTokenizerBase def parse_args(): @@ -54,28 +53,14 @@ def parse_args(): return args, overrides -def setup_data(tokenizer: AutoTokenizer, data_config: MathDataConfig, env_configs): - print("\n▶ Setting up data...") - math_task_spec = TaskDataSpec( - task_name="math", - prompt_file=data_config["prompt_file"], - system_prompt_file=data_config["system_prompt_file"], - ) +def setup_data(tokenizer: AutoTokenizer, data_config, env_configs): + print("Setting up data...") # load dataset - base_dataset = load_dataset(data_config["dataset_name"]) - if data_config["dataset_key"] is not None: - base_dataset = base_dataset[data_config["dataset_key"]] - # remap problem and solution keys - remapped_dataset = remap_dataset_keys( - base_dataset, - mapping_dict={ - data_config["problem_key"]: "problem", - data_config["solution_key"]: "expected_answer", - }, - ) + base_dataset = load_eval_dataset(data_config) + rekeyed_ds = base_dataset.rekeyed_ds - math_env = MathEnvironment.options( + env = MathEnvironment.options( runtime_env={ "py_executable": get_actor_python_env( "nemo_rl.environments.math_environment.MathEnvironment" @@ -84,14 +69,14 @@ def setup_data(tokenizer: AutoTokenizer, data_config: MathDataConfig, env_config ).remote(env_configs["math"]) dataset = AllTaskProcessedDataset( - dataset=remapped_dataset, + dataset=rekeyed_ds, tokenizer=tokenizer, - default_task_data_spec=math_task_spec, - task_data_processors=math_data_processor, + default_task_data_spec=base_dataset.task_spec, + task_data_processors=base_dataset.processor, max_seq_length=data_config["max_input_seq_length"], ) - return dataset, math_env, tokenizer + return dataset, env, tokenizer def main(): @@ -100,9 +85,11 @@ def main(): args, overrides = parse_args() if not args.config: - args.config = os.path.join(os.path.dirname(__file__), "configs", "eval.yaml") + args.config = os.path.join( + os.path.dirname(__file__), "configs", "evals", "eval.yaml" + ) - config = OmegaConf.load(args.config) + config = load_config(args.config) print(f"Loaded configuration from: {args.config}") if overrides: @@ -129,7 +116,7 @@ def main(): # Setup data ( dataset, - math_env, + env, tokenizer, ) = setup_data(tokenizer, config["data"], config["env"]) @@ -144,7 +131,7 @@ def main(): run_env_eval( vllm_generation, dataloader, - math_env, + env, master_config, ) diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 4a64d3c13b..673322eb61 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -16,9 +16,8 @@ import os import pprint from collections import defaultdict -from typing import Any, Optional, cast +from typing import Any, Optional -import torch from omegaconf import OmegaConf from transformers import PreTrainedTokenizerBase @@ -116,74 +115,6 @@ def hf_data_processor( return output -# Example of a generic math data processor -# TaskDataProcessFnCallable -def math_data_processor( - datum_dict: dict[str, Any], - task_data_spec: TaskDataSpec, - tokenizer: TokenizerType, - max_seq_length: int, - idx: int, -) -> DatumSpec: - """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment.""" - problem = datum_dict["problem"] - solution = str(datum_dict["expected_answer"]) - extra_env_info = {"ground_truth": solution} - - message_log: LLMMessageLogType = [] - - # system prompt - if task_data_spec.system_prompt: - sys_prompt: dict[str, str | torch.Tensor] = { - "role": "system", - "content": task_data_spec.system_prompt, - } - sys = tokenizer.apply_chat_template( - [cast(dict[str, str], sys_prompt)], - tokenize=False, - add_generation_prompt=False, - add_special_tokens=False, - ) - sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0] - message_log.append(sys_prompt) - - # user prompt - if task_data_spec.prompt: - problem = task_data_spec.prompt.format(problem) - user_message = {"role": "user", "content": problem} - message = tokenizer.apply_chat_template( - [user_message], - tokenize=False, - add_generation_prompt=True, - add_special_tokens=False, - ) - user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0] - user_message["content"] = message - message_log.append(user_message) - - length = sum(len(m["token_ids"]) for m in message_log) - - loss_multiplier = 1.0 - if length > max_seq_length: - # make smaller and mask out - for indiv_message in message_log: - indiv_message["token_ids"] = indiv_message["token_ids"][ - : min(4, max_seq_length // len(message_log)) - ] - loss_multiplier = 0.0 - - output: DatumSpec = { - "message_log": message_log, - "length": length, - "extra_env_info": extra_env_info, - "loss_multiplier": loss_multiplier, - "idx": idx, - } - if "task_name" in datum_dict: - output["task_name"] = datum_dict["task_name"] - return output - - def setup_data( tokenizer: TokenizerType, data_config: DataConfig, diff --git a/nemo_rl/data/eval_datasets/__init__.py b/nemo_rl/data/eval_datasets/__init__.py new file mode 100644 index 0000000000..2e5ba97974 --- /dev/null +++ b/nemo_rl/data/eval_datasets/__init__.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.data.eval_datasets.aime2024 import AIME2024Dataset +from nemo_rl.data.eval_datasets.gpqa import GPQADataset +from nemo_rl.data.eval_datasets.local_math_dataset import LocalMathDataset +from nemo_rl.data.eval_datasets.math import MathDataset +from nemo_rl.data.eval_datasets.mmlu import MMLUDataset +from nemo_rl.data.eval_datasets.mmlu_pro import MMLUProDataset + + +def load_eval_dataset(data_config): + """Loads evaluation dataset.""" + dataset_name = data_config["dataset_name"] + if dataset_name == "mmlu": + base_dataset = MMLUDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "aime2024": + base_dataset = AIME2024Dataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "gpqa": + base_dataset = GPQADataset( + variant="main", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "gpqa_diamond": + base_dataset = GPQADataset( + variant="diamond", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "mmlu_pro": + base_dataset = MMLUProDataset( + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "math": + base_dataset = MathDataset( + variant="math_test", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "math500": + base_dataset = MathDataset( + variant="math_500_test", + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + elif dataset_name == "local": + base_dataset = LocalMathDataset( + name=dataset_name, + data_paths=data_config["data_paths"], + problem_key=data_config["problem_key"], + solution_key=data_config["solution_key"], + file_format=data_config["file_format"], + split=data_config["split"], + prompt_file=data_config["prompt_file"], + system_prompt_file=data_config["system_prompt_file"], + ) + else: + raise ValueError(f"Unknown dataset {dataset_name}.") + return base_dataset + + +__all__ = [ + "AIME2024Dataset", + "GPQADataset", + "LocalMathDataset", + "MathDataset", + "MMLUDataset", + "MMLUProDataset", +] diff --git a/nemo_rl/data/eval_datasets/aime2024.py b/nemo_rl/data/eval_datasets/aime2024.py new file mode 100644 index 0000000000..9e585bb511 --- /dev/null +++ b/nemo_rl/data/eval_datasets/aime2024.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AIME 2024 dataset.""" + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class AIME2024Dataset: + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + ds = load_dataset("HuggingFaceH4/aime_2024", split="train") + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name="aime2024", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.math_data_processor + + def _rekey(self, data: dict[str, Any]): + return { + "problem": data["problem"], + "expected_answer": data["answer"], + } diff --git a/nemo_rl/data/eval_datasets/gpqa.py b/nemo_rl/data/eval_datasets/gpqa.py new file mode 100644 index 0000000000..f41efa136a --- /dev/null +++ b/nemo_rl/data/eval_datasets/gpqa.py @@ -0,0 +1,63 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPQA dataset and its variants.""" + +import random +from typing import Any, Literal, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class GPQADataset: + def __init__( + self, + variant: Literal["diamond", "main"] = "diamond", + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + ds = load_dataset("Idavidrein/gpqa", f"gpqa_{variant}", split="train") + self._rng = random.Random() + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name=f"GPQA_{variant}", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.multichoice_qa_processor + + def _rekey(self, data: dict[str, Any]): + choices = [ + data["Correct Answer"], + data["Incorrect Answer 1"], + data["Incorrect Answer 2"], + data["Incorrect Answer 3"], + ] + permutation = self._rng.sample(range(4), 4) + choices = [choices[i] for i in permutation] + correct_index = choices.index(data["Correct Answer"]) + correct_answer = "ABCD"[correct_index] + return { + "question": data["Question"], + "options": dict( + A=choices[0], + B=choices[1], + C=choices[2], + D=choices[3], + ), + "answer": correct_answer, + } diff --git a/nemo_rl/data/eval_datasets/local_math_dataset.py b/nemo_rl/data/eval_datasets/local_math_dataset.py new file mode 100644 index 0000000000..2810899b4a --- /dev/null +++ b/nemo_rl/data/eval_datasets/local_math_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Local math dataset.""" + +from typing import Any, Literal, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class LocalMathDataset: + def __init__( + self, + data_paths: str | list[str], + problem_key: str, + solution_key: str, + name: str, + split: Optional[str] = None, + file_format: Literal["csv", "json"] = "csv", + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + ds = load_dataset(file_format, data_files=data_paths) + if split is not None: + ds = ds[split] + self._problem_key = problem_key + self._solution_key = solution_key + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name=name, + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.math_data_processor + + def _rekey(self, data: dict[str, Any]): + return { + "problem": data[self._problem_key], + "expected_answer": data[self._solution_key], + } diff --git a/nemo_rl/data/eval_datasets/math.py b/nemo_rl/data/eval_datasets/math.py new file mode 100644 index 0000000000..290902657e --- /dev/null +++ b/nemo_rl/data/eval_datasets/math.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Math dataset and its variants.""" + +from typing import Any, Literal, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class MathDataset: + def __init__( + self, + variant: Literal["math_test", "math_500_test"] = "math_test", + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + ds = load_dataset( + "csv", + data_files=f"https://openaipublic.blob.core.windows.net/simple-evals/{variant}.csv", + split="train", + ) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + self.task_spec = TaskDataSpec( + task_name=f"{variant}", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.math_data_processor + + def _rekey(self, data: dict[str, Any]): + return { + "problem": data["Question"], + "expected_answer": data["Answer"], + } diff --git a/nemo_rl/data/eval_datasets/mmlu.py b/nemo_rl/data/eval_datasets/mmlu.py new file mode 100644 index 0000000000..f8b75d3b56 --- /dev/null +++ b/nemo_rl/data/eval_datasets/mmlu.py @@ -0,0 +1,56 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MMLU dataset and its variants.""" + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class MMLUDataset: + def __init__( + self, + prompt_file: Optional[str] = None, + system_prompt_file: Optional[str] = None, + ): + ds = load_dataset( + "csv", + data_files="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + split="train", + ) + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + + self.task_spec = TaskDataSpec( + task_name="MMLU", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.multichoice_qa_processor + + def _rekey(self, data: dict[str, Any]): + return { + "question": data["Question"], + "options": dict( + A=data["A"], + B=data["B"], + C=data["C"], + D=data["D"], + ), + "answer": data["Answer"], + "subject": data["Subject"], + } diff --git a/nemo_rl/data/eval_datasets/mmlu_pro.py b/nemo_rl/data/eval_datasets/mmlu_pro.py new file mode 100644 index 0000000000..159d4d1738 --- /dev/null +++ b/nemo_rl/data/eval_datasets/mmlu_pro.py @@ -0,0 +1,44 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MMLU-Pro dataset.""" + +from typing import Any, Optional + +from datasets import load_dataset + +from nemo_rl.data import processors +from nemo_rl.data.interfaces import TaskDataSpec + + +class MMLUProDataset: + def __init__(self, prompt_file: str, system_prompt_file: Optional[str] = None): + ds = load_dataset("TIGER-Lab/MMLU-Pro", split="test") + self.rekeyed_ds = ds.map(self._rekey, remove_columns=ds.column_names) + + self.task_spec = TaskDataSpec( + task_name="MMLU-Pro", + prompt_file=prompt_file, + system_prompt_file=system_prompt_file, + ) + self.processor = processors.multichoice_qa_processor + + def _rekey(self, data: dict[str, Any]): + options = {chr(ord("A") + i): op for i, op in enumerate(data["options"])} + return { + "question": data["question"], + "options": options, + "answer": data["answer"], + "subject": data["category"], + } diff --git a/nemo_rl/data/processors.py b/nemo_rl/data/processors.py new file mode 100644 index 0000000000..67e3658882 --- /dev/null +++ b/nemo_rl/data/processors.py @@ -0,0 +1,168 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains data processors for evaluation.""" + +from typing import Any, cast + +import torch +from transformers import PreTrainedTokenizerBase + +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec + +TokenizerType = PreTrainedTokenizerBase + + +# Example of a generic math data processor +def math_data_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for the Math Environment.""" + problem = datum_dict["problem"] + solution = str(datum_dict["expected_answer"]) + extra_env_info = {"ground_truth": solution} + + message_log: LLMMessageLogType = [] + + # system prompt + if task_data_spec.system_prompt: + sys_prompt: dict[str, str | torch.Tensor] = { + "role": "system", + "content": task_data_spec.system_prompt, + } + sys = tokenizer.apply_chat_template( + [cast(dict[str, str], sys_prompt)], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0] + message_log.append(sys_prompt) + + # user prompt + if task_data_spec.prompt: + problem = task_data_spec.prompt.format(problem) + user_message = {"role": "user", "content": problem} + message = tokenizer.apply_chat_template( + [user_message], + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0] + user_message["content"] = message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + + loss_multiplier = 1.0 + if length > max_seq_length: + # make smaller and mask out + for indiv_message in message_log: + indiv_message["token_ids"] = indiv_message["token_ids"][ + : min(4, max_seq_length // len(message_log)) + ] + loss_multiplier = 0.0 + + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": loss_multiplier, + "idx": idx, + } + if "task_name" in datum_dict: + output["task_name"] = datum_dict["task_name"] + return output + + +def _construct_multichoice_prompt( + prompt: str, question: str, options: dict[str, str] +) -> str: + """Construct prompt from question and options.""" + output = prompt + output += f"\n\nQuestion: {question}\nOptions:\n" + output += "\n".join( + [ + f"{letter}) {option}" + for letter, option in options.items() + if option is not None + ] + ) + return output + + +def multichoice_qa_processor( + datum_dict: dict[str, Any], + task_data_spec: TaskDataSpec, + tokenizer: TokenizerType, + max_seq_length: int, + idx: int, +) -> DatumSpec: + """Process a datum dictionary (directly loaded from dataset) into a DatumSpec for multiple-choice problems.""" + question = datum_dict["question"] + answer = str(datum_dict["answer"]) + options = datum_dict["options"] + extra_env_info = {"ground_truth": answer} + if "subject" in datum_dict: + extra_env_info.update({"subject": datum_dict["subject"]}) + + message_log = [] + + # system prompt + if task_data_spec.system_prompt: + sys_prompt: dict[str, str | torch.Tensor] = { + "role": "system", + "content": task_data_spec.system_prompt, + } + sys = tokenizer.apply_chat_template( + [cast(dict[str, str], sys_prompt)], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + sys_prompt["token_ids"] = tokenizer(sys, return_tensors="pt")["input_ids"][0] + message_log.append(sys_prompt) + + # user prompt + if task_data_spec.prompt: + question = _construct_multichoice_prompt( + task_data_spec.prompt, question, options + ) + user_message = {"role": "user", "content": question} + message = tokenizer.apply_chat_template( + [user_message], + tokenize=False, + add_generation_prompt=True, + add_special_tokens=False, + ) + user_message["token_ids"] = tokenizer(message, return_tensors="pt")["input_ids"][0] + user_message["content"] = message + message_log.append(user_message) + + length = sum(len(m["token_ids"]) for m in message_log) + output: DatumSpec = { + "message_log": message_log, + "length": length, + "extra_env_info": extra_env_info, + "loss_multiplier": 1.0, + "idx": idx, + } + if "task_name" in datum_dict: + output["task_name"] = datum_dict["task_name"] + return output diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index e8a47db06f..8dd5247f1c 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -14,6 +14,7 @@ import contextlib import io import logging +import re from typing import Any, Optional, TypedDict import ray @@ -32,11 +33,13 @@ calculate_pass_rate_per_prompt, ) from nemo_rl.environments.utils import chunk_list_to_workers +from nemo_rl.evals import answer_parsing class MathEnvConfig(TypedDict): num_workers: int stop_strings: Optional[list[str]] # Default stop strings for this env + verifier_type: Optional[str] @contextlib.contextmanager @@ -97,6 +100,39 @@ def verify( return results +@ray.remote +class MultichoiceVerifyWorker: + def verify( + self, pred_responses: list[str], ground_truths: list[str] + ) -> list[float]: + """Verify the correctness of the predicted responses against the ground truth. + + Args: + pred_responses: list[str]. The predicted responses from the LLM. + ground_truths: list[str]. The ground truth responses. + + Returns: + list[float]. The rewards for each predicted response. + """ + results = [] + for response, ground_truth in zip(pred_responses, ground_truths): + response = answer_parsing.normalize_response(response) + extracted_answer = None + for answer_regex in answer_parsing.MULTILINGUAL_ANSWER_REGEXES: + regex = answer_parsing.MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format( + answer_regex + ) + match = re.search(regex, response) + if match: + extracted_answer = answer_parsing.normalize_extracted_answer( + match.group(1) + ) + break + score = 1.0 if extracted_answer == ground_truth else 0.0 + results.append(score) + return results + + class MathEnvironmentMetadata(TypedDict): ground_truth: str @@ -106,8 +142,13 @@ class MathEnvironment(EnvironmentInterface): def __init__(self, cfg: MathEnvConfig): self.cfg = cfg self.num_workers = cfg["num_workers"] + worker_cls = ( + MultichoiceVerifyWorker + if cfg.get("verifier_type", "math") == "multichoice" + else HFVerifyWorker + ) self.workers = [ - HFVerifyWorker.options( # type: ignore # (decorated with @ray.remote) + worker_cls.options( # type: ignore # (decorated with @ray.remote) runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} ).remote() for _ in range(self.num_workers) diff --git a/nemo_rl/evals/answer_parsing.py b/nemo_rl/evals/answer_parsing.py new file mode 100644 index 0000000000..dcf020774a --- /dev/null +++ b/nemo_rl/evals/answer_parsing.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains utility functions for answer parsing.""" + +MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( + "(?i){}[ \t]*([A-D]|[أ-د]|[অ]|[ব]|[ড]|[ঢ]|[A]|[B]|[C]|[D])" +) +# All the different ways "Answer" is written in different languages +MULTILINGUAL_ANSWER_REGEXES = [ + "Answer\s*:", + "Answer\s*:​​​​​​", # Korean invisible character + "উত্তর\s*:", + "उत्तर\s*:", + "উত্তরঃ", + "উত্তর\s*:", + "Antwort\s*:", + "답변\s*:", + "정답\s*:", + "답\s*:", + "答案\s*:", + "答案\s*:", + "答\s*:", + "答\s*:", + "答复\s*:", + "答曰\s*:", + "الإجابة:", + "الجواب:", + "إجابة:", + "الإجابة النهائية:", + "الإجابة الصحيحة:", + "الإجابة الصحيحة هي:", + "الإجابة هي:", + "الجواب النهائي:", + "Respuesta\s*:", + "Risposta\s*:", + "答え\s*:", + "答え\s*:", + "回答\s*:", + "回答\s*:", + "解答\s*:", + "Jawaban\s*:", + "Réponse\s*:", + "Resposta\s*:", + "Jibu\s*:", + "Idahun\s*:", + "Ìdáhùn\s*:", + "Idáhùn\s*:", + "Àmọ̀nà\s*:", + "Àdáhùn\s*:", + "Ànúgọ\s*:", + "Àṣàyàn\s*:", +] + + +def normalize_extracted_answer(extracted_answer: str) -> str: + return ( + # In arabic these are the letters used for A-D in multiple choice questions + extracted_answer.replace("أ", " A") + .replace("ب", " B") + .replace("ج", " C") + .replace("د", " D") + # In Bengali these are the letters used for A-D in multiple choice questions + .replace("অ", " A") + .replace("ব", " B") + .replace("ড", " C") + .replace("ঢ", " D") + # In Japanese these are the letters sometimes used for A-D in multiple choice questions + .replace("A", " A") + .replace("B", " B") + .replace("C", " C") + .replace("D", " D") + .strip() + ) + + +def normalize_response(response: str) -> str: + """Normalize the response by removing markdown and LaTeX formatting that may prevent a match.""" + return ( + response.replace("**", "") + .replace("$\\boxed{", "") + .replace("}$", "") + .replace("\\$", "") + .replace("$\\text{", "") + .replace("$", "") + .replace("\\mathrm{", "") + .replace("\\{", "") + .replace("\\text", "") + .replace("\\(", "") + .replace("\\mathbf{", "") + .replace("{", "") + .replace("\\boxed", "") + ) diff --git a/tests/functional/test_converter_roundtrip.py b/tests/functional/test_converter_roundtrip.py index ea865be9b2..9679fcc724 100644 --- a/tests/functional/test_converter_roundtrip.py +++ b/tests/functional/test_converter_roundtrip.py @@ -13,20 +13,6 @@ # limitations under the License. #!/usr/bin/env python3 -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - """ Functional test for converter roundtrip functionality. diff --git a/tests/unit/data/eval_datasets/test_gpqa.py b/tests/unit/data/eval_datasets/test_gpqa.py new file mode 100644 index 0000000000..3441f11974 --- /dev/null +++ b/tests/unit/data/eval_datasets/test_gpqa.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from transformers import AutoTokenizer + +from nemo_rl.data.eval_datasets.gpqa import GPQADataset + + +@pytest.mark.skip(reason="dataset download is flaky") +def test_gpqa_dataset(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + gpqa_dataset = GPQADataset() + + # check that the dataset is formatted correctly + for example in gpqa_dataset.rekeyed_ds.take(5): + assert "question" in example + assert "options" in example + assert "answer" in example + + ## check that applying chat template works as expected + default_templated = tokenizer.apply_chat_template( + [{"role": "user", "content": example["question"]}], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert ( + default_templated + == f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['question']}<|im_end|>\n" + ) diff --git a/tests/unit/data/eval_datasets/test_math.py b/tests/unit/data/eval_datasets/test_math.py new file mode 100644 index 0000000000..3bab184f1a --- /dev/null +++ b/tests/unit/data/eval_datasets/test_math.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from transformers import AutoTokenizer + +from nemo_rl.data.eval_datasets.math import MathDataset + + +@pytest.mark.skip(reason="dataset download is flaky") +def test_math_dataset(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + math_dataset = MathDataset() + + # check that the dataset is formatted correctly + for example in math_dataset.rekeyed_ds.take(5): + assert "problem" in example + assert "expected_answer" in example + + ## check that applying chat template works as expected + default_templated = tokenizer.apply_chat_template( + [{"role": "user", "content": example["problem"]}], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert ( + default_templated + == f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['problem']}<|im_end|>\n" + ) diff --git a/tests/unit/data/eval_datasets/test_mmlu.py b/tests/unit/data/eval_datasets/test_mmlu.py new file mode 100644 index 0000000000..02c1936003 --- /dev/null +++ b/tests/unit/data/eval_datasets/test_mmlu.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from transformers import AutoTokenizer + +from nemo_rl.data.eval_datasets.mmlu import MMLUDataset + + +@pytest.mark.skip(reason="dataset download is flaky") +def test_mmlu_dataset(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + mmlu_dataset = MMLUDataset() + + # check that the dataset is formatted correctly + for example in mmlu_dataset.rekeyed_ds.take(5): + assert "question" in example + assert "options" in example + assert "answer" in example + assert "subject" in example + + ## check that applying chat template works as expected + default_templated = tokenizer.apply_chat_template( + [{"role": "user", "content": example["question"]}], + tokenize=False, + add_generation_prompt=False, + add_special_tokens=False, + ) + + assert ( + default_templated + == f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['question']}<|im_end|>\n" + ) diff --git a/tests/unit/data/test_data_processor.py b/tests/unit/data/test_data_processor.py index 302dfece77..dc88bebee3 100644 --- a/tests/unit/data/test_data_processor.py +++ b/tests/unit/data/test_data_processor.py @@ -20,10 +20,10 @@ abspath = os.path.abspath(__file__) sys.path.append("/".join(abspath.split("/")[:-4])) -from examples.run_grpo_math import math_data_processor from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.datasets import AllTaskProcessedDataset from nemo_rl.data.interfaces import TaskDataSpec +from nemo_rl.data.processors import math_data_processor from nemo_rl.models.policy import TokenizerConfig basic_tokenizer_test_config: TokenizerConfig = { diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 386a21e2f8..b254f2ef5f 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -42,6 +42,25 @@ def math_env(): time.sleep(0.1) +@pytest.fixture(scope="module") +def multichoice_env(): + """Create a MathEnvironment actor for testing.""" + env = MathEnvironment.options( + runtime_env={ + "py_executable": get_actor_python_env( + "nemo_rl.environments.math_environment.MathEnvironment" + ), + "env_vars": dict(os.environ), + } + ).remote({"num_workers": 2, "verifier_type": "multichoice"}) + yield env + # Clean up the actor and wait for it to be killed + env.shutdown.remote() + ray.kill(env) + # Give some time for cleanup + time.sleep(0.1) + + @pytest.fixture def basic_test_data(): """Common test data for basic math problems.""" @@ -68,6 +87,41 @@ def basic_test_data(): } +@pytest.fixture +def basic_multichoice_test_data(): + """Common test data for basic multichoice problems.""" + return { + "message_log_batch": [ + [ + { + "role": "user", + "content": "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD", + }, + {"role": "assistant", "content": "\nAnswer: C"}, + ], + [ + { + "role": "user", + "content": "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD", + }, + {"role": "assistant", "content": "\nAnswer: B"}, + ], + [ + { + "role": "user", + "content": "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD", + }, + {"role": "assistant", "content": "\nAnswer: D"}, + ], + ], + "metadata": [ + {"ground_truth": "C"}, + {"ground_truth": "B"}, + {"ground_truth": "B"}, + ], + } + + @pytest.fixture def mixed_test_data(): """Test data with mix of correct and incorrect responses.""" @@ -148,6 +202,47 @@ def test_math_env_step_basic(math_env, basic_test_data): assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" +def test_multichoice_env_step_basic(multichoice_env, basic_multichoice_test_data): + """Test basic functionality of MathEnvironment step with multichoice verifier.""" + result = ray.get( + multichoice_env.step.remote( + basic_multichoice_test_data["message_log_batch"], + basic_multichoice_test_data["metadata"], + ) + ) + + # Check observations using field access + assert len(result.observations) == 3, ( + "Should return observations for all 3 messages" + ) + assert all(obs["role"] == "environment" for obs in result.observations), ( + "All observations should be from environment" + ) + assert all( + obs["content"] == "Environment: correct" for obs in result.observations[:2] + ), "The first two responses should be correct" + assert result.observations[2]["content"] == "Environment: incorrect", ( + "The third response should be incorrect" + ) + + # Check metadata + assert len(result.metadata) == 3, "Should return metadata for all 3 messages" + assert result.metadata == basic_multichoice_test_data["metadata"], ( + "Metadata should be unchanged" + ) + + # Check rewards and done flags + assert result.rewards.shape == (3,), "Rewards should be a tensor of shape (3,)" + assert all(result.rewards[:2] == 1.0), ( + "The first two rewards should be 1.0 for correct answers" + ) + assert result.rewards[2] == 0.0, "The thrid reward should be 0.0 for wrong answer" + assert result.terminateds.shape == (3,), ( + "Terminated flags should be a tensor of shape (3,)" + ) + assert all(result.terminateds == 1.0), "All terminated flags should be 1.0" + + def test_math_env_step_mixed(math_env, mixed_test_data): """Test MathEnvironment step with a mix of correct and incorrect responses.""" result = ray.get(