diff --git a/pyproject.toml b/pyproject.toml index b844ca2..f0fa7c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "nats_queue" -version = "1.1.1" +version = "1.1.3" description = "" authors = ["Kristina Shishkina "] readme = "README.md" diff --git a/src/nats_queue/nats_job.py b/src/nats_queue/nats_job.py index cc283c7..a87bc09 100644 --- a/src/nats_queue/nats_job.py +++ b/src/nats_queue/nats_job.py @@ -12,7 +12,7 @@ def __init__( data: Dict[str, Any] = {}, timeout=None, delay=0, - meta=None, + meta={}, ): for param, param_name in [ (queue_name, "queue_name"), @@ -25,7 +25,7 @@ def __init__( self.queue_name = queue_name self.name = name self.data = data - self.meta = meta or { + self.meta = meta | { "retry_count": 0, "start_time": (datetime.now() + timedelta(seconds=delay)).isoformat(), "timeout": timeout, diff --git a/src/nats_queue/nats_queue.py b/src/nats_queue/nats_queue.py index 3388073..b1542fd 100644 --- a/src/nats_queue/nats_queue.py +++ b/src/nats_queue/nats_queue.py @@ -1,5 +1,7 @@ import logging +from typing import Dict, List, Optional, Union from nats.aio.client import Client +from nats.js.kv import KeyValue from nats.js.errors import BadRequestError from nats.errors import ConnectionClosedError import json @@ -36,6 +38,7 @@ def __init__( self.client = client self.manager = None self.duplicate_window = duplicate_window + self.kv: Optional[KeyValue] = None self.logger: Logger = logger self.logger.info( @@ -53,6 +56,11 @@ async def setup(self): duplicate_window=self.duplicate_window, ) self.logger.info(f"Stream '{self.name}' created successfully.") + + self.kv = await self.manager.create_key_value( + bucket=f"{self.name}_parent_id" + ) + except BadRequestError: self.logger.warning( f"Stream '{self.name}' already exists. Attempting to update" @@ -106,3 +114,31 @@ async def addJobs(self, jobs: list[Job], priority: int = 1): for job in jobs: await self.addJob(job, priority) + + async def addFlowJob( + self, tree: Dict[str, Union[List[Job], Job]], priority: int = 1 + ): + async def traverse(node: Dict[str, Union[List[Job], Job]], parent_id=None): + current_job: Job = node["job"] + if parent_id: + current_job.meta["parent_id"] = parent_id + + children = node.get("children", []) + if not children: + return [current_job] + + await self.kv.put( + current_job.id, + json.dumps( + {**current_job.to_dict(), "children_count": len(children)} + ).encode(), + ) + + deepest_jobs = [] + for child in children: + deepest_jobs.extend(await traverse(child, current_job.id)) + + return deepest_jobs + + deepest_jobs = await traverse(tree) + await self.addJobs(deepest_jobs, priority) diff --git a/src/nats_queue/nats_worker.py b/src/nats_queue/nats_worker.py index 9acd185..2536dc4 100644 --- a/src/nats_queue/nats_worker.py +++ b/src/nats_queue/nats_worker.py @@ -9,8 +9,8 @@ from nats_queue.nats_limiter import FixedWindowLimiter, IntervalLimiter from nats.js.client import JetStreamContext from nats.aio.client import Client +from nats.js.kv import KeyValue from nats.aio.msg import Msg -from nats_queue.nats_job import Job from nats.errors import TimeoutError logger = logging.getLogger("nats_queue") @@ -56,6 +56,7 @@ def __init__( self.processing_now: int = 0 self.loop_task: Optional[asyncio.Task] = None self.logger: Logger = logger + self.kv: Optional[KeyValue] = None self.logger.info( ( @@ -69,6 +70,7 @@ async def setup(self): try: self.manager = self.client.jetstream() self.consumers = await self.get_subscriptions() + self.kv = await self.manager.key_value(f"{self.name}_parent_id") except Exception as e: raise e @@ -91,11 +93,13 @@ async def loop(self): while self.running: for consumer in self.consumers: max_jobs = self.limiter.get(self.concurrency - self.processing_now) - if max_jobs == 0: + if max_jobs <= 0: continue jobs = await self.fetch_messages(consumer, max_jobs) if jobs: break + else: + jobs = [] for job in jobs: self.limiter.inc() @@ -103,10 +107,46 @@ async def loop(self): await asyncio.sleep(self.limiter.timeout()) + async def _mark_parents_failed(self, job_data: dict): + parent_id = job_data["meta"].get("parent_id") + if not parent_id: + return + + parent_job = await self.kv.get(parent_id) + if not parent_job: + self.logger.warning( + f"Parent job with ID {parent_id} not found in KV store." + ) + return + + parent_job_data = json.loads(parent_job.value.decode()) + + parent_job_data["meta"]["failed"] = True + await self._publish_parent_job(parent_job_data) + await self._mark_parents_failed(parent_job_data) + + async def _publish_parent_job(self, parent_job_data): + subject = f"{parent_job_data['queue_name']}.{parent_job_data['name']}.1" + job_bytes = json.dumps(parent_job_data).encode() + await self.manager.publish( + subject, job_bytes, headers={"Nats-Msg-Id": parent_job_data["id"]} + ) + self.logger.info( + f"Parent Job id={parent_job_data['id']} " + f"subject={subject} added successfully" + ) + async def _process_task(self, job: Msg): try: self.processing_now += 1 job_data = json.loads(job.data.decode()) + if job_data["meta"].get("faild"): + await job.term() + self.logger.warning( + f"Job: {job_data['name']} id={job_data['id']} failed because " + f"child job did not complete successfully " + ) + job_start_time = datetime.fromisoformat(job_data["meta"]["start_time"]) if job_start_time > datetime.now(): planned_time = job_start_time - datetime.now() @@ -124,8 +164,11 @@ async def _process_task(self, job: Msg): if job_data.get("meta").get("retry_count") > self.max_retries: await job.term() self.logger.warning( - f"Job: {job_data['name']} id={job_data['id']} max retries exceeded" + f"Job: {job_data['name']} id={job_data['id']} " + f"failed max retries exceeded" ) + + await self._mark_parents_failed(job_data) return self.logger.info( @@ -143,27 +186,37 @@ async def _process_task(self, job: Msg): f'Job: {job_data["name"]} id={job_data["id"]} is completed' ) + parent_id = job_data["meta"].get("parent_id") + if parent_id: + parent_job_data = json.loads( + (await self.kv.get(parent_id)).value.decode() + ) + parent_job_data["children_count"] -= 1 + await self.kv.put(parent_id, json.dumps(parent_job_data).encode()) + if parent_job_data["children_count"] == 0: + await self.kv.delete(parent_id) + await self._publish_parent_job(parent_job_data) + except Exception as e: if isinstance(e, asyncio.TimeoutError): self.logger.error( - f"Job: {job_data['name']} id={job_data['id']} TimeoutError: {e}" + f"Job: {job_data['name']} id={job_data['id']} " + f"TimeoutError start retry" ) else: - self.logger.error(f"Error while processing job {job_data['id']}: {e}") - + self.logger.error( + f"Error while processing job {job_data['id']}: {e} start retry" + ) + new_id = f"{uuid.uuid4()}_{int(time.time())}" job_data["meta"]["retry_count"] += 1 - new_job = Job( - queue_name=job_data["queue_name"], - name=job_data["name"], - data=job_data["data"], - meta=job_data["meta"], - ) - job_bytes = json.dumps(new_job.to_dict()).encode() + job_data["id"] = new_id + + job_bytes = json.dumps(job_data).encode() await job.term() await self.manager.publish( job.subject, job_bytes, - headers={"Nats-Msg-Id": f"{uuid.uuid4()}_{int(time.time())}"}, + headers={"Nats-Msg-Id": new_id}, ) finally: self.processing_now -= 1 diff --git a/tests/test_job.py b/tests/test_job.py index 6ca51f3..86c061d 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -7,6 +7,7 @@ def test_job_initialization(): queue_name="my_queue", name="task_1", data={"key": "value"}, + meta={"parent_id": "1"}, ) assert job.queue_name == "my_queue" @@ -14,6 +15,7 @@ def test_job_initialization(): assert job.data == {"key": "value"} assert job.meta["retry_count"] == 0 assert job.meta["timeout"] is None + assert job.meta["parent_id"] == "1" def test_job_initialization_with_delay(): diff --git a/tests/test_queue.py b/tests/test_queue.py index 3145c11..bfae7b1 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1,3 +1,4 @@ +from typing import Dict, List, Union import pytest import pytest_asyncio import json @@ -243,3 +244,39 @@ async def test_create_queue_with_one_conf(get_client): queue2 = Queue(client, name="my_queue", duplicate_window=1) await queue2.setup() + + +@pytest.mark.asyncio +async def test_create_flow_job(get_client): + + client = get_client + queue = Queue(client, name="my_queue") + await queue.setup() + + flowJob: Dict[ + str, Union[Job, List[Dict[str, Union[Job, List[Dict[str, Job]]]]]] + ] = { + "job": Job("my_queue", "parent_job"), + "children": [ + { + "job": Job("my_queue", "child_job_1"), + "children": [ + { + "job": Job("my_queue", "child_job_1_1"), + }, + {"job": Job("my_queue", "child_job_1_2")}, + ], + }, + {"job": Job("my_queue", "child_job_2")}, + ], + } + parent_job_id = [flowJob["job"].id, flowJob["children"][0]["job"].id] + + await queue.addFlowJob(flowJob) + key_value = await queue.manager.key_value(f"{queue.name}_parent_id") + kv_keys = await key_value.keys() + assert set(kv_keys) == set(parent_job_id) + + stream_info = await queue.manager.stream_info(queue.name) + messages = stream_info.state.messages + assert messages == 3 diff --git a/tests/test_workers.py b/tests/test_workers.py index 988d8f4..ef01063 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -1,5 +1,5 @@ import json -from typing import Dict +from typing import Dict, List, Union import pytest import asyncio import pytest_asyncio @@ -10,6 +10,7 @@ from nats.aio.client import Client from nats.js.client import JetStreamContext import nats +from nats.js.errors import NoKeysError @pytest_asyncio.fixture @@ -62,10 +63,11 @@ async def test_worker_initialization(get_client): assert worker.processing_now == 0 assert worker.loop_task is None assert worker.consumers is None + assert worker.kv is None @pytest.mark.asyncio -async def test_worker_connect_success(get_client): +async def test_worker_setup_success(get_client): client: Client = get_client queue = Queue(client, "my_queue") await queue.setup() @@ -82,7 +84,7 @@ async def test_worker_connect_success(get_client): @pytest.mark.asyncio -async def test_worker_connect_faild(get_client): +async def test_worker_setup_faild(get_client): client: Client = get_client queue = Queue(client, "my_queue") await queue.setup() @@ -96,6 +98,21 @@ async def test_worker_connect_faild(get_client): await worker.setup() +@pytest.mark.asyncio +async def test_worker_setup_unknow_queue(get_client): + client: Client = get_client + queue = Queue(client, "my_queue") + await queue.setup() + + worker = Worker( + client, + name="my_queue_1", + processor=process_job, + ) + with pytest.raises(Exception): + await worker.setup() + + @pytest.mark.asyncio async def test_worker_connect_stop_success(get_client): client: Client = get_client @@ -526,12 +543,14 @@ async def test_worker_start_one_worker(get_client): await worker.stop() stream_info = await worker.manager.streams_info() - assert len(stream_info) == 1 - assert stream_info[0].config.subjects == ["my_queue.*.*"] - stream = stream_info[0].config.name + assert len(stream_info) == 2 + stream_name = [stream.config.name for stream in stream_info] + assert set(stream_name) == set(["my_queue", f"KV_{queue.name}_parent_id"]) + assert stream_info[1].config.subjects == ["my_queue.*.*"] + stream = stream_info[1].config.name assert stream == queue.name - worker_count = stream_info[0].state.consumer_count + worker_count = stream_info[1].state.consumer_count assert worker_count == 1 worker_info = await worker.manager.consumers_info(stream) @@ -580,25 +599,25 @@ async def test_worker_start_many_worker_with_one_durable(get_client): await worker2.stop() stream_info = await worker.manager.streams_info() - assert len(stream_info) == 2 + assert len(stream_info) == 4 stream_name = [] for stream in stream_info: - stream_name.extend(stream.config.subjects) - assert stream_name == [ - "my_queue.*.*", - "my_queue_2.*.*", - ] + stream_name.append(stream.config.name) + assert set(stream_name) == set( + [ + "KV_my_queue_2_parent_id", + "KV_my_queue_parent_id", + "my_queue", + "my_queue_2", + ] + ) - stream = stream_info[0].config.name - assert stream == queue.name worker_count = 0 for stream in stream_info: worker_count += stream.state.consumer_count assert worker_count == 2 - worker_info1 = await worker.manager.consumers_info(queue.name) assert len(worker_info1) == 1 - worker_name1 = worker_info1[0].name assert worker_name1 == "worker_group_1" @@ -613,8 +632,8 @@ async def test_worker_start_many_worker_with_one_durable(get_client): @pytest.mark.asyncio async def test_worker_start(get_client): - nc = get_client - queue = Queue(nc, name="my_queue") + client = get_client + queue = Queue(client, name="my_queue") await queue.setup() job = Job(queue_name="my_queue", name="task_1", data={"key": "value"}, timeout=1) @@ -622,7 +641,7 @@ async def test_worker_start(get_client): await queue.addJobs([job, job2]) worker = Worker( - nc, + client, name="my_queue", concurrency=3, processor=process_job, @@ -632,7 +651,145 @@ async def test_worker_start(get_client): await asyncio.sleep(4) await worker.stop() stream_info = await worker.manager.streams_info() - assert len(stream_info) == 1 + assert len(stream_info) == 2 + + +@pytest.mark.asyncio +async def test_publish_parent_job(get_client): + client = get_client + queue = Queue(client, name="my_queue") + await queue.setup() + + job = Job( + queue_name="my_queue", + name="parent_job", + data={"key": "value"}, + timeout=1, + meta={"parent_id": "1"}, + ) + + worker = Worker( + client, + name="my_queue", + concurrency=3, + processor=process_job, + ) + await worker.setup() + await worker._publish_parent_job(job.to_dict()) + + job_data = json.loads( + ( + await worker.manager.get_last_msg(queue.name, f"{queue.name}.*.*") + ).data.decode() + ) + assert job.name == job_data["name"] + + +@pytest.mark.asyncio +async def test__mark_parents_failed(get_client): + client = get_client + queue = Queue(client, name="my_queue") + await queue.setup() + job_1 = Job(queue_name="my_queue", name="parent_job") + job_2 = Job(queue_name="my_queue", name="child_job_1", meta={"parent_id": job_1.id}) + job_3 = Job(queue_name="my_queue", name="child_job_2", meta={"parent_id": job_2.id}) + + await queue.kv.put( + job_1.id, + json.dumps({**job_1.to_dict(), "children_count": 1}).encode(), + ) + await queue.kv.put( + job_2.id, + json.dumps({**job_2.to_dict(), "children_count": 1}).encode(), + ) + + worker = Worker( + client, + name="my_queue", + concurrency=3, + processor=process_job, + ) + await worker.setup() + + await worker._mark_parents_failed(job_3.to_dict()) + + sub = await worker.manager.pull_subscribe(f"{queue.name}.*.*") + msg = await sub.fetch(5) + job_data = [json.loads(job.data.decode()) for job in msg] + + assert job_data[0]["name"] == "child_job_1" + assert job_data[0]["meta"]["failed"] is True + + assert job_data[1]["name"] == "parent_job" + assert job_data[1]["meta"]["failed"] is True + + +@pytest.mark.asyncio +async def test_process_task_with_flow_job(get_client): + client = get_client + queue = Queue(client, name="my_queue") + await queue.setup() + + parent_job = Job("my_queue", "parent_job") + child_job_1 = Job("my_queue", "child_job_1") + child_job_1_1 = Job("my_queue", "child_job_1_1") + child_job_1_2 = Job("my_queue", "child_job_1_2") + child_job_1_2_1 = Job("my_queue", "child_job_1_2_1") + child_job_1_2_2 = Job("my_queue", "child_job_1_2_2") + child_job_2 = Job("my_queue", "child_job_2") + + flowJob: Dict[ + str, Union[Job, List[Dict[str, Union[Job, List[Dict[str, Job]]]]]] + ] = { + "job": parent_job, + "children": [ + { + "job": child_job_1, + "children": [ + {"job": child_job_1_1}, + { + "job": child_job_1_2, + "children": [ + {"job": child_job_1_2_1}, + {"job": child_job_1_2_2}, + ], + }, + ], + }, + {"job": child_job_2}, + ], + } + + await queue.addFlowJob(flowJob) + + keys = await queue.kv.keys() + + assert set(keys) == set( + [ + parent_job.id, + child_job_1.id, + child_job_1_2.id, + ], + ) + + messages = (await queue.manager.stream_info(queue.name)).state.messages + assert messages == 4 + + worker = Worker( + client, + name="my_queue", + concurrency=3, + processor=process_job, + limiter={"max": 2, "duration": 6}, + ) + + await worker.setup() + await worker.start() + await asyncio.sleep(10) + await worker.stop() + + with pytest.raises(NoKeysError): + await queue.kv.keys() async def process_job(job_data: Dict):