From 5dcd9c2a9631ceabb16baa332b53f8d469034b80 Mon Sep 17 00:00:00 2001 From: "Laura F. Dickinson" Date: Sat, 28 Jul 2018 12:04:10 +0100 Subject: [PATCH 1/5] Add the ability to close a queue. --- newsfragments/573.feature.rst | 2 ++ trio/_sync.py | 34 ++++++++++++++++++++++++++++++++++ trio/tests/test_sync.py | 18 +++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 newsfragments/573.feature.rst diff --git a/newsfragments/573.feature.rst b/newsfragments/573.feature.rst new file mode 100644 index 0000000000..e3f6b9e9fa --- /dev/null +++ b/newsfragments/573.feature.rst @@ -0,0 +1,2 @@ +Add the ability to close a :class:`trio.Queue`, cancelling all waiting getters and putters, and +preventing anyone else from getting or putting onto it. \ No newline at end of file diff --git a/trio/_sync.py b/trio/_sync.py index fe806dd34a..abcb96860a 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -15,6 +15,7 @@ "StrictFIFOLock", "Condition", "Queue", + "QueueClosed", ] @@ -802,6 +803,11 @@ class _QueueStats: tasks_waiting_get = attr.ib() +class QueueClosed(Exception): + """Raised on waiters for the queue when a queue is closed. + """ + + # Like queue.Queue, with the notable difference that the capacity argument is # mandatory. class Queue: @@ -842,6 +848,8 @@ def __init__(self, capacity): # if len(self._data) < self.capacity, then self._put_wait is empty # if len(self._data) > 0, then self._get_wait is empty self._data = deque() + # closed state, prevents any more waiters or putters when this is true + self._closed = False def __repr__(self): return ( @@ -887,6 +895,9 @@ def put_nowait(self, obj): WouldBlock: if the queue is full. """ + if self._closed: + raise QueueClosed + if self._get_wait: assert not self._data task, _ = self._get_wait.popitem(last=False) @@ -905,6 +916,9 @@ async def put(self, obj): """ await _core.checkpoint_if_cancelled() + if self._closed: + raise QueueClosed + try: self.put_nowait(obj) except _core.WouldBlock: @@ -933,6 +947,9 @@ def get_nowait(self): WouldBlock: if the queue is empty. """ + if self._closed: + raise QueueClosed + if self._put_wait: task, value = self._put_wait.popitem(last=False) # No need to check max_size, b/c we'll pop an item off again right @@ -953,6 +970,9 @@ async def get(self): """ await _core.checkpoint_if_cancelled() + if self._closed: + raise QueueClosed + try: value = self.get_nowait() except _core.WouldBlock: @@ -972,6 +992,20 @@ def abort_fn(_): value = await _core.wait_task_rescheduled(abort_fn) return value + def close(self): + """Closes this queue, cancelling any remaining tasks waiting to get or put, and removes + any buffered data. + """ + self._closed = True + + for task in self._get_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + + for task in self._put_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + + self._data.clear() + @aiter_compat def __aiter__(self): return self diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index bded1b0544..f958999908 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -5,7 +5,6 @@ from ..testing import wait_all_tasks_blocked, assert_checkpoints from .. import _core -from .. import _timeouts from .._timeouts import sleep_forever, move_on_after from .._sync import * @@ -542,6 +541,23 @@ async def do_put(q, v): q.get_nowait() +async def test_Queue_close(): + queue = Queue(capacity=0) + + async def first(): + with pytest.raises(QueueClosed): + await queue.get() + + async with _core.open_nursery() as n: + n.start_soon(first) + queue.close() + + with pytest.raises(QueueClosed): + await queue.put(None) + + assert len(queue._get_wait) == 0 + + # Two ways of implementing a Lock in terms of a Queue. Used to let us put the # Queue through the generic lock tests. From 5178dd54a1ad8fe5b4c6d3fd17cf6371e7830df4 Mon Sep 17 00:00:00 2001 From: "Laura F. Dickinson" Date: Sat, 28 Jul 2018 12:13:31 +0100 Subject: [PATCH 2/5] Make it so you can close each side of a queue independently. --- trio/_sync.py | 35 ++++++++++++++++++++++++----------- trio/tests/test_sync.py | 21 +++++++++++---------- 2 files changed, 35 insertions(+), 21 deletions(-) diff --git a/trio/_sync.py b/trio/_sync.py index abcb96860a..62fa2de232 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -848,8 +848,8 @@ def __init__(self, capacity): # if len(self._data) < self.capacity, then self._put_wait is empty # if len(self._data) > 0, then self._get_wait is empty self._data = deque() - # closed state, prevents any more waiters or putters when this is true - self._closed = False + # closed state, prevents any action when [0] is True, and prevents any new putters when [1] + self._close_state = [False, False] def __repr__(self): return ( @@ -895,7 +895,7 @@ def put_nowait(self, obj): WouldBlock: if the queue is full. """ - if self._closed: + if any(self._close_state): raise QueueClosed if self._get_wait: @@ -916,7 +916,7 @@ async def put(self, obj): """ await _core.checkpoint_if_cancelled() - if self._closed: + if any(self._close_state): raise QueueClosed try: @@ -947,7 +947,7 @@ def get_nowait(self): WouldBlock: if the queue is empty. """ - if self._closed: + if self._close_state[0]: raise QueueClosed if self._put_wait: @@ -970,7 +970,7 @@ async def get(self): """ await _core.checkpoint_if_cancelled() - if self._closed: + if self._close_state[0]: raise QueueClosed try: @@ -992,12 +992,22 @@ def abort_fn(_): value = await _core.wait_task_rescheduled(abort_fn) return value - def close(self): - """Closes this queue, cancelling any remaining tasks waiting to get or put, and removes - any buffered data. + def close_put(self): + """Closes one side of this queue, preventing any putters from putting data onto the queue. + + If this queue is empty, it will also cancel all getters. """ - self._closed = True + if self.empty(): + # pointless to let the getters wait on closed data + self.close_both_sides() + else: + self._close_state[1] = True + for task in self._put_wait.values(): + _core.reschedule(task, outcome.Error(QueueClosed)) + def close_both_sides(self): + """Closes both the getter and putter sides of the queue, discarding all data. + """ for task in self._get_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) @@ -1011,7 +1021,10 @@ def __aiter__(self): return self async def __anext__(self): - return await self.get() + try: + return await self.get() + except QueueClosed: + raise StopAsyncIteration from None def statistics(self): """Returns an object containing debugging information. diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index f958999908..0bd123f12d 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -542,20 +542,21 @@ async def do_put(q, v): async def test_Queue_close(): - queue = Queue(capacity=0) + q1 = Queue(capacity=1) - async def first(): - with pytest.raises(QueueClosed): - await queue.get() + await q1.put(1) + q1.close_put() + with pytest.raises(QueueClosed): + await q1.put(2) - async with _core.open_nursery() as n: - n.start_soon(first) - queue.close() + assert (await q1.get()) == 1 - with pytest.raises(QueueClosed): - await queue.put(None) + q2 = Queue(capacity=1) + await q2.put(1) + q2.close_both_sides() - assert len(queue._get_wait) == 0 + with pytest.raises(QueueClosed): + await q2.get() # Two ways of implementing a Lock in terms of a Queue. Used to let us put the From 812c0eaf05a3f4bb13ebf64e7abf72423f94bfe5 Mon Sep 17 00:00:00 2001 From: "Laura F. Dickinson" Date: Sun, 29 Jul 2018 10:43:13 +0100 Subject: [PATCH 3/5] Automatically close the get side if the put side is closed. --- trio/_sync.py | 11 +++++++++++ trio/tests/test_sync.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/trio/_sync.py b/trio/_sync.py index 62fa2de232..71eb9c6caa 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -959,6 +959,16 @@ def get_nowait(self): if self._data: value = self._data.popleft() return value + if self._close_state[1]: + # this confused me a bit so its bound to confuse somebody else as to why this is here + # 1) there's no put waiters, so we skip that branch + # 2) there's no data so we skip that branch + # that means that if there's no data at all, and the put size is closed + # we cannot ever have more data, so we close this side and raise QueueClosed so that + # any getters from here on close early + self._close_state[0] = True + raise QueueClosed + raise _core.WouldBlock() @_core.enable_ki_protection @@ -1008,6 +1018,7 @@ def close_put(self): def close_both_sides(self): """Closes both the getter and putter sides of the queue, discarding all data. """ + self._close_state = [True, True] for task in self._get_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) diff --git a/trio/tests/test_sync.py b/trio/tests/test_sync.py index 0bd123f12d..2d0c354907 100644 --- a/trio/tests/test_sync.py +++ b/trio/tests/test_sync.py @@ -550,6 +550,8 @@ async def test_Queue_close(): await q1.put(2) assert (await q1.get()) == 1 + with pytest.raises(QueueClosed): + await q1.get() q2 = Queue(capacity=1) await q2.put(1) From 6bbb66f688c6174d53d114d5b191f41bb22c1e43 Mon Sep 17 00:00:00 2001 From: "Laura F. Dickinson" Date: Sun, 29 Jul 2018 11:29:36 +0100 Subject: [PATCH 4/5] Clear _get_wait and _put_wait when closing the queue. --- trio/_sync.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trio/_sync.py b/trio/_sync.py index 71eb9c6caa..13aef7a410 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -1015,6 +1015,8 @@ def close_put(self): for task in self._put_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) + self._put_wait.clear() + def close_both_sides(self): """Closes both the getter and putter sides of the queue, discarding all data. """ @@ -1022,9 +1024,13 @@ def close_both_sides(self): for task in self._get_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) + self._get_wait.clear() + for task in self._put_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) + self._put_wait.clear() + self._data.clear() @aiter_compat From 8a85421a09647f57a69e0e84ee639303e39d09f5 Mon Sep 17 00:00:00 2001 From: "Laura F. Dickinson" Date: Sun, 29 Jul 2018 11:32:01 +0100 Subject: [PATCH 5/5] Change close state into two fields. --- trio/_sync.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/trio/_sync.py b/trio/_sync.py index 13aef7a410..c9a2c0a950 100644 --- a/trio/_sync.py +++ b/trio/_sync.py @@ -848,8 +848,9 @@ def __init__(self, capacity): # if len(self._data) < self.capacity, then self._put_wait is empty # if len(self._data) > 0, then self._get_wait is empty self._data = deque() - # closed state, prevents any action when [0] is True, and prevents any new putters when [1] - self._close_state = [False, False] + # closed state + self._put_close = False + self._all_closed = False def __repr__(self): return ( @@ -895,7 +896,7 @@ def put_nowait(self, obj): WouldBlock: if the queue is full. """ - if any(self._close_state): + if self._put_close or self._all_closed: raise QueueClosed if self._get_wait: @@ -916,7 +917,7 @@ async def put(self, obj): """ await _core.checkpoint_if_cancelled() - if any(self._close_state): + if self._put_close or self._all_closed: raise QueueClosed try: @@ -947,7 +948,7 @@ def get_nowait(self): WouldBlock: if the queue is empty. """ - if self._close_state[0]: + if self._all_closed: raise QueueClosed if self._put_wait: @@ -959,14 +960,14 @@ def get_nowait(self): if self._data: value = self._data.popleft() return value - if self._close_state[1]: + if self._put_close: # this confused me a bit so its bound to confuse somebody else as to why this is here # 1) there's no put waiters, so we skip that branch # 2) there's no data so we skip that branch - # that means that if there's no data at all, and the put size is closed + # that means that if there's no data at all, and the put side is closed # we cannot ever have more data, so we close this side and raise QueueClosed so that # any getters from here on close early - self._close_state[0] = True + self._all_closed = True raise QueueClosed raise _core.WouldBlock() @@ -980,7 +981,7 @@ async def get(self): """ await _core.checkpoint_if_cancelled() - if self._close_state[0]: + if self._all_closed: raise QueueClosed try: @@ -1011,7 +1012,7 @@ def close_put(self): # pointless to let the getters wait on closed data self.close_both_sides() else: - self._close_state[1] = True + self._put_close = True for task in self._put_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed)) @@ -1020,7 +1021,7 @@ def close_put(self): def close_both_sides(self): """Closes both the getter and putter sides of the queue, discarding all data. """ - self._close_state = [True, True] + self._put_close, self._all_closed = True, True for task in self._get_wait.values(): _core.reschedule(task, outcome.Error(QueueClosed))