From ea0f124962af833281020aac352d2518e52d3860 Mon Sep 17 00:00:00 2001 From: yaya1738 <65517364+yaya1738@users.noreply.github.com> Date: Thu, 4 Dec 2025 07:56:09 +0200 Subject: [PATCH] feat: Implement Model Lifecycle Manager - Systemd-based LLM service management Complete implementation of issue #220 ($150 bounty): Features: - Systemd service generation for any LLM backend - Multi-backend support (vLLM, llama.cpp, Ollama, TGI) - Health check monitoring with auto-restart on failure - Resource limits via systemd (CPU, memory, I/O, tasks) - Security hardening (NoNewPrivileges, ProtectSystem, etc.) - SQLite database for configuration persistence - Event logging for audit trail - Boot auto-start via systemd enable Files: - model_lifecycle.py: 993 lines of implementation - test_model_lifecycle.py: 907 lines, 63 tests (all passing) - README_MODEL_LIFECYCLE.md: Complete documentation Closes #220 --- README_MODEL_LIFECYCLE.md | 297 ++++++ cortex/kernel_features/model_lifecycle.py | 1016 +++++++++++++++++++-- tests/test_model_lifecycle.py | 907 ++++++++++++++++++ 3 files changed, 2119 insertions(+), 101 deletions(-) create mode 100644 README_MODEL_LIFECYCLE.md create mode 100644 tests/test_model_lifecycle.py diff --git a/README_MODEL_LIFECYCLE.md b/README_MODEL_LIFECYCLE.md new file mode 100644 index 00000000..a0e196c9 --- /dev/null +++ b/README_MODEL_LIFECYCLE.md @@ -0,0 +1,297 @@ +# Cortex Model Lifecycle Manager + +Systemd-based service management for LLM models. Brings "systemctl for AI models" to Cortex Linux. + +## Quick Start + +```bash +# Register a model +cortex-model register llama-70b --path meta-llama/Llama-2-70b-hf --backend vllm --gpus 0,1 + +# Start the model +cortex-model start llama-70b + +# Check status +cortex-model status + +# Enable auto-start on boot +cortex-model enable llama-70b + +# View logs +cortex-model logs llama-70b -f +``` + +## Features + +- **Systemd Service Generation**: Creates proper systemd user services for any LLM backend +- **Multi-Backend Support**: vLLM, llama.cpp, Ollama, Text Generation Inference (TGI) +- **Health Check Monitoring**: HTTP endpoint checks with automatic restart on failure +- **Resource Limits**: CPU, memory, I/O, and task limits via systemd cgroups +- **Security Hardening**: NoNewPrivileges, ProtectSystem, namespace isolation +- **SQLite Persistence**: Configuration and event logging +- **Boot Auto-Start**: Enable models to start automatically on system boot + +## Supported Backends + +| Backend | Command | Health Endpoint | +|---------|---------|-----------------|| +| vLLM | `python -m vllm.entrypoints.openai.api_server` | `/health` | +| llama.cpp | `llama-server` | `/health` | +| Ollama | `ollama serve` | `/api/tags` | +| TGI | `text-generation-launcher` | `/health` | + +## Commands + +### Register a Model + +```bash +cortex-model register --path [options] + +Options: + --backend Backend: vllm, llamacpp, ollama, tgi (default: vllm) + --port Service port (default: 8000) + --host Service host (default: 127.0.0.1) + --gpus Comma-separated GPU IDs (default: 0) + --memory Memory limit (default: 32G) + --cpu CPU cores limit (default: 4.0) + --max-model-len Maximum sequence length (default: 4096) + --tensor-parallel Tensor parallel size (default: 1) + --quantization Quantization method: awq, gptq + --extra-args Extra backend arguments + --no-health-check Disable health monitoring +``` + +### Lifecycle Commands + +```bash +cortex-model start # Start a model service +cortex-model stop # Stop a model service +cortex-model restart # Restart a model service +cortex-model enable # Enable auto-start on boot +cortex-model disable # Disable auto-start +cortex-model unregister # Remove model completely +``` + +### Status and Monitoring + +```bash +cortex-model status # List all models with state +cortex-model status # Show specific model status +cortex-model list # Alias for status +cortex-model logs # View systemd journal logs +cortex-model logs -f # Follow logs in real-time +cortex-model events # Show all model events +cortex-model events # Show events for specific model +cortex-model health # Check health endpoint +``` + +## Usage Examples + +### vLLM with Multiple GPUs + +```bash +cortex-model register llama-70b \ + --path meta-llama/Llama-2-70b-hf \ + --backend vllm \ + --gpus 0,1,2,3 \ + --tensor-parallel 4 \ + --memory 128G \ + --max-model-len 8192 + +cortex-model start llama-70b +cortex-model enable llama-70b +``` + +### Quantized Model with AWQ + +```bash +cortex-model register llama-awq \ + --path TheBloke/Llama-2-70B-AWQ \ + --backend vllm \ + --quantization awq \ + --gpus 0 + +cortex-model start llama-awq +``` + +### Local GGUF Model with llama.cpp + +```bash +cortex-model register local-gguf \ + --path /models/llama-7b.Q4_K_M.gguf \ + --backend llamacpp \ + --port 8080 + +cortex-model start local-gguf +``` + +### TGI for Production + +```bash +cortex-model register tgi-prod \ + --path bigscience/bloom-7b1 \ + --backend tgi \ + --gpus 0,1 \ + --tensor-parallel 2 \ + --host 0.0.0.0 \ + --port 8000 + +cortex-model start tgi-prod +cortex-model enable tgi-prod +``` + +## Configuration + +### Resource Limits + +Models are configured with systemd resource limits: + +| Setting | Default | Description | +|---------|---------|-------------| +| MemoryMax | 32G | Hard memory limit | +| MemoryHigh | 28G | Soft memory limit (triggers reclaim) | +| CPUQuota | 400% | CPU cores (100% = 1 core) | +| CPUWeight | 100 | CPU scheduling weight (1-10000) | +| IOWeight | 100 | I/O scheduling weight (1-10000) | +| TasksMax | 512 | Maximum processes/threads | + +### Security Hardening + +Default security settings (can be customized): + +| Setting | Default | Description | +|---------|---------|-------------| +| NoNewPrivileges | true | Prevent privilege escalation | +| ProtectSystem | strict | Read-only /usr and /boot | +| ProtectHome | read-only | Read-only home directory | +| PrivateTmp | true | Private /tmp namespace | +| PrivateDevices | false | False to allow GPU access | +| RestrictRealtime | true | Prevent realtime scheduling | +| ProtectKernelTunables | true | Protect sysctl | +| ProtectKernelModules | true | Prevent module loading | + +### Health Checks + +Health monitoring configuration: + +| Setting | Default | Description | +|---------|---------|-------------| +| enabled | true | Enable health monitoring | +| endpoint | /health | HTTP endpoint to check | +| interval_seconds | 30 | Check interval | +| timeout_seconds | 10 | Request timeout | +| max_failures | 3 | Failures before restart | +| startup_delay_seconds | 60 | Wait before first check | + +## Architecture + +``` +ModelLifecycleManager +|-- ModelDatabase (SQLite) +| |-- models table (configuration) +| +-- events table (audit log) +|-- ServiceGenerator (systemd units) +| |-- Backend templates (vLLM, TGI, etc.) +| |-- Resource limits +| +-- Security hardening ++-- HealthChecker (monitoring) + |-- HTTP endpoint checks + +-- Auto-restart logic + +Configuration: +|-- ~/.cortex/models.db # SQLite database +|-- ~/.config/systemd/user/ # Service files +| +-- cortex-.service ++-- ~/.cortex/logs/ # Local logs +``` + +## Service File Example + +Generated service file for a vLLM model: + +```ini +[Unit] +Description=Cortex Model: llama-70b +Documentation=https://github.com/cortexlinux/cortex +After=network.target +Wants=network-online.target + +[Service] +Type=simple +ExecStart=python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-70b-hf --host 127.0.0.1 --port 8000 --gpu-memory-utilization 0.9 --max-model-len 4096 --tensor-parallel-size 4 +Environment=CUDA_VISIBLE_DEVICES=0,1,2,3 +Environment=HIP_VISIBLE_DEVICES=0,1,2,3 +Environment=TOKENIZERS_PARALLELISM=false + +# Resource Limits +CPUQuota=400% +CPUWeight=100 +MemoryMax=128G +MemoryHigh=120G +IOWeight=100 +TasksMax=512 + +# Security Hardening +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=read-only +PrivateTmp=true +RestrictRealtime=true +RestrictSUIDSGID=true +ProtectKernelTunables=true +ProtectKernelModules=true +ProtectControlGroups=true + +# Restart Policy +Restart=on-failure +RestartSec=10 +StartLimitIntervalSec=300 +StartLimitBurst=5 + +# Logging +StandardOutput=journal +StandardError=journal +SyslogIdentifier=cortex-llama-70b + +[Install] +WantedBy=default.target +``` + +## Testing + +```bash +# Run all tests +pytest tests/test_model_lifecycle.py -v + +# Run specific test class +pytest tests/test_model_lifecycle.py::TestModelConfig -v + +# Run with coverage +pytest tests/test_model_lifecycle.py --cov=cortex.kernel_features.model_lifecycle +``` + +## Requirements + +- Python 3.8+ +- systemd with user services enabled +- One of: vLLM, llama.cpp, Ollama, or TGI installed + +### Enabling User Services + +```bash +# Enable lingering for user services to run without login +loginctl enable-linger $USER + +# Verify systemd user instance +systemctl --user status +``` + +## Files + +- `cortex/kernel_features/model_lifecycle.py` - Main implementation (~1000 lines) +- `tests/test_model_lifecycle.py` - Unit tests (~650 lines, 50+ tests) +- `README_MODEL_LIFECYCLE.md` - This documentation + +## Related Issues + +- [#220 Model Lifecycle Manager - Systemd-Based LLM Service Management](https://github.com/cortexlinux/cortex/issues/220) diff --git a/cortex/kernel_features/model_lifecycle.py b/cortex/kernel_features/model_lifecycle.py index 7a4205b0..4ab6e01d 100644 --- a/cortex/kernel_features/model_lifecycle.py +++ b/cortex/kernel_features/model_lifecycle.py @@ -3,6 +3,14 @@ Cortex Model Lifecycle Manager Manages LLM models as first-class system services using systemd. +Provides health monitoring, auto-restart, resource limits, and security hardening. + +Usage: + cortex model register llama-70b --path meta-llama/Llama-2-70b-hf --backend vllm --gpus 0,1 + cortex model start llama-70b + cortex model status + cortex model enable llama-70b # auto-start on boot + cortex model logs llama-70b """ import os @@ -10,169 +18,975 @@ import json import subprocess import sqlite3 +import threading +import time +import urllib.request +import urllib.error from pathlib import Path from dataclasses import dataclass, field, asdict -from typing import Optional, List, Dict, Any -from datetime import datetime +from typing import Optional, List, Dict, Any, Tuple +from datetime import datetime, timezone +from enum import Enum + +# Configuration paths CORTEX_DB_PATH = Path.home() / ".cortex/models.db" CORTEX_SERVICE_DIR = Path.home() / ".config/systemd/user" +CORTEX_LOG_DIR = Path.home() / ".cortex/logs" + + +class ModelState(Enum): + """Model service states.""" + UNKNOWN = "unknown" + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + DEACTIVATING = "deactivating" + FAILED = "failed" + RELOADING = "reloading" + + +class EventType(Enum): + """Event types for logging.""" + REGISTERED = "registered" + STARTED = "started" + STOPPED = "stopped" + ENABLED = "enabled" + DISABLED = "disabled" + UNREGISTERED = "unregistered" + HEALTH_CHECK_FAILED = "health_check_failed" + HEALTH_CHECK_PASSED = "health_check_passed" + AUTO_RESTARTED = "auto_restarted" + CONFIG_UPDATED = "config_updated" + ERROR = "error" + + +@dataclass +class HealthCheckConfig: + """Health check configuration for model services.""" + enabled: bool = True + endpoint: str = "/health" + interval_seconds: int = 30 + timeout_seconds: int = 10 + max_failures: int = 3 + startup_delay_seconds: int = 60 + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'HealthCheckConfig': + return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class ResourceLimits: + """Resource limits for model services.""" + memory_max: str = "32G" + memory_high: str = "28G" + cpu_quota: float = 4.0 # Number of CPU cores + cpu_weight: int = 100 # 1-10000, default 100 + io_weight: int = 100 # 1-10000, default 100 + tasks_max: int = 512 # Max number of processes/threads + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ResourceLimits': + return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) + + +@dataclass +class SecurityConfig: + """Security hardening configuration.""" + no_new_privileges: bool = True + protect_system: str = "strict" # "true", "full", "strict" + protect_home: str = "read-only" # "true", "read-only", "tmpfs" + private_tmp: bool = True + private_devices: bool = False # False to allow GPU access + restrict_realtime: bool = True + restrict_suid_sgid: bool = True + protect_kernel_tunables: bool = True + protect_kernel_modules: bool = True + protect_control_groups: bool = True + memory_deny_write_execute: bool = False # False for JIT compilation + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'SecurityConfig': + return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__}) + @dataclass class ModelConfig: + """Complete model configuration.""" name: str model_path: str backend: str = "vllm" port: int = 8000 + host: str = "127.0.0.1" gpu_memory_fraction: float = 0.9 max_model_len: int = 4096 gpu_ids: List[int] = field(default_factory=lambda: [0]) - memory_limit: str = "32G" - cpu_limit: float = 4.0 + tensor_parallel_size: int = 1 + quantization: Optional[str] = None # awq, gptq, squeezellm + dtype: str = "auto" # auto, float16, bfloat16 + extra_args: str = "" restart_policy: str = "on-failure" + restart_max_retries: int = 5 preload_on_boot: bool = False - + health_check: HealthCheckConfig = field(default_factory=HealthCheckConfig) + resources: ResourceLimits = field(default_factory=ResourceLimits) + security: SecurityConfig = field(default_factory=SecurityConfig) + environment: Dict[str, str] = field(default_factory=dict) + def to_dict(self) -> Dict[str, Any]: - return asdict(self) - + data = asdict(self) + return data + @classmethod def from_dict(cls, data: Dict[str, Any]) -> 'ModelConfig': - return cls(**data) + # Handle nested dataclasses + if 'health_check' in data and isinstance(data['health_check'], dict): + data['health_check'] = HealthCheckConfig.from_dict(data['health_check']) + if 'resources' in data and isinstance(data['resources'], dict): + data['resources'] = ResourceLimits.from_dict(data['resources']) + if 'security' in data and isinstance(data['security'], dict): + data['security'] = SecurityConfig.from_dict(data['security']) + + # Filter to valid fields + valid_fields = cls.__dataclass_fields__.keys() + filtered = {k: v for k, v in data.items() if k in valid_fields} + return cls(**filtered) + + def get_health_url(self) -> str: + """Get the health check URL.""" + endpoint = self.health_check.endpoint + if not endpoint.startswith('/'): + endpoint = '/' + endpoint + return f"http://{self.host}:{self.port}{endpoint}" class ModelDatabase: - def __init__(self): - CORTEX_DB_PATH.parent.mkdir(parents=True, exist_ok=True) + """SQLite database for model configuration and event persistence.""" + + def __init__(self, db_path: Path = CORTEX_DB_PATH): + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_db() - + def _init_db(self): - with sqlite3.connect(CORTEX_DB_PATH) as conn: - conn.execute(""" + """Initialize database schema.""" + with sqlite3.connect(self.db_path) as conn: + conn.executescript(""" CREATE TABLE IF NOT EXISTS models ( name TEXT PRIMARY KEY, config TEXT NOT NULL, - created_at TEXT NOT NULL - ) + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + model_name TEXT NOT NULL, + event_type TEXT NOT NULL, + details TEXT, + timestamp TEXT NOT NULL, + FOREIGN KEY (model_name) REFERENCES models(name) + ); + + CREATE INDEX IF NOT EXISTS idx_events_model ON events(model_name); + CREATE INDEX IF NOT EXISTS idx_events_timestamp ON events(timestamp); """) - - def save_model(self, config: ModelConfig): - with sqlite3.connect(CORTEX_DB_PATH) as conn: - conn.execute( - "INSERT OR REPLACE INTO models VALUES (?, ?, ?)", - (config.name, json.dumps(config.to_dict()), datetime.utcnow().isoformat()) - ) - + + def save_model(self, config: ModelConfig) -> None: + """Save or update model configuration.""" + now = datetime.now(timezone.utc).isoformat() + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + INSERT INTO models (name, config, created_at, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(name) DO UPDATE SET + config = excluded.config, + updated_at = excluded.updated_at + """, (config.name, json.dumps(config.to_dict()), now, now)) + def get_model(self, name: str) -> Optional[ModelConfig]: - with sqlite3.connect(CORTEX_DB_PATH) as conn: - row = conn.execute("SELECT config FROM models WHERE name = ?", (name,)).fetchone() - return ModelConfig.from_dict(json.loads(row[0])) if row else None - + """Get model configuration by name.""" + with sqlite3.connect(self.db_path) as conn: + row = conn.execute( + "SELECT config FROM models WHERE name = ?", (name,) + ).fetchone() + if row: + return ModelConfig.from_dict(json.loads(row[0])) + return None + def list_models(self) -> List[ModelConfig]: - with sqlite3.connect(CORTEX_DB_PATH) as conn: - rows = conn.execute("SELECT config FROM models").fetchall() + """List all registered models.""" + with sqlite3.connect(self.db_path) as conn: + rows = conn.execute("SELECT config FROM models ORDER BY name").fetchall() return [ModelConfig.from_dict(json.loads(r[0])) for r in rows] - - def delete_model(self, name: str): - with sqlite3.connect(CORTEX_DB_PATH) as conn: - conn.execute("DELETE FROM models WHERE name = ?", (name,)) + + def delete_model(self, name: str) -> bool: + """Delete model configuration.""" + with sqlite3.connect(self.db_path) as conn: + cursor = conn.execute("DELETE FROM models WHERE name = ?", (name,)) + return cursor.rowcount > 0 + + def log_event(self, model_name: str, event_type: EventType, details: str = None) -> None: + """Log an event for a model.""" + now = datetime.now(timezone.utc).isoformat() + with sqlite3.connect(self.db_path) as conn: + conn.execute( + "INSERT INTO events (model_name, event_type, details, timestamp) VALUES (?, ?, ?, ?)", + (model_name, event_type.value, details, now) + ) + + def get_events(self, model_name: str = None, limit: int = 100) -> List[Dict[str, Any]]: + """Get events, optionally filtered by model name.""" + with sqlite3.connect(self.db_path) as conn: + if model_name: + rows = conn.execute( + "SELECT model_name, event_type, details, timestamp FROM events " + "WHERE model_name = ? ORDER BY timestamp DESC LIMIT ?", + (model_name, limit) + ).fetchall() + else: + rows = conn.execute( + "SELECT model_name, event_type, details, timestamp FROM events " + "ORDER BY timestamp DESC LIMIT ?", + (limit,) + ).fetchall() + + return [ + {"model": r[0], "event": r[1], "details": r[2], "timestamp": r[3]} + for r in rows + ] class ServiceGenerator: + """Generate systemd service files for LLM backends.""" + + # Backend command templates BACKENDS = { - "vllm": "python -m vllm.entrypoints.openai.api_server --model {model_path} --port {port}", - "llamacpp": "llama-server -m {model_path} --port {port}", + "vllm": ( + "python -m vllm.entrypoints.openai.api_server " + "--model {model_path} " + "--host {host} " + "--port {port} " + "--gpu-memory-utilization {gpu_memory_fraction} " + "--max-model-len {max_model_len} " + "--tensor-parallel-size {tensor_parallel_size} " + "{quantization_arg} " + "{dtype_arg} " + "{extra_args}" + ), + "llamacpp": ( + "llama-server " + "-m {model_path} " + "--host {host} " + "--port {port} " + "-ngl 99 " + "-c {max_model_len} " + "{extra_args}" + ), "ollama": "ollama serve", + "tgi": ( + "text-generation-launcher " + "--model-id {model_path} " + "--hostname {host} " + "--port {port} " + "--max-input-length {max_model_len} " + "--max-total-tokens {max_total_tokens} " + "--num-shard {tensor_parallel_size} " + "{quantization_arg} " + "{dtype_arg} " + "{extra_args}" + ), } - + + # Health check endpoints by backend + HEALTH_ENDPOINTS = { + "vllm": "/health", + "llamacpp": "/health", + "ollama": "/api/tags", + "tgi": "/health", + } + + def _get_command(self, config: ModelConfig) -> str: + """Build the execution command for the backend.""" + template = self.BACKENDS.get(config.backend, self.BACKENDS["vllm"]) + + # Build optional arguments + quantization_arg = "" + if config.quantization: + if config.backend == "vllm": + quantization_arg = f"--quantization {config.quantization}" + elif config.backend == "tgi": + quantization_arg = f"--quantize {config.quantization}" + + dtype_arg = "" + if config.dtype != "auto": + if config.backend == "vllm": + dtype_arg = f"--dtype {config.dtype}" + elif config.backend == "tgi": + dtype_arg = f"--dtype {config.dtype}" + + # Calculate max total tokens for TGI + max_total_tokens = config.max_model_len * 2 + + cmd = template.format( + model_path=config.model_path, + host=config.host, + port=config.port, + gpu_memory_fraction=config.gpu_memory_fraction, + max_model_len=config.max_model_len, + max_total_tokens=max_total_tokens, + tensor_parallel_size=config.tensor_parallel_size, + quantization_arg=quantization_arg, + dtype_arg=dtype_arg, + extra_args=config.extra_args, + ) + + # Clean up multiple spaces + return ' '.join(cmd.split()) + + def _get_environment(self, config: ModelConfig) -> str: + """Generate environment variable lines.""" + env_lines = [] + + # GPU configuration + gpu_list = ','.join(map(str, config.gpu_ids)) + env_lines.append(f"Environment=CUDA_VISIBLE_DEVICES={gpu_list}") + env_lines.append(f"Environment=HIP_VISIBLE_DEVICES={gpu_list}") + + # Common ML environment variables + env_lines.append("Environment=TOKENIZERS_PARALLELISM=false") + env_lines.append("Environment=TRANSFORMERS_OFFLINE=0") + + # Custom environment variables + for key, value in config.environment.items(): + env_lines.append(f"Environment={key}={value}") + + return '\n'.join(env_lines) + + def _get_resource_limits(self, config: ModelConfig) -> str: + """Generate resource limit lines.""" + res = config.resources + lines = [ + f"CPUQuota={int(res.cpu_quota * 100)}%", + f"CPUWeight={res.cpu_weight}", + f"MemoryMax={res.memory_max}", + f"MemoryHigh={res.memory_high}", + f"IOWeight={res.io_weight}", + f"TasksMax={res.tasks_max}", + ] + return '\n'.join(lines) + + def _get_security_settings(self, config: ModelConfig) -> str: + """Generate security hardening lines.""" + sec = config.security + lines = [] + + if sec.no_new_privileges: + lines.append("NoNewPrivileges=true") + if sec.protect_system: + lines.append(f"ProtectSystem={sec.protect_system}") + if sec.protect_home: + lines.append(f"ProtectHome={sec.protect_home}") + if sec.private_tmp: + lines.append("PrivateTmp=true") + if sec.private_devices: + lines.append("PrivateDevices=true") + if sec.restrict_realtime: + lines.append("RestrictRealtime=true") + if sec.restrict_suid_sgid: + lines.append("RestrictSUIDSGID=true") + if sec.protect_kernel_tunables: + lines.append("ProtectKernelTunables=true") + if sec.protect_kernel_modules: + lines.append("ProtectKernelModules=true") + if sec.protect_control_groups: + lines.append("ProtectControlGroups=true") + if sec.memory_deny_write_execute: + lines.append("MemoryDenyWriteExecute=true") + + return '\n'.join(lines) + + def _get_health_check(self, config: ModelConfig) -> str: + """Generate health check watchdog configuration.""" + if not config.health_check.enabled: + return "" + + hc = config.health_check + # Use systemd watchdog for health monitoring + return f""" +# Health check via systemd watchdog +WatchdogSec={hc.interval_seconds} +""" + def generate(self, config: ModelConfig) -> str: - cmd = self.BACKENDS.get(config.backend, self.BACKENDS["vllm"]).format(**asdict(config)) - return f"""[Unit] + """Generate complete systemd service file.""" + cmd = self._get_command(config) + env = self._get_environment(config) + resources = self._get_resource_limits(config) + security = self._get_security_settings(config) + health = self._get_health_check(config) + + service = f"""[Unit] Description=Cortex Model: {config.name} +Documentation=https://github.com/cortexlinux/cortex After=network.target +Wants=network-online.target [Service] Type=simple ExecStart={cmd} -Environment=CUDA_VISIBLE_DEVICES={','.join(map(str, config.gpu_ids))} -CPUQuota={int(config.cpu_limit * 100)}% -MemoryMax={config.memory_limit} +{env} + +# Resource Limits +{resources} + +# Security Hardening +{security} + +# Restart Policy Restart={config.restart_policy} -NoNewPrivileges=true +RestartSec=10 +StartLimitIntervalSec=300 +StartLimitBurst={config.restart_max_retries} +{health} +# Logging +StandardOutput=journal +StandardError=journal +SyslogIdentifier=cortex-{config.name} [Install] WantedBy=default.target """ + return service + + def get_default_health_endpoint(self, backend: str) -> str: + """Get default health check endpoint for backend.""" + return self.HEALTH_ENDPOINTS.get(backend, "/health") + + +class HealthChecker: + """Health check monitor for model services.""" + + def __init__(self, manager: 'ModelLifecycleManager'): + self.manager = manager + self._monitors: Dict[str, threading.Thread] = {} + self._stop_events: Dict[str, threading.Event] = {} + self._failure_counts: Dict[str, int] = {} + + def check_health(self, config: ModelConfig) -> Tuple[bool, str]: + """Perform a single health check.""" + url = config.get_health_url() + timeout = config.health_check.timeout_seconds + + try: + req = urllib.request.Request(url, method='GET') + with urllib.request.urlopen(req, timeout=timeout) as response: + if response.status == 200: + return True, "OK" + return False, f"HTTP {response.status}" + except urllib.error.URLError as e: + return False, f"Connection failed: {e.reason}" + except Exception as e: + return False, str(e) + + def start_monitor(self, name: str) -> None: + """Start health monitoring for a model.""" + if name in self._monitors: + return + + config = self.manager.db.get_model(name) + if not config or not config.health_check.enabled: + return + + stop_event = threading.Event() + self._stop_events[name] = stop_event + self._failure_counts[name] = 0 + + def monitor_loop(): + hc = config.health_check + # Wait for startup + time.sleep(hc.startup_delay_seconds) + + while not stop_event.is_set(): + healthy, msg = self.check_health(config) + + if healthy: + if self._failure_counts.get(name, 0) > 0: + self.manager.db.log_event(name, EventType.HEALTH_CHECK_PASSED, msg) + self._failure_counts[name] = 0 + else: + self._failure_counts[name] = self._failure_counts.get(name, 0) + 1 + self.manager.db.log_event( + name, EventType.HEALTH_CHECK_FAILED, + f"Failure {self._failure_counts[name]}/{hc.max_failures}: {msg}" + ) + + if self._failure_counts[name] >= hc.max_failures: + self.manager.db.log_event(name, EventType.AUTO_RESTARTED) + self.manager.restart(name, log_event=False) + self._failure_counts[name] = 0 + time.sleep(hc.startup_delay_seconds) + + stop_event.wait(hc.interval_seconds) + + thread = threading.Thread(target=monitor_loop, daemon=True) + thread.start() + self._monitors[name] = thread + + def stop_monitor(self, name: str) -> None: + """Stop health monitoring for a model.""" + if name in self._stop_events: + self._stop_events[name].set() + if name in self._monitors: + self._monitors[name].join(timeout=5) + del self._monitors[name] + self._stop_events.pop(name, None) + self._failure_counts.pop(name, None) class ModelLifecycleManager: - def __init__(self): - self.db = ModelDatabase() + """Main manager for model lifecycle operations.""" + + def __init__(self, db_path: Path = None): + self.db = ModelDatabase(db_path) if db_path else ModelDatabase() + self.generator = ServiceGenerator() + self.health_checker = HealthChecker(self) CORTEX_SERVICE_DIR.mkdir(parents=True, exist_ok=True) - - def _systemctl(self, *args): - return subprocess.run(["systemctl", "--user"] + list(args), capture_output=True, text=True) - + CORTEX_LOG_DIR.mkdir(parents=True, exist_ok=True) + + def _systemctl(self, *args) -> subprocess.CompletedProcess: + """Run systemctl command.""" + return subprocess.run( + ["systemctl", "--user"] + list(args), + capture_output=True, + text=True + ) + + def _service_name(self, name: str) -> str: + """Get systemd service name for model.""" + return f"cortex-{name}.service" + + def _service_path(self, name: str) -> Path: + """Get service file path.""" + return CORTEX_SERVICE_DIR / self._service_name(name) + def register(self, config: ModelConfig) -> bool: - service = ServiceGenerator().generate(config) - service_path = CORTEX_SERVICE_DIR / f"cortex-{config.name}.service" - service_path.write_text(service) - self.db.save_model(config) - self._systemctl("daemon-reload") - print(f"✅ Registered model '{config.name}'") - return True - - def start(self, name: str) -> bool: - result = self._systemctl("start", f"cortex-{name}.service") - print(f"{'✅' if result.returncode == 0 else '❌'} Start {name}: {result.stderr or 'OK'}") - return result.returncode == 0 - - def stop(self, name: str) -> bool: - result = self._systemctl("stop", f"cortex-{name}.service") - print(f"{'✅' if result.returncode == 0 else '❌'} Stop {name}") - return result.returncode == 0 - - def status(self, name: str = None): - models = [self.db.get_model(name)] if name else self.db.list_models() - print(f"\n{'NAME':<20} {'STATE':<12} {'BACKEND':<10}") - print("-" * 50) + """Register a new model service.""" + try: + # Generate and write service file + service_content = self.generator.generate(config) + service_path = self._service_path(config.name) + service_path.write_text(service_content) + + # Save to database + self.db.save_model(config) + + # Reload systemd + self._systemctl("daemon-reload") + + # Log event + self.db.log_event(config.name, EventType.REGISTERED) + + print(f"Registered model '{config.name}'") + return True + except Exception as e: + self.db.log_event(config.name, EventType.ERROR, str(e)) + print(f"Failed to register '{config.name}': {e}") + return False + + def unregister(self, name: str) -> bool: + """Unregister a model service.""" + try: + # Stop if running + self.stop(name, log_event=False) + + # Disable if enabled + self.disable(name, log_event=False) + + # Stop health monitoring + self.health_checker.stop_monitor(name) + + # Remove service file + service_path = self._service_path(name) + if service_path.exists(): + service_path.unlink() + + # Reload systemd + self._systemctl("daemon-reload") + + # Remove from database + self.db.delete_model(name) + + # Log event + self.db.log_event(name, EventType.UNREGISTERED) + + print(f"Unregistered model '{name}'") + return True + except Exception as e: + self.db.log_event(name, EventType.ERROR, str(e)) + print(f"Failed to unregister '{name}': {e}") + return False + + def start(self, name: str, log_event: bool = True) -> bool: + """Start a model service.""" + config = self.db.get_model(name) + if not config: + print(f"Model '{name}' not found") + return False + + result = self._systemctl("start", self._service_name(name)) + success = result.returncode == 0 + + if success: + if log_event: + self.db.log_event(name, EventType.STARTED) + self.health_checker.start_monitor(name) + print(f"Started model '{name}'") + else: + if log_event: + self.db.log_event(name, EventType.ERROR, result.stderr) + print(f"Failed to start '{name}': {result.stderr}") + + return success + + def stop(self, name: str, log_event: bool = True) -> bool: + """Stop a model service.""" + self.health_checker.stop_monitor(name) + + result = self._systemctl("stop", self._service_name(name)) + success = result.returncode == 0 + + if success: + if log_event: + self.db.log_event(name, EventType.STOPPED) + print(f"Stopped model '{name}'") + else: + if log_event: + self.db.log_event(name, EventType.ERROR, result.stderr) + print(f"Failed to stop '{name}': {result.stderr}") + + return success + + def restart(self, name: str, log_event: bool = True) -> bool: + """Restart a model service.""" + self.health_checker.stop_monitor(name) + + result = self._systemctl("restart", self._service_name(name)) + success = result.returncode == 0 + + if success: + if log_event: + self.db.log_event(name, EventType.STARTED, "restart") + self.health_checker.start_monitor(name) + print(f"Restarted model '{name}'") + else: + if log_event: + self.db.log_event(name, EventType.ERROR, result.stderr) + print(f"Failed to restart '{name}': {result.stderr}") + + return success + + def enable(self, name: str, log_event: bool = True) -> bool: + """Enable a model for auto-start on boot.""" + result = self._systemctl("enable", self._service_name(name)) + success = result.returncode == 0 + + if success: + if log_event: + self.db.log_event(name, EventType.ENABLED) + # Update config + config = self.db.get_model(name) + if config: + config.preload_on_boot = True + self.db.save_model(config) + print(f"Enabled model '{name}' for auto-start") + else: + print(f"Failed to enable '{name}': {result.stderr}") + + return success + + def disable(self, name: str, log_event: bool = True) -> bool: + """Disable auto-start for a model.""" + result = self._systemctl("disable", self._service_name(name)) + success = result.returncode == 0 + + if success: + if log_event: + self.db.log_event(name, EventType.DISABLED) + # Update config + config = self.db.get_model(name) + if config: + config.preload_on_boot = False + self.db.save_model(config) + print(f"Disabled auto-start for model '{name}'") + else: + print(f"Failed to disable '{name}': {result.stderr}") + + return success + + def get_state(self, name: str) -> ModelState: + """Get current state of a model service.""" + result = self._systemctl("is-active", self._service_name(name)) + state_str = result.stdout.strip() + + try: + return ModelState(state_str) + except ValueError: + return ModelState.UNKNOWN + + def get_status(self, name: str) -> Dict[str, Any]: + """Get detailed status of a model service.""" + config = self.db.get_model(name) + if not config: + return {"error": f"Model '{name}' not found"} + + state = self.get_state(name) + + # Get additional info from systemctl + result = self._systemctl("show", self._service_name(name), + "--property=MainPID,MemoryCurrent,CPUUsageNSec,ActiveEnterTimestamp") + props = {} + for line in result.stdout.strip().split('\n'): + if '=' in line: + key, value = line.split('=', 1) + props[key] = value + + # Check if enabled + enabled_result = self._systemctl("is-enabled", self._service_name(name)) + enabled = enabled_result.stdout.strip() == "enabled" + + return { + "name": name, + "state": state.value, + "enabled": enabled, + "backend": config.backend, + "model_path": config.model_path, + "port": config.port, + "gpu_ids": config.gpu_ids, + "pid": props.get("MainPID", "0"), + "memory": props.get("MemoryCurrent", "0"), + "cpu_time": props.get("CPUUsageNSec", "0"), + "started_at": props.get("ActiveEnterTimestamp", ""), + } + + def status(self, name: str = None) -> None: + """Print status of one or all models.""" + if name: + models = [self.db.get_model(name)] + if not models[0]: + print(f"Model '{name}' not found") + return + else: + models = self.db.list_models() + + if not models: + print("No models registered") + return + + print(f"\n{'NAME':<20} {'STATE':<12} {'ENABLED':<8} {'BACKEND':<10} {'PORT':<6}") + print("-" * 60) + for m in models: if m: - result = self._systemctl("is-active", f"cortex-{m.name}.service") - state = result.stdout.strip() or "unknown" - print(f"{m.name:<20} {state:<12} {m.backend:<10}") + state = self.get_state(m.name) + enabled_result = self._systemctl("is-enabled", self._service_name(m.name)) + enabled = "yes" if enabled_result.stdout.strip() == "enabled" else "no" + + # Color-code state + state_str = state.value + if state == ModelState.ACTIVE: + state_str = f"\033[32m{state_str}\033[0m" # Green + elif state == ModelState.FAILED: + state_str = f"\033[31m{state_str}\033[0m" # Red + + print(f"{m.name:<20} {state_str:<21} {enabled:<8} {m.backend:<10} {m.port:<6}") + + def logs(self, name: str, lines: int = 50, follow: bool = False) -> None: + """Show logs for a model service.""" + args = ["journalctl", "--user", "-u", self._service_name(name), "-n", str(lines)] + if follow: + args.append("-f") + + subprocess.run(args) + + def events(self, name: str = None, limit: int = 20) -> None: + """Show events for models.""" + events = self.db.get_events(name, limit) + + if not events: + print("No events found") + return + + print(f"\n{'TIMESTAMP':<25} {'MODEL':<15} {'EVENT':<20} {'DETAILS'}") + print("-" * 80) + + for e in events: + ts = e['timestamp'][:19].replace('T', ' ') + details = (e['details'] or '')[:30] + print(f"{ts:<25} {e['model']:<15} {e['event']:<20} {details}") def main(): + """CLI entry point.""" import argparse - parser = argparse.ArgumentParser(description="Cortex Model Lifecycle Manager") - sub = parser.add_subparsers(dest="cmd") - - reg = sub.add_parser("register") - reg.add_argument("name") - reg.add_argument("--path", required=True) - reg.add_argument("--backend", default="vllm") - reg.add_argument("--port", type=int, default=8000) - reg.add_argument("--gpus", default="0") - - for cmd in ["start", "stop", "unregister"]: - p = sub.add_parser(cmd) - p.add_argument("name") - - sub.add_parser("status").add_argument("name", nargs="?") - sub.add_parser("list") - + + parser = argparse.ArgumentParser( + description="Cortex Model Lifecycle Manager - Systemd-based LLM service management", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + cortex-model register llama-70b --path meta-llama/Llama-2-70b-hf --backend vllm --gpus 0,1 + cortex-model start llama-70b + cortex-model status + cortex-model enable llama-70b + cortex-model logs llama-70b -f +""" + ) + + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # Register command + reg = subparsers.add_parser("register", help="Register a new model") + reg.add_argument("name", help="Model name") + reg.add_argument("--path", required=True, help="Model path or HuggingFace ID") + reg.add_argument("--backend", default="vllm", + choices=["vllm", "llamacpp", "ollama", "tgi"], + help="Inference backend") + reg.add_argument("--port", type=int, default=8000, help="Service port") + reg.add_argument("--host", default="127.0.0.1", help="Service host") + reg.add_argument("--gpus", default="0", help="Comma-separated GPU IDs") + reg.add_argument("--memory", default="32G", help="Memory limit") + reg.add_argument("--cpu", type=float, default=4.0, help="CPU cores limit") + reg.add_argument("--max-model-len", type=int, default=4096, help="Max sequence length") + reg.add_argument("--tensor-parallel", type=int, default=1, help="Tensor parallel size") + reg.add_argument("--quantization", help="Quantization method (awq, gptq)") + reg.add_argument("--extra-args", default="", help="Extra backend arguments") + reg.add_argument("--no-health-check", action="store_true", help="Disable health checks") + + # Start command + start = subparsers.add_parser("start", help="Start a model") + start.add_argument("name", help="Model name") + + # Stop command + stop = subparsers.add_parser("stop", help="Stop a model") + stop.add_argument("name", help="Model name") + + # Restart command + restart = subparsers.add_parser("restart", help="Restart a model") + restart.add_argument("name", help="Model name") + + # Enable command + enable = subparsers.add_parser("enable", help="Enable auto-start on boot") + enable.add_argument("name", help="Model name") + + # Disable command + disable = subparsers.add_parser("disable", help="Disable auto-start") + disable.add_argument("name", help="Model name") + + # Unregister command + unreg = subparsers.add_parser("unregister", help="Unregister a model") + unreg.add_argument("name", help="Model name") + + # Status command + status = subparsers.add_parser("status", help="Show model status") + status.add_argument("name", nargs="?", help="Model name (optional)") + + # List command + subparsers.add_parser("list", help="List all models") + + # Logs command + logs = subparsers.add_parser("logs", help="Show model logs") + logs.add_argument("name", help="Model name") + logs.add_argument("-n", "--lines", type=int, default=50, help="Number of lines") + logs.add_argument("-f", "--follow", action="store_true", help="Follow log output") + + # Events command + events = subparsers.add_parser("events", help="Show model events") + events.add_argument("name", nargs="?", help="Model name (optional)") + events.add_argument("-n", "--limit", type=int, default=20, help="Number of events") + + # Health command + health = subparsers.add_parser("health", help="Check model health") + health.add_argument("name", help="Model name") + args = parser.parse_args() - mgr = ModelLifecycleManager() - - if args.cmd == "register": - mgr.register(ModelConfig(args.name, args.path, args.backend, args.port, - gpu_ids=[int(x) for x in args.gpus.split(",")])) - elif args.cmd == "start": - mgr.start(args.name) - elif args.cmd == "stop": - mgr.stop(args.name) - elif args.cmd in ("status", "list"): - mgr.status(getattr(args, 'name', None)) + + if not args.command: + parser.print_help() + sys.exit(1) + + manager = ModelLifecycleManager() + + if args.command == "register": + config = ModelConfig( + name=args.name, + model_path=args.path, + backend=args.backend, + port=args.port, + host=args.host, + gpu_ids=[int(x) for x in args.gpus.split(",")], + max_model_len=args.max_model_len, + tensor_parallel_size=args.tensor_parallel, + quantization=args.quantization, + extra_args=args.extra_args, + health_check=HealthCheckConfig( + enabled=not args.no_health_check, + endpoint=ServiceGenerator().get_default_health_endpoint(args.backend) + ), + resources=ResourceLimits( + memory_max=args.memory, + cpu_quota=args.cpu + ) + ) + manager.register(config) + + elif args.command == "start": + manager.start(args.name) + + elif args.command == "stop": + manager.stop(args.name) + + elif args.command == "restart": + manager.restart(args.name) + + elif args.command == "enable": + manager.enable(args.name) + + elif args.command == "disable": + manager.disable(args.name) + + elif args.command == "unregister": + manager.unregister(args.name) + + elif args.command in ("status", "list"): + manager.status(getattr(args, 'name', None)) + + elif args.command == "logs": + manager.logs(args.name, args.lines, args.follow) + + elif args.command == "events": + manager.events(getattr(args, 'name', None), args.limit) + + elif args.command == "health": + config = manager.db.get_model(args.name) + if config: + healthy, msg = manager.health_checker.check_health(config) + status = "healthy" if healthy else "unhealthy" + print(f"Model '{args.name}' is {status}: {msg}") + sys.exit(0 if healthy else 1) + else: + print(f"Model '{args.name}' not found") + sys.exit(1) if __name__ == "__main__": diff --git a/tests/test_model_lifecycle.py b/tests/test_model_lifecycle.py new file mode 100644 index 00000000..03e69535 --- /dev/null +++ b/tests/test_model_lifecycle.py @@ -0,0 +1,907 @@ +#!/usr/bin/env python3 +""" +Unit tests for Cortex Model Lifecycle Manager + +Tests cover: +- Configuration dataclasses (ModelConfig, HealthCheckConfig, ResourceLimits, SecurityConfig) +- Database operations (save, get, list, delete, events) +- Service generation (all backends, security, resources) +- Lifecycle operations (register, start, stop, enable, disable) +- Health checking +- CLI parsing +""" + +import os +import sys +import json +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock +from dataclasses import asdict + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cortex.kernel_features.model_lifecycle import ( + ModelConfig, + HealthCheckConfig, + ResourceLimits, + SecurityConfig, + ModelDatabase, + ServiceGenerator, + ModelLifecycleManager, + HealthChecker, + ModelState, + EventType, +) + + +class TestHealthCheckConfig(unittest.TestCase): + """Test HealthCheckConfig dataclass.""" + + def test_default_values(self): + """Test default configuration values.""" + config = HealthCheckConfig() + self.assertTrue(config.enabled) + self.assertEqual(config.endpoint, "/health") + self.assertEqual(config.interval_seconds, 30) + self.assertEqual(config.timeout_seconds, 10) + self.assertEqual(config.max_failures, 3) + self.assertEqual(config.startup_delay_seconds, 60) + + def test_custom_values(self): + """Test custom configuration values.""" + config = HealthCheckConfig( + enabled=False, + endpoint="/api/health", + interval_seconds=60, + timeout_seconds=5, + max_failures=5, + startup_delay_seconds=120 + ) + self.assertFalse(config.enabled) + self.assertEqual(config.endpoint, "/api/health") + self.assertEqual(config.interval_seconds, 60) + + def test_to_dict(self): + """Test dictionary serialization.""" + config = HealthCheckConfig(interval_seconds=45) + d = config.to_dict() + self.assertEqual(d["interval_seconds"], 45) + self.assertIn("enabled", d) + + def test_from_dict(self): + """Test dictionary deserialization.""" + data = {"enabled": False, "endpoint": "/status", "timeout_seconds": 15} + config = HealthCheckConfig.from_dict(data) + self.assertFalse(config.enabled) + self.assertEqual(config.endpoint, "/status") + self.assertEqual(config.timeout_seconds, 15) + # Defaults for missing fields + self.assertEqual(config.interval_seconds, 30) + + +class TestResourceLimits(unittest.TestCase): + """Test ResourceLimits dataclass.""" + + def test_default_values(self): + """Test default resource limits.""" + limits = ResourceLimits() + self.assertEqual(limits.memory_max, "32G") + self.assertEqual(limits.memory_high, "28G") + self.assertEqual(limits.cpu_quota, 4.0) + self.assertEqual(limits.cpu_weight, 100) + self.assertEqual(limits.io_weight, 100) + self.assertEqual(limits.tasks_max, 512) + + def test_custom_values(self): + """Test custom resource limits.""" + limits = ResourceLimits( + memory_max="64G", + memory_high="56G", + cpu_quota=8.0, + cpu_weight=200, + io_weight=500, + tasks_max=1024 + ) + self.assertEqual(limits.memory_max, "64G") + self.assertEqual(limits.cpu_quota, 8.0) + self.assertEqual(limits.tasks_max, 1024) + + def test_to_dict(self): + """Test dictionary serialization.""" + limits = ResourceLimits(memory_max="16G") + d = limits.to_dict() + self.assertEqual(d["memory_max"], "16G") + + def test_from_dict(self): + """Test dictionary deserialization.""" + data = {"memory_max": "128G", "cpu_quota": 16.0} + limits = ResourceLimits.from_dict(data) + self.assertEqual(limits.memory_max, "128G") + self.assertEqual(limits.cpu_quota, 16.0) + + +class TestSecurityConfig(unittest.TestCase): + """Test SecurityConfig dataclass.""" + + def test_default_values(self): + """Test default security settings.""" + config = SecurityConfig() + self.assertTrue(config.no_new_privileges) + self.assertEqual(config.protect_system, "strict") + self.assertEqual(config.protect_home, "read-only") + self.assertTrue(config.private_tmp) + self.assertFalse(config.private_devices) # False for GPU access + self.assertTrue(config.restrict_realtime) + + def test_custom_values(self): + """Test custom security settings.""" + config = SecurityConfig( + no_new_privileges=False, + protect_system="full", + private_devices=True + ) + self.assertFalse(config.no_new_privileges) + self.assertEqual(config.protect_system, "full") + self.assertTrue(config.private_devices) + + def test_to_dict(self): + """Test dictionary serialization.""" + config = SecurityConfig(protect_system="true") + d = config.to_dict() + self.assertEqual(d["protect_system"], "true") + + def test_from_dict(self): + """Test dictionary deserialization.""" + data = {"no_new_privileges": False, "protect_home": "tmpfs"} + config = SecurityConfig.from_dict(data) + self.assertFalse(config.no_new_privileges) + self.assertEqual(config.protect_home, "tmpfs") + + +class TestModelConfig(unittest.TestCase): + """Test ModelConfig dataclass.""" + + def test_minimal_config(self): + """Test minimal configuration.""" + config = ModelConfig(name="test-model", model_path="/path/to/model") + self.assertEqual(config.name, "test-model") + self.assertEqual(config.model_path, "/path/to/model") + self.assertEqual(config.backend, "vllm") + self.assertEqual(config.port, 8000) + + def test_full_config(self): + """Test full configuration with all options.""" + config = ModelConfig( + name="llama-70b", + model_path="meta-llama/Llama-2-70b-hf", + backend="tgi", + port=8080, + host="0.0.0.0", + gpu_memory_fraction=0.85, + max_model_len=8192, + gpu_ids=[0, 1, 2, 3], + tensor_parallel_size=4, + quantization="awq", + dtype="float16", + extra_args="--trust-remote-code", + restart_policy="always", + restart_max_retries=10, + preload_on_boot=True, + health_check=HealthCheckConfig(enabled=True, interval_seconds=60), + resources=ResourceLimits(memory_max="128G"), + security=SecurityConfig(protect_system="full"), + environment={"HF_TOKEN": "xxx"} + ) + self.assertEqual(config.name, "llama-70b") + self.assertEqual(config.backend, "tgi") + self.assertEqual(config.gpu_ids, [0, 1, 2, 3]) + self.assertEqual(config.tensor_parallel_size, 4) + self.assertTrue(config.preload_on_boot) + self.assertEqual(config.health_check.interval_seconds, 60) + self.assertEqual(config.resources.memory_max, "128G") + + def test_to_dict(self): + """Test dictionary serialization.""" + config = ModelConfig(name="test", model_path="/path") + d = config.to_dict() + self.assertEqual(d["name"], "test") + self.assertIn("health_check", d) + self.assertIn("resources", d) + self.assertIn("security", d) + + def test_from_dict(self): + """Test dictionary deserialization.""" + data = { + "name": "from-dict", + "model_path": "/model", + "backend": "llamacpp", + "port": 9000, + "health_check": {"enabled": False, "interval_seconds": 120}, + "resources": {"memory_max": "64G"}, + "security": {"no_new_privileges": False} + } + config = ModelConfig.from_dict(data) + self.assertEqual(config.name, "from-dict") + self.assertEqual(config.backend, "llamacpp") + self.assertEqual(config.port, 9000) + self.assertFalse(config.health_check.enabled) + self.assertEqual(config.health_check.interval_seconds, 120) + self.assertEqual(config.resources.memory_max, "64G") + self.assertFalse(config.security.no_new_privileges) + + def test_get_health_url(self): + """Test health URL generation.""" + config = ModelConfig( + name="test", + model_path="/path", + host="localhost", + port=8080, + health_check=HealthCheckConfig(endpoint="/api/health") + ) + self.assertEqual(config.get_health_url(), "http://localhost:8080/api/health") + + def test_get_health_url_no_slash(self): + """Test health URL with endpoint without leading slash.""" + config = ModelConfig( + name="test", + model_path="/path", + health_check=HealthCheckConfig(endpoint="health") + ) + self.assertEqual(config.get_health_url(), "http://127.0.0.1:8000/health") + + +class TestModelDatabase(unittest.TestCase): + """Test ModelDatabase class.""" + + def setUp(self): + """Create temporary database.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test.db" + self.db = ModelDatabase(self.db_path) + + def tearDown(self): + """Clean up temporary files.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_save_and_get_model(self): + """Test saving and retrieving a model.""" + config = ModelConfig(name="test-model", model_path="/path") + self.db.save_model(config) + + retrieved = self.db.get_model("test-model") + self.assertIsNotNone(retrieved) + self.assertEqual(retrieved.name, "test-model") + self.assertEqual(retrieved.model_path, "/path") + + def test_get_nonexistent_model(self): + """Test getting a model that doesn't exist.""" + retrieved = self.db.get_model("nonexistent") + self.assertIsNone(retrieved) + + def test_list_models(self): + """Test listing all models.""" + config1 = ModelConfig(name="model-a", model_path="/a") + config2 = ModelConfig(name="model-b", model_path="/b") + config3 = ModelConfig(name="model-c", model_path="/c") + + self.db.save_model(config1) + self.db.save_model(config2) + self.db.save_model(config3) + + models = self.db.list_models() + self.assertEqual(len(models), 3) + names = [m.name for m in models] + self.assertIn("model-a", names) + self.assertIn("model-b", names) + self.assertIn("model-c", names) + + def test_list_models_empty(self): + """Test listing models when none exist.""" + models = self.db.list_models() + self.assertEqual(len(models), 0) + + def test_delete_model(self): + """Test deleting a model.""" + config = ModelConfig(name="to-delete", model_path="/path") + self.db.save_model(config) + + result = self.db.delete_model("to-delete") + self.assertTrue(result) + + retrieved = self.db.get_model("to-delete") + self.assertIsNone(retrieved) + + def test_delete_nonexistent_model(self): + """Test deleting a model that doesn't exist.""" + result = self.db.delete_model("nonexistent") + self.assertFalse(result) + + def test_update_model(self): + """Test updating an existing model.""" + config = ModelConfig(name="test", model_path="/old") + self.db.save_model(config) + + config.model_path = "/new" + config.port = 9000 + self.db.save_model(config) + + retrieved = self.db.get_model("test") + self.assertEqual(retrieved.model_path, "/new") + self.assertEqual(retrieved.port, 9000) + + def test_log_event(self): + """Test logging an event.""" + self.db.log_event("test-model", EventType.REGISTERED) + self.db.log_event("test-model", EventType.STARTED, "details here") + + events = self.db.get_events("test-model") + self.assertEqual(len(events), 2) + self.assertEqual(events[0]["event"], "started") # Most recent first + self.assertEqual(events[0]["details"], "details here") + + def test_get_events_all(self): + """Test getting all events.""" + self.db.log_event("model-a", EventType.REGISTERED) + self.db.log_event("model-b", EventType.STARTED) + self.db.log_event("model-a", EventType.STOPPED) + + events = self.db.get_events() + self.assertEqual(len(events), 3) + + def test_get_events_limit(self): + """Test event limit.""" + for i in range(10): + self.db.log_event("test", EventType.STARTED) + + events = self.db.get_events("test", limit=5) + self.assertEqual(len(events), 5) + + +class TestServiceGenerator(unittest.TestCase): + """Test ServiceGenerator class.""" + + def setUp(self): + """Create generator.""" + self.generator = ServiceGenerator() + + def test_generate_vllm_service(self): + """Test generating vLLM service file.""" + config = ModelConfig( + name="llama", + model_path="meta-llama/Llama-2-7b-hf", + backend="vllm", + port=8000, + gpu_ids=[0], + max_model_len=4096 + ) + service = self.generator.generate(config) + + self.assertIn("Description=Cortex Model: llama", service) + self.assertIn("vllm.entrypoints.openai.api_server", service) + self.assertIn("--model meta-llama/Llama-2-7b-hf", service) + self.assertIn("--port 8000", service) + self.assertIn("CUDA_VISIBLE_DEVICES=0", service) + self.assertIn("NoNewPrivileges=true", service) + + def test_generate_llamacpp_service(self): + """Test generating llama.cpp service file.""" + config = ModelConfig( + name="gguf-model", + model_path="/models/model.gguf", + backend="llamacpp", + port=8080 + ) + service = self.generator.generate(config) + + self.assertIn("llama-server", service) + self.assertIn("-m /models/model.gguf", service) + self.assertIn("--port 8080", service) + + def test_generate_tgi_service(self): + """Test generating TGI service file.""" + config = ModelConfig( + name="tgi-model", + model_path="bigscience/bloom-560m", + backend="tgi", + port=8000, + gpu_ids=[0, 1], + tensor_parallel_size=2 + ) + service = self.generator.generate(config) + + self.assertIn("text-generation-launcher", service) + self.assertIn("--model-id bigscience/bloom-560m", service) + self.assertIn("--num-shard 2", service) + self.assertIn("CUDA_VISIBLE_DEVICES=0,1", service) + + def test_generate_ollama_service(self): + """Test generating Ollama service file.""" + config = ModelConfig( + name="ollama", + model_path="llama2", + backend="ollama" + ) + service = self.generator.generate(config) + + self.assertIn("ollama serve", service) + + def test_generate_with_quantization(self): + """Test service with quantization.""" + config = ModelConfig( + name="quant-model", + model_path="/model", + backend="vllm", + quantization="awq" + ) + service = self.generator.generate(config) + + self.assertIn("--quantization awq", service) + + def test_generate_with_resources(self): + """Test service with custom resources.""" + config = ModelConfig( + name="resource-model", + model_path="/model", + resources=ResourceLimits( + memory_max="64G", + cpu_quota=8.0, + tasks_max=1024 + ) + ) + service = self.generator.generate(config) + + self.assertIn("MemoryMax=64G", service) + self.assertIn("CPUQuota=800%", service) + self.assertIn("TasksMax=1024", service) + + def test_generate_with_security(self): + """Test service with security settings.""" + config = ModelConfig( + name="secure-model", + model_path="/model", + security=SecurityConfig( + protect_system="strict", + protect_home="read-only", + restrict_realtime=True + ) + ) + service = self.generator.generate(config) + + self.assertIn("ProtectSystem=strict", service) + self.assertIn("ProtectHome=read-only", service) + self.assertIn("RestrictRealtime=true", service) + + def test_generate_restart_policy(self): + """Test restart policy in service.""" + config = ModelConfig( + name="restart-model", + model_path="/model", + restart_policy="always", + restart_max_retries=10 + ) + service = self.generator.generate(config) + + self.assertIn("Restart=always", service) + self.assertIn("StartLimitBurst=10", service) + + def test_get_default_health_endpoint(self): + """Test default health endpoints.""" + self.assertEqual(self.generator.get_default_health_endpoint("vllm"), "/health") + self.assertEqual(self.generator.get_default_health_endpoint("tgi"), "/health") + self.assertEqual(self.generator.get_default_health_endpoint("ollama"), "/api/tags") + self.assertEqual(self.generator.get_default_health_endpoint("unknown"), "/health") + + +class TestModelLifecycleManager(unittest.TestCase): + """Test ModelLifecycleManager class.""" + + def setUp(self): + """Create manager with temporary database.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test.db" + self.manager = ModelLifecycleManager(self.db_path) + + def tearDown(self): + """Clean up temporary files.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_service_name(self): + """Test service name generation.""" + self.assertEqual(self.manager._service_name("my-model"), "cortex-my-model.service") + + @patch('subprocess.run') + def test_register(self, mock_run): + """Test model registration.""" + mock_run.return_value = MagicMock(returncode=0) + + config = ModelConfig(name="test-model", model_path="/path") + result = self.manager.register(config) + + self.assertTrue(result) + self.assertIsNotNone(self.manager.db.get_model("test-model")) + mock_run.assert_called() # daemon-reload + + @patch('subprocess.run') + def test_start(self, mock_run): + """Test starting a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path") + self.manager.db.save_model(config) + + result = self.manager.start("test-model") + self.assertTrue(result) + + @patch('subprocess.run') + def test_start_nonexistent(self, mock_run): + """Test starting a model that doesn't exist.""" + result = self.manager.start("nonexistent") + self.assertFalse(result) + + @patch('subprocess.run') + def test_stop(self, mock_run): + """Test stopping a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path") + self.manager.db.save_model(config) + + result = self.manager.stop("test-model") + self.assertTrue(result) + + @patch('subprocess.run') + def test_restart(self, mock_run): + """Test restarting a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path") + self.manager.db.save_model(config) + + result = self.manager.restart("test-model") + self.assertTrue(result) + + @patch('subprocess.run') + def test_enable(self, mock_run): + """Test enabling a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path", preload_on_boot=False) + self.manager.db.save_model(config) + + result = self.manager.enable("test-model") + self.assertTrue(result) + + # Check config updated + updated = self.manager.db.get_model("test-model") + self.assertTrue(updated.preload_on_boot) + + @patch('subprocess.run') + def test_disable(self, mock_run): + """Test disabling a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path", preload_on_boot=True) + self.manager.db.save_model(config) + + result = self.manager.disable("test-model") + self.assertTrue(result) + + updated = self.manager.db.get_model("test-model") + self.assertFalse(updated.preload_on_boot) + + @patch('subprocess.run') + def test_unregister(self, mock_run): + """Test unregistering a model.""" + mock_run.return_value = MagicMock(returncode=0, stdout="", stderr="") + + config = ModelConfig(name="test-model", model_path="/path") + self.manager.register(config) + + result = self.manager.unregister("test-model") + self.assertTrue(result) + self.assertIsNone(self.manager.db.get_model("test-model")) + + @patch('subprocess.run') + def test_get_state(self, mock_run): + """Test getting model state.""" + mock_run.return_value = MagicMock(returncode=0, stdout="active\n", stderr="") + + state = self.manager.get_state("test-model") + self.assertEqual(state, ModelState.ACTIVE) + + @patch('subprocess.run') + def test_get_state_inactive(self, mock_run): + """Test getting inactive state.""" + mock_run.return_value = MagicMock(returncode=3, stdout="inactive\n", stderr="") + + state = self.manager.get_state("test-model") + self.assertEqual(state, ModelState.INACTIVE) + + @patch('subprocess.run') + def test_get_state_failed(self, mock_run): + """Test getting failed state.""" + mock_run.return_value = MagicMock(returncode=3, stdout="failed\n", stderr="") + + state = self.manager.get_state("test-model") + self.assertEqual(state, ModelState.FAILED) + + @patch('subprocess.run') + def test_get_status(self, mock_run): + """Test getting detailed status.""" + mock_run.side_effect = [ + MagicMock(returncode=0, stdout="active\n", stderr=""), + MagicMock(returncode=0, stdout="MainPID=12345\nMemoryCurrent=1000000\n", stderr=""), + MagicMock(returncode=0, stdout="enabled\n", stderr=""), + ] + + config = ModelConfig(name="test-model", model_path="/path") + self.manager.db.save_model(config) + + status = self.manager.get_status("test-model") + self.assertEqual(status["name"], "test-model") + self.assertEqual(status["state"], "active") + self.assertTrue(status["enabled"]) + + def test_get_status_nonexistent(self): + """Test getting status of nonexistent model.""" + status = self.manager.get_status("nonexistent") + self.assertIn("error", status) + + +class TestHealthChecker(unittest.TestCase): + """Test HealthChecker class.""" + + def setUp(self): + """Create health checker with mock manager.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test.db" + self.manager = ModelLifecycleManager(self.db_path) + self.checker = self.manager.health_checker + + def tearDown(self): + """Clean up temporary files.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @patch('urllib.request.urlopen') + def test_check_health_success(self, mock_urlopen): + """Test successful health check.""" + mock_response = MagicMock() + mock_response.status = 200 + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + config = ModelConfig(name="test", model_path="/path") + healthy, msg = self.checker.check_health(config) + + self.assertTrue(healthy) + self.assertEqual(msg, "OK") + + @patch('urllib.request.urlopen') + def test_check_health_failure_status(self, mock_urlopen): + """Test health check with bad status code.""" + mock_response = MagicMock() + mock_response.status = 500 + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock(return_value=False) + mock_urlopen.return_value = mock_response + + config = ModelConfig(name="test", model_path="/path") + healthy, msg = self.checker.check_health(config) + + self.assertFalse(healthy) + self.assertIn("500", msg) + + @patch('urllib.request.urlopen') + def test_check_health_connection_error(self, mock_urlopen): + """Test health check with connection error.""" + import urllib.error + mock_urlopen.side_effect = urllib.error.URLError("Connection refused") + + config = ModelConfig(name="test", model_path="/path") + healthy, msg = self.checker.check_health(config) + + self.assertFalse(healthy) + self.assertIn("Connection failed", msg) + + +class TestModelState(unittest.TestCase): + """Test ModelState enum.""" + + def test_state_values(self): + """Test all state values.""" + self.assertEqual(ModelState.UNKNOWN.value, "unknown") + self.assertEqual(ModelState.INACTIVE.value, "inactive") + self.assertEqual(ModelState.ACTIVATING.value, "activating") + self.assertEqual(ModelState.ACTIVE.value, "active") + self.assertEqual(ModelState.DEACTIVATING.value, "deactivating") + self.assertEqual(ModelState.FAILED.value, "failed") + self.assertEqual(ModelState.RELOADING.value, "reloading") + + +class TestEventType(unittest.TestCase): + """Test EventType enum.""" + + def test_event_values(self): + """Test all event type values.""" + self.assertEqual(EventType.REGISTERED.value, "registered") + self.assertEqual(EventType.STARTED.value, "started") + self.assertEqual(EventType.STOPPED.value, "stopped") + self.assertEqual(EventType.ENABLED.value, "enabled") + self.assertEqual(EventType.DISABLED.value, "disabled") + self.assertEqual(EventType.UNREGISTERED.value, "unregistered") + self.assertEqual(EventType.HEALTH_CHECK_FAILED.value, "health_check_failed") + self.assertEqual(EventType.HEALTH_CHECK_PASSED.value, "health_check_passed") + self.assertEqual(EventType.AUTO_RESTARTED.value, "auto_restarted") + self.assertEqual(EventType.CONFIG_UPDATED.value, "config_updated") + self.assertEqual(EventType.ERROR.value, "error") + + +class TestConfigSerialization(unittest.TestCase): + """Test full configuration serialization round-trip.""" + + def test_full_roundtrip(self): + """Test serializing and deserializing full config.""" + original = ModelConfig( + name="roundtrip-test", + model_path="/models/test", + backend="tgi", + port=9000, + host="0.0.0.0", + gpu_memory_fraction=0.8, + max_model_len=8192, + gpu_ids=[0, 1], + tensor_parallel_size=2, + quantization="gptq", + dtype="float16", + extra_args="--trust-remote-code", + restart_policy="always", + restart_max_retries=10, + preload_on_boot=True, + health_check=HealthCheckConfig( + enabled=True, + endpoint="/api/health", + interval_seconds=60, + timeout_seconds=15, + max_failures=5, + startup_delay_seconds=120 + ), + resources=ResourceLimits( + memory_max="128G", + memory_high="120G", + cpu_quota=16.0, + cpu_weight=200, + io_weight=500, + tasks_max=2048 + ), + security=SecurityConfig( + no_new_privileges=True, + protect_system="full", + protect_home="read-only", + private_tmp=True, + private_devices=False, + restrict_realtime=True + ), + environment={"HF_TOKEN": "test", "CUSTOM_VAR": "value"} + ) + + # Serialize to dict + data = original.to_dict() + + # Deserialize from dict + restored = ModelConfig.from_dict(data) + + # Verify all fields match + self.assertEqual(restored.name, original.name) + self.assertEqual(restored.model_path, original.model_path) + self.assertEqual(restored.backend, original.backend) + self.assertEqual(restored.port, original.port) + self.assertEqual(restored.gpu_ids, original.gpu_ids) + self.assertEqual(restored.tensor_parallel_size, original.tensor_parallel_size) + self.assertEqual(restored.quantization, original.quantization) + self.assertEqual(restored.preload_on_boot, original.preload_on_boot) + + # Nested configs + self.assertEqual(restored.health_check.interval_seconds, original.health_check.interval_seconds) + self.assertEqual(restored.resources.memory_max, original.resources.memory_max) + self.assertEqual(restored.security.protect_system, original.security.protect_system) + self.assertEqual(restored.environment, original.environment) + + def test_json_roundtrip(self): + """Test JSON serialization round-trip.""" + original = ModelConfig( + name="json-test", + model_path="/path", + health_check=HealthCheckConfig(interval_seconds=45) + ) + + # To JSON and back + json_str = json.dumps(original.to_dict()) + data = json.loads(json_str) + restored = ModelConfig.from_dict(data) + + self.assertEqual(restored.name, original.name) + self.assertEqual(restored.health_check.interval_seconds, 45) + + +class TestDatabasePersistence(unittest.TestCase): + """Test database persistence across instances.""" + + def test_persistence(self): + """Test that data persists across database instances.""" + temp_dir = tempfile.mkdtemp() + db_path = Path(temp_dir) / "persist.db" + + try: + # First instance + db1 = ModelDatabase(db_path) + config = ModelConfig(name="persist-test", model_path="/path") + db1.save_model(config) + db1.log_event("persist-test", EventType.REGISTERED) + + # Second instance (simulates restart) + db2 = ModelDatabase(db_path) + retrieved = db2.get_model("persist-test") + events = db2.get_events("persist-test") + + self.assertIsNotNone(retrieved) + self.assertEqual(retrieved.name, "persist-test") + self.assertEqual(len(events), 1) + finally: + import shutil + shutil.rmtree(temp_dir, ignore_errors=True) + + +class TestEdgeCases(unittest.TestCase): + """Test edge cases and error handling.""" + + def setUp(self): + """Create temporary database.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = Path(self.temp_dir) / "test.db" + + def tearDown(self): + """Clean up temporary files.""" + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_special_characters_in_name(self): + """Test model names with special characters.""" + db = ModelDatabase(self.db_path) + config = ModelConfig(name="model-with_special.chars", model_path="/path") + db.save_model(config) + retrieved = db.get_model("model-with_special.chars") + self.assertIsNotNone(retrieved) + + def test_empty_gpu_ids(self): + """Test config with empty GPU IDs.""" + config = ModelConfig(name="cpu-only", model_path="/path", gpu_ids=[]) + self.assertEqual(config.gpu_ids, []) + + def test_large_max_model_len(self): + """Test config with large max_model_len.""" + config = ModelConfig(name="large", model_path="/path", max_model_len=131072) + self.assertEqual(config.max_model_len, 131072) + + def test_many_gpus(self): + """Test config with many GPUs.""" + config = ModelConfig( + name="multi-gpu", + model_path="/path", + gpu_ids=[0, 1, 2, 3, 4, 5, 6, 7] + ) + self.assertEqual(len(config.gpu_ids), 8) + + generator = ServiceGenerator() + service = generator.generate(config) + self.assertIn("CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7", service) + + +if __name__ == "__main__": + unittest.main(verbosity=2)