From e1caf1d4d758aff78a98b99d0addb86c975e29e9 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Thu, 13 Apr 2023 14:44:30 +0100 Subject: [PATCH] Backport the ability to define `__init__` methods on Protocol classes --- CHANGELOG.md | 4 ++++ src/test_typing_extensions.py | 26 ++++++++++++++++++++++++++ src/typing_extensions.py | 3 ++- 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d22b4049..7fa13198 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,10 @@ `typing_extensions` may no longer be considered instances of that protocol using the new release, and vice versa. Most users are unlikely to be affected by this change. Patch by Alex Waygood. +- Backport the ability to define `__init__` methods on Protocol classes, a + change made in Python 3.11 (originally implemented in + https://github.com/python/cpython/pull/31628 by Adrian Garcia Badaracco). + Patch by Alex Waygood. - Speedup `isinstance(3, typing_extensions.SupportsIndex)` by >10x on Python <3.12. Patch by Alex Waygood. diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 0e4a1b7e..19a116fd 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1454,6 +1454,32 @@ class PG(Protocol[T]): pass class CG(PG[T]): pass self.assertIsInstance(CG[int](), CG) + def test_protocol_defining_init_does_not_get_overridden(self): + # check that P.__init__ doesn't get clobbered + # see https://bugs.python.org/issue44807 + + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + class C: pass + + c = C() + P.__init__(c, 1) + self.assertEqual(c.x, 1) + + def test_concrete_class_inheriting_init_from_protocol(self): + class P(Protocol): + x: int + def __init__(self, x: int) -> None: + self.x = x + + class C(P): pass + + c = C(1) + self.assertIsInstance(c, C) + self.assertEqual(c.x, 1) + def test_cannot_instantiate_abstract(self): @runtime_checkable class P(Protocol): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index b7864a98..16a8fdd3 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -662,7 +662,8 @@ def _proto_hook(other): isinstance(base, _ProtocolMeta) and base._is_protocol): raise TypeError('Protocols can only inherit from other' f' protocols, got {repr(base)}') - cls.__init__ = _no_init + if cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it