Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions nonebot/dependencies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
"""

import inspect
from typing import Any, Callable, ForwardRef
from typing import Any, Callable, ForwardRef, cast
from typing_extensions import TypeAliasType

from loguru import logger

from nonebot.compat import ModelField
from nonebot.exception import TypeMisMatch
from nonebot.typing import evaluate_forwardref
from nonebot.typing import evaluate_forwardref, is_type_alias_type


def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
Expand Down Expand Up @@ -46,6 +47,9 @@ def get_typed_annotation(param: inspect.Parameter, globalns: dict[str, Any]) ->
f'Unknown ForwardRef["{param.annotation}"] for parameter {param.name}'
)
return inspect.Parameter.empty
if is_type_alias_type(annotation):
# Python 3.12+ supports PEP 695 TypeAliasType
annotation = cast(TypeAliasType, annotation).__value__
return annotation


Expand Down
11 changes: 11 additions & 0 deletions nonebot/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ def is_none_type(type_: type[t.Any]) -> bool:
return type_ in NONE_TYPES


if sys.version_info < (3, 12):

def is_type_alias_type(type_: type[t.Any]) -> bool:
"""判断是否是 TypeAliasType 类型"""
return isinstance(type_, t_ext.TypeAliasType)
else:

def is_type_alias_type(type_: type[t.Any]) -> bool:
return isinstance(type_, (t.TypeAliasType, t_ext.TypeAliasType))


def evaluate_forwardref(
ref: t.ForwardRef, globalns: dict[str, t.Any], localns: dict[str, t.Any]
) -> t.Any:
Expand Down
8 changes: 5 additions & 3 deletions nonebot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ def generic_check_issubclass(

特别的:

- 如果 cls 是 `typing.TypeVar` 类型,
则会检查其 `__bound__` 或 `__constraints__`
是否是 class_or_tuple 中一个类型的子类或 None。
- 如果 cls 是 `typing.Union` 或 `types.UnionType` 类型,
则会检查其中的所有类型是否是 class_or_tuple 中一个类型的子类或 None。
- 如果 cls 是 `typing.Literal` 类型,
则会检查其中的所有值是否是 class_or_tuple 中一个类型的实例。
- 如果 cls 是 `typing.TypeVar` 类型,
则会检查其 `__bound__` 或 `__constraints__`
是否是 class_or_tuple 中一个类型的子类或 None。
- 如果 cls 是 `typing.List`、`typing.Dict` 等泛型类型,
则会检查其原始类型是否是 class_or_tuple 中一个类型的子类。
"""
# if the target is a TypeVar, we check it first
if isinstance(cls, TypeVar):
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"pygtrie >=2.4.1, <3.0.0",
"exceptiongroup >=1.2.2, <2.0.0",
"python-dotenv >=0.21.0, <2.0.0",
"typing-extensions >=4.4.0, <5.0.0",
"typing-extensions >=4.6.0, <5.0.0",
"tomli >=2.0.1, <3.0.0; python_version < '3.11'",
"pydantic >=1.10.0, <3.0.0, !=2.5.0, !=2.5.1, !=2.10.0, !=2.10.1",
]
Expand Down Expand Up @@ -129,6 +129,9 @@ pythonVersion = "3.9"
pythonPlatform = "All"
defineConstant = { PYDANTIC_V2 = true }
executionEnvironments = [
{ root = "./tests/python_3_12", pythonVersion = "3.12", extraPaths = [
"./",
] },
{ root = "./tests", extraPaths = [
"./",
] },
Expand Down
10 changes: 9 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import wraps
import os
from pathlib import Path
import sys
import threading
from typing import TYPE_CHECKING, Callable, TypeVar
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -67,7 +68,14 @@ def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
@run_once
def load_plugin(anyio_backend, nonebug_init: None) -> set["Plugin"]:
# preload global plugins
return nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
plugins: set["Plugin"] = set()
plugins |= nonebot.load_plugins(str(Path(__file__).parent / "plugins"))
if sys.version_info >= (3, 12):
# preload python 3.12 plugins
plugins |= nonebot.load_plugins(
str(Path(__file__).parent / "python_3_12" / "plugins")
)
return plugins


@pytest.fixture(scope="session", autouse=True)
Expand Down
7 changes: 7 additions & 0 deletions tests/python_3_12/plugins/aliased_param/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pathlib import Path

from nonebot import load_plugins

_sub_plugins = set()

_sub_plugins |= load_plugins(str(Path(__file__).parent))
10 changes: 10 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_arg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Annotated

from nonebot.adapters import Message
from nonebot.params import Arg

type AliasedArg = Annotated[Message, Arg()]


async def aliased_arg(key: AliasedArg) -> Message:
return key
7 changes: 7 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from nonebot.adapters import Bot

type AliasedBot = Bot


async def get_aliased_bot(b: AliasedBot) -> Bot:
return b
21 changes: 21 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_depend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Annotated

from nonebot import on_message
from nonebot.params import Depends

test_depends = on_message()

runned = []


def dependency():
runned.append(1)
return 1


type AliasedDepends = Annotated[int, Depends(dependency)]


@test_depends.handle()
async def aliased_depends(x: AliasedDepends):
return x
7 changes: 7 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from nonebot.adapters import Event

type AliasedEvent = Event


async def aliased_event(e: AliasedEvent) -> Event:
return e
5 changes: 5 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
type AliasedException = Exception


async def aliased_exc(e: AliasedException) -> Exception:
return e
7 changes: 7 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from nonebot.matcher import Matcher

type AliasedMatcher = Matcher


async def aliased_matcher(m: AliasedMatcher) -> Matcher:
return m
7 changes: 7 additions & 0 deletions tests/python_3_12/plugins/aliased_param/param_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from nonebot.typing import T_State

type AliasedState = T_State


async def aliased_state(x: AliasedState) -> T_State:
return x
3 changes: 3 additions & 0 deletions tests/python_3_12/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[tool.ruff]
extend = "../pyproject.toml"
target-version = "py312"
121 changes: 121 additions & 0 deletions tests/test_param.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import suppress
import re
import sys

from exceptiongroup import BaseExceptionGroup
from nonebug import App
Expand Down Expand Up @@ -156,6 +157,22 @@ async def test_depend(app: App):
ctx.should_return(1)


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_depend(app: App):
from python_3_12.plugins.aliased_param.param_depend import aliased_depends, runned

async with app.test_dependent(aliased_depends, allow_types=[DependParam]) as ctx:
ctx.should_return(1)

assert len(runned) == 1
assert runned[0] == 1

runned.clear()


@pytest.mark.anyio
async def test_bot(app: App):
from plugins.param.param_bot import (
Expand Down Expand Up @@ -221,6 +238,19 @@ async def test_bot(app: App):
app.test_dependent(not_bot, allow_types=[BotParam])


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_bot(app: App):
from python_3_12.plugins.aliased_param.param_bot import get_aliased_bot

async with app.test_dependent(get_aliased_bot, allow_types=[BotParam]) as ctx:
bot = ctx.create_bot()
ctx.pass_params(bot=bot)
ctx.should_return(bot)


@pytest.mark.anyio
async def test_event(app: App):
from plugins.param.param_event import (
Expand Down Expand Up @@ -310,6 +340,21 @@ async def test_event(app: App):
ctx.should_return(fake_event.is_tome())


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_event(app: App):
from python_3_12.plugins.aliased_param.param_event import aliased_event

fake_message = FakeMessage("text")
fake_event = make_fake_event(_message=fake_message)()

async with app.test_dependent(aliased_event, allow_types=[EventParam]) as ctx:
ctx.pass_params(event=fake_event)
ctx.should_return(fake_event)


@pytest.mark.anyio
async def test_state(app: App):
from plugins.param.param_state import (
Expand Down Expand Up @@ -461,6 +506,37 @@ async def test_state(app: App):
ctx.should_return(fake_state[KEYWORD_KEY])


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_state(app: App):
from python_3_12.plugins.aliased_param.param_state import aliased_state

fake_message = FakeMessage("text")
fake_matched = re.match(r"\[cq:(?P<type>.*?),(?P<arg>.*?)\]", "[cq:test,arg=value]")
fake_state = {
PREFIX_KEY: {
CMD_KEY: ("cmd",),
RAW_CMD_KEY: "/cmd",
CMD_START_KEY: "/",
CMD_ARG_KEY: fake_message,
CMD_WHITESPACE_KEY: " ",
},
SHELL_ARGV: ["-h"],
SHELL_ARGS: {"help": True},
REGEX_MATCHED: fake_matched,
STARTSWITH_KEY: "startswith",
ENDSWITH_KEY: "endswith",
FULLMATCH_KEY: "fullmatch",
KEYWORD_KEY: "keyword",
}

async with app.test_dependent(aliased_state, allow_types=[StateParam]) as ctx:
ctx.pass_params(state=fake_state)
ctx.should_return(fake_state)


@pytest.mark.anyio
async def test_matcher(app: App):
from plugins.param.param_matcher import (
Expand Down Expand Up @@ -573,6 +649,20 @@ async def test_matcher(app: App):
ctx.should_return(False)


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_matcher(app: App):
from python_3_12.plugins.aliased_param.param_matcher import aliased_matcher

fake_matcher = Matcher()

async with app.test_dependent(aliased_matcher, allow_types=[MatcherParam]) as ctx:
ctx.pass_params(matcher=fake_matcher)
ctx.should_return(fake_matcher)


@pytest.mark.anyio
async def test_arg(app: App):
from plugins.param.param_arg import (
Expand Down Expand Up @@ -642,11 +732,28 @@ async def test_arg(app: App):
ctx.should_return(message.extract_plain_text())


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_arg(app: App):
from python_3_12.plugins.aliased_param.param_arg import aliased_arg

matcher = Matcher()
message = FakeMessage("text")
matcher.set_arg("key", message)

async with app.test_dependent(aliased_arg, allow_types=[ArgParam]) as ctx:
ctx.pass_params(matcher=matcher)
ctx.should_return(message)


@pytest.mark.anyio
async def test_exception(app: App):
from plugins.param.param_exception import exc, legacy_exc

exception = ValueError("test")

async with app.test_dependent(exc, allow_types=[ExceptionParam]) as ctx:
ctx.pass_params(exception=exception)
ctx.should_return(exception)
Expand All @@ -656,6 +763,20 @@ async def test_exception(app: App):
ctx.should_return(exception)


@pytest.mark.anyio
@pytest.mark.skipif(
sys.version_info < (3, 12), reason="TypeAlias requires Python 3.12 or higher"
)
async def test_aliased_exception(app: App):
from python_3_12.plugins.aliased_param.param_exception import aliased_exc

exception = ValueError("test")

async with app.test_dependent(aliased_exc, allow_types=[ExceptionParam]) as ctx:
ctx.pass_params(exception=exception)
ctx.should_return(exception)


@pytest.mark.anyio
async def test_default(app: App):
from plugins.param.param_default import default
Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading