diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..085b8997 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,129 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Setup + +```bash +# Install in development mode with all dependencies +pip install -e ".[dev]" + +# Optional: install Kubernetes support +pip install -e ".[all]" + +# Setup pre-commit hooks +pre-commit install +``` + +## Commands + +```bash +# Run all tests +pytest + +# Run specific test file +pytest tests/unit/test_error_handling.py -v + +# Run specific test class or function +pytest tests/unit/test_error_handling.py::TestErrorPatternMatching -v + +# Run tests with coverage +pytest --cov=src/madengine --cov-report=html + +# Skip slow tests +pytest -m "not slow" + +# Format code +black src/ tests/ +isort src/ tests/ + +# Lint +flake8 src/ tests/ + +# Type check +mypy src/madengine + +# Run all pre-commit checks +pre-commit run --all-files +``` + +## Architecture + +madengine is a CLI tool for running AI/ML models in local Docker, Kubernetes, and SLURM environments. The entry point is `madengine.cli.app:cli_main` (registered as the `madengine` console script). + +### Layer Structure + +**CLI Layer** (`src/madengine/cli/`) +- `app.py` — Typer app wiring, registers 5 commands: `discover`, `build`, `run`, `report`, `database` +- `commands/` — One file per command (build, run, discover, report, database) +- `constants.py` — `ExitCode` enum (`SUCCESS=0`, `FAILURE=1`, `BUILD_FAILURE=2`, `RUN_FAILURE=3`, `INVALID_ARGS=4`) + +**Orchestration Layer** (`src/madengine/orchestration/`) +- `build_orchestrator.py` — `BuildOrchestrator`: discovers models, builds Docker images, writes `build_manifest.json` +- `run_orchestrator.py` — `RunOrchestrator`: reads or triggers builds, infers deployment target, delegates to local or distributed execution + +**Core Layer** (`src/madengine/core/`) +- `context.py` — `Context` class: merges `additional_context` with system detection (GPU vendor, architecture, OS, ROCm path). Uses `ast.literal_eval()` to parse additional_context strings (not `json.loads` — pass Python dict repr, not JSON) +- `console.py` — `Console`: shell execution wrapper with live output support +- `docker.py` — Docker command wrapper + +**Execution Layer** (`src/madengine/execution/`) +- `container_runner.py` — `ContainerRunner`: runs models from manifest via `docker run`, writes results to `perf.csv` +- `docker_builder.py` — `DockerBuilder`: builds images from Dockerfiles +- `container_runner_helpers.py` — Log error pattern scanning, timeout resolution + +**Deployment Layer** (`src/madengine/deployment/`) +- `factory.py` — `DeploymentFactory`: Factory pattern, registers `SlurmDeployment` and `KubernetesDeployment` +- `base.py` — `BaseDeployment` abstract class, `DeploymentConfig` dataclass +- `kubernetes.py` / `slurm.py` — Concrete deployments; target is inferred by Convention over Configuration: presence of `"k8s"` or `"kubernetes"` key → K8s; `"slurm"` key → SLURM; neither → local +- `presets/` — JSON preset files for K8s/SLURM default configurations; auto-merged with minimal user configs +- `config_loader.py` — Loads and merges preset JSON with user-supplied config + +**Utils** (`src/madengine/utils/`) +- `discover_models.py` — `DiscoverModels`: three discovery methods: root `models.json`, `scripts/{dir}/models.json`, or `scripts/{dir}/get_models_json.py` (dynamic) +- `gpu_tool_factory.py` / `gpu_tool_manager.py` — GPU vendor abstraction (AMD/NVIDIA) +- `gpu_validator.py` — ROCm installation detection, GPU vendor detection +- `config_parser.py` — `ConfigParser`: parses `--additional-context` and tools config + +**Reporting** (`src/madengine/reporting/`) +- `update_perf_csv.py` — Writes/appends to `perf.csv` and `perf_entry.csv` +- `csv_to_html.py` / `csv_to_email.py` — Report generation + +### Key Data Flows + +1. **Build flow**: CLI → `BuildOrchestrator` → `DiscoverModels` (finds models by tags) → `DockerBuilder` (builds images) → writes `build_manifest.json` + +2. **Run flow**: CLI → `RunOrchestrator` → loads/generates `build_manifest.json` → infers target → `ContainerRunner` (local) or `DeploymentFactory` (K8s/SLURM) → writes `perf.csv` + +3. **`additional_context`**: User JSON/Python-dict string merged into `Context.ctx`. Context is parsed with `ast.literal_eval()`, so values can use Python dict syntax. Keys like `k8s`, `slurm`, `distributed`, `tools`, `pre_scripts`, `post_scripts` drive behavior. + +4. **Model definition**: Models defined in `models.json` with fields: `name`, `tags`, `dockerfile`, `scripts`, `n_gpus`, `args`, `timeout`, `skip_gpu_arch`, etc. + +5. **Script isolation**: During run, `scripts/common/` is populated from the madengine package (pre_scripts, post_scripts, tools) and cleaned up afterwards. The MAD project's own `scripts/` and `docker/` directories are preserved. + +### Deployment Target Inference + +No explicit `"deploy"` field is needed. Target is inferred from config structure: +- `"k8s"` or `"kubernetes"` key present → Kubernetes deployment +- `"slurm"` key present → SLURM deployment +- Neither → local Docker execution + +### Test Structure + +``` +tests/ +├── unit/ # Fast isolated tests with mocking +├── integration/ # End-to-end with real Docker/system calls +├── e2e/ # Full workflow tests +└── fixtures/ # Dummy models, scripts, and data for testing +``` + +Pytest config is in `pyproject.toml` under `[tool.pytest.ini_options]`. Test markers: `slow`, `integration`. + +### Code Style + +- Black formatting, 88-character line length +- isort with `profile = "black"` +- Google-style docstrings +- Type hints required for public functions +- Conventional commits: `feat:`, `fix:`, `docs:`, `test:`, `refactor:`, `style:`, `perf:`, `chore:` diff --git a/pyproject.toml b/pyproject.toml index 81fded5e..0c83f30a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,8 @@ dependencies = [ "GitPython", "jsondiff", "sqlalchemy", - "setuptools-rust", "paramiko", "tqdm", - "pytest", "typing-extensions", "pymongo", "toml", diff --git a/src/madengine/cli/app.py b/src/madengine/cli/app.py index 66d3256b..2e761f49 100644 --- a/src/madengine/cli/app.py +++ b/src/madengine/cli/app.py @@ -8,6 +8,7 @@ """ import sys +from importlib.metadata import PackageNotFoundError, version as pkg_version import typer from rich.traceback import install @@ -55,9 +56,12 @@ def main( Built with Typer and Rich for a beautiful, production-ready experience. """ if version: - # You might want to get the actual version from your package + try: + _version = pkg_version("madengine") + except PackageNotFoundError: + _version = "unknown" console.print( - "🚀 [bold cyan]madengine[/bold cyan] version [green]2.0.0[/green]" + f"🚀 [bold cyan]madengine[/bold cyan] version [green]{_version}[/green]" ) raise typer.Exit() diff --git a/src/madengine/cli/commands/run.py b/src/madengine/cli/commands/run.py index a684973d..09d90772 100644 --- a/src/madengine/cli/commands/run.py +++ b/src/madengine/cli/commands/run.py @@ -194,6 +194,9 @@ def run( # Convert -1 (default) to actual default timeout value (7200 seconds = 2 hours) if timeout == -1: timeout = 7200 + # 0 means "no timeout" per the help text — map to None so subprocess never expires + elif timeout == 0: + timeout = None try: # Check if we're doing execution-only or full workflow diff --git a/src/madengine/cli/constants.py b/src/madengine/cli/constants.py index f32eb024..b437fa30 100644 --- a/src/madengine/cli/constants.py +++ b/src/madengine/cli/constants.py @@ -5,11 +5,13 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +from enum import IntEnum + # Exit codes -class ExitCode: +class ExitCode(IntEnum): """Exit codes for CLI commands.""" - + SUCCESS = 0 FAILURE = 1 BUILD_FAILURE = 2 diff --git a/src/madengine/cli/validators.py b/src/madengine/cli/validators.py index d99e87f7..b4e08e8b 100644 --- a/src/madengine/cli/validators.py +++ b/src/madengine/cli/validators.py @@ -395,6 +395,8 @@ def process_batch_manifest_entries( # If the model was not built (build_new=false), create an entry for it if not build_new: + # Initialize with a safe fallback so the except block can always reference it + dockerfile_matched = "unknown" # Find the model configuration by discovering models with this tag try: # Create a temporary args object to discover the model diff --git a/src/madengine/core/auth.py b/src/madengine/core/auth.py new file mode 100644 index 00000000..f8caf116 --- /dev/null +++ b/src/madengine/core/auth.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +""" +Shared authentication utilities for madengine. + +Centralises credential loading logic used by both BuildOrchestrator and +RunOrchestrator so that fixes and improvements only need to be made once. + +Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +""" + +import json +import os +import shlex +from typing import Dict, Optional + +from madengine.core.errors import ( + ConfigurationError, + create_error_context, + handle_error, +) + + +def load_credentials() -> Optional[Dict]: + """Load credentials from credential.json and environment variables. + + Precedence (highest wins): + 1. ``MAD_DOCKERHUB_USER`` / ``MAD_DOCKERHUB_PASSWORD`` environment vars + (merged into the ``dockerhub`` key of the returned dict) + 2. ``credential.json`` in the current working directory + + Returns: + Credentials dict (keyed by registry name), or ``None`` if no + credentials are found. + """ + credentials: Optional[Dict] = None + + credential_file = "credential.json" + if os.path.exists(credential_file): + try: + with open(credential_file) as f: + credentials = json.load(f) + print( + f"Loaded credentials from {credential_file}: " + f"{list(credentials.keys())}" + ) + except Exception as e: + context = create_error_context( + operation="load_credentials", + component="auth", + file_path=credential_file, + ) + handle_error( + ConfigurationError( + f"Could not load credentials: {e}", + context=context, + suggestions=[ + "Check if credential.json exists and has valid JSON format" + ], + ) + ) + + # Environment variables override / supplement file credentials + docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") + docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") + docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") + + if docker_hub_user and docker_hub_password: + print("Found Docker Hub credentials in environment variables") + if credentials is None: + credentials = {} + credentials["dockerhub"] = { + "username": docker_hub_user, + "password": docker_hub_password, + } + if docker_hub_repo: + credentials["dockerhub"]["repository"] = docker_hub_repo + + return credentials + + +def login_to_registry( + registry: str, + credentials: Optional[Dict], + console, + rich_console, + raise_on_failure: bool = True, +) -> None: + """Login to a Docker registry. + + This is the single shared implementation used by both DockerBuilder + and ContainerRunner. + + Args: + registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty + for DockerHub). + credentials: Credentials dictionary keyed by registry name. + console: A ``Console`` instance for shell execution. + rich_console: A Rich ``Console`` instance for formatted output. + raise_on_failure: If ``True`` (default), raise ``RuntimeError`` on any + failure (missing key, invalid format, or docker login error). + Set to ``False`` to log and return instead, allowing the caller + to fall back to pulling public images. + """ + if not credentials: + rich_console.print( + "[yellow]No credentials provided for registry login[/yellow]" + ) + return + + registry_key = registry if registry else "dockerhub" + + # Normalise docker.io → dockerhub + if registry and registry.lower() == "docker.io": + registry_key = "dockerhub" + + if registry_key not in credentials: + error_msg = f"No credentials found for registry: {registry_key}" + if registry_key == "dockerhub": + error_msg += ( + f"\nPlease add dockerhub credentials to credential.json:\n" + "{\n" + ' "dockerhub": {\n' + ' "repository": "your-repository",\n' + ' "username": "your-dockerhub-username",\n' + ' "password": "your-dockerhub-password-or-token"\n' + " }\n" + "}" + ) + else: + error_msg += ( + f"\nPlease add {registry_key} credentials to credential.json:\n" + "{\n" + f' "{registry_key}": {{\n' + f' "repository": "your-repository",\n' + f' "username": "your-{registry_key}-username",\n' + f' "password": "your-{registry_key}-password"\n' + " }\n" + "}" + ) + rich_console.print(f"[red]{error_msg}[/red]") + if raise_on_failure: + raise RuntimeError(error_msg) + return + + creds = credentials[registry_key] + + if "username" not in creds or "password" not in creds: + error_msg = ( + f"Invalid credentials format for registry: {registry_key}" + f"\nCredentials must contain 'username' and 'password' fields" + ) + rich_console.print(f"[red]{error_msg}[/red]") + if raise_on_failure: + raise RuntimeError(error_msg) + return + + username = str(creds["username"]) + password = str(creds["password"]) + + quoted_password = shlex.quote(password) + quoted_username = shlex.quote(username) + login_command = f"printf %s {quoted_password} | docker login" + if registry and registry.lower() not in ["docker.io", "dockerhub"]: + login_command += f" {shlex.quote(str(registry))}" + login_command += f" --username {quoted_username} --password-stdin" + + try: + console.sh(login_command, secret=True) + rich_console.print( + f"[green]Successfully logged in to registry: " + f"{registry or 'DockerHub'}[/green]" + ) + except Exception as e: + rich_console.print( + f"[red]Failed to login to registry {registry}: {e}[/red]" + ) + if raise_on_failure: + raise diff --git a/src/madengine/core/context.py b/src/madengine/core/context.py index 24763588..6d8089ff 100644 --- a/src/madengine/core/context.py +++ b/src/madengine/core/context.py @@ -23,7 +23,7 @@ # third-party modules from madengine.core.console import Console from madengine.core.constants import get_rocm_path -from madengine.utils.gpu_validator import validate_rocm_installation, GPUInstallationError, GPUVendor +from madengine.utils.gpu_validator import GPUVendor from madengine.utils.gpu_tool_factory import get_gpu_tool_manager from madengine.utils.gpu_tool_manager import BaseGPUToolManager @@ -395,11 +395,8 @@ def get_gpu_vendor(self) -> str: for amd_smi_path in amd_smi_paths: if os.path.exists(amd_smi_path): try: - # Debug: log to stderr so SLURM node .err captures where we are if killed - print(f"[DEBUG] get_gpu_vendor: trying amd-smi at {amd_smi_path}", file=sys.stderr, flush=True) # Verify amd-smi actually works (180s timeout for slow GPU initialization) result = self.console.sh(f"{amd_smi_path} list > /dev/null 2>&1 && echo 'AMD' || echo ''", timeout=180) - print(f"[DEBUG] get_gpu_vendor: amd-smi returned", file=sys.stderr, flush=True) if result and result.strip() == "AMD": return "AMD" except Exception as e: @@ -409,9 +406,7 @@ def get_gpu_vendor(self) -> str: rocm_smi_path = os.path.join(self._rocm_path, "bin", "rocm-smi") if os.path.exists(rocm_smi_path): try: - print(f"[DEBUG] get_gpu_vendor: trying rocm-smi at {rocm_smi_path}", file=sys.stderr, flush=True) result = self.console.sh(f"{rocm_smi_path} --showid > /dev/null 2>&1 && echo 'AMD' || echo ''", timeout=180) - print(f"[DEBUG] get_gpu_vendor: rocm-smi returned", file=sys.stderr, flush=True) if result and result.strip() == "AMD": return "AMD" except Exception as e: @@ -439,11 +434,11 @@ def get_host_os(self) -> str: "if [ -f \"$(which apt)\" ]; then echo 'HOST_UBUNTU'; elif [ -f \"$(which yum)\" ]; then echo 'HOST_CENTOS'; elif [ -f \"$(which zypper)\" ]; then echo 'HOST_SLES'; elif [ -f \"$(which tdnf)\" ]; then echo 'HOST_AZURE'; else echo 'Unable to detect Host OS'; fi || true" ) - def get_numa_balancing(self) -> bool: + def get_numa_balancing(self) -> typing.Union[str, bool]: """Get NUMA balancing. Returns: - bool: The output of the shell command. + Union[str, bool]: The shell command output as a string, or False if the path does not exist. Raises: RuntimeError: If the NUMA balancing is not enabled or disabled. diff --git a/src/madengine/core/dataprovider.py b/src/madengine/core/dataprovider.py index c0df24a5..809c4425 100644 --- a/src/madengine/core/dataprovider.py +++ b/src/madengine/core/dataprovider.py @@ -164,8 +164,6 @@ def check_source(self, config: typing.Dict) -> bool: # get the base directory of the current file. BASE_DIR = os.path.dirname(os.path.realpath(__file__)) - print("DEBUG - BASE_DIR::", BASE_DIR) - print("DEBUG - self.config[path]::", self.config["path"]) # check if the path exists in the base directory. # if os.path.exists(BASE_DIR + "/../" + self.config["path"]): diff --git a/src/madengine/core/docker.py b/src/madengine/core/docker.py index 42f88263..f9b5c6c9 100644 --- a/src/madengine/core/docker.py +++ b/src/madengine/core/docker.py @@ -7,6 +7,7 @@ """ # built-in modules import os +import shlex import typing # user-defined modules @@ -32,7 +33,7 @@ def __init__( mounts: typing.Optional[typing.List] = None, envVars: typing.Optional[typing.Dict] = None, keep_alive: bool = False, - console: Console = Console(), + console: Console = None, ) -> None: """Constructor of the Docker class. @@ -52,13 +53,14 @@ def __init__( self.docker_sha = None self.keep_alive = keep_alive cwd = os.getcwd() - self.console = console + self.console = console if console is not None else Console() self.userid = self.console.sh("id -u") self.groupid = self.console.sh("id -g") # check if container name exists + container_name_quoted = shlex.quote(container_name) container_name_exists = self.console.sh( - "docker container ps -a | grep " + container_name + " | wc -l" + "docker container ps -a | grep " + container_name_quoted + " | wc -l" ) # if container name exists, clean it up automatically if container_name_exists != "0": @@ -67,11 +69,11 @@ def __init__( ) # Stop the container (with timeout) self.console.sh( - f"docker stop -t 1 {container_name} 2>/dev/null || true" + f"docker stop -t 1 {container_name_quoted} 2>/dev/null || true" ) # Remove the container self.console.sh( - f"docker rm -f {container_name} 2>/dev/null || true" + f"docker rm -f {container_name_quoted} 2>/dev/null || true" ) print(f"✓ Cleaned up existing container '{container_name}'") @@ -93,7 +95,7 @@ def __init__( # add envVars if envVars is not None: for evar in envVars.keys(): - command += "-e " + evar + "=" + envVars[evar] + " " + command += "-e " + evar + "=" + shlex.quote(str(envVars[evar])) + " " command += "--workdir /myworkspace/ " command += "--name " + container_name + " " @@ -123,7 +125,7 @@ def sh(self, command: str, timeout: int = 60, secret: bool = False) -> str: """ # run as root! return self.console.sh( - "docker exec " + self.docker_sha + ' bash -c "' + command + '"', + "docker exec " + self.docker_sha + " bash -c " + shlex.quote(command), timeout=timeout, secret=secret, ) diff --git a/src/madengine/core/errors.py b/src/madengine/core/errors.py index 18ba92f8..6a0757ab 100644 --- a/src/madengine/core/errors.py +++ b/src/madengine/core/errors.py @@ -7,7 +7,6 @@ """ import logging -import traceback from dataclasses import dataclass from typing import Optional, Any, Dict, List from enum import Enum @@ -16,14 +15,13 @@ from rich.console import Console from rich.panel import Panel from rich.text import Text - from rich.table import Table except ImportError: raise ImportError("Rich is required for error handling. Install with: pip install rich") class ErrorCategory(Enum): """Error category enumeration for classification.""" - + VALIDATION = "validation" CONNECTION = "connection" AUTHENTICATION = "authentication" @@ -72,25 +70,25 @@ def __init__( class ValidationError(MADEngineError): """Validation and input errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.VALIDATION, - context, + message, + ErrorCategory.VALIDATION, + context, recoverable=True, **kwargs ) -class ConnectionError(MADEngineError): +class NetworkError(MADEngineError): """Connection and network errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.CONNECTION, - context, + message, + ErrorCategory.CONNECTION, + context, recoverable=True, **kwargs ) @@ -98,12 +96,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class AuthenticationError(MADEngineError): """Authentication and credential errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.AUTHENTICATION, - context, + message, + ErrorCategory.AUTHENTICATION, + context, recoverable=True, **kwargs ) @@ -122,10 +120,6 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg ) -# Backward compatibility alias -RuntimeError = ExecutionError - - class BuildError(MADEngineError): """Build and compilation errors.""" @@ -154,12 +148,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class OrchestrationError(MADEngineError): """Distributed orchestration errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.ORCHESTRATION, - context, + message, + ErrorCategory.ORCHESTRATION, + context, recoverable=False, **kwargs ) @@ -167,12 +161,12 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class RunnerError(MADEngineError): """Distributed runner errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.RUNNER, - context, + message, + ErrorCategory.RUNNER, + context, recoverable=True, **kwargs ) @@ -180,25 +174,25 @@ def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwarg class ConfigurationError(MADEngineError): """Configuration and setup errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.CONFIGURATION, - context, + message, + ErrorCategory.CONFIGURATION, + context, recoverable=True, **kwargs ) -class TimeoutError(MADEngineError): +class DeploymentTimeoutError(MADEngineError): """Timeout and duration errors.""" - + def __init__(self, message: str, context: Optional[ErrorContext] = None, **kwargs): super().__init__( - message, - ErrorCategory.TIMEOUT, - context, + message, + ErrorCategory.TIMEOUT, + context, recoverable=True, **kwargs ) @@ -387,4 +381,4 @@ def create_error_context( phase=phase, component=component, **kwargs - ) \ No newline at end of file + ) diff --git a/src/madengine/core/timeout.py b/src/madengine/core/timeout.py index 0f72bd84..7fbdcb2e 100644 --- a/src/madengine/core/timeout.py +++ b/src/madengine/core/timeout.py @@ -7,7 +7,6 @@ """ # built-in modules import signal -import typing class Timeout: diff --git a/src/madengine/deployment/base.py b/src/madengine/deployment/base.py index 52bbe02f..00306505 100644 --- a/src/madengine/deployment/base.py +++ b/src/madengine/deployment/base.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from jinja2 import Environment, FileSystemLoader from rich.console import Console @@ -43,6 +43,7 @@ class DeploymentStatus(Enum): SUCCESS = "success" FAILED = "failed" CANCELLED = "cancelled" + UNKNOWN = "unknown" @dataclass @@ -224,7 +225,7 @@ def _monitor_until_complete(self, deployment_id: str) -> DeploymentResult: while True: status = self.monitor(deployment_id) - if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED]: + if status.status in [DeploymentStatus.SUCCESS, DeploymentStatus.FAILED, DeploymentStatus.UNKNOWN]: return status # Still running, wait and check again @@ -631,11 +632,13 @@ def _write_to_perf_csv(self, perf_data: Dict[str, Any]) -> None: row_to_write = perf_data with open(perf_csv_path, "a", newline="", encoding="utf-8") as f: - writer = csv.DictWriter(f, fieldnames=headers, extrasaction="ignore") - if not file_exists: - writer.writeheader() if file_exists and existing_header: + # File already has a header — write a plain row using csv.writer + # to preserve the exact column order captured in row_to_write csv.writer(f).writerow(row_to_write) else: + # New file — write header then the data row via DictWriter + writer = csv.DictWriter(f, fieldnames=headers, extrasaction="ignore") + writer.writeheader() writer.writerow(row_to_write) diff --git a/src/madengine/deployment/common.py b/src/madengine/deployment/common.py index 93ae1881..5b898960 100644 --- a/src/madengine/deployment/common.py +++ b/src/madengine/deployment/common.py @@ -8,6 +8,7 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +import functools import subprocess from typing import Any, Dict, List, Optional @@ -84,6 +85,7 @@ def normalize_launcher(launcher_type: Optional[str], deployment_type: str) -> st return "docker" +@functools.lru_cache(maxsize=None) def is_rocprofv3_available() -> bool: """ Check if rocprofv3 is available on the system. diff --git a/src/madengine/deployment/config_loader.py b/src/madengine/deployment/config_loader.py index fdbf3c94..06d8a1b1 100644 --- a/src/madengine/deployment/config_loader.py +++ b/src/madengine/deployment/config_loader.py @@ -11,7 +11,6 @@ """ import json -import os from pathlib import Path from typing import Any, Callable, Dict, Optional from copy import deepcopy diff --git a/src/madengine/deployment/factory.py b/src/madengine/deployment/factory.py index 9391d3a3..e2b503d6 100644 --- a/src/madengine/deployment/factory.py +++ b/src/madengine/deployment/factory.py @@ -88,8 +88,14 @@ def register_default_deployments(): DeploymentFactory.register("k8s", KubernetesDeployment) DeploymentFactory.register("kubernetes", KubernetesDeployment) except ImportError: - # Kubernetes library not installed, skip registration - pass + import warnings + warnings.warn( + "Kubernetes deployment target is unavailable: the 'kubernetes' library is not " + "installed. Install it with: pip install madengine[kubernetes] " + "(or pip install madengine[all]).", + ImportWarning, + stacklevel=2, + ) # Auto-register on module import diff --git a/src/madengine/deployment/kubernetes.py b/src/madengine/deployment/kubernetes.py index 3430c9d0..92be5549 100644 --- a/src/madengine/deployment/kubernetes.py +++ b/src/madengine/deployment/kubernetes.py @@ -37,17 +37,13 @@ from .base import BaseDeployment, DeploymentConfig, DeploymentResult, DeploymentStatus, create_jinja_env from .common import ( - VALID_LAUNCHERS, configure_multi_node_profiling, - is_rocprofv3_available, normalize_launcher, ) from .config_loader import ConfigLoader, apply_deployment_config from .k8s_secrets import ( CONFIGMAP_MAX_BYTES, - SECRETS_STRATEGY_EXISTING, SECRETS_STRATEGY_FROM_LOCAL, - SECRETS_STRATEGY_OMIT, create_or_update_secrets_from_credentials, delete_job_secrets_if_exist, estimate_configmap_payload_bytes, @@ -58,7 +54,7 @@ ) from madengine.core.dataprovider import Data from madengine.core.context import Context -from madengine.core.errors import ConfigurationError, create_error_context +from madengine.core.errors import ConfigurationError from madengine.utils.gpu_config import resolve_runtime_gpus from madengine.utils.path_utils import get_madengine_root, scripts_base_dir_from from madengine.utils.run_details import flatten_tags_in_place, get_build_number, get_pipeline diff --git a/src/madengine/deployment/slurm.py b/src/madengine/deployment/slurm.py index a45f83d3..82afa0aa 100644 --- a/src/madengine/deployment/slurm.py +++ b/src/madengine/deployment/slurm.py @@ -1019,20 +1019,26 @@ def _check_job_completion(self, job_id: str) -> DeploymentResult: message=f"Job {job_id} failed: {status}", ) - # Fallback - assume completed - self.console.print(f"[dim yellow]Warning: Could not get status for job {job_id}, assuming success[/dim yellow]") + # sacct returned non-zero — status unknown, do not assume success + self.console.print( + f"[yellow]Warning: sacct returned non-zero for job {job_id} " + f"(exit code {result.returncode}). Status cannot be verified.[/yellow]" + ) return DeploymentResult( - status=DeploymentStatus.SUCCESS, + status=DeploymentStatus.UNKNOWN, deployment_id=job_id, - message=f"Job {job_id} completed (assumed)", + message=f"Job {job_id} status unknown: sacct exited with code {result.returncode}", ) except Exception as e: - self.console.print(f"[dim yellow]Warning: Exception checking job {job_id}: {e}[/dim yellow]") + self.console.print( + f"[yellow]Warning: Exception checking job {job_id} status: {e}. " + f"Status cannot be verified.[/yellow]" + ) return DeploymentResult( - status=DeploymentStatus.SUCCESS, + status=DeploymentStatus.UNKNOWN, deployment_id=job_id, - message=f"Job {job_id} completed (status unavailable)", + message=f"Job {job_id} status unknown: {e}", ) def _build_perf_entry_from_aggregated( diff --git a/src/madengine/execution/container_runner.py b/src/madengine/execution/container_runner.py index d5f27cf0..2b2c3d46 100644 --- a/src/madengine/execution/container_runner.py +++ b/src/madengine/execution/container_runner.py @@ -17,6 +17,7 @@ import warnings from rich.console import Console as RichConsole from contextlib import redirect_stdout, redirect_stderr +from madengine.core.auth import login_to_registry from madengine.core.console import Console from madengine.core.context import Context from madengine.core.docker import Docker @@ -389,72 +390,16 @@ def load_build_manifest( def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: """Login to a Docker registry for pulling images. - Args: - registry: Registry URL (e.g., "localhost:5000", "docker.io") - credentials: Optional credentials dictionary containing username/password + Delegates to :func:`madengine.core.auth.login_to_registry`. + Does not raise on failure so public images can still be pulled. """ - if not credentials: - self.rich_console.print("[yellow]No credentials provided for registry login[/yellow]") - return - - # Check if registry credentials are available - registry_key = registry if registry else "dockerhub" - - # Handle docker.io as dockerhub - if registry and registry.lower() == "docker.io": - registry_key = "dockerhub" - - if registry_key not in credentials: - error_msg = f"No credentials found for registry: {registry_key}" - if registry_key == "dockerhub": - error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" - error_msg += "{\n" - error_msg += ' "dockerhub": {\n' - error_msg += ' "repository": "your-repository",\n' - error_msg += ' "username": "your-dockerhub-username",\n' - error_msg += ' "password": "your-dockerhub-password-or-token"\n' - error_msg += " }\n" - error_msg += "}" - else: - error_msg += ( - f"\nPlease add {registry_key} credentials to credential.json:\n" - ) - error_msg += "{\n" - error_msg += f' "{registry_key}": {{\n' - error_msg += f' "repository": "your-repository",\n' - error_msg += f' "username": "your-{registry_key}-username",\n' - error_msg += f' "password": "your-{registry_key}-password"\n' - error_msg += " }\n" - error_msg += "}" - print(error_msg) - raise RuntimeError(error_msg) - - creds = credentials[registry_key] - - if "username" not in creds or "password" not in creds: - error_msg = f"Invalid credentials format for registry: {registry_key}" - error_msg += f"\nCredentials must contain 'username' and 'password' fields" - print(error_msg) - raise RuntimeError(error_msg) - - # Ensure credential values are strings - username = str(creds["username"]) - password = str(creds["password"]) - - # Perform docker login - login_command = f"echo '{password}' | docker login" - - if registry and registry.lower() not in ["docker.io", "dockerhub"]: - login_command += f" {registry}" - - login_command += f" --username {username} --password-stdin" - - try: - self.console.sh(login_command, secret=True) - self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") - except Exception as e: - self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") - # Don't raise exception here, as public images might still be pullable + login_to_registry( + registry, + credentials, + console=self.console, + rich_console=self.rich_console, + raise_on_failure=False, + ) def pull_image( self, @@ -493,7 +438,7 @@ def pull_image( try: self.console.sh(f"docker rmi -f {registry_image} 2>/dev/null || true") print(f"✓ Removed cached image layers") - except: + except Exception: pass # It's okay if image doesn't exist try: @@ -597,14 +542,14 @@ def get_env_arg(self, run_env: typing.Dict) -> str: # Add custom environment variables if run_env: for env_arg in run_env: - env_args += f"--env {env_arg}='{str(run_env[env_arg])}' " + env_args += f"--env {env_arg}={shlex.quote(str(run_env[env_arg]))} " # Add context environment variables if "docker_env_vars" in self.context.ctx: for env_arg in self.context.ctx["docker_env_vars"].keys(): - env_args += f"--env {env_arg}='{str(self.context.ctx['docker_env_vars'][env_arg])}' " + value = self.context.ctx["docker_env_vars"][env_arg] + env_args += f"--env {env_arg}={shlex.quote(str(value))} " - print(f"Env arguments: {env_args}") return env_args def get_mount_arg(self, mount_datapaths: typing.List) -> str: diff --git a/src/madengine/execution/docker_builder.py b/src/madengine/execution/docker_builder.py index f769b85f..b900a359 100644 --- a/src/madengine/execution/docker_builder.py +++ b/src/madengine/execution/docker_builder.py @@ -8,20 +8,20 @@ """ import os +import shlex import time import json import re import typing from contextlib import redirect_stdout, redirect_stderr from rich.console import Console as RichConsole +from madengine.core.auth import login_to_registry from madengine.core.console import Console from madengine.core.context import Context from madengine.utils.ops import PythonicTee from madengine.execution.dockerfile_utils import ( - is_compilation_arch_compatible, is_target_arch_compatible_with_variable, parse_dockerfile_gpu_variables, - parse_gpu_variable_value, ) @@ -67,7 +67,7 @@ def get_context_path(self, info: typing.Dict) -> str: return "." return "./docker" - def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: + def get_build_arg(self, run_build_arg: typing.Optional[typing.Dict] = None) -> str: """Get the build arguments. Args: @@ -76,6 +76,8 @@ def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: Returns: str: The build arguments. """ + if run_build_arg is None: + run_build_arg = {} if not run_build_arg and "docker_build_arg" not in self.context.ctx: return "" @@ -84,14 +86,14 @@ def get_build_arg(self, run_build_arg: typing.Dict = {}) -> str: build_args += ( "--build-arg " + build_arg - + "='" - + self.context.ctx["docker_build_arg"][build_arg] - + "' " + + "=" + + shlex.quote(self.context.ctx["docker_build_arg"][build_arg]) + + " " ) if run_build_arg: for key, value in run_build_arg.items(): - build_args += "--build-arg " + key + "='" + value + "' " + build_args += "--build-arg " + key + "=" + shlex.quote(value) + " " return build_args @@ -248,72 +250,15 @@ def build_image( def login_to_registry(self, registry: str, credentials: typing.Dict = None) -> None: """Login to a Docker registry. - Args: - registry: Registry URL (e.g., "localhost:5000", "docker.io", or empty for DockerHub) - credentials: Optional credentials dictionary containing username/password + Delegates to :func:`madengine.core.auth.login_to_registry`. """ - if not credentials: - print("No credentials provided for registry login") - return - - # Check if registry credentials are available - registry_key = registry if registry else "dockerhub" - - # Handle docker.io as dockerhub - if registry and registry.lower() == "docker.io": - registry_key = "dockerhub" - - if registry_key not in credentials: - error_msg = f"No credentials found for registry: {registry_key}" - if registry_key == "dockerhub": - error_msg += f"\nPlease add dockerhub credentials to credential.json:\n" - error_msg += "{\n" - error_msg += ' "dockerhub": {\n' - error_msg += ' "repository": "your-repository",\n' - error_msg += ' "username": "your-dockerhub-username",\n' - error_msg += ' "password": "your-dockerhub-password-or-token"\n' - error_msg += " }\n" - error_msg += "}" - else: - error_msg += ( - f"\nPlease add {registry_key} credentials to credential.json:\n" - ) - error_msg += "{\n" - error_msg += f' "{registry_key}": {{\n' - error_msg += f' "repository": "your-repository",\n' - error_msg += f' "username": "your-{registry_key}-username",\n' - error_msg += f' "password": "your-{registry_key}-password"\n' - error_msg += " }\n" - error_msg += "}" - self.rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) - - creds = credentials[registry_key] - - if "username" not in creds or "password" not in creds: - error_msg = f"Invalid credentials format for registry: {registry_key}" - error_msg += f"\nCredentials must contain 'username' and 'password' fields" - self.rich_console.print(f"[red]{error_msg}[/red]") - raise RuntimeError(error_msg) - - # Ensure credential values are strings - username = str(creds["username"]) - password = str(creds["password"]) - - # Perform docker login - login_command = f"echo '{password}' | docker login" - - if registry and registry.lower() not in ["docker.io", "dockerhub"]: - login_command += f" {registry}" - - login_command += f" --username {username} --password-stdin" - - try: - self.console.sh(login_command, secret=True) - self.rich_console.print(f"[green]✅ Successfully logged in to registry: {registry or 'DockerHub'}[/green]") - except Exception as e: - self.rich_console.print(f"[red]❌ Failed to login to registry {registry}: {e}[/red]") - raise + login_to_registry( + registry, + credentials, + console=self.console, + rich_console=self.rich_console, + raise_on_failure=True, + ) def push_image( self, @@ -604,8 +549,10 @@ def _check_dockerfile_has_gpu_variables(self, model_info: typing.Dict) -> typing def _get_dockerfiles_for_model(self, model_info: typing.Dict) -> typing.List[str]: """Get dockerfiles for a model.""" try: + # Quote the dockerfile path to prevent shell injection + dockerfile_quoted = shlex.quote(model_info["dockerfile"]) all_dockerfiles = self.console.sh( - f"ls {model_info['dockerfile']}.*" + f"ls {dockerfile_quoted}.*" ).split("\n") dockerfiles = {} diff --git a/src/madengine/orchestration/build_orchestrator.py b/src/madengine/orchestration/build_orchestrator.py index da06f91f..3f67ff29 100644 --- a/src/madengine/orchestration/build_orchestrator.py +++ b/src/madengine/orchestration/build_orchestrator.py @@ -10,7 +10,6 @@ import json import os -import shutil from pathlib import Path from typing import Dict, List, Optional @@ -20,12 +19,12 @@ from madengine.core.console import Console from madengine.core.context import Context from madengine.core.additional_context_defaults import apply_build_context_defaults +from madengine.core.auth import load_credentials from madengine.core.errors import ( BuildError, ConfigurationError, DiscoveryError, create_error_context, - handle_error, ) from madengine.utils.discover_models import DiscoverModels from madengine.execution.docker_builder import DockerBuilder @@ -104,9 +103,8 @@ def __init__(self, args, additional_context: Optional[Dict] = None): # 4. Add 'deploy' field for internal use self.additional_context = ConfigLoader.load_config(self.additional_context) except ValueError as e: - # Configuration validation error - fail fast - self.rich_console.print(f"[red]Configuration Error: {e}[/red]") - raise SystemExit(1) + # Re-raise as ConfigurationError so the CLI layer handles the exit code + raise ConfigurationError(str(e)) except Exception as e: # Other errors during config loading - warn but continue self.rich_console.print(f"[yellow]Warning: Could not apply config defaults: {e}[/yellow]") @@ -131,53 +129,7 @@ def __init__(self, args, additional_context: Optional[Dict] = None): ) # Load credentials if available - self.credentials = self._load_credentials() - - def _load_credentials(self) -> Optional[Dict]: - """Load credentials from credential.json and environment variables.""" - credentials = None - - # Try loading from file - credential_file = "credential.json" - if os.path.exists(credential_file): - try: - with open(credential_file) as f: - credentials = json.load(f) - print(f"Loaded credentials from {credential_file}: {list(credentials.keys())}") - except Exception as e: - context = create_error_context( - operation="load_credentials", - component="BuildOrchestrator", - file_path=credential_file, - ) - handle_error( - ConfigurationError( - f"Could not load credentials: {e}", - context=context, - suggestions=[ - "Check if credential.json exists and has valid JSON format" - ], - ) - ) - - # Override with environment variables if present - docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") - docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") - docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") - - if docker_hub_user and docker_hub_password: - print("Found Docker Hub credentials in environment variables") - if credentials is None: - credentials = {} - - credentials["dockerhub"] = { - "username": docker_hub_user, - "password": docker_hub_password, - } - if docker_hub_repo: - credentials["dockerhub"]["repository"] = docker_hub_repo - - return credentials + self.credentials = load_credentials() def _copy_scripts(self): """[DEPRECATED] Copy common scripts to model directories. diff --git a/src/madengine/orchestration/run_orchestrator.py b/src/madengine/orchestration/run_orchestrator.py index 6725a457..681eb4c9 100644 --- a/src/madengine/orchestration/run_orchestrator.py +++ b/src/madengine/orchestration/run_orchestrator.py @@ -21,6 +21,7 @@ from rich.panel import Panel from madengine.core.console import Console +from madengine.core.auth import load_credentials from madengine.core.context import Context from madengine.core.dataprovider import Data from madengine.core.errors import ( @@ -28,7 +29,6 @@ ConfigurationError, ExecutionError, create_error_context, - handle_error, ) from madengine.core.constants import get_rocm_path from madengine.utils.session_tracker import SessionTracker @@ -554,7 +554,7 @@ def _execute_local(self, manifest_file: str, timeout: int) -> Dict: from madengine.execution.container_runner import ContainerRunner # Load credentials - credentials = self._load_credentials() + credentials = load_credentials() # Restore context from manifest if present if "context" in manifest: @@ -992,35 +992,6 @@ def ignore_cache_files(directory, files): # Note: K8s and Slurm deployments have their own script handling mechanisms # and do not rely on this local filesystem operation - def _load_credentials(self) -> Optional[Dict]: - """Load credentials from credential.json and environment.""" - credentials = None - - credential_file = "credential.json" - if os.path.exists(credential_file): - try: - with open(credential_file) as f: - credentials = json.load(f) - except Exception as e: - print(f"Warning: Could not load credentials: {e}") - - # Override with environment variables - docker_hub_user = os.environ.get("MAD_DOCKERHUB_USER") - docker_hub_password = os.environ.get("MAD_DOCKERHUB_PASSWORD") - docker_hub_repo = os.environ.get("MAD_DOCKERHUB_REPO") - - if docker_hub_user and docker_hub_password: - if credentials is None: - credentials = {} - credentials["dockerhub"] = { - "username": docker_hub_user, - "password": docker_hub_password, - } - if docker_hub_repo: - credentials["dockerhub"]["repository"] = docker_hub_repo - - return credentials - def _filter_images_by_gpu_compatibility( self, built_images: Dict, runtime_gpu_vendor: str, runtime_gpu_arch: str ) -> Dict: @@ -1133,70 +1104,4 @@ def _infer_deployment_target(self, config: Dict) -> str: else: return "local" - def _filter_images_by_dockerfile_context(self, built_images: Dict) -> Dict: - """Filter images by dockerfile context matching runtime context. - - This implements the legacy behavior where dockerfiles are filtered - at runtime based on their CONTEXT header matching the current runtime context. - - Args: - built_images: Dictionary of built images from manifest - - Returns: - Dictionary of images that match the runtime context - """ - if not self.context: - return built_images - - compatible_images = {} - - for image_name, image_info in built_images.items(): - dockerfile = image_info.get("dockerfile", "") - - if not dockerfile: - # No dockerfile info, include by default (legacy compatibility) - compatible_images[image_name] = image_info - continue - - # Check if dockerfile exists - if not os.path.exists(dockerfile): - self.rich_console.print( - f"[dim] Warning: Dockerfile {dockerfile} not found. Including by default.[/dim]" - ) - compatible_images[image_name] = image_info - continue - - # Read dockerfile context header - try: - dockerfile_context_str = self.console.sh( - f"head -n5 {dockerfile} | grep '# CONTEXT ' | sed 's/# CONTEXT //g'" - ).strip() - - if not dockerfile_context_str: - # No context header, include by default - compatible_images[image_name] = image_info - continue - - # Create a dict with this dockerfile and its context - dockerfile_dict = {dockerfile: dockerfile_context_str} - - # Use context.filter() to check if this dockerfile matches runtime context - filtered = self.context.filter(dockerfile_dict) - - if filtered: - # Dockerfile matches runtime context - compatible_images[image_name] = image_info - else: - self.rich_console.print( - f"[dim] Skipping {image_name}: dockerfile context doesn't match runtime context[/dim]" - ) - - except Exception as e: - # If we can't read the dockerfile, include it by default - self.rich_console.print( - f"[dim] Warning: Could not read context for {dockerfile}: {e}. Including by default.[/dim]" - ) - compatible_images[image_name] = image_info - - return compatible_images diff --git a/src/madengine/reporting/csv_to_email.py b/src/madengine/reporting/csv_to_email.py index 0902ef00..4b21bc17 100644 --- a/src/madengine/reporting/csv_to_email.py +++ b/src/madengine/reporting/csv_to_email.py @@ -9,7 +9,7 @@ import os import argparse import logging -from typing import List, Tuple +from typing import List, Optional, Tuple import pandas as pd @@ -60,7 +60,7 @@ def csv_to_html_section(file_path: str) -> Tuple[str, str]: def convert_directory_csvs_to_html( directory_path: str, output_file: str = "run_results.html" -) -> str: +) -> Optional[str]: """Convert all CSV files in a directory to a single HTML file. Args: diff --git a/src/madengine/reporting/update_perf_csv.py b/src/madengine/reporting/update_perf_csv.py index 0859c9c0..f298efa2 100644 --- a/src/madengine/reporting/update_perf_csv.py +++ b/src/madengine/reporting/update_perf_csv.py @@ -62,7 +62,7 @@ def flatten_tags(perf_entry: dict): The performance entry with flattened tags. """ # flatten tags to a string, if tags is a list. - if type(perf_entry["tags"]) == list: + if isinstance(perf_entry["tags"], list): perf_entry["tags"] = ",".join(str(item) for item in perf_entry["tags"]) @@ -192,6 +192,9 @@ def handle_single_result(perf_csv_df: pd.DataFrame, single_result: str) -> pd.Da AssertionError: If the number of columns in the performance csv DataFrame is not equal """ single_result_json = read_json(single_result) + # Remove non-scalar fields that are not perf.csv columns (e.g. configs list). + # See handle_exception_result for rationale. + single_result_json.pop("configs", None) perf_entry_dict_to_csv(single_result_json) single_result_df = pd.DataFrame(single_result_json, index=[0]) if perf_csv_df.empty: @@ -226,6 +229,11 @@ def handle_exception_result( AssertionError: If there is already an entry for the model in the performance csv DataFrame. """ exception_result_json = read_json(exception_result) + # Remove non-scalar fields that are not perf.csv columns (e.g. configs list) + # before constructing a single-row DataFrame with index=[0]. + # pd.DataFrame(dict_with_list_value, index=[0]) raises ValueError when any + # dict value is a list whose length != 1. + exception_result_json.pop("configs", None) perf_entry_dict_to_csv(exception_result_json) exception_result_df = pd.DataFrame(exception_result_json, index=[0]) if perf_csv_df.empty: diff --git a/src/madengine/scripts/common/tools/amd_smi_utils.py b/src/madengine/scripts/common/tools/amd_smi_utils.py index 05975257..e0e48096 100644 --- a/src/madengine/scripts/common/tools/amd_smi_utils.py +++ b/src/madengine/scripts/common/tools/amd_smi_utils.py @@ -152,7 +152,7 @@ def check_if_secondary_die(self, device: int) -> bool: avg_power = power_info.get('average_socket_power', -1) if current_power == 0 and avg_power == 0: return True - except: + except Exception: # If we can't get power info, might be secondary die return True diff --git a/src/madengine/scripts/common/tools/rocm_smi_utils.py b/src/madengine/scripts/common/tools/rocm_smi_utils.py index 92dff9f2..dd73219b 100644 --- a/src/madengine/scripts/common/tools/rocm_smi_utils.py +++ b/src/madengine/scripts/common/tools/rocm_smi_utils.py @@ -38,7 +38,7 @@ def __init__(self, mode) -> None: raise ImportError('Driver not initialized (amdgpu not found in modules)') exit(0) self.rocm6 = True - except: + except Exception: rocm_smi.initializeRsmi() def get_power(self, device: int) -> str: diff --git a/src/madengine/utils/config_parser.py b/src/madengine/utils/config_parser.py index ec988570..04e71f9c 100644 --- a/src/madengine/utils/config_parser.py +++ b/src/madengine/utils/config_parser.py @@ -184,7 +184,7 @@ def _walk_up_between( current = os.path.abspath(start_dir) stop = os.path.abspath(stop_dir) - while current.startswith(stop): + while current == stop or current.startswith(stop + os.sep): parent = os.path.dirname(current) if parent == current: # Reached root break diff --git a/src/madengine/utils/discover_models.py b/src/madengine/utils/discover_models.py index 4c3c9201..0cf0438e 100644 --- a/src/madengine/utils/discover_models.py +++ b/src/madengine/utils/discover_models.py @@ -85,6 +85,7 @@ def _setup_model_dir_if_needed(self) -> None: # Only copy if MODEL_DIR points to a different directory (not current dir) if model_dir_abs != cwd_abs: + import shlex import subprocess from pathlib import Path @@ -121,7 +122,7 @@ def _setup_model_dir_if_needed(self) -> None: copied_count = 0 for src_path, item_name, item_type in items_to_copy: try: - cmd = f"cp -vLR --preserve=all {src_path} {cwd_abs}/" + cmd = f"cp -vLR --preserve=all {shlex.quote(str(src_path))} {shlex.quote(str(cwd_abs))}/" result = subprocess.run( cmd, shell=True, capture_output=True, text=True, check=True ) @@ -216,9 +217,12 @@ def discover_models(self) -> None: custom_model_list = get_models_json.list_models() for custom_model in custom_model_list: - assert isinstance( - custom_model, CustomModel - ), "Please use or subclass madengine.utils.discover_models.CustomModel to define your custom model." + if not isinstance(custom_model, CustomModel): + raise TypeError( + "Please use or subclass " + "madengine.utils.discover_models.CustomModel " + "to define your custom model." + ) # Update model name using backslash-separated path custom_model.name = dirname + "/" + custom_model.name # Defer updating script and dockerfile paths until update_model is called diff --git a/src/madengine/utils/gpu_config.py b/src/madengine/utils/gpu_config.py index ff6aabc8..4b3c4143 100644 --- a/src/madengine/utils/gpu_config.py +++ b/src/madengine/utils/gpu_config.py @@ -14,9 +14,12 @@ Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ +import logging import warnings from typing import Dict, Any, Optional, Tuple +logger = logging.getLogger(__name__) + class GPUConfigResolver: """ @@ -157,17 +160,18 @@ def _extract_gpu_count( # Warn if multiple GPU fields found if len(found_fields) > 1: field_list = ", ".join([f"{name}={val}" for name, val in found_fields]) - print( - f"⚠️ Multiple GPU fields in {context}: {field_list}. " - f"Using {found_fields[0][0]}={found_fields[0][1]}" + logger.warning( + "Multiple GPU fields in %s: %s. Using %s=%s", + context, field_list, found_fields[0][0], found_fields[0][1], ) # Convert to int (handle string values like "8") try: return int(found_fields[0][1]) except (ValueError, TypeError): - print( - f"⚠️ Invalid GPU count in {context}: {found_fields[0][1]}. Using default." + logger.warning( + "Invalid GPU count in %s: %s. Using default.", + context, found_fields[0][1], ) return None @@ -231,10 +235,9 @@ def _validate_consistency( if is_deployment_override: # This is normal - deployment config overriding model default - # Use print instead of warnings.warn for cleaner output - print( - f"ℹ️ GPU configuration override: {sources[0][0]}={sources[0][1]} " - f"(overriding model default: {mismatch_details.split(',')[-1].strip()})" + logger.info( + "GPU configuration override: %s=%s (overriding model default: %s)", + sources[0][0], sources[0][1], mismatch_details.split(",")[-1].strip(), ) else: # Potentially unexpected mismatch - use warning for actual errors @@ -302,7 +305,7 @@ def resolve_runtime_gpus( validate=True, ) - print(f"ℹ️ Resolved GPU count: {gpu_count} (from {source})") + logger.info("Resolved GPU count: %s (from %s)", gpu_count, source) return gpu_count diff --git a/src/madengine/utils/gpu_validator.py b/src/madengine/utils/gpu_validator.py index 7014268a..8429891e 100644 --- a/src/madengine/utils/gpu_validator.py +++ b/src/madengine/utils/gpu_validator.py @@ -10,7 +10,7 @@ import subprocess import os -from typing import Dict, List, Tuple, Optional +from typing import List, Tuple, Optional from dataclasses import dataclass from enum import Enum diff --git a/src/madengine/utils/log_formatting.py b/src/madengine/utils/log_formatting.py index 31673c93..14a0eed5 100644 --- a/src/madengine/utils/log_formatting.py +++ b/src/madengine/utils/log_formatting.py @@ -9,10 +9,8 @@ """ import pandas as pd -import typing from rich.table import Table from rich.console import Console as RichConsole -from rich.text import Text def format_dataframe_for_log( @@ -82,7 +80,7 @@ def format_dataframe_for_log( header += f"📏 Shape: {df.shape[0]} rows × {df.shape[1]} columns\n" if truncated_rows: - header += f"⚠️ Display truncated: showing first {max_rows} rows\n" + header += f"⚠️ Display truncated: showing last {max_rows} rows\n" header += f"{'='*80}\n" @@ -209,33 +207,3 @@ def print_dataframe_beautiful( # Fallback to simple but nice formatting formatted_output = format_dataframe_for_log(df, title) print(formatted_output) - - -def highlight_log_section(title: str, content: str, style: str = "info") -> str: - """ - Create a highlighted log section with borders and styling. - - Args: - title: Section title - content: Section content - style: Style type ('info', 'success', 'warning', 'error') - - Returns: - str: Formatted log section - """ - styles = { - "info": {"emoji": "ℹ️", "border": "-"}, - "success": {"emoji": "✅", "border": "="}, - "warning": {"emoji": "⚠️", "border": "!"}, - "error": {"emoji": "❌", "border": "#"}, - } - - style_config = styles.get(style, styles["info"]) - emoji = style_config["emoji"] - border_char = style_config["border"] - - border = border_char * 80 - header = f"\n{border}\n{emoji} {title.upper()}\n{border}" - footer = f"{border}\n" - - return f"{header}\n{content}\n{footer}" diff --git a/src/madengine/utils/ops.py b/src/madengine/utils/ops.py index 0b8ab077..cd717fec 100644 --- a/src/madengine/utils/ops.py +++ b/src/madengine/utils/ops.py @@ -5,15 +5,12 @@ functions: PythonicTee: Class to both write and display stream, in "live" mode - find_and_replace_pattern: Find and replace a substring in a dictionary - substring_found: Check if a substring is found in the dictionary file_print: Write and flush file Copyright (c) Advanced Micro Devices, Inc. All rights reserved. """ # built-in modules import typing -import re import sys @@ -53,46 +50,6 @@ def flush(self) -> None: self.stdio.flush() -def find_and_replace_pattern( - dictionary: typing.Dict, substring: str, replacement: str -) -> typing.Dict: - """Find and replace a substring in a dictionary. - - Args: - dictionary: The dictionary. - substring: The substring to find. - replacement: The replacement string. - - Returns: - The updated dictionary. - """ - updated_dict = {} - # iterate over the dictionary, replace the substring with the replacement string. - for key, value in dictionary.items(): - updated_key = str(key).replace(substring, replacement) - updated_value = str(value).replace(substring, replacement) - updated_dict[updated_key] = updated_value - - return updated_dict - - -def substring_found(dictionary: typing.Dict, substring: str) -> bool: - """Check if a substring is found in the dictionary. - - Args: - dictionary: The dictionary. - substring: The substring to find. - - Returns: - True if the substring is found, False otherwise. - """ - # iterate over the dictionary, check if the substring is found in the key or value. - for key, value in dictionary.items(): - if substring in str(key) or substring in str(value): - return True - return False - - def file_print(write_str: str, filename: str, mode: str = "a") -> None: """Write and flush file. diff --git a/src/madengine/utils/rocm_tool_manager.py b/src/madengine/utils/rocm_tool_manager.py index 439f7da2..60870d29 100644 --- a/src/madengine/utils/rocm_tool_manager.py +++ b/src/madengine/utils/rocm_tool_manager.py @@ -199,14 +199,15 @@ def execute_command( self._log_debug(f"Command succeeded: {command[:50]}...") return stdout - # Log primary failure - self._log_warning(f"Primary command failed: {command[:50]}... Error: {stderr}") - + # Capture primary error before attempting fallback (fallback overwrites stderr) + primary_stderr = stderr + self._log_warning(f"Primary command failed: {command[:50]}... Error: {primary_stderr}") + # Try fallback if provided if fallback_command: self._log_info(f"Trying fallback command: {fallback_command[:50]}...") success, stdout, stderr = self._execute_shell_command(fallback_command, timeout) - + if success: self._log_warning("Fallback command succeeded (primary tool may be missing or misconfigured)") return stdout @@ -215,7 +216,7 @@ def execute_command( raise RuntimeError( f"Both primary and fallback commands failed.\n" f"Primary: {command}\n" - f"Primary error: {stderr}\n" + f"Primary error: {primary_stderr}\n" f"Fallback: {fallback_command}\n" f"Fallback error: {stderr}" ) diff --git a/src/madengine/utils/session_tracker.py b/src/madengine/utils/session_tracker.py index 6ddd1d92..4449e496 100644 --- a/src/madengine/utils/session_tracker.py +++ b/src/madengine/utils/session_tracker.py @@ -47,11 +47,12 @@ def start_session(self) -> int: The starting row number (number of rows in CSV before this session) """ if self.perf_csv_path.exists(): - # Count existing rows (excluding header) + # Count existing data rows (excluding header and blank lines) with open(self.perf_csv_path, 'r') as f: lines = f.readlines() + non_empty = [l for l in lines if l.strip()] # Subtract 1 for header row - self.session_start_row = max(0, len(lines) - 1) + self.session_start_row = max(0, len(non_empty) - 1) else: # No existing file, start at 0 self.session_start_row = 0 @@ -61,15 +62,6 @@ def start_session(self) -> int: return self.session_start_row - def get_session_start(self) -> Optional[int]: - """ - Get the session start row. - - Returns: - Session start row number, or None if session not started - """ - return self.session_start_row - def get_session_row_count(self) -> int: """ Get the number of rows added during this session. @@ -85,7 +77,8 @@ def get_session_row_count(self) -> int: with open(self.perf_csv_path, 'r') as f: lines = f.readlines() - current_row_count = max(0, len(lines) - 1) # Exclude header + non_empty = [l for l in lines if l.strip()] + current_row_count = max(0, len(non_empty) - 1) # Exclude header return current_row_count - self.session_start_row @@ -101,23 +94,6 @@ def _save_marker(self, start_row: int): with open(self.marker_file, 'w') as f: f.write(str(start_row)) - def load_marker(self) -> Optional[int]: - """ - Load session start marker from file. - - Uses the marker file path from this instance's perf_csv_path. - - Returns: - Session start row, or None if file doesn't exist - """ - if self.marker_file.exists(): - try: - with open(self.marker_file, 'r') as f: - return int(f.read().strip()) - except (ValueError, IOError): - return None - return None - def cleanup_marker(self): """ Remove session marker file for this instance. diff --git a/tests/integration/test_docker_integration.py b/tests/integration/test_docker_integration.py index 14041e3e..a7421d6b 100644 --- a/tests/integration/test_docker_integration.py +++ b/tests/integration/test_docker_integration.py @@ -8,6 +8,7 @@ # built-in modules import os import json +import shlex import tempfile import unittest.mock from unittest.mock import patch, MagicMock, mock_open @@ -169,8 +170,8 @@ def test_get_build_arg_with_context_args( result = builder.get_build_arg() - assert "--build-arg ARG1='value1'" in result - assert "--build-arg ARG2='value2'" in result + assert f"--build-arg ARG1={shlex.quote('value1')}" in result + assert f"--build-arg ARG2={shlex.quote('value2')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) @@ -188,7 +189,7 @@ def test_get_build_arg_with_run_args( run_build_arg = {"RUNTIME_ARG": "runtime_value"} result = builder.get_build_arg(run_build_arg) - assert "--build-arg RUNTIME_ARG='runtime_value'" in result + assert f"--build-arg RUNTIME_ARG={shlex.quote('runtime_value')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) @@ -207,8 +208,8 @@ def test_get_build_arg_with_both_args( run_build_arg = {"RUNTIME_ARG": "runtime_value"} result = builder.get_build_arg(run_build_arg) - assert "--build-arg CONTEXT_ARG='context_value'" in result - assert "--build-arg RUNTIME_ARG='runtime_value'" in result + assert f"--build-arg CONTEXT_ARG={shlex.quote('context_value')}" in result + assert f"--build-arg RUNTIME_ARG={shlex.quote('runtime_value')}" in result @patch.object(Context, "get_gpu_vendor", return_value="AMD") @patch.object(Context, "get_system_ngpus", return_value=1) diff --git a/tests/integration/test_errors.py b/tests/integration/test_errors.py index e325e6ac..c0a88876 100644 --- a/tests/integration/test_errors.py +++ b/tests/integration/test_errors.py @@ -126,7 +126,7 @@ def test_error_logging_integration(self): def test_error_context_serialization(self): """Error context can be serialized for logging.""" - from madengine.core.errors import RuntimeError + from madengine.core.errors import ExecutionError context = create_error_context( operation="model_execution", @@ -136,7 +136,7 @@ def test_error_context_serialization(self): node_id="worker-node-01", additional_info={"container_id": "abc123", "gpu_count": 2}, ) - error = RuntimeError("Model execution failed", context=context) + error = ExecutionError("Model execution failed", context=context) data = json.dumps(error.context.__dict__, default=str) assert "model_execution" in data and "ContainerRunner" in data and "abc123" in data @@ -184,28 +184,28 @@ def test_error_hierarchy_consistency(self): """All error types inherit MADEngineError and have context/category/recoverable.""" from madengine.core.errors import ( ValidationError, - ConnectionError, + NetworkError, AuthenticationError, - RuntimeError, + ExecutionError, BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ) for error_class in [ ValidationError, - ConnectionError, + NetworkError, AuthenticationError, - RuntimeError, + ExecutionError, BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ]: err = error_class("Test error message") assert isinstance(err, MADEngineError) @@ -245,9 +245,9 @@ def test_error_suggestions_and_recovery(self): def test_nested_error_handling(self): """Nested errors with cause chain are handled.""" - from madengine.core.errors import RuntimeError as MADRuntimeError, OrchestrationError + from madengine.core.errors import ExecutionError as MADRuntimeError, OrchestrationError, NetworkError - orig = ConnectionError("Network timeout") + orig = NetworkError("Network timeout") runtime = MADRuntimeError("Operation failed", cause=orig) final = OrchestrationError("Orchestration failed", cause=runtime) assert final.cause == runtime and runtime.cause == orig diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 00000000..bd34e14d --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,114 @@ +"""Unit tests for madengine.core.auth module.""" + +import os +from unittest.mock import mock_open, patch + +from madengine.core.auth import load_credentials + + +class TestLoadCredentials: + """Tests for load_credentials().""" + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"dockerhub": {"username": "user", "password": "pass"}}', + ) + def test_load_credentials_from_file(self, mock_file, mock_exists): + """Valid credential.json is loaded and returned.""" + result = load_credentials() + assert result is not None + assert "dockerhub" in result + assert result["dockerhub"]["username"] == "user" + assert result["dockerhub"]["password"] == "pass" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict(os.environ, {}, clear=True) + def test_load_credentials_no_file_no_env(self, mock_exists): + """Returns None when no credential file and no env vars.""" + result = load_credentials() + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch("builtins.open", new_callable=mock_open, read_data="not valid json{{{") + def test_load_credentials_malformed_json(self, mock_file, mock_exists): + """Malformed credential.json is handled gracefully (returns None).""" + # The function logs the error via handle_error but does not re-raise + result = load_credentials() + # credentials should be None since the file parse failed and no env vars + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser", "MAD_DOCKERHUB_PASSWORD": "envpass"}, + clear=True, + ) + def test_load_credentials_env_vars_only(self, mock_exists): + """Credentials from env vars when no file exists.""" + result = load_credentials() + assert result is not None + assert "dockerhub" in result + assert result["dockerhub"]["username"] == "envuser" + assert result["dockerhub"]["password"] == "envpass" + assert "repository" not in result["dockerhub"] + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"dockerhub": {"username": "fileuser", "password": "filepass"}}', + ) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser", "MAD_DOCKERHUB_PASSWORD": "envpass"}, + clear=True, + ) + def test_load_credentials_env_overrides_file(self, mock_file, mock_exists): + """Env vars override file credentials for dockerhub key.""" + result = load_credentials() + assert result is not None + assert result["dockerhub"]["username"] == "envuser" + assert result["dockerhub"]["password"] == "envpass" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + { + "MAD_DOCKERHUB_USER": "envuser", + "MAD_DOCKERHUB_PASSWORD": "envpass", + "MAD_DOCKERHUB_REPO": "myrepo/images", + }, + clear=True, + ) + def test_load_credentials_env_with_repo(self, mock_exists): + """MAD_DOCKERHUB_REPO is included when set.""" + result = load_credentials() + assert result is not None + assert result["dockerhub"]["repository"] == "myrepo/images" + + @patch("madengine.core.auth.os.path.exists", return_value=False) + @patch.dict( + os.environ, + {"MAD_DOCKERHUB_USER": "envuser"}, + clear=True, + ) + def test_load_credentials_env_user_only_no_password(self, mock_exists): + """Only MAD_DOCKERHUB_USER without PASSWORD does not create dockerhub entry.""" + result = load_credentials() + # Without both user and password, dockerhub credentials are not created + assert result is None + + @patch("madengine.core.auth.os.path.exists", return_value=True) + @patch( + "builtins.open", + new_callable=mock_open, + read_data='{"custom_registry": {"token": "abc123"}}', + ) + def test_load_credentials_non_dockerhub_registry(self, mock_file, mock_exists): + """Non-dockerhub registries in credential.json are preserved.""" + result = load_credentials() + assert result is not None + assert "custom_registry" in result + assert result["custom_registry"]["token"] == "abc123" diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py index a71c75e8..d51d94b9 100644 --- a/tests/unit/test_deployment.py +++ b/tests/unit/test_deployment.py @@ -85,6 +85,14 @@ def test_false_for_rocm_trace_lite(self): class TestIsRocprofv3Available: """is_rocprofv3_available (mocked subprocess).""" + def setup_method(self): + # Clear the lru_cache so each test starts with a fresh result + is_rocprofv3_available.cache_clear() + + def teardown_method(self): + # Restore clean cache state after each test + is_rocprofv3_available.cache_clear() + def test_returns_true_when_help_succeeds(self): with patch("madengine.deployment.common.subprocess.run") as m: m.return_value = MagicMock(returncode=0) diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 1fa808e4..45422c34 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -24,16 +24,15 @@ ErrorContext, MADEngineError, ValidationError, - ConnectionError, + NetworkError, AuthenticationError, ExecutionError, - RuntimeError, # Backward compatibility alias BuildError, DiscoveryError, OrchestrationError, RunnerError, ConfigurationError, - TimeoutError, + DeploymentTimeoutError, ErrorHandler, set_error_handler, get_error_handler, @@ -87,7 +86,7 @@ def test_base_madengine_error(self): @pytest.mark.parametrize("error_class,category,recoverable,message", [ (ValidationError, ErrorCategory.VALIDATION, True, "Invalid input"), - (ConnectionError, ErrorCategory.CONNECTION, True, "Connection failed"), + (NetworkError, ErrorCategory.CONNECTION, True, "Connection failed"), (BuildError, ErrorCategory.BUILD, False, "Build failed"), (RunnerError, ErrorCategory.RUNNER, True, "Runner execution failed"), (AuthenticationError, ErrorCategory.AUTHENTICATION, True, "Auth failed"), @@ -110,9 +109,9 @@ def test_error_with_cause(self): assert mad_error.cause == original_error assert str(mad_error) == "Runtime failure" - def test_backward_compatibility_alias(self): - """Test that RuntimeError alias still works.""" - error = RuntimeError("Test error") + def test_execution_error_is_mad_engine_error(self): + """Test that ExecutionError is a MADEngineError.""" + error = ExecutionError("Test error") assert isinstance(error, ExecutionError) assert isinstance(error, MADEngineError)