diff --git a/mypy/checker.py b/mypy/checker.py index 63e128f78310..e37313bdc27b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -161,6 +161,7 @@ find_member, infer_class_variances, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_more_precise, is_proper_subtype, @@ -6628,6 +6629,7 @@ def equality_type_narrowing_helper( if operator in {"is", "is not"}: is_valid_target: Callable[[Type], bool] = is_singleton_type coerce_only_in_literal_context = False + no_custom_eq = True should_narrow_by_identity = True else: @@ -6643,14 +6645,16 @@ def has_no_custom_eq_checks(t: Type) -> bool: coerce_only_in_literal_context = True expr_types = [operand_types[i] for i in expr_indices] - should_narrow_by_identity = all( - map(has_no_custom_eq_checks, expr_types) - ) and not is_ambiguous_mix_of_enums(expr_types) + no_custom_eq = all(map(has_no_custom_eq_checks, expr_types)) + should_narrow_by_identity = not is_ambiguous_mix_of_enums(expr_types) if_map: TypeMap = {} else_map: TypeMap = {} - if should_narrow_by_identity: - if_map, else_map = self.refine_identity_comparison_expression( + if no_custom_eq: + # Try to narrow the types or at least identify unreachable blocks. + # If there's some mix of enums and values, we do not want to narrow enums + # to literals, but still want to detect unreachable branches. + if_map_optimistic, else_map_optimistic = self.refine_identity_comparison_expression( operands, operand_types, expr_indices, @@ -6658,6 +6662,14 @@ def has_no_custom_eq_checks(t: Type) -> bool: is_valid_target, coerce_only_in_literal_context, ) + if should_narrow_by_identity: + if_map = if_map_optimistic + else_map = else_map_optimistic + else: + if if_map_optimistic is None: + if_map = None + if else_map_optimistic is None: + else_map = None if if_map == {} and else_map == {}: if_map, else_map = self.refine_away_none_in_comparison( @@ -6905,13 +6917,16 @@ def should_coerce_inner(typ: Type) -> bool: expr_type = coerce_to_literal(expr_type) if not is_valid_target(get_proper_type(expr_type)): continue - if target and not is_same_type(target, expr_type): + if ( + target is not None + and not is_same_type(target, expr_type) + and not is_enum_value_pair(target, expr_type) + ): # We have multiple disjoint target types. So the 'if' branch # must be unreachable. return None, {} target = expr_type possible_target_indices.append(i) - # There's nothing we can currently infer if none of the operands are valid targets, # so we end early and infer nothing. if target is None: @@ -9291,7 +9306,8 @@ def _ambiguous_enum_variants(types: list[Type]) -> set[str]: if t.last_known_value: result.update(_ambiguous_enum_variants([t.last_known_value])) elif t.type.is_enum and any( - base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro + base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int") + for base in t.type.mro ): result.add(t.type.fullname) elif not t.type.is_enum: diff --git a/mypy/meet.py b/mypy/meet.py index 1cb291ff90d5..073fc1756cba 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -10,6 +10,7 @@ are_parameters_compatible, find_member, is_callable_compatible, + is_enum_value_pair, is_equivalent, is_proper_subtype, is_same_type, @@ -559,9 +560,16 @@ def _type_object_overlap(left: Type, right: Type) -> bool: right = right.fallback if isinstance(left, LiteralType) and isinstance(right, LiteralType): - if left.value == right.value: + if ( + left.value == right.value + and left.fallback.type.is_enum == right.fallback.type.is_enum + or is_enum_value_pair(left, right) + ): # If values are the same, we still need to check if fallbacks are overlapping, # this is done below. + # Enums are more interesting: + # * if both sides are enums, they should have same values + # * if exactly one of them is a enum, fallback compatibibility is enough left = left.fallback right = right.fallback else: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..772fce2a5f96 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -37,6 +37,7 @@ from mypy.options import Options from mypy.state import state from mypy.types import ( + ELLIPSIS_TYPE_NAMES, MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, TYPED_NAMEDTUPLE_NAMES, @@ -286,6 +287,35 @@ def is_same_type( ) +def is_enum_value_pair(a: Type, b: Type) -> bool: + a = get_proper_type(a) + b = get_proper_type(b) + + if not isinstance(a, LiteralType) or not isinstance(b, LiteralType): + return False + if b.fallback.type.is_enum: + a, b = b, a + if b.fallback.type.is_enum or not a.fallback.type.is_enum: + return False + # At this point we have a pair (enum literal, non-enum literal). + # Check that the non-enum fallback is compatible + if not is_subtype(a.fallback, b.fallback): + return False + assert isinstance(a.value, str) + enum_value = a.fallback.type.get(a.value) + if enum_value is None or enum_value.type is None: + return False + proper_value = get_proper_type(enum_value.type) + return isinstance(proper_value, Instance) and ( + proper_value.last_known_value == b + # TODO: this is too lax and should only be applied for enums defined in stubs, + # but checking that strictly requires access to the checker. This function + # is needed in `is_overlapping_types` and operates on a lower level, + # so doing this properly would be more difficult. + or proper_value.type.fullname in ELLIPSIS_TYPE_NAMES + ) + + # This is a common entry point for subtyping checks (both proper and non-proper). # Never call this private function directly, use the public versions. def _is_subtype( diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 3bcf9745a801..c691eadbca90 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2681,3 +2681,308 @@ reveal_type(Wrapper.Nested.FOO) # N: Revealed type is "Literal[__main__.Wrapper reveal_type(Wrapper.Nested.FOO.value) # N: Revealed type is "builtins.ellipsis" reveal_type(Wrapper.Nested.FOO._value_) # N: Revealed type is "builtins.ellipsis" [builtins fixtures/enum.pyi] + +[case testEnumItemsEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = "b" + b = "a" + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal['a']") + 1 + 'a' +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +else: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +class B(StrEnum): + a = "b" + b = "a" + +if B.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal['a']") + 1 + 'a' +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +class C(IntEnum): + a = 0 + b = 1 + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[1]") + 1 + 'a' + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +class D(int, Enum): + a = 0 + b = 1 + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[1]") + 1 + 'a' + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsInStub] +# flags: --python-version=3.11 --strict-equality +from mystub import A, B, C, D + +# Every `if` block in this test should have an error on exactly one of two lines. +# Either it is reachable (and thus overlapping) or unreachable (and non-overlapping) + +if A.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if A.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[0]") + 1 + 'a' + +if A.a == A.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if A.a == A.b: # E: Non-overlapping equality check (left operand type: "Literal[A.a]", right operand type: "Literal[A.b]") + 1 + 'a' + +if B.a == "a": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == "b": + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if B.a == 0: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[0]") + 1 + 'a' + +if B.a == B.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if B.a == B.b: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[B.b]") + 1 + 'a' + +if B.a == A.a: # E: Non-overlapping equality check (left operand type: "Literal[B.a]", right operand type: "Literal[A.a]") + 1 + 'a' + +if C.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['a']") + 1 + 'a' +if C.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal['b']") + 1 + 'a' + +if C.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if C.a == C.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if C.a == C.b: # E: Non-overlapping equality check (left operand type: "Literal[C.a]", right operand type: "Literal[C.b]") + 1 + 'a' + +if D.a == "a": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['a']") + 1 + 'a' +if D.a == "b": # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal['b']") + 1 + 'a' + +if D.a == 0: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == 1: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +if D.a == D.a: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +if D.a == D.b: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[D.b]") + 1 + 'a' + +if D.a == C.a: # E: Non-overlapping equality check (left operand type: "Literal[D.a]", right operand type: "Literal[C.a]") + 1 + 'a' + +[file mystub.pyi] +from enum import Enum, StrEnum, IntEnum + +class A(str, Enum): + a = ... + b = ... + +class B(StrEnum): + a = ... + b = ... + +class C(int, Enum): + a = ... + b = ... + +class D(IntEnum): + a = ... + b = ... +[builtins fixtures/dict.pyi] + + +[case testEnumItemsEqualityToLiteralsWithAlias-xfail] +# flags: --python-version=3.11 --strict-equality +# TODO: mypy does not support enum member aliases now. +from enum import Enum, IntEnum + +class A(str, Enum): + a = "c" + b = a + +if A.a == A.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") + +class B(IntEnum): + a = 0 + b = a + +if B.a == B.b: + 1 + 'a' # E: Unsupported operand types for + ("int" and "str") +[builtins fixtures/dict.pyi] + + +[case testEnumNarrowingByEqualityToLiterals] +# flags: --python-version=3.11 --strict-equality +from enum import Enum, StrEnum, IntEnum + +# Every `if` block in this test should either reveal or report non-overlapping. + +class A(str, Enum): + a = "b" + b = "a" +class B(StrEnum): + a = "b" + b = "a" +class C(int, Enum): + a = 0 + b = 1 +class D(IntEnum): + a = 0 + b = 1 + +a: A +if a == A.a: + reveal_type(a) # N: Revealed type is "Literal[__main__.A.a]" +else: + reveal_type(a) # N: Revealed type is "Literal[__main__.A.b]" + +if a == "a": + reveal_type(a) # N: Revealed type is "__main__.A" +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +if a == "c": + reveal_type(a) # N: Revealed type is "__main__.A" +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +if a == 0: # E: Non-overlapping equality check (left operand type: "A", right operand type: "Literal[0]") + reveal_type(a) +else: + reveal_type(a) # N: Revealed type is "__main__.A" + +b: B +if b == B.a: + reveal_type(b) # N: Revealed type is "Literal[__main__.B.a]" +else: + reveal_type(b) # N: Revealed type is "Literal[__main__.B.b]" + +if b == "a": + reveal_type(b) # N: Revealed type is "__main__.B" +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +if b == "c": + reveal_type(b) # N: Revealed type is "__main__.B" +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +if b == 0: # E: Non-overlapping equality check (left operand type: "B", right operand type: "Literal[0]") + reveal_type(b) +else: + reveal_type(b) # N: Revealed type is "__main__.B" + +c: C +if c == C.a: + reveal_type(c) # N: Revealed type is "Literal[__main__.C.a]" +else: + reveal_type(c) # N: Revealed type is "Literal[__main__.C.b]" + +if c == 0: + reveal_type(c) # N: Revealed type is "__main__.C" +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +if c == 2: + reveal_type(c) # N: Revealed type is "__main__.C" +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +if c == "a": # E: Non-overlapping equality check (left operand type: "C", right operand type: "Literal['a']") + reveal_type(c) +else: + reveal_type(c) # N: Revealed type is "__main__.C" + +d: D +if d == D.a: + reveal_type(d) # N: Revealed type is "Literal[__main__.D.a]" +else: + reveal_type(d) # N: Revealed type is "Literal[__main__.D.b]" + +if d == 0: + reveal_type(d) # N: Revealed type is "__main__.D" +else: + reveal_type(d) # N: Revealed type is "__main__.D" + +if d == 2: + reveal_type(d) # N: Revealed type is "__main__.D" +else: + reveal_type(d) # N: Revealed type is "__main__.D" + +if d == "a": # E: Non-overlapping equality check (left operand type: "D", right operand type: "Literal['a']") + reveal_type(d) +else: + reveal_type(d) # N: Revealed type is "__main__.D" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 00d33c86414f..ab89295f9c9a 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2203,6 +2203,11 @@ def f3(x: IE | IE2) -> None: else: reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + if x == 1: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + else: + reveal_type(x) # N: Revealed type is "Union[__main__.IE, __main__.IE2]" + def f4(x: IE | E) -> None: if x == IE.X: reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]"