From a49d3b87380536b4a79894c64133558fd6ceea40 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 25 Aug 2022 14:42:48 +0100 Subject: [PATCH 1/2] Add type inference for class object vs generic protocol --- mypy/constraints.py | 31 +++++++++++++++++++++++++++-- test-data/unit/check-protocols.test | 29 +++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 9e28ce503b6c..59b9c830d1e9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -541,7 +541,29 @@ def visit_instance(self, template: Instance) -> list[Constraint]: template.type.inferring.pop() return res if isinstance(actual, CallableType) and actual.fallback is not None: + if actual.is_type_obj() and template.type.is_protocol: + ret_type = get_proper_type(actual.ret_type) + if isinstance(ret_type, TupleType): + ret_type = mypy.typeops.tuple_fallback(ret_type) + if isinstance(ret_type, Instance): + if self.direction == SUBTYPE_OF: + subtype = template + else: + subtype = ret_type + return self.infer_constraints_from_protocol_members( + ret_type, template, subtype, template, class_obj=True + ) actual = actual.fallback + if isinstance(actual, TypeType) and template.type.is_protocol: + if isinstance(actual.item, Instance): + if self.direction == SUBTYPE_OF: + subtype = template + else: + subtype = actual.item + return self.infer_constraints_from_protocol_members( + actual.item, template, subtype, template, class_obj=True + ) + if isinstance(actual, Overloaded) and actual.fallback is not None: actual = actual.fallback if isinstance(actual, TypedDictType): @@ -740,7 +762,12 @@ def visit_instance(self, template: Instance) -> list[Constraint]: return [] def infer_constraints_from_protocol_members( - self, instance: Instance, template: Instance, subtype: Type, protocol: Instance + self, + instance: Instance, + template: Instance, + subtype: Type, + protocol: Instance, + class_obj: bool = False, ) -> list[Constraint]: """Infer constraints for situations where either 'template' or 'instance' is a protocol. @@ -750,7 +777,7 @@ def infer_constraints_from_protocol_members( """ res = [] for member in protocol.type.protocol_members: - inst = mypy.subtypes.find_member(member, instance, subtype) + inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj) temp = mypy.subtypes.find_member(member, template, subtype) if inst is None or temp is None: return [] # See #11020 diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index 90276ebae972..fef26733167e 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -3517,3 +3517,32 @@ test(c) # E: Argument 1 to "test" has incompatible type "Type[C]"; expected "P" # N: def [T] foo(arg: T) -> T \ # N: Got: \ # N: def [T] foo(self: T) -> Union[T, int] + +[case testProtocolClassObjectInference] +from typing import Any, Protocol, TypeVar + +T = TypeVar("T", contravariant=True) +class P(Protocol[T]): + def foo(self, obj: T) -> int: ... + +class B: + def foo(self) -> int: ... + +S = TypeVar("S") +def test(arg: P[S]) -> S: ... +reveal_type(test(B)) # N: Revealed type is "__main__.B" + +[case testProtocolTypeTypeInference] +from typing import Any, Protocol, TypeVar, Type + +T = TypeVar("T", contravariant=True) +class P(Protocol[T]): + def foo(self, obj: T) -> int: ... + +class B: + def foo(self) -> int: ... + +S = TypeVar("S") +def test(arg: P[S]) -> S: ... +b: Type[B] +reveal_type(test(b)) # N: Revealed type is "__main__.B" From 6b505510f19a641d14acb79f6d0847b5f4e89126 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 25 Aug 2022 15:50:20 +0100 Subject: [PATCH 2/2] Fix tests --- mypy/constraints.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 59b9c830d1e9..e0742a33e9e8 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -550,8 +550,10 @@ def visit_instance(self, template: Instance) -> list[Constraint]: subtype = template else: subtype = ret_type - return self.infer_constraints_from_protocol_members( - ret_type, template, subtype, template, class_obj=True + res.extend( + self.infer_constraints_from_protocol_members( + ret_type, template, subtype, template, class_obj=True + ) ) actual = actual.fallback if isinstance(actual, TypeType) and template.type.is_protocol: @@ -560,8 +562,10 @@ def visit_instance(self, template: Instance) -> list[Constraint]: subtype = template else: subtype = actual.item - return self.infer_constraints_from_protocol_members( - actual.item, template, subtype, template, class_obj=True + res.extend( + self.infer_constraints_from_protocol_members( + actual.item, template, subtype, template, class_obj=True + ) ) if isinstance(actual, Overloaded) and actual.fallback is not None: @@ -737,6 +741,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]: ) instance.type.inferring.pop() return res + if res: + return res + if isinstance(actual, AnyType): return self.infer_against_any(template.args, actual) if (