diff --git a/mypy/checker.py b/mypy/checker.py index 85a2759185813..51137b7ad5fb6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1387,16 +1387,12 @@ def check_method_override_for_base_with_name( if isinstance(original_type, AnyType) or isinstance(typ, AnyType): pass elif isinstance(original_type, FunctionLike) and isinstance(typ, FunctionLike): - if (isinstance(base_attr.node, (FuncDef, OverloadedFuncDef, Decorator)) - and not is_static(base_attr.node)): - bound = bind_self(original_type, self.scope.active_self_type()) - else: - bound = original_type - original = map_type_from_supertype(bound, defn.info, base) + original = self.bind_and_map_method(base_attr, original_type, + defn.info, base) # Check that the types are compatible. # TODO overloaded signatures self.check_override(typ, - cast(FunctionLike, original), + original, defn.name(), name, base.name(), @@ -1415,6 +1411,23 @@ def check_method_override_for_base_with_name( defn.name(), name, base.name(), context) return False + def bind_and_map_method(self, sym: SymbolTableNode, typ: FunctionLike, + sub_info: TypeInfo, super_info: TypeInfo) -> FunctionLike: + """Bind self-type and map type variables for a method. + + Arguments: + sym: a symbol that points to method definition + typ: method type on the definition + sub_info: class where the method is used + super_info: class where the method was defined + """ + if (isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) + and not is_static(sym.node)): + bound = bind_self(typ, self.scope.active_self_type()) + else: + bound = typ + return cast(FunctionLike, map_type_from_supertype(bound, sub_info, super_info)) + def get_op_other_domain(self, tp: FunctionLike) -> Optional[Type]: if isinstance(tp, CallableType): if tp.arg_kinds and tp.arg_kinds[0] == ARG_POS: @@ -1628,22 +1641,35 @@ def determine_type_of_class_member(self, sym: SymbolTableNode) -> Optional[Type] return None def check_compatibility(self, name: str, base1: TypeInfo, - base2: TypeInfo, ctx: Context) -> None: + base2: TypeInfo, ctx: TypeInfo) -> None: """Check if attribute name in base1 is compatible with base2 in multiple inheritance. Assume base1 comes before base2 in the MRO, and that base1 and base2 don't have a direct subclass relationship (i.e., the compatibility requirement only derives from multiple inheritance). + + This check verifies that a definition taken from base1 (and mapped to the current + class ctx), is type compatible with the definition taken from base2 (also mapped), so + that unsafe subclassing like this can be detected: + class A(Generic[T]): + def foo(self, x: T) -> None: ... + + class B: + def foo(self, x: str) -> None: ... + + class C(B, A[int]): ... # this is unsafe because... + + x: A[int] = C() + x.foo # ...runtime type is (str) -> None, while static type is (int) -> None """ if name in ('__init__', '__new__', '__init_subclass__'): # __init__ and friends can be incompatible -- it's a special case. return - first = base1[name] - second = base2[name] + first = base1.names[name] + second = base2.names[name] first_type = self.determine_type_of_class_member(first) second_type = self.determine_type_of_class_member(second) - # TODO: What if some classes are generic? if (isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike)): if first_type.is_type_obj() and second_type.is_type_obj(): @@ -1652,8 +1678,9 @@ def check_compatibility(self, name: str, base1: TypeInfo, ok = is_subtype(left=fill_typevars_with_any(first_type.type_object()), right=fill_typevars_with_any(second_type.type_object())) else: - first_sig = bind_self(first_type) - second_sig = bind_self(second_type) + # First bind/map method types when necessary. + first_sig = self.bind_and_map_method(first, first_type, ctx, base1) + second_sig = self.bind_and_map_method(second, second_type, ctx, base2) ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True) elif first_type and second_type: ok = is_equivalent(first_type, second_type) diff --git a/test-data/unit/check-multiple-inheritance.test b/test-data/unit/check-multiple-inheritance.test index d87b4a5c48193..dc69a4187cead 100644 --- a/test-data/unit/check-multiple-inheritance.test +++ b/test-data/unit/check-multiple-inheritance.test @@ -624,3 +624,44 @@ class Mixin2: class A(Mixin1, Mixin2): pass [out] + +[case testGenericMultipleOverrideRemap] +from typing import TypeVar, Generic, Tuple + +K = TypeVar('K') +V = TypeVar('V') +T = TypeVar('T') + +class ItemsView(Generic[K, V]): + def __iter__(self) -> Tuple[K, V]: ... + +class Sequence(Generic[T]): + def __iter__(self) -> T: ... + +# Override compatible between bases. +class OrderedItemsView(ItemsView[K, V], Sequence[Tuple[K, V]]): + def __iter__(self) -> Tuple[K, V]: ... + +class OrderedItemsViewDirect(ItemsView[K, V], Sequence[Tuple[K, V]]): + pass + +[case testGenericMultipleOverrideReplace] +from typing import TypeVar, Generic, Union + +T = TypeVar('T') + +class A(Generic[T]): + def foo(self, x: T) -> None: ... + +class B(A[T]): ... + +class C1: + def foo(self, x: str) -> None: ... + +class C2: + def foo(self, x: Union[str, int]) -> None: ... + +class D1(B[str], C1): ... +class D2(B[Union[int, str]], C2): ... +class D3(C2, B[str]): ... +class D4(B[str], C2): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "C2"