diff --git a/src/cloudai/cli/handlers.py b/src/cloudai/cli/handlers.py index d474ff421..475092451 100644 --- a/src/cloudai/cli/handlers.py +++ b/src/cloudai/cli/handlers.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,11 +21,12 @@ import signal from contextlib import contextmanager from pathlib import Path -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional from unittest.mock import Mock import toml import yaml +from pydantic import ValidationError from cloudai.core import ( BaseInstaller, @@ -145,7 +146,21 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: continue env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) - agent = agent_class(env) + + try: + agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) + except ValidationError as e: + logging.error(f"Invalid agent_config for agent '{agent_type}': ") + for error in e.errors(): + logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}") + logging.error("Valid overrides: ") + for item, desc in validate_agent_overrides(agent_type).items(): + logging.error(f" - {item}: {desc}") + err = 1 + continue + + agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env) + for step in range(agent.max_steps): result = agent.select_action() if result is None: @@ -166,6 +181,37 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: return err +def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """ + Validate and process agent configuration overrides. + + If agent_config is empty, returns the available configuration fields for the agent type. + """ + registry = Registry() + config_class_map = {} + for agent_name, agent_class in registry.agents_map.items(): + if agent_class.config: + config_class_map[agent_name] = agent_class.config + + config_class = config_class_map.get(agent_type) + if not config_class: + valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map) + raise ValueError( + f"Agent type '{agent_type}' does not support configuration overrides. " + f"Valid agent types are: {valid_types}. " + ) + + if agent_config: + validated_config = config_class.model_validate(agent_config) + agent_kwargs = validated_config.model_dump(exclude_none=True) + logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") + else: + agent_kwargs = {} + for field_name, field_info in config_class.model_fields.items(): + agent_kwargs[field_name] = field_info.description + return agent_kwargs + + def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: registry = Registry() diff --git a/src/cloudai/configurator/base_agent.py b/src/cloudai/configurator/base_agent.py index dbd397099..4b806a53c 100644 --- a/src/cloudai/configurator/base_agent.py +++ b/src/cloudai/configurator/base_agent.py @@ -15,7 +15,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple + +from cloudai.models.agent_config import AgentConfig from .base_gym import BaseGym @@ -28,6 +30,8 @@ class BaseAgent(ABC): Automatically infers parameter types from TestRun's cmd_args. """ + config: Optional[AgentConfig] = None + def __init__(self, env: BaseGym): """ Initialize the agent with the environment. diff --git a/src/cloudai/models/agent_config.py b/src/cloudai/models/agent_config.py new file mode 100644 index 000000000..0b04059aa --- /dev/null +++ b/src/cloudai/models/agent_config.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class AgentConfig(BaseModel, ABC): + """Base configuration for agent overrides.""" + + model_config = ConfigDict(extra="forbid") + random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") diff --git a/src/cloudai/models/workload.py b/src/cloudai/models/workload.py index 1745ae734..0a962cf59 100644 --- a/src/cloudai/models/workload.py +++ b/src/cloudai/models/workload.py @@ -1,5 +1,5 @@ # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES -# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC): agent_steps: int = 1 agent_metrics: list[str] = Field(default=["default"]) agent_reward_function: str = "inverse" + agent_config: Optional[dict[str, Any]] = None @property def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]: