diff --git a/cortex/cli.py b/cortex/cli.py index 9261a816..a553086b 100644 --- a/cortex/cli.py +++ b/cortex/cli.py @@ -632,12 +632,292 @@ def ask(self, question: str) -> int: self._print_error(str(e)) return 1 + def _run_conflict_prediction(self, software: str) -> int | None: + """ + Run AI-powered conflict prediction before installation. + + Returns: + None if safe to proceed + int (exit code) if should stop + """ + from cortex.conflict_predictor import ( + ConflictSeverity, + DependencyConflictPredictor, + ) + + cx_print("Analyzing dependencies for potential conflicts...", "info") + + try: + predictor = DependencyConflictPredictor() + + # Split software into individual packages + packages = software.split() + + # Run prediction for each package + all_conflicts = [] + for pkg in packages: + # Skip commands like "pip install" or "apt-get" + if pkg in ("pip", "pip3", "apt", "apt-get", "install", "-y", "&&"): + continue + + prediction = predictor.predict_conflicts(pkg) + + if prediction.conflicts: + all_conflicts.append(prediction) + + if not all_conflicts: + cx_print("No conflicts predicted. Safe to proceed.", "success") + console.print() + return None # Safe to proceed + + # Display conflicts + console.print() + for prediction in all_conflicts: + # Format the prediction nicely + risk_emoji = { + ConflictSeverity.LOW: "[yellow]", + ConflictSeverity.MEDIUM: "[orange1]", + ConflictSeverity.HIGH: "[red]", + ConflictSeverity.CRITICAL: "[bold red]", + } + risk_color = risk_emoji.get(prediction.overall_risk, "") + + console.print( + f"{risk_color}Conflict predicted[/]: {prediction.package_name}" + ) + console.print( + f" Risk Level: {risk_color}{prediction.overall_risk.value.upper()}[/]" + ) + console.print( + f" Confidence: {prediction.prediction_confidence:.0%}" + ) + console.print() + + # Show each conflict + for i, conflict in enumerate(prediction.conflicts, 1): + console.print(f" {i}. {conflict.description}") + console.print( + f" {conflict.conflicting_package} " + f"{conflict.conflicting_version} (installed)" + ) + console.print() + + # Show top suggestions + if prediction.resolutions: + console.print(" [bold cyan]Suggestions (ranked by safety):[/bold cyan]") + for i, res in enumerate(prediction.resolutions[:3], 1): + rec_tag = " [RECOMMENDED]" if res.recommended else "" + console.print(f" {i}. {res.description}{rec_tag}") + if res.command: + console.print(f" [dim]$ {res.command}[/dim]") + console.print() + + # Check if we should stop for critical conflicts + critical_predictions = [ + p for p in all_conflicts if p.overall_risk == ConflictSeverity.CRITICAL + ] + + if critical_predictions: + cx_print( + "Critical conflicts detected. Installation blocked.", + "error", + ) + cx_print( + "Resolve conflicts above or use --no-predict to skip this check.", + "warning", + ) + return 1 # Stop with error + + # For non-critical conflicts, ask user + high_predictions = [ + p for p in all_conflicts if p.overall_risk == ConflictSeverity.HIGH + ] + + if high_predictions: + try: + response = console.input( + "[bold yellow]High-risk conflicts detected. " + "Proceed anyway? (y/N): [/bold yellow]" + ) + if response.lower() not in ("y", "yes"): + cx_print("Installation cancelled", "info") + return 0 # User cancelled + except (EOFError, KeyboardInterrupt): + console.print() + cx_print("Installation cancelled", "info") + return 0 + + # Medium/Low conflicts - just warn and proceed + return None # Safe to proceed + + except ImportError as e: + # Conflict predictor module not available - just warn and continue + self._debug(f"Conflict prediction unavailable: {e}") + return None + except Exception as e: + # Don't let prediction errors block installation + cx_print(f"Conflict prediction failed: {e}", "warning") + cx_print("Proceeding with installation...", "info") + return None + + def predict(self, package: str, json_output: bool = False, verbose: bool = False) -> int: + """ + Predict dependency conflicts for a package before installation. + + This is the standalone 'cortex predict' command. + """ + import json as json_lib + + from cortex.conflict_predictor import ( + ConflictSeverity, + DependencyConflictPredictor, + ) + + try: + predictor = DependencyConflictPredictor() + + # Split packages if multiple provided + packages = package.split() + predictions = [] + + cx_print(f"Analyzing {len(packages)} package(s) for conflicts...", "info") + console.print() + + for pkg in packages: + prediction = predictor.predict_conflicts(pkg) + predictions.append(prediction) + + if json_output: + # JSON output mode + output = { + "packages": [ + predictor.export_prediction_json(p) for p in predictions + ], + "summary": { + "total_packages": len(predictions), + "packages_with_conflicts": sum( + 1 for p in predictions if p.conflicts + ), + "critical_conflicts": sum( + 1 + for p in predictions + if p.overall_risk == ConflictSeverity.CRITICAL + ), + }, + } + console.print(json_lib.dumps(output, indent=2)) + else: + # Human-readable output + for prediction in predictions: + if not prediction.conflicts: + console.print( + f"[green]No conflicts predicted[/green] for " + f"[bold]{prediction.package_name}[/bold]" + ) + console.print( + f" Confidence: {prediction.prediction_confidence:.0%}" + ) + console.print() + continue + + # Risk color coding + risk_color = { + ConflictSeverity.LOW: "yellow", + ConflictSeverity.MEDIUM: "orange1", + ConflictSeverity.HIGH: "red", + ConflictSeverity.CRITICAL: "bold red", + }.get(prediction.overall_risk, "white") + + console.print( + f"[{risk_color}]Conflict predicted[/{risk_color}]: " + f"[bold]{prediction.package_name}[/bold]" + ) + console.print() + + # Show conflicts + for i, conflict in enumerate(prediction.conflicts, 1): + severity_badge = f"[{conflict.severity.value.upper()}]" + console.print(f" {i}. {severity_badge} {conflict.description}") + console.print( + f" Installed: {conflict.conflicting_package} " + f"{conflict.conflicting_version}" + ) + console.print(f" Type: {conflict.conflict_type}") + console.print( + f" Confidence: {conflict.confidence:.0%}" + ) + console.print() + + # Show suggestions + if prediction.resolutions: + console.print( + " [bold cyan]Suggestions (ranked by safety):[/bold cyan]" + ) + for i, res in enumerate(prediction.resolutions[:4], 1): + rec = " [RECOMMENDED]" if res.recommended else "" + console.print(f" {i}. {res.description}") + if res.command: + console.print(f" [dim]$ {res.command}[/dim]") + console.print( + f" Safety: {res.safety_score:.0%}{rec}" + ) + if verbose and res.side_effects: + for effect in res.side_effects: + console.print(f" [dim]- {effect}[/dim]") + console.print() + + # Show analysis details if verbose + if verbose and prediction.analysis_details: + console.print(" [dim]Analysis Details:[/dim]") + for key, value in prediction.analysis_details.items(): + console.print(f" {key}: {value}") + console.print() + + # Summary + conflicts_found = sum(1 for p in predictions if p.conflicts) + critical = sum( + 1 + for p in predictions + if p.overall_risk == ConflictSeverity.CRITICAL + ) + + console.print("-" * 50) + if conflicts_found == 0: + console.print( + f"[bold green]All {len(predictions)} package(s) safe to install[/bold green]" + ) + return 0 + else: + console.print( + f"[bold yellow]{conflicts_found}/{len(predictions)} package(s) " + f"have potential conflicts[/bold yellow]" + ) + if critical > 0: + console.print( + f"[bold red]{critical} CRITICAL conflict(s) detected[/bold red]" + ) + return 1 + return 0 + + except ImportError as e: + self._print_error(f"Conflict prediction module not available: {e}") + return 1 + except Exception as e: + self._print_error(f"Prediction failed: {e}") + if verbose: + import traceback + + traceback.print_exc() + return 1 + def install( self, software: str, execute: bool = False, dry_run: bool = False, parallel: bool = False, + predict: bool = False, + no_predict: bool = False, ): # Validate input first is_valid, error = validate_install_request(software) @@ -645,6 +925,14 @@ def install( self._print_error(error) return 1 + # Run conflict prediction if requested or if executing (unless --no-predict) + should_predict = predict or (execute and not no_predict) + if should_predict: + prediction_result = self._run_conflict_prediction(software) + if prediction_result is not None: + # Non-None means we should stop (critical conflict or user cancelled) + return prediction_result + # Special-case the ml-cpu stack: # The LLM sometimes generates outdated torch==1.8.1+cpu installs # which fail on modern Python. For the "pytorch-cpu jupyter numpy pandas" @@ -2031,6 +2319,7 @@ def show_rich_help(): table.add_row("wizard", "Configure API key") table.add_row("status", "System status") table.add_row("install ", "Install software") + table.add_row("predict ", "Predict dependency conflicts") table.add_row("import ", "Import deps from package files") table.add_row("history", "View history") table.add_row("rollback ", "Undo installation") @@ -2144,6 +2433,29 @@ def main(): action="store_true", help="Enable parallel execution for multi-step installs", ) + install_parser.add_argument( + "--predict", + action="store_true", + help="Predict dependency conflicts before installation", + ) + install_parser.add_argument( + "--no-predict", + action="store_true", + help="Skip automatic conflict prediction", + ) + + # Predict command - standalone conflict prediction + predict_parser = subparsers.add_parser( + "predict", + help="Predict dependency conflicts before installation", + ) + predict_parser.add_argument("package", type=str, help="Package(s) to analyze") + predict_parser.add_argument( + "--json", "-j", action="store_true", help="Output as JSON" + ) + predict_parser.add_argument( + "--verbose", action="store_true", help="Show detailed analysis" + ) # Import command - import dependencies from package manager files import_parser = subparsers.add_parser( @@ -2510,6 +2822,14 @@ def main(): execute=args.execute, dry_run=args.dry_run, parallel=args.parallel, + predict=args.predict, + no_predict=args.no_predict, + ) + elif args.command == "predict": + return cli.predict( + args.package, + json_output=args.json, + verbose=args.verbose, ) elif args.command == "import": return cli.import_deps(args) diff --git a/cortex/conflict_predictor.py b/cortex/conflict_predictor.py new file mode 100644 index 00000000..e21f5fdf --- /dev/null +++ b/cortex/conflict_predictor.py @@ -0,0 +1,933 @@ +#!/usr/bin/env python3 +""" +AI-Powered Dependency Conflict Predictor + +Predicts dependency conflicts BEFORE installation starts using: +- Local dependency graph analysis +- System state from dpkg/apt +- Pip package metadata +- AI-powered conflict pattern recognition +- Confidence scoring and resolution suggestions +""" + +import json +import logging +import re +import subprocess +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class ConflictSeverity(Enum): + """Severity levels for predicted conflicts""" + + LOW = "low" # Minor version mismatch, likely compatible + MEDIUM = "medium" # Potential issues, may need attention + HIGH = "high" # Likely to cause problems + CRITICAL = "critical" # Will definitely fail + + +class ResolutionStrategy(Enum): + """Types of resolution strategies""" + + UPGRADE_PACKAGE = "upgrade_package" # Upgrade target to compatible version + DOWNGRADE_DEPENDENCY = "downgrade_dependency" # Downgrade existing package + USE_VIRTUALENV = "use_virtualenv" # Isolate in virtual environment + REMOVE_CONFLICTING = "remove_conflicting" # Remove conflicting package + INSTALL_ALTERNATIVE = "install_alternative" # Use alternative package + PIN_VERSION = "pin_version" # Pin specific compatible version + SKIP_INSTALL = "skip_install" # Skip this package entirely + + +class PackageEcosystem(Enum): + """Package ecosystems supported for conflict detection""" + + APT = "apt" # Debian/Ubuntu apt packages + PIP = "pip" # Python pip packages + NPM = "npm" # Node.js npm packages + SYSTEM = "system" # System-level conflicts + + +@dataclass +class InstalledPackage: + """Represents an installed package""" + + name: str + version: str + ecosystem: PackageEcosystem + source: str = "" # What installed this package + + +@dataclass +class VersionConstraint: + """Version constraint for a dependency""" + + operator: str # <, <=, ==, >=, >, !=, ~= + version: str + original: str = "" # Original constraint string + + +@dataclass +class DependencyRequirement: + """A package's dependency requirement""" + + package_name: str + constraints: list[VersionConstraint] + is_optional: bool = False + extras: list[str] = field(default_factory=list) + + +@dataclass +class PredictedConflict: + """A predicted dependency conflict""" + + package_to_install: str + package_version: str | None + conflicting_package: str + conflicting_version: str + installed_by: str # What package installed the conflicting dependency + conflict_type: str # version_mismatch, mutual_exclusion, etc. + severity: ConflictSeverity + confidence: float # 0.0 to 1.0 + description: str + ecosystem: PackageEcosystem + + +@dataclass +class ResolutionSuggestion: + """A suggested resolution for a conflict""" + + strategy: ResolutionStrategy + description: str + command: str | None + safety_score: float # 0.0 to 1.0 (higher = safer) + side_effects: list[str] + recommended: bool = False + + +@dataclass +class ConflictPrediction: + """Complete conflict prediction result""" + + package_name: str + package_version: str | None + conflicts: list[PredictedConflict] + resolutions: list[ResolutionSuggestion] + overall_risk: ConflictSeverity + can_proceed: bool + prediction_confidence: float + analysis_details: dict[str, Any] = field(default_factory=dict) + + +class DependencyConflictPredictor: + """ + AI-powered dependency conflict prediction engine. + + Analyzes dependency graphs and predicts conflicts BEFORE installation, + providing resolution suggestions ranked by safety. + """ + + # Known version conflict patterns (pre-trained knowledge) + KNOWN_CONFLICTS = { + # Python/pip conflicts + "tensorflow": { + "numpy": {"max_version": "2.0.0", "reason": "TensorFlow <2.16 requires numpy<2.0"}, + "protobuf": {"max_version": "4.0.0", "reason": "TensorFlow requires specific protobuf"}, + }, + "torch": { + "numpy": {"min_version": "1.19.0", "reason": "PyTorch requires numpy>=1.19"}, + }, + "pandas": { + "numpy": {"min_version": "1.20.0", "reason": "Pandas 2.x requires numpy>=1.20"}, + }, + "scipy": { + "numpy": {"min_version": "1.19.0", "reason": "SciPy requires numpy>=1.19"}, + }, + # APT package conflicts + "mysql-server": { + "mariadb-server": {"mutual_exclusion": True, "reason": "Cannot have both MySQL and MariaDB"}, + }, + "mariadb-server": { + "mysql-server": {"mutual_exclusion": True, "reason": "Cannot have both MariaDB and MySQL"}, + }, + "apache2": { + "nginx": {"port_conflict": True, "reason": "Both use port 80 by default"}, + }, + "nginx": { + "apache2": {"port_conflict": True, "reason": "Both use port 80 by default"}, + }, + "python3.10": { + "python3.11": {"conflict_type": "alternative", "reason": "Different Python versions"}, + "python3.12": {"conflict_type": "alternative", "reason": "Different Python versions"}, + }, + } + + # Common transitive dependency issues + TRANSITIVE_PATTERNS = { + "grpcio": ["protobuf"], + "tensorflow": ["numpy", "protobuf", "grpcio", "h5py"], + "torch": ["numpy", "pillow", "typing-extensions"], + "pandas": ["numpy", "python-dateutil", "pytz"], + "scikit-learn": ["numpy", "scipy", "joblib"], + "matplotlib": ["numpy", "pillow", "pyparsing"], + } + + def __init__(self, llm_router=None): + """ + Initialize the conflict predictor. + + Args: + llm_router: Optional LLMRouter instance for AI-powered analysis + """ + self.llm_router = llm_router + self._installed_apt_cache: dict[str, InstalledPackage] = {} + self._installed_pip_cache: dict[str, InstalledPackage] = {} + self._refresh_caches() + + def _run_command(self, cmd: list[str], timeout: int = 30) -> tuple[bool, str, str]: + """Execute command and return success, stdout, stderr""" + try: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + return (result.returncode == 0, result.stdout, result.stderr) + except subprocess.TimeoutExpired: + return (False, "", "Command timed out") + except FileNotFoundError: + return (False, "", f"Command not found: {cmd[0]}") + except Exception as e: + return (False, "", str(e)) + + def _refresh_caches(self) -> None: + """Refresh caches of installed packages""" + self._refresh_apt_cache() + self._refresh_pip_cache() + + def _refresh_apt_cache(self) -> None: + """Parse dpkg status to get installed apt packages""" + logger.debug("Refreshing apt package cache...") + self._installed_apt_cache = {} + + success, stdout, _ = self._run_command(["dpkg-query", "-W", "-f=${Package} ${Version}\n"]) + if success: + for line in stdout.strip().split("\n"): + if line: + parts = line.split(maxsplit=1) + if len(parts) >= 2: + name, version = parts[0], parts[1] + self._installed_apt_cache[name] = InstalledPackage( + name=name, + version=version, + ecosystem=PackageEcosystem.APT, + source="dpkg", + ) + + logger.debug(f"Found {len(self._installed_apt_cache)} installed apt packages") + + def _refresh_pip_cache(self) -> None: + """Get installed pip packages using pip list""" + logger.debug("Refreshing pip package cache...") + self._installed_pip_cache = {} + + # Try pip3 first, then pip + for pip_cmd in ["pip3", "pip"]: + success, stdout, _ = self._run_command([pip_cmd, "list", "--format=json"]) + if success: + try: + packages = json.loads(stdout) + for pkg in packages: + name = pkg.get("name", "").lower() + version = pkg.get("version", "") + if name: + self._installed_pip_cache[name] = InstalledPackage( + name=name, + version=version, + ecosystem=PackageEcosystem.PIP, + source=pip_cmd, + ) + logger.debug(f"Found {len(self._installed_pip_cache)} installed pip packages") + return + except json.JSONDecodeError: + continue + + def _parse_version_constraint(self, constraint: str) -> VersionConstraint | None: + """Parse a version constraint string like '>=1.0,<2.0'""" + constraint = constraint.strip() + if not constraint: + return None + + # Handle operators + operators = ["<=", ">=", "==", "!=", "~=", "<", ">"] + for op in operators: + if constraint.startswith(op): + version = constraint[len(op) :].strip() + return VersionConstraint(operator=op, version=version, original=constraint) + + # No operator means exact match + return VersionConstraint(operator="==", version=constraint, original=constraint) + + def _compare_versions(self, v1: str, v2: str) -> int: + """ + Compare two version strings. + Returns: -1 if v1 < v2, 0 if equal, 1 if v1 > v2 + """ + + def normalize(v: str) -> list[int]: + parts = [] + for part in re.split(r"[.\-+]", v): + # Extract numeric prefix + match = re.match(r"(\d+)", part) + if match: + parts.append(int(match.group(1))) + return parts + + v1_parts = normalize(v1) + v2_parts = normalize(v2) + + # Pad shorter version with zeros + max_len = max(len(v1_parts), len(v2_parts)) + v1_parts.extend([0] * (max_len - len(v1_parts))) + v2_parts.extend([0] * (max_len - len(v2_parts))) + + for p1, p2 in zip(v1_parts, v2_parts): + if p1 < p2: + return -1 + if p1 > p2: + return 1 + return 0 + + def _check_version_satisfies(self, installed_version: str, constraint: VersionConstraint) -> bool: + """Check if installed version satisfies a constraint""" + cmp = self._compare_versions(installed_version, constraint.version) + + if constraint.operator == "==": + return cmp == 0 + elif constraint.operator == "!=": + return cmp != 0 + elif constraint.operator == "<": + return cmp < 0 + elif constraint.operator == "<=": + return cmp <= 0 + elif constraint.operator == ">": + return cmp > 0 + elif constraint.operator == ">=": + return cmp >= 0 + elif constraint.operator == "~=": + # Compatible release (e.g., ~=1.4 means >=1.4, <2.0) + return cmp >= 0 # Simplified + return True + + def _get_pip_package_requirements(self, package_name: str) -> list[DependencyRequirement]: + """Get requirements for a pip package using pip show""" + requirements = [] + + success, stdout, _ = self._run_command(["pip3", "show", package_name]) + if not success: + success, stdout, _ = self._run_command(["pip", "show", package_name]) + + if success: + for line in stdout.split("\n"): + if line.startswith("Requires:"): + deps = line.split(":", 1)[1].strip() + if deps: + for dep in deps.split(","): + dep = dep.strip() + if dep: + requirements.append( + DependencyRequirement( + package_name=dep.lower(), + constraints=[], + ) + ) + + return requirements + + def _get_apt_package_requirements(self, package_name: str) -> list[DependencyRequirement]: + """Get requirements for an apt package""" + requirements = [] + + success, stdout, _ = self._run_command(["apt-cache", "depends", package_name]) + if success: + for line in stdout.split("\n"): + line = line.strip() + if line.startswith("Depends:"): + dep_name = line.split(":", 1)[1].strip() + # Handle alternatives + if "|" in dep_name: + dep_name = dep_name.split("|")[0].strip() + # Remove version constraint for simple matching + dep_name = re.sub(r"\s*\(.*?\)", "", dep_name) + if dep_name: + requirements.append( + DependencyRequirement( + package_name=dep_name, + constraints=[], + ) + ) + + return requirements + + def _analyze_known_conflicts( + self, package_name: str, ecosystem: PackageEcosystem + ) -> list[PredictedConflict]: + """Check against known conflict patterns""" + conflicts = [] + pkg_lower = package_name.lower() + + if pkg_lower not in self.KNOWN_CONFLICTS: + return conflicts + + known = self.KNOWN_CONFLICTS[pkg_lower] + cache = self._installed_pip_cache if ecosystem == PackageEcosystem.PIP else self._installed_apt_cache + + for dep_name, constraint_info in known.items(): + if dep_name in cache: + installed = cache[dep_name] + + # Check for mutual exclusion + if constraint_info.get("mutual_exclusion"): + conflicts.append( + PredictedConflict( + package_to_install=package_name, + package_version=None, + conflicting_package=dep_name, + conflicting_version=installed.version, + installed_by="system", + conflict_type="mutual_exclusion", + severity=ConflictSeverity.CRITICAL, + confidence=0.95, + description=constraint_info.get("reason", "Packages are mutually exclusive"), + ecosystem=ecosystem, + ) + ) + continue + + # Check for port conflicts + if constraint_info.get("port_conflict"): + conflicts.append( + PredictedConflict( + package_to_install=package_name, + package_version=None, + conflicting_package=dep_name, + conflicting_version=installed.version, + installed_by="system", + conflict_type="port_conflict", + severity=ConflictSeverity.HIGH, + confidence=0.85, + description=constraint_info.get("reason", "Services use same port"), + ecosystem=ecosystem, + ) + ) + continue + + # Check version constraints + max_ver = constraint_info.get("max_version") + min_ver = constraint_info.get("min_version") + + if max_ver: + if self._compare_versions(installed.version, max_ver) >= 0: + conflicts.append( + PredictedConflict( + package_to_install=package_name, + package_version=None, + conflicting_package=dep_name, + conflicting_version=installed.version, + installed_by="unknown", + conflict_type="version_too_high", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description=f"{constraint_info.get('reason', '')} " + f"(requires {dep_name}<{max_ver}, you have {installed.version})", + ecosystem=ecosystem, + ) + ) + + if min_ver: + if self._compare_versions(installed.version, min_ver) < 0: + conflicts.append( + PredictedConflict( + package_to_install=package_name, + package_version=None, + conflicting_package=dep_name, + conflicting_version=installed.version, + installed_by="unknown", + conflict_type="version_too_low", + severity=ConflictSeverity.MEDIUM, + confidence=0.85, + description=f"{constraint_info.get('reason', '')} " + f"(requires {dep_name}>={min_ver}, you have {installed.version})", + ecosystem=ecosystem, + ) + ) + + return conflicts + + def _analyze_transitive_conflicts( + self, package_name: str, ecosystem: PackageEcosystem + ) -> list[PredictedConflict]: + """Analyze potential transitive dependency conflicts""" + conflicts = [] + pkg_lower = package_name.lower() + + if pkg_lower not in self.TRANSITIVE_PATTERNS: + return conflicts + + transitive_deps = self.TRANSITIVE_PATTERNS[pkg_lower] + cache = self._installed_pip_cache if ecosystem == PackageEcosystem.PIP else self._installed_apt_cache + + # Check if any transitive dependencies are installed and might conflict + for dep in transitive_deps: + if dep in self.KNOWN_CONFLICTS.get(pkg_lower, {}): + # Already handled by known conflicts + continue + + if dep in cache: + # Check if this dep has known conflicts with other installed packages + for other_pkg, other_info in cache.items(): + if other_pkg == dep: + continue + if dep in self.KNOWN_CONFLICTS.get(other_pkg, {}): + conflicts.append( + PredictedConflict( + package_to_install=package_name, + package_version=None, + conflicting_package=dep, + conflicting_version=cache[dep].version, + installed_by=other_pkg, + conflict_type="transitive_conflict", + severity=ConflictSeverity.MEDIUM, + confidence=0.7, + description=f"Installing {package_name} may affect {dep} " + f"which is also used by {other_pkg}", + ecosystem=ecosystem, + ) + ) + + return conflicts + + def _generate_resolutions( + self, conflicts: list[PredictedConflict], package_name: str + ) -> list[ResolutionSuggestion]: + """Generate resolution suggestions for conflicts""" + resolutions: list[ResolutionSuggestion] = [] + + if not conflicts: + return resolutions + + # Group by severity + critical_conflicts = [c for c in conflicts if c.severity == ConflictSeverity.CRITICAL] + high_conflicts = [c for c in conflicts if c.severity == ConflictSeverity.HIGH] + medium_conflicts = [c for c in conflicts if c.severity == ConflictSeverity.MEDIUM] + + # Handle critical conflicts first + for conflict in critical_conflicts: + if conflict.conflict_type == "mutual_exclusion": + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.REMOVE_CONFLICTING, + description=f"Remove {conflict.conflicting_package} before installing {package_name}", + command=f"sudo apt-get remove {conflict.conflicting_package}", + safety_score=0.4, + side_effects=[ + f"Will remove {conflict.conflicting_package} and dependent packages", + "May affect running services", + ], + recommended=False, + ) + ) + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.SKIP_INSTALL, + description=f"Skip installing {package_name}, keep {conflict.conflicting_package}", + command=None, + safety_score=0.9, + side_effects=["Target package will not be installed"], + recommended=True, + ) + ) + + # Handle version conflicts + for conflict in high_conflicts + medium_conflicts: + if conflict.conflict_type in ["version_too_high", "version_too_low"]: + # Suggest virtual environment for pip packages + if conflict.ecosystem == PackageEcosystem.PIP: + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.USE_VIRTUALENV, + description=f"Create virtual environment to isolate {package_name}", + command=f"python3 -m venv .venv && source .venv/bin/activate && pip install {package_name}", + safety_score=0.95, + side_effects=["Package installed in isolated environment only"], + recommended=True, + ) + ) + + # Suggest upgrading/downgrading + if conflict.conflict_type == "version_too_high": + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.UPGRADE_PACKAGE, + description=f"Install newer version of {package_name} that supports {conflict.conflicting_package} {conflict.conflicting_version}", + command=f"pip install --upgrade {package_name}", + safety_score=0.8, + side_effects=["May get different version than expected"], + recommended=True, + ) + ) + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.DOWNGRADE_DEPENDENCY, + description=f"Downgrade {conflict.conflicting_package} to compatible version", + command=f"pip install {conflict.conflicting_package}<{conflict.conflicting_version}", + safety_score=0.5, + side_effects=[ + f"May break packages depending on {conflict.conflicting_package}" + ], + recommended=False, + ) + ) + + if conflict.conflict_type == "port_conflict": + resolutions.append( + ResolutionSuggestion( + strategy=ResolutionStrategy.PIN_VERSION, + description=f"Configure {package_name} to use a different port", + command=None, + safety_score=0.85, + side_effects=["Requires manual configuration"], + recommended=True, + ) + ) + + # Sort by safety score (highest first) + resolutions.sort(key=lambda r: (-r.safety_score, not r.recommended)) + + # Mark the safest recommended option + for i, res in enumerate(resolutions): + if res.recommended: + resolutions[i] = ResolutionSuggestion( + strategy=res.strategy, + description=res.description + " [RECOMMENDED]", + command=res.command, + safety_score=res.safety_score, + side_effects=res.side_effects, + recommended=True, + ) + break + + return resolutions + + def _determine_overall_risk(self, conflicts: list[PredictedConflict]) -> ConflictSeverity: + """Determine overall risk level from conflicts""" + if not conflicts: + return ConflictSeverity.LOW + + severities = [c.severity for c in conflicts] + + if ConflictSeverity.CRITICAL in severities: + return ConflictSeverity.CRITICAL + if ConflictSeverity.HIGH in severities: + return ConflictSeverity.HIGH + if ConflictSeverity.MEDIUM in severities: + return ConflictSeverity.MEDIUM + return ConflictSeverity.LOW + + def _detect_ecosystem(self, package_name: str) -> PackageEcosystem: + """Detect which ecosystem a package belongs to""" + # Check if it's a known pip package + pip_indicators = [ + "numpy", + "pandas", + "tensorflow", + "torch", + "pytorch", + "flask", + "django", + "requests", + "scipy", + "matplotlib", + "scikit-learn", + "pillow", + ] + if package_name.lower() in pip_indicators: + return PackageEcosystem.PIP + + # Check if it's a known apt package + apt_indicators = [ + "nginx", + "apache2", + "mysql-server", + "mariadb-server", + "postgresql", + "redis-server", + "docker", + "nodejs", + ] + if package_name.lower() in apt_indicators: + return PackageEcosystem.APT + + # Try to detect by checking package availability + success, _, _ = self._run_command(["apt-cache", "show", package_name]) + if success: + return PackageEcosystem.APT + + success, _, _ = self._run_command(["pip3", "show", package_name]) + if success: + return PackageEcosystem.PIP + + # Default to system for unknown + return PackageEcosystem.SYSTEM + + async def predict_conflicts_async( + self, package_name: str, version: str | None = None + ) -> ConflictPrediction: + """ + Async version that uses LLM for enhanced conflict analysis. + """ + # Start with local analysis + prediction = self.predict_conflicts(package_name, version) + + # If LLM router available, enhance with AI analysis + if self.llm_router and prediction.conflicts: + try: + from cortex.llm_router import TaskType + + # Build context for LLM + conflict_descriptions = [ + f"- {c.conflicting_package} {c.conflicting_version}: {c.description}" + for c in prediction.conflicts + ] + + prompt = f"""Analyze these potential dependency conflicts for installing {package_name}: + +{chr(10).join(conflict_descriptions)} + +Installed packages context: +- Pip packages: {len(self._installed_pip_cache)} +- Apt packages: {len(self._installed_apt_cache)} + +Provide: +1. Risk assessment (low/medium/high/critical) +2. Most likely cause of conflicts +3. Best resolution approach +4. Any additional conflicts I might have missed + +Be concise and actionable.""" + + response = await self.llm_router.acomplete( + messages=[ + {"role": "system", "content": "You are a Linux package dependency expert."}, + {"role": "user", "content": prompt}, + ], + task_type=TaskType.DEPENDENCY_RESOLUTION, + temperature=0.3, + max_tokens=1000, + ) + + # Add LLM analysis to details + prediction.analysis_details["llm_analysis"] = response.content + prediction.analysis_details["llm_provider"] = response.provider.value + + except Exception as e: + logger.warning(f"LLM analysis failed: {e}") + prediction.analysis_details["llm_error"] = str(e) + + return prediction + + def predict_conflicts( + self, package_name: str, version: str | None = None + ) -> ConflictPrediction: + """ + Predict potential conflicts before installing a package. + + Args: + package_name: Name of package to install + version: Optional specific version to install + + Returns: + ConflictPrediction with all conflicts and resolutions + """ + logger.info(f"Predicting conflicts for {package_name}...") + + # Refresh caches for latest state + self._refresh_caches() + + # Detect ecosystem + ecosystem = self._detect_ecosystem(package_name) + logger.debug(f"Detected ecosystem: {ecosystem.value}") + + all_conflicts: list[PredictedConflict] = [] + + # 1. Check known conflict patterns + known_conflicts = self._analyze_known_conflicts(package_name, ecosystem) + all_conflicts.extend(known_conflicts) + + # 2. Analyze transitive dependencies + transitive_conflicts = self._analyze_transitive_conflicts(package_name, ecosystem) + all_conflicts.extend(transitive_conflicts) + + # 3. Generate resolutions + resolutions = self._generate_resolutions(all_conflicts, package_name) + + # 4. Determine overall risk + overall_risk = self._determine_overall_risk(all_conflicts) + + # 5. Calculate can_proceed + can_proceed = overall_risk != ConflictSeverity.CRITICAL + + # 6. Calculate prediction confidence + if all_conflicts: + avg_confidence = sum(c.confidence for c in all_conflicts) / len(all_conflicts) + else: + avg_confidence = 0.9 # High confidence that there are no conflicts + + prediction = ConflictPrediction( + package_name=package_name, + package_version=version, + conflicts=all_conflicts, + resolutions=resolutions, + overall_risk=overall_risk, + can_proceed=can_proceed, + prediction_confidence=avg_confidence, + analysis_details={ + "ecosystem": ecosystem.value, + "installed_pip_packages": len(self._installed_pip_cache), + "installed_apt_packages": len(self._installed_apt_cache), + "known_conflicts_checked": len(self.KNOWN_CONFLICTS), + "transitive_patterns_checked": len(self.TRANSITIVE_PATTERNS), + }, + ) + + return prediction + + def predict_multiple(self, packages: list[str]) -> list[ConflictPrediction]: + """ + Predict conflicts for multiple packages. + + Args: + packages: List of package names + + Returns: + List of ConflictPrediction objects + """ + predictions = [] + for pkg in packages: + prediction = self.predict_conflicts(pkg) + predictions.append(prediction) + return predictions + + def format_prediction(self, prediction: ConflictPrediction) -> str: + """Format prediction for CLI output""" + lines = [] + + if not prediction.conflicts: + lines.append(f"No conflicts predicted for {prediction.package_name}") + lines.append(f" Confidence: {prediction.prediction_confidence:.0%}") + return "\n".join(lines) + + # Header with risk indicator + risk_emoji = { + ConflictSeverity.LOW: "", + ConflictSeverity.MEDIUM: "", + ConflictSeverity.HIGH: "", + ConflictSeverity.CRITICAL: "", + } + + lines.append( + f"{risk_emoji.get(prediction.overall_risk, '')} " + f"Conflict predicted for {prediction.package_name}" + ) + lines.append("") + + # List conflicts + for i, conflict in enumerate(prediction.conflicts, 1): + severity_badge = f"[{conflict.severity.value.upper()}]" + lines.append(f" {i}. {severity_badge} {conflict.description}") + lines.append( + f" Package: {conflict.conflicting_package} {conflict.conflicting_version}" + ) + lines.append(f" Type: {conflict.conflict_type}") + lines.append(f" Confidence: {conflict.confidence:.0%}") + lines.append("") + + # Suggestions + if prediction.resolutions: + lines.append(" Suggestions (ranked by safety):") + for i, res in enumerate(prediction.resolutions[:4], 1): + recommended = " [RECOMMENDED]" if res.recommended else "" + lines.append(f" {i}. {res.description}") + if res.command: + lines.append(f" Command: {res.command}") + lines.append(f" Safety: {res.safety_score:.0%}{recommended}") + if res.side_effects: + lines.append(f" Note: {res.side_effects[0]}") + lines.append("") + + return "\n".join(lines) + + def export_prediction_json(self, prediction: ConflictPrediction) -> dict[str, Any]: + """Export prediction to JSON-serializable dict""" + return { + "package_name": prediction.package_name, + "package_version": prediction.package_version, + "conflicts": [ + { + "package_to_install": c.package_to_install, + "conflicting_package": c.conflicting_package, + "conflicting_version": c.conflicting_version, + "conflict_type": c.conflict_type, + "severity": c.severity.value, + "confidence": c.confidence, + "description": c.description, + "ecosystem": c.ecosystem.value, + } + for c in prediction.conflicts + ], + "resolutions": [ + { + "strategy": r.strategy.value, + "description": r.description, + "command": r.command, + "safety_score": r.safety_score, + "side_effects": r.side_effects, + "recommended": r.recommended, + } + for r in prediction.resolutions + ], + "overall_risk": prediction.overall_risk.value, + "can_proceed": prediction.can_proceed, + "prediction_confidence": prediction.prediction_confidence, + "analysis_details": prediction.analysis_details, + } + + +# CLI Interface +if __name__ == "__main__": + import argparse + import sys + + parser = argparse.ArgumentParser(description="Predict dependency conflicts before installation") + parser.add_argument("package", help="Package name to analyze") + parser.add_argument("--version", "-v", help="Specific version to check") + parser.add_argument("--json", "-j", action="store_true", help="Output as JSON") + parser.add_argument("--verbose", action="store_true", help="Verbose output") + + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + + predictor = DependencyConflictPredictor() + prediction = predictor.predict_conflicts(args.package, args.version) + + if args.json: + print(json.dumps(predictor.export_prediction_json(prediction), indent=2)) + else: + print(predictor.format_prediction(prediction)) + + if prediction.overall_risk == ConflictSeverity.CRITICAL: + sys.exit(1) + elif prediction.overall_risk == ConflictSeverity.HIGH: + sys.exit(2) diff --git a/cortex/llm_router.py b/cortex/llm_router.py index d4bb3a21..dbfbc052 100644 --- a/cortex/llm_router.py +++ b/cortex/llm_router.py @@ -38,6 +38,7 @@ class TaskType(Enum): ERROR_DEBUGGING = "error_debugging" # Diagnosing failures CODE_GENERATION = "code_generation" # Writing scripts DEPENDENCY_RESOLUTION = "dependency_resolution" # Figuring out deps + CONFLICT_PREDICTION = "conflict_prediction" # AI-powered conflict prediction CONFIGURATION = "configuration" # System config files TOOL_EXECUTION = "tool_execution" # Running system tools @@ -110,6 +111,7 @@ class LLMRouter: TaskType.ERROR_DEBUGGING: LLMProvider.KIMI_K2, TaskType.CODE_GENERATION: LLMProvider.KIMI_K2, TaskType.DEPENDENCY_RESOLUTION: LLMProvider.KIMI_K2, + TaskType.CONFLICT_PREDICTION: LLMProvider.KIMI_K2, # Technical analysis TaskType.CONFIGURATION: LLMProvider.KIMI_K2, TaskType.TOOL_EXECUTION: LLMProvider.KIMI_K2, } diff --git a/tests/test_cli.py b/tests/test_cli.py index bed29ab4..82906cf7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -212,7 +212,9 @@ def test_main_install_command(self, mock_install): mock_install.return_value = 0 result = main() self.assertEqual(result, 0) - mock_install.assert_called_once_with("docker", execute=False, dry_run=False, parallel=False) + mock_install.assert_called_once_with( + "docker", execute=False, dry_run=False, parallel=False, predict=False, no_predict=False + ) @patch("sys.argv", ["cortex", "install", "docker", "--execute"]) @patch("cortex.cli.CortexCLI.install") @@ -220,7 +222,9 @@ def test_main_install_with_execute(self, mock_install): mock_install.return_value = 0 result = main() self.assertEqual(result, 0) - mock_install.assert_called_once_with("docker", execute=True, dry_run=False, parallel=False) + mock_install.assert_called_once_with( + "docker", execute=True, dry_run=False, parallel=False, predict=False, no_predict=False + ) @patch("sys.argv", ["cortex", "install", "docker", "--dry-run"]) @patch("cortex.cli.CortexCLI.install") @@ -228,7 +232,49 @@ def test_main_install_with_dry_run(self, mock_install): mock_install.return_value = 0 result = main() self.assertEqual(result, 0) - mock_install.assert_called_once_with("docker", execute=False, dry_run=True, parallel=False) + mock_install.assert_called_once_with( + "docker", execute=False, dry_run=True, parallel=False, predict=False, no_predict=False + ) + + @patch("sys.argv", ["cortex", "install", "tensorflow", "--predict"]) + @patch("cortex.cli.CortexCLI.install") + def test_main_install_with_predict(self, mock_install): + """Test install command with --predict flag""" + mock_install.return_value = 0 + result = main() + self.assertEqual(result, 0) + mock_install.assert_called_once_with( + "tensorflow", execute=False, dry_run=False, parallel=False, predict=True, no_predict=False + ) + + @patch("sys.argv", ["cortex", "install", "numpy", "--execute", "--no-predict"]) + @patch("cortex.cli.CortexCLI.install") + def test_main_install_with_no_predict(self, mock_install): + """Test install command with --no-predict flag""" + mock_install.return_value = 0 + result = main() + self.assertEqual(result, 0) + mock_install.assert_called_once_with( + "numpy", execute=True, dry_run=False, parallel=False, predict=False, no_predict=True + ) + + @patch("sys.argv", ["cortex", "predict", "tensorflow"]) + @patch("cortex.cli.CortexCLI.predict") + def test_main_predict_command(self, mock_predict): + """Test standalone predict command""" + mock_predict.return_value = 0 + result = main() + self.assertEqual(result, 0) + mock_predict.assert_called_once_with("tensorflow", json_output=False, verbose=False) + + @patch("sys.argv", ["cortex", "predict", "numpy", "--json"]) + @patch("cortex.cli.CortexCLI.predict") + def test_main_predict_with_json(self, mock_predict): + """Test predict command with --json flag""" + mock_predict.return_value = 0 + result = main() + self.assertEqual(result, 0) + mock_predict.assert_called_once_with("numpy", json_output=True, verbose=False) def test_spinner_animation(self): initial_idx = self.cli.spinner_idx diff --git a/tests/test_conflict_predictor.py b/tests/test_conflict_predictor.py new file mode 100644 index 00000000..967c76d0 --- /dev/null +++ b/tests/test_conflict_predictor.py @@ -0,0 +1,640 @@ +#!/usr/bin/env python3 +""" +Tests for AI-Powered Dependency Conflict Predictor + +This module tests the conflict prediction system that analyzes +dependency graphs before installation to predict and prevent conflicts. +""" + +import json +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from cortex.conflict_predictor import ( + ConflictPrediction, + ConflictSeverity, + DependencyConflictPredictor, + InstalledPackage, + PackageEcosystem, + PredictedConflict, + ResolutionStrategy, + ResolutionSuggestion, + VersionConstraint, +) + + +class TestVersionConstraint(unittest.TestCase): + """Test version constraint parsing""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_parse_equal_constraint(self): + """Test parsing == constraint""" + constraint = self.predictor._parse_version_constraint("==1.0.0") + + self.assertIsNotNone(constraint) + self.assertEqual(constraint.operator, "==") + self.assertEqual(constraint.version, "1.0.0") + + def test_parse_greater_equal_constraint(self): + """Test parsing >= constraint""" + constraint = self.predictor._parse_version_constraint(">=2.0.0") + + self.assertIsNotNone(constraint) + self.assertEqual(constraint.operator, ">=") + self.assertEqual(constraint.version, "2.0.0") + + def test_parse_less_than_constraint(self): + """Test parsing < constraint""" + constraint = self.predictor._parse_version_constraint("<3.0") + + self.assertIsNotNone(constraint) + self.assertEqual(constraint.operator, "<") + self.assertEqual(constraint.version, "3.0") + + def test_parse_compatible_release_constraint(self): + """Test parsing ~= constraint""" + constraint = self.predictor._parse_version_constraint("~=1.4.2") + + self.assertIsNotNone(constraint) + self.assertEqual(constraint.operator, "~=") + self.assertEqual(constraint.version, "1.4.2") + + def test_parse_empty_constraint(self): + """Test parsing empty constraint""" + constraint = self.predictor._parse_version_constraint("") + + self.assertIsNone(constraint) + + +class TestVersionComparison(unittest.TestCase): + """Test version comparison logic""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_equal_versions(self): + """Test equal version comparison""" + result = self.predictor._compare_versions("1.0.0", "1.0.0") + self.assertEqual(result, 0) + + def test_greater_version(self): + """Test greater version comparison""" + result = self.predictor._compare_versions("2.0.0", "1.0.0") + self.assertEqual(result, 1) + + def test_lesser_version(self): + """Test lesser version comparison""" + result = self.predictor._compare_versions("1.0.0", "2.0.0") + self.assertEqual(result, -1) + + def test_patch_version_comparison(self): + """Test patch version comparison""" + result = self.predictor._compare_versions("1.0.1", "1.0.0") + self.assertEqual(result, 1) + + def test_minor_version_comparison(self): + """Test minor version comparison""" + result = self.predictor._compare_versions("1.1.0", "1.0.0") + self.assertEqual(result, 1) + + def test_different_length_versions(self): + """Test comparing versions with different lengths""" + result = self.predictor._compare_versions("1.0", "1.0.0") + self.assertEqual(result, 0) + + def test_complex_version_strings(self): + """Test complex version strings with suffixes""" + result = self.predictor._compare_versions("2.1.0-beta", "2.0.0") + self.assertEqual(result, 1) + + +class TestVersionConstraintSatisfaction(unittest.TestCase): + """Test version constraint satisfaction checking""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_equal_constraint_satisfied(self): + """Test == constraint satisfaction""" + constraint = VersionConstraint(operator="==", version="1.0.0", original="==1.0.0") + self.assertTrue(self.predictor._check_version_satisfies("1.0.0", constraint)) + self.assertFalse(self.predictor._check_version_satisfies("1.0.1", constraint)) + + def test_greater_equal_constraint_satisfied(self): + """Test >= constraint satisfaction""" + constraint = VersionConstraint(operator=">=", version="2.0.0", original=">=2.0.0") + self.assertTrue(self.predictor._check_version_satisfies("2.0.0", constraint)) + self.assertTrue(self.predictor._check_version_satisfies("3.0.0", constraint)) + self.assertFalse(self.predictor._check_version_satisfies("1.9.9", constraint)) + + def test_less_than_constraint_satisfied(self): + """Test < constraint satisfaction""" + constraint = VersionConstraint(operator="<", version="2.0.0", original="<2.0.0") + self.assertTrue(self.predictor._check_version_satisfies("1.9.9", constraint)) + self.assertFalse(self.predictor._check_version_satisfies("2.0.0", constraint)) + self.assertFalse(self.predictor._check_version_satisfies("2.0.1", constraint)) + + def test_not_equal_constraint_satisfied(self): + """Test != constraint satisfaction""" + constraint = VersionConstraint(operator="!=", version="1.5.0", original="!=1.5.0") + self.assertTrue(self.predictor._check_version_satisfies("1.4.0", constraint)) + self.assertTrue(self.predictor._check_version_satisfies("1.6.0", constraint)) + self.assertFalse(self.predictor._check_version_satisfies("1.5.0", constraint)) + + +class TestKnownConflictPatterns(unittest.TestCase): + """Test detection of known conflict patterns""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + # Mock the pip cache with numpy 2.1.0 installed + self.predictor._installed_pip_cache = { + "numpy": InstalledPackage( + name="numpy", + version="2.1.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ), + } + + def test_tensorflow_numpy_conflict(self): + """Test detection of TensorFlow + NumPy conflict""" + conflicts = self.predictor._analyze_known_conflicts( + "tensorflow", PackageEcosystem.PIP + ) + + self.assertGreater(len(conflicts), 0) + numpy_conflict = next( + (c for c in conflicts if c.conflicting_package == "numpy"), None + ) + self.assertIsNotNone(numpy_conflict) + self.assertIn("numpy", numpy_conflict.description.lower()) + + def test_no_conflict_when_compatible(self): + """Test no conflict when versions are compatible""" + # Set numpy to a compatible version + self.predictor._installed_pip_cache["numpy"] = InstalledPackage( + name="numpy", + version="1.24.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ) + + conflicts = self.predictor._analyze_known_conflicts( + "tensorflow", PackageEcosystem.PIP + ) + + numpy_conflict = next( + (c for c in conflicts if c.conflicting_package == "numpy"), None + ) + self.assertIsNone(numpy_conflict) + + +class TestMutualExclusionConflicts(unittest.TestCase): + """Test detection of mutually exclusive packages""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + # Mock mariadb-server being installed + self.predictor._installed_apt_cache = { + "mariadb-server": InstalledPackage( + name="mariadb-server", + version="10.6.12", + ecosystem=PackageEcosystem.APT, + source="dpkg", + ), + } + + def test_mysql_mariadb_conflict(self): + """Test MySQL/MariaDB mutual exclusion conflict""" + conflicts = self.predictor._analyze_known_conflicts( + "mysql-server", PackageEcosystem.APT + ) + + self.assertGreater(len(conflicts), 0) + mariadb_conflict = next( + (c for c in conflicts if c.conflicting_package == "mariadb-server"), None + ) + self.assertIsNotNone(mariadb_conflict) + self.assertEqual(mariadb_conflict.conflict_type, "mutual_exclusion") + self.assertEqual(mariadb_conflict.severity, ConflictSeverity.CRITICAL) + + +class TestConflictPrediction(unittest.TestCase): + """Test the main prediction functionality""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + @patch.object(DependencyConflictPredictor, "_run_command") + def test_predict_no_conflicts(self, mock_run): + """Test prediction when no conflicts exist""" + # Mock empty caches + mock_run.return_value = (True, "", "") + self.predictor._installed_pip_cache = {} + self.predictor._installed_apt_cache = {} + + prediction = self.predictor.predict_conflicts("flask") + + self.assertIsInstance(prediction, ConflictPrediction) + self.assertEqual(len(prediction.conflicts), 0) + self.assertEqual(prediction.overall_risk, ConflictSeverity.LOW) + self.assertTrue(prediction.can_proceed) + + def test_predict_with_conflicts(self): + """Test prediction when conflicts exist""" + # Set up conflicting state + self.predictor._installed_pip_cache = { + "numpy": InstalledPackage( + name="numpy", + version="2.1.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ), + } + + prediction = self.predictor.predict_conflicts("tensorflow") + + self.assertIsInstance(prediction, ConflictPrediction) + self.assertGreater(len(prediction.conflicts), 0) + self.assertIn( + prediction.overall_risk, + [ConflictSeverity.HIGH, ConflictSeverity.CRITICAL], + ) + + def test_predict_multiple_packages(self): + """Test predicting conflicts for multiple packages""" + self.predictor._installed_pip_cache = {} + self.predictor._installed_apt_cache = {} + + predictions = self.predictor.predict_multiple(["flask", "django", "numpy"]) + + self.assertEqual(len(predictions), 3) + for pred in predictions: + self.assertIsInstance(pred, ConflictPrediction) + + +class TestResolutionSuggestions(unittest.TestCase): + """Test resolution suggestion generation""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_virtualenv_suggestion_for_pip(self): + """Test that virtualenv is suggested for pip conflicts""" + conflicts = [ + PredictedConflict( + package_to_install="tensorflow", + package_version=None, + conflicting_package="numpy", + conflicting_version="2.1.0", + installed_by="pandas", + conflict_type="version_too_high", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description="TensorFlow requires numpy<2.0", + ecosystem=PackageEcosystem.PIP, + ) + ] + + resolutions = self.predictor._generate_resolutions(conflicts, "tensorflow") + + self.assertGreater(len(resolutions), 0) + + # Check for virtualenv suggestion + venv_suggestions = [ + r for r in resolutions if r.strategy == ResolutionStrategy.USE_VIRTUALENV + ] + self.assertGreater(len(venv_suggestions), 0) + self.assertTrue(venv_suggestions[0].recommended) + + def test_resolution_safety_ranking(self): + """Test that resolutions are ranked by safety""" + conflicts = [ + PredictedConflict( + package_to_install="tensorflow", + package_version=None, + conflicting_package="numpy", + conflicting_version="2.1.0", + installed_by="pandas", + conflict_type="version_too_high", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description="Test conflict", + ecosystem=PackageEcosystem.PIP, + ) + ] + + resolutions = self.predictor._generate_resolutions(conflicts, "tensorflow") + + # Check that resolutions are sorted by safety score (descending) + for i in range(len(resolutions) - 1): + self.assertGreaterEqual( + resolutions[i].safety_score, resolutions[i + 1].safety_score + ) + + def test_mutual_exclusion_resolution(self): + """Test resolution suggestions for mutual exclusion conflicts""" + conflicts = [ + PredictedConflict( + package_to_install="mysql-server", + package_version=None, + conflicting_package="mariadb-server", + conflicting_version="10.6.12", + installed_by="system", + conflict_type="mutual_exclusion", + severity=ConflictSeverity.CRITICAL, + confidence=0.95, + description="Cannot have both MySQL and MariaDB", + ecosystem=PackageEcosystem.APT, + ) + ] + + resolutions = self.predictor._generate_resolutions(conflicts, "mysql-server") + + # Should have remove and skip options + strategies = [r.strategy for r in resolutions] + self.assertIn(ResolutionStrategy.REMOVE_CONFLICTING, strategies) + self.assertIn(ResolutionStrategy.SKIP_INSTALL, strategies) + + +class TestRiskAssessment(unittest.TestCase): + """Test overall risk level determination""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_no_conflicts_low_risk(self): + """Test that no conflicts means low risk""" + risk = self.predictor._determine_overall_risk([]) + self.assertEqual(risk, ConflictSeverity.LOW) + + def test_critical_conflict_critical_risk(self): + """Test that critical conflicts result in critical risk""" + conflicts = [ + PredictedConflict( + package_to_install="test", + package_version=None, + conflicting_package="other", + conflicting_version="1.0", + installed_by="system", + conflict_type="test", + severity=ConflictSeverity.CRITICAL, + confidence=0.9, + description="Test", + ecosystem=PackageEcosystem.APT, + ) + ] + + risk = self.predictor._determine_overall_risk(conflicts) + self.assertEqual(risk, ConflictSeverity.CRITICAL) + + def test_mixed_severity_highest_wins(self): + """Test that highest severity determines overall risk""" + conflicts = [ + PredictedConflict( + package_to_install="test", + package_version=None, + conflicting_package="pkg1", + conflicting_version="1.0", + installed_by="system", + conflict_type="test", + severity=ConflictSeverity.LOW, + confidence=0.9, + description="Test", + ecosystem=PackageEcosystem.APT, + ), + PredictedConflict( + package_to_install="test", + package_version=None, + conflicting_package="pkg2", + conflicting_version="1.0", + installed_by="system", + conflict_type="test", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description="Test", + ecosystem=PackageEcosystem.APT, + ), + ] + + risk = self.predictor._determine_overall_risk(conflicts) + self.assertEqual(risk, ConflictSeverity.HIGH) + + +class TestEcosystemDetection(unittest.TestCase): + """Test package ecosystem detection""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_detect_pip_package(self): + """Test detection of pip packages""" + ecosystem = self.predictor._detect_ecosystem("numpy") + self.assertEqual(ecosystem, PackageEcosystem.PIP) + + def test_detect_apt_package(self): + """Test detection of apt packages""" + ecosystem = self.predictor._detect_ecosystem("nginx") + self.assertEqual(ecosystem, PackageEcosystem.APT) + + +class TestOutputFormatting(unittest.TestCase): + """Test output formatting functions""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_format_no_conflicts(self): + """Test formatting when no conflicts""" + prediction = ConflictPrediction( + package_name="flask", + package_version=None, + conflicts=[], + resolutions=[], + overall_risk=ConflictSeverity.LOW, + can_proceed=True, + prediction_confidence=0.9, + ) + + output = self.predictor.format_prediction(prediction) + + self.assertIn("No conflicts predicted", output) + self.assertIn("flask", output) + + def test_format_with_conflicts(self): + """Test formatting with conflicts""" + prediction = ConflictPrediction( + package_name="tensorflow", + package_version=None, + conflicts=[ + PredictedConflict( + package_to_install="tensorflow", + package_version=None, + conflicting_package="numpy", + conflicting_version="2.1.0", + installed_by="pandas", + conflict_type="version_too_high", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description="TensorFlow requires numpy<2.0", + ecosystem=PackageEcosystem.PIP, + ) + ], + resolutions=[ + ResolutionSuggestion( + strategy=ResolutionStrategy.USE_VIRTUALENV, + description="Use virtual environment", + command="python3 -m venv .venv", + safety_score=0.95, + side_effects=[], + recommended=True, + ) + ], + overall_risk=ConflictSeverity.HIGH, + can_proceed=True, + prediction_confidence=0.85, + ) + + output = self.predictor.format_prediction(prediction) + + self.assertIn("tensorflow", output) + self.assertIn("numpy", output) + self.assertIn("Suggestions", output) + + def test_export_json(self): + """Test JSON export of predictions""" + prediction = ConflictPrediction( + package_name="tensorflow", + package_version="2.15.0", + conflicts=[], + resolutions=[], + overall_risk=ConflictSeverity.LOW, + can_proceed=True, + prediction_confidence=0.9, + analysis_details={"test": "value"}, + ) + + exported = self.predictor.export_prediction_json(prediction) + + self.assertIsInstance(exported, dict) + self.assertEqual(exported["package_name"], "tensorflow") + self.assertEqual(exported["overall_risk"], "low") + self.assertTrue(exported["can_proceed"]) + self.assertEqual(exported["analysis_details"]["test"], "value") + + +class TestTransitiveConflicts(unittest.TestCase): + """Test transitive dependency conflict detection""" + + def setUp(self): + self.predictor = DependencyConflictPredictor() + + def test_transitive_tensorflow_conflicts(self): + """Test detection of TensorFlow's transitive dependency conflicts""" + # Set up a scenario where numpy is installed and could conflict + self.predictor._installed_pip_cache = { + "numpy": InstalledPackage( + name="numpy", + version="1.24.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ), + "protobuf": InstalledPackage( + name="protobuf", + version="3.20.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ), + } + + conflicts = self.predictor._analyze_transitive_conflicts( + "tensorflow", PackageEcosystem.PIP + ) + + # Should detect potential transitive conflicts + self.assertIsInstance(conflicts, list) + + +class TestDataclasses(unittest.TestCase): + """Test dataclass structures""" + + def test_installed_package_creation(self): + """Test InstalledPackage dataclass""" + pkg = InstalledPackage( + name="test-pkg", + version="1.0.0", + ecosystem=PackageEcosystem.PIP, + source="pip3", + ) + + self.assertEqual(pkg.name, "test-pkg") + self.assertEqual(pkg.version, "1.0.0") + self.assertEqual(pkg.ecosystem, PackageEcosystem.PIP) + + def test_predicted_conflict_creation(self): + """Test PredictedConflict dataclass""" + conflict = PredictedConflict( + package_to_install="tensorflow", + package_version="2.15.0", + conflicting_package="numpy", + conflicting_version="2.1.0", + installed_by="pandas", + conflict_type="version_too_high", + severity=ConflictSeverity.HIGH, + confidence=0.9, + description="Test conflict", + ecosystem=PackageEcosystem.PIP, + ) + + self.assertEqual(conflict.package_to_install, "tensorflow") + self.assertEqual(conflict.severity, ConflictSeverity.HIGH) + self.assertEqual(conflict.confidence, 0.9) + + def test_resolution_suggestion_creation(self): + """Test ResolutionSuggestion dataclass""" + suggestion = ResolutionSuggestion( + strategy=ResolutionStrategy.USE_VIRTUALENV, + description="Create virtual environment", + command="python3 -m venv .venv", + safety_score=0.95, + side_effects=["Isolates packages"], + recommended=True, + ) + + self.assertEqual(suggestion.strategy, ResolutionStrategy.USE_VIRTUALENV) + self.assertEqual(suggestion.safety_score, 0.95) + self.assertTrue(suggestion.recommended) + + +class TestEnums(unittest.TestCase): + """Test enum values""" + + def test_conflict_severity_values(self): + """Test ConflictSeverity enum values""" + self.assertEqual(ConflictSeverity.LOW.value, "low") + self.assertEqual(ConflictSeverity.MEDIUM.value, "medium") + self.assertEqual(ConflictSeverity.HIGH.value, "high") + self.assertEqual(ConflictSeverity.CRITICAL.value, "critical") + + def test_resolution_strategy_values(self): + """Test ResolutionStrategy enum values""" + self.assertEqual(ResolutionStrategy.UPGRADE_PACKAGE.value, "upgrade_package") + self.assertEqual(ResolutionStrategy.USE_VIRTUALENV.value, "use_virtualenv") + self.assertEqual(ResolutionStrategy.SKIP_INSTALL.value, "skip_install") + + def test_package_ecosystem_values(self): + """Test PackageEcosystem enum values""" + self.assertEqual(PackageEcosystem.APT.value, "apt") + self.assertEqual(PackageEcosystem.PIP.value, "pip") + self.assertEqual(PackageEcosystem.NPM.value, "npm") + + +if __name__ == "__main__": + unittest.main()