diff --git a/arq/connections.py b/arq/connections.py index c5dd0fc4..250e3d36 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -13,8 +13,8 @@ from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix -from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job +from .constants import default_queue_name, expires_extra_ms, in_progress_key_prefix, job_key_prefix, result_key_prefix +from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, deserialize_job_raw, serialize_job from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') @@ -126,6 +126,8 @@ async def enqueue_job( _defer_by: Union[None, int, float, timedelta] = None, _expires: Union[None, int, float, timedelta] = None, _job_try: Optional[int] = None, + _debounce: bool = False, + _debounce_max: Union[None, int, float, timedelta] = None, **kwargs: Any, ) -> Optional[Job]: """ @@ -140,9 +142,15 @@ async def enqueue_job( :param _expires: do not start or retry a job after this duration; defaults to 24 hours plus deferring time, if any :param _job_try: useful when re-enqueueing jobs within a job + :param _debounce: if True and a queued job with the same ID exists, update its defer time + instead of returning None + :param _debounce_max: maximum total time from the original enqueue time before debouncing + stops and the job is allowed to run :param kwargs: any keyword arguments to pass to the function :return: :class:`arq.jobs.Job` instance or ``None`` if a job with this ID already exists """ + if _debounce and not _job_id: + raise RuntimeError("'_debounce' requires '_job_id' to be set") if _queue_name is None: _queue_name = self.default_queue_name job_id = _job_id or uuid4().hex @@ -152,22 +160,36 @@ async def enqueue_job( defer_by_ms = to_ms(_defer_by) expires_ms = to_ms(_expires) + debounce_max_ms = to_ms(_debounce_max) async with self.pipeline(transaction=True) as pipe: await pipe.watch(job_key) - if await pipe.exists(job_key, result_key_prefix + job_id): + job_exists = await pipe.exists(job_key) + result_exists = await pipe.exists(result_key_prefix + job_id) + in_progress = await pipe.exists(in_progress_key_prefix + job_id) if _debounce else False + can_debounce = _debounce and job_exists and not result_exists and not in_progress + + if (job_exists or result_exists) and not can_debounce: await pipe.reset() return None - enqueue_time_ms = timestamp_ms() + now_ms = timestamp_ms() + if can_debounce: + existing_job_data = await pipe.get(job_key) + _, _, _, _, enqueue_time_ms = deserialize_job_raw(existing_job_data, deserializer=self.job_deserializer) + if debounce_max_ms is not None and now_ms - enqueue_time_ms >= debounce_max_ms: + await pipe.reset() + return None + else: + enqueue_time_ms = now_ms if _defer_until is not None: score = to_unix_ms(_defer_until) elif defer_by_ms: - score = enqueue_time_ms + defer_by_ms + score = now_ms + defer_by_ms else: - score = enqueue_time_ms + score = now_ms - expires_ms = expires_ms or score - enqueue_time_ms + self.expires_extra_ms + expires_ms = expires_ms or score - now_ms + self.expires_extra_ms job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) pipe.multi() diff --git a/tests/test_main.py b/tests/test_main.py index 96f6beac..6c1d5d45 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -283,6 +283,76 @@ async def test_get_jobs(arq_redis: ArqRedis): assert isinstance(jobs[2], JobDef) +async def test_debounce_requires_job_id(arq_redis: ArqRedis): + # when + with pytest.raises(RuntimeError, match="'_debounce' requires '_job_id'"): + await arq_redis.enqueue_job('foobar', _debounce=True) + + +async def test_debounce_updates_defer_time(arq_redis: ArqRedis): + # given: a job already enqueued with a defer + j1 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _defer_by=5) + assert isinstance(j1, Job) + score1 = await arq_redis.zscore(default_queue_name, 'debounce_id') + + await asyncio.sleep(0.05) + + # when: we enqueue the same job with debounce and the same defer + j2 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _debounce=True, _defer_by=5) + + # then: the job is returned (not None) and the score is updated (deferred from now) + assert isinstance(j2, Job) + score2 = await arq_redis.zscore(default_queue_name, 'debounce_id') + assert score2 > score1 + + +async def test_debounce_preserves_enqueue_time(arq_redis: ArqRedis): + # given: a job enqueued + j1 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _defer_by=5) + assert isinstance(j1, Job) + info1 = await j1.info() + + await asyncio.sleep(0.05) + + # when: debounced + j2 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _debounce=True, _defer_by=10) + assert isinstance(j2, Job) + info2 = await j2.info() + + # then: enqueue_time is preserved from the original job + assert info2.enqueue_time == info1.enqueue_time + + +async def test_debounce_max_stops_debouncing(arq_redis: ArqRedis): + # given: a job enqueued with a very short debounce_max + j1 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _defer_by=5) + assert isinstance(j1, Job) + + # when: we wait longer than debounce_max and try to debounce + await asyncio.sleep(0.1) + j2 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _debounce=True, _defer_by=10, _debounce_max=0.05) + + # then: debounce is refused, returns None (let existing job run) + assert j2 is None + + +async def test_debounce_does_not_touch_in_progress_job(arq_redis: ArqRedis): + # given: a job that is in progress (has in_progress key) + from arq.constants import in_progress_key_prefix + + await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _defer_by=5) + score_before = await arq_redis.zscore(default_queue_name, 'debounce_id') + await arq_redis.set(in_progress_key_prefix + 'debounce_id', b'1') + + # when: we try to debounce + j2 = await arq_redis.enqueue_job('foobar', _job_id='debounce_id', _debounce=True, _defer_by=10) + + # then: returns None, job score is unchanged + assert j2 is None + score_after = await arq_redis.zscore(default_queue_name, 'debounce_id') + assert score_after == score_before + + async def test_enqueue_multiple(arq_redis: ArqRedis, caplog): caplog.set_level(logging.DEBUG) results = await asyncio.gather(*[arq_redis.enqueue_job('foobar', i, _job_id='testing') for i in range(10)])