diff --git a/state-manager/app/controller/get_runs.py b/state-manager/app/controller/get_runs.py index 93b24b3d..f4549b56 100644 --- a/state-manager/app/controller/get_runs.py +++ b/state-manager/app/controller/get_runs.py @@ -1,5 +1,3 @@ -import asyncio -from beanie.operators import In, NotIn from ..models.run_models import RunsResponse, RunListItem, RunStatusEnum from ..models.db.state import State @@ -9,40 +7,129 @@ logger = LogsManager().get_logger() -async def get_run_status(run_id: str) -> RunStatusEnum: - if await State.find(State.run_id == run_id, In(State.status, [StateStatusEnum.ERRORED, StateStatusEnum.NEXT_CREATED_ERROR])).count() > 0: - return RunStatusEnum.FAILED - elif await State.find(State.run_id == run_id, NotIn(State.status, [StateStatusEnum.SUCCESS, StateStatusEnum.RETRY_CREATED, StateStatusEnum.PRUNED])).count() == 0: - return RunStatusEnum.SUCCESS - else: - return RunStatusEnum.PENDING - -async def get_run_info(run: Run) -> RunListItem: - return RunListItem( - run_id=run.run_id, - graph_name=run.graph_name, - success_count=await State.find(State.run_id == run.run_id, In(State.status, [StateStatusEnum.SUCCESS, StateStatusEnum.PRUNED])).count(), - pending_count=await State.find(State.run_id == run.run_id, In(State.status, [StateStatusEnum.CREATED, StateStatusEnum.QUEUED, StateStatusEnum.EXECUTED])).count(), - errored_count=await State.find(State.run_id == run.run_id, In(State.status, [StateStatusEnum.ERRORED, StateStatusEnum.NEXT_CREATED_ERROR])).count(), - retried_count=await State.find(State.run_id == run.run_id, State.status == StateStatusEnum.RETRY_CREATED).count(), - total_count=await State.find(State.run_id == run.run_id,).count(), - status=await get_run_status(run.run_id), - created_at=run.created_at - ) - - async def get_runs(namespace_name: str, page: int, size: int, x_exosphere_request_id: str) -> RunsResponse: try: logger.info(f"Getting runs for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) runs = await Run.find(Run.namespace_name == namespace_name).sort(-Run.created_at).skip((page - 1) * size).limit(size).to_list() # type: ignore + + if len(runs) == 0: + return RunsResponse( + namespace=namespace_name, + total=await Run.find(Run.namespace_name == namespace_name).count(), + page=page, + size=size, + runs=[] + ) + + look_up_table = { + run.run_id: run for run in runs + } + viewed = set() + + + data_cursor = await State.get_pymongo_collection().aggregate( + [ + { + "$match": { + "run_id": { + "$in": [run.run_id for run in runs] + } + } + }, + { + "$group": { + "_id": "$run_id", + "total_count": { + "$sum": 1 + }, + "success_count": { + "$sum": { + "$cond": { + "if": {"$in": ["$status", [StateStatusEnum.SUCCESS, StateStatusEnum.PRUNED]]}, + "then": 1, + "else": 0 + } + } + }, + "pending_count": { + "$sum": { + "$cond": { + "if": {"$in": ["$status", [StateStatusEnum.CREATED, StateStatusEnum.QUEUED, StateStatusEnum.EXECUTED]]}, + "then": 1, + "else": 0 + } + } + }, + "errored_count": { + "$sum": { + "$cond": { + "if": {"$in": ["$status", [StateStatusEnum.ERRORED, StateStatusEnum.NEXT_CREATED_ERROR]]}, + "then": 1, + "else": 0 + } + } + }, + "retried_count": { + "$sum": { + "$cond": { + "if": {"$eq": ["$status", StateStatusEnum.RETRY_CREATED]}, + "then": 1, + "else": 0 + } + } + } + } + } + ] + ) + data = await data_cursor.to_list() + runs = [] + for run in data: + success_count = run["success_count"] + pending_count = run["pending_count"] + errored_count = run["errored_count"] + retried_count = run["retried_count"] + + runs.append( + RunListItem( + run_id=run["_id"], + graph_name=look_up_table[run["_id"]].graph_name, + success_count=success_count, + pending_count=pending_count, + errored_count=errored_count, + retried_count=retried_count, + total_count=run["total_count"], + status=RunStatusEnum.PENDING if pending_count > 0 else RunStatusEnum.FAILED if errored_count > 0 else RunStatusEnum.SUCCESS, + created_at=look_up_table[run["_id"]].created_at + ) + ) + viewed.add(run["_id"]) + + if len(look_up_table) > 0: + for run_id in look_up_table: + if run_id not in viewed: + runs.append( + RunListItem( + run_id=run_id, + graph_name=look_up_table[run_id].graph_name, + success_count=0, + pending_count=0, + errored_count=0, + retried_count=0, + total_count=0, + status=RunStatusEnum.FAILED, + created_at=look_up_table[run_id].created_at + ) + ) + return RunsResponse( namespace=namespace_name, total=await Run.find(Run.namespace_name == namespace_name).count(), page=page, size=size, - runs=await asyncio.gather(*[get_run_info(run) for run in runs]) + runs=sorted(runs, key=lambda x: x.created_at, reverse=True) ) except Exception as e: diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py index e3ee4ffd..a57a44b9 100644 --- a/state-manager/app/controller/trigger_graph.py +++ b/state-manager/app/controller/trigger_graph.py @@ -61,7 +61,8 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph ) for key, value in body.store.items() ] - await Store.insert_many(new_stores) + if len(new_stores) > 0: + await Store.insert_many(new_stores) root = graph_template.get_root_node() diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 05441ec3..2364bc3a 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -93,5 +93,12 @@ class Settings: ], unique=True, name="uniq_fanout_retry" + ), + IndexModel( + [ + ("run_id", 1), + ("status", 1), + ], + name="run_id_status_index" ) ] \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_get_runs.py b/state-manager/tests/unit/controller/test_get_runs.py index 455e8d53..c146e10d 100644 --- a/state-manager/tests/unit/controller/test_get_runs.py +++ b/state-manager/tests/unit/controller/test_get_runs.py @@ -2,194 +2,10 @@ from unittest.mock import AsyncMock, MagicMock, patch from datetime import datetime -from app.controller.get_runs import get_runs, get_run_status, get_run_info +from app.controller.get_runs import get_runs from app.models.db.run import Run -from app.models.run_models import RunsResponse, RunListItem, RunStatusEnum - - -class TestGetRunStatus: - """Test cases for get_run_status function""" - - @pytest.mark.asyncio - async def test_get_run_status_failed(self): - """Test get_run_status returns FAILED when there are errored states""" - run_id = "test_run_id" - - with patch('app.controller.get_runs.State') as mock_state_class: - # Mock count to return > 0 for errored states - mock_state_class.find.return_value.count = AsyncMock(return_value=1) - - result = await get_run_status(run_id) - - assert result == RunStatusEnum.FAILED - mock_state_class.find.assert_called_once() - - @pytest.mark.asyncio - async def test_get_run_status_success(self): - """Test get_run_status returns SUCCESS when all states are completed""" - run_id = "test_run_id" - - with patch('app.controller.get_runs.State') as mock_state_class: - # Mock count to return 0 for errored states and 0 for pending states - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=0)), # First call for errored states - MagicMock(count=AsyncMock(return_value=0)) # Second call for pending states - ] - - result = await get_run_status(run_id) - - assert result == RunStatusEnum.SUCCESS - assert mock_state_class.find.call_count == 2 - - @pytest.mark.asyncio - async def test_get_run_status_pending(self): - """Test get_run_status returns PENDING when there are pending states""" - run_id = "test_run_id" - - with patch('app.controller.get_runs.State') as mock_state_class: - # Mock count to return 0 for errored states but > 0 for pending states - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=0)), # First call for errored states - MagicMock(count=AsyncMock(return_value=1)) # Second call for pending states - ] - - result = await get_run_status(run_id) - - assert result == RunStatusEnum.PENDING - assert mock_state_class.find.call_count == 2 - - @pytest.mark.asyncio - async def test_get_run_status_multiple_errored_states(self): - """Test get_run_status with multiple errored states""" - run_id = "test_run_id" - - with patch('app.controller.get_runs.State') as mock_state_class: - mock_state_class.find.return_value.count = AsyncMock(return_value=5) - - result = await get_run_status(run_id) - - assert result == RunStatusEnum.FAILED - - @pytest.mark.asyncio - async def test_get_run_status_mixed_states(self): - """Test get_run_status with mixed state statuses""" - run_id = "test_run_id" - - with patch('app.controller.get_runs.State') as mock_state_class: - # Mock count to return 0 for errored states but > 0 for pending states - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=0)), # First call for errored states - MagicMock(count=AsyncMock(return_value=3)) # Second call for pending states - ] - - result = await get_run_status(run_id) - - assert result == RunStatusEnum.PENDING - - -class TestGetRunInfo: - """Test cases for get_run_info function""" - - @pytest.fixture - def mock_run(self): - """Create a mock Run object""" - run = MagicMock(spec=Run) - run.run_id = "test_run_id" - run.graph_name = "test_graph" - run.created_at = datetime.now() - return run - - @pytest.mark.asyncio - async def test_get_run_info_success(self, mock_run): - """Test get_run_info returns correct RunListItem""" - with patch('app.controller.get_runs.State') as mock_state_class: - # Mock different count queries - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=5)), # success_count - MagicMock(count=AsyncMock(return_value=2)), # pending_count - MagicMock(count=AsyncMock(return_value=0)), # errored_count - MagicMock(count=AsyncMock(return_value=1)), # retried_count - MagicMock(count=AsyncMock(return_value=8)), # total_count - ] - - with patch('app.controller.get_runs.get_run_status') as mock_get_status: - mock_get_status.return_value = RunStatusEnum.SUCCESS - - result = await get_run_info(mock_run) - - assert isinstance(result, RunListItem) - assert result.run_id == "test_run_id" - assert result.graph_name == "test_graph" - assert result.success_count == 5 - assert result.pending_count == 2 - assert result.errored_count == 0 - assert result.retried_count == 1 - assert result.total_count == 8 - assert result.status == RunStatusEnum.SUCCESS - assert result.created_at == mock_run.created_at - - @pytest.mark.asyncio - async def test_get_run_info_with_errored_states(self, mock_run): - """Test get_run_info with errored states""" - with patch('app.controller.get_runs.State') as mock_state_class: - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=3)), # success_count - MagicMock(count=AsyncMock(return_value=1)), # pending_count - MagicMock(count=AsyncMock(return_value=2)), # errored_count - MagicMock(count=AsyncMock(return_value=0)), # retried_count - MagicMock(count=AsyncMock(return_value=6)), # total_count - ] - - with patch('app.controller.get_runs.get_run_status') as mock_get_status: - mock_get_status.return_value = RunStatusEnum.FAILED - - result = await get_run_info(mock_run) - - assert result.errored_count == 2 - assert result.status == RunStatusEnum.FAILED - - @pytest.mark.asyncio - async def test_get_run_info_with_pending_states(self, mock_run): - """Test get_run_info with pending states""" - with patch('app.controller.get_runs.State') as mock_state_class: - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=2)), # success_count - MagicMock(count=AsyncMock(return_value=4)), # pending_count - MagicMock(count=AsyncMock(return_value=0)), # errored_count - MagicMock(count=AsyncMock(return_value=1)), # retried_count - MagicMock(count=AsyncMock(return_value=7)), # total_count - ] - - with patch('app.controller.get_runs.get_run_status') as mock_get_status: - mock_get_status.return_value = RunStatusEnum.PENDING - - result = await get_run_info(mock_run) - - assert result.pending_count == 4 - assert result.status == RunStatusEnum.PENDING - - @pytest.mark.asyncio - async def test_get_run_info_zero_counts(self, mock_run): - """Test get_run_info with zero counts""" - with patch('app.controller.get_runs.State') as mock_state_class: - mock_state_class.find.side_effect = [ - MagicMock(count=AsyncMock(return_value=0)), # success_count - MagicMock(count=AsyncMock(return_value=0)), # pending_count - MagicMock(count=AsyncMock(return_value=0)), # errored_count - MagicMock(count=AsyncMock(return_value=0)), # retried_count - MagicMock(count=AsyncMock(return_value=0)), # total_count - ] - - with patch('app.controller.get_runs.get_run_status') as mock_get_status: - mock_get_status.return_value = RunStatusEnum.SUCCESS - - result = await get_run_info(mock_run) - - assert result.success_count == 0 - assert result.pending_count == 0 - assert result.errored_count == 0 - assert result.retried_count == 0 - assert result.total_count == 0 +from app.models.run_models import RunsResponse, RunStatusEnum +from app.models.state_status_enum import StateStatusEnum class TestGetRuns: @@ -211,43 +27,71 @@ def mock_runs(self): run = MagicMock(spec=Run) run.run_id = f"run_{i}" run.graph_name = f"graph_{i}" - run.created_at = datetime.now() + run.created_at = datetime(2024, 1, 15, 10 + i, 30, 0) runs.append(run) return runs + @pytest.fixture + def mock_aggregation_data(self): + """Create mock aggregation data that matches the MongoDB aggregation pipeline output""" + return [ + { + "_id": "run_0", + "total_count": 8, + "success_count": 5, + "pending_count": 2, + "errored_count": 0, + "retried_count": 1 + }, + { + "_id": "run_1", + "total_count": 6, + "success_count": 3, + "pending_count": 0, + "errored_count": 2, + "retried_count": 1 + }, + { + "_id": "run_2", + "total_count": 4, + "success_count": 4, + "pending_count": 0, + "errored_count": 0, + "retried_count": 0 + } + ] + @pytest.mark.asyncio - async def test_get_runs_success(self, mock_namespace, mock_request_id, mock_runs): - """Test successful retrieval of runs""" + async def test_get_runs_success(self, mock_namespace, mock_request_id, mock_runs, mock_aggregation_data): + """Test successful retrieval of runs with aggregation data""" page = 1 size = 10 with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as mock_get_run_info, \ + patch('app.controller.get_runs.State') as mock_state_class, \ patch('app.controller.get_runs.logger') as mock_logger: - # Mock the entire query chain for runs list + # Mock the Run query chain for the main runs list mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=mock_runs) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain - # Mock the count query separately + # Mock the count query for total calculation mock_count_query = MagicMock() mock_count_query.count = AsyncMock(return_value=25) - # Use side_effect to return different mocks for different calls mock_run_class.find.side_effect = [ mock_run_class.find.return_value, # First call for runs list mock_count_query # Second call for count ] - # Mock get_run_info for each run - mock_run_items = [] - for i, run in enumerate(mock_runs): - mock_item = MagicMock(spec=RunListItem) - mock_item.run_id = run.run_id - mock_item.graph_name = run.graph_name - mock_run_items.append(mock_item) - - mock_get_run_info.side_effect = mock_run_items + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=mock_aggregation_data) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) result = await get_runs(mock_namespace, page, size, mock_request_id) @@ -259,6 +103,16 @@ async def test_get_runs_success(self, mock_namespace, mock_request_id, mock_runs assert result.size == size assert len(result.runs) == 3 + # Verify the runs are sorted by created_at in descending order + assert result.runs[0].created_at == mock_runs[2].created_at # Most recent first + assert result.runs[2].created_at == mock_runs[0].created_at # Oldest last + + # Verify aggregation pipeline was called correctly + mock_collection.aggregate.assert_called_once() + aggregate_call = mock_collection.aggregate.call_args[0][0] + assert len(aggregate_call) == 2 + assert aggregate_call[0]["$match"]["run_id"]["$in"] == ["run_0", "run_1", "run_2"] + # Verify logging mock_logger.info.assert_called_once_with( f"Getting runs for namespace {mock_namespace}", @@ -266,21 +120,51 @@ async def test_get_runs_success(self, mock_namespace, mock_request_id, mock_runs ) @pytest.mark.asyncio - async def test_get_runs_pagination(self, mock_namespace, mock_request_id, mock_runs): + async def test_get_runs_empty_result(self, mock_namespace, mock_request_id): + """Test get_runs when no runs are found""" + page = 1 + size = 10 + + with patch('app.controller.get_runs.Run') as mock_run_class, \ + patch('app.controller.get_runs.logger') as _: + + # Mock the Run query chain to return empty list + mock_query_chain = MagicMock() + mock_query_chain.to_list = AsyncMock(return_value=[]) + mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain + + # Mock the count query for total calculation when no runs are found + mock_count_query = MagicMock() + mock_count_query.count = AsyncMock(return_value=0) + mock_run_class.find.side_effect = [ + mock_run_class.find.return_value, # First call for runs list + mock_count_query # Second call for count + ] + + result = await get_runs(mock_namespace, page, size, mock_request_id) + + assert result.runs == [] + assert result.total == 0 + assert result.namespace == mock_namespace + assert result.page == page + assert result.size == size + + @pytest.mark.asyncio + async def test_get_runs_pagination(self, mock_namespace, mock_request_id, mock_runs, mock_aggregation_data): """Test get_runs with different pagination parameters""" page = 2 size = 5 with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as mock_get_run_info, \ + patch('app.controller.get_runs.State') as mock_state_class, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=mock_runs) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain - # Mock the count query separately + # Mock the count query mock_count_query = MagicMock() mock_count_query.count = AsyncMock(return_value=15) mock_run_class.find.side_effect = [ @@ -288,40 +172,176 @@ async def test_get_runs_pagination(self, mock_namespace, mock_request_id, mock_r mock_count_query # Second call for count ] - mock_get_run_info.side_effect = [MagicMock(spec=RunListItem) for _ in mock_runs] + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=mock_aggregation_data) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) result = await get_runs(mock_namespace, page, size, mock_request_id) assert result.page == page assert result.size == size assert result.total == 15 + assert len(result.runs) == 3 @pytest.mark.asyncio - async def test_get_runs_empty_result(self, mock_namespace, mock_request_id): - """Test get_runs when no runs are found""" + async def test_get_runs_with_missing_states(self, mock_namespace, mock_request_id, mock_runs): + """Test get_runs when some runs have no states in the aggregation""" page = 1 size = 10 + # Only first two runs have aggregation data + mock_aggregation_data = [ + { + "_id": "run_0", + "total_count": 5, + "success_count": 3, + "pending_count": 1, + "errored_count": 0, + "retried_count": 1 + }, + { + "_id": "run_1", + "total_count": 3, + "success_count": 2, + "pending_count": 0, + "errored_count": 1, + "retried_count": 0 + } + # run_2 has no aggregation data + ] + with patch('app.controller.get_runs.Run') as mock_run_class, \ + patch('app.controller.get_runs.State') as mock_state_class, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain mock_query_chain = MagicMock() - mock_query_chain.to_list = AsyncMock(return_value=[]) + mock_query_chain.to_list = AsyncMock(return_value=mock_runs) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain - # Mock the count query separately + # Mock the count query mock_count_query = MagicMock() - mock_count_query.count = AsyncMock(return_value=0) + mock_count_query.count = AsyncMock(return_value=15) mock_run_class.find.side_effect = [ mock_run_class.find.return_value, # First call for runs list mock_count_query # Second call for count ] + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=mock_aggregation_data) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + result = await get_runs(mock_namespace, page, size, mock_request_id) - assert result.runs == [] - assert result.total == 0 + assert len(result.runs) == 3 + + # Check that runs with aggregation data have correct counts + run_0 = next(r for r in result.runs if r.run_id == "run_0") + assert run_0.total_count == 5 + assert run_0.success_count == 3 + assert run_0.pending_count == 1 + assert run_0.errored_count == 0 + assert run_0.retried_count == 1 + assert run_0.status == RunStatusEnum.PENDING # Has pending states + + run_1 = next(r for r in result.runs if r.run_id == "run_1") + assert run_1.total_count == 3 + assert run_1.success_count == 2 + assert run_1.pending_count == 0 + assert run_1.errored_count == 1 + assert run_1.retried_count == 0 + assert run_1.status == RunStatusEnum.FAILED # Has errored states + + # Check that run_2 (no aggregation data) has zero counts and FAILED status + run_2 = next(r for r in result.runs if r.run_id == "run_2") + assert run_2.total_count == 0 + assert run_2.success_count == 0 + assert run_2.pending_count == 0 + assert run_2.errored_count == 0 + assert run_2.retried_count == 0 + assert run_2.status == RunStatusEnum.FAILED + + @pytest.mark.asyncio + async def test_get_runs_status_calculation(self, mock_namespace, mock_request_id, mock_runs): + """Test that run status is calculated correctly based on state counts""" + page = 1 + size = 10 + + # Test different status scenarios + mock_aggregation_data = [ + { + "_id": "run_0", + "total_count": 5, + "success_count": 5, + "pending_count": 0, + "errored_count": 0, + "retried_count": 0 + }, + { + "_id": "run_1", + "total_count": 3, + "success_count": 1, + "pending_count": 2, + "errored_count": 0, + "retried_count": 0 + }, + { + "_id": "run_2", + "total_count": 4, + "success_count": 2, + "pending_count": 0, + "errored_count": 2, + "retried_count": 0 + } + ] + + with patch('app.controller.get_runs.Run') as mock_run_class, \ + patch('app.controller.get_runs.State') as mock_state_class, \ + patch('app.controller.get_runs.logger') as _: + + # Mock the Run query chain + mock_query_chain = MagicMock() + mock_query_chain.to_list = AsyncMock(return_value=mock_runs) + mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain + + # Mock the count query + mock_count_query = MagicMock() + mock_count_query.count = AsyncMock(return_value=15) + mock_run_class.find.side_effect = [ + mock_run_class.find.return_value, # First call for runs list + mock_count_query # Second call for count + ] + + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=mock_aggregation_data) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + + result = await get_runs(mock_namespace, page, size, mock_request_id) + + # Verify status calculations + run_0 = next(r for r in result.runs if r.run_id == "run_0") + assert run_0.status == RunStatusEnum.SUCCESS # All states successful + + run_1 = next(r for r in result.runs if r.run_id == "run_1") + assert run_1.status == RunStatusEnum.PENDING # Has pending states + + run_2 = next(r for r in result.runs if r.run_id == "run_2") + assert run_2.status == RunStatusEnum.FAILED # Has errored states @pytest.mark.asyncio async def test_get_runs_exception_handling(self, mock_namespace, mock_request_id): @@ -352,10 +372,9 @@ async def test_get_runs_different_namespaces(self, mock_request_id): for namespace in namespaces: with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as _, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain to return empty list mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=[]) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain @@ -385,19 +404,31 @@ async def test_get_runs_large_page_size(self, mock_namespace, mock_request_id): run = MagicMock(spec=Run) run.run_id = f"run_{i}" run.graph_name = f"graph_{i}" - run.created_at = datetime.now() + run.created_at = datetime(2024, 1, 15, 10, 30, 0) large_runs_list.append(run) + # Create corresponding aggregation data + large_aggregation_data = [] + for i in range(1000): + large_aggregation_data.append({ + "_id": f"run_{i}", + "total_count": 5, + "success_count": 3, + "pending_count": 1, + "errored_count": 0, + "retried_count": 1 + }) + with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as mock_get_run_info, \ + patch('app.controller.get_runs.State') as mock_state_class, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=large_runs_list) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain - # Mock the count query separately + # Mock the count query mock_count_query = MagicMock() mock_count_query.count = AsyncMock(return_value=1000) mock_run_class.find.side_effect = [ @@ -405,7 +436,14 @@ async def test_get_runs_large_page_size(self, mock_namespace, mock_request_id): mock_count_query # Second call for count ] - mock_get_run_info.side_effect = [MagicMock(spec=RunListItem) for _ in large_runs_list] + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=large_aggregation_data) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) result = await get_runs(mock_namespace, page, size, mock_request_id) @@ -419,10 +457,9 @@ async def test_get_runs_edge_case_page_zero(self, mock_namespace, mock_request_i size = 10 with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as _, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain to return empty list mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=[]) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain @@ -447,10 +484,9 @@ async def test_get_runs_edge_case_size_zero(self, mock_namespace, mock_request_i size = 0 with patch('app.controller.get_runs.Run') as mock_run_class, \ - patch('app.controller.get_runs.get_run_info') as _, \ patch('app.controller.get_runs.logger') as _: - # Mock the entire query chain for runs list + # Mock the Run query chain to return empty list mock_query_chain = MagicMock() mock_query_chain.to_list = AsyncMock(return_value=[]) mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain @@ -466,4 +502,71 @@ async def test_get_runs_edge_case_size_zero(self, mock_namespace, mock_request_i result = await get_runs(mock_namespace, page, size, mock_request_id) assert result.page == page - assert result.size == size \ No newline at end of file + assert result.size == size + + @pytest.mark.asyncio + async def test_get_runs_aggregation_pipeline_structure(self, mock_namespace, mock_request_id, mock_runs): + """Test that the MongoDB aggregation pipeline is structured correctly""" + page = 1 + size = 10 + + with patch('app.controller.get_runs.Run') as mock_run_class, \ + patch('app.controller.get_runs.State') as mock_state_class, \ + patch('app.controller.get_runs.logger') as _: + + # Mock the Run query chain + mock_query_chain = MagicMock() + mock_query_chain.to_list = AsyncMock(return_value=mock_runs) + mock_run_class.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_query_chain + + # Mock the count query + mock_count_query = MagicMock() + mock_count_query.count = AsyncMock(return_value=15) + mock_run_class.find.side_effect = [ + mock_run_class.find.return_value, # First call for runs list + mock_count_query # Second call for count + ] + + # Mock the State aggregation pipeline with cursor approach + mock_collection = MagicMock() + mock_cursor = MagicMock() + mock_cursor.to_list = AsyncMock(return_value=[]) + # Mock aggregate to return an awaitable cursor since source code awaits the entire expression + mock_collection.aggregate = AsyncMock(return_value=mock_cursor) + # Mock get_pymongo_collection to return a mock collection + mock_state_class.get_pymongo_collection = MagicMock(return_value=mock_collection) + + await get_runs(mock_namespace, page, size, mock_request_id) + + # Verify aggregation pipeline structure + mock_collection.aggregate.assert_called_once() + pipeline = mock_collection.aggregate.call_args[0][0] + + # Check $match stage + assert pipeline[0]["$match"]["run_id"]["$in"] == ["run_0", "run_1", "run_2"] + + # Check $group stage + group_stage = pipeline[1]["$group"] + assert group_stage["_id"] == "$run_id" + assert "total_count" in group_stage + assert "success_count" in group_stage + assert "pending_count" in group_stage + assert "errored_count" in group_stage + assert "retried_count" in group_stage + + # Check that the aggregation conditions use the correct StateStatusEnum values + success_condition = group_stage["success_count"]["$sum"]["$cond"]["if"]["$in"][1] + assert StateStatusEnum.SUCCESS in success_condition + assert StateStatusEnum.PRUNED in success_condition + + pending_condition = group_stage["pending_count"]["$sum"]["$cond"]["if"]["$in"][1] + assert StateStatusEnum.CREATED in pending_condition + assert StateStatusEnum.QUEUED in pending_condition + assert StateStatusEnum.EXECUTED in pending_condition + + errored_condition = group_stage["errored_count"]["$sum"]["$cond"]["if"]["$in"][1] + assert StateStatusEnum.ERRORED in errored_condition + assert StateStatusEnum.NEXT_CREATED_ERROR in errored_condition + + retried_condition = group_stage["retried_count"]["$sum"]["$cond"]["if"]["$eq"][1] + assert retried_condition == StateStatusEnum.RETRY_CREATED \ No newline at end of file