diff --git a/Lib/operator.py b/Lib/operator.py index fb58851fa6ef67..f18cc8630963b1 100644 --- a/Lib/operator.py +++ b/Lib/operator.py @@ -10,14 +10,15 @@ This is the pure Python implementation of the module. """ -__all__ = ['abs', 'add', 'and_', 'attrgetter', 'concat', 'contains', 'countOf', - 'delitem', 'eq', 'floordiv', 'ge', 'getitem', 'gt', 'iadd', 'iand', - 'iconcat', 'ifloordiv', 'ilshift', 'imatmul', 'imod', 'imul', - 'index', 'indexOf', 'inv', 'invert', 'ior', 'ipow', 'irshift', - 'is_', 'is_not', 'isub', 'itemgetter', 'itruediv', 'ixor', 'le', - 'length_hint', 'lshift', 'lt', 'matmul', 'methodcaller', 'mod', - 'mul', 'ne', 'neg', 'not_', 'or_', 'pos', 'pow', 'rshift', - 'setitem', 'sub', 'truediv', 'truth', 'xor'] +__all__ = [ + 'abs', 'add', 'aiter', 'anext', 'and_', 'attrgetter', 'concat', 'contains', + 'countOf', 'delitem', 'eq', 'floordiv', 'ge', 'getitem', 'gt', 'iadd', + 'iand', 'iconcat', 'ifloordiv', 'ilshift', 'imatmul', 'imod', 'imul', + 'index', 'indexOf', 'inv', 'invert', 'ior', 'ipow', 'irshift', 'is_', + 'is_not', 'isub', 'itemgetter', 'itruediv', 'ixor', 'le', 'length_hint', + 'lshift', 'lt', 'matmul', 'methodcaller', 'mod', 'mul', 'ne', 'neg', 'not_', + 'or_', 'pos', 'pow', 'rshift', 'setitem', 'sub', 'truediv', 'truth', 'xor', +] from builtins import abs as _abs @@ -404,6 +405,59 @@ def ixor(a, b): return a +# Asynchronous Iterator Operations ********************************************# + + +_NOT_PROVIDED = object() # sentinel object to detect when a kwarg was not given + + +def aiter(obj, sentinel=_NOT_PROVIDED): + """aiter(async_iterable) -> async_iterator + aiter(async_callable, sentinel) -> async_iterator + + Like the iter() builtin but for async iterables and callables. + """ + from collections.abc import AsyncIterable, AsyncIterator + if sentinel is _NOT_PROVIDED: + if not isinstance(obj, AsyncIterable): + raise TypeError(f'aiter expected an AsyncIterable, got {type(obj)}') + ait = type(obj).__aiter__(obj) + if not isinstance(ait, AsyncIterator): + raise TypeError(f'obj.__aiter__() returned non-AsyncIterator: {type(ait)}') + return ait + + if not callable(obj): + raise TypeError(f'aiter expected an async callable, got {type(obj)}') + + async def ait(): + while True: + value = await obj() + if value == sentinel: + break + yield value + + return ait() + + +async def anext(async_iterator, default=_NOT_PROVIDED): + """anext(async_iterator[, default]) + + Return the next item from the async iterator. + If default is given and the iterator is exhausted, + it is returned instead of raising StopAsyncIteration. + """ + from collections.abc import AsyncIterator + if not isinstance(async_iterator, AsyncIterator): + raise TypeError(f'anext expected an AsyncIterator, got {type(async_iterator)}') + anxt = type(async_iterator).__anext__ + try: + return await anxt(async_iterator) + except StopAsyncIteration: + if default is _NOT_PROVIDED: + raise + return default + + try: from _operator import * except ImportError: diff --git a/Lib/test/test_asyncgen.py b/Lib/test/test_asyncgen.py index 1f7e05b42be99c..5c2dad836ccb19 100644 --- a/Lib/test/test_asyncgen.py +++ b/Lib/test/test_asyncgen.py @@ -1,4 +1,5 @@ import inspect +import operator import types import unittest @@ -372,6 +373,86 @@ def tearDown(self): self.loop = None asyncio.set_event_loop_policy(None) + def test_async_gen_operator_anext(self): + async def gen(): + yield 1 + yield 2 + g = gen() + async def consume(): + results = [] + results.append(await operator.anext(g)) + results.append(await operator.anext(g)) + results.append(await operator.anext(g, 'buckle my shoe')) + return results + res = self.loop.run_until_complete(consume()) + self.assertEqual(res, [1, 2, 'buckle my shoe']) + with self.assertRaises(StopAsyncIteration): + self.loop.run_until_complete(consume()) + + def test_async_gen_operator_aiter(self): + async def gen(): + yield 1 + yield 2 + g = gen() + async def consume(): + return [i async for i in operator.aiter(g)] + res = self.loop.run_until_complete(consume()) + self.assertEqual(res, [1, 2]) + + def test_async_gen_operator_aiter_class(self): + results = [] + loop = self.loop + class Gen: + async def __aiter__(self): + yield 1 + await asyncio.sleep(0.01) + yield 2 + g = Gen() + async def consume(): + ait = operator.aiter(g) + while True: + try: + results.append(await operator.anext(ait)) + except StopAsyncIteration: + break + self.loop.run_until_complete(consume()) + self.assertEqual(results, [1, 2]) + + def test_async_gen_operator_aiter_2_arg(self): + async def gen(): + yield 1 + yield 2 + yield None + g = gen() + async def foo(): + return await operator.anext(g) + async def consume(): + return [i async for i in operator.aiter(foo, None)] + res = self.loop.run_until_complete(consume()) + self.assertEqual(res, [1, 2]) + + def test_operator_anext_bad_args(self): + self._test_bad_args(operator.anext) + + def test_operator_aiter_bad_args(self): + self._test_bad_args(operator.aiter) + + def _test_bad_args(self, afn): + async def gen(): + yield 1 + async def call_with_no_args(): + await afn() + async def call_with_3_args(): + await afn(gen(), 1, 2) + async def call_with_bad_args(): + await afn(1, gen()) + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_no_args()) + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_3_args()) + with self.assertRaises(TypeError): + self.loop.run_until_complete(call_with_bad_args()) + async def to_list(self, gen): res = [] async for i in gen: diff --git a/Misc/NEWS.d/next/Library/2018-08-24-01-08-09.bpo-31861.-q9RKJ.rst b/Misc/NEWS.d/next/Library/2018-08-24-01-08-09.bpo-31861.-q9RKJ.rst new file mode 100644 index 00000000000000..f218a1aa2af71e --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-08-24-01-08-09.bpo-31861.-q9RKJ.rst @@ -0,0 +1 @@ +Add the operator.aiter and operator.anext functions. Patch by Josh Bronson.