From ab3ed33b65f53786d214be34d147a60edaf1b887 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 5 Aug 2025 01:47:22 +0200 Subject: [PATCH 1/8] Fix enum comparison with literal values --- mypy/meet.py | 5 +++- test-data/unit/check-enum.test | 47 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/mypy/meet.py b/mypy/meet.py index 349c15e668c3..229c8a75b78d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -547,9 +547,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value: + if left.value == right.value or (left.fallback.type.is_enum ^ right.fallback.type.is_enum): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. + # Enums are more interesting: + # * if both sides are enums, they should have same values + # * if exactly one of them is a enum, fallback compatibibility is enough left = left.fallback right = right.fallback else: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 3bcf9745a801..f10975926756 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2681,3 +2681,50 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis" reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis" [builtins fixtures/enum.pyi] + +[case testEnumItemsEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = "b" + b = "a" + +A.a == "a" +A.a == "b" + +A.a == A.a +A.a == A.b # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + +class B(StrEnum): + a = "b" + b = "a" + +B.a == "a" +B.a == "b" + +B.a == B.a +B.a == B.b # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + +B.a == A.a # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + +class C(IntEnum): + a = 0 + +C.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") +C.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + +C.a == C.a +C.a == C.b + +class D(int, Enum): + a = 0 + +D.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") +D.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + +D.a == D.a +D.a == D.b + +D.a == C.a # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") +[builtins fixtures/dict.pyi] From 350aa53ed54d7dd3668b44487a18dff217016cf9 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 18:49:49 +0200 Subject: [PATCH 2/8] Sync reachability and comparison overlap checks --- mypy/checker.py | 12 ++- mypy/meet.py | 7 +- mypy/subtypes.py | 29 +++++ test-data/unit/check-enum.test | 186 +++++++++++++++++++++++++++++---- 4 files changed, 212 insertions(+), 22 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 68f9bd4c1383..f1ec4cbfdeaf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -146,6 +146,7 @@ find_member, infer_class_variances, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_more_precise, is_proper_subtype, @@ -6783,13 +6784,16 @@ def should_coerce_inner(typ: Type) -> bool: expr_type = coerce_to_literal(expr_type) if not is_valid_target(get_proper_type(expr_type)): continue - if target and not is_same_type(target, expr_type): + if ( + target is not None + and not is_same_type(target, expr_type) + and not is_enum_value_pair(target, expr_type) + ): # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} target = expr_type possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. if target is None: @@ -9125,7 +9129,9 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: # let's be conservative result.add("") elif isinstance(t, LiteralType): - result.update(_ambiguous_enum_variants([t.fallback])) + if t.fallback.type.is_enum: + result.update(_ambiguous_enum_variants([t.fallback])) + # Other literals (str, int, bool) cannot introduce any surprises elif isinstance(t, NoneType): pass else: diff --git a/mypy/meet.py b/mypy/meet.py index 229c8a75b78d..2df18f2a3cf4 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -10,6 +10,7 @@ are_parameters_compatible, find_member, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_proper_subtype, is_same_type, @@ -547,7 +548,11 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value or (left.fallback.type.is_enum ^ right.fallback.type.is_enum): + if ( + left.value == right.value + and left.fallback.type.is_enum == right.fallback.type.is_enum + or is_enum_value_pair(left, right) + ): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. # Enums are more interesting: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..f1be900c6232 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -37,6 +37,7 @@ from mypy.options import Options from mypy.state import state from mypy.types import ( + ELLIPSIS_TYPE_NAMES, MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, TYPED_NAMEDTUPLE_NAMES, @@ -286,6 +287,34 @@ def is_same_type( ) +def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): + return False + if b.fallback.type.is_enum: + a, b = b, a + if b.fallback.type.is_enum: + return False + # At this point we have a pair (non-enum literal, enum literal). + # Check that the non-enum fallback is compatible + if not is_subtype(a.fallback, b.fallback): + return False + assert isinstance(a.value, str) + enum_value = a.fallback.type.get(a.value) + return ( + enum_value is not None + and enum_value.type is not None + and isinstance(enum_value.type, Instance) + and ( + enum_value.type.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or enum_value.type.type.fullname in ELLIPSIS_TYPE_NAMES + ) + ) + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index f10975926756..18ab4ee7794d 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2690,41 +2690,191 @@ class A(str, Enum): a = "b" b = "a" -A.a == "a" -A.a == "b" +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) -A.a == A.a -A.a == A.b # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") +if A.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal['a']") + 1 + 'a' +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' class B(StrEnum): a = "b" b = "a" -B.a == "a" -B.a == "b" +if B.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal['a']") + 1 + 'a' +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") -B.a == B.a -B.a == B.b # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' -B.a == A.a # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' class C(IntEnum): a = 0 + b = 1 + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' -C.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") -C.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[1]") + 1 + 'a' -C.a == C.a -C.a == C.b +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' class D(int, Enum): a = 0 + b = 1 + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[1]") + 1 + 'a' + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' -D.a == "a" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") -D.a == "b" # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsInStub] +# flags: --python-version=3.11 --strict-equality +from mystub import A, B, C, D + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +if B.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' + +[file mystub.pyi] +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = ... + b = ... -D.a == D.a -D.a == D.b +class B(StrEnum): + a = ... + b = ... + +class C(int, Enum): + a = ... + b = ... + +class D(IntEnum): + a = ... + b = ... +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsWithAlias-xfail] +# flags: --python-version=3.11 --strict-equality +# TODO: mypy does not support enum member aliases now. +from enum import Enum, IntEnum + +class A(str, Enum): + a = "c" + b = a + +if A.a == A.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +class B(IntEnum): + a = 0 + b = a -D.a == C.a # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") +if B.a == B.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") [builtins fixtures/dict.pyi] From ebfc647bcbd243d191f71cfdef0cba6a1a73ce41 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 19:10:47 +0200 Subject: [PATCH 3/8] Fix selfcheck --- mypy/subtypes.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index f1be900c6232..b2233ace9831 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -287,7 +287,10 @@ def is_same_type( ) -def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: +def is_enum_value_pair(a: Type, b: Type) -> bool: + a = get_proper_type(a) + b = get_proper_type(b) + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): return False if b.fallback.type.is_enum: @@ -300,18 +303,16 @@ def is_enum_value_pair(a: ProperType, b: ProperType) -> bool: return False assert isinstance(a.value, str) enum_value = a.fallback.type.get(a.value) - return ( - enum_value is not None - and enum_value.type is not None - and isinstance(enum_value.type, Instance) - and ( - enum_value.type.last_known_value == b - # TODO: this is too lax and should only be applied for enums defined in stubs, - # but checking that strictly requires access to the checker. This function - # is needed in `is_overlapping_types` and operates on a lower level, - # so doing this properly would be more difficult. - or enum_value.type.type.fullname in ELLIPSIS_TYPE_NAMES - ) + if enum_value is None or enum_value.type is None: + return False + proper_value = get_proper_type(enum_value.type) + return isinstance(proper_value, Instance) and ( + proper_value.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES ) From 0001936504a9c65ba509156bd6b06b4024548f18 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 19:34:03 +0200 Subject: [PATCH 4/8] Oops --- mypy/subtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index b2233ace9831..772fce2a5f96 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -295,9 +295,9 @@ def is_enum_value_pair(a: Type, b: Type) -> bool: return False if b.fallback.type.is_enum: a, b = b, a - if b.fallback.type.is_enum: + if b.fallback.type.is_enum or not a.fallback.type.is_enum: return False - # At this point we have a pair (non-enum literal, enum literal). + # At this point we have a pair (enum literal, non-enum literal). # Check that the non-enum fallback is compatible if not is_subtype(a.fallback, b.fallback): return False From c87c7074149944b7578f391f6bb3a2de270026fa Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 7 Aug 2025 20:46:46 +0200 Subject: [PATCH 5/8] Prevent narrowing by equality to overlapping literals (discarding enum info). --- mypy/checker.py | 24 ++++++++++++++++++++++-- test-data/unit/check-narrowing.test | 5 +++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f1ec4cbfdeaf..958b7d75dd4c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6866,8 +6866,28 @@ def should_coerce_inner(typ: Type) -> bool: # We intentionally use 'conditional_types' directly here instead of # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - types = conditional_types(expr_type, target_type) - partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) + yes, no = conditional_types(expr_type, target_type) + # If we encounter `enum_value == 1` checks (enum vs literal), we do not want + # to narrow the former to literal and should preserve the enum identity. + # TODO: maybe we should infer literals here? + if ( + isinstance(get_proper_type(yes), LiteralType) + and isinstance(proper_expr := get_proper_type(expr_type), Instance) + and proper_expr.type.is_enum + ): + yes_items = [] + for name in proper_expr.type.enum_members: + e = proper_expr.type.get(name) + if ( + e is not None + and isinstance(proper_e := get_proper_type(e.type), Instance) + and proper_e.last_known_value == yes + ): + name_val = LiteralType(name, fallback=proper_expr) + yes_items.append(proper_expr.copy_modified(last_known_value=name_val)) + if yes_items: + yes = UnionType.make_union(yes_items) + partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no)) return reduce_conditional_maps(partial_type_maps) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7fffd3ce94e5..cd2487a6e42e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None: else: reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + if x == 1: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + def f4(x: IE | E) -> None: if x == IE.X: reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" From ba560f23ed0379aad24a8718ea4c4d335782adec Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 28 Aug 2025 13:30:47 +0200 Subject: [PATCH 6/8] No, that was too much --- mypy/checker.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 958b7d75dd4c..f1ec4cbfdeaf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6866,28 +6866,8 @@ def should_coerce_inner(typ: Type) -> bool: # We intentionally use 'conditional_types' directly here instead of # 'self.conditional_types_with_intersection': we only compute ad-hoc # intersections when working with pure instances. - yes, no = conditional_types(expr_type, target_type) - # If we encounter `enum_value == 1` checks (enum vs literal), we do not want - # to narrow the former to literal and should preserve the enum identity. - # TODO: maybe we should infer literals here? - if ( - isinstance(get_proper_type(yes), LiteralType) - and isinstance(proper_expr := get_proper_type(expr_type), Instance) - and proper_expr.type.is_enum - ): - yes_items = [] - for name in proper_expr.type.enum_members: - e = proper_expr.type.get(name) - if ( - e is not None - and isinstance(proper_e := get_proper_type(e.type), Instance) - and proper_e.last_known_value == yes - ): - name_val = LiteralType(name, fallback=proper_expr) - yes_items.append(proper_expr.copy_modified(last_known_value=name_val)) - if yes_items: - yes = UnionType.make_union(yes_items) - partial_type_maps.append(conditional_types_to_typemaps(expr, yes, no)) + types = conditional_types(expr_type, target_type) + partial_type_maps.append(conditional_types_to_typemaps(expr, *types)) return reduce_conditional_maps(partial_type_maps) From d96243d4c96096fd8efcc994c43cbb63366bc079 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Thu, 28 Aug 2025 14:46:03 +0200 Subject: [PATCH 7/8] Pick unreachable branches more aggressively --- mypy/checker.py | 25 +++++++++++++++++-------- test-data/unit/check-enum.test | 3 +++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index f1ec4cbfdeaf..257cda60b9fc 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6507,6 +6507,7 @@ def equality_type_narrowing_helper( if operator in {"is", "is not"}: is_valid_target: Callable[[Type], bool] = is_singleton_type coerce_only_in_literal_context = False + no_custom_eq = True should_narrow_by_identity = True else: @@ -6522,14 +6523,16 @@ def has_no_custom_eq_checks(t: Type) -> bool: coerce_only_in_literal_context = True expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity = all( - map(has_no_custom_eq_checks, expr_types) - ) and not is_ambiguous_mix_of_enums(expr_types) + no_custom_eq = all(map(has_no_custom_eq_checks, expr_types)) + should_narrow_by_identity = not is_ambiguous_mix_of_enums(expr_types) if_map: TypeMap = {} else_map: TypeMap = {} - if should_narrow_by_identity: - if_map, else_map = self.refine_identity_comparison_expression( + if no_custom_eq: + # Try to narrow the types or at least identify unreachable blocks. + # If there's some mix of enums and values, we do not want to narrow enums + # to literals, but still want to detect unreachable branches. + if_map_optimistic, else_map_optimistic = self.refine_identity_comparison_expression( operands, operand_types, expr_indices, @@ -6537,6 +6540,14 @@ def has_no_custom_eq_checks(t: Type) -> bool: is_valid_target, coerce_only_in_literal_context, ) + if should_narrow_by_identity: + if_map = if_map_optimistic + else_map = else_map_optimistic + else: + if if_map_optimistic is None: + if_map = None + if else_map_optimistic is None: + else_map = None if if_map == {} and else_map == {}: if_map, else_map = self.refine_away_none_in_comparison( @@ -9129,9 +9140,7 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: # let's be conservative result.add("") elif isinstance(t, LiteralType): - if t.fallback.type.is_enum: - result.update(_ambiguous_enum_variants([t.fallback])) - # Other literals (str, int, bool) cannot introduce any surprises + result.update(_ambiguous_enum_variants([t.fallback])) elif isinstance(t, NoneType): pass else: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 18ab4ee7794d..e415065a2b15 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2703,6 +2703,9 @@ if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A if A.a == A.a: 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +else: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") 1 + 'a' From 79262e75facbec39b5c7c426ef20b54dce72bfa4 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 29 Aug 2025 15:53:27 +0200 Subject: [PATCH 8/8] Handle (str, Enum) and (int, Enum) subclasses narrowing --- mypy/checker.py | 3 +- test-data/unit/check-enum.test | 105 +++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5bd1f63fa73e..b8ea856ed229 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -9208,7 +9208,8 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: if t.last_known_value: result.update(_ambiguous_enum_variants([t.last_known_value])) elif t.type.is_enum and any( - base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro + base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int") + for base in t.type.mro ): result.add(t.type.fullname) elif not t.type.is_enum: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index e415065a2b15..c691eadbca90 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2881,3 +2881,108 @@ class B(IntEnum): if B.a == B.b: 1 + 'a' # E: Unsupported operand types for + ("int" and "str") [builtins fixtures/dict.pyi] + + +[case testEnumNarrowingByEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +# Every `if` block in this test should either reveal or report non-overlapping. + +class A(str, Enum): + a = "b" + b = "a" +class B(StrEnum): + a = "b" + b = "a" +class C(int, Enum): + a = 0 + b = 1 +class D(IntEnum): + a = 0 + b = 1 + +a: A +if a == A.a: + reveal_type(a) # N: Revealed type is "Literal[__main__.A.a]" +else: + reveal_type(a) # N: Revealed type is "Literal[__main__.A.b]" + +if a == "a": + reveal_type(a) # N: Revealed type is "__main__.A" +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +if a == "c": + reveal_type(a) # N: Revealed type is "__main__.A" +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +if a == 0: # E: Non-overlapping equality check (left operand type: "A", right operand type: "Literal[0]") + reveal_type(a) +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +b: B +if b == B.a: + reveal_type(b) # N: Revealed type is "Literal[__main__.B.a]" +else: + reveal_type(b) # N: Revealed type is "Literal[__main__.B.b]" + +if b == "a": + reveal_type(b) # N: Revealed type is "__main__.B" +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +if b == "c": + reveal_type(b) # N: Revealed type is "__main__.B" +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +if b == 0: # E: Non-overlapping equality check (left operand type: "B", right operand type: "Literal[0]") + reveal_type(b) +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +c: C +if c == C.a: + reveal_type(c) # N: Revealed type is "Literal[__main__.C.a]" +else: + reveal_type(c) # N: Revealed type is "Literal[__main__.C.b]" + +if c == 0: + reveal_type(c) # N: Revealed type is "__main__.C" +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +if c == 2: + reveal_type(c) # N: Revealed type is "__main__.C" +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +if c == "a": # E: Non-overlapping equality check (left operand type: "C", right operand type: "Literal['a']") + reveal_type(c) +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +d: D +if d == D.a: + reveal_type(d) # N: Revealed type is "Literal[__main__.D.a]" +else: + reveal_type(d) # N: Revealed type is "Literal[__main__.D.b]" + +if d == 0: + reveal_type(d) # N: Revealed type is "__main__.D" +else: + reveal_type(d) # N: Revealed type is "__main__.D" + +if d == 2: + reveal_type(d) # N: Revealed type is "__main__.D" +else: + reveal_type(d) # N: Revealed type is "__main__.D" + +if d == "a": # E: Non-overlapping equality check (left operand type: "D", right operand type: "Literal['a']") + reveal_type(d) +else: + reveal_type(d) # N: Revealed type is "__main__.D" +[builtins fixtures/dict.pyi]