diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index b87c7037b3..c67829e80e 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -15,24 +15,23 @@ import argparse import os import pprint - -from omegaconf import OmegaConf -from typing import Dict, Any +from collections import defaultdict +from typing import Any, Dict from datasets import load_dataset +from omegaconf import OmegaConf from transformers import AutoTokenizer -from collections import defaultdict from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup -from nemo_reinforcer.distributed.virtual_cluster import init_ray -from nemo_reinforcer.utils.config import load_config -from nemo_reinforcer.utils.logger import get_next_experiment_dir -from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec, LLMMessageLogType from nemo_reinforcer.data import DataConfig -from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn -from nemo_reinforcer.environments.math_environment import MathEnvironment from nemo_reinforcer.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset +from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec +from nemo_reinforcer.distributed.virtual_cluster import init_ray +from nemo_reinforcer.environments.math_environment import MathEnvironment +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides +from nemo_reinforcer.utils.logger import get_next_experiment_dir def parse_args(): @@ -43,10 +42,7 @@ def parse_args(): ) # Parse known args for the script - args, remaining = parser.parse_known_args() - - # Convert remaining args to OmegaConf format - overrides = OmegaConf.from_dotlist(remaining) + args, overrides = parser.parse_known_args() return args, overrides @@ -242,7 +238,7 @@ def main(): if overrides: print(f"Overrides: {overrides}") - config = OmegaConf.merge(config, overrides) + config = parse_hydra_overrides(config, overrides) config: MasterConfig = OmegaConf.to_container(config, resolve=True) print("Applied CLI overrides") diff --git a/nemo_reinforcer/utils/config.py b/nemo_reinforcer/utils/config.py index 51418e3b1b..82c4be462b 100644 --- a/nemo_reinforcer/utils/config.py +++ b/nemo_reinforcer/utils/config.py @@ -15,6 +15,8 @@ from pathlib import Path from typing import Optional, Union +from hydra._internal.config_loader_impl import ConfigLoaderImpl +from hydra.core.override_parser.overrides_parser import OverridesParser from omegaconf import DictConfig, ListConfig, OmegaConf @@ -130,3 +132,32 @@ def load_config(config_path: Union[str, Path]) -> DictConfig: Merged config dictionary """ return load_config_with_inheritance(config_path) + + +class OverridesError(Exception): + """Custom exception for Hydra override parsing errors.""" + + pass + + +def parse_hydra_overrides(cfg: DictConfig, overrides: list[str]) -> DictConfig: + """Parse and apply Hydra overrides to an OmegaConf config. + + Args: + cfg: OmegaConf config to apply overrides to + overrides: List of Hydra override strings + + Returns: + Updated config with overrides applied + + Raises: + OverridesError: If there's an error parsing or applying overrides + """ + try: + OmegaConf.set_struct(cfg, True) + parser = OverridesParser.create() + parsed = parser.parse_overrides(overrides=overrides) + ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed, cfg=cfg) + return cfg + except Exception as e: + raise OverridesError(f"Failed to parse Hydra overrides: {str(e)}") from e diff --git a/pyproject.toml b/pyproject.toml index 752f25047d..5f72e283f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "torchdata", "vllm==0.8.0", "nvidia-ml-py", + "hydra-core", ] [tool.setuptools] diff --git a/tests/unit/utils/test_config.py b/tests/unit/utils/test_config.py index 245d9e0053..b8aa5328d2 100644 --- a/tests/unit/utils/test_config.py +++ b/tests/unit/utils/test_config.py @@ -196,3 +196,61 @@ def test_interpolation(temp_config_dir): config = load_config(child_path) assert config.base_value == 43 assert config.derived.value == 43 # Interpolation uses child's base_value + + +def test_parse_hydra_overrides(): + """Test parsing and applying Hydra overrides.""" + from omegaconf import OmegaConf + + from nemo_reinforcer.utils.config import OverridesError, parse_hydra_overrides + + # Create initial config + cfg = OmegaConf.create( + { + "model": {"type": "default", "hidden_size": 768}, + "training": {"batch_size": 32, "learning_rate": 1e-4}, + } + ) + + # Test basic override + overrides = ["model.type=transformer"] + updated_cfg = parse_hydra_overrides(cfg, overrides) + assert updated_cfg.model.type == "transformer" + assert updated_cfg.model.hidden_size == 768 # Unchanged + + # Test nested override + overrides = ["model.hidden_size=1024"] + updated_cfg = parse_hydra_overrides(cfg, overrides) + assert updated_cfg.model.hidden_size == 1024 + + # Test multiple overrides + overrides = ["training.batch_size=64", "training.learning_rate=2e-4"] + updated_cfg = parse_hydra_overrides(cfg, overrides) + assert updated_cfg.training.batch_size == 64 + assert updated_cfg.training.learning_rate == 2e-4 + + # Test invalid override + overrides = ["nonexistent.key=value"] + with pytest.raises(OverridesError): + parse_hydra_overrides(cfg, overrides) + + # Test invalid syntax + overrides = ["invalid.syntax"] + with pytest.raises(OverridesError): + parse_hydra_overrides(cfg, overrides) + + # Test empty overrides + overrides = [] + updated_cfg = parse_hydra_overrides(cfg, overrides) + assert updated_cfg == cfg # Config should be unchanged + + # Test override additions and deletions + overrides = [ + "+model.num_layers=12", + "++model.type=transformer", + "~training.batch_size", + ] + updated_cfg = parse_hydra_overrides(cfg, overrides) + assert updated_cfg.model.num_layers == 12 + assert updated_cfg.model.type == "transformer" + assert "batch_size" not in updated_cfg.training