Skip to content
186 changes: 184 additions & 2 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from unittest import TestCase, main, skipUnless
from typing import TypeVar, Optional
from typing import T, KT, VT # Not in __all__.
from typing import Tuple, List
from typing import Tuple, List, Dict
from typing import Generic
from typing import get_type_hints
from typing import no_type_check
Expand Down Expand Up @@ -65,6 +65,8 @@
# Protocols are hard to backport to the original version of typing 3.5.0
HAVE_PROTOCOLS = sys.version_info[:3] != (3, 5, 0)

# Not backported to older versions yet
HAVE_ANNOTATED = PEP_560

class BaseTestCase(TestCase):
def assertIsSubclass(self, cls, class_or_tuple, msg=None):
Expand Down Expand Up @@ -1458,6 +1460,182 @@ def test_total(self):
self.assertEqual(Options.__total__, False)


if HAVE_ANNOTATED:
from typing_extensions import Annotated, get_type_hints

class AnnotatedTests(BaseTestCase):

def test_repr(self):
self.assertEqual(
repr(Annotated[int, 4, 5]),
"typing_extensions.Annotated[int, 4, 5]"
)

def test_flatten(self):
A = Annotated[Annotated[int, 4], 5]
self.assertEqual(A, Annotated[int, 4, 5])
self.assertEqual(A.__metadata__, (4, 5))
self.assertEqual(A.__origin__, int)

def test_hash_eq(self):
self.assertEqual(len({Annotated[int, 4, 5], Annotated[int, 4, 5]}), 1)
self.assertNotEqual(Annotated[int, 4, 5], Annotated[int, 5, 4])
self.assertNotEqual(Annotated[int, 4, 5], Annotated[str, 4, 5])
self.assertNotEqual(Annotated[int, 4], Annotated[int, 4, 4])
self.assertEqual(
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
)

def test_instantiate(self):
Comment thread
till-varoquaux marked this conversation as resolved.
class C:
classvar = 4

def __init__(self, x):
self.x = x

def __eq__(self, other):
if not isinstance(other, C):
return NotImplemented
return other.x == self.x

A = Annotated[C, "a decoration"]
a = A(5)
c = C(5)
self.assertEqual(a, c)
self.assertEqual(a.x, c.x)
self.assertEqual(A.classvar, C.classvar)

MyCount = Annotated[typing_extensions.Counter[T], "my decoration"]
self.assertEqual(MyCount([4, 4, 5]), {4: 2, 5: 1})
self.assertEqual(MyCount[int]([4, 4, 5]), {4: 2, 5: 1})


def test_cannot_subclass(self):
with self.assertRaises(TypeError):
class C(Annotated):
pass

def test_pickle(self):
Comment thread
till-varoquaux marked this conversation as resolved.
samples = [typing.Any, typing.Union[int, str],
typing.Optional[str], Tuple[int, ...],
typing.Callable[[str], bytes]]

for t in samples:
x = Annotated[t, "a"]

for prot in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(protocol=prot, type=t):
pickled = pickle.dumps(x, prot)
restored = pickle.loads(pickled)
self.assertEqual(x, restored)

global _Annotated_test_G

class _Annotated_test_G(Generic[T]):
x = 1

G = Annotated[_Annotated_test_G[int], "A decoration"]
G.foo = 42
G.bar = 'abc'

for proto in range(pickle.HIGHEST_PROTOCOL + 1):
z = pickle.dumps(G, proto)
x = pickle.loads(z)
Comment thread
till-varoquaux marked this conversation as resolved.
self.assertEqual(x.foo, 42)
self.assertEqual(x.bar, 'abc')
self.assertEqual(x.x, 1)

def test_subst(self):
dec = "a decoration"

S = Annotated[T, dec]
self.assertEqual(S[int], Annotated[int, dec])

L = Annotated[List[T], dec]
self.assertEqual(L[int], Annotated[List[int], dec])
with self.assertRaises(TypeError):
L[int, int]

D = Annotated[Dict[KT, VT], dec]
self.assertEqual(D[str, int], Annotated[Dict[str, int], dec])
with self.assertRaises(TypeError):
D[int]

I = Annotated[int, dec]
with self.assertRaises(TypeError):
I[None]

LI = L[int]
with self.assertRaises(TypeError):
LI[None]

def test_annotated_in_other_types(self):
X = List[Annotated[T, 5]]
self.assertEqual(X[int], List[Annotated[int, 5]])


Comment thread
till-varoquaux marked this conversation as resolved.
class GetTypeHintsTests(BaseTestCase):
def test_get_type_hints(self):
def foobar(x: List['X']): ...
X = Annotated[int, (1, 10)]
self.assertEqual(
get_type_hints(foobar, globals(), locals()),
{'x': List[int]}
)
self.assertEqual(
get_type_hints(foobar, globals(), locals(), include_extras=True),
{'x': List[Annotated[int, (1, 10)]]}
)
BA = Tuple[Annotated[T, (1, 0)], ...]
def barfoo(x: BA): ...
self.assertEqual(get_type_hints(barfoo, globals(), locals())['x'], Tuple[T, ...])
self.assertIs(
get_type_hints(barfoo, globals(), locals(), include_extras=True)['x'],
BA
)
def barfoo2(x: typing.Callable[..., Annotated[List[T], "const"]],
y: typing.Union[int, Annotated[T, "mutable"]]): ...
self.assertEqual(
get_type_hints(barfoo2, globals(), locals()),
{'x': typing.Callable[..., List[T]], 'y': typing.Union[int, T]}
)
BA2 = typing.Callable[..., List[T]]
def barfoo3(x: BA2): ...
self.assertIs(
get_type_hints(barfoo3, globals(), locals(), include_extras=True)["x"],
BA2
)

def test_get_type_hints_refs(self):

Const = Annotated[T, "Const"]

class MySet(Generic[T]):

def __ior__(self, other: "Const[MySet[T]]") -> "MySet[T]":
...

def __iand__(self, other: Const["MySet[T]"]) -> "MySet[T]":
...

self.assertEqual(
get_type_hints(MySet.__iand__, globals(), locals()),
{'other': MySet[T], 'return': MySet[T]}
)

self.assertEqual(
get_type_hints(MySet.__iand__, globals(), locals(), include_extras=True),
{'other': Const[MySet[T]], 'return': MySet[T]}
)

self.assertEqual(
get_type_hints(MySet.__ior__, globals(), locals()),
{'other': MySet[T], 'return': MySet[T]}
)



class AllTests(BaseTestCase):

def test_typing_extensions_includes_standard(self):
Expand Down Expand Up @@ -1488,8 +1666,12 @@ def test_typing_extensions_includes_standard(self):
self.assertIn('Protocol', a)
self.assertIn('runtime', a)

if HAVE_ANNOTATED:
self.assertIn('Annotated', a)
self.assertIn('get_type_hints', a)

def test_typing_extensions_defers_when_possible(self):
exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final'}
exclude = {'overload', 'Text', 'TYPE_CHECKING', 'Final', 'get_type_hints'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
155 changes: 155 additions & 0 deletions typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import typing
import collections.abc as collections_abc
import operator

# After PEP 560, internal typing API was substantially reworked.
# This is especially important for Protocol class which uses internal APIs
Expand Down Expand Up @@ -139,6 +140,14 @@ def _check_methods_in_mro(C, *methods):
if HAVE_PROTOCOLS:
__all__.extend(['Protocol', 'runtime'])

# Annotations were implemented under tight time constraints; this keeps the
# implementation simple for now
HAVE_ANNOTATED = PEP_560

if HAVE_ANNOTATED:
__all__.extend(['Annotated', 'get_type_hints'])


# TODO
if hasattr(typing, 'NoReturn'):
NoReturn = typing.NoReturn
Expand Down Expand Up @@ -1595,3 +1604,149 @@ class Point2D(TypedDict):
The class syntax is only supported in Python 3.6+, while two other
syntax forms work for Python 2.7 and 3.2+
"""


if HAVE_ANNOTATED:
class _Annotated(typing._GenericAlias, _root=True):
Comment thread
till-varoquaux marked this conversation as resolved.
"""Runtime representation of an annotated type.

At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't'
with extra annotations. The alias behaves like a normal typing alias,
instantiating is the same as instantiating the underlying type, binding
it to types is also the same.
"""
def __init__(self, origin, metadata):
if isinstance(origin, _Annotated):
metadata = origin.__metadata__ + metadata
origin = origin.__origin__
super().__init__(origin, origin)
self.__metadata__ = metadata

def copy_with(self, params):
assert len(params) == 1
new_type = params[0]
return _Annotated(new_type, self.__metadata__)

def __repr__(self):
return "typing_extensions.Annotated[{}, {}]".format(
typing._type_repr(self.__origin__),
", ".join(repr(a) for a in self.__metadata__)
)

def __reduce__(self):
return operator.getitem, (
Annotated, (self.__origin__,) + self.__metadata__
)

def __eq__(self, other):
if not isinstance(other, _Annotated):
return NotImplemented
if self.__origin__ != other.__origin__:
return False
return self.__metadata__ == other.__metadata__

def __hash__(self):
return hash((self.__origin__, self.__metadata__))


class Annotated:
Comment thread
till-varoquaux marked this conversation as resolved.
"""Add context specific metadata to a type.

Example: Annotated[int, runtime_check.Unsigned] indicates to the
hypothetical runtime_check module that this type is an unsigned int.
Every other consumer of this type can ignore this metadata and treat
this type as int.

The first argument to Annotated must be a valid type (and will be in
the __origin__ field), the remaining arguments are kept as a tuple in
the __extra__ field.

Details:

- It's an error to call `Annotated` with less than two arguments.
- Nested Annotated are flattened::

Annotated[Annotated[int, Ann1, Ann2], Ann3] == Annotated[int, Ann1, Ann2, Ann3]

- Instantiating an annotated type is equivalent to instantiating the
underlying type::

Annotated[C, Ann1](5) == C(5)
Comment thread
till-varoquaux marked this conversation as resolved.

- Annotated can be used as a generic type alias::

Optimized = Annotated[T, runtime.Optimize]
Optimized[int] == Annotated[int, runtime.Optimize]

OptimizedList = Annotated[List[T], runtime.Optimize]
OptimizedList[int] == Annotated[List[int], runtime.Optimize]
"""

__slots__ = ()

def __new__(cls, *args, **kwargs):
raise TypeError("Type Annotated cannot be instantiated.")

@typing._tp_cache
def __class_getitem__(cls, params):
if not isinstance(params, tuple) or len(params) < 2:
raise TypeError("Annotated[...] should be used "
"with at least two arguments (a type and an "
"annotation).")
msg = "Annotated[t, ...]: t must be a type."
origin = typing._type_check(params[0], msg)
metadata = tuple(params[1:])
return _Annotated(origin, metadata)

Comment thread
till-varoquaux marked this conversation as resolved.
def __init_subclass__(cls, *args, **kwargs):
raise TypeError("Cannot inherit from Annotated")

def _strip_annotations(t):
"""Strips the annotations from a given type.
"""
if isinstance(t, _Annotated):
return _strip_annotations(t.__origin__)
if isinstance(t, typing._GenericAlias):
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
if stripped_args == t.__args__:
return t
res = t.copy_with(stripped_args)
res._special = t._special
return res
return t

def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
"""Return type hints for an object.

This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').

The argument may be a module, class, method, or function. The annotations
are returned as a dictionary. For classes, annotations include also
inherited members.

TypeError is raised if the argument is not of a type that can contain
annotations, and an empty dictionary is returned if no annotations are
present.

BEWARE -- the behavior of globalns and localns is counterintuitive
(unless you are familiar with how eval() and exec() work). The
search order is locals first, then globals.

- If no dict arguments are passed, an attempt is made to use the
globals from obj (or the respective module's globals for classes),
and these are also used as the locals. If the object does not appear
to have globals, an empty dictionary is used.

- If one dict argument is passed, it is used for both globals and
locals.

- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
hint = typing.get_type_hints(obj, globalns=globalns, localns=localns)
if include_extras:
return hint
return {k: _strip_annotations(t) for k, t in hint.items()}