From e668396755461ef54ba9a6e3648e1381e3b3bead Mon Sep 17 00:00:00 2001 From: jakkdl Date: Tue, 25 Jun 2024 18:49:04 +0200 Subject: [PATCH 1/3] add PosArgT typing to run() --- src/trio/_core/_run.py | 4 ++-- src/trio/_core/_tests/test_run.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 1512fdf954..c6faabaf72 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -2184,8 +2184,8 @@ def setup_runner( def run( - async_fn: Callable[..., Awaitable[RetT]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + *args: Unpack[PosArgT], clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index d54d9f1813..c4639c4342 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -76,7 +76,7 @@ async def trivial(x: T) -> T: with pytest.raises(TypeError): # Missing an argument - _core.run(trivial) + _core.run(trivial) # type: ignore[arg-type] with pytest.raises(TypeError): # Not an async function From 89763050474ec047f782560b0ee1ddc6e77675e7 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jun 2024 14:36:54 +0200 Subject: [PATCH 2/3] add type tests --- src/trio/_core/_tests/type_tests/run.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/trio/_core/_tests/type_tests/run.py diff --git a/src/trio/_core/_tests/type_tests/run.py b/src/trio/_core/_tests/type_tests/run.py new file mode 100644 index 0000000000..9b16f5c9d5 --- /dev/null +++ b/src/trio/_core/_tests/type_tests/run.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Sequence, overload + +import trio +from typing_extensions import assert_type + + +async def sleep_sort(values: Sequence[float]) -> list[float]: + return [1] + + +async def has_optional(arg: int | None = None) -> int: + return 5 + + +@overload +async def foo_overloaded(arg: int) -> str: ... + + +@overload +async def foo_overloaded(arg: str) -> int: ... + + +async def foo_overloaded(arg: int | str) -> int | str: + if isinstance(arg, str): + return 5 + return "hello" + + +v = trio.run( + sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0) +) +assert_type(v, list[float]) +trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type] +trio.run(sleep_sort) # type: ignore[arg-type] + +r = trio.run(has_optional) +assert_type(r, int) +r = trio.run(has_optional, 5) +trio.run(has_optional, 7, 8) # type: ignore[arg-type] +trio.run(has_optional, "hello") # type: ignore[arg-type] + + +assert_type(trio.run(foo_overloaded, 5), str) +assert_type(trio.run(foo_overloaded, ""), int) From 945a3b5bd9130a42bda3c82614fa477b651304be Mon Sep 17 00:00:00 2001 From: jakkdl Date: Wed, 26 Jun 2024 14:48:54 +0200 Subject: [PATCH 3/3] fix pyright --- src/trio/_core/_tests/type_tests/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trio/_core/_tests/type_tests/run.py b/src/trio/_core/_tests/type_tests/run.py index 9b16f5c9d5..c121ce6c7a 100644 --- a/src/trio/_core/_tests/type_tests/run.py +++ b/src/trio/_core/_tests/type_tests/run.py @@ -31,7 +31,7 @@ async def foo_overloaded(arg: int | str) -> int | str: v = trio.run( sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0) ) -assert_type(v, list[float]) +assert_type(v, "list[float]") trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type] trio.run(sleep_sort) # type: ignore[arg-type]