diff --git a/requirements_dev.txt b/requirements_dev.txt index c54b2310..23e72b92 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,3 +1,4 @@ +bidict==0.21.4 pytest==6.2.5 pytest-cov==3.0.0 pytest-timeout==2.0.1 diff --git a/setup.cfg b/setup.cfg index 3ced5fab..403d7777 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,6 +26,7 @@ project-urls = [options] install_requires = + bidict pika setuptools stomp.py>=7 diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index 362ae672..09c314f6 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -15,6 +15,7 @@ import pika import pika.exceptions +from bidict import bidict from pika.adapters.blocking_connection import BlockingChannel import workflows @@ -563,7 +564,12 @@ def _ack( :param **kwargs: Further parameters for the transport layer. """ - self._pika_thread.ack(message_id, subscription_id, multiple=multiple) + self._pika_thread.ack( + message_id, + subscription_id, + multiple=multiple, + transaction_id=_kwargs.get("transaction"), + ) def _nack( self, @@ -588,7 +594,11 @@ def _nack( requeue: Attempt to requeue. see AMQP basic.nack. """ self._pika_thread.nack( - message_id, subscription_id, multiple=multiple, requeue=requeue + message_id, + subscription_id, + multiple=multiple, + requeue=requeue, + transaction_id=_kwargs.get("transaction"), ) @staticmethod @@ -707,10 +717,11 @@ def __init__( # The pika connection object self._connection: Optional[pika.BlockingConnection] = None # Index of per-subscription channels. - self._pika_channels: Dict[int, BlockingChannel] = {} + self._pika_channels: bidict[int, BlockingChannel] = bidict() # Bidirectional index of all ongoing transactions. May include the shared channel - self._transactions_by_id: Dict[int, BlockingChannel] = {} - self._transactions_by_channel: Dict[BlockingChannel, int] = {} + self._transaction_on_channel: bidict[BlockingChannel, int] = bidict() + # Information on whether a channel has uncommitted messages + self._channel_has_active_tx: Dict[BlockingChannel, bool] = {} # A common, shared channel, used for sending messages outside of transactions. self._pika_shared_channel: Optional[BlockingChannel] # Are we allowed to reconnect. Can only be turned off, never on @@ -923,9 +934,7 @@ def _unsubscribe(): channel.close() # Forget about any ongoing transactions on the channel - if channel in self._transactions_by_channel: - transaction_id = self._transactions_by_channel.pop(channel) - self._transactions_by_id.pop(transaction_id) + self._transaction_on_channel.pop(channel, None) result.set_result(None) except BaseException as e: @@ -955,7 +964,8 @@ def _send(): if future.set_running_or_notify_cancel(): try: if transaction_id: - channel = self._transactions_by_id[transaction_id] + channel = self._transaction_on_channel.inverse[transaction_id] + self._channel_has_active_tx[channel] = True else: channel = self._get_shared_channel() channel.basic_publish( @@ -973,7 +983,14 @@ def _send(): self._connection.add_callback_threadsafe(_send) return future - def ack(self, delivery_tag: int, subscription_id: int, *, multiple=False): + def ack( + self, + delivery_tag: int, + subscription_id: int, + *, + multiple=False, + transaction_id: Optional[int], + ): if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to ACK") @@ -981,12 +998,35 @@ def ack(self, delivery_tag: int, subscription_id: int, *, multiple=False): assert self._connection is not None - self._connection.add_callback_threadsafe( - lambda: channel.basic_ack(delivery_tag, multiple=multiple) - ) + # Check if channel is in tx mode + transaction = self._transaction_on_channel.get(channel) + if transaction == transaction_id: + # Matching transaction IDs - perfect + self._channel_has_active_tx[channel] |= transaction is not None + self._connection.add_callback_threadsafe( + lambda: channel.basic_ack(delivery_tag, multiple=multiple) + ) + elif transaction_id is None and not self._channel_has_active_tx[channel]: + + def _ack_callback(): + channel.basic_ack(delivery_tag, multiple=multiple) + channel.tx_commit() + + self._connection.add_callback_threadsafe(_ack_callback) + else: + raise workflows.Error( + "Transaction state mismatch. " + f"Call assumes transaction {transaction_id}, channel has transaction {transaction}" + ) def nack( - self, delivery_tag: int, subscription_id: int, *, multiple=False, requeue=True + self, + delivery_tag: int, + subscription_id: int, + *, + multiple=False, + requeue=True, + transaction_id: Optional[int], ): if subscription_id not in self._subscriptions: raise KeyError(f"Could not find subscription {subscription_id} to NACK") @@ -995,9 +1035,28 @@ def nack( assert self._connection is not None - self._connection.add_callback_threadsafe( - lambda: channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue) - ) + # Check if channel is in tx mode + transaction = self._transaction_on_channel.get(channel) + if transaction == transaction_id: + # Matching transaction IDs - perfect + self._channel_has_active_tx[channel] |= transaction is not None + self._connection.add_callback_threadsafe( + lambda: channel.basic_nack( + delivery_tag, multiple=multiple, requeue=requeue + ) + ) + elif transaction_id is None and not self._channel_has_active_tx[channel]: + + def _nack_callback(): + channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue) + channel.tx_commit() + + self._connection.add_callback_threadsafe(_nack_callback) + else: + raise workflows.Error( + "Transaction state mismatch. " + f"Call assumes transaction {transaction_id}, channel has transaction {transaction}" + ) def tx_select( self, transaction_id: int, subscription_id: Optional[int] @@ -1023,13 +1082,12 @@ def _tx_select(): channel = self._pika_channels[subscription_id] else: channel = self._get_shared_channel() - if channel in self._transactions_by_channel: + if channel in self._transaction_on_channel: raise KeyError( - f"Channel {channel} is already running transaction {self._transactions_by_channel[channel]}, so can't start transaction {transaction_id}" + f"Channel {channel} is already running transaction {self._transaction_on_channel[channel]}, so can't start transaction {transaction_id}" ) channel.tx_select() - self._transactions_by_channel[channel] = transaction_id - self._transactions_by_id[transaction_id] = channel + self._transaction_on_channel[channel] = transaction_id future.set_result(None) except BaseException as e: @@ -1051,13 +1109,15 @@ def tx_rollback(self, transaction_id: int) -> Future[None]: def _tx_rollback(): if future.set_running_or_notify_cancel(): try: - channel = self._transactions_by_id.pop(transaction_id, None) + channel = self._transaction_on_channel.inverse.pop( + transaction_id, None + ) if not channel: raise KeyError( f"Could not find transaction {transaction_id} to roll back" ) - self._transactions_by_channel.pop(channel) channel.tx_rollback() + self._channel_has_active_tx.pop(channel, None) future.set_result(None) except BaseException as e: future.set_exception(e) @@ -1078,13 +1138,15 @@ def tx_commit(self, transaction_id: int) -> Future[None]: def _tx_commit(): if future.set_running_or_notify_cancel(): try: - channel = self._transactions_by_id.pop(transaction_id, None) + channel = self._transaction_on_channel.inverse.pop( + transaction_id, None + ) if not channel: raise KeyError( f"Could not find transaction {transaction_id} to commit" ) - self._transactions_by_channel.pop(channel) channel.tx_commit() + self._channel_has_active_tx.pop(channel, None) future.set_result(None) except BaseException as e: future.set_exception(e) diff --git a/tests/transport/test_pika.py b/tests/transport/test_pika.py index c5cc8076..888f1896 100644 --- a/tests/transport/test_pika.py +++ b/tests/transport/test_pika.py @@ -764,7 +764,10 @@ def test_ack_message(mock_pikathread): transport._ack(mock.sentinel.messageid, mock.sentinel.sub_id) mock_pikathread.ack.assert_called_once_with( - mock.sentinel.messageid, mock.sentinel.sub_id, multiple=False + mock.sentinel.messageid, + mock.sentinel.sub_id, + multiple=False, + transaction_id=None, ) @@ -776,11 +779,15 @@ def test_nack_message(mock_pikathread): transport._nack(mock.sentinel.messageid, mock.sentinel.sub_id) mock_pikathread.nack.assert_called_once_with( - mock.sentinel.messageid, mock.sentinel.sub_id, multiple=False, requeue=True + mock.sentinel.messageid, + mock.sentinel.sub_id, + multiple=False, + requeue=True, + transaction_id=None, ) -@pytest.fixture +@pytest.fixture(scope="session") def connection_params(): """Connection Parameters for connecting to a physical RabbitMQ server""" params = [