diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index b83d731a54e35..c43c931f3e28a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -345,3 +345,9 @@ class TaskStatesResponse(BaseModel): """Response for task states with run_id, task and state.""" task_states: dict[str, Any] + + +class InactiveAssetsResponse(BaseModel): + """Response for inactive assets.""" + + inactive_assets: Annotated[list[AssetProfile], Field(default_factory=list)] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index ac1d1602460de..15cdb0a40cad6 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -17,12 +17,15 @@ from __future__ import annotations +import contextlib +import itertools import json from collections import defaultdict from collections.abc import Iterator from typing import TYPE_CHECKING, Annotated, Any from uuid import UUID +import attrs import structlog from cadwyn import VersionedAPIRouter from fastapi import Body, HTTPException, Query, status @@ -37,6 +40,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + InactiveAssetsResponse, PrevSuccessfulDagRunResponse, TaskStatesResponse, TIDeferredStatePayload, @@ -51,6 +55,8 @@ TITerminalStatePayload, ) from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException, TaskNotFound +from airflow.models.asset import AssetActive from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks @@ -58,6 +64,7 @@ from airflow.models.trigger import Trigger from airflow.models.xcom import XComModel from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated +from airflow.sdk.definitions.asset import Asset, AssetUniqueKey from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState @@ -400,12 +407,16 @@ def ti_update_state( query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) updated_state = ti_patch_payload.state task_instance = session.get(TI, ti_id_str) - TI.register_asset_changes_in_db( - task_instance, - ti_patch_payload.task_outlets, # type: ignore - ti_patch_payload.outlet_events, - session, - ) + try: + TI.register_asset_changes_in_db( + task_instance, + ti_patch_payload.task_outlets, # type: ignore + ti_patch_payload.outlet_events, + session, + ) + except AirflowInactiveAssetInInletOrOutletException as err: + log.error("Asset registration failed due to conflicting asset: %s", err) + query = query.values(state=updated_state) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed @@ -840,5 +851,67 @@ def _get_group_tasks(dag_id: str, task_group_id: str, session: SessionDep, logic return group_tasks +@ti_id_router.get( + "/{task_instance_id}/validate-inlets-and-outlets", + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, + }, +) +def validate_inlets_and_outlets( + task_instance_id: UUID, + session: SessionDep, + dag_bag: DagBagDep, +) -> InactiveAssetsResponse: + """Validate whether there're inactive assets in inlets and outlets of a given task instance.""" + ti_id_str = str(task_instance_id) + bind_contextvars(ti_id=ti_id_str) + + ti = session.scalar(select(TI).where(TI.id == ti_id_str)) + if not ti or not ti.logical_date: + log.error("Task Instance not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": "Task Instance not found", + }, + ) + + if not ti.task: + dag = dag_bag.get_dag(ti.dag_id) + if dag: + with contextlib.suppress(TaskNotFound): + ti.task = dag.get_task(ti.task_id) + + inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] + outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] + if not (inlets or outlets): + return InactiveAssetsResponse(inactive_assets=[]) + + all_asset_unique_keys: set[AssetUniqueKey] = { + AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore + for inlet_or_outlet in itertools.chain(inlets, outlets) + } + active_asset_unique_keys = { + AssetUniqueKey(name, uri) + for name, uri in session.execute( + select(AssetActive.name, AssetActive.uri).where( + tuple_(AssetActive.name, AssetActive.uri).in_( + attrs.astuple(key) for key in all_asset_unique_keys + ) + ) + ) + } + different = all_asset_unique_keys - active_asset_unique_keys + + return InactiveAssetsResponse( + inactive_assets=[ + asset_unique_key.to_asset().asprofile() # type: ignore + for asset_unique_key in different + ] + ) + + # This line should be at the end of the file to ensure all routes are registered router.include_router(ti_id_router) diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 5f17ff54709c7..cd5ec56a83993 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1894,6 +1894,7 @@ def schedule_tis( and not ti.task.on_execute_callback and not ti.task.on_success_callback and not ti.task.outlets + and not ti.task.inlets ): empty_ti_ids.append(ti.id) # check "start_trigger_args" to see whether the operator supports start execution from triggerer diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 21c733a5cced2..49a8717c36aa2 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -33,7 +33,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk import TaskGroup, task, task_group +from airflow.sdk import Asset, TaskGroup, task, task_group from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState, TerminalTIState @@ -2139,3 +2139,54 @@ def add_one(x): response = client.get("/execution/task-instances/states", params={"dag_id": dr.dag_id, **params}) assert response.status_code == 200 assert response.json() == {"task_states": {dr.run_id: expected}} + + +class TestInvactiveInletsAndOutlets: + def test_ti_inactive_inlets_and_outlets(self, client, dag_maker): + """Test the inactive assets in inlets and outlets can be found.""" + with dag_maker("test_inlets_and_outlets"): + EmptyOperator( + task_id="task1", + inlets=[Asset(name="inlet-name"), Asset(name="inlet-name", uri="but-different-uri")], + outlets=[ + Asset(name="outlet-name", uri="uri"), + Asset(name="outlet-name", uri="second-different-uri"), + ], + ) + + dr = dag_maker.create_dagrun() + + task1_ti = dr.get_task_instance("task1") + response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets") + assert response.status_code == 200 + inactive_assets = response.json()["inactive_assets"] + expected_inactive_assets = ( + { + "name": "inlet-name", + "type": "Asset", + "uri": "but-different-uri", + }, + { + "name": "outlet-name", + "type": "Asset", + "uri": "second-different-uri", + }, + ) + for asset in expected_inactive_assets: + assert asset in inactive_assets + + def test_ti_inactive_inlets_and_outlets_without_inactive_assets(self, client, dag_maker): + """Test the task without inactive assets in its inlets or outlets returns empty list.""" + with dag_maker("test_inlets_and_outlets_inactive"): + EmptyOperator( + task_id="inactive_task1", + inlets=[Asset(name="inlet-name")], + outlets=[Asset(name="outlet-name", uri="uri")], + ) + + dr = dag_maker.create_dagrun() + + task1_ti = dr.get_task_instance("inactive_task1") + response = client.get(f"/execution/task-instances/{task1_ti.id}/validate-inlets-and-outlets") + assert response.status_code == 200 + assert response.json() == {"inactive_assets": []} diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index b9d0a4511ea10..8764c59976172 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -40,6 +40,7 @@ ConnectionResponse, DagRunStateResponse, DagRunType, + InactiveAssetsResponse, PrevSuccessfulDagRunResponse, TaskInstanceState, TaskStatesResponse, @@ -273,6 +274,11 @@ def get_task_states( resp = self.client.get("task-instances/states", params=params) return TaskStatesResponse.model_validate_json(resp.read()) + def validate_inlets_and_outlets(self, id: uuid.UUID) -> InactiveAssetsResponse: + """Validate whether there're inactive assets in inlets and outlets of a given task instance.""" + resp = self.client.get(f"task-instances/{id}/validate-inlets-and-outlets") + return InactiveAssetsResponse.model_validate_json(resp.read()) + class ConnectionOperations: __slots__ = ("client",) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index f6b1c907ef529..ac1e51d5e55c9 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -154,6 +154,14 @@ class DagRunType(str, Enum): ASSET_TRIGGERED = "asset_triggered" +class InactiveAssetsResponse(BaseModel): + """ + Response for inactive assets. + """ + + inactive_assets: Annotated[list[AssetProfile] | None, Field(title="Inactive Assets")] = None + + class IntermediateTIState(str, Enum): """ States that a Task Instance can be in that indicate it is not yet in a terminal or running state. diff --git a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py index c81732cf40414..9cb913807ee91 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -86,6 +86,18 @@ def from_str(key: str) -> AssetUniqueKey: def to_str(self) -> str: return json.dumps(attrs.asdict(self)) + @staticmethod + def from_profile(profile: AssetProfile) -> AssetUniqueKey: + if profile.name and profile.uri: + return AssetUniqueKey(name=profile.name, uri=profile.uri) + + if name := profile.name: + return AssetUniqueKey(name=name, uri=name) + if uri := profile.uri: + return AssetUniqueKey(name=uri, uri=uri) + + raise ValueError("name and uri cannot both be empty") + @attrs.define(frozen=True) class AssetAliasUniqueKey: diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index ecc34852252e0..d0622cf6ffc86 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -61,6 +61,7 @@ BundleInfo, ConnectionResponse, DagRunStateResponse, + InactiveAssetsResponse, PrevSuccessfulDagRunResponse, TaskInstance, TaskInstanceState, @@ -208,6 +209,24 @@ def source_task_instance(self) -> AssetEventSourceTaskInstance | None: ) +class InactiveAssetsResult(InactiveAssetsResponse): + """Response of InactiveAssets requests.""" + + type: Literal["InactiveAssetsResult"] = "InactiveAssetsResult" + + @classmethod + def from_inactive_assets_response( + cls, inactive_assets_response: InactiveAssetsResponse + ) -> InactiveAssetsResult: + """ + Get InactiveAssetsResponse from InactiveAssetsResult. + + InactiveAssetsResponse is autogenerated from the API schema, so we need to convert it to InactiveAssetsResult + for communication between the Supervisor and the task process. + """ + return cls(**inactive_assets_response.model_dump(exclude_defaults=True), type="InactiveAssetsResult") + + class XComResult(XComResponse): """Response to ReadXCom request.""" @@ -376,6 +395,7 @@ class OKResponse(BaseModel): XComResult, XComSequenceIndexResult, XComSequenceSliceResult, + InactiveAssetsResult, OKResponse, ], Field(discriminator="type"), @@ -590,6 +610,11 @@ class GetAssetEventByAssetAlias(BaseModel): type: Literal["GetAssetEventByAssetAlias"] = "GetAssetEventByAssetAlias" +class ValidateInletsAndOutlets(BaseModel): + ti_id: UUID + type: Literal["ValidateInletsAndOutlets"] = "ValidateInletsAndOutlets" + + class GetPrevSuccessfulDagRun(BaseModel): ti_id: UUID type: Literal["GetPrevSuccessfulDagRun"] = "GetPrevSuccessfulDagRun" @@ -657,6 +682,7 @@ class GetDRCount(BaseModel): SetXCom, SkipDownstreamTasks, SucceedTask, + ValidateInletsAndOutlets, TaskState, TriggerDagRun, DeleteVariable, diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 65d05cc023d51..7e6b89043d221 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -88,6 +88,7 @@ GetXComCount, GetXComSequenceItem, GetXComSequenceSlice, + InactiveAssetsResult, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, @@ -101,6 +102,7 @@ TaskStatesResult, ToSupervisor, TriggerDagRun, + ValidateInletsAndOutlets, VariableResult, XComCountResponse, XComResult, @@ -1215,6 +1217,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): ) elif isinstance(msg, DeleteVariable): resp = self.client.variables.delete(msg.key) + elif isinstance(msg, ValidateInletsAndOutlets): + inactive_assets_resp = self.client.task_instances.validate_inlets_and_outlets(msg.ti_id) + resp = InactiveAssetsResult.from_inactive_assets_response(inactive_assets_resp) + dump_opts = {"exclude_unset": True} else: log.error("Unhandled request", msg=msg) return diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index cfe354f784e07..bbc7394a87a0d 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -41,6 +41,7 @@ from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException from airflow.listeners.listener import get_listener_manager from airflow.sdk.api.datamodels._generated import ( AssetProfile, @@ -67,6 +68,7 @@ GetTaskRescheduleStartDate, GetTaskStates, GetTICount, + InactiveAssetsResult, RescheduleTask, RetryTask, SetRenderedFields, @@ -80,6 +82,7 @@ ToSupervisor, ToTask, TriggerDagRun, + ValidateInletsAndOutlets, ) from airflow.sdk.execution_time.context import ( ConnectionAccessor, @@ -784,6 +787,8 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv # so that we do not call the API unnecessarily SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) + _validate_task_inlets_and_outlets(ti=ti, log=log) + try: # TODO: Call pre execute etc. get_listener_manager().hook.on_task_instance_running( @@ -796,6 +801,22 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv return None +def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) -> None: + if not ti.task.inlets and not ti.task.outlets: + return + + SUPERVISOR_COMMS.send_request(msg=ValidateInletsAndOutlets(ti_id=ti.id), log=log) + inactive_assets_resp = SUPERVISOR_COMMS.get_message() + if TYPE_CHECKING: + assert isinstance(inactive_assets_resp, InactiveAssetsResult) + if inactive_assets := inactive_assets_resp.inactive_assets: + raise AirflowInactiveAssetInInletOrOutletException( + inactive_asset_keys=[ + AssetUniqueKey.from_profile(asset_profile) for asset_profile in inactive_assets + ] + ) + + def _defer_task( defer: TaskDeferred, ti: RuntimeTaskInstance, log: Logger ) -> tuple[ToSupervisor, TaskInstanceState]: diff --git a/task-sdk/tests/task_sdk/definitions/test_asset.py b/task-sdk/tests/task_sdk/definitions/test_asset.py index 8328b061811ce..2a25c0907c7dc 100644 --- a/task-sdk/tests/task_sdk/definitions/test_asset.py +++ b/task-sdk/tests/task_sdk/definitions/test_asset.py @@ -17,6 +17,7 @@ from __future__ import annotations +import json import os from typing import Callable from unittest import mock @@ -24,6 +25,7 @@ import pytest from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.sdk.definitions.asset import ( Asset, AssetAlias, @@ -384,6 +386,39 @@ def test_normalize_uri_valid_uri(mock_get_normalized_scheme): assert asset.normalized_uri == "valid_aip60_uri" +class TestAssetUniqueKey: + def test_from_asset(self): + asset = Asset(name="test", uri="test://test/") + + assert AssetUniqueKey.from_asset(asset) == AssetUniqueKey(name="test", uri="test://test/") + + def test_to_asset(self): + assert AssetUniqueKey(name="test", uri="test://test/").to_asset() == Asset( + name="test", uri="test://test/" + ) + + def test_from_str(self): + json_str = json.dumps({"name": "test", "uri": "test://test/"}) + assert AssetUniqueKey.from_str(json_str) == AssetUniqueKey(name="test", uri="test://test/") + + def test_to_str(self): + assert AssetUniqueKey(name="test", uri="test://test/").to_str() == json.dumps( + {"name": "test", "uri": "test://test/"} + ) + + @pytest.mark.parametrize( + "name, uri, expected_asset_unique_key", + [ + ("test", None, AssetUniqueKey(name="test", uri="test")), + (None, "test://test/", AssetUniqueKey(name="test://test/", uri="test://test/")), + ("test", "test://test/", AssetUniqueKey(name="test", uri="test://test/")), + ], + ) + def test_from_profile(self, name, uri, expected_asset_unique_key): + profile = AssetProfile(name=name, uri=uri, type="Asset") + assert AssetUniqueKey.from_profile(profile) == expected_asset_unique_key + + class TestAssetAlias: def test_as_expression(self): alias = AssetAlias(name="test_name", group="test") diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 86a2e747e0f0f..4f5e4cfc7ac9d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -45,6 +45,7 @@ from airflow.sdk.api.client import ServerResponseError from airflow.sdk.api.datamodels._generated import ( AssetEventResponse, + AssetProfile, AssetResponse, DagRunState, TaskInstance, @@ -77,6 +78,7 @@ GetXCom, GetXComSequenceItem, GetXComSequenceSlice, + InactiveAssetsResult, OKResponse, PrevSuccessfulDagRunResult, PutVariable, @@ -90,6 +92,7 @@ TaskStatesResult, TICount, TriggerDagRun, + ValidateInletsAndOutlets, VariableResult, XComResult, XComSequenceIndexResult, @@ -1480,6 +1483,18 @@ def watched_subprocess(self, mocker): None, id="get_asset_events_by_asset_alias", ), + pytest.param( + ValidateInletsAndOutlets(ti_id=TI_ID), + b'{"inactive_assets":[{"name":"asset_name","uri":"asset_uri","type":"asset"}],"type":"InactiveAssetsResult"}\n', + "task_instances.validate_inlets_and_outlets", + (TI_ID,), + {}, + InactiveAssetsResult( + inactive_assets=[AssetProfile(name="asset_name", uri="asset_uri", type="asset")] + ), + None, + id="validate_inlets_and_outlets", + ), pytest.param( SucceedTask( end_date=timezone.parse("2024-10-31T12:00:00Z"), rendered_map_index="test success task" diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index df4be21960fbf..a994e995bace1 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -946,7 +946,12 @@ def test_run_with_asset_outlets( instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) - run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + with mock.patch( + "airflow.sdk.execution_time.task_runner._validate_task_inlets_and_outlets" + ) as validate_mock: + run(ti, context=ti.get_template_context(), log=mock.MagicMock()) + + validate_mock.assert_called_once() mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY)