-
Notifications
You must be signed in to change notification settings - Fork 42
Agent params #792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: agent-dev
Are you sure you want to change the base?
Agent params #792
Changes from all commits
1012aa0
bd551e4
c1ecf25
9e7d6c3
e7fee0c
cfb38f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||
| # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES | ||||||
| # Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
| # Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|
|
@@ -21,11 +21,12 @@ | |||||
| import signal | ||||||
| from contextlib import contextmanager | ||||||
| from pathlib import Path | ||||||
| from typing import Callable, List, Optional | ||||||
| from typing import Any, Callable, List, Optional | ||||||
| from unittest.mock import Mock | ||||||
|
|
||||||
| import toml | ||||||
| import yaml | ||||||
| from pydantic import ValidationError | ||||||
|
|
||||||
| from cloudai.core import ( | ||||||
| BaseInstaller, | ||||||
|
|
@@ -145,7 +146,21 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: | |||||
| continue | ||||||
|
|
||||||
| env = CloudAIGymEnv(test_run=test_run, runner=runner.runner) | ||||||
| agent = agent_class(env) | ||||||
|
|
||||||
| try: | ||||||
| agent_overrides = validate_agent_overrides(agent_type, test_run.test.agent_config) | ||||||
| except ValidationError as e: | ||||||
| logging.error(f"Invalid agent_config for agent '{agent_type}': ") | ||||||
| for error in e.errors(): | ||||||
| logging.error(f" - {'.'.join(str(var_name) for var_name in error['loc'])}: {error['msg']}") | ||||||
| logging.error("Valid overrides: ") | ||||||
| for item, desc in validate_agent_overrides(agent_type).items(): | ||||||
| logging.error(f" - {item}: {desc}") | ||||||
| err = 1 | ||||||
| continue | ||||||
|
|
||||||
| agent = agent_class(env, **agent_overrides) if agent_overrides else agent_class(env) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conditional passes kwargs only when non-empty, but empty dict is falsy. If
Suggested change
However, given Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||||||
|
|
||||||
| for step in range(agent.max_steps): | ||||||
| result = agent.select_action() | ||||||
| if result is None: | ||||||
|
|
@@ -166,6 +181,37 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace) -> int: | |||||
| return err | ||||||
|
|
||||||
|
|
||||||
| def validate_agent_overrides(agent_type: str, agent_config: Optional[dict[str, Any]] = None) -> dict[str, Any]: | ||||||
| """ | ||||||
| Validate and process agent configuration overrides. | ||||||
|
|
||||||
| If agent_config is empty, returns the available configuration fields for the agent type. | ||||||
| """ | ||||||
| registry = Registry() | ||||||
| config_class_map = {} | ||||||
| for agent_name, agent_class in registry.agents_map.items(): | ||||||
| if agent_class.config: | ||||||
| config_class_map[agent_name] = agent_class.config | ||||||
|
|
||||||
| config_class = config_class_map.get(agent_type) | ||||||
| if not config_class: | ||||||
| valid_types = ", ".join(f"'{agent_name}'" for agent_name in config_class_map) | ||||||
| raise ValueError( | ||||||
| f"Agent type '{agent_type}' does not support configuration overrides. " | ||||||
| f"Valid agent types are: {valid_types}. " | ||||||
| ) | ||||||
|
|
||||||
| if agent_config: | ||||||
| validated_config = config_class.model_validate(agent_config) | ||||||
| agent_kwargs = validated_config.model_dump(exclude_none=True) | ||||||
| logging.info(f"Applying agent config overrides for '{agent_type}': {agent_kwargs}") | ||||||
| else: | ||||||
| agent_kwargs = {} | ||||||
| for field_name, field_info in config_class.model_fields.items(): | ||||||
| agent_kwargs[field_name] = field_info.description | ||||||
| return agent_kwargs | ||||||
|
|
||||||
|
|
||||||
| def generate_reports(system: System, test_scenario: TestScenario, result_dir: Path) -> None: | ||||||
| registry = Registry() | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES | ||
| # Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from abc import ABC | ||
| from typing import Optional | ||
|
|
||
| from pydantic import BaseModel, ConfigDict, Field | ||
|
|
||
|
|
||
| class AgentConfig(BaseModel, ABC): | ||
| """Base configuration for agent overrides.""" | ||
|
|
||
| model_config = ConfigDict(extra="forbid") | ||
| random_seed: Optional[int] = Field(default=None, description="Random seed for reproducibility") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling
validate_agent_overrides(agent_type)in error handler could raise anotherValueErrorif agent doesn't support config, masking original validation error