From 9ef3a4663c0a41609e665b35f2cdccec7ba358d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B9=E6=B0=B8=E8=B5=AB?= <1259085392@qq.com> Date: Tue, 28 Apr 2026 20:37:25 +0900 Subject: [PATCH] fix: protect desktop plugin installs with core lock --- astrbot/core/utils/core_constraints.py | 10 ++- astrbot/core/utils/desktop_core_lock.py | 108 ++++++++++++++++++++++++ astrbot/core/utils/pip_installer.py | 7 ++ tests/test_pip_installer.py | 96 +++++++++++++++++++++ 4 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 astrbot/core/utils/desktop_core_lock.py diff --git a/astrbot/core/utils/core_constraints.py b/astrbot/core/utils/core_constraints.py index b43f001227..ed353e738b 100644 --- a/astrbot/core/utils/core_constraints.py +++ b/astrbot/core/utils/core_constraints.py @@ -7,6 +7,7 @@ from packaging.requirements import Requirement +from astrbot.core.utils.desktop_core_lock import get_desktop_core_lock_constraints from astrbot.core.utils.requirements_utils import ( canonicalize_distribution_name, collect_installed_distribution_versions, @@ -93,7 +94,14 @@ def __init__(self, core_dist_name: str | None) -> None: @contextlib.contextmanager def constraints_file(self) -> Iterator[str | None]: - constraints = _get_core_constraints(self._core_dist_name) + constraints = tuple( + dict.fromkeys( + ( + *_get_core_constraints(self._core_dist_name), + *get_desktop_core_lock_constraints(), + ) + ) + ) if not constraints: yield None return diff --git a/astrbot/core/utils/desktop_core_lock.py b/astrbot/core/utils/desktop_core_lock.py new file mode 100644 index 0000000000..933c9fb307 --- /dev/null +++ b/astrbot/core/utils/desktop_core_lock.py @@ -0,0 +1,108 @@ +import json +import logging +import os +import re +from functools import lru_cache +from typing import Any + +from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime + +logger = logging.getLogger("astrbot") + +DESKTOP_CORE_LOCK_PATH_ENV = "ASTRBOT_DESKTOP_CORE_LOCK_PATH" + + +def _canonicalize_distribution_name(name: str) -> str: + return re.sub(r"[-_.]+", "-", name).strip("-").lower() + + +def _safe_requirement_pin(name: str, version: str) -> str | None: + if not name or not version: + return None + if any(char.isspace() for char in name) or any(char.isspace() for char in version): + return None + return f"{name}=={version}" + + +def _fallback_module_name(name: str) -> str: + return _canonicalize_distribution_name(name).replace("-", "_") + + +def _iter_distribution_records(data: Any): + if not isinstance(data, dict): + return + distributions = data.get("distributions", []) + if not isinstance(distributions, list): + return + for record in distributions: + if isinstance(record, dict): + yield record + + +@lru_cache(maxsize=8) +def _load_lock_data(lock_path: str) -> dict[str, Any] | None: + try: + with open(lock_path, encoding="utf-8") as file: + data = json.load(file) + except FileNotFoundError: + logger.warning("桌面端核心依赖锁不存在: %s", lock_path) + return None + except Exception as exc: + logger.warning("读取桌面端核心依赖锁失败: %s", exc) + return None + + if not isinstance(data, dict): + logger.warning("桌面端核心依赖锁格式无效: %s", lock_path) + return None + return data + + +def _resolve_lock_data() -> dict[str, Any] | None: + if not is_packaged_desktop_runtime(): + return None + + lock_path = os.environ.get(DESKTOP_CORE_LOCK_PATH_ENV, "").strip() + if not lock_path: + return None + return _load_lock_data(lock_path) + + +def get_desktop_core_lock_constraints() -> tuple[str, ...]: + data = _resolve_lock_data() + if not data: + return () + + constraints: dict[str, str] = {} + for record in _iter_distribution_records(data): + name = record.get("name") + version = record.get("version") + if not isinstance(name, str) or not isinstance(version, str): + continue + + pin = _safe_requirement_pin(name, version) + if not pin: + continue + constraints.setdefault(_canonicalize_distribution_name(name), pin) + + return tuple(constraints[key] for key in sorted(constraints)) + + +def get_desktop_core_lock_modules() -> frozenset[str]: + data = _resolve_lock_data() + if not data: + return frozenset() + + modules: set[str] = set() + for record in _iter_distribution_records(data): + name = record.get("name") + top_level_modules = record.get("top_level_modules", []) + if isinstance(top_level_modules, list): + for module_name in top_level_modules: + if isinstance(module_name, str) and module_name: + modules.add(module_name.split(".", 1)[0]) + if isinstance(name, str): + fallback = _fallback_module_name(name) + if fallback: + modules.add(fallback) + + return frozenset(modules) diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index e5f7138209..fbc1b5a7ec 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -18,6 +18,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path from astrbot.core.utils.core_constraints import CoreConstraintsProvider +from astrbot.core.utils.desktop_core_lock import get_desktop_core_lock_modules from astrbot.core.utils.requirements_utils import ( canonicalize_distribution_name as _canonicalize_distribution_name, ) @@ -811,6 +812,12 @@ def _ensure_plugin_dependencies_preferred( if not candidate_modules: return + locked_modules = get_desktop_core_lock_modules() + if locked_modules: + candidate_modules = candidate_modules.difference(locked_modules) + if not candidate_modules: + return + _ensure_preferred_modules(candidate_modules, target_site_packages) diff --git a/tests/test_pip_installer.py b/tests/test_pip_installer.py index bfddf60e1c..266d9195b0 100644 --- a/tests/test_pip_installer.py +++ b/tests/test_pip_installer.py @@ -1,6 +1,8 @@ import asyncio +import json import ntpath import threading +from pathlib import Path from unittest.mock import AsyncMock import pytest @@ -1061,6 +1063,100 @@ def test_core_constraints_file_propagates_inner_conflict_without_fake_warning( assert warning_logs == [] +@pytest.mark.asyncio +async def test_install_adds_desktop_core_lock_constraints_for_packaged_runtime( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + monkeypatch.delattr("sys.frozen", raising=False) + + lock_path = tmp_path / "runtime-core-lock.json" + lock_path.write_text( + json.dumps( + { + "version": 1, + "distributions": [ + { + "name": "desktop-only-core", + "version": "9.9.9", + "top_level_modules": ["desktop_only_core"], + } + ], + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("ASTRBOT_DESKTOP_CORE_LOCK_PATH", str(lock_path)) + + site_packages_path = tmp_path / "site-packages" + captured_constraints = [] + + async def capture_pip_args(self, args): + del self + constraints_path = args[args.index("-c") + 1] + captured_constraints.append(Path(constraints_path).read_text(encoding="utf-8")) + return 0 + + monkeypatch.setattr(PipInstaller, "_run_pip_in_process", capture_pip_args) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer.get_astrbot_site_packages_path", + lambda: str(site_packages_path), + ) + monkeypatch.setattr( + "astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred", + lambda path, requirements: None, + ) + + installer = PipInstaller("") + await installer.install(package_name="Cua") + + assert captured_constraints + assert "desktop-only-core==9.9.9" in captured_constraints[0] + + +def test_ensure_plugin_dependencies_preferred_skips_desktop_core_lock_modules( + monkeypatch, tmp_path +): + monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1") + lock_path = tmp_path / "runtime-core-lock.json" + lock_path.write_text( + json.dumps( + { + "version": 1, + "distributions": [ + { + "name": "openai", + "version": "2.32.0", + "top_level_modules": ["openai"], + } + ], + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("ASTRBOT_DESKTOP_CORE_LOCK_PATH", str(lock_path)) + + preferred_calls = [] + + monkeypatch.setattr( + pip_installer_module, + "_collect_candidate_modules", + lambda requirements, site_packages_path: {"openai", "cua_agent"}, + ) + monkeypatch.setattr( + pip_installer_module, + "_ensure_preferred_modules", + lambda modules, site_packages_path: preferred_calls.append(modules), + ) + + pip_installer_module._ensure_plugin_dependencies_preferred( + str(tmp_path / "site-packages"), + {"Cua"}, + ) + + assert preferred_calls == [{"cua_agent"}] + + def test_iter_requirement_lines_expands_nested_requirement_files(tmp_path): base_requirements = tmp_path / "base.txt" base_requirements.write_text("demo-package==1.0\n", encoding="utf-8")