From 83c02c5f10446ea34415ba93b73a42a4e0d464ab Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Wed, 20 Oct 2021 08:52:58 +0100 Subject: [PATCH 1/6] NotImplement the correct interface --- src/workflows/transport/pika_transport.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index e310f67e..18381d50 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -516,25 +516,28 @@ def _broadcast( mandatory=False, ).result() - def _transaction_begin(self, **kwargs): - """Enter transaction mode. + def _transaction_begin(self, transaction_id, **kwargs): + """Start a new transaction. + :param transaction_id: ID for this transaction in the transport layer. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError() + raise NotImplementedError("Transport interface not implemented") # self._channel.tx_select() - def _transaction_abort(self, **kwargs): + def _transaction_abort(self, transaction_id, **kwargs): """Abort a transaction and roll back all operations. + :param transaction_id: ID of transaction to be aborted. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError() + raise NotImplementedError("Transport interface not implemented") # self._channel.tx_rollback() - def _transaction_commit(self, **kwargs): + def _transaction_commit(self, transaction_id, **kwargs): """Commit a transaction. + :param transaction_id: ID of transaction to be committed. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError() + raise NotImplementedError("Transport interface not implemented") # self._channel.tx_commit() def _ack( From ce4c3d42973e266241bac4e2494840cc71654909 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Wed, 20 Oct 2021 11:46:34 +0100 Subject: [PATCH 2/6] Add transaction support --- src/workflows/services/sample_transaction.py | 5 +- src/workflows/transport/pika_transport.py | 124 +++++++++++++++++-- 2 files changed, 117 insertions(+), 12 deletions(-) diff --git a/src/workflows/services/sample_transaction.py b/src/workflows/services/sample_transaction.py index 48530b2e..cbce82c5 100644 --- a/src/workflows/services/sample_transaction.py +++ b/src/workflows/services/sample_transaction.py @@ -17,7 +17,10 @@ class SampleTxn(CommonService): def initializing(self): """Subscribe to a channel. Received messages must be acknowledged.""" self.subid = self._transport.subscribe( - "transient.transaction", self.receive_message, acknowledgement=True + "transient.transaction", + self.receive_message, + acknowledgement=True, + prefetch_count=1000, ) @staticmethod diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index 18381d50..a9b4b31d 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -516,29 +516,29 @@ def _broadcast( mandatory=False, ).result() - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin( + self, transaction_id: int, subscription_id: Optional[int] = None, **kwargs + ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. + :param subscription_id: Tie the transaction to a specific channel containing this subscription. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError("Transport interface not implemented") - # self._channel.tx_select() + self._pika_thread.tx_select(transaction_id, subscription_id) - def _transaction_abort(self, transaction_id, **kwargs): + def _transaction_abort(self, transaction_id: int, **kwargs) -> None: """Abort a transaction and roll back all operations. :param transaction_id: ID of transaction to be aborted. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError("Transport interface not implemented") - # self._channel.tx_rollback() + self._pika_thread.tx_rollback(transaction_id) - def _transaction_commit(self, transaction_id, **kwargs): + def _transaction_commit(self, transaction_id: int, **kwargs) -> None: """Commit a transaction. :param transaction_id: ID of transaction to be committed. :param **kwargs: Further parameters for the transport layer. """ - raise NotImplementedError("Transport interface not implemented") - # self._channel.tx_commit() + self._pika_thread.tx_commit(transaction_id) def _ack( self, message_id, subscription_id: int, *, multiple: bool = False, **_kwargs @@ -704,6 +704,9 @@ def __init__( self._connection: Optional[pika.BlockingConnection] = None # Per-subscription channels. May be pointing to the shared channel self._pika_channels: Dict[int, BlockingChannel] = {} + # Bidirectional index of all ongoing transactions. May include the shared channel + self._transactions_by_id: Dict[int, BlockingChannel] = {} + self._transactions_by_channel: Dict[BlockingChannel, int] = {} # A common, shared channel, used for non-QoS subscriptions self._pika_shared_channel: Optional[BlockingChannel] # Are we allowed to reconnect. Can only be turned off, never on @@ -919,6 +922,11 @@ def _unsubscribe(): logger.debug("Closing channel that is now unused") channel.close() + # Forget about any ongoing transactions on the channel + if channel in self._transactions_by_channel: + transaction_id = self._transactions_by_channel.pop(channel) + self._transactions_by_id.pop(transaction_id) + result.set_result(None) except BaseException as e: result.set_exception(e) @@ -986,6 +994,100 @@ def nack( lambda: channel.basic_nack(delivery_tag, multiple=multiple, requeue=requeue) ) + def tx_select( + self, transaction_id: int, subscription_id: Optional[int] + ) -> Future[None]: + """Set a channel to transaction mode. Thread-safe. + :param transaction_id: ID for this transaction in the transport layer. + :param subscription_id: Tie the transaction to a specific channel containing this subscription. + """ + + if not self._connection: + raise RuntimeError("Cannot transact on unstarted connection") + + future: Future[None] = Future() + + def _tx_select(): + if future.set_running_or_notify_cancel(): + try: + if subscription_id: + if subscription_id not in self._pika_channels: + raise KeyError( + f"Could not find subscription {subscription_id} to begin transaction" + ) + channel = self._pika_channels[subscription_id] + else: + channel = self._get_shared_channel() + if channel in self._transactions_by_channel: + raise KeyError( + f"Channel {channel} is already running transaction {self._transactions_by_channel[channel]}, so can't start transaction {transaction_id}" + ) + channel.tx_select() + self._transactions_by_channel[channel] = transaction_id + self._transactions_by_id[transaction_id] = channel + + future.set_result(None) + except BaseException as e: + future.set_exception(e) + raise + + self._connection.add_callback_threadsafe(_tx_select) + return future + + def tx_rollback(self, transaction_id: int) -> Future[None]: + """Abort a transaction and roll back all operations. Thread-safe. + :param transaction_id: ID of transaction to be aborted. + """ + if not self._connection: + raise RuntimeError("Cannot transact on unstarted connection") + + future: Future[None] = Future() + + def _tx_rollback(): + if future.set_running_or_notify_cancel(): + try: + channel = self._transactions_by_id.pop(transaction_id, None) + if not channel: + raise KeyError( + f"Could not find transaction {transaction_id} to roll back" + ) + self._transactions_by_channel.pop(channel) + channel.tx_rollback() + future.set_result(None) + except BaseException as e: + future.set_exception(e) + raise + + self._connection.add_callback_threadsafe(_tx_rollback) + return future + + def tx_commit(self, transaction_id: int) -> Future[None]: + """Commit a transaction. + :param transaction_id: ID of transaction to be committed. Thread-safe.. + """ + if not self._connection: + raise RuntimeError("Cannot transact on unstarted connection") + + future: Future[None] = Future() + + def _tx_commit(): + if future.set_running_or_notify_cancel(): + try: + channel = self._transactions_by_id.pop(transaction_id, None) + if not channel: + raise KeyError( + f"Could not find transaction {transaction_id} to commit" + ) + self._transactions_by_channel.pop(channel) + channel.tx_commit() + future.set_result(None) + except BaseException as e: + future.set_exception(e) + raise + + self._connection.add_callback_threadsafe(_tx_commit) + return future + @property def connection_alive(self) -> bool: """ @@ -1029,7 +1131,7 @@ def _get_shared_channel(self) -> BlockingChannel: if not self._pika_shared_channel: self._pika_shared_channel = self._connection.channel() - self._pika_shared_channel.confirm_delivery() + ##### self._pika_shared_channel.confirm_delivery() return self._pika_shared_channel def _recreate_subscriptions(self): @@ -1067,7 +1169,7 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio channel = self._get_shared_channel() else: channel = self._connection.channel() - channel.confirm_delivery() + ##### channel.confirm_delivery() channel.basic_qos(prefetch_count=subscription.prefetch_count) if subscription.kind == _PikaSubscriptionKind.FANOUT: From 960ac65befa17fc23ff3e03fa0cd12c4603d2637 Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Fri, 22 Oct 2021 13:17:33 +0100 Subject: [PATCH 3/6] Make every subscription run on a separate channel And disable confirm mode on all channels. This is so that transactions can be used on channels. So if a message is to be send as part of a transaction then it can be sent on the relevant channel, and if a message is sent outside of a transaction it can be sent via the default channel. The confirm mode currently does not have any effect as we are not observing the callback. However, as confirm and transaction modes are mutually exclusive remove the confirm mode from all subscription channels. For the default channel we may eventually want to have one in normal mode and one in confirm mode, but this is another issue for another day. --- src/workflows/transport/pika_transport.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index a9b4b31d..831ceda8 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -702,12 +702,12 @@ def __init__( self._subscriptions: Dict[int, _PikaSubscription] = {} # The pika connection object self._connection: Optional[pika.BlockingConnection] = None - # Per-subscription channels. May be pointing to the shared channel + # Index of per-subscription channels. self._pika_channels: Dict[int, BlockingChannel] = {} # Bidirectional index of all ongoing transactions. May include the shared channel self._transactions_by_id: Dict[int, BlockingChannel] = {} self._transactions_by_channel: Dict[BlockingChannel, int] = {} - # A common, shared channel, used for non-QoS subscriptions + # A common, shared channel, used for sending messages outside of transactions. self._pika_shared_channel: Optional[BlockingChannel] # Are we allowed to reconnect. Can only be turned off, never on self._reconnection_allowed: bool = True @@ -1103,7 +1103,7 @@ def connection_alive(self) -> bool: ) # NOTE: With reconnection lifecycle this probably doesn't make sense - # on it's own. It might make sense to add this returning a + # on its own. It might make sense to add this returning a # connection-specific 'token' - presumably the user might want # to ensure that a connection is still the same connection # and thus adhering to various within-connection guarantees. @@ -1164,22 +1164,18 @@ def _add_subscription(self, subscription_id: int, subscription: _PikaSubscriptio f"Subscription {subscription_id} to '{subscription.destination}' is not reconnectable. Turning reconnection off." ) - # Either open a channel (if prefetch) or use the shared one - if subscription.prefetch_count == 0: - channel = self._get_shared_channel() - else: - channel = self._connection.channel() - ##### channel.confirm_delivery() - channel.basic_qos(prefetch_count=subscription.prefetch_count) + # Open a dedicated channel for this subscription + channel = self._connection.channel() + channel.basic_qos(prefetch_count=subscription.prefetch_count) - if subscription.kind == _PikaSubscriptionKind.FANOUT: + if subscription.kind is _PikaSubscriptionKind.FANOUT: # If a FANOUT subscription, then we need to create and bind # a temporary queue to receive messages from the exchange queue = channel.queue_declare("", exclusive=True).method.queue assert queue is not None channel.queue_bind(queue, subscription.destination) subscription.queue = queue - elif subscription.kind == _PikaSubscriptionKind.DIRECT: + elif subscription.kind is _PikaSubscriptionKind.DIRECT: subscription.queue = subscription.destination else: raise NotImplementedError(f"Unknown subscription kind: {subscription.kind}") From 4187cf97e47328258434fe0ce3d8b8dc05db568a Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 25 Oct 2021 09:35:22 +0100 Subject: [PATCH 4/6] Add optional subscription ID to transaction calls so that a transaction can be associated with the correct channel where needed --- src/workflows/transport/common_transport.py | 23 ++++++++++++++------ src/workflows/transport/offline_transport.py | 2 +- src/workflows/transport/stomp_transport.py | 2 +- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index d10012f8..61d305ac 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -288,15 +288,22 @@ def nack(self, message, subscription_id: Optional[int] = None, **kwargs): ) self._nack(message_id, subscription_id=subscription_id, **kwargs) - def transaction_begin(self, **kwargs) -> int: + def transaction_begin(self, subscription_id: Optional[int] = None, **kwargs) -> int: """Start a new transaction. - :param **kwargs: Further parameters for the transport layer. For example + :param **kwargs: Further parameters for the transport layer. :return: A transaction ID that can be passed to other functions. """ self.__transaction_id += 1 self.__transactions.add(self.__transaction_id) - self.log.debug("Starting transaction with ID %d", self.__subscription_id) - self._transaction_begin(self.__transaction_id, **kwargs) + if subscription_id: + self.log.debug( + "Starting transaction with ID %d on subscription %d", + self.__transaction_id, + subscription_id, + ) + else: + self.log.debug("Starting transaction with ID %d", self.__transaction_id) + self._transaction_begin(self.__transaction_id, subscription_id, **kwargs) return self.__transaction_id def transaction_abort(self, transaction_id: int, **kwargs): @@ -405,21 +412,23 @@ def _nack(self, message_id, subscription_id, **kwargs): """ raise NotImplementedError("Transport interface not implemented") - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin( + self, transaction_id: int, subscription_id: Optional[int] = None, **kwargs + ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. :param **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") - def _transaction_abort(self, transaction_id, **kwargs): + def _transaction_abort(self, transaction_id: int, **kwargs) -> None: """Abort a transaction and roll back all operations. :param transaction_id: ID of transaction to be aborted. :param **kwargs: Further parameters for the transport layer. """ raise NotImplementedError("Transport interface not implemented") - def _transaction_commit(self, transaction_id, **kwargs): + def _transaction_commit(self, transaction_id: int, **kwargs) -> None: """Commit a transaction. :param transaction_id: ID of transaction to be committed. :param **kwargs: Further parameters for the transport layer. diff --git a/src/workflows/transport/offline_transport.py b/src/workflows/transport/offline_transport.py index 2838acb0..fb6e9446 100644 --- a/src/workflows/transport/offline_transport.py +++ b/src/workflows/transport/offline_transport.py @@ -66,7 +66,7 @@ def _broadcast( ): self._output(f"Broadcasting {len(message)} bytes to {destination}", message) - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin(self, transaction_id, subscription_id, **kwargs): self._output(f"Starting transaction {transaction_id}") def _transaction_abort(self, transaction_id, **kwargs): diff --git a/src/workflows/transport/stomp_transport.py b/src/workflows/transport/stomp_transport.py index fc774f4e..46a937ed 100644 --- a/src/workflows/transport/stomp_transport.py +++ b/src/workflows/transport/stomp_transport.py @@ -407,7 +407,7 @@ def _broadcast( self._connected = False raise workflows.Disconnected("No connection to stomp host") - def _transaction_begin(self, transaction_id, **kwargs): + def _transaction_begin(self, transaction_id, subscription_id, **kwargs): """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. :param **kwargs: Further parameters for the transport layer. From d7b0f1881e082199d7883fc939821867f01e10fe Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 25 Oct 2021 09:53:13 +0100 Subject: [PATCH 5/6] same, but without breaking the API --- src/workflows/transport/common_transport.py | 6 ++++-- src/workflows/transport/offline_transport.py | 2 +- src/workflows/transport/pika_transport.py | 5 +---- src/workflows/transport/stomp_transport.py | 5 +---- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/workflows/transport/common_transport.py b/src/workflows/transport/common_transport.py index 61d305ac..d7fd6e7a 100644 --- a/src/workflows/transport/common_transport.py +++ b/src/workflows/transport/common_transport.py @@ -303,7 +303,9 @@ def transaction_begin(self, subscription_id: Optional[int] = None, **kwargs) -> ) else: self.log.debug("Starting transaction with ID %d", self.__transaction_id) - self._transaction_begin(self.__transaction_id, subscription_id, **kwargs) + self._transaction_begin( + self.__transaction_id, subscription_id=subscription_id, **kwargs + ) return self.__transaction_id def transaction_abort(self, transaction_id: int, **kwargs): @@ -413,7 +415,7 @@ def _nack(self, message_id, subscription_id, **kwargs): raise NotImplementedError("Transport interface not implemented") def _transaction_begin( - self, transaction_id: int, subscription_id: Optional[int] = None, **kwargs + self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. diff --git a/src/workflows/transport/offline_transport.py b/src/workflows/transport/offline_transport.py index fb6e9446..2838acb0 100644 --- a/src/workflows/transport/offline_transport.py +++ b/src/workflows/transport/offline_transport.py @@ -66,7 +66,7 @@ def _broadcast( ): self._output(f"Broadcasting {len(message)} bytes to {destination}", message) - def _transaction_begin(self, transaction_id, subscription_id, **kwargs): + def _transaction_begin(self, transaction_id, **kwargs): self._output(f"Starting transaction {transaction_id}") def _transaction_abort(self, transaction_id, **kwargs): diff --git a/src/workflows/transport/pika_transport.py b/src/workflows/transport/pika_transport.py index 831ceda8..88673f80 100644 --- a/src/workflows/transport/pika_transport.py +++ b/src/workflows/transport/pika_transport.py @@ -517,26 +517,23 @@ def _broadcast( ).result() def _transaction_begin( - self, transaction_id: int, subscription_id: Optional[int] = None, **kwargs + self, transaction_id: int, *, subscription_id: Optional[int] = None, **kwargs ) -> None: """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. :param subscription_id: Tie the transaction to a specific channel containing this subscription. - :param **kwargs: Further parameters for the transport layer. """ self._pika_thread.tx_select(transaction_id, subscription_id) def _transaction_abort(self, transaction_id: int, **kwargs) -> None: """Abort a transaction and roll back all operations. :param transaction_id: ID of transaction to be aborted. - :param **kwargs: Further parameters for the transport layer. """ self._pika_thread.tx_rollback(transaction_id) def _transaction_commit(self, transaction_id: int, **kwargs) -> None: """Commit a transaction. :param transaction_id: ID of transaction to be committed. - :param **kwargs: Further parameters for the transport layer. """ self._pika_thread.tx_commit(transaction_id) diff --git a/src/workflows/transport/stomp_transport.py b/src/workflows/transport/stomp_transport.py index 46a937ed..c484fb60 100644 --- a/src/workflows/transport/stomp_transport.py +++ b/src/workflows/transport/stomp_transport.py @@ -407,24 +407,21 @@ def _broadcast( self._connected = False raise workflows.Disconnected("No connection to stomp host") - def _transaction_begin(self, transaction_id, subscription_id, **kwargs): + def _transaction_begin(self, transaction_id, **kwargs): """Start a new transaction. :param transaction_id: ID for this transaction in the transport layer. - :param **kwargs: Further parameters for the transport layer. """ self._conn.begin(transaction=transaction_id) def _transaction_abort(self, transaction_id, **kwargs): """Abort a transaction and roll back all operations. :param transaction_id: ID of transaction to be aborted. - :param **kwargs: Further parameters for the transport layer. """ self._conn.abort(transaction_id) def _transaction_commit(self, transaction_id, **kwargs): """Commit a transaction. :param transaction_id: ID of transaction to be committed. - :param **kwargs: Further parameters for the transport layer. """ self._conn.commit(transaction_id) From 55a8dc97b32acc0db272ab117683b1ac1076b7ea Mon Sep 17 00:00:00 2001 From: Markus Gerstel Date: Mon, 25 Oct 2021 11:01:05 +0100 Subject: [PATCH 6/6] fix up tests --- tests/services/test_sample_transaction.py | 2 +- tests/transport/test_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/services/test_sample_transaction.py b/tests/services/test_sample_transaction.py index 2b57eca7..a2d52285 100644 --- a/tests/services/test_sample_transaction.py +++ b/tests/services/test_sample_transaction.py @@ -52,7 +52,7 @@ def test_txnservice_subscribes_to_channel(): p.initializing() mock_transport.subscribe.assert_called_once_with( - mock.ANY, p.receive_message, acknowledgement=True + mock.ANY, p.receive_message, acknowledgement=True, prefetch_count=1000 ) diff --git a/tests/transport/test_common.py b/tests/transport/test_common.py index 9439ec16..329fe792 100644 --- a/tests/transport/test_common.py +++ b/tests/transport/test_common.py @@ -248,7 +248,7 @@ def test_create_and_destroy_transactions(): t = ct.transaction_begin() assert t - ct._transaction_begin.assert_called_once_with(t) + ct._transaction_begin.assert_called_once_with(t, subscription_id=None) ct.transaction_abort(t) with pytest.raises(workflows.Error):