Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -21,11 +21,12 @@
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, List, Optional
from typing import Any, Callable, List, Optional
from unittest.mock import Mock

import toml
import yaml
from pydantic import ValidationError

from cloudai.core import (
BaseInstaller,
Expand Down Expand Up @@ -145,7 +146,21 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
continue

env = CloudAIGymEnv(test_run=test_run, runner=runner.runner)
agent = agent_class(env)

try:
agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config)
except ValidationError as e:
logging.error(f"Invalid agent_config for agent '{agent_type}': ")
for error in e.errors():
logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}")
logging.error("Valid overrides: ")
for item, desc in validate_agent_overrides(agent_type).items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling validate_agent_overrides(agent_type) in error handler could raise another ValueError if agent doesn't support config, masking original validation error

Suggested change
for item, desc in validate_agent_overrides(agent_type).items():
logging.error(f"Invalid agent_config for agent '{agent_type}': ")
for error in e.errors():
logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}")

logging.error(f" - {item}: {desc}")
err = 1
continue

agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conditional passes kwargs only when non-empty, but empty dict is falsy. If agent_overrides = {}, the condition evaluates False and takes the else branch. Consider using is not None or explicit length check for clarity:

Suggested change
agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env)
agent = agent_class(env, **agent_overrides) if agent_overrides is not None else agent_class(env)

However, given validate_agent_overrides always returns a dict, the current logic works but is subtle.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


for step in range(agent.max_steps):
result = agent.select_action()
if result is None:
Expand All @@ -166,6 +181,37 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int:
return err


def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""
Validate and process agent configuration overrides.

If agent_config is empty, returns the available configuration fields for the agent type.
"""
registry = Registry()
config_class_map = {}
for agent_name, agent_class in registry.agents_map.items():
if agent_class.config:
config_class_map[agent_name] = agent_class.config

config_class = config_class_map.get(agent_type)
if not config_class:
valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map)
raise ValueError(
f"Agent type '{agent_type}' does not support configuration overrides. "
f"Valid agent types are: {valid_types}. "
)

if agent_config:
validated_config = config_class.model_validate(agent_config)
agent_kwargs = validated_config.model_dump(exclude_none=True)
logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}")
else:
agent_kwargs = {}
for field_name, field_info in config_class.model_fields.items():
agent_kwargs[field_name] = field_info.description
return agent_kwargs


def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None:
registry = Registry()

Expand Down
6 changes: 5 additions & 1 deletion src/cloudai/configurator/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple

from cloudai.models.agent_config import AgentConfig

from .base_gym import BaseGym

Expand All @@ -28,6 +30,8 @@ class BaseAgent(ABC):
Automatically infers parameter types from TestRun's cmd_args.
"""

config: Optional[AgentConfig] = None

def __init__(self, env: BaseGym):
"""
Initialize the agent with the environment.
Expand Down
27 changes: 27 additions & 0 deletions src/cloudai/models/agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from typing import Optional

from pydantic import BaseModel, ConfigDict, Field


class AgentConfig(BaseModel, ABC):
"""Base configuration for agent overrides."""

model_config = ConfigDict(extra="forbid")
random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility")
3 changes: 2 additions & 1 deletion src/cloudai/models/workload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -107,6 +107,7 @@ class TestDefinition(BaseModel, ABC):
agent_steps: int = 1
agent_metrics: list[str] = Field(default=["default"])
agent_reward_function: str = "inverse"
agent_config: Optional[dict[str, Any]] = None

@property
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:
Expand Down
Loading