diff --git a/typing_extensions/src_py3/test_typing_extensions.py b/typing_extensions/src_py3/test_typing_extensions.py index 173a782a5..3e62577ed 100644 --- a/typing_extensions/src_py3/test_typing_extensions.py +++ b/typing_extensions/src_py3/test_typing_extensions.py @@ -9,7 +9,7 @@ from unittest import TestCase, main, skipUnless from typing import TypeVar, Optional from typing import T, KT, VT # Not in __all__. -from typing import Tuple, List +from typing import Tuple, List, Dict from typing import Generic from typing import get_type_hints from typing import no_type_check @@ -65,6 +65,8 @@ # Protocols are hard to backport to the original version of typing 3.5.0 HAVE_PROTOCOLS = sys.version_info[:3] != (3, 5, 0) +# Not backported to older versions yet +HAVE_ANNOTATED = PEP_560 class BaseTestCase(TestCase): def assertIsSubclass(self, cls, class_or_tuple, msg=None): @@ -1458,6 +1460,182 @@ def test_total(self): self.assertEqual(Options.__total__, False) +if HAVE_ANNOTATED: + from typing_extensions import Annotated, get_type_hints + + class AnnotatedTests(BaseTestCase): + + def test_repr(self): + self.assertEqual( + repr(Annotated[int, 4, 5]), + "typing_extensions.Annotated[int, 4, 5]" + ) + + def test_flatten(self): + A = Annotated[Annotated[int, 4], 5] + self.assertEqual(A, Annotated[int, 4, 5]) + self.assertEqual(A.__metadata__, (4, 5)) + self.assertEqual(A.__origin__, int) + + def test_hash_eq(self): + self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1) + self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4]) + self.assertNotEqual(Annotated[int, 4, 5], Annotated[str, 4, 5]) + self.assertNotEqual(Annotated[int, 4], Annotated[int, 4, 4]) + self.assertEqual( + {Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]}, + {Annotated[int, 4, 5], Annotated[T, 4, 5]} + ) + + def test_instantiate(self): + class C: + classvar = 4 + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + if not isinstance(other, C): + return NotImplemented + return other.x == self.x + + A = Annotated[C, "a decoration"] + a = A(5) + c = C(5) + self.assertEqual(a, c) + self.assertEqual(a.x, c.x) + self.assertEqual(A.classvar, C.classvar) + + MyCount = Annotated[typing_extensions.Counter[T], "my decoration"] + self.assertEqual(MyCount([4, 4, 5]), {4: 2, 5: 1}) + self.assertEqual(MyCount[int]([4, 4, 5]), {4: 2, 5: 1}) + + + def test_cannot_subclass(self): + with self.assertRaises(TypeError): + class C(Annotated): + pass + + def test_pickle(self): + samples = [typing.Any, typing.Union[int, str], + typing.Optional[str], Tuple[int, ...], + typing.Callable[[str], bytes]] + + for t in samples: + x = Annotated[t, "a"] + + for prot in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(protocol=prot, type=t): + pickled = pickle.dumps(x, prot) + restored = pickle.loads(pickled) + self.assertEqual(x, restored) + + global _Annotated_test_G + + class _Annotated_test_G(Generic[T]): + x = 1 + + G = Annotated[_Annotated_test_G[int], "A decoration"] + G.foo = 42 + G.bar = 'abc' + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(G, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + + def test_subst(self): + dec = "a decoration" + + S = Annotated[T, dec] + self.assertEqual(S[int], Annotated[int, dec]) + + L = Annotated[List[T], dec] + self.assertEqual(L[int], Annotated[List[int], dec]) + with self.assertRaises(TypeError): + L[int, int] + + D = Annotated[Dict[KT, VT], dec] + self.assertEqual(D[str, int], Annotated[Dict[str, int], dec]) + with self.assertRaises(TypeError): + D[int] + + I = Annotated[int, dec] + with self.assertRaises(TypeError): + I[None] + + LI = L[int] + with self.assertRaises(TypeError): + LI[None] + + def test_annotated_in_other_types(self): + X = List[Annotated[T, 5]] + self.assertEqual(X[int], List[Annotated[int, 5]]) + + + class GetTypeHintsTests(BaseTestCase): + def test_get_type_hints(self): + def foobar(x: List['X']): ... + X = Annotated[int, (1, 10)] + self.assertEqual( + get_type_hints(foobar, globals(), locals()), + {'x': List[int]} + ) + self.assertEqual( + get_type_hints(foobar, globals(), locals(), include_extras=True), + {'x': List[Annotated[int, (1, 10)]]} + ) + BA = Tuple[Annotated[T, (1, 0)], ...] + def barfoo(x: BA): ... + self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...]) + self.assertIs( + get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'], + BA + ) + def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]], + y: typing.Union[int, Annotated[T, "mutable"]]): ... + self.assertEqual( + get_type_hints(barfoo2, globals(), locals()), + {'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]} + ) + BA2 = typing.Callable[..., List[T]] + def barfoo3(x: BA2): ... + self.assertIs( + get_type_hints(barfoo3, globals(), locals(), include_extras=True)["x"], + BA2 + ) + + def test_get_type_hints_refs(self): + + Const = Annotated[T, "Const"] + + class MySet(Generic[T]): + + def __ior__(self, other: "Const[MySet[T]]") -> "MySet[T]": + ... + + def __iand__(self, other: Const["MySet[T]"]) -> "MySet[T]": + ... + + self.assertEqual( + get_type_hints(MySet.__iand__, globals(), locals()), + {'other': MySet[T], 'return': MySet[T]} + ) + + self.assertEqual( + get_type_hints(MySet.__iand__, globals(), locals(), include_extras=True), + {'other': Const[MySet[T]], 'return': MySet[T]} + ) + + self.assertEqual( + get_type_hints(MySet.__ior__, globals(), locals()), + {'other': MySet[T], 'return': MySet[T]} + ) + + + class AllTests(BaseTestCase): def test_typing_extensions_includes_standard(self): @@ -1488,8 +1666,12 @@ def test_typing_extensions_includes_standard(self): self.assertIn('Protocol', a) self.assertIn('runtime', a) + if HAVE_ANNOTATED: + self.assertIn('Annotated', a) + self.assertIn('get_type_hints', a) + def test_typing_extensions_defers_when_possible(self): - exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final'} + exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final', 'get_type_hints'} for item in typing_extensions.__all__: if item not in exclude and hasattr(typing, item): self.assertIs( diff --git a/typing_extensions/src_py3/typing_extensions.py b/typing_extensions/src_py3/typing_extensions.py index 956ef6138..3d50e1f1b 100644 --- a/typing_extensions/src_py3/typing_extensions.py +++ b/typing_extensions/src_py3/typing_extensions.py @@ -4,6 +4,7 @@ import sys import typing import collections.abc as collections_abc +import operator # After PEP 560, internal typing API was substantially reworked. # This is especially important for Protocol class which uses internal APIs @@ -139,6 +140,14 @@ def _check_methods_in_mro(C, *methods): if HAVE_PROTOCOLS: __all__.extend(['Protocol', 'runtime']) +# Annotations were implemented under tight time constraints; this keeps the +# implementation simple for now +HAVE_ANNOTATED = PEP_560 + +if HAVE_ANNOTATED: + __all__.extend(['Annotated', 'get_type_hints']) + + # TODO if hasattr(typing, 'NoReturn'): NoReturn = typing.NoReturn @@ -1595,3 +1604,149 @@ class Point2D(TypedDict): The class syntax is only supported in Python 3.6+, while two other syntax forms work for Python 2.7 and 3.2+ """ + + +if HAVE_ANNOTATED: + class _Annotated(typing._GenericAlias, _root=True): + """Runtime representation of an annotated type. + + At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' + with extra annotations. The alias behaves like a normal typing alias, + instantiating is the same as instantiating the underlying type, binding + it to types is also the same. + """ + def __init__(self, origin, metadata): + if isinstance(origin, _Annotated): + metadata = origin.__metadata__ + metadata + origin = origin.__origin__ + super().__init__(origin, origin) + self.__metadata__ = metadata + + def copy_with(self, params): + assert len(params) == 1 + new_type = params[0] + return _Annotated(new_type, self.__metadata__) + + def __repr__(self): + return "typing_extensions.Annotated[{}, {}]".format( + typing._type_repr(self.__origin__), + ", ".join(repr(a) for a in self.__metadata__) + ) + + def __reduce__(self): + return operator.getitem, ( + Annotated, (self.__origin__,) + self.__metadata__ + ) + + def __eq__(self, other): + if not isinstance(other, _Annotated): + return NotImplemented + if self.__origin__ != other.__origin__: + return False + return self.__metadata__ == other.__metadata__ + + def __hash__(self): + return hash((self.__origin__, self.__metadata__)) + + + class Annotated: + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type (and will be in + the __origin__ field), the remaining arguments are kept as a tuple in + the __extra__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[int, Ann1, Ann2], Ann3] == Annotated[int, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize] + Optimized[int] == Annotated[int, runtime.Optimize] + + OptimizedList = Annotated[List[T], runtime.Optimize] + OptimizedList[int] == Annotated[List[int], runtime.Optimize] + """ + + __slots__ = () + + def __new__(cls, *args, **kwargs): + raise TypeError("Type Annotated cannot be instantiated.") + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + msg = "Annotated[t, ...]: t must be a type." + origin = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return _Annotated(origin, metadata) + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError("Cannot inherit from Annotated") + + def _strip_annotations(t): + """Strips the annotations from a given type. + """ + if isinstance(t, _Annotated): + return _strip_annotations(t.__origin__) + if isinstance(t, typing._GenericAlias): + stripped_args = tuple(_strip_annotations(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + res = t.copy_with(stripped_args) + res._special = t._special + return res + return t + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + """Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]' with 'T' (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if include_extras: + return hint + return {k: _strip_annotations(t) for k, t in hint.items()}