From 61527042f8bbe08c5a1c0a28a7cefca813dfb3c6 Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 04:00:45 +0800 Subject: [PATCH 01/11] =?UTF-8?q?:sparkles:=20=E5=85=81=E8=AE=B8=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E4=BB=8E=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E4=B8=AD?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E9=85=8D=E7=BD=AE=E9=A1=B9=E8=80=8C=E4=B8=8D?= =?UTF-8?q?=E9=9C=80=E8=A6=81=E5=9C=A8envfile=E4=B8=AD=E5=A3=B0=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/config.py | 68 +++++++++++++++++++++++++++++++---- nonebot/plugin/__init__.py | 9 +++-- tests/test_plugin/test_get.py | 31 +++++++++++++++- 3 files changed, 99 insertions(+), 9 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index 5a1fa13238d1..50546d88f083 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -14,14 +14,15 @@ """ import abc -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from datetime import timedelta +from functools import lru_cache from ipaddress import IPv4Address import json import os from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union -from typing_extensions import TypeAlias, get_args, get_origin +from typing_extensions import TypeAlias, get_args, get_origin, override from dotenv import dotenv_values from pydantic import BaseModel, Field @@ -46,6 +47,8 @@ ENV_FILE_SENTINEL = Path("") +cachable_dotenv_values = lru_cache(maxsize=20)(dotenv_values) + class SettingsError(ValueError): ... @@ -79,7 +82,7 @@ def __repr__(self) -> str: return f"InitSettingsSource(init_kwargs={self.init_kwargs!r})" -class DotEnvSettingsSource(BaseSettingsSource): +class EnvSettingsSource(BaseSettingsSource): def __init__( self, settings_cls: type["BaseSettings"], @@ -110,6 +113,11 @@ def __init__( else self.config.get("env_nested_delimiter", None) ) + @abc.abstractmethod + def get_setting_fields(self) -> Iterable[ModelField]: + """获取配置类的字段信息""" + raise NotImplementedError + def _apply_case_sensitive(self, var_name: str) -> str: return var_name if self.case_sensitive else var_name.lower() @@ -130,7 +138,8 @@ def _parse_env_vars( } def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]: - file_vars = dotenv_values(file_path, encoding=self.env_file_encoding) + file_vars = cachable_dotenv_values(file_path, encoding=self.env_file_encoding) + logger.trace(f"Loaded env file '{file_path}': {file_vars}") return self._parse_env_vars(file_vars) def _read_env_files(self) -> dict[str, Optional[str]]: @@ -209,8 +218,8 @@ def __call__(self) -> dict[str, Any]: env_file_vars = self._read_env_files() env_vars = {**env_file_vars, **env_vars} - for field in model_fields(self.settings_cls): - field_name = field.name + for field in self.get_setting_fields(): + field_name = self._parse_field_name(field) env_name = self._apply_case_sensitive(field_name) # try get values from env vars @@ -283,6 +292,53 @@ def __call__(self) -> dict[str, Any]: return d + def _parse_field_name(self, field: ModelField) -> str: + return field.field_info.alias or field.name + + +class DotEnvSettingsSource(EnvSettingsSource): + def __init__( + self, + settings_cls: type["BaseSettings"], + env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, + env_file_encoding: Optional[str] = None, + case_sensitive: Optional[bool] = None, + env_nested_delimiter: Optional[str] = None, + ) -> None: + super().__init__( + settings_cls, + env_file, + env_file_encoding, + case_sensitive, + env_nested_delimiter, + ) + + @override + def get_setting_fields(self) -> Iterable[ModelField]: + return model_fields(self.settings_cls) + + +class PluginEnvSettingsSource(EnvSettingsSource): + def __init__( + self, + config_cls: type[BaseModel], + driver_config: "Config", + env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, + ) -> None: + setting_config: "SettingsConfig" = model_config(driver_config.__class__) + super().__init__( + BaseSettings, + env_file=env_file, + env_file_encoding=setting_config.get("env_file_encoding", "utf-8"), + case_sensitive=setting_config.get("case_sensitive", False), + env_nested_delimiter=setting_config.get("env_nested_delimiter", None), + ) + self.config_cls = config_cls + + @override + def get_setting_fields(self) -> Iterable[ModelField]: + return model_fields(self.config_cls) + if PYDANTIC_V2: # pragma: pydantic-v2 diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index faad08bbd61a..2cb295ec8b4b 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -46,7 +46,8 @@ from pydantic import BaseModel from nonebot import get_driver -from nonebot.compat import model_dump, type_validate_python +from nonebot.compat import type_validate_python +from nonebot.config import PluginEnvSettingsSource C = TypeVar("C", bound=BaseModel) @@ -172,7 +173,11 @@ def get_available_plugin_names() -> set[str]: def get_plugin_config(config: type[C]) -> C: """从全局配置获取当前插件需要的配置项。""" - return type_validate_python(config, model_dump(get_driver().config)) + driver = get_driver() + env_settings = PluginEnvSettingsSource( + config, driver.config, env_file=(".env", f".env.{driver.env}") + ) + return type_validate_python(config, env_settings()) from .load import inherit_supported_adapters as inherit_supported_adapters diff --git a/tests/test_plugin/test_get.py b/tests/test_plugin/test_get.py index 5b3def6a5813..bfa6c6292bbb 100644 --- a/tests/test_plugin/test_get.py +++ b/tests/test_plugin/test_get.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel +from pydantic import BaseModel, Field +import pytest import nonebot +from nonebot.compat import model_dump from nonebot.plugin import PluginManager, _managers @@ -67,3 +69,30 @@ class Config(BaseModel): config = nonebot.get_plugin_config(Config) assert isinstance(config, Config) assert config.plugin_config == 1 + + +def test_plugin_load_env_config(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("TEST_CONFIG_ONE", "no_dummy_val") + monkeypatch.setenv("TEST_CONFIG__TWO", "two") + monkeypatch.setenv("TEST_CFG_THREE", "33") + + class CfgTwo(BaseModel): + two: str = "dummy_val" + + class Config(BaseModel): + test_config_one: str = "dummy_val" + test_config: CfgTwo = Field(default_factory=CfgTwo) + test_config_three: int = Field(alias="TEST_CFG_THREE", default=3) + + global_config = nonebot.get_driver().config + assert "test_config_one" not in model_dump(global_config) + assert "TEST_CONFIG_ONE" not in model_dump(global_config) + assert "test_config" not in model_dump(global_config) + assert "TEST_CONFIG" not in model_dump(global_config) + assert "test_config_three" not in model_dump(global_config) + assert "TEST_CFG_THREE" not in model_dump(global_config) + + config = nonebot.get_plugin_config(Config) + assert config.test_config_one == "no_dummy_val" + assert config.test_config.two == "two" + assert config.test_config_three == 33 From e5b127d9696fe40734a4f2779aafb28104892516 Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 15:40:35 +0800 Subject: [PATCH 02/11] =?UTF-8?q?:bug:=20PluginEnvSettingSource=20?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E7=88=B6=E7=B1=BB=E6=97=B6=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=80=BC=E5=BA=94=E8=AF=A5=E4=B8=BA=20None?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index 50546d88f083..366271896854 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -329,8 +329,8 @@ def __init__( super().__init__( BaseSettings, env_file=env_file, - env_file_encoding=setting_config.get("env_file_encoding", "utf-8"), - case_sensitive=setting_config.get("case_sensitive", False), + env_file_encoding=setting_config.get("env_file_encoding", None), + case_sensitive=setting_config.get("case_sensitive", None), env_nested_delimiter=setting_config.get("env_nested_delimiter", None), ) self.config_cls = config_cls From cec5ca7d4eafdbcd9b7a0de62fd387b0f4de106a Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 15:42:50 +0800 Subject: [PATCH 03/11] =?UTF-8?q?:bug:=20DotEnvSettingsSource=20=E4=B8=8D?= =?UTF-8?q?=E9=9C=80=E8=A6=81=E8=A6=86=E5=86=99=20`=5F=5Finit=5F=5F`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/config.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index 366271896854..900e1d772fed 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -297,22 +297,6 @@ def _parse_field_name(self, field: ModelField) -> str: class DotEnvSettingsSource(EnvSettingsSource): - def __init__( - self, - settings_cls: type["BaseSettings"], - env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, - env_file_encoding: Optional[str] = None, - case_sensitive: Optional[bool] = None, - env_nested_delimiter: Optional[str] = None, - ) -> None: - super().__init__( - settings_cls, - env_file, - env_file_encoding, - case_sensitive, - env_nested_delimiter, - ) - @override def get_setting_fields(self) -> Iterable[ModelField]: return model_fields(self.settings_cls) From d3fc5b0e497f0ab2b010f1560b4e91bce6007fb4 Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 16:50:10 +0800 Subject: [PATCH 04/11] =?UTF-8?q?:bug:=20=E7=9B=B4=E6=8E=A5=E4=BC=A0?= =?UTF-8?q?=E5=85=A5=E7=9A=84=E9=85=8D=E7=BD=AE=E9=A1=B9=E4=BC=98=E5=85=88?= =?UTF-8?q?=E4=BA=8E=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E7=94=9F=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/plugin/__init__.py | 5 +++-- tests/test_plugin/test_get.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index 2cb295ec8b4b..401e1a5ae217 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -46,7 +46,7 @@ from pydantic import BaseModel from nonebot import get_driver -from nonebot.compat import type_validate_python +from nonebot.compat import model_dump, type_validate_python from nonebot.config import PluginEnvSettingsSource C = TypeVar("C", bound=BaseModel) @@ -177,7 +177,8 @@ def get_plugin_config(config: type[C]) -> C: env_settings = PluginEnvSettingsSource( config, driver.config, env_file=(".env", f".env.{driver.env}") ) - return type_validate_python(config, env_settings()) + config_vars = {**env_settings(), **model_dump(driver.config)} + return type_validate_python(config, config_vars) from .load import inherit_supported_adapters as inherit_supported_adapters diff --git a/tests/test_plugin/test_get.py b/tests/test_plugin/test_get.py index bfa6c6292bbb..28906346ab60 100644 --- a/tests/test_plugin/test_get.py +++ b/tests/test_plugin/test_get.py @@ -75,6 +75,7 @@ def test_plugin_load_env_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("TEST_CONFIG_ONE", "no_dummy_val") monkeypatch.setenv("TEST_CONFIG__TWO", "two") monkeypatch.setenv("TEST_CFG_THREE", "33") + monkeypatch.setenv("CONFIG_FROM_INIT", "impossible") class CfgTwo(BaseModel): two: str = "dummy_val" @@ -83,6 +84,7 @@ class Config(BaseModel): test_config_one: str = "dummy_val" test_config: CfgTwo = Field(default_factory=CfgTwo) test_config_three: int = Field(alias="TEST_CFG_THREE", default=3) + config_from_init: str = "dummy_val" global_config = nonebot.get_driver().config assert "test_config_one" not in model_dump(global_config) @@ -96,3 +98,4 @@ class Config(BaseModel): assert config.test_config_one == "no_dummy_val" assert config.test_config.two == "two" assert config.test_config_three == 33 + assert config.config_from_init == "init" From 1093050ec4ec795ea08d313b6c4a294e7bbaca4d Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 16:50:49 +0800 Subject: [PATCH 05/11] =?UTF-8?q?:memo:=20=E6=9B=B4=E6=96=B0=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E9=85=8D=E7=BD=AE=E7=9B=B8=E5=85=B3=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- website/docs/appendices/config.mdx | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/website/docs/appendices/config.mdx b/website/docs/appendices/config.mdx index 482b2acbf133..0d054ce19431 100644 --- a/website/docs/appendices/config.mdx +++ b/website/docs/appendices/config.mdx @@ -84,7 +84,7 @@ export CUSTOM_CONFIG='config in environment variables' 那最终 NoneBot 所读取的内容为环境变量中的内容,即 `config in environment variables`。 :::caution 注意 -NoneBot 不会自发读取未被定义的配置项的环境变量,如果需要读取某一环境变量需要在 dotenv 配置文件中进行声明。 +如果一个环境变量既不是 NoneBot 的[**内置配置项**](#内置配置项),也不是任何插件所定义的[**插件配置**](#插件配置),那么 NoneBot 不会自发读取该环境变量,需要在 dotenv 配置文件中先行声明。 ::: ### dotenv 配置文件 @@ -242,11 +242,17 @@ weather = on_command( 这种方式可以简洁、高效地读取配置项,同时也可以设置默认值或者在运行时对配置项进行合法性检查,防止由于配置项导致的插件出错等情况出现。 -:::tip 提示 +:::tip 可配置的事件响应优先级 发布插件应该为自身的事件响应器提供可配置的优先级,以便插件使用者可以自定义多个插件间的响应顺序。 ::: -由于插件配置项是从全局配置中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致在使用配置项时过长的变量名,因此我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例: +:::tip 插件配置获取逻辑 +无论是否在 dotenv 文件中声明了插件配置项,使用 `get_plugin_config` 获取插件配置模型中定义的配置项时都遵循[**配置项的加载**](#配置项的加载)一节中的优先级顺序进行读取。 +::: + +### 配置 scope + +由于插件配置项是从全局配置和环境变量中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致在使用配置项时过长的变量名,因此我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例: ```python title=weather/config.py from pydantic import BaseModel From 3774024d5b8752d1dafb343a359410b70b8820b8 Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 21:56:19 +0800 Subject: [PATCH 06/11] =?UTF-8?q?:recycle:=20Plugin=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E6=97=B6=E4=B8=8D=E5=86=8D?= =?UTF-8?q?=E8=AF=BB=E5=8F=96envfile?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/config.py | 169 ++++++++++++++++++++++++++----------- nonebot/plugin/__init__.py | 11 ++- 2 files changed, 123 insertions(+), 57 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index 900e1d772fed..def15455aee2 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -14,9 +14,10 @@ """ import abc +from collections import ChainMap from collections.abc import Iterable, Mapping from datetime import timedelta -from functools import lru_cache +from functools import cache from ipaddress import IPv4Address import json import os @@ -35,6 +36,7 @@ PydanticUndefined, PydanticUndefinedType, model_config, + model_dump, model_fields, ) from nonebot.log import logger @@ -47,8 +49,6 @@ ENV_FILE_SENTINEL = Path("") -cachable_dotenv_values = lru_cache(maxsize=20)(dotenv_values) - class SettingsError(ValueError): ... @@ -86,22 +86,10 @@ class EnvSettingsSource(BaseSettingsSource): def __init__( self, settings_cls: type["BaseSettings"], - env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, - env_file_encoding: Optional[str] = None, case_sensitive: Optional[bool] = None, env_nested_delimiter: Optional[str] = None, ) -> None: super().__init__(settings_cls) - self.env_file = ( - env_file - if env_file is not ENV_FILE_SENTINEL - else self.config.get("env_file", (".env",)) - ) - self.env_file_encoding = ( - env_file_encoding - if env_file_encoding is not None - else self.config.get("env_file_encoding", "utf-8") - ) self.case_sensitive = ( case_sensitive if case_sensitive is not None @@ -113,11 +101,26 @@ def __init__( else self.config.get("env_nested_delimiter", None) ) + @property + @abc.abstractmethod + def config_id(self) -> str: + raise NotImplementedError + + @abc.abstractmethod + def get_env_vars(self) -> Mapping[str, Optional[str]]: + """获取环境变量映射""" + raise NotImplementedError + @abc.abstractmethod def get_setting_fields(self) -> Iterable[ModelField]: """获取配置类的字段信息""" raise NotImplementedError + @abc.abstractmethod + def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: + """获取剩余的用户自定义配置项名称""" + raise NotImplementedError + def _apply_case_sensitive(self, var_name: str) -> str: return var_name if self.case_sensitive else var_name.lower() @@ -137,26 +140,6 @@ def _parse_env_vars( self._apply_case_sensitive(key): value for key, value in env_vars.items() } - def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]: - file_vars = cachable_dotenv_values(file_path, encoding=self.env_file_encoding) - logger.trace(f"Loaded env file '{file_path}': {file_vars}") - return self._parse_env_vars(file_vars) - - def _read_env_files(self) -> dict[str, Optional[str]]: - env_files = self.env_file - if env_files is None: - return {} - - if isinstance(env_files, (str, os.PathLike)): - env_files = [env_files] - - dotenv_vars: dict[str, Optional[str]] = {} - for env_file in env_files: - env_path = Path(env_file).expanduser() - if env_path.is_file(): - dotenv_vars.update(self._read_env_file(env_path)) - return dotenv_vars - def _next_field( self, field: Optional[ModelField], key: str ) -> Optional[ModelField]: @@ -171,8 +154,8 @@ def _next_field( def _explode_env_vars( self, field: ModelField, - env_vars: dict[str, Optional[str]], - env_file_vars: dict[str, Optional[str]], + env_vars: Mapping[str, Optional[str]], + used_env_vars: set[str], ) -> dict[str, Any]: if self.env_nested_delimiter is None: return {} @@ -183,8 +166,8 @@ def _explode_env_vars( if not env_name.startswith(prefix): continue - # delete from file vars when used - env_file_vars.pop(env_name, None) + # record vars when used + used_env_vars.add(env_name) _, *keys, last_key = env_name.split(self.env_nested_delimiter) env_var = result @@ -214,9 +197,8 @@ def __call__(self) -> dict[str, Any]: d: dict[str, Any] = {} - env_vars = self._parse_env_vars(os.environ) - env_file_vars = self._read_env_files() - env_vars = {**env_file_vars, **env_vars} + env_vars = self.get_env_vars() + used_env_vars = set[str]() for field in self.get_setting_fields(): field_name = self._parse_field_name(field) @@ -224,16 +206,15 @@ def __call__(self) -> dict[str, Any]: # try get values from env vars env_val = env_vars.get(env_name, PydanticUndefined) - # delete from file vars when used - if env_name in env_file_vars: - del env_file_vars[env_name] + # record vars when used + used_env_vars.add(env_name) is_complex, allow_parse_failure = self._field_is_complex(field) if is_complex: if isinstance(env_val, PydanticUndefinedType): # field is complex but no value found so far, try explode_env_vars if env_val_built := self._explode_env_vars( - field, env_vars, env_file_vars + field, env_vars, used_env_vars ): d[field_name] = env_val_built elif env_val is None: @@ -254,7 +235,7 @@ def __call__(self) -> dict[str, Any]: # try explode_env_vars to find more sub-values d[field_name] = deep_update( env_val, - self._explode_env_vars(field, env_vars, env_file_vars), + self._explode_env_vars(field, env_vars, used_env_vars), ) else: d[field_name] = env_val @@ -264,7 +245,7 @@ def __call__(self) -> dict[str, Any]: d[field_name] = env_val # remain user custom config - for env_name in env_file_vars: + for env_name in self.get_remain_config(used_env_vars): env_val = env_vars[env_name] if env_val and (val_striped := env_val.strip()): # there's a value, decode that as JSON @@ -290,6 +271,7 @@ def __call__(self) -> dict[str, Any]: elif not nested_keys: d[env_name] = env_val + logger.debug(f"{self.config_id} loaded config from env: {d}") return d def _parse_field_name(self, field: ModelField) -> str: @@ -297,32 +279,117 @@ def _parse_field_name(self, field: ModelField) -> str: class DotEnvSettingsSource(EnvSettingsSource): + def __init__( + self, + settings_cls: type["BaseSettings"], + env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, + env_file_encoding: Optional[str] = None, + case_sensitive: Optional[bool] = None, + env_nested_delimiter: Optional[str] = None, + ) -> None: + super().__init__(settings_cls, case_sensitive, env_nested_delimiter) + self.env_file = ( + env_file + if env_file is not ENV_FILE_SENTINEL + else self.config.get("env_file", (".env",)) + ) + self.env_file_encoding = ( + env_file_encoding + if env_file_encoding is not None + else self.config.get("env_file_encoding", "utf-8") + ) + + def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]: + file_vars = dotenv_values(file_path, encoding=self.env_file_encoding) + logger.warning(f"Loaded env file '{file_path}': {file_vars}") + return self._parse_env_vars(file_vars) + + @cache + def _read_env_files(self) -> dict[str, Optional[str]]: + env_files = self.env_file + if env_files is None: + return {} + + if isinstance(env_files, (str, os.PathLike)): + env_files = [env_files] + + dotenv_vars: dict[str, Optional[str]] = {} + for env_file in env_files: + env_path = Path(env_file).expanduser() + if env_path.is_file(): + dotenv_vars.update(self._read_env_file(env_path)) + return dotenv_vars + + @property + @override + def config_id(self) -> str: + return ( + f"{self.__class__.__name__}" + f"({self.settings_cls.__module__}.{self.settings_cls.__name__})" + ) + @override def get_setting_fields(self) -> Iterable[ModelField]: return model_fields(self.settings_cls) + @override + def get_env_vars(self) -> Mapping[str, Optional[str]]: + env_vars = self._parse_env_vars(os.environ) + env_file_vars = self._read_env_files() + return ChainMap(env_vars, env_file_vars) + + @override + def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: + return ( + env_var + for env_var in self._read_env_files() + if env_var not in used_env_vars + ) + class PluginEnvSettingsSource(EnvSettingsSource): def __init__( self, config_cls: type[BaseModel], driver_config: "Config", - env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, ) -> None: setting_config: "SettingsConfig" = model_config(driver_config.__class__) super().__init__( BaseSettings, - env_file=env_file, - env_file_encoding=setting_config.get("env_file_encoding", None), case_sensitive=setting_config.get("case_sensitive", None), env_nested_delimiter=setting_config.get("env_nested_delimiter", None), ) self.config_cls = config_cls + self.driver_config = model_dump(driver_config) + + @property + @override + def config_id(self) -> str: + return ( + f"{self.__class__.__name__}" + f"({self.config_cls.__module__}.{self.config_cls.__name__})" + ) @override def get_setting_fields(self) -> Iterable[ModelField]: return model_fields(self.config_cls) + @override + def get_env_vars(self) -> Mapping[str, Optional[str]]: + env_vars = self._parse_env_vars(os.environ) + return ChainMap(self.driver_config, env_vars) + + @override + def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: + return ( + name + for name in ( + self._apply_case_sensitive(self._parse_field_name(f)) + for f in model_fields(self.config_cls) + ) + if name not in used_env_vars + ) + if PYDANTIC_V2: # pragma: pydantic-v2 diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index 401e1a5ae217..bc2d424e09db 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -39,6 +39,7 @@ """ from contextvars import ContextVar +from functools import cache from itertools import chain from types import ModuleType from typing import Optional, TypeVar @@ -46,7 +47,7 @@ from pydantic import BaseModel from nonebot import get_driver -from nonebot.compat import model_dump, type_validate_python +from nonebot.compat import type_validate_python from nonebot.config import PluginEnvSettingsSource C = TypeVar("C", bound=BaseModel) @@ -171,14 +172,12 @@ def get_available_plugin_names() -> set[str]: return {*chain.from_iterable(manager.available_plugins for manager in _managers)} +@cache def get_plugin_config(config: type[C]) -> C: """从全局配置获取当前插件需要的配置项。""" driver = get_driver() - env_settings = PluginEnvSettingsSource( - config, driver.config, env_file=(".env", f".env.{driver.env}") - ) - config_vars = {**env_settings(), **model_dump(driver.config)} - return type_validate_python(config, config_vars) + env_setting = PluginEnvSettingsSource(config, driver.config) + return type_validate_python(config, env_setting()) from .load import inherit_supported_adapters as inherit_supported_adapters From 164b066ddd67a2bd8e82ec16392fae58fc8b2fb2 Mon Sep 17 00:00:00 2001 From: Azide Date: Sun, 7 Sep 2025 22:10:28 +0800 Subject: [PATCH 07/11] =?UTF-8?q?:bug:=20=E7=BB=99=20`functools.cache`=20?= =?UTF-8?q?=E7=9A=84=E7=B1=BB=E5=9E=8B=E6=A0=87=E6=B3=A8=E6=93=A6=E5=B1=81?= =?UTF-8?q?=E8=82=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nonebot/config.py | 3 +-- nonebot/plugin/__init__.py | 2 +- nonebot/utils.py | 14 +++++++++++++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index def15455aee2..e9bdeec41b4a 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -17,7 +17,6 @@ from collections import ChainMap from collections.abc import Iterable, Mapping from datetime import timedelta -from functools import cache from ipaddress import IPv4Address import json import os @@ -41,7 +40,7 @@ ) from nonebot.log import logger from nonebot.typing import origin_is_union -from nonebot.utils import deep_update, lenient_issubclass, type_is_complex +from nonebot.utils import cache, deep_update, lenient_issubclass, type_is_complex DOTENV_TYPE: TypeAlias = Union[ Path, str, list[Union[Path, str]], tuple[Union[Path, str], ...] diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index bc2d424e09db..1d74668131ef 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -39,7 +39,6 @@ """ from contextvars import ContextVar -from functools import cache from itertools import chain from types import ModuleType from typing import Optional, TypeVar @@ -49,6 +48,7 @@ from nonebot import get_driver from nonebot.compat import type_validate_python from nonebot.config import PluginEnvSettingsSource +from nonebot.utils import cache C = TypeVar("C", bound=BaseModel) diff --git a/nonebot/utils.py b/nonebot/utils.py index 5869fdbe3df9..3fa906b6c9a5 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -12,7 +12,7 @@ import contextlib from contextlib import AbstractContextManager, asynccontextmanager import dataclasses -from functools import partial, wraps +from functools import lru_cache, partial, wraps import importlib import inspect import json @@ -345,3 +345,15 @@ def log(level: str, message: str, exception: Optional[Exception] = None): ) return log + + +def cache(user_function: Callable[P, R], /) -> Callable[P, R]: + """等价于 `functools.cache`。为了更好的类型提示而进行重新实现 + + 参数: + user_function: 需要使用缓存的待装饰函数 + + 返回: + 被装饰的函数 + """ + return lru_cache(maxsize=None)(user_function) # pyright: ignore[reportReturnType] From 78e7882a93a69aa49077d28574f30781f14e76f2 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 18 Oct 2025 08:23:12 +0000 Subject: [PATCH 08/11] :recycle: re-impl the logic --- nonebot/config.py | 216 ++++++++--------------------- nonebot/plugin/__init__.py | 22 ++- nonebot/utils.py | 14 +- tests/test_plugin/test_get.py | 6 +- website/docs/appendices/config.mdx | 4 +- 5 files changed, 78 insertions(+), 184 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index e9bdeec41b4a..1676a1a44451 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -14,15 +14,14 @@ """ import abc -from collections import ChainMap -from collections.abc import Iterable, Mapping +from collections.abc import Mapping from datetime import timedelta from ipaddress import IPv4Address import json import os from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union -from typing_extensions import TypeAlias, get_args, get_origin, override +from typing_extensions import TypeAlias, get_args, get_origin from dotenv import dotenv_values from pydantic import BaseModel, Field @@ -35,12 +34,11 @@ PydanticUndefined, PydanticUndefinedType, model_config, - model_dump, model_fields, ) from nonebot.log import logger from nonebot.typing import origin_is_union -from nonebot.utils import cache, deep_update, lenient_issubclass, type_is_complex +from nonebot.utils import deep_update, lenient_issubclass, type_is_complex DOTENV_TYPE: TypeAlias = Union[ Path, str, list[Union[Path, str]], tuple[Union[Path, str], ...] @@ -53,7 +51,7 @@ class SettingsError(ValueError): ... class BaseSettingsSource(abc.ABC): - def __init__(self, settings_cls: type["BaseSettings"]) -> None: + def __init__(self, settings_cls: type[BaseModel]) -> None: self.settings_cls = settings_cls @property @@ -69,7 +67,7 @@ class InitSettingsSource(BaseSettingsSource): __slots__ = ("init_kwargs",) def __init__( - self, settings_cls: type["BaseSettings"], init_kwargs: dict[str, Any] + self, settings_cls: type[BaseModel], init_kwargs: dict[str, Any] ) -> None: self.init_kwargs = init_kwargs super().__init__(settings_cls) @@ -81,14 +79,26 @@ def __repr__(self) -> str: return f"InitSettingsSource(init_kwargs={self.init_kwargs!r})" -class EnvSettingsSource(BaseSettingsSource): +class DotEnvSettingsSource(BaseSettingsSource): def __init__( self, - settings_cls: type["BaseSettings"], + settings_cls: type[BaseModel], + env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, + env_file_encoding: Optional[str] = None, case_sensitive: Optional[bool] = None, env_nested_delimiter: Optional[str] = None, ) -> None: super().__init__(settings_cls) + self.env_file = ( + env_file + if env_file is not ENV_FILE_SENTINEL + else self.config.get("env_file", (".env",)) + ) + self.env_file_encoding = ( + env_file_encoding + if env_file_encoding is not None + else self.config.get("env_file_encoding", "utf-8") + ) self.case_sensitive = ( case_sensitive if case_sensitive is not None @@ -100,26 +110,6 @@ def __init__( else self.config.get("env_nested_delimiter", None) ) - @property - @abc.abstractmethod - def config_id(self) -> str: - raise NotImplementedError - - @abc.abstractmethod - def get_env_vars(self) -> Mapping[str, Optional[str]]: - """获取环境变量映射""" - raise NotImplementedError - - @abc.abstractmethod - def get_setting_fields(self) -> Iterable[ModelField]: - """获取配置类的字段信息""" - raise NotImplementedError - - @abc.abstractmethod - def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: - """获取剩余的用户自定义配置项名称""" - raise NotImplementedError - def _apply_case_sensitive(self, var_name: str) -> str: return var_name if self.case_sensitive else var_name.lower() @@ -139,6 +129,25 @@ def _parse_env_vars( self._apply_case_sensitive(key): value for key, value in env_vars.items() } + def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]: + file_vars = dotenv_values(file_path, encoding=self.env_file_encoding) + return self._parse_env_vars(file_vars) + + def _read_env_files(self) -> dict[str, Optional[str]]: + env_files = self.env_file + if env_files is None: + return {} + + if isinstance(env_files, (str, os.PathLike)): + env_files = [env_files] + + dotenv_vars: dict[str, Optional[str]] = {} + for env_file in env_files: + env_path = Path(env_file).expanduser() + if env_path.is_file(): + dotenv_vars.update(self._read_env_file(env_path)) + return dotenv_vars + def _next_field( self, field: Optional[ModelField], key: str ) -> Optional[ModelField]: @@ -153,8 +162,8 @@ def _next_field( def _explode_env_vars( self, field: ModelField, - env_vars: Mapping[str, Optional[str]], - used_env_vars: set[str], + env_vars: dict[str, Optional[str]], + env_file_vars: dict[str, Optional[str]], ) -> dict[str, Any]: if self.env_nested_delimiter is None: return {} @@ -165,8 +174,8 @@ def _explode_env_vars( if not env_name.startswith(prefix): continue - # record vars when used - used_env_vars.add(env_name) + # delete from file vars when used + env_file_vars.pop(env_name, None) _, *keys, last_key = env_name.split(self.env_nested_delimiter) env_var = result @@ -196,24 +205,26 @@ def __call__(self) -> dict[str, Any]: d: dict[str, Any] = {} - env_vars = self.get_env_vars() - used_env_vars = set[str]() + env_vars = self._parse_env_vars(os.environ) + env_file_vars = self._read_env_files() + env_vars = {**env_file_vars, **env_vars} - for field in self.get_setting_fields(): - field_name = self._parse_field_name(field) + for field in model_fields(self.settings_cls): + field_name = field.name env_name = self._apply_case_sensitive(field_name) # try get values from env vars env_val = env_vars.get(env_name, PydanticUndefined) - # record vars when used - used_env_vars.add(env_name) + # delete from file vars when used + if env_name in env_file_vars: + del env_file_vars[env_name] is_complex, allow_parse_failure = self._field_is_complex(field) if is_complex: if isinstance(env_val, PydanticUndefinedType): # field is complex but no value found so far, try explode_env_vars if env_val_built := self._explode_env_vars( - field, env_vars, used_env_vars + field, env_vars, env_file_vars ): d[field_name] = env_val_built elif env_val is None: @@ -234,7 +245,7 @@ def __call__(self) -> dict[str, Any]: # try explode_env_vars to find more sub-values d[field_name] = deep_update( env_val, - self._explode_env_vars(field, env_vars, used_env_vars), + self._explode_env_vars(field, env_vars, env_file_vars), ) else: d[field_name] = env_val @@ -244,7 +255,7 @@ def __call__(self) -> dict[str, Any]: d[field_name] = env_val # remain user custom config - for env_name in self.get_remain_config(used_env_vars): + for env_name in env_file_vars: env_val = env_vars[env_name] if env_val and (val_striped := env_val.strip()): # there's a value, decode that as JSON @@ -270,125 +281,8 @@ def __call__(self) -> dict[str, Any]: elif not nested_keys: d[env_name] = env_val - logger.debug(f"{self.config_id} loaded config from env: {d}") return d - def _parse_field_name(self, field: ModelField) -> str: - return field.field_info.alias or field.name - - -class DotEnvSettingsSource(EnvSettingsSource): - def __init__( - self, - settings_cls: type["BaseSettings"], - env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, - env_file_encoding: Optional[str] = None, - case_sensitive: Optional[bool] = None, - env_nested_delimiter: Optional[str] = None, - ) -> None: - super().__init__(settings_cls, case_sensitive, env_nested_delimiter) - self.env_file = ( - env_file - if env_file is not ENV_FILE_SENTINEL - else self.config.get("env_file", (".env",)) - ) - self.env_file_encoding = ( - env_file_encoding - if env_file_encoding is not None - else self.config.get("env_file_encoding", "utf-8") - ) - - def _read_env_file(self, file_path: Path) -> dict[str, Optional[str]]: - file_vars = dotenv_values(file_path, encoding=self.env_file_encoding) - logger.warning(f"Loaded env file '{file_path}': {file_vars}") - return self._parse_env_vars(file_vars) - - @cache - def _read_env_files(self) -> dict[str, Optional[str]]: - env_files = self.env_file - if env_files is None: - return {} - - if isinstance(env_files, (str, os.PathLike)): - env_files = [env_files] - - dotenv_vars: dict[str, Optional[str]] = {} - for env_file in env_files: - env_path = Path(env_file).expanduser() - if env_path.is_file(): - dotenv_vars.update(self._read_env_file(env_path)) - return dotenv_vars - - @property - @override - def config_id(self) -> str: - return ( - f"{self.__class__.__name__}" - f"({self.settings_cls.__module__}.{self.settings_cls.__name__})" - ) - - @override - def get_setting_fields(self) -> Iterable[ModelField]: - return model_fields(self.settings_cls) - - @override - def get_env_vars(self) -> Mapping[str, Optional[str]]: - env_vars = self._parse_env_vars(os.environ) - env_file_vars = self._read_env_files() - return ChainMap(env_vars, env_file_vars) - - @override - def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: - return ( - env_var - for env_var in self._read_env_files() - if env_var not in used_env_vars - ) - - -class PluginEnvSettingsSource(EnvSettingsSource): - def __init__( - self, - config_cls: type[BaseModel], - driver_config: "Config", - ) -> None: - setting_config: "SettingsConfig" = model_config(driver_config.__class__) - super().__init__( - BaseSettings, - case_sensitive=setting_config.get("case_sensitive", None), - env_nested_delimiter=setting_config.get("env_nested_delimiter", None), - ) - self.config_cls = config_cls - self.driver_config = model_dump(driver_config) - - @property - @override - def config_id(self) -> str: - return ( - f"{self.__class__.__name__}" - f"({self.config_cls.__module__}.{self.config_cls.__name__})" - ) - - @override - def get_setting_fields(self) -> Iterable[ModelField]: - return model_fields(self.config_cls) - - @override - def get_env_vars(self) -> Mapping[str, Optional[str]]: - env_vars = self._parse_env_vars(os.environ) - return ChainMap(self.driver_config, env_vars) - - @override - def get_remain_config(self, used_env_vars: set[str]) -> Iterable[str]: - return ( - name - for name in ( - self._apply_case_sensitive(self._parse_field_name(f)) - for f in model_fields(self.config_cls) - ) - if name not in used_env_vars - ) - if PYDANTIC_V2: # pragma: pydantic-v2 @@ -446,6 +340,10 @@ def __init__( ) ) + __settings_self__._env_file = _env_file + __settings_self__._env_file_encoding = _env_file_encoding + __settings_self__._env_nested_delimiter = _env_nested_delimiter + def _settings_build_values( self, init_kwargs: dict[str, Any], diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index 1d74668131ef..7375cb3c9cb6 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -46,9 +46,8 @@ from pydantic import BaseModel from nonebot import get_driver -from nonebot.compat import type_validate_python -from nonebot.config import PluginEnvSettingsSource -from nonebot.utils import cache +from nonebot.compat import model_dump, type_validate_python +from nonebot.config import DotEnvSettingsSource, InitSettingsSource C = TypeVar("C", bound=BaseModel) @@ -172,12 +171,21 @@ def get_available_plugin_names() -> set[str]: return {*chain.from_iterable(manager.available_plugins for manager in _managers)} -@cache def get_plugin_config(config: type[C]) -> C: """从全局配置获取当前插件需要的配置项。""" - driver = get_driver() - env_setting = PluginEnvSettingsSource(config, driver.config) - return type_validate_python(config, env_setting()) + global_config = get_driver().config + return type_validate_python( + config, + { + **DotEnvSettingsSource( + config, + env_file=global_config._env_file, + env_file_encoding=global_config._env_file_encoding, + env_nested_delimiter=global_config._env_nested_delimiter, + )(), + **InitSettingsSource(config, model_dump(global_config))(), + }, + ) from .load import inherit_supported_adapters as inherit_supported_adapters diff --git a/nonebot/utils.py b/nonebot/utils.py index 3fa906b6c9a5..5869fdbe3df9 100644 --- a/nonebot/utils.py +++ b/nonebot/utils.py @@ -12,7 +12,7 @@ import contextlib from contextlib import AbstractContextManager, asynccontextmanager import dataclasses -from functools import lru_cache, partial, wraps +from functools import partial, wraps import importlib import inspect import json @@ -345,15 +345,3 @@ def log(level: str, message: str, exception: Optional[Exception] = None): ) return log - - -def cache(user_function: Callable[P, R], /) -> Callable[P, R]: - """等价于 `functools.cache`。为了更好的类型提示而进行重新实现 - - 参数: - user_function: 需要使用缓存的待装饰函数 - - 返回: - 被装饰的函数 - """ - return lru_cache(maxsize=None)(user_function) # pyright: ignore[reportReturnType] diff --git a/tests/test_plugin/test_get.py b/tests/test_plugin/test_get.py index 28906346ab60..1c75ea024d3e 100644 --- a/tests/test_plugin/test_get.py +++ b/tests/test_plugin/test_get.py @@ -77,13 +77,13 @@ def test_plugin_load_env_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("TEST_CFG_THREE", "33") monkeypatch.setenv("CONFIG_FROM_INIT", "impossible") - class CfgTwo(BaseModel): + class SubConfig(BaseModel): two: str = "dummy_val" class Config(BaseModel): test_config_one: str = "dummy_val" - test_config: CfgTwo = Field(default_factory=CfgTwo) - test_config_three: int = Field(alias="TEST_CFG_THREE", default=3) + test_config: SubConfig = Field(default_factory=SubConfig) + test_config_three: int = Field(default=3, alias="TEST_CFG_THREE") config_from_init: str = "dummy_val" global_config = nonebot.get_driver().config diff --git a/website/docs/appendices/config.mdx b/website/docs/appendices/config.mdx index 0d054ce19431..e423368c0c7d 100644 --- a/website/docs/appendices/config.mdx +++ b/website/docs/appendices/config.mdx @@ -250,9 +250,9 @@ weather = on_command( 无论是否在 dotenv 文件中声明了插件配置项,使用 `get_plugin_config` 获取插件配置模型中定义的配置项时都遵循[**配置项的加载**](#配置项的加载)一节中的优先级顺序进行读取。 ::: -### 配置 scope +### 避免插件配置名称冲突 -由于插件配置项是从全局配置和环境变量中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致在使用配置项时过长的变量名,因此我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例: +由于插件配置项是从全局配置和环境变量中读取的,通常我们需要在配置项名称前面添加前缀名,以防止配置项冲突。例如在上方的示例中,我们就添加了配置项前缀 `weather_`。但是这样会导致使用配置项时变量名过长,此时我们可以使用 `pydantic` 的 `alias` 或者通过配置 scope 来简化配置项名称。这里我们以 scope 配置为例: ```python title=weather/config.py from pydantic import BaseModel From fb758090182b372d9e537282f076769ef4468752 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 18 Oct 2025 08:31:39 +0000 Subject: [PATCH 09/11] :bug: use deep update logic --- nonebot/config.py | 7 ++++--- nonebot/plugin/__init__.py | 17 +++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index 1676a1a44451..b2d1b8677115 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -344,16 +344,17 @@ def __init__( __settings_self__._env_file_encoding = _env_file_encoding __settings_self__._env_nested_delimiter = _env_nested_delimiter + @classmethod def _settings_build_values( - self, + cls, init_kwargs: dict[str, Any], env_file: Optional[DOTENV_TYPE] = None, env_file_encoding: Optional[str] = None, env_nested_delimiter: Optional[str] = None, ) -> dict[str, Any]: - init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs) + init_settings = InitSettingsSource(cls, init_kwargs=init_kwargs) env_settings = DotEnvSettingsSource( - self.__class__, + cls, env_file=env_file, env_file_encoding=env_file_encoding, env_nested_delimiter=env_nested_delimiter, diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index 7375cb3c9cb6..a0fe1b85642f 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -47,7 +47,7 @@ from nonebot import get_driver from nonebot.compat import model_dump, type_validate_python -from nonebot.config import DotEnvSettingsSource, InitSettingsSource +from nonebot.config import BaseSettings C = TypeVar("C", bound=BaseModel) @@ -176,15 +176,12 @@ def get_plugin_config(config: type[C]) -> C: global_config = get_driver().config return type_validate_python( config, - { - **DotEnvSettingsSource( - config, - env_file=global_config._env_file, - env_file_encoding=global_config._env_file_encoding, - env_nested_delimiter=global_config._env_nested_delimiter, - )(), - **InitSettingsSource(config, model_dump(global_config))(), - }, + BaseSettings._settings_build_values( + model_dump(global_config), + env_file=global_config._env_file, + env_file_encoding=global_config._env_file_encoding, + env_nested_delimiter=global_config._env_nested_delimiter, + ), ) From dfb0f50dd12593c9e10f4cf30e19ac50da23d463 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 18 Oct 2025 09:41:33 +0000 Subject: [PATCH 10/11] :bug: fix alias field --- nonebot/config.py | 95 ++++++++++++++++++++++------------- nonebot/plugin/__init__.py | 1 + tests/.env.example | 1 + tests/test_config.py | 3 ++ tests/test_plugin/test_get.py | 31 +++++------- 5 files changed, 76 insertions(+), 55 deletions(-) diff --git a/nonebot/config.py b/nonebot/config.py index b2d1b8677115..dc2dba94ebc3 100644 --- a/nonebot/config.py +++ b/nonebot/config.py @@ -83,32 +83,16 @@ class DotEnvSettingsSource(BaseSettingsSource): def __init__( self, settings_cls: type[BaseModel], - env_file: Optional[DOTENV_TYPE] = ENV_FILE_SENTINEL, - env_file_encoding: Optional[str] = None, - case_sensitive: Optional[bool] = None, + env_file: Optional[DOTENV_TYPE], + env_file_encoding: str, + case_sensitive: Optional[bool] = False, env_nested_delimiter: Optional[str] = None, ) -> None: super().__init__(settings_cls) - self.env_file = ( - env_file - if env_file is not ENV_FILE_SENTINEL - else self.config.get("env_file", (".env",)) - ) - self.env_file_encoding = ( - env_file_encoding - if env_file_encoding is not None - else self.config.get("env_file_encoding", "utf-8") - ) - self.case_sensitive = ( - case_sensitive - if case_sensitive is not None - else self.config.get("case_sensitive", False) - ) - self.env_nested_delimiter = ( - env_nested_delimiter - if env_nested_delimiter is not None - else self.config.get("env_nested_delimiter", None) - ) + self.env_file = env_file + self.env_file_encoding = env_file_encoding + self.case_sensitive = case_sensitive + self.env_nested_delimiter = env_nested_delimiter def _apply_case_sensitive(self, var_name: str) -> str: return var_name if self.case_sensitive else var_name.lower() @@ -212,12 +196,33 @@ def __call__(self) -> dict[str, Any]: for field in model_fields(self.settings_cls): field_name = field.name env_name = self._apply_case_sensitive(field_name) + alias_name = field.field_info.alias + alias_env_name = ( + None if alias_name is None else self._apply_case_sensitive(alias_name) + ) + + # pydantic use alias name to validate if exist + if alias_name is not None: + field_name = alias_name # try get values from env vars env_val = env_vars.get(env_name, PydanticUndefined) + alias_env_val = ( + PydanticUndefined + if alias_env_name is None + else env_vars.get(alias_env_name, PydanticUndefined) + ) + # alias env value has higher priority + env_val = ( + env_val + if isinstance(alias_env_val, PydanticUndefinedType) + else alias_env_val + ) # delete from file vars when used if env_name in env_file_vars: del env_file_vars[env_name] + if alias_env_name is not None and alias_env_name in env_file_vars: + del env_file_vars[alias_env_name] is_complex, allow_parse_failure = self._field_is_complex(field) if is_complex: @@ -331,30 +336,48 @@ def __init__( _env_nested_delimiter: Optional[str] = None, **values: Any, ) -> None: + settings_config = model_config(__settings_self__.__class__) + env_file = ( + _env_file + if _env_file is not ENV_FILE_SENTINEL + else settings_config.get("env_file", (".env",)) + ) + env_file_encoding = ( + _env_file_encoding + if _env_file_encoding is not None + else settings_config.get("env_file_encoding", "utf-8") + ) + env_nested_delimiter = ( + _env_nested_delimiter + if _env_nested_delimiter is not None + else settings_config.get("env_nested_delimiter", None) + ) + super().__init__( **__settings_self__._settings_build_values( + __settings_self__.__class__, values, - env_file=_env_file, - env_file_encoding=_env_file_encoding, - env_nested_delimiter=_env_nested_delimiter, + env_file=env_file, + env_file_encoding=env_file_encoding, + env_nested_delimiter=env_nested_delimiter, ) ) - __settings_self__._env_file = _env_file - __settings_self__._env_file_encoding = _env_file_encoding - __settings_self__._env_nested_delimiter = _env_nested_delimiter + __settings_self__._env_file = env_file + __settings_self__._env_file_encoding = env_file_encoding + __settings_self__._env_nested_delimiter = env_nested_delimiter - @classmethod + @staticmethod def _settings_build_values( - cls, + settings_cls: type[BaseModel], init_kwargs: dict[str, Any], - env_file: Optional[DOTENV_TYPE] = None, - env_file_encoding: Optional[str] = None, - env_nested_delimiter: Optional[str] = None, + env_file: Optional[DOTENV_TYPE], + env_file_encoding: str, + env_nested_delimiter: Optional[str], ) -> dict[str, Any]: - init_settings = InitSettingsSource(cls, init_kwargs=init_kwargs) + init_settings = InitSettingsSource(settings_cls, init_kwargs=init_kwargs) env_settings = DotEnvSettingsSource( - cls, + settings_cls, env_file=env_file, env_file_encoding=env_file_encoding, env_nested_delimiter=env_nested_delimiter, diff --git a/nonebot/plugin/__init__.py b/nonebot/plugin/__init__.py index a0fe1b85642f..fde938ff2de0 100644 --- a/nonebot/plugin/__init__.py +++ b/nonebot/plugin/__init__.py @@ -177,6 +177,7 @@ def get_plugin_config(config: type[C]) -> C: return type_validate_python( config, BaseSettings._settings_build_values( + config, model_dump(global_config), env_file=global_config._env_file, env_file_encoding=global_config._env_file_encoding, diff --git a/tests/.env.example b/tests/.env.example index 5d416fd5aedc..d6d5b109a86c 100644 --- a/tests/.env.example +++ b/tests/.env.example @@ -10,6 +10,7 @@ NESTED__C__C=3 NESTED__COMPLEX=[1, 2, 3] NESTED_INNER__A=1 NESTED_INNER__B=2 +ALIAS_SIMPLE=aliased_simple OTHER_SIMPLE=simple OTHER_NESTED={"a": 1} OTHER_NESTED__B=2 diff --git a/tests/test_config.py b/tests/test_config.py index dc775744ef42..914cc888c8c8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -37,6 +37,7 @@ class Config( # pyright: ignore[reportIncompatibleVariableOverride] complex_union: Union[int, list[int]] = 1 nested: Simple = Simple() nested_inner: Simple = Simple() + aliased_simple: str = Field(alias="alias_simple") class ExampleWithoutDelimiter(Example): @@ -85,6 +86,8 @@ def test_config_with_env(): with pytest.raises(AttributeError): config.nested_inner__b + assert config.aliased_simple == "aliased_simple" + assert config.common_config == "common" assert config.other_simple == "simple" diff --git a/tests/test_plugin/test_get.py b/tests/test_plugin/test_get.py index 1c75ea024d3e..d389f91a9ee1 100644 --- a/tests/test_plugin/test_get.py +++ b/tests/test_plugin/test_get.py @@ -2,7 +2,6 @@ import pytest import nonebot -from nonebot.compat import model_dump from nonebot.plugin import PluginManager, _managers @@ -71,31 +70,25 @@ class Config(BaseModel): assert config.plugin_config == 1 -def test_plugin_load_env_config(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("TEST_CONFIG_ONE", "no_dummy_val") - monkeypatch.setenv("TEST_CONFIG__TWO", "two") - monkeypatch.setenv("TEST_CFG_THREE", "33") +def test_get_plugin_config_with_env(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("PLUGIN_CONFIG_ONE", "no_dummy_val") + monkeypatch.setenv("PLUGIN_SUB_CONFIG__TWO", "two") + monkeypatch.setenv("PLUGIN_CFG_THREE", "33") monkeypatch.setenv("CONFIG_FROM_INIT", "impossible") class SubConfig(BaseModel): two: str = "dummy_val" class Config(BaseModel): - test_config_one: str = "dummy_val" - test_config: SubConfig = Field(default_factory=SubConfig) - test_config_three: int = Field(default=3, alias="TEST_CFG_THREE") + plugin_config: int + plugin_config_one: str = "dummy_val" + plugin_sub_config: SubConfig = Field(default_factory=SubConfig) + plugin_config_three: int = Field(default=3, alias="plugin_cfg_three") config_from_init: str = "dummy_val" - global_config = nonebot.get_driver().config - assert "test_config_one" not in model_dump(global_config) - assert "TEST_CONFIG_ONE" not in model_dump(global_config) - assert "test_config" not in model_dump(global_config) - assert "TEST_CONFIG" not in model_dump(global_config) - assert "test_config_three" not in model_dump(global_config) - assert "TEST_CFG_THREE" not in model_dump(global_config) - config = nonebot.get_plugin_config(Config) - assert config.test_config_one == "no_dummy_val" - assert config.test_config.two == "two" - assert config.test_config_three == 33 + assert config.plugin_config == 1 + assert config.plugin_config_one == "no_dummy_val" + assert config.plugin_sub_config.two == "two" + assert config.plugin_config_three == 33 assert config.config_from_init == "init" From 7ec52244da966f0f86d9c2d98b3cf62f4f0997c2 Mon Sep 17 00:00:00 2001 From: Ju4tCode <42488585+yanyongyu@users.noreply.github.com> Date: Sat, 18 Oct 2025 10:02:58 +0000 Subject: [PATCH 11/11] :white_check_mark: fix test case --- tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_config.py b/tests/test_config.py index 914cc888c8c8..3ad9bbf42d4e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -37,7 +37,7 @@ class Config( # pyright: ignore[reportIncompatibleVariableOverride] complex_union: Union[int, list[int]] = 1 nested: Simple = Simple() nested_inner: Simple = Simple() - aliased_simple: str = Field(alias="alias_simple") + aliased_simple: str = Field(default="", alias="alias_simple") class ExampleWithoutDelimiter(Example):