diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index a869993a1de3fe2..f4e1ba58866e532 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -8,6 +8,7 @@ ) import collections +import enum import heapq from types import GenericAlias @@ -30,9 +31,10 @@ class QueueShutDown(Exception): pass -_queue_alive = "alive" -_queue_shutdown = "shutdown" -_queue_shutdown_immediate = "shutdown-immediate" +class _QueueState(enum.Enum): + ALIVE = "alive" + SHUTDOWN = "shutdown" + SHUTDOWN_IMMEDIATE = "shutdown-immediate" class Queue(mixins._LoopBoundMixin): @@ -58,7 +60,7 @@ def __init__(self, maxsize=0): self._finished = locks.Event() self._finished.set() self._init(maxsize) - self.shutdown_state = _queue_alive + self._shutdown_state = _QueueState.ALIVE # These three are overridable in subclasses. @@ -99,6 +101,8 @@ def _format(self): result += f' _putters[{len(self._putters)}]' if self._unfinished_tasks: result += f' tasks={self._unfinished_tasks}' + if self._shutdown_state is not _QueueState.ALIVE: + result += f' shutdown={self._shutdown_state.value}' return result def qsize(self): @@ -131,7 +135,7 @@ async def put(self, item): Put an item into the queue. If the queue is full, wait until a free slot is available before adding item. """ - if self.shutdown_state != _queue_alive: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown while self.full(): putter = self._get_loop().create_future() @@ -152,7 +156,7 @@ async def put(self, item): # the call. Wake up the next in line. self._wakeup_next(self._putters) raise - if self.shutdown_state != _queue_alive: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown return self.put_nowait(item) @@ -161,7 +165,7 @@ def put_nowait(self, item): If no free slot is immediately available, raise QueueFull. """ - if self.shutdown_state != _queue_alive: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown if self.full(): raise QueueFull @@ -175,10 +179,10 @@ async def get(self): If queue is empty, wait until an item is available. """ - if self.shutdown_state == _queue_shutdown_immediate: + if self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise QueueShutDown while self.empty(): - if self.shutdown_state != _queue_alive: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown getter = self._get_loop().create_future() self._getters.append(getter) @@ -198,7 +202,7 @@ async def get(self): # the call. Wake up the next in line. self._wakeup_next(self._getters) raise - if self.shutdown_state == _queue_shutdown_immediate: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown return self.get_nowait() @@ -208,10 +212,10 @@ def get_nowait(self): Return an item if one is immediately available, else raise QueueEmpty. """ if self.empty(): - if self.shutdown_state != _queue_alive: + if self._shutdown_state is not _QueueState.ALIVE: raise QueueShutDown raise QueueEmpty - elif self.shutdown_state == _queue_shutdown_immediate: + elif self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise QueueShutDown item = self._get() self._wakeup_next(self._putters) @@ -257,18 +261,24 @@ def shutdown(self, immediate=False): All blocked callers of put() will be unblocked, and also get() and join() if 'immediate'. The QueueShutDown exception is raised. """ + if self._shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + return + if immediate: - self.shutdown_state = _queue_shutdown_immediate + self._shutdown_state = _QueueState.SHUTDOWN_IMMEDIATE while self._getters: getter = self._getters.popleft() if not getter.done(): getter.set_result(None) else: - self.shutdown_state = _queue_shutdown + self._shutdown_state = _QueueState.SHUTDOWN while self._putters: putter = self._putters.popleft() if not putter.done(): putter.set_result(None) + # Release 'blocked' tasks/coros via `.join()` + self._finished.set() + class PriorityQueue(Queue): """A subclass of Queue; retrieves entries in priority order (lowest first). diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index 26b212d6a50610f..bfaf81e2dc66163 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -65,11 +65,13 @@ def __init__(self, maxsize=0, *, ctx): def __getstate__(self): context.assert_spawning(self) return (self._ignore_epipe, self._maxsize, self._reader, self._writer, - self._rlock, self._wlock, self._sem, self._opid) + self._rlock, self._wlock, self._sem, self._opid, + self._shutdown_state) def __setstate__(self, state): (self._ignore_epipe, self._maxsize, self._reader, self._writer, - self._rlock, self._wlock, self._sem, self._opid) = state + self._rlock, self._wlock, self._sem, self._opid, + self._shutdown_state) = state self._reset() def _after_fork(self): @@ -159,6 +161,9 @@ def get_nowait(self): def put_nowait(self, obj): return self.put(obj, False) + def shutdown(self, immediate=True): + pass + def close(self): self._closed = True close = self._close diff --git a/Lib/queue.py b/Lib/queue.py index f6af7cb6df5cffc..933f18062050726 100644 --- a/Lib/queue.py +++ b/Lib/queue.py @@ -1,5 +1,6 @@ '''A multi-producer, multi-consumer queue.''' +import enum import threading import types from collections import deque @@ -29,9 +30,10 @@ class ShutDown(Exception): '''Raised when put/get with shut-down queue.''' -_queue_alive = "alive" -_queue_shutdown = "shutdown" -_queue_shutdown_immediate = "shutdown-immediate" +class _QueueState(enum.Enum): + ALIVE = "alive" + SHUTDOWN = "shutdown" + SHUTDOWN_IMMEDIATE = "shutdown-immediate" class Queue: @@ -64,7 +66,7 @@ def __init__(self, maxsize=0): self.unfinished_tasks = 0 # Queue shut-down state - self.shutdown_state = _queue_alive + self.shutdown_state = _QueueState.ALIVE def task_done(self): '''Indicate that a formerly enqueued task is complete. @@ -99,7 +101,7 @@ def join(self): ''' with self.all_tasks_done: while self.unfinished_tasks: - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: return self.all_tasks_done.wait() @@ -144,7 +146,7 @@ def put(self, item, block=True, timeout=None): is immediately available, else raise the Full exception ('timeout' is ignored in that case). ''' - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown with self.not_full: if self.maxsize > 0: @@ -154,7 +156,7 @@ def put(self, item, block=True, timeout=None): elif timeout is None: while self._qsize() >= self.maxsize: self.not_full.wait() - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") @@ -165,7 +167,7 @@ def put(self, item, block=True, timeout=None): if remaining <= 0.0: raise Full self.not_full.wait(remaining) - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown self._put(item) self.unfinished_tasks += 1 @@ -182,35 +184,35 @@ def get(self, block=True, timeout=None): available, else raise the Empty exception ('timeout' is ignored in that case). ''' - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise ShutDown with self.not_empty: if not block: if not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown raise Empty elif timeout is None: while not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown self.not_empty.wait() - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: endtime = time() + timeout while not self._qsize(): - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown remaining = endtime - time() if remaining <= 0.0: raise Empty self.not_empty.wait(remaining) - if self.shutdown_state != _queue_alive: + if self.shutdown_state is not _QueueState.ALIVE: raise ShutDown - if self.shutdown_state == _queue_shutdown_immediate: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: raise ShutDown item = self._get() self.not_full.notify() @@ -242,12 +244,19 @@ def shutdown(self, immediate=False): and join() if 'immediate'. The ShutDown exception is raised. ''' with self.mutex: + if self.shutdown_state is _QueueState.SHUTDOWN_IMMEDIATE: + return + if immediate: - self.shutdown_state = _queue_shutdown_immediate + self.shutdown_state = _QueueState.SHUTDOWN_IMMEDIATE self.not_empty.notify_all() + # set self.unfinished_tasks to 0 + # to break the loop in 'self.join()' + # when quits from `wait()` + self.unfinished_tasks = 0 self.all_tasks_done.notify_all() else: - self.shutdown_state = _queue_shutdown + self.shutdown_state = _QueueState.SHUTDOWN self.not_full.notify_all() # Override these methods to implement other queue organizations diff --git a/Lib/test/test_asyncio/test_queues.py b/Lib/test/test_asyncio/test_queues.py index 418c3fe618d89b8..f2c329c94b79973 100644 --- a/Lib/test/test_asyncio/test_queues.py +++ b/Lib/test/test_asyncio/test_queues.py @@ -528,37 +528,167 @@ class _QueueShutdownTestMixin: async def test_empty(self): q = self.q_class() q.shutdown() - try: + with self.assertRaises(asyncio.QueueShutDown): await q.put("data") - self.fail("Didn't appear to shut-down queue") - except asyncio.QueueShutDown: - pass - try: + with self.assertRaises(asyncio.QueueShutDown): await q.get() - self.fail("Didn't appear to shut-down queue") - except asyncio.QueueShutDown: - pass async def test_nonempty(self): q = self.q_class() q.put_nowait("data") q.shutdown() await q.get() - try: + with self.assertRaises(asyncio.QueueShutDown): await q.get() - self.fail("Didn't appear to shut-down queue") - except asyncio.QueueShutDown: - pass async def test_immediate(self): q = self.q_class() q.put_nowait("data") q.shutdown(immediate=True) - try: + with self.assertRaises(asyncio.QueueShutDown): await q.get() - self.fail("Didn't appear to shut-down queue") - except asyncio.QueueShutDown: - pass + + async def test_shutdown_repr(self): + q = self.q_class() + q.shutdown() + self.assertIn("shutdown", repr(q)) + + q = self.q_class() + q.shutdown(immediate=True) + self.assertIn("shutdown-immediate", repr(q)) + + async def test_shutdown_transition(self): + # allowed transitions would be from alive via shutdown to immediate + q = self.q_class() + self.assertEqual("alive", q._shutdown_state.value) + + q.shutdown() + self.assertEqual("shutdown", q._shutdown_state.value) + + q.shutdown(immediate=True) + self.assertEqual("shutdown-immediate", q._shutdown_state.value) + + q.shutdown() + self.assertEqual("shutdown-immediate", q._shutdown_state.value) + + async def test_shutdown_immediate_get(self): + q = self.q_class() + results = [] + go = asyncio.Event() + + async def get_once(q, go): + await go.wait() + try: + msg = await q.get() + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return True + + async def shutdown(q, go, immediate): + q.shutdown(immediate) + go.set() + return True + + tasks = ( + (get_once, (q, go)), + (get_once, (q, go)), + ) + t = [] + for coro, params in tasks: + t.append(asyncio.create_task(coro(*params))) + t.append(asyncio.create_task(shutdown(q, go, True))) + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*len(tasks)) + + + async def _shutdown_put(self, immediate): + q = self.q_class(2) + results = [] + go = asyncio.Event() + await q.put("Y") + await q.put("D") + # queue fulled + + async def put_once(q, go, msg): + await go.wait() + try: + await q.put(msg) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def shutdown(q, go, immediate): + q.shutdown(immediate) + go.set() + + tasks = ( + (put_once, (q, go, 100)), + (put_once, (q, go, 200)), + ) + t = [] + for coro, params in tasks: + t.append(asyncio.create_task(coro(*params))) + t.append(asyncio.create_task(shutdown(q, go, immediate))) + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*len(tasks)) + + async def test_shutdown_put(self): + return await self._shutdown_put(False) + + async def test_shutdown_immediate_put(self): + return await self._shutdown_put(True) + + + async def _shutdown_put_and_join(self, immediate): + q = self.q_class(2) + results = [] + go = asyncio.Event() + await q.put("Y") + await q.put("D") + # queue fulled + + async def put_once(q, go, msg): + await go.wait() + try: + await q.put(msg) + results.append(False) + except asyncio.QueueShutDown: + results.append(True) + return msg + + async def shutdown(q, go, immediate): + q.shutdown(immediate) + go.set() + + async def join(q, go): + await go.wait() + await q.join() + results.append(True) + return True + + tasks = ( + (put_once, (q, go, 'E')), + (put_once, (q, go, 'W')), + (join, (q, go)), + (join, (q, go)), + ) + t = [] + for coro, params in tasks: + t.append(asyncio.create_task(coro(*params))) + t.append(asyncio.create_task(shutdown(q, go, immediate))) + res = await asyncio.gather(*t) + + self.assertEqual(results, [True]*len(tasks)) + + async def test_shutdown_put_and_join(self): + return await self._shutdown_put_and_join(False) + + async def test_shutdown_immediate_put_and_join(self): + return await self._shutdown_put_and_join(True) class QueueShutdownTests( diff --git a/Lib/test/test_queue.py b/Lib/test/test_queue.py index 354299b9a5b16a6..a2814ac40f2e181 100644 --- a/Lib/test/test_queue.py +++ b/Lib/test/test_queue.py @@ -244,37 +244,230 @@ def test_shrinking_queue(self): def test_shutdown_empty(self): q = self.type2test() q.shutdown() - try: + with self.assertRaises(self.queue.ShutDown): q.put("data") - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass def test_shutdown_nonempty(self): q = self.type2test() q.put("data") q.shutdown() q.get() - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass def test_shutdown_immediate(self): q = self.type2test() q.put("data") q.shutdown(immediate=True) - try: + with self.assertRaises(self.queue.ShutDown): q.get() - self.fail("Didn't appear to shut-down queue") - except self.queue.ShutDown: - pass + + def test_shutdown_transition(self): + # allowed transitions would be from alive via shutdown to immediate + q = self.type2test() + self.assertEqual("alive", q.shutdown_state.value) + + q.shutdown() + self.assertEqual("shutdown", q.shutdown_state.value) + + q.shutdown(immediate=True) + self.assertEqual("shutdown-immediate", q.shutdown_state.value) + + q.shutdown(immediate=False) + self.assertEqual("shutdown-immediate", q.shutdown_state.value) + + def test_shutdown_get(self): + q = self.type2test(2) + results = [] + go = threading.Event() + + def get_once(q, go): + go.wait() + try: + msg = q.get() + results.append(False) + except self.queue.ShutDown: + results.append(True) + return True + + thrds = ( + (get_once, (q, go)), + (get_once, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + q.shutdown() + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_put(self): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue fulled + + def put_once(q, msg, go): + go.wait() + try: + q.put(msg) + results.append(False) + except self.queue.ShutDown: + results.append(True) + return msg + + thrds = ( + (put_once, (q, 100, go)), + (put_once, (q, 200, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + q.shutdown() + go.set() + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def _shutdown_join(self, immediate): + q = self.type2test() + results = [] + go = threading.Event() + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_join(self): + return self._shutdown_join(True) + + def test_shutdown_join(self): + return self._shutdown_join(False) + + def _shutdown_put_and_join(self, immediate): + q = self.type2test(2) + results = [] + go = threading.Event() + q.put("Y") + q.put("D") + # queue fulled + + def put_once(q, msg, go): + go.wait() + try: + q.put(msg) + results.append(False) + except self.queue.ShutDown: + results.append(True) + return msg + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (put_once, (q, 100, go)), + (put_once, (q, 200, go)), + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + if not immediate: + self.assertTrue(q.unfinished_tasks, 2) + for i in range(2): + thread = threading.Thread(target=q.task_done) + thread.start() + threads.append(thread) + + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_put_and_join(self): + return self._shutdown_put_and_join(True) + + def test_shutdown_put_and_join(self): + return self._shutdown_put_and_join(False) + + def _shutdown_get_and_join(self, immediate): + q = self.type2test() + results = [] + go = threading.Event() + + def get_once(q, go): + go.wait() + try: + msg = q.get() + results.append(False) + except self.queue.ShutDown: + results.append(True) + return True + + def join(q, go): + go.wait() + q.join() + results.append(True) + + thrds = ( + (get_once, (q, go)), + (get_once, (q, go)), + (join, (q, go)), + (join, (q, go)), + ) + threads = [] + for f, params in thrds: + thread = threading.Thread(target=f, args=params) + thread.start() + threads.append(thread) + go.set() + q.shutdown(immediate) + for t in threads: + t.join() + + self.assertEqual(results, [True]*len(thrds)) + + def test_shutdown_immediate_get_and_join(self): + return self._shutdown_get_and_join(True) + + def test__shutdown_get_and_join(self): + return self._shutdown_get_and_join(False) class QueueTest(BaseQueueTestMixin):