From b6ef6947c3f8493035f32204a8302485988d8865 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 18 Jan 2017 13:02:32 +0000 Subject: [PATCH] Fix special cases where type context is a union This helps with the new definition of `dict.get` in typeshed. The fix feels a little ad-hoc, but it seems to fix a real-world issue and doesn't seem to break anything. Fixes #2703. --- mypy/constraints.py | 62 ++++++++++++++++++--- test-data/unit/check-inference-context.test | 39 +++++++++++++ 2 files changed, 93 insertions(+), 8 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 1d1e1c332872..533a436448ec 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -2,10 +2,12 @@ from typing import Iterable, List, Optional +from mypy import experiments from mypy.types import ( CallableType, Type, TypeVisitor, UnboundType, AnyType, Void, NoneTyp, TypeVarType, Instance, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, - DeletedType, UninhabitedType, TypeType, TypeVarId, is_named_instance + DeletedType, UninhabitedType, TypeType, TypeVarId, TypeQuery, ALL_TYPES_STRATEGY, + is_named_instance ) from mypy.maptype import map_instance_to_supertype from mypy import nodes @@ -149,13 +151,23 @@ def infer_constraints(template: Type, actual: Type, # be a supertype of the potential subtype, some item of the Union # must be a supertype of it. if direction == SUBTYPE_OF and isinstance(actual, UnionType): + # If some of items is not a complete type, disregard that. + items = simplify_away_incomplete_types(actual.items) + # We infer constraints eagerly -- try to find constraints for a type + # variable if possible. This seems to help with some real-world + # use cases. return any_constraints( [infer_constraints_if_possible(template, a_item, direction) - for a_item in actual.items]) + for a_item in items], + eager=True) if direction == SUPERTYPE_OF and isinstance(template, UnionType): + # When the template is a union, we are okay with leaving some + # type variables indeterminate. This helps with some special + # cases, though this isn't very principled. return any_constraints( [infer_constraints_if_possible(t_item, actual, direction) - for t_item in template.items]) + for t_item in template.items], + eager=False) # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) @@ -177,12 +189,18 @@ def infer_constraints_if_possible(template: Type, actual: Type, return infer_constraints(template, actual, direction) -def any_constraints(options: List[Optional[List[Constraint]]]) -> List[Constraint]: - """Deduce what we can from a collection of constraint lists given that - at least one of the lists must be satisfied. A None element in the - list of options represents an unsatisfiable constraint and is ignored. +def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> List[Constraint]: + """Deduce what we can from a collection of constraint lists. + + It's a given that at least one of the lists must be satisfied. A + None element in the list of options represents an unsatisfiable + constraint and is ignored. Ignore empty constraint lists if eager + is true -- they are always trivially satisfiable. """ - valid_options = [option for option in options if option is not None] + if eager: + valid_options = [option for option in options if option] + else: + valid_options = [option for option in options if option is not None] if len(valid_options) == 1: return valid_options[0] # Otherwise, there are either no valid options or multiple valid options. @@ -196,6 +214,34 @@ def any_constraints(options: List[Optional[List[Constraint]]]) -> List[Constrain # every option, combine the bounds with meet/join. +def simplify_away_incomplete_types(types: List[Type]) -> List[Type]: + complete = [typ for typ in types if is_complete_type(typ)] + if complete: + return complete + else: + return types + + +def is_complete_type(typ: Type) -> bool: + """Is a type complete? + + A complete doesn't have uninhabited type components or (when not in strict + optional mode) None components. + """ + return typ.accept(CompleteTypeVisitor()) + + +class CompleteTypeVisitor(TypeQuery): + def __init__(self) -> None: + super().__init__(default=True, strategy=ALL_TYPES_STRATEGY) + + def visit_none_type(self, t: NoneTyp) -> bool: + return experiments.STRICT_OPTIONAL + + def visit_uninhabited_type(self, t: UninhabitedType) -> bool: + return False + + class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): """Visitor class for inferring type constraints.""" diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index e22aee11ac3e..59a8f211c9f9 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -826,3 +826,42 @@ class A(Generic[T]): pass reveal_type(A()) # E: Revealed type is '__main__.A[builtins.None]' b = reveal_type(A()) # type: A[int] # E: Revealed type is '__main__.A[builtins.int]' + +[case testUnionWithGenericTypeItemContext] +from typing import TypeVar, Union, List + +T = TypeVar('T') + +def f(x: Union[T, List[int]]) -> Union[T, List[int]]: pass +reveal_type(f(1)) # E: Revealed type is 'Union[builtins.int*, builtins.list[builtins.int]]' +reveal_type(f([])) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(f(None)) # E: Revealed type is 'builtins.list[builtins.int]' +[builtins fixtures/list.pyi] + +[case testUnionWithGenericTypeItemContextAndStrictOptional] +# flags: --strict-optional +from typing import TypeVar, Union, List + +T = TypeVar('T') + +def f(x: Union[T, List[int]]) -> Union[T, List[int]]: pass +reveal_type(f(1)) # E: Revealed type is 'Union[builtins.int*, builtins.list[builtins.int]]' +reveal_type(f([])) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(f(None)) # E: Revealed type is 'Union[builtins.None, builtins.list[builtins.int]]' +[builtins fixtures/list.pyi] + +[case testUnionWithGenericTypeItemContextInMethod] +from typing import TypeVar, Union, List, Generic + +T = TypeVar('T') +S = TypeVar('S') + +class C(Generic[T]): + def f(self, x: Union[T, S]) -> Union[T, S]: pass + +c = C[List[int]]() +reveal_type(c.f('')) # E: Revealed type is 'Union[builtins.list[builtins.int], builtins.str*]' +reveal_type(c.f([1])) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(c.f([])) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(c.f(None)) # E: Revealed type is 'builtins.list[builtins.int]' +[builtins fixtures/list.pyi]