Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions airflow-core/src/airflow/api_fastapi/common/db/dag_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
from __future__ import annotations

from sqlalchemy import func, select
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy.orm.interfaces import LoaderOption

from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory

dagruns_select_with_state_count = (
select(
Expand All @@ -33,3 +38,17 @@
.group_by(DagRun.dag_id, DagRun.state, DagModel.dag_display_name)
.order_by(DagRun.dag_id)
)


def eager_load_dag_run_for_validation() -> tuple[LoaderOption, ...]:
"""Construct the eager loading options necessary for a DagRunResponse object."""
return (
joinedload(DagRun.dag_model),
selectinload(DagRun.task_instances)
.joinedload(TaskInstance.dag_version)
.joinedload(DagVersion.bundle),
selectinload(DagRun.task_instances_histories)
.joinedload(TaskInstanceHistory.dag_version)
.joinedload(DagVersion.bundle),
joinedload(DagRun.dag_run_note),
)
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
from datetime import datetime
from uuid import UUID

from pydantic import AliasPath, Field, computed_field
from sqlalchemy import select
from pydantic import AliasPath, Field

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.dag_processing.bundles.manager import DagBundlesManager


class DagVersionResponse(BaseModel):
Expand All @@ -37,29 +35,7 @@ class DagVersionResponse(BaseModel):
created_at: datetime
dag_display_name: str = Field(validation_alias=AliasPath("dag_model", "dag_display_name"))

# Mypy issue https://github.com/python/mypy/issues/1362
@computed_field # type: ignore[prop-decorator]
@property
def bundle_url(self) -> str | None:
if self.bundle_name:
# Get the bundle model from the database and render the URL
from airflow.models.dagbundle import DagBundleModel
from airflow.utils.session import create_session

with create_session() as session:
bundle_model = session.scalar(
select(DagBundleModel).where(DagBundleModel.name == self.bundle_name)
)

if bundle_model and hasattr(bundle_model, "signed_url_template"):
return bundle_model.render_url(self.bundle_version)
# fallback to the deprecated option if the bundle model does not have a signed_url_template
# attribute
try:
return DagBundlesManager().view_url(self.bundle_name, self.bundle_version)
except ValueError:
return None
return None
bundle_url: str | None = Field(validation_alias="bundle_url")


class DAGVersionCollectionResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def concurrency(self) -> int:
@property
def latest_dag_version(self) -> DagVersionResponse | None:
"""Return the latest DagVersion."""
latest_dag_version = DagVersion.get_latest_version(self.dag_id, load_dag_model=True)
latest_dag_version = DagVersion.get_latest_version(
self.dag_id, load_dag_model=True, load_bundle_model=True
)
if latest_dag_version is None:
return latest_dag_version
return DagVersionResponse.model_validate(latest_dag_version)
Original file line number Diff line number Diff line change
Expand Up @@ -1676,7 +1676,6 @@ components:
- type: string
- type: 'null'
title: Bundle Url
readOnly: true
type: object
required:
- id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10717,7 +10717,6 @@ components:
- type: string
- type: 'null'
title: Bundle Url
readOnly: true
type: object
required:
- id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
from airflow.api_fastapi.common.dagbag import DagBagDep, get_dag_for_run, get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.db.dag_runs import eager_load_dag_run_for_validation
from airflow.api_fastapi.common.parameters import (
FilterOptionEnum,
FilterParam,
Expand Down Expand Up @@ -81,6 +82,7 @@
from airflow.exceptions import ParamValidationError
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagModel, DagRun
from airflow.models.asset import AssetEvent
from airflow.models.dag_version import DagVersion
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
Expand Down Expand Up @@ -232,10 +234,12 @@ def get_upstream_asset_events(
) -> AssetEventCollectionResponse:
"""If dag run is asset-triggered, return the asset events that triggered it."""
dag_run: DagRun | None = session.scalar(
select(DagRun).where(
select(DagRun)
.where(
DagRun.dag_id == dag_id,
DagRun.run_id == dag_run_id,
)
.options(joinedload(DagRun.consumed_asset_events).joinedload(AssetEvent.asset))
)
if dag_run is None:
raise HTTPException(
Expand Down Expand Up @@ -357,11 +361,11 @@ def get_dag_runs(

This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs.
"""
query = select(DagRun)
query = select(DagRun).options(*eager_load_dag_run_for_validation())

if dag_id != "~":
get_latest_version_of_dag(dag_bag, dag_id, session) # Check if the DAG exists.
query = query.filter(DagRun.dag_id == dag_id).options(joinedload(DagRun.dag_model))
query = query.filter(DagRun.dag_id == dag_id).options()

# Add join with DagVersion if dag_version filter is active
if dag_version.value:
Expand Down Expand Up @@ -585,7 +589,8 @@ def get_list_dag_runs_batch(
{"dag_run_id": "run_id"},
).set_value([body.order_by] if body.order_by else None)

base_query = select(DagRun).options(joinedload(DagRun.dag_model))
base_query = select(DagRun).options(*eager_load_dag_run_for_validation())

dag_runs_select, total_entries = paginated_select(
statement=base_query,
filters=[dag_ids, logical_date, run_after, start_date, end_date, state, readable_dag_runs_filter],
Expand Down
39 changes: 37 additions & 2 deletions airflow-core/src/airflow/models/dag_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from sqlalchemy_utils import UUIDType

from airflow._shared.timezones import timezone
from airflow.dag_processing.bundles.manager import DagBundlesManager
from airflow.models.base import Base, StringID
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
Expand All @@ -47,6 +48,12 @@ class DagVersion(Base):
dag_model = relationship("DagModel", back_populates="dag_versions")
bundle_name = Column(StringID(), nullable=True)
bundle_version = Column(StringID())
bundle = relationship(
"DagBundleModel",
primaryjoin="foreign(DagVersion.bundle_name) == DagBundleModel.name",
uselist=False,
viewonly=True,
)
dag_code = relationship(
"DagCode",
back_populates="dag_version",
Expand All @@ -73,6 +80,20 @@ def __repr__(self):
"""Represent the object as a string."""
return f"<DagVersion {self.dag_id} {self.version}>"

@property
def bundle_url(self) -> str | None:
"""Render the bundle URL using the joined bundle metadata if available."""
# Prefer using the joined bundle relationship when present to avoid extra queries
if getattr(self, "bundle", None) is not None and hasattr(self.bundle, "signed_url_template"):
return self.bundle.render_url(self.bundle_version)

# fallback to the deprecated option if the bundle model does not have a signed_url_template
# attribute
try:
return DagBundlesManager().view_url(self.bundle_name, self.bundle_version)
except ValueError:
return None

@classmethod
@provide_session
def write_dag(
Expand Down Expand Up @@ -114,7 +135,11 @@ def write_dag(

@classmethod
def _latest_version_select(
cls, dag_id: str, bundle_version: str | None = None, load_dag_model: bool = False
cls,
dag_id: str,
bundle_version: str | None = None,
load_dag_model: bool = False,
load_bundle_model: bool = False,
) -> Select:
"""
Get the select object to get the latest version of the DAG.
Expand All @@ -129,6 +154,9 @@ def _latest_version_select(
if load_dag_model:
query = query.options(joinedload(cls.dag_model))

if load_bundle_model:
query = query.options(joinedload(cls.bundle))

query = query.order_by(cls.created_at.desc()).limit(1)
return query

Expand All @@ -140,6 +168,7 @@ def get_latest_version(
*,
bundle_version: str | None = None,
load_dag_model: bool = False,
load_bundle_model: bool = False,
session: Session = NEW_SESSION,
) -> DagVersion | None:
"""
Expand All @@ -148,10 +177,16 @@ def get_latest_version(
:param dag_id: The DAG ID.
:param session: The database session.
:param load_dag_model: Whether to load the DAG model.
:param load_bundle_model: Whether to load the DagBundle model.
:return: The latest version of the DAG or None if not found.
"""
return session.scalar(
cls._latest_version_select(dag_id, bundle_version=bundle_version, load_dag_model=load_dag_model)
cls._latest_version_select(
dag_id,
bundle_version=bundle_version,
load_dag_model=load_dag_model,
load_bundle_model=load_bundle_model,
)
)

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3185,8 +3185,7 @@ export const $DagVersionResponse = {
type: 'null'
}
],
title: 'Bundle Url',
readOnly: true
title: 'Bundle Url'
}
},
type: 'object',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ export type DagVersionResponse = {
bundle_version: string | null;
created_at: string;
dag_display_name: string;
readonly bundle_url: string | null;
bundle_url: string | null;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.types import DagRunTriggeredByType, DagRunType

from tests_common.test_utils.api_fastapi import _check_dag_run_note, _check_last_log
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import (
clear_db_connections,
clear_db_dag_bundles,
Expand Down Expand Up @@ -307,7 +308,10 @@ def test_should_respond_403(self, unauthorized_test_client):


class TestGetDagRuns:
@pytest.mark.parametrize("dag_id, total_entries", [(DAG1_ID, 2), (DAG2_ID, 2), ("~", 4)])
@pytest.mark.parametrize(
"dag_id, total_entries",
[(DAG1_ID, 2), (DAG2_ID, 2), ("~", 4)],
)
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_get_dag_runs(self, test_client, session, dag_id, total_entries):
response = test_client.get(f"/dags/{dag_id}/dagRuns")
Expand Down Expand Up @@ -364,7 +368,10 @@ def test_should_respond_403(self, unauthorized_test_client):
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_return_correct_results_with_order_by(self, test_client, order_by, expected_order):
# Test ascending order
response = test_client.get("/dags/test_dag1/dagRuns", params={"order_by": order_by})

with assert_queries_count(7):
response = test_client.get("/dags/test_dag1/dagRuns", params={"order_by": order_by})

assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 2
Expand Down Expand Up @@ -750,7 +757,8 @@ def test_invalid_dag_version(self, test_client):
class TestListDagRunsBatch:
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_list_dag_runs_return_200(self, test_client, session):
response = test_client.post("/dags/~/dagRuns/list", json={})
with assert_queries_count(5):
response = test_client.post("/dags/~/dagRuns/list", json={})
assert response.status_code == 200
body = response.json()
assert body["total_entries"] == 4
Expand Down Expand Up @@ -791,7 +799,8 @@ def test_list_dag_runs_with_invalid_dag_id(self, test_client):
)
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_list_dag_runs_with_dag_ids_filter(self, test_client, dag_ids, status_code, expected_dag_id_list):
response = test_client.post("/dags/~/dagRuns/list", json={"dag_ids": dag_ids})
with assert_queries_count(5):
response = test_client.post("/dags/~/dagRuns/list", json={"dag_ids": dag_ids})
assert response.status_code == status_code
assert set([each["dag_run_id"] for each in response.json()["dag_runs"]]) == set(expected_dag_id_list)

Expand Down Expand Up @@ -1293,9 +1302,10 @@ def test_should_respond_200(self, test_client, dag_maker, session):
session.commit()
assert event.timestamp

response = test_client.get(
"/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamAssetEvents",
)
with assert_queries_count(3):
response = test_client.get(
"/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamAssetEvents",
)
assert response.status_code == 200
expected_response = {
"asset_events": [
Expand Down Expand Up @@ -1504,6 +1514,7 @@ def test_should_respond_200(
run = (
session.query(DagRun).where(DagRun.dag_id == DAG1_ID, DagRun.run_id == expected_dag_run_id).one()
)

expected_response_json = {
"bundle_version": None,
"conf": {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ class TestGetDagVersion(TestDagVersionEndpoint):
],
)
@pytest.mark.usefixtures("make_dag_with_multiple_versions")
@mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr")
def test_get_dag_version(self, mock_hasattr, test_client, dag_id, dag_version, expected_response):
mock_hasattr.return_value = False
def test_get_dag_version(self, test_client, dag_id, dag_version, expected_response):
response = test_client.get(f"/dags/{dag_id}/dagVersions/{dag_version}")
assert response.status_code == 200
assert response.json() == expected_response
Expand Down Expand Up @@ -180,7 +178,7 @@ def test_get_dag_version_with_url_template(self, test_client, dag_id, dag_versio

@pytest.mark.usefixtures("make_dag_with_multiple_versions")
@mock.patch("airflow.dag_processing.bundles.manager.DagBundlesManager.view_url")
@mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr")
@mock.patch("airflow.models.dag_version.hasattr")
def test_get_dag_version_with_unconfigured_bundle(
self, mock_hasattr, mock_view_url, test_client, dag_maker, session
):
Expand Down Expand Up @@ -305,9 +303,7 @@ class TestGetDagVersions(TestDagVersionEndpoint):
],
)
@pytest.mark.usefixtures("make_dag_with_multiple_versions")
@mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.hasattr")
def test_get_dag_versions(self, mock_hasattr, test_client, dag_id, expected_response):
mock_hasattr.return_value = False
def test_get_dag_versions(self, test_client, dag_id, expected_response):
response = test_client.get(f"/dags/{dag_id}/dagVersions")
assert response.status_code == 200
assert response.json() == expected_response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def test_dag_details(
"is_paused_upon_creation": None,
"latest_dag_version": {
"bundle_name": "dag_maker",
"bundle_url": None,
"bundle_url": "http://test_host.github.com/tree/None/dags",
"bundle_version": None,
"created_at": mock.ANY,
"dag_id": "test_dag2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ def test_should_respond_200(
update_extras=update_extras,
task_instances=task_instances,
)
with mock.patch("airflow.api_fastapi.core_api.datamodels.dag_versions.DagBundlesManager"):
with mock.patch("airflow.models.dag_version.DagBundlesManager"):
# Mock DagBundlesManager to avoid checking if dags-folder bundle is configured
response = test_client.get(url, params=params)
if params == {"task_id_pattern": "task_match_id"}:
Expand Down