From de21b02b541a1e07ee1f9e9d1fa3824eaaf134f5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 20 Apr 2017 00:40:33 +0200 Subject: [PATCH 01/18] Initial ideas for runtime implementation of protocols --- src/test_typing.py | 16 ++++--- src/typing.py | 108 +++++++++++++++++++++++++++++++-------------- 2 files changed, 87 insertions(+), 37 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index b3cabda39..c9901c441 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -570,9 +570,8 @@ def test_reversible(self): self.assertIsSubclass(list, typing.Reversible) self.assertNotIsSubclass(int, typing.Reversible) - def test_protocol_instance_type_error(self): - with self.assertRaises(TypeError): - isinstance(0, typing.SupportsAbs) + def test_protocol_instance(self): + self.assertIsInstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): def __int__(self) -> int: return 42 @@ -580,6 +579,13 @@ class C2(C1): pass c = C2() self.assertIsInstance(c, C1) + class C3: + def __int__(self) -> int: + return 42 + class C4(C3): + pass + c = C4() + self.assertIsInstance(c, typing.SupportsInt) class GenericTests(BaseTestCase): @@ -682,7 +688,7 @@ def test_new_repr_complex(self): def test_new_repr_bare(self): T = TypeVar('T') self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]') - self.assertEqual(repr(typing._Protocol[T]), 'typing.Protocol[~T]') + self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]') class C(typing.Dict[Any, Any]): ... # this line should just work repr(C.__mro__) @@ -978,7 +984,7 @@ def test_fail_with_bare_generic(self): with self.assertRaises(TypeError): Tuple[Generic[T]] with self.assertRaises(TypeError): - List[typing._Protocol] + List[typing.Protocol] with self.assertRaises(TypeError): isinstance(1, Generic) diff --git a/src/typing.py b/src/typing.py index 645bc6f8a..db77c50c7 100644 --- a/src/typing.py +++ b/src/typing.py @@ -26,6 +26,7 @@ 'ClassVar', 'Generic', 'Optional', + 'Protocol', 'Tuple', 'Type', 'TypeVar', @@ -88,6 +89,7 @@ 'no_type_check', 'no_type_check_decorator', 'overload', + 'runtime', 'Text', 'TYPE_CHECKING', ] @@ -373,7 +375,7 @@ def _type_check(arg, msg): if ( type(arg).__name__ in ('_Union', '_Optional') and not getattr(arg, '__origin__', None) or - isinstance(arg, TypingMeta) and _gorg(arg) in (Generic, _Protocol) + isinstance(arg, TypingMeta) and _gorg(arg) in (Generic, Protocol) ): raise TypeError("Plain %s is not valid as type argument" % arg) return arg @@ -924,6 +926,9 @@ def _no_slots_copy(dct): return dict_copy +Protocol = object() + + class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types. @@ -967,10 +972,11 @@ def __new__(cls, name, bases, namespace, if base is Generic: raise TypeError("Cannot inherit from plain Generic") if (isinstance(base, GenericMeta) and - base.__origin__ is Generic): + base.__origin__ in (Generic, Protocol)): if gvars is not None: raise TypeError( - "Cannot inherit from Generic[...] multiple types.") + "Cannot inherit from Generic[...] or" + " Protocol[...] multiple types.") gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -980,8 +986,10 @@ def __new__(cls, name, bases, namespace, if not tvarset <= gvarset: raise TypeError( "Some type variables (%s) " - "are not listed in Generic[%s]" % + "are not listed in %s[%s]" % (", ".join(str(t) for t in tvars if t not in gvarset), + "Generic" if any(b.__origin__ is Generic + for b in bases) else "Protocol", ", ".join(str(g) for g in gvars))) tvars = gvars @@ -1123,25 +1131,21 @@ def __getitem__(self, params): "Parameter list to %s[...] cannot be empty" % _qualname(self)) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - if self is Generic: + if self in (Generic, Protocol): # Generic can only be subscripted with unique type variables. if not all(isinstance(p, TypeVar) for p in params): raise TypeError( - "Parameters to Generic[...] must all be type variables") + "Parameters to %r[...] must all be type variables" % self) if len(set(params)) != len(params): raise TypeError( - "Parameters to Generic[...] must all be unique") + "Parameters to %r[...] must all be unique" % self) tvars = params args = params elif self in (Tuple, Callable): tvars = _type_vars(params) args = params - elif self is _Protocol: - # _Protocol is internal, don't check anything. - tvars = params - args = params - elif self.__origin__ in (Generic, _Protocol): - # Can't subscript Generic[...] or _Protocol[...]. + elif self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. raise TypeError("Cannot subscript already-subscripted %s" % repr(self)) else: @@ -1634,25 +1638,39 @@ def utf8(value): class _ProtocolMeta(GenericMeta): - """Internal metaclass for _Protocol. + """Internal metaclass for Protocol. - This exists so _Protocol classes can be generic without deriving + This exists so Protocol classes can be generic without deriving from Generic. """ + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + getattr(b, '__origin__', None) is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (type, object) or base._is_protocol): + raise TypeError('Protocols can only inherit from other protocols,' + ' got %r' % base) + def __instancecheck__(self, obj): - if _Protocol not in self.__bases__: - return super().__instancecheck__(obj) - raise TypeError("Protocols cannot be used with isinstance().") + return issubclass(obj.__class__, self) def __subclasscheck__(self, cls): if not self._is_protocol: # No structural checks since this isn't a protocol. return NotImplemented - - if self is _Protocol: - # Every class is a subclass of the empty protocol. - return True + if not getattr(self, '_is_runtime_protocol', None): + raise TypeError('Instance and class checks can only be used with' + ' @runtime protocols') + if self.__origin__ is not None: + if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + raise TypeError("Parameterized generics cannot be used with class " + "or instance checks") + return False # Find all attributes defined in the protocol. attrs = self._get_protocol_attrs() @@ -1666,7 +1684,7 @@ def _get_protocol_attrs(self): # Get all Protocol base classes. protocol_bases = [] for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol': + if getattr(c, '_is_protocol', False) and c.__name__ != 'Protocol': protocol_bases.append(c) # Get attributes included in protocol. @@ -1684,6 +1702,7 @@ def _get_protocol_attrs(self): attr != '__annotations__' and attr != '__weakref__' and attr != '_is_protocol' and + attr != '_is_runtime_protocol' and attr != '__dict__' and attr != '__args__' and attr != '__slots__' and @@ -1700,8 +1719,8 @@ def _get_protocol_attrs(self): return attrs -class _Protocol(metaclass=_ProtocolMeta): - """Internal base class for protocol classes. +class Protocol(metaclass=_ProtocolMeta): + """Base class for protocol classes. This implements a simple-minded structural issubclass check (similar but more general than the one-offs in collections.abc @@ -1712,6 +1731,24 @@ class _Protocol(metaclass=_ProtocolMeta): _is_protocol = True + def __new__(cls, *args, **kwds): + if _geqv(cls, Protocol): + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + +def runtime(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + """ + if not getattr(cls, '_is_protocol', None): + raise TypeError('@runtime can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. @@ -1755,7 +1792,8 @@ class Iterator(Iterable[T_co], extra=collections_abc.Iterator): __slots__ = () -class SupportsInt(_Protocol): +@runtime +class SupportsInt(Protocol): __slots__ = () @abstractmethod @@ -1763,7 +1801,8 @@ def __int__(self) -> int: pass -class SupportsFloat(_Protocol): +@runtime +class SupportsFloat(Protocol): __slots__ = () @abstractmethod @@ -1771,7 +1810,8 @@ def __float__(self) -> float: pass -class SupportsComplex(_Protocol): +@runtime +class SupportsComplex(Protocol): __slots__ = () @abstractmethod @@ -1779,7 +1819,8 @@ def __complex__(self) -> complex: pass -class SupportsBytes(_Protocol): +@runtime +class SupportsBytes(Protocol): __slots__ = () @abstractmethod @@ -1787,7 +1828,8 @@ def __bytes__(self) -> bytes: pass -class SupportsAbs(_Protocol[T_co]): +@runtime +class SupportsAbs(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1795,7 +1837,8 @@ def __abs__(self) -> T_co: pass -class SupportsRound(_Protocol[T_co]): +@runtime +class SupportsRound(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1807,7 +1850,8 @@ def __round__(self, ndigits: int = 0) -> T_co: class Reversible(Iterable[T_co], extra=collections_abc.Reversible): __slots__ = () else: - class Reversible(_Protocol[T_co]): + @runtime + class Reversible(Protocol[T_co]): __slots__ = () @abstractmethod From 6a31a93b96007d1a60b34df469d93d3d4f62fddf Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 01:01:52 +0200 Subject: [PATCH 02/18] Uniform class and instance checks (same as collections.abc) --- src/typing.py | 48 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/typing.py b/src/typing.py index db77c50c7..705b0acb3 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1648,37 +1648,33 @@ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) if not cls.__dict__.get('_is_protocol', None): cls._is_protocol = any(b is Protocol or - getattr(b, '__origin__', None) is Protocol + isinstance(b, _ProtocolMeta) and b.__origin__ is Protocol for b in cls.__bases__) if cls._is_protocol: for base in cls.__mro__[1:]: - if not (base in (type, object) or base._is_protocol): + if not (base in (type, object) or base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic): raise TypeError('Protocols can only inherit from other protocols,' ' got %r' % base) - - def __instancecheck__(self, obj): - return issubclass(obj.__class__, self) - - def __subclasscheck__(self, cls): - if not self._is_protocol: - # No structural checks since this isn't a protocol. - return NotImplemented - if not getattr(self, '_is_runtime_protocol', None): - raise TypeError('Instance and class checks can only be used with' - ' @runtime protocols') - if self.__origin__ is not None: - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError("Parameterized generics cannot be used with class " - "or instance checks") - return False - - # Find all attributes defined in the protocol. - attrs = self._get_protocol_attrs() - - for attr in attrs: - if not any(attr in d.__dict__ for d in cls.__mro__): - return False - return True + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + cls.__init__ = _no_init + + def __protohook__(other): + if not getattr(cls, '_is_runtime_protocol', None): + raise TypeError('Instance and class checks can only be used with' + ' @runtime protocols') + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + else: + return NotImplemented + return True + cls.__subclasshook__ = __protohook__ def _get_protocol_attrs(self): # Get all Protocol base classes. From f1072994dad1cd1183aa43dca9e0600a46f011b1 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 01:07:51 +0200 Subject: [PATCH 03/18] Fix lint --- src/typing.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/typing.py b/src/typing.py index 705b0acb3..780d56e36 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1648,7 +1648,8 @@ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) if not cls.__dict__.get('_is_protocol', None): cls._is_protocol = any(b is Protocol or - isinstance(b, _ProtocolMeta) and b.__origin__ is Protocol + isinstance(b, _ProtocolMeta) and + b.__origin__ is Protocol for b in cls.__bases__) if cls._is_protocol: for base in cls.__mro__[1:]: @@ -1656,12 +1657,13 @@ def __init__(cls, *args, **kwargs): isinstance(base, GenericMeta) and base.__origin__ is Generic): raise TypeError('Protocols can only inherit from other protocols,' ' got %r' % base) + def _no_init(self, *args, **kwargs): if type(self)._is_protocol: raise TypeError('Protocols cannot be instantiated') cls.__init__ = _no_init - def __protohook__(other): + def _proto_hook(other): if not getattr(cls, '_is_runtime_protocol', None): raise TypeError('Instance and class checks can only be used with' ' @runtime protocols') @@ -1674,7 +1676,7 @@ def __protohook__(other): else: return NotImplemented return True - cls.__subclasshook__ = __protohook__ + cls.__subclasshook__ = _proto_hook def _get_protocol_attrs(self): # Get all Protocol base classes. From 8e07b885f0159e513e01afdee8b3ea06bdc0fa6e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 01:31:31 +0200 Subject: [PATCH 04/18] Simplify code using absence of non-protocol bases + some support for PY 3.6 --- src/typing.py | 41 +++++++++++++---------------------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/src/typing.py b/src/typing.py index 780d56e36..17ae1bea9 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1653,7 +1653,7 @@ def __init__(cls, *args, **kwargs): for b in cls.__bases__) if cls._is_protocol: for base in cls.__mro__[1:]: - if not (base in (type, object) or base._is_protocol or + if not (base is object or base._is_protocol or isinstance(base, GenericMeta) and base.__origin__ is Generic): raise TypeError('Protocols can only inherit from other protocols,' ' got %r' % base) @@ -1679,41 +1679,26 @@ def _proto_hook(other): cls.__subclasshook__ = _proto_hook def _get_protocol_attrs(self): - # Get all Protocol base classes. - protocol_bases = [] - for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != 'Protocol': - protocol_bases.append(c) - - # Get attributes included in protocol. attrs = set() - for base in protocol_bases: - for attr in base.__dict__.keys(): + for base in self.__mro__[:-1]: # without object + if base.__name__ == 'Protocol': + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): # Include attributes not defined in any non-protocol bases. for c in self.__mro__: if (c is not base and attr in c.__dict__ and not getattr(c, '_is_protocol', False)): break else: - if (not attr.startswith('_abc_') and - attr != '__abstractmethods__' and - attr != '__annotations__' and - attr != '__weakref__' and - attr != '_is_protocol' and - attr != '_is_runtime_protocol' and - attr != '__dict__' and - attr != '__args__' and - attr != '__slots__' and - attr != '_get_protocol_attrs' and - attr != '__next_in_mro__' and - attr != '__parameters__' and - attr != '__origin__' and - attr != '__orig_bases__' and - attr != '__extra__' and - attr != '__tree_hash__' and - attr != '__module__'): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__module__')): attrs.add(attr) - return attrs From e3db3de4f8f95ccf9bc3e3450e443657592cbce3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 17:55:05 +0200 Subject: [PATCH 05/18] Better support for isinstance(); add some tests --- src/test_typing.py | 279 ++++++++++++++++++++++++++++++++++++++++++++- src/typing.py | 44 ++++--- 2 files changed, 306 insertions(+), 17 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index c9901c441..2bd2577da 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -13,6 +13,7 @@ from typing import Tuple, List, MutableMapping from typing import Callable from typing import Generic, ClassVar, GenericMeta +from typing import Protocol, runtime from typing import cast from typing import get_type_hints from typing import no_type_check, no_type_check_decorator @@ -525,8 +526,284 @@ def get(self, key: str, default=None): return default +PY36 = sys.version_info[:2] >= (3, 6) + +PY36_PROTOCOL_TESTS = """ +class Coordinate(Protocol): + x: int + y: int + +@runtime +class Point(Coordinate, Protocol): + label: str + +class MyPoint: + x: int + y: int + label: str + +class BadPoint: + z: str + +class XAxis(Protocol): + x: int + +class YAxis(Protocol): + y: int + +@runtime +class Position(XAxis, YAxis, Position): + pass + +@runtime +class Proto(Protocol): + attr: int + def meth(self, arg: str) -> int: + ... + +class Concrete(Proto): + pass + +class Other: + attr: int + def meth(self, arg: str) -> int: + if arg == 'this': + return 1 + return 0 +""" + +if PY36: + exec(PY36_PROTOCOL_TESTS) +else: + # fake names for the sake of static analysis + Coordinate = Point = MyPoint = BadPoint = object + XAxis = YAxis = Position = Proto = Concrete = Other = object + + class ProtocolTests(BaseTestCase): + def test_basic_protocol(self): + @runtime + class P(Protocol): + def meth(self): + pass + class C: pass + class D: + def meth(self): + pass + self.assertIsSubclass(D, P) + self.assertIsInstance(D(), P) + self.assertNotIsSubclass(C, P) + self.assertNotIsInstance(C(), P) + + def test_everything_implements_empty_protocol(self): + @runtime + class Empty(Protocol): pass + class C: pass + for thing in (object, type, tuple, C): + self.assertIsSubclass(thing, Empty) + for thing in (object(), 1, (), typing): + self.assertIsInstance(thing, Empty) + + def test_no_inheritance_from_nominal(self): + class C: pass + class BP(Protocol): pass + with self.assertRaises(TypeError): + class P(C, Protocol): + pass + with self.assertRaises(TypeError): + class P(Protocol, C): + pass + with self.assertRaises(TypeError): + class P(BP, C, Protocol): + pass + class D(BP, C): pass + class E(C, BP): pass + self.assertNotIsInstance(D(), E) + self.assertNotIsInstance(E(), D) + + def test_no_instantiation(self): + class P(Protocol): pass + with self.assertRaises(TypeError): + P() + class C(P): pass + self.assertIsInstance(C(), C) + T = TypeVar('T') + class PG(Protocol[T]): pass + with self.assertRaises(TypeError): + PG() + with self.assertRaises(TypeError): + PG[int]() + with self.assertRaises(TypeError): + PG[T]() + class CG(PG[T]): pass + self.assertIsInstance(CG[int](), CG) + + def test_subprotocols_extending(self): + class P1(Protocol): + def meth1(self): + pass + @runtime + class P2(P1, Protocol): + def meth2(self): + pass + class C: + def meth1(self): + pass + def meth2(self): + pass + class C1: + def meth1(self): + pass + class C2: + def meth2(self): + pass + self.assertNotIsInstance(C1(), P2) + self.assertNotIsInstance(C2(), P2) + self.assertNotIsSubclass(C1, P2) + self.assertNotIsSubclass(C2, P2) + self.assertIsInstance(C(), P2) + self.assertIsSubclass(C, P2) + + def test_subprotocols_merging(self): + class P1(Protocol): + def meth1(self): + pass + class P2(Protocol): + def meth2(self): + pass + @runtime + class P(P1, P2, Protocol): + pass + class C: + def meth1(self): + pass + def meth2(self): + pass + class C1: + def meth1(self): + pass + class C2: + def meth2(self): + pass + self.assertNotIsInstance(C1(), P) + self.assertNotIsInstance(C2(), P) + self.assertNotIsSubclass(C1, P) + self.assertNotIsSubclass(C2, P) + self.assertIsInstance(C(), P) + self.assertIsSubclass(C, P) + + def test_protocols_issubclass(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class BadP(Protocol): + x = 1 + class BadPG(Protocol[T]): + x = 1 + class C: + x = 1 + self.assertIsSubclass(C, P) + self.assertIsSubclass(C, PG) + self.assertIsSubclass(BadP, PG) + self.assertIsSubclass(PG[int], PG) + self.assertIsSubclass(BadPG[int], P) + self.assertIsSubclass(BadPG[T], PG) + with self.assertRaises(TypeError): + issubclass(C, PG[T]) + with self.assertRaises(TypeError): + issubclass(C, PG[C]) + with self.assertRaises(TypeError): + issubclass(C, BadP) + with self.assertRaises(TypeError): + issubclass(C, BadPG) + with self.assertRaises(TypeError): + issubclass(P, PG[T]) + with self.assertRaises(TypeError): + issubclass(PG, PG[int]) + + @skipUnless(PY36, 'Python 3.6 required') + def test_protocols_issubclass_py36(self): + pass + + def test_protocols_isinstance(self): + T = TypeVar('T') + @runtime + class P(Protocol): + def meth(x): ... + @runtime + class PG(Protocol[T]): + def meth(x): ... + class BadP(Protocol): + def meth(x): ... + class BadPG(Protocol[T]): + def meth(x): ... + class C: + def meth(x): ... + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), PG) + with self.assertRaises(TypeError): + isinstance(C(), PG[T]) + with self.assertRaises(TypeError): + isinstance(C(), PG[C]) + with self.assertRaises(TypeError): + isinstance(C(), BadP) + with self.assertRaises(TypeError): + isinstance(C(), BadPG) + + @skipUnless(PY36, 'Python 3.6 required') + def test_protocols_isinstance_py36(self): + pass + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class C: + def __init__(self, x): + self.x = x + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_support_register(self): + pass + + def test_none_blocks_implementation(self): + pass + + def test_custom_subclasshook(self): + pass + + def test_non_protocol_subclasses(self): + # check both runtime and non-runtime + pass + + def test_defining_generic_protocols(self): + pass + + def test_protocols_bad_subscripts(self): + pass + + def test_generic_protocols_repr(self): + pass + + def test_generic_protocols_special_from_generic(self): + pass + + def test_generic_protocols_special_from_protocol(self): + pass + + def test_runtime_deco(self): + pass + def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) self.assertNotIsSubclass(str, typing.SupportsInt) @@ -1570,8 +1847,6 @@ def __anext__(self) -> T_a: asyncio = None AwaitableWrapper = AsyncIteratorWrapper = object -PY36 = sys.version_info[:2] >= (3, 6) - PY36_TESTS = """ from test import ann_module, ann_module2, ann_module3 diff --git a/src/typing.py b/src/typing.py index 17ae1bea9..222331118 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1653,7 +1653,8 @@ def __init__(cls, *args, **kwargs): for b in cls.__bases__) if cls._is_protocol: for base in cls.__mro__[1:]: - if not (base is object or base._is_protocol or + if not (base is object or + isinstance(base, _ProtocolMeta) and base._is_protocol or isinstance(base, GenericMeta) and base.__origin__ is Generic): raise TypeError('Protocols can only inherit from other protocols,' ' got %r' % base) @@ -1663,21 +1664,34 @@ def _no_init(self, *args, **kwargs): raise TypeError('Protocols cannot be instantiated') cls.__init__ = _no_init - def _proto_hook(other): - if not getattr(cls, '_is_runtime_protocol', None): - raise TypeError('Instance and class checks can only be used with' - ' @runtime protocols') - for attr in cls._get_protocol_attrs(): - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - else: - return NotImplemented - return True + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + if not cls.__dict__.get('_is_runtime_protocol', None): + print(cls) + raise TypeError('Instance and class checks can only be used with' + ' @runtime protocols') + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: cls.__subclasshook__ = _proto_hook + def __instancecheck__(self, instance): + # We need this function for situations where attributes are assigned in __init__ + if issubclass(instance.__class__, self): + return True + if self._is_protocol: + return all(hasattr(instance, attr) and getattr(instance, attr) is not None + for attr in self._get_protocol_attrs()) + return False + def _get_protocol_attrs(self): attrs = set() for base in self.__mro__[:-1]: # without object @@ -1726,7 +1740,7 @@ def runtime(cls): can be used with isinstance() and issubclass(). Raise TypeError if applied to a non-protocol class. """ - if not getattr(cls, '_is_protocol', None): + if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: raise TypeError('@runtime can be only applied to protocol classes,' ' got %r' % cls) cls._is_runtime_protocol = True From e580511940d83a258c02d5c3011230b8fcb3e0bd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 18:02:42 +0200 Subject: [PATCH 06/18] Minor fixes --- src/test_typing.py | 2 +- src/typing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 2bd2577da..798abe56b 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -552,7 +552,7 @@ class YAxis(Protocol): y: int @runtime -class Position(XAxis, YAxis, Position): +class Position(XAxis, YAxis, Protocol): pass @runtime diff --git a/src/typing.py b/src/typing.py index 222331118..01cd2f775 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1710,7 +1710,7 @@ def _get_protocol_attrs(self): '_is_protocol', '_is_runtime_protocol', '__dict__', '__args__', '__slots__', '_get_protocol_attrs', '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', + '__orig_bases__', '__extra__', '__tree_hash__', '__module__')): attrs.add(attr) return attrs From d99d9ef8537d8564b4585c8dff12cdaaac478b1c Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Apr 2017 21:32:23 +0200 Subject: [PATCH 07/18] Add more tests; improve docstrings --- src/test_typing.py | 184 +++++++++++++++++++++++++++++++++++++++++---- src/typing.py | 28 +++++-- 2 files changed, 193 insertions(+), 19 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 798abe56b..872810817 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -774,35 +774,191 @@ def __init__(self, x): self.assertIsInstance(C(1), PG) def test_protocols_support_register(self): - pass + @runtime + class P(Protocol): + x = 1 + class PM(Protocol): + def meth(self): pass + class D(PM): pass + class C: pass + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) def test_none_blocks_implementation(self): - pass - - def test_custom_subclasshook(self): - pass + @runtime + class P(Protocol): + x = 1 + class A: + x = 1 + class B(A): + x = None + class C: + def __init__(self): + self.x = None + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) def test_non_protocol_subclasses(self): - # check both runtime and non-runtime - pass + class P(Protocol): + x = 1 + @runtime + class PR(Protocol): + def meth(self): pass + class NonP(P): + x = 1 + class NonPR(PR): pass + class C: + x = 1 + class D: + def meth(self): pass + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + class OKClass: pass + class BadClass: + x = 1 + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) def test_defining_generic_protocols(self): - pass + T = TypeVar('T') + S = TypeVar('S') + @runtime + class PR(Protocol[T, S]): + def meth(self): pass + class P(PR[int, T], Protocol[T]): + y = 1 + self.assertIsSubclass(PR[int, T], PR) + self.assertIsSubclass(P[str], PR) + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + with self.assertRaises(TypeError): + PR[int, 1] + with self.assertRaises(TypeError): + PR[int, ClassVar] + class C(PR[int, T]): pass + self.assertIsInstance(C[str](), C) + + def test_init_called(self): + T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): + def __init__(self): + self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') def test_protocols_bad_subscripts(self): - pass + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class P(Protocol[int]): pass + with self.assertRaises(TypeError): + class P(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class P(typing.Mapping[T, S], Protocol[T]): pass def test_generic_protocols_repr(self): - pass + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertTrue(repr(P).endswith('P')) + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) def test_generic_protocols_special_from_generic(self): - pass + T = TypeVar('T') + class P(Protocol[T]): pass + self.assertEqual(P.__parameters__, (T,)) + self.assertIs(P.__args__, None) + self.assertIs(P.__origin__, None) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) def test_generic_protocols_special_from_protocol(self): - pass + @runtime + class PR(Protocol): + x = 1 + class P(Protocol): + def meth(self): + pass + T = TypeVar('T') + class PG(Protocol[T]): + x = 1 + def meth(self): + pass + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + with self.assertRaises(AttributeError): + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(P._get_protocol_attrs(), {'meth'}) + self.assertEqual(PR._get_protocol_attrs(), {'x'}) + self.assertEqual(frozenset(PG._get_protocol_attrs()), + frozenset({'x', 'meth'})) + self.assertEqual(frozenset(PG[int]._get_protocol_attrs()), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime + class C: pass + + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') - def test_runtime_deco(self): - pass + @runtime + class P(Protocol[T]): + x = 1 + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P) + D = pickle.loads(s) + class E: + x = 1 + self.assertIsInstance(E(), D) def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) diff --git a/src/typing.py b/src/typing.py index 01cd2f775..a8c4e0a99 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1668,7 +1668,6 @@ def _proto_hook(other): if not cls.__dict__.get('_is_protocol', None): return NotImplemented if not cls.__dict__.get('_is_runtime_protocol', None): - print(cls) raise TypeError('Instance and class checks can only be used with' ' @runtime protocols') for attr in cls._get_protocol_attrs(): @@ -1717,11 +1716,27 @@ def _get_protocol_attrs(self): class Protocol(metaclass=_ProtocolMeta): - """Base class for protocol classes. + """Base class for protocol classes. Protocol classes are defined as:: - This implements a simple-minded structural issubclass check - (similar but more general than the one-offs in collections.abc - such as Hashable). + class Proto(Protocol[T]): + def meth(self) -> T: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto[int]) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with @typing.runtime + act as simple-minded runtime protocols that checks only the presence of + given attributes, ignoring their type signatures. """ __slots__ = () @@ -1739,6 +1754,9 @@ def runtime(cls): """Mark a protocol class as a runtime protocol, so that it can be used with isinstance() and issubclass(). Raise TypeError if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. """ if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: raise TypeError('@runtime can be only applied to protocol classes,' From 831e1a7f65bf7a08eb6a0b6e299545bbc50fcb27 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 00:35:36 +0200 Subject: [PATCH 08/18] Add support for extending collections protocols (+ tests) --- src/test_typing.py | 69 +++++++++++++++ src/typing.py | 211 ++++++++++++++++++++++++--------------------- 2 files changed, 180 insertions(+), 100 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 872810817..3bd2cbb97 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1003,6 +1003,75 @@ def test_reversible(self): self.assertIsSubclass(list, typing.Reversible) self.assertNotIsSubclass(int, typing.Reversible) + def test_collection_protocols(self): + T = TypeVar('T') + class C(typing.Callable[[T], T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__call__', 'x'})) + if hasattr(typing, 'Awaitable'): + class C(typing.Awaitable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__await__', 'x'})) + class C(typing.Iterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'x'})) + class C(typing.Iterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', '__next__', 'x'})) + if hasattr(typing, 'AsyncIterable'): + class C(typing.AsyncIterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aiter__', 'x'})) + if hasattr(typing, 'AsyncIterator'): + class C(typing.AsyncIterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aiter__', '__anext__', 'x'})) + class C(typing.Hashable, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__hash__', 'x'})) + class C(typing.Sized, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', 'x'})) + class C(typing.Container[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__contains__', 'x'})) + if hasattr(collections_abc, 'Reversible'): + class C(typing.Reversible[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', 'x'})) + if hasattr(typing, 'Collection'): + class C(typing.Collection[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__iter__', '__contains__', 'x'})) + class C(typing.Sequence[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', 'count', 'index', 'x'})) + class C(typing.MutableSequence[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', '__setitem__', '__delitem__', + '__iadd__', 'count', 'index', 'extend', 'clear', + 'insert', 'append', 'remove', 'pop', 'reverse', 'x'})) + class C(typing.Mapping[T, int], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', 'items', 'keys', 'values', 'get', 'x'})) + class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', '__setitem__', '__delitem__', 'items', + 'keys', 'values', 'get', 'clear', 'pop', 'popitem', + 'update', 'setdefault', 'x'})) + if hasattr(typing, 'ContextManager'): + class C(typing.ContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__enter__', '__exit__', 'x'})) + if hasattr(typing, 'AsyncContextManager'): + class C(typing.AsyncContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__aenter__', '__aexit__', 'x'})) + def test_protocol_instance(self): self.assertIsInstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): diff --git a/src/typing.py b/src/typing.py index a8c4e0a99..6b0e2f82d 100644 --- a/src/typing.py +++ b/src/typing.py @@ -926,9 +926,6 @@ def _no_slots_copy(dct): return dict_copy -Protocol = object() - - class GenericMeta(TypingMeta, abc.ABCMeta): """Metaclass for generic types. @@ -1200,8 +1197,10 @@ def __setattr__(self, attr, value): super(GenericMeta, _gorg(self)).__setattr__(attr, value) -# Prevent checks for Generic to crash when defining Generic. +# Prevent checks for Generic, etc. to crash when defining Generic. Generic = None +Protocol = object() +Callable = object() def _generic_new(base_cls, cls, *args, **kwds): @@ -1314,83 +1313,6 @@ def __new__(cls, *args, **kwds): return _generic_new(tuple, cls, *args, **kwds) -class CallableMeta(GenericMeta): - """Metaclass for Callable (internal).""" - - def __repr__(self): - if self.__origin__ is None: - return super().__repr__() - return self._tree_repr(self._subs_tree()) - - def _tree_repr(self, tree): - if _gorg(self) is not Callable: - return super()._tree_repr(tree) - # For actual Callable (not its subclass) we override - # super()._tree_repr() for nice formatting. - arg_list = [] - for arg in tree[1:]: - if not isinstance(arg, tuple): - arg_list.append(_type_repr(arg)) - else: - arg_list.append(arg[0]._tree_repr(arg)) - if arg_list[0] == '...': - return repr(tree[0]) + '[..., %s]' % arg_list[1] - return (repr(tree[0]) + - '[[%s], %s]' % (', '.join(arg_list[:-1]), arg_list[-1])) - - def __getitem__(self, parameters): - """A thin wrapper around __getitem_inner__ to provide the latter - with hashable arguments to improve speed. - """ - - if self.__origin__ is not None or not _geqv(self, Callable): - return super().__getitem__(parameters) - if not isinstance(parameters, tuple) or len(parameters) != 2: - raise TypeError("Callable must be used as " - "Callable[[arg, ...], result].") - args, result = parameters - if args is Ellipsis: - parameters = (Ellipsis, result) - else: - if not isinstance(args, list): - raise TypeError("Callable[args, result]: args must be a list." - " Got %.100r." % (args,)) - parameters = (tuple(args), result) - return self.__getitem_inner__(parameters) - - @_tp_cache - def __getitem_inner__(self, parameters): - args, result = parameters - msg = "Callable[args, result]: result must be a type." - result = _type_check(result, msg) - if args is Ellipsis: - return super().__getitem__((_TypingEllipsis, result)) - msg = "Callable[[arg, ...], result]: each arg must be a type." - args = tuple(_type_check(arg, msg) for arg in args) - parameters = args + (result,) - return super().__getitem__(parameters) - - -class Callable(extra=collections_abc.Callable, metaclass=CallableMeta): - """Callable type; Callable[[int], str] is a function of (int) -> str. - - The subscription syntax must always be used with exactly two - values: the argument list and the return type. The argument list - must be a list of types or ellipsis; the return type must be a single type. - - There is no syntax to indicate optional or keyword arguments, - such function types are rarely used as callback types. - """ - - __slots__ = () - - def __new__(cls, *args, **kwds): - if _geqv(cls, Callable): - raise TypeError("Type Callable cannot be instantiated; " - "use a non-abstract subclass instead") - return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) - - class _ClassVar(_FinalTypingBase, _root=True): """Special type construct to mark class variables. @@ -1637,6 +1559,18 @@ def utf8(value): return _overload_dummy +def _collection_protocol(cls): + # Selected set of collections ABCs that are considered protocols. + qname = cls.__qualname__ + return (qname in ('Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and + cls.__module__ in ('collections.abc', 'typing', 'contextlib')) + + class _ProtocolMeta(GenericMeta): """Internal metaclass for Protocol. @@ -1653,9 +1587,10 @@ def __init__(cls, *args, **kwargs): for b in cls.__bases__) if cls._is_protocol: for base in cls.__mro__[1:]: - if not (base is object or - isinstance(base, _ProtocolMeta) and base._is_protocol or - isinstance(base, GenericMeta) and base.__origin__ is Generic): + if not (base in (object, Generic, Callable) or + isinstance(base, TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic or + _collection_protocol(base)): raise TypeError('Protocols can only inherit from other protocols,' ' got %r' % base) @@ -1668,6 +1603,8 @@ def _proto_hook(other): if not cls.__dict__.get('_is_protocol', None): return NotImplemented if not cls.__dict__.get('_is_runtime_protocol', None): + if sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']: + return NotImplemented raise TypeError('Instance and class checks can only be used with' ' @runtime protocols') for attr in cls._get_protocol_attrs(): @@ -1683,7 +1620,7 @@ def _proto_hook(other): cls.__subclasshook__ = _proto_hook def __instancecheck__(self, instance): - # We need this function for situations where attributes are assigned in __init__ + # We need this method for situations where attributes are assigned in __init__ if issubclass(instance.__class__, self): return True if self._is_protocol: @@ -1694,24 +1631,20 @@ def __instancecheck__(self, instance): def _get_protocol_attrs(self): attrs = set() for base in self.__mro__[:-1]: # without object - if base.__name__ == 'Protocol': + if base.__name__ in ('Protocol', 'Generic'): continue annotations = getattr(base, '__annotations__', {}) for attr in list(base.__dict__.keys()) + list(annotations.keys()): - # Include attributes not defined in any non-protocol bases. - for c in self.__mro__: - if (c is not base and attr in c.__dict__ and - not getattr(c, '_is_protocol', False)): - break - else: - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', '_get_protocol_attrs', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__module__')): - attrs.add(attr) + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker') and + getattr(base, attr, object()) is not None): + attrs.add(attr) return attrs @@ -1765,6 +1698,83 @@ def runtime(cls): return cls +class CallableMeta(_ProtocolMeta): + """Metaclass for Callable (internal).""" + + def __repr__(self): + if self.__origin__ is None: + return super().__repr__() + return self._tree_repr(self._subs_tree()) + + def _tree_repr(self, tree): + if _gorg(self) is not Callable: + return super()._tree_repr(tree) + # For actual Callable (not its subclass) we override + # super()._tree_repr() for nice formatting. + arg_list = [] + for arg in tree[1:]: + if not isinstance(arg, tuple): + arg_list.append(_type_repr(arg)) + else: + arg_list.append(arg[0]._tree_repr(arg)) + if arg_list[0] == '...': + return repr(tree[0]) + '[..., %s]' % arg_list[1] + return (repr(tree[0]) + + '[[%s], %s]' % (', '.join(arg_list[:-1]), arg_list[-1])) + + def __getitem__(self, parameters): + """A thin wrapper around __getitem_inner__ to provide the latter + with hashable arguments to improve speed. + """ + + if self.__origin__ is not None or not _geqv(self, Callable): + return super().__getitem__(parameters) + if not isinstance(parameters, tuple) or len(parameters) != 2: + raise TypeError("Callable must be used as " + "Callable[[arg, ...], result].") + args, result = parameters + if args is Ellipsis: + parameters = (Ellipsis, result) + else: + if not isinstance(args, list): + raise TypeError("Callable[args, result]: args must be a list." + " Got %.100r." % (args,)) + parameters = (tuple(args), result) + return self.__getitem_inner__(parameters) + + @_tp_cache + def __getitem_inner__(self, parameters): + args, result = parameters + msg = "Callable[args, result]: result must be a type." + result = _type_check(result, msg) + if args is Ellipsis: + return super().__getitem__((_TypingEllipsis, result)) + msg = "Callable[[arg, ...], result]: each arg must be a type." + args = tuple(_type_check(arg, msg) for arg in args) + parameters = args + (result,) + return super().__getitem__(parameters) + + +class Callable(extra=collections_abc.Callable, metaclass=CallableMeta): + """Callable type; Callable[[int], str] is a function of (int) -> str. + + The subscription syntax must always be used with exactly two + values: the argument list and the return type. The argument list + must be a list of types or ellipsis; the return type must be a single type. + + There is no syntax to indicate optional or keyword arguments, + such function types are rarely used as callback types. + """ + + __slots__ = () + + def __new__(cls, *args, **kwds): + if _geqv(cls, Callable): + raise TypeError("Type Callable cannot be instantiated; " + "use a non-abstract subclass instead") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. @@ -1794,6 +1804,7 @@ class AsyncIterable(Generic[T_co], extra=collections_abc.AsyncIterable): class AsyncIterator(AsyncIterable[T_co], extra=collections_abc.AsyncIterator): __slots__ = () + _is_protocol = True __all__.append('AsyncIterable') __all__.append('AsyncIterator') From a932e9faa38f429afbfd0acf84145b7c4b35d2df Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 00:46:23 +0200 Subject: [PATCH 09/18] Fix behaviour in different veriosns --- src/test_typing.py | 20 +++++++++++--------- src/typing.py | 14 +++++++------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 3bd2cbb97..47f8e8151 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1038,7 +1038,7 @@ class C(typing.Container[T], Protocol[T]): x = 1 if hasattr(collections_abc, 'Reversible'): class C(typing.Reversible[T], Protocol[T]): x = 1 self.assertEqual(frozenset(C[int]._get_protocol_attrs()), - frozenset({'__reversed__', 'x'})) + frozenset({'__reversed__', '__iter__', 'x'})) if hasattr(typing, 'Collection'): class C(typing.Collection[T], Protocol[T]): x = 1 self.assertEqual(frozenset(C[int]._get_protocol_attrs()), @@ -1054,15 +1054,17 @@ class C(typing.MutableSequence[T], Protocol[T]): x = 1 '__iadd__', 'count', 'index', 'extend', 'clear', 'insert', 'append', 'remove', 'pop', 'reverse', 'x'})) class C(typing.Mapping[T, int], Protocol[T]): x = 1 - self.assertEqual(frozenset(C[int]._get_protocol_attrs()), - frozenset({'__len__', '__getitem__', '__iter__', '__contains__', - '__eq__', 'items', 'keys', 'values', 'get', 'x'})) + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', 'items', 'keys', 'values', 'get', 'x'})) class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 - self.assertEqual(frozenset(C[int]._get_protocol_attrs()), - frozenset({'__len__', '__getitem__', '__iter__', '__contains__', - '__eq__', '__setitem__', '__delitem__', 'items', - 'keys', 'values', 'get', 'clear', 'pop', 'popitem', - 'update', 'setdefault', 'x'})) + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', '__setitem__', '__delitem__', 'items', + 'keys', 'values', 'get', 'clear', 'pop', 'popitem', + 'update', 'setdefault', 'x'})) if hasattr(typing, 'ContextManager'): class C(typing.ContextManager[T], Protocol[T]): x = 1 self.assertEqual(frozenset(C[int]._get_protocol_attrs()), diff --git a/src/typing.py b/src/typing.py index 6b0e2f82d..177818c3a 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1561,13 +1561,13 @@ def utf8(value): def _collection_protocol(cls): # Selected set of collections ABCs that are considered protocols. - qname = cls.__qualname__ - return (qname in ('Callable', 'Awaitable', - 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', - 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', - 'AbstractContextManager', 'ContextManager', - 'AbstractAsyncContextManager', 'AsyncContextManager',) and + name = cls.__name__ + return (name in ('Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and cls.__module__ in ('collections.abc', 'typing', 'contextlib')) From 3e34f7adb77c93fac1b2b1c0888fed93ac63bec2 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 00:50:09 +0200 Subject: [PATCH 10/18] More compatibility fixes --- src/typing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/typing.py b/src/typing.py index 177818c3a..4900e020a 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1562,13 +1562,14 @@ def utf8(value): def _collection_protocol(cls): # Selected set of collections ABCs that are considered protocols. name = cls.__name__ - return (name in ('Callable', 'Awaitable', + return (name in ('ABC', 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', 'AbstractContextManager', 'ContextManager', 'AbstractAsyncContextManager', 'AsyncContextManager',) and - cls.__module__ in ('collections.abc', 'typing', 'contextlib')) + cls.__module__ in ('collections.abc', 'typing', 'contextlib', + '_abcol', 'abc')) class _ProtocolMeta(GenericMeta): From 8483810ce54fdce9e318db5700c19860819d2a40 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 00:53:27 +0200 Subject: [PATCH 11/18] Fix typos and lint --- src/test_typing.py | 2 +- src/typing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 47f8e8151..62cb8b7ef 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1062,7 +1062,7 @@ class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 # We use superset, since some versions also have '__ne__' self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= frozenset({'__len__', '__getitem__', '__iter__', '__contains__', - '__eq__', '__setitem__', '__delitem__', 'items', + '__eq__', '__setitem__', '__delitem__', 'items', 'keys', 'values', 'get', 'clear', 'pop', 'popitem', 'update', 'setdefault', 'x'})) if hasattr(typing, 'ContextManager'): diff --git a/src/typing.py b/src/typing.py index 4900e020a..2201ecf4a 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1569,7 +1569,7 @@ def _collection_protocol(cls): 'AbstractContextManager', 'ContextManager', 'AbstractAsyncContextManager', 'AsyncContextManager',) and cls.__module__ in ('collections.abc', 'typing', 'contextlib', - '_abcol', 'abc')) + '_abcoll', 'abc')) class _ProtocolMeta(GenericMeta): From 9e7b6740616afe44a9a9b872fe130373f9fee6a8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 00:59:16 +0200 Subject: [PATCH 12/18] Even more compatibility fixes --- src/test_typing.py | 11 ++++++----- src/typing.py | 1 - 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 62cb8b7ef..57b771067 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -1047,12 +1047,13 @@ class C(typing.Sequence[T], Protocol[T]): x = 1 self.assertEqual(frozenset(C[int]._get_protocol_attrs()), frozenset({'__reversed__', '__contains__', '__getitem__', '__len__', '__iter__', 'count', 'index', 'x'})) + # We use superset, since Python 3.2 does not have 'clear' class C(typing.MutableSequence[T], Protocol[T]): x = 1 - self.assertEqual(frozenset(C[int]._get_protocol_attrs()), - frozenset({'__reversed__', '__contains__', '__getitem__', - '__len__', '__iter__', '__setitem__', '__delitem__', - '__iadd__', 'count', 'index', 'extend', 'clear', - 'insert', 'append', 'remove', 'pop', 'reverse', 'x'})) + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', '__setitem__', '__delitem__', + '__iadd__', 'count', 'index', 'extend', 'insert', + 'append', 'remove', 'pop', 'reverse', 'x'})) class C(typing.Mapping[T, int], Protocol[T]): x = 1 # We use superset, since some versions also have '__ne__' self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= diff --git a/src/typing.py b/src/typing.py index 2201ecf4a..c5b90aa4a 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1805,7 +1805,6 @@ class AsyncIterable(Generic[T_co], extra=collections_abc.AsyncIterable): class AsyncIterator(AsyncIterable[T_co], extra=collections_abc.AsyncIterator): __slots__ = () - _is_protocol = True __all__.append('AsyncIterable') __all__.append('AsyncIterator') From 8ee5b712abf655d851d7c8282bef505434eb522e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 13:17:33 +0200 Subject: [PATCH 13/18] Reorganize code more logically (hopefully this will also reduce diff) --- src/typing.py | 434 +++++++++++++++++++++++++------------------------- 1 file changed, 217 insertions(+), 217 deletions(-) diff --git a/src/typing.py b/src/typing.py index c5b90aa4a..f43585719 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1170,6 +1170,12 @@ def __subclasscheck__(self, cls): if self is Generic: raise TypeError("Class %r cannot be used with class " "or instance checks" % self) + if (self.__dict__.get('_is_protocol', None) and + not self.__dict__.get('_is_runtime_protocol', None)): + if sys._getframe(1).f_globals['__name__'] in ['abc', 'functools']: + return False + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") return super().__subclasscheck__(cls) def __instancecheck__(self, instance): @@ -1313,6 +1319,217 @@ def __new__(cls, *args, **kwds): return _generic_new(tuple, cls, *args, **kwds) +def _collection_protocol(cls): + # Selected set of collections ABCs that are considered protocols. + name = cls.__name__ + return (name in ('ABC', 'Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and + cls.__module__ in ('collections.abc', 'typing', 'contextlib', + '_abcoll', 'abc')) + + +class _ProtocolMeta(GenericMeta): + """Internal metaclass for Protocol. + + This exists so Protocol classes can be generic without deriving + from Generic. + """ + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + isinstance(b, _ProtocolMeta) and + b.__origin__ is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (object, Generic, Callable) or + isinstance(base, TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic or + _collection_protocol(base)): + raise TypeError('Protocols can only inherit from other protocols,' + ' got %r' % base) + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + cls.__init__ = _no_init + + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + def __instancecheck__(self, instance): + # We need this method for situations where attributes are assigned in __init__ + if issubclass(instance.__class__, self): + return True + if self._is_protocol: + return all(hasattr(instance, attr) and getattr(instance, attr) is not None + for attr in self._get_protocol_attrs()) + return False + + def _get_protocol_attrs(self): + attrs = set() + for base in self.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker') and + getattr(base, attr, object()) is not None): + attrs.add(attr) + return attrs + + +class Protocol(metaclass=_ProtocolMeta): + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol[T]): + def meth(self) -> T: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto[int]) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with @typing.runtime + act as simple-minded runtime protocols that checks only the presence of + given attributes, ignoring their type signatures. + """ + + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if _geqv(cls, Protocol): + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + +def runtime(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + + +class CallableMeta(_ProtocolMeta): + """Metaclass for Callable (internal).""" + + def __repr__(self): + if self.__origin__ is None: + return super().__repr__() + return self._tree_repr(self._subs_tree()) + + def _tree_repr(self, tree): + if _gorg(self) is not Callable: + return super()._tree_repr(tree) + # For actual Callable (not its subclass) we override + # super()._tree_repr() for nice formatting. + arg_list = [] + for arg in tree[1:]: + if not isinstance(arg, tuple): + arg_list.append(_type_repr(arg)) + else: + arg_list.append(arg[0]._tree_repr(arg)) + if arg_list[0] == '...': + return repr(tree[0]) + '[..., %s]' % arg_list[1] + return (repr(tree[0]) + + '[[%s], %s]' % (', '.join(arg_list[:-1]), arg_list[-1])) + + def __getitem__(self, parameters): + """A thin wrapper around __getitem_inner__ to provide the latter + with hashable arguments to improve speed. + """ + + if self.__origin__ is not None or not _geqv(self, Callable): + return super().__getitem__(parameters) + if not isinstance(parameters, tuple) or len(parameters) != 2: + raise TypeError("Callable must be used as " + "Callable[[arg, ...], result].") + args, result = parameters + if args is Ellipsis: + parameters = (Ellipsis, result) + else: + if not isinstance(args, list): + raise TypeError("Callable[args, result]: args must be a list." + " Got %.100r." % (args,)) + parameters = (tuple(args), result) + return self.__getitem_inner__(parameters) + + @_tp_cache + def __getitem_inner__(self, parameters): + args, result = parameters + msg = "Callable[args, result]: result must be a type." + result = _type_check(result, msg) + if args is Ellipsis: + return super().__getitem__((_TypingEllipsis, result)) + msg = "Callable[[arg, ...], result]: each arg must be a type." + args = tuple(_type_check(arg, msg) for arg in args) + parameters = args + (result,) + return super().__getitem__(parameters) + + +class Callable(extra=collections_abc.Callable, metaclass=CallableMeta): + """Callable type; Callable[[int], str] is a function of (int) -> str. + + The subscription syntax must always be used with exactly two + values: the argument list and the return type. The argument list + must be a list of types or ellipsis; the return type must be a single type. + + There is no syntax to indicate optional or keyword arguments, + such function types are rarely used as callback types. + """ + + __slots__ = () + + def __new__(cls, *args, **kwds): + if _geqv(cls, Callable): + raise TypeError("Type Callable cannot be instantiated; " + "use a non-abstract subclass instead") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + class _ClassVar(_FinalTypingBase, _root=True): """Special type construct to mark class variables. @@ -1559,223 +1776,6 @@ def utf8(value): return _overload_dummy -def _collection_protocol(cls): - # Selected set of collections ABCs that are considered protocols. - name = cls.__name__ - return (name in ('ABC', 'Callable', 'Awaitable', - 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', - 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', - 'AbstractContextManager', 'ContextManager', - 'AbstractAsyncContextManager', 'AsyncContextManager',) and - cls.__module__ in ('collections.abc', 'typing', 'contextlib', - '_abcoll', 'abc')) - - -class _ProtocolMeta(GenericMeta): - """Internal metaclass for Protocol. - - This exists so Protocol classes can be generic without deriving - from Generic. - """ - - def __init__(cls, *args, **kwargs): - super().__init__(*args, **kwargs) - if not cls.__dict__.get('_is_protocol', None): - cls._is_protocol = any(b is Protocol or - isinstance(b, _ProtocolMeta) and - b.__origin__ is Protocol - for b in cls.__bases__) - if cls._is_protocol: - for base in cls.__mro__[1:]: - if not (base in (object, Generic, Callable) or - isinstance(base, TypingMeta) and base._is_protocol or - isinstance(base, GenericMeta) and base.__origin__ is Generic or - _collection_protocol(base)): - raise TypeError('Protocols can only inherit from other protocols,' - ' got %r' % base) - - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') - cls.__init__ = _no_init - - def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): - return NotImplemented - if not cls.__dict__.get('_is_runtime_protocol', None): - if sys._getframe(3).f_globals['__name__'] in ['abc', 'functools']: - return NotImplemented - raise TypeError('Instance and class checks can only be used with' - ' @runtime protocols') - for attr in cls._get_protocol_attrs(): - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - else: - return NotImplemented - return True - if '__subclasshook__' not in cls.__dict__: - cls.__subclasshook__ = _proto_hook - - def __instancecheck__(self, instance): - # We need this method for situations where attributes are assigned in __init__ - if issubclass(instance.__class__, self): - return True - if self._is_protocol: - return all(hasattr(instance, attr) and getattr(instance, attr) is not None - for attr in self._get_protocol_attrs()) - return False - - def _get_protocol_attrs(self): - attrs = set() - for base in self.__mro__[:-1]: # without object - if base.__name__ in ('Protocol', 'Generic'): - continue - annotations = getattr(base, '__annotations__', {}) - for attr in list(base.__dict__.keys()) + list(annotations.keys()): - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', '_get_protocol_attrs', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker') and - getattr(base, attr, object()) is not None): - attrs.add(attr) - return attrs - - -class Protocol(metaclass=_ProtocolMeta): - """Base class for protocol classes. Protocol classes are defined as:: - - class Proto(Protocol[T]): - def meth(self) -> T: - ... - - Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: - - class C: - def meth(self) -> int: - return 0 - - def func(x: Proto[int]) -> int: - return x.meth() - - func(C()) # Passes static type check - - See PEP 544 for details. Protocol classes decorated with @typing.runtime - act as simple-minded runtime protocols that checks only the presence of - given attributes, ignoring their type signatures. - """ - - __slots__ = () - - _is_protocol = True - - def __new__(cls, *args, **kwds): - if _geqv(cls, Protocol): - raise TypeError("Type Protocol cannot be instantiated; " - "it can be used only as a base class") - return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) - - -def runtime(cls): - """Mark a protocol class as a runtime protocol, so that it - can be used with isinstance() and issubclass(). Raise TypeError - if applied to a non-protocol class. - - This allows a simple-minded structural check very similar to the - one-offs in collections.abc such as Hashable. - """ - if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: - raise TypeError('@runtime can be only applied to protocol classes,' - ' got %r' % cls) - cls._is_runtime_protocol = True - return cls - - -class CallableMeta(_ProtocolMeta): - """Metaclass for Callable (internal).""" - - def __repr__(self): - if self.__origin__ is None: - return super().__repr__() - return self._tree_repr(self._subs_tree()) - - def _tree_repr(self, tree): - if _gorg(self) is not Callable: - return super()._tree_repr(tree) - # For actual Callable (not its subclass) we override - # super()._tree_repr() for nice formatting. - arg_list = [] - for arg in tree[1:]: - if not isinstance(arg, tuple): - arg_list.append(_type_repr(arg)) - else: - arg_list.append(arg[0]._tree_repr(arg)) - if arg_list[0] == '...': - return repr(tree[0]) + '[..., %s]' % arg_list[1] - return (repr(tree[0]) + - '[[%s], %s]' % (', '.join(arg_list[:-1]), arg_list[-1])) - - def __getitem__(self, parameters): - """A thin wrapper around __getitem_inner__ to provide the latter - with hashable arguments to improve speed. - """ - - if self.__origin__ is not None or not _geqv(self, Callable): - return super().__getitem__(parameters) - if not isinstance(parameters, tuple) or len(parameters) != 2: - raise TypeError("Callable must be used as " - "Callable[[arg, ...], result].") - args, result = parameters - if args is Ellipsis: - parameters = (Ellipsis, result) - else: - if not isinstance(args, list): - raise TypeError("Callable[args, result]: args must be a list." - " Got %.100r." % (args,)) - parameters = (tuple(args), result) - return self.__getitem_inner__(parameters) - - @_tp_cache - def __getitem_inner__(self, parameters): - args, result = parameters - msg = "Callable[args, result]: result must be a type." - result = _type_check(result, msg) - if args is Ellipsis: - return super().__getitem__((_TypingEllipsis, result)) - msg = "Callable[[arg, ...], result]: each arg must be a type." - args = tuple(_type_check(arg, msg) for arg in args) - parameters = args + (result,) - return super().__getitem__(parameters) - - -class Callable(extra=collections_abc.Callable, metaclass=CallableMeta): - """Callable type; Callable[[int], str] is a function of (int) -> str. - - The subscription syntax must always be used with exactly two - values: the argument list and the return type. The argument list - must be a list of types or ellipsis; the return type must be a single type. - - There is no syntax to indicate optional or keyword arguments, - such function types are rarely used as callback types. - """ - - __slots__ = () - - def __new__(cls, *args, **kwds): - if _geqv(cls, Callable): - raise TypeError("Type Callable cannot be instantiated; " - "use a non-abstract subclass instead") - return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) - - # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. From 4a0d335ca2611564f6925bd263f16143462147cd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 25 Apr 2017 23:13:00 +0200 Subject: [PATCH 14/18] Add tests for PY36 and @abstracmethod --- src/test_typing.py | 64 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/src/test_typing.py b/src/test_typing.py index 57b771067..ee65be7ed 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -542,9 +542,6 @@ class MyPoint: y: int label: str -class BadPoint: - z: str - class XAxis(Protocol): x: int @@ -565,7 +562,7 @@ class Concrete(Proto): pass class Other: - attr: int + attr: int = 1 def meth(self, arg: str) -> int: if arg == 'this': return 1 @@ -639,6 +636,21 @@ class PG(Protocol[T]): pass class CG(PG[T]): pass self.assertIsInstance(CG[int](), CG) + def test_cannot_instantiate_abstract(self): + @runtime + class P(Protocol): + @abc.abstractmethod + def ameth(self) -> int: + raise NotImplementedError + class B(P): + pass + class C(B): + def ameth(self) -> int: + return 26 + with self.assertRaises(TypeError): + B() + self.assertIsInstance(C(), P) + def test_subprotocols_extending(self): class P1(Protocol): def meth1(self): @@ -728,7 +740,20 @@ class C: @skipUnless(PY36, 'Python 3.6 required') def test_protocols_issubclass_py36(self): - pass + class OtherPoint: + x = 1 + y = 2 + label = 'other' + class Bad: pass + self.assertNotIsSubclass(MyPoint, Point) + self.assertIsSubclass(OtherPoint, Point) + self.assertNotIsSubclass(Bad, Point) + self.assertNotIsSubclass(MyPoint, Position) + self.assertIsSubclass(OtherPoint, Position) + self.assertIsSubclass(Concrete, Proto) + self.assertIsSubclass(Other, Proto) + self.assertNotIsSubclass(Concrete, Other) + self.assertNotIsSubclass(Other, Concrete) def test_protocols_isinstance(self): T = TypeVar('T') @@ -757,7 +782,34 @@ def meth(x): ... @skipUnless(PY36, 'Python 3.6 required') def test_protocols_isinstance_py36(self): - pass + class APoint: + def __init__(self, x, y, label): + self.x = x + self.y = y + self.label = label + class BPoint: + label = 'B' + def __init__(self, x, y): + self.x = x + self.y = y + class C: + def __init__(self, attr): + self.attr = attr + def meth(self, arg): + return 0 + class Bad: pass + self.assertIsInstance(APoint(1, 2, 'A'), Point) + self.assertIsInstance(BPoint(1, 2), Point) + self.assertNotIsInstance(MyPoint(), Point) + self.assertIsInstance(BPoint(1, 2), Position) + self.assertIsInstance(Other(), Proto) + self.assertIsInstance(Concrete(), Proto) + self.assertIsInstance(C(42), Proto) + self.assertNotIsInstance(Bad(), Proto) + self.assertNotIsInstance(Bad(), Point) + self.assertNotIsInstance(Bad(), Position) + self.assertNotIsInstance(Bad(), Concrete) + self.assertNotIsInstance(Other(), Concrete) def test_protocols_isinstance_init(self): T = TypeVar('T') From 8fc9a5e9d16b850a45e6ff9b80b7ac449019ee3e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 26 Apr 2017 00:01:49 +0200 Subject: [PATCH 15/18] Backport protocols to Python 2 --- python2/test_typing.py | 457 ++++++++++++++++++++++++++++++++++++++++- python2/typing.py | 276 ++++++++++++++++--------- src/test_typing.py | 16 +- src/typing.py | 14 +- 4 files changed, 645 insertions(+), 118 deletions(-) diff --git a/python2/test_typing.py b/python2/test_typing.py index b08389971..b6ff40780 100644 --- a/python2/test_typing.py +++ b/python2/test_typing.py @@ -14,11 +14,13 @@ from typing import Tuple, List, MutableMapping from typing import Callable from typing import Generic, ClassVar, GenericMeta +from typing import Protocol, runtime from typing import cast from typing import Type from typing import NewType from typing import NamedTuple from typing import Pattern, Match +import abc import typing import weakref try: @@ -511,6 +513,391 @@ def get(self, key, default=None): class ProtocolTests(BaseTestCase): + def test_basic_protocol(self): + @runtime + class P(Protocol): + def meth(self): + pass + class C(object): pass + class D(object): + def meth(self): + pass + self.assertIsSubclass(D, P) + self.assertIsInstance(D(), P) + self.assertNotIsSubclass(C, P) + self.assertNotIsInstance(C(), P) + + def test_everything_implements_empty_protocol(self): + @runtime + class Empty(Protocol): pass + class C(object): pass + for thing in (object, type, tuple, C): + self.assertIsSubclass(thing, Empty) + for thing in (object(), 1, (), typing): + self.assertIsInstance(thing, Empty) + + def test_no_inheritance_from_nominal(self): + class C(object): pass + class BP(Protocol): pass + with self.assertRaises(TypeError): + class P(C, Protocol): + pass + with self.assertRaises(TypeError): + class P(Protocol, C): + pass + with self.assertRaises(TypeError): + class P(BP, C, Protocol): + pass + class D(BP, C): pass + class E(C, BP): pass + self.assertNotIsInstance(D(), E) + self.assertNotIsInstance(E(), D) + + def test_no_instantiation(self): + class P(Protocol): pass + with self.assertRaises(TypeError): + P() + class C(P): pass + self.assertIsInstance(C(), C) + T = TypeVar('T') + class PG(Protocol[T]): pass + with self.assertRaises(TypeError): + PG() + with self.assertRaises(TypeError): + PG[int]() + with self.assertRaises(TypeError): + PG[T]() + class CG(PG[T]): pass + self.assertIsInstance(CG[int](), CG) + + def test_cannot_instantiate_abstract(self): + @runtime + class P(Protocol): + @abc.abstractmethod + def ameth(self): + raise NotImplementedError + class B(P): + pass + class C(B): + def ameth(self): + return 26 + with self.assertRaises(TypeError): + B() + self.assertIsInstance(C(), P) + + def test_subprotocols_extending(self): + class P1(Protocol): + def meth1(self): + pass + @runtime + class P2(P1, Protocol): + def meth2(self): + pass + class C(object): + def meth1(self): + pass + def meth2(self): + pass + class C1(object): + def meth1(self): + pass + class C2(object): + def meth2(self): + pass + self.assertNotIsInstance(C1(), P2) + self.assertNotIsInstance(C2(), P2) + self.assertNotIsSubclass(C1, P2) + self.assertNotIsSubclass(C2, P2) + self.assertIsInstance(C(), P2) + self.assertIsSubclass(C, P2) + + def test_subprotocols_merging(self): + class P1(Protocol): + def meth1(self): + pass + class P2(Protocol): + def meth2(self): + pass + @runtime + class P(P1, P2, Protocol): + pass + class C(object): + def meth1(self): + pass + def meth2(self): + pass + class C1(object): + def meth1(self): + pass + class C2(object): + def meth2(self): + pass + self.assertNotIsInstance(C1(), P) + self.assertNotIsInstance(C2(), P) + self.assertNotIsSubclass(C1, P) + self.assertNotIsSubclass(C2, P) + self.assertIsInstance(C(), P) + self.assertIsSubclass(C, P) + + def test_protocols_issubclass(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class BadP(Protocol): + x = 1 + class BadPG(Protocol[T]): + x = 1 + class C(object): + x = 1 + self.assertIsSubclass(C, P) + self.assertIsSubclass(C, PG) + self.assertIsSubclass(BadP, PG) + self.assertIsSubclass(PG[int], PG) + self.assertIsSubclass(BadPG[int], P) + self.assertIsSubclass(BadPG[T], PG) + with self.assertRaises(TypeError): + issubclass(C, PG[T]) + with self.assertRaises(TypeError): + issubclass(C, PG[C]) + with self.assertRaises(TypeError): + issubclass(C, BadP) + with self.assertRaises(TypeError): + issubclass(C, BadPG) + with self.assertRaises(TypeError): + issubclass(P, PG[T]) + with self.assertRaises(TypeError): + issubclass(PG, PG[int]) + + def test_protocols_isinstance(self): + T = TypeVar('T') + @runtime + class P(Protocol): + def meth(x): pass + @runtime + class PG(Protocol[T]): + def meth(x): pass + class BadP(Protocol): + def meth(x): pass + class BadPG(Protocol[T]): + def meth(x): pass + class C(object): + def meth(x): pass + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), PG) + with self.assertRaises(TypeError): + isinstance(C(), PG[T]) + with self.assertRaises(TypeError): + isinstance(C(), PG[C]) + with self.assertRaises(TypeError): + isinstance(C(), BadP) + with self.assertRaises(TypeError): + isinstance(C(), BadPG) + + def test_protocols_isinstance_init(self): + T = TypeVar('T') + @runtime + class P(Protocol): + x = 1 + @runtime + class PG(Protocol[T]): + x = 1 + class C(object): + def __init__(self, x): + self.x = x + self.assertIsInstance(C(1), P) + self.assertIsInstance(C(1), PG) + + def test_protocols_support_register(self): + @runtime + class P(Protocol): + x = 1 + class PM(Protocol): + def meth(self): pass + class D(PM): pass + class C(object): pass + D.register(C) + P.register(C) + self.assertIsInstance(C(), P) + self.assertIsInstance(C(), D) + + def test_none_blocks_implementation(self): + @runtime + class P(Protocol): + x = 1 + class A(object): + x = 1 + class B(A): + x = None + class C(object): + def __init__(self): + self.x = None + self.assertNotIsInstance(B(), P) + self.assertNotIsInstance(C(), P) + + def test_non_protocol_subclasses(self): + class P(Protocol): + x = 1 + @runtime + class PR(Protocol): + def meth(self): pass + class NonP(P): + x = 1 + class NonPR(PR): pass + class C(object): + x = 1 + class D(object): + def meth(self): pass + self.assertNotIsInstance(C(), NonP) + self.assertNotIsInstance(D(), NonPR) + self.assertNotIsSubclass(C, NonP) + self.assertNotIsSubclass(D, NonPR) + self.assertIsInstance(NonPR(), PR) + self.assertIsSubclass(NonPR, PR) + + def test_custom_subclasshook(self): + class P(Protocol): + x = 1 + class OKClass(object): pass + class BadClass(object): + x = 1 + class C(P): + @classmethod + def __subclasshook__(cls, other): + return other.__name__.startswith("OK") + self.assertIsInstance(OKClass(), C) + self.assertNotIsInstance(BadClass(), C) + self.assertIsSubclass(OKClass, C) + self.assertNotIsSubclass(BadClass, C) + + def test_defining_generic_protocols(self): + T = TypeVar('T') + S = TypeVar('S') + @runtime + class PR(Protocol[T, S]): + def meth(self): pass + class P(PR[int, T], Protocol[T]): + y = 1 + self.assertIsSubclass(PR[int, T], PR) + self.assertIsSubclass(P[str], PR) + with self.assertRaises(TypeError): + PR[int] + with self.assertRaises(TypeError): + P[int, str] + with self.assertRaises(TypeError): + PR[int, 1] + with self.assertRaises(TypeError): + PR[int, ClassVar] + class C(PR[int, T]): pass + self.assertIsInstance(C[str](), C) + + def test_init_called(self): + T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): + def __init__(self): + self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') + + def test_protocols_bad_subscripts(self): + T = TypeVar('T') + S = TypeVar('S') + with self.assertRaises(TypeError): + class P(Protocol[T, T]): pass + with self.assertRaises(TypeError): + class P(Protocol[int]): pass + with self.assertRaises(TypeError): + class P(Protocol[T], Protocol[S]): pass + with self.assertRaises(TypeError): + class P(typing.Mapping[T, S], Protocol[T]): pass + + def test_generic_protocols_repr(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertTrue(repr(P).endswith('P')) + self.assertTrue(repr(P[T, S]).endswith('P[~T, ~S]')) + self.assertTrue(repr(P[int, str]).endswith('P[int, str]')) + + def test_generic_protocols_eq(self): + T = TypeVar('T') + S = TypeVar('S') + class P(Protocol[T, S]): pass + self.assertEqual(P, P) + self.assertEqual(P[int, T], P[int, T]) + self.assertEqual(P[T, T][Tuple[T, S]][int, str], + P[Tuple[int, str], Tuple[int, str]]) + + def test_generic_protocols_special_from_generic(self): + T = TypeVar('T') + class P(Protocol[T]): pass + self.assertEqual(P.__parameters__, (T,)) + self.assertIs(P.__args__, None) + self.assertIs(P.__origin__, None) + self.assertEqual(P[int].__parameters__, ()) + self.assertEqual(P[int].__args__, (int,)) + self.assertIs(P[int].__origin__, P) + + def test_generic_protocols_special_from_protocol(self): + @runtime + class PR(Protocol): + x = 1 + class P(Protocol): + def meth(self): + pass + T = TypeVar('T') + class PG(Protocol[T]): + x = 1 + def meth(self): + pass + self.assertTrue(P._is_protocol) + self.assertTrue(PR._is_protocol) + self.assertTrue(PG._is_protocol) + with self.assertRaises(AttributeError): + self.assertFalse(P._is_runtime_protocol) + self.assertTrue(PR._is_runtime_protocol) + self.assertTrue(PG[int]._is_protocol) + self.assertEqual(P._get_protocol_attrs(), {'meth'}) + self.assertEqual(PR._get_protocol_attrs(), {'x'}) + self.assertEqual(frozenset(PG._get_protocol_attrs()), + frozenset({'x', 'meth'})) + self.assertEqual(frozenset(PG[int]._get_protocol_attrs()), + frozenset({'x', 'meth'})) + + def test_no_runtime_deco_on_nominal(self): + with self.assertRaises(TypeError): + @runtime + class C(object): pass + + def test_protocols_pickleable(self): + global P, CP # pickle wants to reference the class by name + T = TypeVar('T') + + @runtime + class P(Protocol[T]): + x = 1 + class CP(P[int]): + pass + + c = CP() + c.foo = 42 + c.bar = 'abc' + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(c, proto) + x = pickle.loads(z) + self.assertEqual(x.foo, 42) + self.assertEqual(x.bar, 'abc') + self.assertEqual(x.x, 1) + self.assertEqual(x.__dict__, {'foo': 42, 'bar': 'abc'}) + s = pickle.dumps(P) + D = pickle.loads(s) + class E(object): + x = 1 + self.assertIsInstance(E(), D) + def test_supports_int(self): self.assertIsSubclass(int, typing.SupportsInt) self.assertNotIsSubclass(str, typing.SupportsInt) @@ -538,9 +925,64 @@ def test_reversible(self): self.assertIsSubclass(list, typing.Reversible) self.assertNotIsSubclass(int, typing.Reversible) + def test_collection_protocols(self): + T = TypeVar('T') + class C(typing.Callable[[T], T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__call__', 'x'})) + class C(typing.Iterable[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'x'})) + class C(typing.Iterator[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__iter__', 'next', 'x'})) + class C(typing.Hashable, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__hash__', 'x'})) + class C(typing.Sized, Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', 'x'})) + class C(typing.Container[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__contains__', 'x'})) + if hasattr(collections_abc, 'Reversible'): + class C(typing.Reversible[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__iter__', 'x'})) + if hasattr(typing, 'Collection'): + class C(typing.Collection[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__len__', '__iter__', '__contains__', 'x'})) + class C(typing.Sequence[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', 'count', 'index', 'x'})) + # We use superset, since Python 3.2 does not have 'clear' + class C(typing.MutableSequence[T], Protocol[T]): x = 1 + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__reversed__', '__contains__', '__getitem__', + '__len__', '__iter__', '__setitem__', '__delitem__', + '__iadd__', 'count', 'index', 'extend', 'insert', + 'append', 'remove', 'pop', 'reverse', 'x'})) + class C(typing.Mapping[T, int], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', 'items', 'keys', 'values', 'get', 'x'})) + class C(typing.MutableMapping[int, T], Protocol[T]): x = 1 + # We use superset, since some versions also have '__ne__' + self.assertTrue(frozenset(C[int]._get_protocol_attrs()) >= + frozenset({'__len__', '__getitem__', '__iter__', '__contains__', + '__eq__', '__setitem__', '__delitem__', 'items', + 'keys', 'values', 'get', 'clear', 'pop', 'popitem', + 'update', 'setdefault', 'x'})) + if hasattr(typing, 'ContextManager'): + class C(typing.ContextManager[T], Protocol[T]): x = 1 + self.assertEqual(frozenset(C[int]._get_protocol_attrs()), + frozenset({'__enter__', '__exit__', 'x'})) + def test_protocol_instance_type_error(self): - with self.assertRaises(TypeError): - isinstance(0, typing.SupportsAbs) + isinstance(0, typing.SupportsAbs) class C1(typing.SupportsInt): def __int__(self): return 42 @@ -548,6 +990,13 @@ class C2(C1): pass c = C2() self.assertIsInstance(c, C1) + class C3(object): + def __int__(self): + return 42 + class C4(C3): + pass + c = C4() + self.assertIsInstance(c, typing.SupportsInt) class GenericTests(BaseTestCase): @@ -650,7 +1099,7 @@ def test_new_repr_complex(self): def test_new_repr_bare(self): T = TypeVar('T') self.assertEqual(repr(Generic[T]), 'typing.Generic[~T]') - self.assertEqual(repr(typing._Protocol[T]), 'typing.Protocol[~T]') + self.assertEqual(repr(typing.Protocol[T]), 'typing.Protocol[~T]') class C(typing.Dict[Any, Any]): pass # this line should just work repr(C.__mro__) @@ -934,7 +1383,7 @@ def test_fail_with_bare_generic(self): with self.assertRaises(TypeError): Tuple[Generic[T]] with self.assertRaises(TypeError): - List[typing._Protocol] + List[typing.Protocol] with self.assertRaises(TypeError): isinstance(1, Generic) diff --git a/python2/typing.py b/python2/typing.py index a5bfd3414..11d80c421 100644 --- a/python2/typing.py +++ b/python2/typing.py @@ -21,6 +21,7 @@ 'ClassVar', 'Generic', 'Optional', + 'Protocol', 'Tuple', 'Type', 'TypeVar', @@ -72,6 +73,7 @@ 'no_type_check', 'no_type_check_decorator', 'overload', + 'runtime', 'Text', 'TYPE_CHECKING', ] @@ -356,7 +358,7 @@ def _type_check(arg, msg): if ( type(arg).__name__ in ('_Union', '_Optional') and not getattr(arg, '__origin__', None) or - isinstance(arg, TypingMeta) and _gorg(arg) in (Generic, _Protocol) + isinstance(arg, TypingMeta) and _gorg(arg) in (Generic, Protocol) ): raise TypeError("Plain %s is not valid as type argument" % arg) return arg @@ -1054,10 +1056,11 @@ def __new__(cls, name, bases, namespace, if base is Generic: raise TypeError("Cannot inherit from plain Generic") if (isinstance(base, GenericMeta) and - base.__origin__ is Generic): + base.__origin__ in (Generic, Protocol)): if gvars is not None: raise TypeError( - "Cannot inherit from Generic[...] multiple types.") + "Cannot inherit from Generic[...] or" + " Protocol[...] multiple types.") gvars = base.__parameters__ if gvars is None: gvars = tvars @@ -1067,8 +1070,10 @@ def __new__(cls, name, bases, namespace, if not tvarset <= gvarset: raise TypeError( "Some type variables (%s) " - "are not listed in Generic[%s]" % + "are not listed in %s[%s]" % (", ".join(str(t) for t in tvars if t not in gvarset), + "Generic" if any(b.__origin__ is Generic + for b in bases) else "Protocol", ", ".join(str(g) for g in gvars))) tvars = gvars @@ -1215,25 +1220,21 @@ def __getitem__(self, params): "Parameter list to %s[...] cannot be empty" % _qualname(self)) msg = "Parameters to generic types must be types." params = tuple(_type_check(p, msg) for p in params) - if self is Generic: + if self in (Generic, Protocol): # Generic can only be subscripted with unique type variables. if not all(isinstance(p, TypeVar) for p in params): raise TypeError( - "Parameters to Generic[...] must all be type variables") + "Parameters to %r[...] must all be type variables", self) if len(set(params)) != len(params): raise TypeError( - "Parameters to Generic[...] must all be unique") + "Parameters to %r[...] must all be unique", self) tvars = params args = params elif self in (Tuple, Callable): tvars = _type_vars(params) args = params - elif self is _Protocol: - # _Protocol is internal, don't check anything. - tvars = params - args = params - elif self.__origin__ in (Generic, _Protocol): - # Can't subscript Generic[...] or _Protocol[...]. + elif self.__origin__ in (Generic, Protocol): + # Can't subscript Generic[...] or Protocol[...]. raise TypeError("Cannot subscript already-subscripted %s" % repr(self)) else: @@ -1261,6 +1262,12 @@ def __subclasscheck__(self, cls): if self is Generic: raise TypeError("Class %r cannot be used with class " "or instance checks" % self) + if (self.__dict__.get('_is_protocol', None) and + not self.__dict__.get('_is_runtime_protocol', None)): + if sys._getframe(1).f_globals['__name__'] in ['abc', 'functools']: + return False + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") return super(GenericMeta, self).__subclasscheck__(cls) def __instancecheck__(self, instance): @@ -1289,8 +1296,10 @@ def __setattr__(self, attr, value): super(GenericMeta, _gorg(self)).__setattr__(attr, value) -# Prevent checks for Generic to crash when defining Generic. +# Prevent checks for Generic, etc. to crash when defining Generic. Generic = None +Protocol = object() +Callable = object() def _generic_new(base_cls, cls, *args, **kwds): @@ -1406,7 +1415,150 @@ def __new__(cls, *args, **kwds): return _generic_new(tuple, cls, *args, **kwds) -class CallableMeta(GenericMeta): +def _collection_protocol(cls): + # Selected set of collections ABCs that are considered protocols. + name = cls.__name__ + return (name in ('ABC', 'Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'Sequence', 'MutableSequence', 'Mapping', 'MutableMapping', + 'AbstractContextManager', 'ContextManager', + 'AbstractAsyncContextManager', 'AsyncContextManager',) and + cls.__module__ in ('collections.abc', 'typing', 'contextlib', + '_abcoll', 'abc')) + + +class ProtocolMeta(GenericMeta): + """Internal metaclass for Protocol. + + This exists so Protocol classes can be generic without deriving + from Generic. + """ + + def __init__(cls, *args, **kwargs): + super(ProtocolMeta, cls).__init__(*args, **kwargs) + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol or + isinstance(b, ProtocolMeta) and + b.__origin__ is Protocol + for b in cls.__bases__) + if cls._is_protocol: + for base in cls.__mro__[1:]: + if not (base in (object, Generic, Callable) or + isinstance(base, TypingMeta) and base._is_protocol or + isinstance(base, GenericMeta) and base.__origin__ is Generic or + _collection_protocol(base)): + raise TypeError('Protocols can only inherit from other protocols,' + ' got %r' % base) + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + cls.__init__ = _no_init + + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + for attr in cls._get_protocol_attrs(): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = classmethod(_proto_hook) + + def __instancecheck__(self, instance): + # We need this method for situations where attributes are assigned in __init__ + if isinstance(instance, type): + # This looks like a fundamental limitation of Python 2. + # It cannot support runtime protocol metaclasses + return False + if issubclass(instance.__class__, self): + return True + if self._is_protocol: + return all(hasattr(instance, attr) and getattr(instance, attr) is not None + for attr in self._get_protocol_attrs()) + return False + + def _get_protocol_attrs(self): + attrs = set() + for base in self.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', '_get_protocol_attrs', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker', + '__metaclass__') and + getattr(base, attr, object()) is not None): + attrs.add(attr) + return attrs + + +class Protocol(object): + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol[T]): + def meth(self): + # type: () -> int + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self): + # type: () -> int + return 0 + + def func(x): + # type: (Proto[int]) -> int + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with @typing.runtime + act as simple-minded runtime protocols that checks only the presence of + given attributes, ignoring their type signatures. + """ + + __metaclass__ = ProtocolMeta + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if _geqv(cls, Protocol): + raise TypeError("Type Protocol cannot be instantiated; " + "it can be used only as a base class") + return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) + + +def runtime(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime can be only applied to protocol classes,' + ' got %r' % cls) + cls._is_runtime_protocol = True + return cls + + +class CallableMeta(ProtocolMeta): """ Metaclass for Callable.""" def __repr__(self): @@ -1597,85 +1749,6 @@ def utf8(value): return _overload_dummy -class _ProtocolMeta(GenericMeta): - """Internal metaclass for _Protocol. - - This exists so _Protocol classes can be generic without deriving - from Generic. - """ - - def __instancecheck__(self, obj): - if _Protocol not in self.__bases__: - return super(_ProtocolMeta, self).__instancecheck__(obj) - raise TypeError("Protocols cannot be used with isinstance().") - - def __subclasscheck__(self, cls): - if not self._is_protocol: - # No structural checks since this isn't a protocol. - return NotImplemented - - if self is _Protocol: - # Every class is a subclass of the empty protocol. - return True - - # Find all attributes defined in the protocol. - attrs = self._get_protocol_attrs() - - for attr in attrs: - if not any(attr in d.__dict__ for d in cls.__mro__): - return False - return True - - def _get_protocol_attrs(self): - # Get all Protocol base classes. - protocol_bases = [] - for c in self.__mro__: - if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol': - protocol_bases.append(c) - - # Get attributes included in protocol. - attrs = set() - for base in protocol_bases: - for attr in base.__dict__.keys(): - # Include attributes not defined in any non-protocol bases. - for c in self.__mro__: - if (c is not base and attr in c.__dict__ and - not getattr(c, '_is_protocol', False)): - break - else: - if (not attr.startswith('_abc_') and - attr != '__abstractmethods__' and - attr != '_is_protocol' and - attr != '__dict__' and - attr != '__args__' and - attr != '__slots__' and - attr != '_get_protocol_attrs' and - attr != '__next_in_mro__' and - attr != '__parameters__' and - attr != '__origin__' and - attr != '__orig_bases__' and - attr != '__extra__' and - attr != '__tree_hash__' and - attr != '__module__'): - attrs.add(attr) - - return attrs - - -class _Protocol(object): - """Internal base class for protocol classes. - - This implements a simple-minded structural issubclass check - (similar but more general than the one-offs in collections.abc - such as Hashable). - """ - - __metaclass__ = _ProtocolMeta - __slots__ = () - - _is_protocol = True - - # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. @@ -1692,7 +1765,8 @@ class Iterator(Iterable[T_co]): __extra__ = collections_abc.Iterator -class SupportsInt(_Protocol): +@runtime +class SupportsInt(Protocol): __slots__ = () @abstractmethod @@ -1700,7 +1774,8 @@ def __int__(self): pass -class SupportsFloat(_Protocol): +@runtime +class SupportsFloat(Protocol): __slots__ = () @abstractmethod @@ -1708,7 +1783,8 @@ def __float__(self): pass -class SupportsComplex(_Protocol): +@runtime +class SupportsComplex(Protocol): __slots__ = () @abstractmethod @@ -1716,7 +1792,8 @@ def __complex__(self): pass -class SupportsAbs(_Protocol[T_co]): +@runtime +class SupportsAbs(Protocol[T_co]): __slots__ = () @abstractmethod @@ -1729,7 +1806,8 @@ class Reversible(Iterable[T_co]): __slots__ = () __extra__ = collections_abc.Reversible else: - class Reversible(_Protocol[T_co]): + @runtime + class Reversible(Protocol[T_co]): __slots__ = () @abstractmethod diff --git a/src/test_typing.py b/src/test_typing.py index ee65be7ed..52856b1c2 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -783,15 +783,15 @@ def meth(x): ... @skipUnless(PY36, 'Python 3.6 required') def test_protocols_isinstance_py36(self): class APoint: - def __init__(self, x, y, label): - self.x = x - self.y = y - self.label = label + def __init__(self, x, y, label): + self.x = x + self.y = y + self.label = label class BPoint: - label = 'B' - def __init__(self, x, y): - self.x = x - self.y = y + label = 'B' + def __init__(self, x, y): + self.x = x + self.y = y class C: def __init__(self, attr): self.attr = attr diff --git a/src/typing.py b/src/typing.py index f43585719..bf9291c27 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1332,7 +1332,7 @@ def _collection_protocol(cls): '_abcoll', 'abc')) -class _ProtocolMeta(GenericMeta): +class ProtocolMeta(GenericMeta): """Internal metaclass for Protocol. This exists so Protocol classes can be generic without deriving @@ -1343,7 +1343,7 @@ def __init__(cls, *args, **kwargs): super().__init__(*args, **kwargs) if not cls.__dict__.get('_is_protocol', None): cls._is_protocol = any(b is Protocol or - isinstance(b, _ProtocolMeta) and + isinstance(b, ProtocolMeta) and b.__origin__ is Protocol for b in cls.__bases__) if cls._is_protocol: @@ -1404,7 +1404,7 @@ def _get_protocol_attrs(self): return attrs -class Protocol(metaclass=_ProtocolMeta): +class Protocol(metaclass=ProtocolMeta): """Base class for protocol classes. Protocol classes are defined as:: class Proto(Protocol[T]): @@ -1415,8 +1415,8 @@ def meth(self) -> T: structural subtyping (static duck-typing), for example:: class C: - def meth(self) -> int: - return 0 + def meth(self) -> int: + return 0 def func(x: Proto[int]) -> int: return x.meth() @@ -1446,14 +1446,14 @@ def runtime(cls): This allows a simple-minded structural check very similar to the one-offs in collections.abc such as Hashable. """ - if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + if not isinstance(cls, ProtocolMeta) or not cls._is_protocol: raise TypeError('@runtime can be only applied to protocol classes,' ' got %r' % cls) cls._is_runtime_protocol = True return cls -class CallableMeta(_ProtocolMeta): +class CallableMeta(ProtocolMeta): """Metaclass for Callable (internal).""" def __repr__(self): From b6e526e23b2d78298e4995bdb4183cab92b6c89c Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 26 Apr 2017 15:25:43 +0200 Subject: [PATCH 16/18] Corner case for PEP 526 subprotocols --- src/test_typing.py | 1 + src/typing.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/test_typing.py b/src/test_typing.py index 52856b1c2..55fadd131 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -754,6 +754,7 @@ class Bad: pass self.assertIsSubclass(Other, Proto) self.assertNotIsSubclass(Concrete, Other) self.assertNotIsSubclass(Other, Concrete) + self.assertIsSubclass(Point, Position) def test_protocols_isinstance(self): T = TypeVar('T') diff --git a/src/typing.py b/src/typing.py index bf9291c27..532b68859 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1369,6 +1369,9 @@ def _proto_hook(other): if base.__dict__[attr] is None: return NotImplemented break + if (attr in getattr(base, '__annotations__', {}) and + isinstance(other, ProtocolMeta) and other._is_protocol): + break else: return NotImplemented return True From d3a5491e6e7349acddba7d3e2b7f75f75a965859 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 13 May 2017 10:39:51 +0200 Subject: [PATCH 17/18] Add test for named tuple with protocol --- src/test_typing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/test_typing.py b/src/test_typing.py index 55fadd131..1f4c1fb92 100644 --- a/src/test_typing.py +++ b/src/test_typing.py @@ -567,13 +567,17 @@ def meth(self, arg: str) -> int: if arg == 'this': return 1 return 0 + +class NT(NamedTuple): + x: int + y: int """ if PY36: exec(PY36_PROTOCOL_TESTS) else: # fake names for the sake of static analysis - Coordinate = Point = MyPoint = BadPoint = object + Coordinate = Point = MyPoint = BadPoint = NT = object XAxis = YAxis = Position = Proto = Concrete = Other = object @@ -755,6 +759,7 @@ class Bad: pass self.assertNotIsSubclass(Concrete, Other) self.assertNotIsSubclass(Other, Concrete) self.assertIsSubclass(Point, Position) + self.assertIsSubclass(NT, Position) def test_protocols_isinstance(self): T = TypeVar('T') @@ -811,6 +816,7 @@ class Bad: pass self.assertNotIsInstance(Bad(), Position) self.assertNotIsInstance(Bad(), Concrete) self.assertNotIsInstance(Other(), Concrete) + self.assertIsInstance(NT(1, 2), Position) def test_protocols_isinstance_init(self): T = TypeVar('T') From 904c978670102c6e116c096bc3706e290f801237 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 28 Jul 2017 21:29:22 +0200 Subject: [PATCH 18/18] Fix broken merge --- python2/typing.py | 4 ++-- src/typing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python2/typing.py b/python2/typing.py index 06a6f421d..9abb5d3dd 100644 --- a/python2/typing.py +++ b/python2/typing.py @@ -1479,7 +1479,7 @@ def _get_protocol_attrs(self): '__orig_bases__', '__extra__', '__tree_hash__', '__doc__', '__subclasshook__', '__init__', '__new__', '__module__', '_MutableMapping__marker', - '__metaclass__') and + '__metaclass__', '_gorg') and getattr(base, attr, object()) is not None): attrs.add(attr) return attrs @@ -1517,7 +1517,7 @@ def func(x): _is_protocol = True def __new__(cls, *args, **kwds): - if _geqv(cls, Protocol): + if cls._gorg is Protocol: raise TypeError("Type Protocol cannot be instantiated; " "it can be used only as a base class") return _generic_new(cls.__next_in_mro__, cls, *args, **kwds) diff --git a/src/typing.py b/src/typing.py index ddf1d2f94..2ab9d1fee 100644 --- a/src/typing.py +++ b/src/typing.py @@ -1382,7 +1382,7 @@ def _get_protocol_attrs(self): '__next_in_mro__', '__parameters__', '__origin__', '__orig_bases__', '__extra__', '__tree_hash__', '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker') and + '__module__', '_MutableMapping__marker', '_gorg') and getattr(base, attr, object()) is not None): attrs.add(attr) return attrs @@ -1416,7 +1416,7 @@ def func(x: Proto[int]) -> int: _is_protocol = True def __new__(cls, *args, **kwds): - if _geqv(cls, Protocol): + if cls._gorg is Protocol: raise TypeError("Type Protocol cannot be instantiated; " "it can be used only as a base class") return _generic_new(cls.__next_in_mro__, cls, *args, **kwds)