Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@
import argparse
import os
import pprint

from omegaconf import OmegaConf
from typing import Dict, Any
from collections import defaultdict
from typing import Any, Dict

from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from collections import defaultdict

from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.utils.config import load_config
from nemo_reinforcer.utils.logger import get_next_experiment_dir
from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec, LLMMessageLogType
from nemo_reinforcer.data import DataConfig
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset
from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec
Comment thread
hemildesai marked this conversation as resolved.
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides
from nemo_reinforcer.utils.logger import get_next_experiment_dir


def parse_args():
Expand All @@ -43,10 +42,7 @@ def parse_args():
)

# Parse known args for the script
args, remaining = parser.parse_known_args()

# Convert remaining args to OmegaConf format
overrides = OmegaConf.from_dotlist(remaining)
args, overrides = parser.parse_known_args()

return args, overrides

Expand Down Expand Up @@ -242,7 +238,7 @@ def main():

if overrides:
print(f"Overrides: {overrides}")
config = OmegaConf.merge(config, overrides)
config = parse_hydra_overrides(config, overrides)

config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")
Expand Down
31 changes: 31 additions & 0 deletions nemo_reinforcer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from pathlib import Path
from typing import Optional, Union

from hydra._internal.config_loader_impl import ConfigLoaderImpl
from hydra.core.override_parser.overrides_parser import OverridesParser
from omegaconf import DictConfig, ListConfig, OmegaConf


Expand Down Expand Up @@ -130,3 +132,32 @@ def load_config(config_path: Union[str, Path]) -> DictConfig:
Merged config dictionary
"""
return load_config_with_inheritance(config_path)


class OverridesError(Exception):
"""Custom exception for Hydra override parsing errors."""

pass


def parse_hydra_overrides(cfg: DictConfig, overrides: list[str]) -> DictConfig:
"""Parse and apply Hydra overrides to an OmegaConf config.

Args:
cfg: OmegaConf config to apply overrides to
overrides: List of Hydra override strings

Returns:
Updated config with overrides applied

Raises:
OverridesError: If there's an error parsing or applying overrides
"""
try:
OmegaConf.set_struct(cfg, True)
parser = OverridesParser.create()
parsed = parser.parse_overrides(overrides=overrides)
ConfigLoaderImpl._apply_overrides_to_config(overrides=parsed, cfg=cfg)
return cfg
except Exception as e:
raise OverridesError(f"Failed to parse Hydra overrides: {str(e)}") from e
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"torchdata",
"vllm==0.8.0",
"nvidia-ml-py",
"hydra-core",
]

[tool.setuptools]
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/utils/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,61 @@ def test_interpolation(temp_config_dir):
config = load_config(child_path)
assert config.base_value == 43
assert config.derived.value == 43 # Interpolation uses child's base_value


def test_parse_hydra_overrides():
"""Test parsing and applying Hydra overrides."""
from omegaconf import OmegaConf

from nemo_reinforcer.utils.config import OverridesError, parse_hydra_overrides

# Create initial config
cfg = OmegaConf.create(
{
"model": {"type": "default", "hidden_size": 768},
"training": {"batch_size": 32, "learning_rate": 1e-4},
}
)

# Test basic override
overrides = ["model.type=transformer"]
updated_cfg = parse_hydra_overrides(cfg, overrides)
assert updated_cfg.model.type == "transformer"
assert updated_cfg.model.hidden_size == 768 # Unchanged

# Test nested override
overrides = ["model.hidden_size=1024"]
updated_cfg = parse_hydra_overrides(cfg, overrides)
assert updated_cfg.model.hidden_size == 1024

# Test multiple overrides
overrides = ["training.batch_size=64", "training.learning_rate=2e-4"]
updated_cfg = parse_hydra_overrides(cfg, overrides)
assert updated_cfg.training.batch_size == 64
assert updated_cfg.training.learning_rate == 2e-4

# Test invalid override
overrides = ["nonexistent.key=value"]
with pytest.raises(OverridesError):
parse_hydra_overrides(cfg, overrides)

# Test invalid syntax
overrides = ["invalid.syntax"]
with pytest.raises(OverridesError):
parse_hydra_overrides(cfg, overrides)

# Test empty overrides
overrides = []
updated_cfg = parse_hydra_overrides(cfg, overrides)
assert updated_cfg == cfg # Config should be unchanged

# Test override additions and deletions
overrides = [
"+model.num_layers=12",
"++model.type=transformer",
"~training.batch_size",
]
updated_cfg = parse_hydra_overrides(cfg, overrides)
assert updated_cfg.model.num_layers == 12
assert updated_cfg.model.type == "transformer"
assert "batch_size" not in updated_cfg.training