From 9ce1ac72c9d8ede9549273b5ac01f21586640fcc Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 22 Aug 2023 17:15:05 +0100 Subject: [PATCH 1/5] Add some tests --- mypy/test/teststubtest.py | 33 ++++++++++++++++++++++++++++++-- test-data/unit/lib-stub/enum.pyi | 8 ++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index cd72bd9300d1..3fac26c68323 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -954,16 +954,15 @@ def fizz(self): pass @collect_cases def test_enum(self) -> Iterator[Case]: + yield Case(stub="import enum", runtime="import enum", error=None) yield Case( stub=""" - import enum class X(enum.Enum): a: int b: str c: str """, runtime=""" - import enum class X(enum.Enum): a = 1 b = "asdf" @@ -971,6 +970,36 @@ class X(enum.Enum): """, error="X.c", ) + yield Case( + stub=""" + class Flags1(enum.IntFlag): + a: int + b: int + def foo(x: Flags1 = ...) -> None: ... + """, + runtime=""" + class Flags1(enum.IntFlag): + a = 0b000000000100 + b = 0b000000001000 + def foo(x=Flags1.a|Flags1.b): pass + """, + error=None, + ) + yield Case( + stub=""" + class Flags2(enum.IntFlag): + a: int + b: int + def bar(x: Flags2 | None = None) -> None: ... + """, + runtime=""" + class Flags2(enum.IntFlag): + a = 0b000000000100 + b = 0b000000001000 + def bar(x=Flags2.a|Flags2.b): pass + """, + error="bar", + ) @collect_cases def test_decorator(self) -> Iterator[Case]: diff --git a/test-data/unit/lib-stub/enum.pyi b/test-data/unit/lib-stub/enum.pyi index 11adfc597955..ed1f64def949 100644 --- a/test-data/unit/lib-stub/enum.pyi +++ b/test-data/unit/lib-stub/enum.pyi @@ -1,3 +1,4 @@ +import sys from typing import Any, TypeVar, Union, Type, Sized, Iterator _T = TypeVar('_T') @@ -35,6 +36,13 @@ def unique(enumeration: _T) -> _T: pass class Flag(Enum): def __or__(self: _T, other: Union[int, _T]) -> _T: pass + def __and__(self: _T, other: _T) -> _T: pass + def __xor__(self: _T, other: _T) -> _T: pass + def __invert__(self) -> Self: pass + if sys.version_info >= (3, 11): + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ class IntFlag(int, Flag): From e6566f69deb2321899c76a5d65ae6c841f840351 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 22 Aug 2023 18:20:37 +0100 Subject: [PATCH 2/5] Take 2 --- mypy/test/teststubtest.py | 42 +++++++++++++++++++++++++++++++- test-data/unit/lib-stub/enum.pyi | 8 ------ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 3fac26c68323..0766ac59710e 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -64,6 +64,7 @@ def __init__(self, name: str) -> None: ... class Coroutine(Generic[_T_co, _S, _R]): ... class Iterable(Generic[_T_co]): ... +class Iterator(Iterable[_T_co]): ... class Mapping(Generic[_K, _V]): ... class Match(Generic[AnyStr]): ... class Sequence(Iterable[_T_co]): ... @@ -86,7 +87,9 @@ def __init__(self) -> None: pass def __repr__(self) -> str: pass class type: ... -class tuple(Sequence[T_co], Generic[T_co]): ... +class tuple(Sequence[T_co], Generic[T_co]): + def __ge__(self, __other: tuple[T_co, ...]) -> bool: pass + class dict(Mapping[KT, VT]): ... class function: pass @@ -105,6 +108,41 @@ def classmethod(f: T) -> T: ... def staticmethod(f: T) -> T: ... """ +stubtest_enum_stub = """ +import sys +from typing import Any, TypeVar, Iterator + +_T = TypeVar('_T') + +class EnumMeta(type): + def __len__(self) -> int: pass + def __iter__(self: type[_T]) -> Iterator[_T]: pass + def __reversed__(self: type[_T]) -> Iterator[_T]: pass + def __getitem__(self: type[_T], name: str) -> _T: pass + +class Enum(metaclass=EnumMeta): + def __new__(cls: type[_T], value: object) -> _T: pass + def __repr__(self) -> str: pass + def __str__(self) -> str: pass + def __format__(self, format_spec: str) -> str: pass + def __hash__(self) -> Any: pass + def __reduce_ex__(self, proto: Any) -> Any: pass + name: str + value: Any + +class Flag(Enum): + def __or__(self: _T, other: _T) -> _T: pass + def __and__(self: _T, other: _T) -> _T: pass + def __xor__(self: _T, other: _T) -> _T: pass + def __invert__(self: _T) -> _T: pass + if sys.version_info >= (3, 11): + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + +class IntFlag(int, Flag): pass +""" + def run_stubtest( stub: str, runtime: str, options: list[str], config_file: str | None = None @@ -114,6 +152,8 @@ def run_stubtest( f.write(stubtest_builtins_stub) with open("typing.pyi", "w") as f: f.write(stubtest_typing_stub) + with open("enum.pyi", "w") as f: + f.write(stubtest_enum_stub) with open(f"{TEST_MODULE_NAME}.pyi", "w") as f: f.write(stub) with open(f"{TEST_MODULE_NAME}.py", "w") as f: diff --git a/test-data/unit/lib-stub/enum.pyi b/test-data/unit/lib-stub/enum.pyi index ed1f64def949..11adfc597955 100644 --- a/test-data/unit/lib-stub/enum.pyi +++ b/test-data/unit/lib-stub/enum.pyi @@ -1,4 +1,3 @@ -import sys from typing import Any, TypeVar, Union, Type, Sized, Iterator _T = TypeVar('_T') @@ -36,13 +35,6 @@ def unique(enumeration: _T) -> _T: pass class Flag(Enum): def __or__(self: _T, other: Union[int, _T]) -> _T: pass - def __and__(self: _T, other: _T) -> _T: pass - def __xor__(self: _T, other: _T) -> _T: pass - def __invert__(self) -> Self: pass - if sys.version_info >= (3, 11): - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ class IntFlag(int, Flag): From 5f113531ab8f6b28a114e7f1796d4cbf232c7be7 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 22 Aug 2023 19:36:01 +0100 Subject: [PATCH 3/5] Fix the crash --- mypy/stubtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 906a8c923b37..b2506e6dcc02 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -1553,7 +1553,7 @@ def anytype() -> mypy.types.AnyType: value: bool | int | str if isinstance(runtime, bytes): value = bytes_to_human_readable_repr(runtime) - elif isinstance(runtime, enum.Enum): + elif isinstance(runtime, enum.Enum) and isinstance(runtime.name, str): value = runtime.name elif isinstance(runtime, (bool, int, str)): value = runtime From 01f43b8f4f4e6f2c96c0155becf046f1beebcef5 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Tue, 22 Aug 2023 21:16:50 +0100 Subject: [PATCH 4/5] simplify test slightly --- mypy/test/teststubtest.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 0766ac59710e..b90982f6da42 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -139,8 +139,6 @@ def __invert__(self: _T) -> _T: pass __ror__ = __or__ __rand__ = __and__ __rxor__ = __xor__ - -class IntFlag(int, Flag): pass """ @@ -1012,30 +1010,30 @@ class X(enum.Enum): ) yield Case( stub=""" - class Flags1(enum.IntFlag): + class Flags1(enum.Flag): a: int b: int def foo(x: Flags1 = ...) -> None: ... """, runtime=""" - class Flags1(enum.IntFlag): - a = 0b000000000100 - b = 0b000000001000 + class Flags1(enum.Flag): + a = 1 + b = 2 def foo(x=Flags1.a|Flags1.b): pass """, error=None, ) yield Case( stub=""" - class Flags2(enum.IntFlag): + class Flags2(enum.Flag): a: int b: int def bar(x: Flags2 | None = None) -> None: ... """, runtime=""" - class Flags2(enum.IntFlag): - a = 0b000000000100 - b = 0b000000001000 + class Flags2(enum.Flag): + a = 1 + b = 2 def bar(x=Flags2.a|Flags2.b): pass """, error="bar", From d3b2a081821aac41e3f3d4d6acaa3e846c5ac076 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Tue, 22 Aug 2023 21:44:06 +0100 Subject: [PATCH 5/5] Also add some tests that crash on py311+ --- mypy/test/teststubtest.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index b90982f6da42..a6733a9e8bd0 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -1038,6 +1038,36 @@ def bar(x=Flags2.a|Flags2.b): pass """, error="bar", ) + yield Case( + stub=""" + class Flags3(enum.Flag): + a: int + b: int + def baz(x: Flags3 | None = ...) -> None: ... + """, + runtime=""" + class Flags3(enum.Flag): + a = 1 + b = 2 + def baz(x=Flags3(0)): pass + """, + error=None, + ) + yield Case( + stub=""" + class Flags4(enum.Flag): + a: int + b: int + def spam(x: Flags4 | None = None) -> None: ... + """, + runtime=""" + class Flags4(enum.Flag): + a = 1 + b = 2 + def spam(x=Flags4(0)): pass + """, + error="spam", + ) @collect_cases def test_decorator(self) -> Iterator[Case]: