From 5ff092e8733f5dee67bb78013ef531e89a86bec0 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 11 May 2024 18:53:49 +0100 Subject: [PATCH 1/3] Backport parameter defaults for `(Async)Generator` and `(Async)ContextManager` --- CHANGELOG.md | 9 ++++ doc/index.rst | 23 ++++++++++- src/test_typing_extensions.py | 51 +++++++++++++++++------ src/typing_extensions.py | 78 +++++++++++++++++++++++++++++++++-- 4 files changed, 142 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbad4264..5f1b8b43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,15 @@ - At runtime, `assert_never` now includes the repr of the argument in the `AssertionError`. Patch by Hashem, backporting of the original fix https://github.com/python/cpython/pull/91720 by Jelle Zijlstra. +- The second and third parameters of `typing_extensions.Generator`, + and the second parameter of `typing_extensions.AsyncGenerator`, + now default to `None`. This matches the behaviour of `typing.Generator` + and `typing.AsyncGenerator` on Python 3.13+. +- `typing.ContextManager` and `typing.AsyncContextManager` now have an + optional second parameter, which defaults to `Optional[bool]`. The new + parameter signifies the return type of the `__(a)exit__` method, + matching `typing.ContextManager` and `typing.AsyncContextManager` on + Python 3.13+. # Release 4.11.0 (April 5, 2024) diff --git a/doc/index.rst b/doc/index.rst index 3486ae74..0b4c4bda 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -885,8 +885,8 @@ Annotation metadata Pure aliases ~~~~~~~~~~~~ -These are simply re-exported from the :mod:`typing` module on all supported -versions of Python. They are listed here for completeness. +Most of these are simply re-exported from the :mod:`typing` module on all supported +versions of Python, but all are listed here for completeness. .. class:: AbstractSet @@ -904,10 +904,19 @@ versions of Python. They are listed here for completeness. See :py:class:`typing.AsyncContextManager`. In ``typing`` since 3.5.4 and 3.6.2. + .. versionchanged:: 4.12.0 + + ``AsyncContextManager`` now has an optional second parameter, defaulting to + ``Optional[bool]``, signifying the return type of the ``__aexit__`` method. + .. class:: AsyncGenerator See :py:class:`typing.AsyncGenerator`. In ``typing`` since 3.6.1. + .. versionchanged:: 4.12.0 + + The second type parameter is now optional (it defaults to ``None``). + .. class:: AsyncIterable See :py:class:`typing.AsyncIterable`. In ``typing`` since 3.5.2. @@ -956,6 +965,11 @@ versions of Python. They are listed here for completeness. See :py:class:`typing.ContextManager`. In ``typing`` since 3.5.4. + .. versionchanged:: 4.12.0 + + ``AsyncContextManager`` now has an optional second parameter, defaulting to + ``Optional[bool]``, signifying the return type of the ``__aexit__`` method. + .. class:: Coroutine See :py:class:`typing.Coroutine`. In ``typing`` since 3.5.3. @@ -996,6 +1010,11 @@ versions of Python. They are listed here for completeness. .. versionadded:: 4.7.0 + .. versionchanged:: 4.12.0 + + The second type and third type parameters are now optional + (they both default to ``None``). + .. class:: Generic See :py:class:`typing.Generic`. diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index c7c2f0d5..ad353735 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -41,6 +41,8 @@ from typing_extensions import Doc, NoDefault from _typed_dict_test_helper import Foo, FooGeneric, VeryAnnotated +NoneType = type(None) + # Flags used to mark tests that only apply after a specific # version of the typing module. TYPING_3_9_0 = sys.version_info[:3] >= (3, 9, 0) @@ -1626,6 +1628,17 @@ async def g(): yield 0 self.assertNotIsInstance(type(g), G) self.assertNotIsInstance(g, G) + def test_generator_default(self): + g1 = typing_extensions.Generator[int] + g2 = typing_extensions.Generator[int, None, None] + self.assertEqual(get_args(g1), (int, type(None), type(None))) + self.assertEqual(get_args(g1), get_args(g2)) + + g3 = typing_extensions.Generator[int, float] + g4 = typing_extensions.Generator[int, float, None] + self.assertEqual(get_args(g3), (int, float, type(None))) + self.assertEqual(get_args(g3), get_args(g4)) + class OtherABCTests(BaseTestCase): @@ -1638,6 +1651,12 @@ def manager(): self.assertIsInstance(cm, typing_extensions.ContextManager) self.assertNotIsInstance(42, typing_extensions.ContextManager) + def test_contextmanager_type_params(self): + cm1 = typing_extensions.ContextManager[int] + self.assertEqual(get_args(cm1), (int, typing.Optional[bool])) + cm2 = typing_extensions.ContextManager[int, None] + self.assertEqual(get_args(cm2), (int, NoneType)) + def test_async_contextmanager(self): class NotACM: pass @@ -1649,11 +1668,20 @@ def manager(): cm = manager() self.assertNotIsInstance(cm, typing_extensions.AsyncContextManager) - self.assertEqual(typing_extensions.AsyncContextManager[int].__args__, (int,)) + self.assertEqual( + typing_extensions.AsyncContextManager[int].__args__, + (int, typing.Optional[bool]) + ) with self.assertRaises(TypeError): isinstance(42, typing_extensions.AsyncContextManager[int]) with self.assertRaises(TypeError): - typing_extensions.AsyncContextManager[int, str] + typing_extensions.AsyncContextManager[int, str, float] + + def test_asynccontextmanager_type_params(self): + cm1 = typing_extensions.AsyncContextManager[int] + self.assertEqual(get_args(cm1), (int, typing.Optional[bool])) + cm2 = typing_extensions.AsyncContextManager[int, None] + self.assertEqual(get_args(cm2), (int, NoneType)) class TypeTests(BaseTestCase): @@ -5533,28 +5561,25 @@ def test_all_names_in___all__(self): self.assertLessEqual(exclude, actual_names) def test_typing_extensions_defers_when_possible(self): - exclude = { - 'dataclass_transform', - 'overload', - 'ParamSpec', - 'TypeVar', - 'TypeVarTuple', - 'get_type_hints', - } + exclude = set() if sys.version_info < (3, 10): exclude |= {'get_args', 'get_origin'} if sys.version_info < (3, 10, 1): exclude |= {"Literal"} if sys.version_info < (3, 11): - exclude |= {'final', 'Any', 'NewType'} + exclude |= {'final', 'Any', 'NewType', 'overload'} if sys.version_info < (3, 12): exclude |= { 'SupportsAbs', 'SupportsBytes', 'SupportsComplex', 'SupportsFloat', 'SupportsIndex', 'SupportsInt', - 'SupportsRound', 'Unpack', + 'SupportsRound', 'Unpack', 'dataclass_transform', } if sys.version_info < (3, 13): - exclude |= {'NamedTuple', 'Protocol', 'runtime_checkable'} + exclude |= { + 'NamedTuple', 'Protocol', 'runtime_checkable', 'Generator', + 'AsyncGenerator', 'ContextManager', 'AsyncContextManager', + 'ParamSpec', 'TypeVar', 'TypeVarTuple', 'get_type_hints', + } if not typing_extensions._PEP_728_IMPLEMENTED: exclude |= {'TypedDict', 'is_typeddict'} for item in typing_extensions.__all__: diff --git a/src/typing_extensions.py b/src/typing_extensions.py index ec145c0a..6a197c06 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -1,6 +1,7 @@ import abc import collections import collections.abc +import contextlib import functools import inspect import operator @@ -408,17 +409,87 @@ def clear_overloads(): AsyncIterable = typing.AsyncIterable AsyncIterator = typing.AsyncIterator Deque = typing.Deque -ContextManager = typing.ContextManager -AsyncContextManager = typing.AsyncContextManager DefaultDict = typing.DefaultDict OrderedDict = typing.OrderedDict Counter = typing.Counter ChainMap = typing.ChainMap -AsyncGenerator = typing.AsyncGenerator Text = typing.Text TYPE_CHECKING = typing.TYPE_CHECKING +if sys.version_info >= (3, 13, 0, "beta"): + from typing import ContextManager, AsyncContextManager, Generator, AsyncGenerator +else: + def _is_dunder(attr): + return attr.startswith('__') and attr.endswith('__') + + _special_generic_alias_base = getattr(typing, "_SpecialGenericAlias", typing._GenericAlias) + + class _SpecialGenericAlias(_special_generic_alias_base, _root=True): + def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): + if _special_generic_alias_base is typing._GenericAlias: + self.__origin__ = origin + self._nparams = nparams + super().__init__(origin, nparams, special=True, inst=inst, name=name) + else: + super().__init__(origin, nparams, inst=inst, name=name) + self._defaults = defaults + + def __setattr__(self, attr, val): + allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'} + if _special_generic_alias_base is typing._GenericAlias: + allowed_attrs.add("__origin__") + if _is_dunder(attr) or attr in allowed_attrs: + object.__setattr__(self, attr, val) + else: + setattr(self.__origin__, attr, val) + + @typing._tp_cache + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) + if (self._defaults + and len(params) < self._nparams + and len(params) + len(self._defaults) >= self._nparams + ): + params = (*params, *self._defaults[len(params) - self._nparams:]) + actual_len = len(params) + + if actual_len != self._nparams: + if self._defaults: + expected = f"at least {self._nparams - len(self._defaults)}" + else: + expected = str(self._nparams) + if not self._nparams: + raise TypeError(f"{self} is not a generic class") + raise TypeError(f"Too {'many' if actual_len > self._nparams else 'few'} arguments for {self};" + f" actual {actual_len}, expected {expected}") + return self.copy_with(params) + + + _NoneType = type(None) + Generator = _SpecialGenericAlias( + collections.abc.Generator, 3, defaults=(_NoneType, _NoneType) + ) + AsyncGenerator = _SpecialGenericAlias( + collections.abc.AsyncGenerator, 2, defaults=(_NoneType,) + ) + ContextManager = _SpecialGenericAlias( + contextlib.AbstractContextManager, + 2, + name="ContextManager", + defaults=(typing.Optional[bool],) + ) + AsyncContextManager = _SpecialGenericAlias( + contextlib.AbstractAsyncContextManager, + 2, + name="AsyncContextManager", + defaults=(typing.Optional[bool],) + ) + + _PROTO_ALLOWLIST = { 'collections.abc': [ 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', @@ -3344,7 +3415,6 @@ def __eq__(self, other: object) -> bool: Dict = typing.Dict ForwardRef = typing.ForwardRef FrozenSet = typing.FrozenSet -Generator = typing.Generator Generic = typing.Generic Hashable = typing.Hashable IO = typing.IO From faf949e9af441bfc7730b87ccbf7184c08cfad97 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 11 May 2024 18:59:06 +0100 Subject: [PATCH 2/3] lint --- src/typing_extensions.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 6a197c06..a85f8d39 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -423,7 +423,9 @@ def clear_overloads(): def _is_dunder(attr): return attr.startswith('__') and attr.endswith('__') - _special_generic_alias_base = getattr(typing, "_SpecialGenericAlias", typing._GenericAlias) + _special_generic_alias_base = getattr( + typing, "_SpecialGenericAlias", typing._GenericAlias + ) class _SpecialGenericAlias(_special_generic_alias_base, _root=True): def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): @@ -450,7 +452,8 @@ def __getitem__(self, params): params = (params,) msg = "Parameters to generic types must be types." params = tuple(typing._type_check(p, msg) for p in params) - if (self._defaults + if ( + self._defaults and len(params) < self._nparams and len(params) + len(self._defaults) >= self._nparams ): @@ -464,11 +467,13 @@ def __getitem__(self, params): expected = str(self._nparams) if not self._nparams: raise TypeError(f"{self} is not a generic class") - raise TypeError(f"Too {'many' if actual_len > self._nparams else 'few'} arguments for {self};" - f" actual {actual_len}, expected {expected}") + raise TypeError( + f"Too {'many' if actual_len > self._nparams else 'few'}" + f" arguments for {self};" + f" actual {actual_len}, expected {expected}" + ) return self.copy_with(params) - _NoneType = type(None) Generator = _SpecialGenericAlias( collections.abc.Generator, 3, defaults=(_NoneType, _NoneType) From 6578c849708404316cb5c0443d3394a1e65ab955 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 11 May 2024 19:16:13 +0100 Subject: [PATCH 3/3] add some clarifying comments --- src/typing_extensions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/typing_extensions.py b/src/typing_extensions.py index a85f8d39..b4ca1bc2 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -423,6 +423,7 @@ def clear_overloads(): def _is_dunder(attr): return attr.startswith('__') and attr.endswith('__') + # Python <3.9 doesn't have typing._SpecialGenericAlias _special_generic_alias_base = getattr( typing, "_SpecialGenericAlias", typing._GenericAlias ) @@ -430,16 +431,19 @@ def _is_dunder(attr): class _SpecialGenericAlias(_special_generic_alias_base, _root=True): def __init__(self, origin, nparams, *, inst=True, name=None, defaults=()): if _special_generic_alias_base is typing._GenericAlias: + # Python <3.9 self.__origin__ = origin self._nparams = nparams super().__init__(origin, nparams, special=True, inst=inst, name=name) else: + # Python >= 3.9 super().__init__(origin, nparams, inst=inst, name=name) self._defaults = defaults def __setattr__(self, attr, val): allowed_attrs = {'_name', '_inst', '_nparams', '_defaults'} if _special_generic_alias_base is typing._GenericAlias: + # Python <3.9 allowed_attrs.add("__origin__") if _is_dunder(attr) or attr in allowed_attrs: object.__setattr__(self, attr, val)