diff --git a/pyproject.toml b/pyproject.toml index b5b0c5b..3b57eb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ dependencies = [ "aio-pika ~= 9.4.2", + "pamqp ~= 3.3.0", "omotes-sdk-protocol ~= 0.1.1", "celery ~= 5.3.6", "typing-extensions ~= 4.11.0", diff --git a/src/omotes_sdk/internal/common/broker_interface.py b/src/omotes_sdk/internal/common/broker_interface.py index 3efbbf8..40f0516 100644 --- a/src/omotes_sdk/internal/common/broker_interface.py +++ b/src/omotes_sdk/internal/common/broker_interface.py @@ -7,7 +7,8 @@ from functools import partial import threading from types import TracebackType -from typing import Callable, Optional, Dict, Type, TypedDict +from typing import Callable, Optional, Dict, Type, TypedDict, cast +from datetime import timedelta from aio_pika import connect_robust, Message, DeliveryMode from aio_pika.abc import ( @@ -17,6 +18,8 @@ AbstractIncomingMessage, AbstractExchange, ) +from aio_pika.exceptions import ChannelClosed +from pamqp.common import Arguments from omotes_sdk.config import RabbitMQConfig @@ -113,6 +116,55 @@ def to_argument(self) -> AioPikaQueueTypeArguments: return result +@dataclass() +class QueueMessageTTLArguments(): + """Construct additional time-to-live arguments when declaring a queue.""" + + queue_ttl: Optional[timedelta] = None + """Expires and deletes the queue after a period of time when it is not used. + The timedelta must be convertible into a positive integer. + Ref: https://www.rabbitmq.com/docs/ttl#queue-ttl""" + message_ttl: Optional[timedelta] = None + """Expires and deletes the message within the queue after the defined TTL. + The timedelta must be convertible into a non-negative integer. + Ref: https://www.rabbitmq.com/docs/ttl#per-queue-message-ttl""" + dead_letter_routing_key: Optional[str] = None + """When specified, the expired message is republished to the designated dead letter queue. + If not set, the message's own routing key is used. + Ref: https://www.rabbitmq.com/docs/dlx#routing""" + dead_letter_exchange: Optional[str] = None + """Dead letter exchange name. + Ref: https://www.rabbitmq.com/docs/dlx""" + + def to_argument(self) -> Arguments: + """Convert the time-to-live variables to the aio-pika `declare_queue` keyword arguments. + + :return: The time-to-live keyword arguments in AMQP method arguments data type. + """ + arguments: Arguments = {} + # Ensure this is not None to avoid typecheck error. + arguments = cast(dict, arguments) + + if self.queue_ttl is not None: + if self.queue_ttl <= timedelta(0): + raise ValueError("queue_ttl must be a positive value, " + + f"{self.queue_ttl} received.") + arguments["x-expires"] = int(self.queue_ttl.total_seconds() * 1000) + if self.message_ttl is not None: + if self.message_ttl < timedelta(0): + raise ValueError("message_ttl can not be a negative value, " + + f"{self.message_ttl} received.") + if self.queue_ttl is not None and self.message_ttl > self.queue_ttl: + # Raise an error as it serves no purpose. + raise ValueError("message_ttl shall be smaller or equal to queue_ttl.") + arguments["x-message-ttl"] = int(self.message_ttl.total_seconds() * 1000) + if self.dead_letter_routing_key is not None: + arguments["x-dead-letter-routing-key"] = str(self.dead_letter_routing_key) + if self.dead_letter_exchange is not None: + arguments["x-dead-letter-exchange"] = str(self.dead_letter_exchange) + return arguments + + class BrokerInterface(threading.Thread): """Interface to RabbitMQ using aiopika.""" @@ -222,6 +274,7 @@ async def _declare_queue( queue_type: AMQPQueueType, bind_to_routing_key: Optional[str] = None, exchange_name: Optional[str] = None, + queue_message_ttl: Optional[QueueMessageTTLArguments] = None ) -> AbstractQueue: """Declare an AMQP queue. @@ -231,6 +284,7 @@ async def _declare_queue( key of the queue name. If none, the queue is only bound to the name of the queue. If not none, then the exchange_name must be set as well. :param exchange_name: Name of the exchange on which the messages will be published. + :param queue_message_ttl: Additional arguments to specify queue or message TTL. """ if bind_to_routing_key is not None and exchange_name is None: raise RuntimeError( @@ -238,8 +292,18 @@ async def _declare_queue( f"exchange name was provided." ) - logger.info("Declaring queue %s as %s", queue_name, queue_type) - queue = await self._channel.declare_queue(queue_name, **queue_type.to_argument()) + if queue_message_ttl is not None: + ttl_arguments = queue_message_ttl.to_argument() + else: + ttl_arguments = None + + logger.info("Declaring queue %s as %s with arguments as %s", + queue_name, + queue_type, + ttl_arguments) + queue = await self._channel.declare_queue(queue_name, + **queue_type.to_argument(), + arguments=ttl_arguments) if exchange_name is not None: if exchange_name not in self._exchanges: @@ -260,6 +324,7 @@ async def _declare_queue_and_add_subscription( bind_to_routing_key: Optional[str] = None, exchange_name: Optional[str] = None, delete_after_messages: Optional[int] = None, + queue_message_ttl: Optional[QueueMessageTTLArguments] = None ) -> None: """Declare an AMQP queue and subscribe to the messages. @@ -273,6 +338,7 @@ async def _declare_queue_and_add_subscription( :param exchange_name: Name of the exchange on which the messages will be published. :param delete_after_messages: Delete the subscription & queue after this limit of messages have been successfully processed. + :param queue_message_ttl: Additional arguments to specify queue or message TTL. """ if queue_name in self._queue_subscription_consumer_by_name: logger.error( @@ -282,7 +348,7 @@ async def _declare_queue_and_add_subscription( raise RuntimeError(f"Queue subscription for {queue_name} already exists.") queue = await self._declare_queue( - queue_name, queue_type, bind_to_routing_key, exchange_name + queue_name, queue_type, bind_to_routing_key, exchange_name, queue_message_ttl ) queue_consumer = QueueSubscriptionConsumer( @@ -296,6 +362,19 @@ async def _declare_queue_and_add_subscription( ) self._queue_subscription_tasks[queue_name] = queue_subscription_task + async def _queue_exists(self, queue_name: str) -> bool: + """Check if the queue exists. + + :param queue_name: Name of the queue to be checked. + """ + try: + await self._channel.get_queue(queue_name, ensure=True) + logger.info("The %s queue exists", queue_name) + return True + except ChannelClosed as err: + logger.warning(err) + return False + async def _remove_queue_subscription(self, queue_name: str) -> None: """Remove subscription from queue and delete the queue if one exists. @@ -393,6 +472,7 @@ def declare_queue( queue_type: AMQPQueueType, bind_to_routing_key: Optional[str] = None, exchange_name: Optional[str] = None, + queue_message_ttl: Optional[QueueMessageTTLArguments] = None ) -> None: """Declare an AMQP queue. @@ -402,6 +482,7 @@ def declare_queue( key of the queue name. If none, the queue is only bound to the name of the queue. If not none, then the exchange_name must be set as well. :param exchange_name: Name of the exchange on which the messages will be published. + :param queue_message_ttl: Additional arguments to specify queue or message TTL. """ asyncio.run_coroutine_threadsafe( self._declare_queue( @@ -409,6 +490,7 @@ def declare_queue( queue_type=queue_type, bind_to_routing_key=bind_to_routing_key, exchange_name=exchange_name, + queue_message_ttl=queue_message_ttl, ), self._loop, ).result() @@ -421,6 +503,7 @@ def declare_queue_and_add_subscription( bind_to_routing_key: Optional[str] = None, exchange_name: Optional[str] = None, delete_after_messages: Optional[int] = None, + queue_message_ttl: Optional[QueueMessageTTLArguments] = None ) -> None: """Declare an AMQP queue and subscribe to the messages. @@ -433,6 +516,7 @@ def declare_queue_and_add_subscription( :param exchange_name: Name of the exchange on which the messages will be published. :param delete_after_messages: Delete the subscription & queue after this limit of messages have been successfully processed. + :param queue_message_ttl: Additional arguments to specify queue or message TTL. """ asyncio.run_coroutine_threadsafe( self._declare_queue_and_add_subscription( @@ -442,10 +526,20 @@ def declare_queue_and_add_subscription( bind_to_routing_key=bind_to_routing_key, exchange_name=exchange_name, delete_after_messages=delete_after_messages, + queue_message_ttl=queue_message_ttl, ), self._loop, ).result() + def queue_exists(self, queue_name: str) -> bool: + """Check if the queue exists. + + :param queue_name: Name of the queue to be checked. + """ + return asyncio.run_coroutine_threadsafe( + self._queue_exists(queue_name=queue_name), self._loop + ).result() + def remove_queue_subscription(self, queue_name: str) -> None: """Remove subscription from queue and delete the queue if one exists. diff --git a/src/omotes_sdk/omotes_interface.py b/src/omotes_sdk/omotes_interface.py index 343bdc5..d0eabc2 100644 --- a/src/omotes_sdk/omotes_interface.py +++ b/src/omotes_sdk/omotes_interface.py @@ -5,7 +5,11 @@ from datetime import timedelta from typing import Callable, Optional, Union -from omotes_sdk.internal.common.broker_interface import BrokerInterface, AMQPQueueType +from omotes_sdk.internal.common.broker_interface import ( + BrokerInterface, + AMQPQueueType, + QueueMessageTTLArguments +) from omotes_sdk.config import RabbitMQConfig from omotes_sdk_protocol.job_pb2 import ( JobResult, @@ -101,6 +105,9 @@ class OmotesInterface: """How long the SDK should wait for the first reply when requesting the current workflow definitions from the orchestrator.""" + JOB_RESULT_MESSAGE_TTL: timedelta = timedelta(hours=48) + """Default value of job result message TTL.""" + def __init__( self, rabbitmq_config: RabbitMQConfig, @@ -171,6 +178,8 @@ def connect_to_submitted_job( callback_on_progress_update: Optional[Callable[[Job, JobProgressUpdate], None]], callback_on_status_update: Optional[Callable[[Job, JobStatusUpdate], None]], auto_disconnect_on_result: bool, + auto_dead_letter_after_ttl: Optional[timedelta] = JOB_RESULT_MESSAGE_TTL, + reconnect: bool = True ) -> None: """(Re)connect to the running job. @@ -184,7 +193,40 @@ def connect_to_submitted_job( :param auto_disconnect_on_result: Remove/disconnect from all queues pertaining to this job once the result is received and handled without exceptions through `callback_on_finished`. + :param auto_dead_letter_after_ttl: When erroneous situations occur (e.g. client is offline), + the job result message (if available) will be dead lettered after the given TTL, + and all queues of this job will be removed subsequently. Default to 48 hours if unset. + Set to `None` to turn off auto dead letter and clean up, but be aware this may lead to + messages and queues to be stored in RabbitMQ indefinitely + (which uses up memory & disk space). + :param reconnect: When True, first check the job queues status and raise an error if not + exist. Default to True. """ + job_results_queue_name = OmotesQueueNames.job_results_queue_name(job.id) + job_progress_queue_name = OmotesQueueNames.job_progress_queue_name(job.id) + job_status_queue_name = OmotesQueueNames.job_status_queue_name(job.id) + + if reconnect: + logger.info("Reconnect to the submitted job %s is set to True. " + + "Checking job queues status...", job.id) + if not self.broker_if.queue_exists(job_results_queue_name): + raise RuntimeError( + f"The {job_results_queue_name} queue does not exist or is removed. " + "Abort reconnecting to the queue." + ) + if (callback_on_progress_update + and not self.broker_if.queue_exists(job_progress_queue_name)): + raise RuntimeError( + f"The {job_progress_queue_name} queue does not exist or is removed. " + "Abort reconnecting to the queue." + ) + if (callback_on_status_update + and not self.broker_if.queue_exists(job_status_queue_name)): + raise RuntimeError( + f"The {job_status_queue_name} queue does not exist or is removed. " + "Abort reconnecting to the queue." + ) + if auto_disconnect_on_result: logger.info("Connecting to update for job %s with auto disconnect on result", job.id) auto_disconnect_handler = self._autodelete_progres_status_queues_on_result @@ -192,6 +234,27 @@ def connect_to_submitted_job( logger.info("Connecting to update for job %s and expect manual disconnect", job.id) auto_disconnect_handler = None + # TODO: handle reconnection after the message is dead lettered but queue still exists. + + if auto_dead_letter_after_ttl is not None: + message_ttl = auto_dead_letter_after_ttl + queue_ttl = auto_dead_letter_after_ttl * 2 + logger.info("Auto dead letter and cleanup on error after TTL is set. " + + "The leftover job result message will be dead lettered after %s, " + + "and leftover job queues will be discarded after %s.", + message_ttl, queue_ttl) + job_result_queue_message_ttl = QueueMessageTTLArguments( + queue_ttl=queue_ttl, + message_ttl=message_ttl, + dead_letter_routing_key=OmotesQueueNames.job_result_dead_letter_queue_name(), + dead_letter_exchange=OmotesQueueNames.omotes_exchange_name()) + job_progress_status_queue_ttl = QueueMessageTTLArguments(queue_ttl=queue_ttl) + else: + logger.info("Auto dead letter and cleanup on error after TTL is not set. " + + "Manual cleanup on leftover job queues and messages might be required.") + job_result_queue_message_ttl = None + job_progress_status_queue_ttl = None + callback_handler = JobSubmissionCallbackHandler( job, callback_on_finished, @@ -201,25 +264,28 @@ def connect_to_submitted_job( ) self.broker_if.declare_queue_and_add_subscription( - queue_name=OmotesQueueNames.job_results_queue_name(job.id), + queue_name=job_results_queue_name, callback_on_message=callback_handler.callback_on_finished_wrapped, queue_type=AMQPQueueType.DURABLE, exchange_name=OmotesQueueNames.omotes_exchange_name(), delete_after_messages=1, + queue_message_ttl=job_result_queue_message_ttl ) if callback_on_progress_update: self.broker_if.declare_queue_and_add_subscription( - queue_name=OmotesQueueNames.job_progress_queue_name(job.id), + queue_name=job_progress_queue_name, callback_on_message=callback_handler.callback_on_progress_update_wrapped, queue_type=AMQPQueueType.DURABLE, exchange_name=OmotesQueueNames.omotes_exchange_name(), + queue_message_ttl=job_progress_status_queue_ttl ) if callback_on_status_update: self.broker_if.declare_queue_and_add_subscription( - queue_name=OmotesQueueNames.job_status_queue_name(job.id), + queue_name=job_status_queue_name, callback_on_message=callback_handler.callback_on_status_update_wrapped, queue_type=AMQPQueueType.DURABLE, exchange_name=OmotesQueueNames.omotes_exchange_name(), + queue_message_ttl=job_progress_status_queue_ttl ) def submit_job( @@ -232,6 +298,7 @@ def submit_job( callback_on_progress_update: Optional[Callable[[Job, JobProgressUpdate], None]], callback_on_status_update: Optional[Callable[[Job, JobStatusUpdate], None]], auto_disconnect_on_result: bool, + auto_dead_letter_after_ttl: Optional[timedelta] = JOB_RESULT_MESSAGE_TTL ) -> Job: """Submit a new job and connect to progress and status updates and the job result. @@ -249,6 +316,12 @@ def submit_job( :param auto_disconnect_on_result: Remove/disconnect from all queues pertaining to this job once the result is received and handled without exceptions through `callback_on_finished`. + :param auto_dead_letter_after_ttl: When erroneous situations occur (e.g. client is offline), + the job result message (if available) will be dead lettered after the given TTL, + and all queues of this job will be removed subsequently. Default to 48 hours if unset. + Set to `None` to turn off auto dead letter and clean up, but be aware this may lead to + messages and queues to be stored in RabbitMQ indefinitely + (which uses up memory & disk space). :raises UnknownWorkflowException: If `workflow_type` is unknown as a possible workflow in this interface. :return: The job handle which is created. This object needs to be saved persistently by the @@ -260,6 +333,7 @@ def submit_job( raise UnknownWorkflowException() job = Job(id=uuid.uuid4(), workflow_type=workflow_type) + reconnect = False logger.info("Submitting job %s", job.id) self.connect_to_submitted_job( job, @@ -267,6 +341,8 @@ def submit_job( callback_on_progress_update, callback_on_status_update, auto_disconnect_on_result, + auto_dead_letter_after_ttl, + reconnect ) if job_timeout is not None: diff --git a/src/omotes_sdk/queue_names.py b/src/omotes_sdk/queue_names.py index 45467f0..7650861 100644 --- a/src/omotes_sdk/queue_names.py +++ b/src/omotes_sdk/queue_names.py @@ -82,3 +82,11 @@ def request_available_workflows_queue_name() -> str: :return: The queue name. """ return "request_available_workflows" + + @staticmethod + def job_result_dead_letter_queue_name() -> str: + """Generate the job result dead letter queue name. + + :return: The queue name. + """ + return "job_result_message_dlq" diff --git a/unit_test/internal/common/__init__.py b/unit_test/internal/common/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unit_test/internal/common/test_queue_message_ttl.py b/unit_test/internal/common/test_queue_message_ttl.py new file mode 100644 index 0000000..5822986 --- /dev/null +++ b/unit_test/internal/common/test_queue_message_ttl.py @@ -0,0 +1,111 @@ +import unittest +from datetime import timedelta +from omotes_sdk.internal.common.broker_interface import QueueMessageTTLArguments + + +class TestQueueMessageTTLArguments(unittest.TestCase): + def test__to_argument__no_arguments(self) -> None: + # Arrange / Act + args = QueueMessageTTLArguments() + + # Assert + self.assertEqual(args.to_argument(), {}) + + def test__to_argument__queue_ttl(self) -> None: + # Arrange + q_ttl = timedelta(seconds=60) + + # Act + args = QueueMessageTTLArguments(queue_ttl=q_ttl) + + # Assert + self.assertEqual(args.to_argument(), {"x-expires": 60000}) + + def test__to_argument__negative_queue_ttl(self) -> None: + # Arrange + q_ttl = timedelta(seconds=-60) + + # Act / Assert + with self.assertRaises(ValueError): + QueueMessageTTLArguments(queue_ttl=q_ttl).to_argument() + + def test__to_argument__zero_queue_ttl(self) -> None: + # Arrange + q_ttl = timedelta(seconds=0) + + # Act / Assert + with self.assertRaises(ValueError): + QueueMessageTTLArguments(queue_ttl=q_ttl).to_argument() + + def test__to_argument__message_ttl(self) -> None: + # Arrange + msg_ttl = timedelta(seconds=30) + + # Act + args = QueueMessageTTLArguments(message_ttl=msg_ttl) + + # Assert + self.assertEqual(args.to_argument(), {"x-message-ttl": 30000}) + + def test__to_argument__negative_message_ttl(self) -> None: + # Arrange + msg_ttl = timedelta(seconds=-30) + + # Act / Assert + with self.assertRaises(ValueError): + QueueMessageTTLArguments(message_ttl=msg_ttl).to_argument() + + def test__to_argument__message_ttl_larger_than_queue_ttl(self) -> None: + # Arrange + q_ttl = timedelta(seconds=30) + msg_ttl = timedelta(seconds=60) + + # Act / Assert + with self.assertRaises(ValueError): + QueueMessageTTLArguments( + queue_ttl=q_ttl, + message_ttl=msg_ttl + ).to_argument() + + def test__to_argument__dead_letter_routing_key(self) -> None: + # Arrange + dl_routing_key = "test-dlq" + + # Act + args = QueueMessageTTLArguments(dead_letter_routing_key=dl_routing_key) + + # Assert + self.assertEqual(args.to_argument(), {"x-dead-letter-routing-key": "test-dlq"}) + + def test__to_argument__dead_letter_exchange(self) -> None: + # Arrange + dl_exchange = "test-exchange" + + # Act + args = QueueMessageTTLArguments(dead_letter_exchange=dl_exchange) + + # Assert + self.assertEqual(args.to_argument(), {"x-dead-letter-exchange": "test-exchange"}) + + def test__to_argument__all_arguments(self) -> None: + # Arrange + q_ttl = timedelta(minutes=2) + msg_ttl = timedelta(minutes=1) + dl_routing_key = "test-dlq" + dl_exchange = "test-exchange" + + # Act + args = QueueMessageTTLArguments( + queue_ttl=q_ttl, + message_ttl=msg_ttl, + dead_letter_routing_key=dl_routing_key, + dead_letter_exchange=dl_exchange + ) + + # Assert + self.assertEqual(args.to_argument(), { + "x-expires": 120000, + "x-message-ttl": 60000, + "x-dead-letter-routing-key": "test-dlq", + "x-dead-letter-exchange": "test-exchange" + })