From 04bdf8c9965f963fd1f9b76e3017ccfc2cb91c2f Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 8 Jan 2026 12:10:20 -0600 Subject: [PATCH 01/19] added two token mechanism for task execution --- .../src/airflow/api_fastapi/auth/tokens.py | 36 ++++++++ .../airflow/api_fastapi/execution_api/app.py | 2 + .../airflow/api_fastapi/execution_api/deps.py | 82 ++++++++++++++++--- .../execution_api/routes/__init__.py | 3 +- .../execution_api/routes/task_instances.py | 18 ++-- .../src/airflow/config_templates/config.yml | 11 +++ .../src/airflow/executors/workloads.py | 8 +- .../unit/api_fastapi/auth/test_tokens.py | 38 +++++++++ .../api_fastapi/execution_api/conftest.py | 36 +++++++- .../versions/head/test_task_instances.py | 47 ++++------- task-sdk/src/airflow/sdk/api/client.py | 7 +- 11 files changed, 235 insertions(+), 53 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 276ae17153da0..2e40413a05e04 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -46,6 +46,7 @@ "JWKS", "JWTGenerator", "JWTValidator", + "TOKEN_SCOPE_QUEUE", "generate_private_key", "get_sig_validation_args", "get_signing_args", @@ -54,6 +55,8 @@ "key_to_jwk_dict", ] +TOKEN_SCOPE_QUEUE = "queue" + class InvalidClaimError(ValueError): """Raised when a claim in the JWT is invalid.""" @@ -458,6 +461,39 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] headers["kid"] = self.kid return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + def generate_queue_token(self, sub: str) -> str: + """ + Generate a long-lived queue token for task workloads. + + Queue tokens have a special 'scope' claim that restricts them to the /run endpoint only. + They are valid for longer (default 24h) to survive queue wait times. + """ + from airflow.configuration import conf + + queue_expiry = conf.getint("execution_api", "jwt_queue_token_expiration_time", fallback=86400) + now = int(datetime.now(tz=timezone.utc).timestamp()) + + claims = { + "jti": uuid.uuid4().hex, + "iss": self.issuer, + "aud": self.audience, + "nbf": now, + "exp": now + queue_expiry, + "iat": now, + "sub": sub, + "scope": TOKEN_SCOPE_QUEUE, + } + + if claims["iss"] is None: + del claims["iss"] + if claims["aud"] is None: + del claims["aud"] + + headers = {"alg": self.algorithm} + if self._private_key: + headers["kid"] = self.kid + return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + def generate_private_key(key_type: str = "RSA", key_size: int = 2048): """ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 9d93f3bf84daf..59d030b198248 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -300,6 +300,7 @@ def app(self): from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.deps import ( JWTBearerDep, + JWTBearerQueueDep, JWTBearerTIPathDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access @@ -315,6 +316,7 @@ async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow + self._app.dependency_overrides[JWTBearerQueueDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index fce188d48ed6d..c324a9a81955e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -26,7 +26,7 @@ from fastapi.security import HTTPBearer from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_QUEUE, JWTValidator from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf @@ -48,12 +48,13 @@ async def _container(request: Request): class JWTBearer(HTTPBearer): """ - A FastAPI security dependency that validates JWT tokens using for the Execution API. + A FastAPI security dependency that validates JWT tokens for the Execution API. - This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that. + This validates tokens are signed and that the ``sub`` is a UUID. Queue-scoped tokens + (with scope="queue") are rejected - they can only be used on the /run endpoint. - The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other - validated claims. + The dependency result will be a `TIToken` object containing the ``id`` UUID (from the ``sub``) + and other validated claims. """ def __init__( @@ -77,7 +78,6 @@ async def __call__( # type: ignore[override] validator: JWTValidator = await services.aget(JWTValidator) try: - # Example: Validate "task_instance_id" component of the path matches the one in the token if self.path_param_name: id = request.path_params[self.path_param_name] validators: dict[str, Any] = { @@ -87,13 +87,70 @@ async def __call__( # type: ignore[override] else: validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) + + # Reject queue-scoped tokens - they can only be used on /run endpoint + # Only check if scope claim is present (allows backwards compatibility with tests) + scope = claims.get("scope") + if scope is not None and scope == TOKEN_SCOPE_QUEUE: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Queue tokens cannot access this endpoint. Use the token from /run response.", + ) + return TIToken(id=claims["sub"], claims=claims) + except HTTPException: + raise except Exception as err: - log.warning( - "Failed to validate JWT", - exc_info=True, - token=creds.credentials, - ) + log.warning("Failed to validate JWT", exc_info=True) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + + +class JWTBearerQueueScope(HTTPBearer): + """ + JWT auth dependency that ONLY accepts queue-scoped tokens. + + Used exclusively by the /run endpoint. Queue tokens have scope="queue" and are + long-lived to survive executor queue wait times. The /run endpoint validates + the queue token and issues a short-lived execution token for subsequent API calls. + """ + + def __init__(self, path_param_name: str | None = None): + super().__init__(auto_error=False) + self.path_param_name = path_param_name + + async def __call__( # type: ignore[override] + self, + request: Request, + services=DepContainer, + ) -> TIToken | None: + creds = await super().__call__(request) + if not creds: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + + validator: JWTValidator = await services.aget(JWTValidator) + + try: + if self.path_param_name: + id = request.path_params[self.path_param_name] + validators: dict[str, Any] = {"sub": {"essential": True, "value": id}} + else: + validators = {} + claims = await validator.avalidated_claims(creds.credentials, validators) + + # Only accept queue-scoped tokens (if scope claim is present) + # This allows backwards compatibility with tests that don't set scope + scope = claims.get("scope") + if scope is not None and scope != TOKEN_SCOPE_QUEUE: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="This endpoint requires a queue-scoped token", + ) + + return TIToken(id=claims["sub"], claims=claims) + except HTTPException: + raise + except Exception as err: + log.warning("Failed to validate JWT", exc_info=True) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") @@ -102,6 +159,9 @@ async def __call__( # type: ignore[override] # This checks that the UUID in the url matches the one in the token for us. JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) +# For /run endpoint only - accepts queue-scoped tokens and validates task_instance_id +JWTBearerQueueDep = Depends(JWTBearerQueueScope(path_param_name="task_instance_id")) + async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None: """Return the team name associated to the task (if any).""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 562b8588fbf2c..4871643ed7c8f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -43,7 +43,6 @@ authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) -authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] ) @@ -52,3 +51,5 @@ authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"]) execution_api_router.include_router(authenticated_router) + +execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index a73145b30ab5a..e0d67f4a831b0 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, HTTPException, Query, Response, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -38,6 +38,7 @@ from structlog.contextvars import bind_contextvars from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime @@ -59,7 +60,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerQueueDep, JWTBearerTIPathDep from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -84,11 +85,10 @@ ] ) - log = structlog.get_logger(__name__) -@ti_id_router.patch( +@router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, responses={ @@ -97,12 +97,15 @@ HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, + dependencies=[JWTBearerQueueDep], ) -def ti_run( +async def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep, dag_bag: DagBagDep, + response: Response, + services=DepContainer, ) -> TIRunContext: """ Run a TaskInstance. @@ -264,6 +267,11 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs + # Generate short-lived execution token for subsequent API calls + generator: JWTGenerator = await services.aget(JWTGenerator) + execution_token = generator.generate(extras={"sub": ti_id_str}) + response.headers["X-Execution-Token"] = execution_token + return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 9532de60a6257..191db34eb6064 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1858,6 +1858,17 @@ execution_api: type: integer example: ~ default: "600" + jwt_queue_token_expiration_time: + description: | + Number in seconds until the queue JWT token expires. Queue tokens are long-lived tokens + sent with task workloads to executors (e.g., Celery). They can only be used to call + the /run endpoint, which then issues a short-lived execution token. + + This should be set long enough to cover the maximum expected queue wait time. + version_added: 3.1.0 + type: integer + example: ~ + default: "86400" jwt_audience: version_added: 3.0.0 description: | diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 7cf1aae60ff21..b49d79db69c1b 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -45,7 +45,13 @@ class BaseWorkload(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" + """ + Generate a queue-scoped token for this workload. + + Queue tokens are long-lived and can only be used on the /run endpoint, + which exchanges them for short-lived execution tokens. + """ + return generator.generate_queue_token(sub_id) if generator else "" class BundleInfo(BaseModel): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index b17c8147dae24..cc1c610722018 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -30,6 +30,7 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.auth.tokens import ( JWKS, + TOKEN_SCOPE_QUEUE, InvalidClaimError, JWTGenerator, JWTValidator, @@ -238,3 +239,40 @@ def rsa_private_key(): @pytest.fixture(scope="session") def ed25519_private_key(): return generate_private_key(key_type="Ed25519") + + +async def test_generate_queue_token(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): + """Test that generate_queue_token creates tokens with queue scope and longer expiration.""" + token = jwt_generator.generate_queue_token("test_subject") + + claims = await jwt_validator.avalidated_claims( + token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + # Verify queue scope is set + assert claims.get("scope") == TOKEN_SCOPE_QUEUE, "Queue token should have queue scope" + + # Verify the token has extended expiration (default 24h = 86400s) + nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc) + exp = datetime.fromtimestamp(claims["exp"], timezone.utc) + expiration_seconds = (exp - nbf).total_seconds() + + # Should be around 24 hours (86400 seconds) - allow some tolerance + assert expiration_seconds >= 86000, "Queue token should have extended expiration (~24h)" + assert expiration_seconds <= 90000, "Queue token expiration should not exceed expected duration" + + +async def test_queue_token_vs_regular_token_scope(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): + """Test that regular tokens don't have scope claim while queue tokens do.""" + regular_token = jwt_generator.generate({"sub": "test_subject"}) + queue_token = jwt_generator.generate_queue_token("test_subject") + + regular_claims = await jwt_validator.avalidated_claims( + regular_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + queue_claims = await jwt_validator.avalidated_claims( + queue_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + assert "scope" not in regular_claims, "Regular token should not have scope claim" + assert queue_claims.get("scope") == TOKEN_SCOPE_QUEUE, "Queue token should have queue scope" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9e26937b63c06..5a673556bd718 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,14 +16,25 @@ # under the License. from __future__ import annotations -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import ( + JWTBearerDep, + JWTBearerQueueDep, + JWTBearerTIPathDep, +) + + +def _always_allow(ti_id: str | None = None) -> TIToken: + """Return a mock TIToken for bypassing auth in tests.""" + return TIToken(id=ti_id or "00000000-0000-0000-0000-000000000000", claims={}) @pytest.fixture @@ -63,4 +74,25 @@ def smart_validated_claims(cred, validators=None): auth.avalidated_claims.side_effect = smart_validated_claims lifespan.registry.register_value(JWTValidator, auth) + # Mock JWTGenerator for /run endpoint that returns execution tokens + jwt_generator = MagicMock(spec=JWTGenerator) + jwt_generator.generate.return_value = "mock-execution-token" + lifespan.registry.register_value(JWTGenerator, jwt_generator) + + # Override auth dependencies to bypass token scope validation in tests + # This allows tests to focus on business logic rather than auth mechanics + # We need to override both the specific instances AND the classes to cover all cases + jwt_bearer_instance = JWTBearerDep.dependency + jwt_bearer_ti_path_instance = JWTBearerTIPathDep.dependency + jwt_bearer_queue_instance = JWTBearerQueueDep.dependency + + app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() + app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() + app.dependency_overrides[jwt_bearer_queue_instance] = lambda: _always_allow() + yield client + + # Clean up dependency overrides + app.dependency_overrides.pop(jwt_bearer_instance, None) + app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) + app.dependency_overrides.pop(jwt_bearer_queue_instance, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index cfee6c9d46cb1..714200a2cf3ed 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -29,8 +29,6 @@ from sqlalchemy.orm import Session from airflow._shared.timezones import timezone -from airflow.api_fastapi.auth.tokens import JWTValidator -from airflow.api_fastapi.execution_api.app import lifespan from airflow.exceptions import AirflowSkipException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel @@ -79,50 +77,35 @@ def _create_asset_aliases(session, num: int = 2) -> None: def client_with_extra_route(): ... -def test_id_matches_sub_claim(client, session, create_task_instance): - # Test that this is validated at the router level, so we don't have to test it in each component - # We validate it is set correctly, and test it once +def test_run_endpoint_returns_execution_token(client, session, create_task_instance, time_machine): + """Test that /run endpoint returns an execution token in the response header.""" + instant = timezone.parse("2024-09-30T12:00:00Z") + time_machine.move_to(instant, tick=False) ti = create_task_instance( - task_id="test_ti_run_state_conflict_if_not_queued", - state="queued", + task_id="test_run_endpoint_returns_execution_token", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), ) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - claims = {"sub": ti.id} - - def side_effect(cred, validators): - if not validators: - return claims - if validators["sub"]["value"] != ti.id: - raise RuntimeError("Fake auth denied") - return claims - - validator.avalidated_claims.side_effect = side_effect - - lifespan.registry.register_value(JWTValidator, validator) - payload = { "state": "running", "hostname": "random-hostname", "unixname": "random-unixname", "pid": 100, - "start_date": "2024-10-31T12:00:00Z", + "start_date": "2024-09-30T12:00:00Z", } - resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload) - assert resp.status_code == 403 - assert validator.avalidated_claims.call_args_list[1] == mock.call( - mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}} - ) - validator.avalidated_claims.reset_mock() - resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload) + assert resp.status_code == 200 - assert resp.status_code == 200, resp.json() - - validator.avalidated_claims.assert_awaited() + # Verify execution token is returned in header + assert "X-Execution-Token" in resp.headers + assert resp.headers["X-Execution-Token"] == "mock-execution-token" class TestTIRunState: diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 2a38ef2fad7b3..c6e70a91306c6 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -928,7 +928,12 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * ) def _update_auth(self, response: httpx.Response): - if new_token := response.headers.get("Refreshed-API-Token"): + # Check for execution token from /run endpoint (replaces queue token with short-lived execution token) + if new_token := response.headers.get("X-Execution-Token"): + log.debug("Received execution token from /run endpoint") + self.auth = BearerAuth(new_token) + # Check for refreshed token from heartbeat/other endpoints + elif new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token) From e104ffa3d44e6945193c422fe5e3cbf34146bdb9 Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 8 Jan 2026 19:19:20 -0600 Subject: [PATCH 02/19] fix failing tests --- airflow-core/tests/unit/jobs/test_scheduler_job.py | 2 +- devel-common/src/tests_common/test_utils/mock_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 3636c2b6db980..b38537cd8708f 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -235,7 +235,7 @@ def set_instance_attrs(self) -> Generator: @pytest.fixture def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_queue_token.return_value = "mock-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4e95ed3a4eea7..17c2c648b309b 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -58,7 +58,7 @@ def __init__(self, do_update=True, *args, **kwargs): # Mock JWT generator for token generation mock_jwt_generator = MagicMock() - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_queue_token.return_value = "mock-token" self.jwt_generator = mock_jwt_generator From e5905459d9802fafe7adfaefb0ff0f3a3d79e993 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 9 Jan 2026 01:14:23 -0600 Subject: [PATCH 03/19] fix failing test --- .../airflow/api_fastapi/execution_api/app.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 59d030b198248..b6e60f249c2ef 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -296,9 +296,12 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: + from unittest.mock import AsyncMock, MagicMock + + from airflow.api_fastapi.auth.tokens import JWTValidator from airflow.api_fastapi.common.dagbag import create_dag_bag - from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.deps import ( + DepContainer, JWTBearerDep, JWTBearerQueueDep, JWTBearerTIPathDep, @@ -306,6 +309,11 @@ def app(self): from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.configuration import conf + + # Set a dummy JWT secret so the lifespan can create JWT services without failing. + if not conf.get("api_auth", "jwt_secret", fallback=None): + conf.set("api_auth", "jwt_secret", "in-process-test-secret-key") self._app = create_task_execution_api_app() @@ -321,24 +329,52 @@ async def always_allow(): ... self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow + # Create a mock container that provides mock JWT services + mock_jwt_generator = MagicMock(spec=JWTGenerator) + mock_jwt_generator.generate.return_value = "mock-execution-token" + + mock_jwt_validator = AsyncMock(spec=JWTValidator) + mock_jwt_validator.avalidated_claims.return_value = {"sub": "test", "exp": 9999999999} + + class MockContainer: + """A mock svcs container that returns mock services.""" + + async def aget(self, svc_type): + if svc_type is JWTGenerator: + return mock_jwt_generator + if svc_type is JWTValidator: + return mock_jwt_validator + raise ValueError(f"Unknown service type: {svc_type}") + + async def mock_container_dep(): + return MockContainer() + + self._app.dependency_overrides[DepContainer.dependency] = mock_container_dep + return self._app @cached_property def transport(self) -> httpx.WSGITransport: import asyncio + import threading import httpx from a2wsgi import ASGIMiddleware middleware = ASGIMiddleware(self.app) + lifespan_started = threading.Event() # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): await cm.enter_async_context(app.router.lifespan_context(app)) + lifespan_started.set() self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) + # Wait for lifespan to complete before returning the transport + lifespan_started.wait(timeout=5.0) + return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] @cached_property From a4eb6120a30f1a819cef517d907a230ce96575ca Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 9 Jan 2026 17:44:00 -0600 Subject: [PATCH 04/19] further enhanced the implementation --- .../src/airflow/api_fastapi/auth/tokens.py | 59 ++++++------ .../airflow/api_fastapi/execution_api/app.py | 36 +------ .../airflow/api_fastapi/execution_api/deps.py | 94 +++++++------------ .../execution_api/routes/task_instances.py | 9 +- .../src/airflow/config_templates/config.yml | 6 +- .../src/airflow/executors/workloads.py | 6 +- .../unit/api_fastapi/auth/test_tokens.py | 30 +++--- .../api_fastapi/execution_api/conftest.py | 8 +- .../tests/unit/jobs/test_scheduler_job.py | 2 +- .../tests_common/test_utils/mock_executor.py | 2 +- 10 files changed, 100 insertions(+), 152 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 2e40413a05e04..b39959ca18be0 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -46,7 +46,7 @@ "JWKS", "JWTGenerator", "JWTValidator", - "TOKEN_SCOPE_QUEUE", + "TOKEN_SCOPE_WORKLOAD", "generate_private_key", "get_sig_validation_args", "get_signing_args", @@ -55,7 +55,7 @@ "key_to_jwk_dict", ] -TOKEN_SCOPE_QUEUE = "queue" +TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload" class InvalidClaimError(ValueError): @@ -437,15 +437,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str: - """Generate a signed JWT for the subject.""" + def generate( + self, + extras: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + expiry: int | None = None, + ) -> str: + """ + Generate a signed JWT. + + Args: + extras: Additional claims to include in the token. These are merged with default claims. + headers: Additional headers to include in the JWT. + expiry: Optional custom expiry time in seconds. If not provided, uses self.valid_for. + """ now = int(datetime.now(tz=timezone.utc).timestamp()) + valid_for = expiry if expiry is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + self.valid_for), + "exp": int(now + valid_for), "iat": now, } @@ -461,38 +474,20 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] headers["kid"] = self.kid return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) - def generate_queue_token(self, sub: str) -> str: + def generate_workload_token(self, sub: str) -> str: """ - Generate a long-lived queue token for task workloads. + Generate a long-lived workload token for task execution. - Queue tokens have a special 'scope' claim that restricts them to the /run endpoint only. - They are valid for longer (default 24h) to survive queue wait times. + Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only. + They are valid for longer (default 24h) to survive executor queue wait times. """ from airflow.configuration import conf - queue_expiry = conf.getint("execution_api", "jwt_queue_token_expiration_time", fallback=86400) - now = int(datetime.now(tz=timezone.utc).timestamp()) - - claims = { - "jti": uuid.uuid4().hex, - "iss": self.issuer, - "aud": self.audience, - "nbf": now, - "exp": now + queue_expiry, - "iat": now, - "sub": sub, - "scope": TOKEN_SCOPE_QUEUE, - } - - if claims["iss"] is None: - del claims["iss"] - if claims["aud"] is None: - del claims["aud"] - - headers = {"alg": self.algorithm} - if self._private_key: - headers["kid"] = self.kid - return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + workload_expiry = conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400) + return self.generate( + extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD}, + expiry=workload_expiry, + ) def generate_private_key(key_type: str = "RSA", key_size: int = 2048): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index b6e60f249c2ef..fb88b09e3d9da 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -296,24 +296,16 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: - from unittest.mock import AsyncMock, MagicMock - - from airflow.api_fastapi.auth.tokens import JWTValidator from airflow.api_fastapi.common.dagbag import create_dag_bag + from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.deps import ( - DepContainer, JWTBearerDep, - JWTBearerQueueDep, JWTBearerTIPathDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access + from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access - from airflow.configuration import conf - - # Set a dummy JWT secret so the lifespan can create JWT services without failing. - if not conf.get("api_auth", "jwt_secret", fallback=None): - conf.set("api_auth", "jwt_secret", "in-process-test-secret-key") self._app = create_task_execution_api_app() @@ -324,33 +316,11 @@ async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow - self._app.dependency_overrides[JWTBearerQueueDep.dependency] = always_allow + self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow - # Create a mock container that provides mock JWT services - mock_jwt_generator = MagicMock(spec=JWTGenerator) - mock_jwt_generator.generate.return_value = "mock-execution-token" - - mock_jwt_validator = AsyncMock(spec=JWTValidator) - mock_jwt_validator.avalidated_claims.return_value = {"sub": "test", "exp": 9999999999} - - class MockContainer: - """A mock svcs container that returns mock services.""" - - async def aget(self, svc_type): - if svc_type is JWTGenerator: - return mock_jwt_generator - if svc_type is JWTValidator: - return mock_jwt_validator - raise ValueError(f"Unknown service type: {svc_type}") - - async def mock_container_dep(): - return MockContainer() - - self._app.dependency_overrides[DepContainer.dependency] = mock_container_dep - return self._app @cached_property diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index c324a9a81955e..e088c2ff28f5e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -26,7 +26,7 @@ from fastapi.security import HTTPBearer from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_QUEUE, JWTValidator +from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_WORKLOAD, JWTValidator from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf @@ -46,15 +46,12 @@ async def _container(request: Request): DepContainer: svcs.Container = Depends(_container) -class JWTBearer(HTTPBearer): +class _BaseJWTBearer(HTTPBearer): """ - A FastAPI security dependency that validates JWT tokens for the Execution API. + Base class for JWT validation in the Execution API. - This validates tokens are signed and that the ``sub`` is a UUID. Queue-scoped tokens - (with scope="queue") are rejected - they can only be used on the /run endpoint. - - The dependency result will be a `TIToken` object containing the ``id`` UUID (from the ``sub``) - and other validated claims. + Validates JWT tokens are properly signed and extracts claims. Subclasses + handle scope-specific validation. """ def __init__( @@ -88,14 +85,8 @@ async def __call__( # type: ignore[override] validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) - # Reject queue-scoped tokens - they can only be used on /run endpoint - # Only check if scope claim is present (allows backwards compatibility with tests) - scope = claims.get("scope") - if scope is not None and scope == TOKEN_SCOPE_QUEUE: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Queue tokens cannot access this endpoint. Use the token from /run response.", - ) + # Let subclasses validate scope + self._check_scope(claims) return TIToken(id=claims["sub"], claims=claims) except HTTPException: @@ -104,54 +95,44 @@ async def __call__( # type: ignore[override] log.warning("Failed to validate JWT", exc_info=True) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + def _check_scope(self, claims: dict[str, Any]) -> None: + """Override in subclasses to validate scope. Raise HTTPException if invalid.""" + pass -class JWTBearerQueueScope(HTTPBearer): - """ - JWT auth dependency that ONLY accepts queue-scoped tokens. - Used exclusively by the /run endpoint. Queue tokens have scope="queue" and are - long-lived to survive executor queue wait times. The /run endpoint validates - the queue token and issues a short-lived execution token for subsequent API calls. +class JWTBearer(_BaseJWTBearer): """ + JWT validation that rejects workload-scoped tokens. - def __init__(self, path_param_name: str | None = None): - super().__init__(auto_error=False) - self.path_param_name = path_param_name + Used for most Execution API endpoints. Workload-scoped tokens can only be used + on the /run endpoint, which exchanges them for short-lived execution tokens. + """ - async def __call__( # type: ignore[override] - self, - request: Request, - services=DepContainer, - ) -> TIToken | None: - creds = await super().__call__(request) - if not creds: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + def _check_scope(self, claims: dict[str, Any]) -> None: + if claims.get("scope") == TOKEN_SCOPE_WORKLOAD: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Workload tokens cannot access this endpoint. Use the token from /run response.", + ) - validator: JWTValidator = await services.aget(JWTValidator) - try: - if self.path_param_name: - id = request.path_params[self.path_param_name] - validators: dict[str, Any] = {"sub": {"essential": True, "value": id}} - else: - validators = {} - claims = await validator.avalidated_claims(creds.credentials, validators) +class JWTBearerWorkloadScope(_BaseJWTBearer): + """ + JWT validation that ONLY accepts workload-scoped tokens. - # Only accept queue-scoped tokens (if scope claim is present) - # This allows backwards compatibility with tests that don't set scope - scope = claims.get("scope") - if scope is not None and scope != TOKEN_SCOPE_QUEUE: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="This endpoint requires a queue-scoped token", - ) + Used exclusively by the /run endpoint. Workload tokens have scope="ExecuteTaskWorkload" + and are long-lived to survive executor queue wait times. The /run endpoint validates + the workload token and issues a short-lived execution token for subsequent API calls. + """ - return TIToken(id=claims["sub"], claims=claims) - except HTTPException: - raise - except Exception as err: - log.warning("Failed to validate JWT", exc_info=True) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + def _check_scope(self, claims: dict[str, Any]) -> None: + scope = claims.get("scope") + # Reject if scope is explicitly set to something other than workload scope + if scope is not None and scope != TOKEN_SCOPE_WORKLOAD: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="This endpoint requires a workload-scoped token", + ) JWTBearerDep: TIToken = Depends(JWTBearer()) @@ -159,9 +140,6 @@ async def __call__( # type: ignore[override] # This checks that the UUID in the url matches the one in the token for us. JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) -# For /run endpoint only - accepts queue-scoped tokens and validates task_instance_id -JWTBearerQueueDep = Depends(JWTBearerQueueScope(path_param_name="task_instance_id")) - async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None: """Return the team name associated to the task (if any).""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index e0d67f4a831b0..4062ab64d94ad 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, Response, status +from fastapi import Body, Depends, HTTPException, Query, Response, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -60,7 +60,7 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerQueueDep, JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerTIPathDep, JWTBearerWorkloadScope from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -87,6 +87,9 @@ log = structlog.get_logger(__name__) +# For /run endpoint only - accepts workload-scoped tokens and validates task_instance_id +JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id")) + @router.patch( "/{task_instance_id}/run", @@ -97,7 +100,7 @@ HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, - dependencies=[JWTBearerQueueDep], + dependencies=[JWTBearerWorkloadDep], ) async def ti_run( task_instance_id: UUID, diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 191db34eb6064..97108feacd5c0 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1858,14 +1858,14 @@ execution_api: type: integer example: ~ default: "600" - jwt_queue_token_expiration_time: + jwt_workload_token_expiration_time: description: | - Number in seconds until the queue JWT token expires. Queue tokens are long-lived tokens + Number in seconds until the workload JWT token expires. Workload tokens are long-lived tokens sent with task workloads to executors (e.g., Celery). They can only be used to call the /run endpoint, which then issues a short-lived execution token. This should be set long enough to cover the maximum expected queue wait time. - version_added: 3.1.0 + version_added: 3.1.7 type: integer example: ~ default: "86400" diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index b49d79db69c1b..c73231dc0571b 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -46,12 +46,12 @@ class BaseWorkload(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: """ - Generate a queue-scoped token for this workload. + Generate a workload-scoped token for this workload. - Queue tokens are long-lived and can only be used on the /run endpoint, + Workload tokens are long-lived and can only be used on the /run endpoint, which exchanges them for short-lived execution tokens. """ - return generator.generate_queue_token(sub_id) if generator else "" + return generator.generate_workload_token(sub_id) if generator else "" class BundleInfo(BaseModel): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index cc1c610722018..ed7d51b7fad9e 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -30,7 +30,7 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.auth.tokens import ( JWKS, - TOKEN_SCOPE_QUEUE, + TOKEN_SCOPE_WORKLOAD, InvalidClaimError, JWTGenerator, JWTValidator, @@ -241,16 +241,16 @@ def ed25519_private_key(): return generate_private_key(key_type="Ed25519") -async def test_generate_queue_token(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): - """Test that generate_queue_token creates tokens with queue scope and longer expiration.""" - token = jwt_generator.generate_queue_token("test_subject") +async def test_generate_workload_token(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): + """Test that generate_workload_token creates tokens with workload scope and longer expiration.""" + token = jwt_generator.generate_workload_token("test_subject") claims = await jwt_validator.avalidated_claims( token, required_claims={"sub": {"essential": True, "value": "test_subject"}} ) - # Verify queue scope is set - assert claims.get("scope") == TOKEN_SCOPE_QUEUE, "Queue token should have queue scope" + # Verify workload scope is set + assert claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" # Verify the token has extended expiration (default 24h = 86400s) nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc) @@ -258,21 +258,23 @@ async def test_generate_queue_token(jwt_generator: JWTGenerator, jwt_validator: expiration_seconds = (exp - nbf).total_seconds() # Should be around 24 hours (86400 seconds) - allow some tolerance - assert expiration_seconds >= 86000, "Queue token should have extended expiration (~24h)" - assert expiration_seconds <= 90000, "Queue token expiration should not exceed expected duration" + assert expiration_seconds >= 86000, "Workload token should have extended expiration (~24h)" + assert expiration_seconds <= 90000, "Workload token expiration should not exceed expected duration" -async def test_queue_token_vs_regular_token_scope(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): - """Test that regular tokens don't have scope claim while queue tokens do.""" +async def test_workload_token_vs_regular_token_scope( + jwt_generator: JWTGenerator, jwt_validator: JWTValidator +): + """Test that regular tokens don't have scope claim while workload tokens do.""" regular_token = jwt_generator.generate({"sub": "test_subject"}) - queue_token = jwt_generator.generate_queue_token("test_subject") + workload_token = jwt_generator.generate_workload_token("test_subject") regular_claims = await jwt_validator.avalidated_claims( regular_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} ) - queue_claims = await jwt_validator.avalidated_claims( - queue_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + workload_claims = await jwt_validator.avalidated_claims( + workload_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} ) assert "scope" not in regular_claims, "Regular token should not have scope claim" - assert queue_claims.get("scope") == TOKEN_SCOPE_QUEUE, "Queue token should have queue scope" + assert workload_claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 5a673556bd718..4f295a45ad18e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -27,9 +27,9 @@ from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.deps import ( JWTBearerDep, - JWTBearerQueueDep, JWTBearerTIPathDep, ) +from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep def _always_allow(ti_id: str | None = None) -> TIToken: @@ -84,15 +84,15 @@ def smart_validated_claims(cred, validators=None): # We need to override both the specific instances AND the classes to cover all cases jwt_bearer_instance = JWTBearerDep.dependency jwt_bearer_ti_path_instance = JWTBearerTIPathDep.dependency - jwt_bearer_queue_instance = JWTBearerQueueDep.dependency + jwt_bearer_workload_instance = JWTBearerWorkloadDep.dependency app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() - app.dependency_overrides[jwt_bearer_queue_instance] = lambda: _always_allow() + app.dependency_overrides[jwt_bearer_workload_instance] = lambda: _always_allow() yield client # Clean up dependency overrides app.dependency_overrides.pop(jwt_bearer_instance, None) app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) - app.dependency_overrides.pop(jwt_bearer_queue_instance, None) + app.dependency_overrides.pop(jwt_bearer_workload_instance, None) diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index b38537cd8708f..01b2a492abbe3 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -235,7 +235,7 @@ def set_instance_attrs(self) -> Generator: @pytest.fixture def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) - mock_jwt_generator.generate_queue_token.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 17c2c648b309b..6280bf03f0ee0 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -58,7 +58,7 @@ def __init__(self, do_update=True, *args, **kwargs): # Mock JWT generator for token generation mock_jwt_generator = MagicMock() - mock_jwt_generator.generate_queue_token.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" self.jwt_generator = mock_jwt_generator From 560fe2f7c09b17276a547635117d3805a9a47903 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 9 Jan 2026 18:22:47 -0600 Subject: [PATCH 05/19] further clean ups --- airflow-core/src/airflow/api_fastapi/execution_api/app.py | 5 ----- .../src/airflow/api_fastapi/execution_api/routes/__init__.py | 3 +++ 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index fb88b09e3d9da..dfce9e4884cde 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -326,24 +326,19 @@ async def always_allow(): ... @cached_property def transport(self) -> httpx.WSGITransport: import asyncio - import threading import httpx from a2wsgi import ASGIMiddleware middleware = ASGIMiddleware(self.app) - lifespan_started = threading.Event() # https://github.com/abersheeran/a2wsgi/discussions/64 async def start_lifespan(cm: AsyncExitStack, app: FastAPI): await cm.enter_async_context(app.router.lifespan_context(app)) - lifespan_started.set() self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) - # Wait for lifespan to complete before returning the transport - lifespan_started.wait(timeout=5.0) return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 4871643ed7c8f..7caa42a224faf 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -52,4 +52,7 @@ execution_api_router.include_router(authenticated_router) +# task_instances.router is NOT in authenticated_router because its /run endpoint requires +# workload-scoped tokens (JWTBearerWorkloadDep), which are rejected by JWTBearerDep. +# The router handles its own auth: /run uses JWTBearerWorkloadDep, others use JWTBearerTIPathDep. execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) From ee8a73da3ee5daf292d59b4ce79dfcc077305b00 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 9 Jan 2026 20:17:58 -0600 Subject: [PATCH 06/19] clean ups --- .../api_fastapi/execution_api/routes/__init__.py | 9 +++++---- .../execution_api/routes/task_instances.py | 14 +++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 7caa42a224faf..bb2e7c33b7ae9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -43,6 +43,7 @@ authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) +authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] ) @@ -52,7 +53,7 @@ execution_api_router.include_router(authenticated_router) -# task_instances.router is NOT in authenticated_router because its /run endpoint requires -# workload-scoped tokens (JWTBearerWorkloadDep), which are rejected by JWTBearerDep. -# The router handles its own auth: /run uses JWTBearerWorkloadDep, others use JWTBearerTIPathDep. -execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) +# ti_run_router: /run endpoint - requires workload-scoped tokens (JWTBearerWorkloadDep) +execution_api_router.include_router( + task_instances.ti_run_router, prefix="/task-instances", tags=["Task Instances"] +) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 4062ab64d94ad..4febfcb7b83d1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -76,6 +76,12 @@ if TYPE_CHECKING: from sqlalchemy.sql.dml import Update +log = structlog.get_logger(__name__) + +JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id")) + +ti_run_router = VersionedAPIRouter(dependencies=[JWTBearerWorkloadDep]) + router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( @@ -85,13 +91,8 @@ ] ) -log = structlog.get_logger(__name__) - -# For /run endpoint only - accepts workload-scoped tokens and validates task_instance_id -JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id")) - -@router.patch( +@ti_run_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, responses={ @@ -100,7 +101,6 @@ HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, - dependencies=[JWTBearerWorkloadDep], ) async def ti_run( task_instance_id: UUID, From b4a5e05f56a7a9d7e9c8d5389173966d785fc36b Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 01:06:36 -0600 Subject: [PATCH 07/19] test fixes --- .../src/airflow/api_fastapi/execution_api/app.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index dfce9e4884cde..13a8c93186a1f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -296,6 +296,9 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: + import os + from base64 import urlsafe_b64encode + from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.deps import ( @@ -306,6 +309,16 @@ def app(self): from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.configuration import conf + + # Ensure JWT secret is available for in-process execution. + # The /run endpoint needs JWTGenerator to issue execution tokens. + # If the config option is empty, generate a random one for the duration of this process. + if not conf.get("api_auth", "jwt_secret", fallback=None): + logger.debug( + "`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution" + ) + conf.set("api_auth", "jwt_secret", urlsafe_b64encode(os.urandom(16)).decode()) self._app = create_task_execution_api_app() From 041b0033dae46222050e944983b1c11c95b1d771 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 15:36:18 -0600 Subject: [PATCH 08/19] fix jwt bearer class not registered --- .../airflow/api_fastapi/execution_api/app.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 13a8c93186a1f..a5e09af5cfabe 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -325,6 +325,33 @@ def app(self): # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() + self._app.state.jwt_generator = _jwt_generator() + self._app.state.jwt_validator = _jwt_validator() + + self._app.state.svcs_registry = svcs.Registry() + self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) + self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) + + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """A container-like object that provides services from app.state.""" + + def __init__(self, app_state): + self._app_state = app_state + + async def aget(self, svc_type): + if svc_type is JWTGenerator: + return self._app_state.jwt_generator + if svc_type is JWTValidator: + return self._app_state.jwt_validator + raise KeyError(svc_type) + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container + async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow From f34cef4a01972b10eba087685795d16b005af1f0 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 17:21:44 -0600 Subject: [PATCH 09/19] clean ups uneccesary overrides --- .../airflow/api_fastapi/execution_api/app.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index a5e09af5cfabe..2b505ebb84bbb 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -332,26 +332,6 @@ def app(self): self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) - from airflow.api_fastapi.execution_api.deps import _container - - class InProcessContainer: - """A container-like object that provides services from app.state.""" - - def __init__(self, app_state): - self._app_state = app_state - - async def aget(self, svc_type): - if svc_type is JWTGenerator: - return self._app_state.jwt_generator - if svc_type is JWTValidator: - return self._app_state.jwt_validator - raise KeyError(svc_type) - - async def _inprocess_container(): - yield InProcessContainer(self._app.state) - - self._app.dependency_overrides[_container] = _inprocess_container - async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow From 8ea9379922134c215ee25e7e628b1cb9d6e3d4b8 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 17:29:40 -0600 Subject: [PATCH 10/19] clean ups --- airflow-core/src/airflow/api_fastapi/execution_api/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 2b505ebb84bbb..f196bc5c1a89e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -359,7 +359,6 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) - return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] @cached_property From 66f92d48c2516b08c49579e12832b0f75b73a4b4 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 18:08:45 -0600 Subject: [PATCH 11/19] add back the override --- .../airflow/api_fastapi/execution_api/app.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index f196bc5c1a89e..288028619847d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -332,6 +332,26 @@ def app(self): self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """A container-like object that provides services from app.state.""" + + def __init__(self, app_state): + self._app_state = app_state + + async def aget(self, svc_type): + if svc_type is JWTGenerator: + return self._app_state.jwt_generator + if svc_type is JWTValidator: + return self._app_state.jwt_validator + raise KeyError(svc_type) + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container + async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow From c9702712882ccf33c42d568a468b2f3cf6f6e669 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 19:38:38 -0600 Subject: [PATCH 12/19] clean up ovveride --- .../airflow/api_fastapi/execution_api/app.py | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 288028619847d..f43470d495aa5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -325,32 +325,9 @@ def app(self): # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() - self._app.state.jwt_generator = _jwt_generator() - self._app.state.jwt_validator = _jwt_validator() - self._app.state.svcs_registry = svcs.Registry() - self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) - self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) - - from airflow.api_fastapi.execution_api.deps import _container - - class InProcessContainer: - """A container-like object that provides services from app.state.""" - - def __init__(self, app_state): - self._app_state = app_state - - async def aget(self, svc_type): - if svc_type is JWTGenerator: - return self._app_state.jwt_generator - if svc_type is JWTValidator: - return self._app_state.jwt_validator - raise KeyError(svc_type) - - async def _inprocess_container(): - yield InProcessContainer(self._app.state) - - self._app.dependency_overrides[_container] = _inprocess_container + self._app.state.svcs_registry.register_value(JWTGenerator, _jwt_generator()) + self._app.state.svcs_registry.register_value(JWTValidator, _jwt_validator()) async def always_allow(): ... From e97781756fe5ebd43c28144e1f2ec18f2d67cdb3 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 10 Jan 2026 20:03:47 -0600 Subject: [PATCH 13/19] rollback to explicit override --- .../airflow/api_fastapi/execution_api/app.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index f43470d495aa5..a5e09af5cfabe 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -325,9 +325,32 @@ def app(self): # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() + self._app.state.jwt_generator = _jwt_generator() + self._app.state.jwt_validator = _jwt_validator() + self._app.state.svcs_registry = svcs.Registry() - self._app.state.svcs_registry.register_value(JWTGenerator, _jwt_generator()) - self._app.state.svcs_registry.register_value(JWTValidator, _jwt_validator()) + self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) + self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) + + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """A container-like object that provides services from app.state.""" + + def __init__(self, app_state): + self._app_state = app_state + + async def aget(self, svc_type): + if svc_type is JWTGenerator: + return self._app_state.jwt_generator + if svc_type is JWTValidator: + return self._app_state.jwt_validator + raise KeyError(svc_type) + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container async def always_allow(): ... @@ -356,6 +379,7 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) + return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] @cached_property From 27b0e8b4f6c47db012a510e15de92479591d3a35 Mon Sep 17 00:00:00 2001 From: Anish Date: Tue, 13 Jan 2026 00:25:25 -0600 Subject: [PATCH 14/19] use real container --- .../src/airflow/api_fastapi/execution_api/app.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index a5e09af5cfabe..02d68359a5430 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -334,21 +334,9 @@ def app(self): from airflow.api_fastapi.execution_api.deps import _container - class InProcessContainer: - """A container-like object that provides services from app.state.""" - - def __init__(self, app_state): - self._app_state = app_state - - async def aget(self, svc_type): - if svc_type is JWTGenerator: - return self._app_state.jwt_generator - if svc_type is JWTValidator: - return self._app_state.jwt_validator - raise KeyError(svc_type) - async def _inprocess_container(): - yield InProcessContainer(self._app.state) + async with svcs.Container(self._app.state.svcs_registry) as container: + yield container self._app.dependency_overrides[_container] = _inprocess_container From d782eb9ee37c4cc1c2b18bb243ba68f589b25f31 Mon Sep 17 00:00:00 2001 From: Anish Date: Tue, 13 Jan 2026 10:48:19 -0600 Subject: [PATCH 15/19] refactor in process container override --- .../src/airflow/api_fastapi/execution_api/app.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 02d68359a5430..2089788b46fe2 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -325,20 +325,8 @@ def app(self): # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() - self._app.state.jwt_generator = _jwt_generator() - self._app.state.jwt_validator = _jwt_validator() - - self._app.state.svcs_registry = svcs.Registry() - self._app.state.svcs_registry.register_value(JWTGenerator, self._app.state.jwt_generator) - self._app.state.svcs_registry.register_value(JWTValidator, self._app.state.jwt_validator) - - from airflow.api_fastapi.execution_api.deps import _container - - async def _inprocess_container(): - async with svcs.Container(self._app.state.svcs_registry) as container: - yield container - - self._app.dependency_overrides[_container] = _inprocess_container + lifespan.registry.register_value(JWTGenerator, _jwt_generator()) + lifespan.registry.register_value(JWTValidator, _jwt_validator()) async def always_allow(): ... From 8264c20ec31a948ce9b3820f9f5c7f0e207ea188 Mon Sep 17 00:00:00 2001 From: Anish Date: Tue, 13 Jan 2026 11:13:50 -0600 Subject: [PATCH 16/19] bring back the override --- .../airflow/api_fastapi/execution_api/app.py | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 2089788b46fe2..b85c04dad1b48 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -325,8 +325,39 @@ def app(self): # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() - lifespan.registry.register_value(JWTGenerator, _jwt_generator()) - lifespan.registry.register_value(JWTValidator, _jwt_validator()) + self._app.state.jwt_generator = _jwt_generator() + self._app.state.jwt_validator = _jwt_validator() + + # Why InProcessContainer instead of lifespan.registry or svcs.Container? + # + # The normal app uses @svcs.fastapi.lifespan which manages the registry lifecycle. + # In tests (conftest.py), lifespan.registry.register_value() works because the + # TestClient initializes the lifespan before requests. However, in InProcessExecutionAPI, + # the lifespan runs later (when transport is accessed), but services may be needed + # before that. Using lifespan.registry fails in CI with ServiceNotFoundError. + # + # This minimal container bypasses the svcs lifecycle and directly returns pre-created + # service instances from app.state. If you add new services, update this class. + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """Minimal container for in-process execution, bypassing svcs lifecycle.""" + + def __init__(self, app_state): + self._services = { + JWTGenerator: app_state.jwt_generator, + JWTValidator: app_state.jwt_validator, + } + + async def aget(self, svc_type): + if svc_type not in self._services: + raise KeyError(f"{svc_type} not registered in InProcessContainer") + return self._services[svc_type] + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container async def always_allow(): ... From d38bb20a15f4c85a3b2eafa3013805ffb8a7fab6 Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 17 Jan 2026 02:58:18 -0600 Subject: [PATCH 17/19] refactored token checks and dependecy --- .../src/airflow/api_fastapi/auth/tokens.py | 14 +++++---- .../airflow/api_fastapi/execution_api/app.py | 6 ++-- .../airflow/api_fastapi/execution_api/deps.py | 8 ++--- .../api_fastapi/execution_api/conftest.py | 29 ++++++++++--------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index b39959ca18be0..3e82671799efc 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -441,7 +441,7 @@ def generate( self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, - expiry: int | None = None, + valid_for: int | None = None, ) -> str: """ Generate a signed JWT. @@ -449,16 +449,16 @@ def generate( Args: extras: Additional claims to include in the token. These are merged with default claims. headers: Additional headers to include in the JWT. - expiry: Optional custom expiry time in seconds. If not provided, uses self.valid_for. + valid_for: Optional custom validity duration in seconds. If not provided, uses self.valid_for. """ now = int(datetime.now(tz=timezone.utc).timestamp()) - valid_for = expiry if expiry is not None else self.valid_for + token_valid_for = valid_for if valid_for is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + valid_for), + "exp": int(now + token_valid_for), "iat": now, } @@ -483,10 +483,12 @@ def generate_workload_token(self, sub: str) -> str: """ from airflow.configuration import conf - workload_expiry = conf.getint("execution_api", "jwt_workload_token_expiration_time", fallback=86400) + workload_valid_for = conf.getint( + "execution_api", "jwt_workload_token_expiration_time", fallback=86400 + ) return self.generate( extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD}, - expiry=workload_expiry, + valid_for=workload_valid_for, ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index b85c04dad1b48..fbfd27a54a254 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import secrets import time from contextlib import AsyncExitStack from functools import cached_property @@ -296,9 +297,6 @@ class InProcessExecutionAPI: @cached_property def app(self): if not self._app: - import os - from base64 import urlsafe_b64encode - from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app from airflow.api_fastapi.execution_api.deps import ( @@ -318,7 +316,7 @@ def app(self): logger.debug( "`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution" ) - conf.set("api_auth", "jwt_secret", urlsafe_b64encode(os.urandom(16)).decode()) + conf.set("api_auth", "jwt_secret", secrets.token_urlsafe(16)) self._app = create_task_execution_api_app() diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index e088c2ff28f5e..4195bbddcdec3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -109,10 +109,10 @@ class JWTBearer(_BaseJWTBearer): """ def _check_scope(self, claims: dict[str, Any]) -> None: - if claims.get("scope") == TOKEN_SCOPE_WORKLOAD: + if claims.get("scope"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Workload tokens cannot access this endpoint. Use the token from /run response.", + detail="Scoped tokens cannot access this endpoint. Use the token from /run response.", ) @@ -127,8 +127,8 @@ class JWTBearerWorkloadScope(_BaseJWTBearer): def _check_scope(self, claims: dict[str, Any]) -> None: scope = claims.get("scope") - # Reject if scope is explicitly set to something other than workload scope - if scope is not None and scope != TOKEN_SCOPE_WORKLOAD: + # Reject if scope is missing or not the workload scope + if scope != TOKEN_SCOPE_WORKLOAD: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="This endpoint requires a workload-scoped token", diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 4f295a45ad18e..c9ae4ea364681 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -25,10 +25,7 @@ from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.execution_api.app import lifespan from airflow.api_fastapi.execution_api.datamodels.token import TIToken -from airflow.api_fastapi.execution_api.deps import ( - JWTBearerDep, - JWTBearerTIPathDep, -) +from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTBearerTIPathDep from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep @@ -79,20 +76,24 @@ def smart_validated_claims(cred, validators=None): jwt_generator.generate.return_value = "mock-execution-token" lifespan.registry.register_value(JWTGenerator, jwt_generator) - # Override auth dependencies to bypass token scope validation in tests - # This allows tests to focus on business logic rather than auth mechanics - # We need to override both the specific instances AND the classes to cover all cases jwt_bearer_instance = JWTBearerDep.dependency jwt_bearer_ti_path_instance = JWTBearerTIPathDep.dependency jwt_bearer_workload_instance = JWTBearerWorkloadDep.dependency - app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() - app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() - app.dependency_overrides[jwt_bearer_workload_instance] = lambda: _always_allow() + execution_app = None + for route in app.routes: + if hasattr(route, "path") and route.path == "/execution": + execution_app = route.app + break + + if execution_app: + execution_app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_workload_instance] = lambda: _always_allow() yield client - # Clean up dependency overrides - app.dependency_overrides.pop(jwt_bearer_instance, None) - app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) - app.dependency_overrides.pop(jwt_bearer_workload_instance, None) + if execution_app: + execution_app.dependency_overrides.pop(jwt_bearer_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_workload_instance, None) From 3e0e44388cdbbaa7c9797769ac7f475cef71dc4e Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 22 Jan 2026 19:04:57 -0600 Subject: [PATCH 18/19] implement scope based token authentication --- .../src/airflow/api_fastapi/auth/tokens.py | 9 ++ .../airflow/api_fastapi/execution_api/app.py | 2 +- .../airflow/api_fastapi/execution_api/deps.py | 109 +++++++++--------- .../execution_api/routes/__init__.py | 11 +- .../execution_api/routes/asset_events.py | 3 +- .../execution_api/routes/assets.py | 3 +- .../execution_api/routes/dag_runs.py | 3 +- .../api_fastapi/execution_api/routes/hitl.py | 3 +- .../execution_api/routes/task_instances.py | 34 +++--- .../execution_api/routes/task_reschedules.py | 2 + .../api_fastapi/execution_api/conftest.py | 3 +- 11 files changed, 94 insertions(+), 88 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 3e82671799efc..123e83643e225 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -46,6 +46,9 @@ "JWKS", "JWTGenerator", "JWTValidator", + "SCOPE_EXECUTION", + "SCOPE_MAPPING", + "SCOPE_WORKLOAD", "TOKEN_SCOPE_WORKLOAD", "generate_private_key", "get_sig_validation_args", @@ -56,6 +59,12 @@ ] TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload" +SCOPE_WORKLOAD = "workload" +SCOPE_EXECUTION = "execution" +SCOPE_MAPPING: dict[str, str] = { + TOKEN_SCOPE_WORKLOAD: SCOPE_WORKLOAD, + "": SCOPE_EXECUTION, +} class InvalidClaimError(ValueError): diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index fbfd27a54a254..70e9463769675 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -302,9 +302,9 @@ def app(self): from airflow.api_fastapi.execution_api.deps import ( JWTBearerDep, JWTBearerTIPathDep, + JWTBearerWorkloadDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access - from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access from airflow.configuration import conf diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index 4195bbddcdec3..11e8b1c2cf99c 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -22,11 +22,16 @@ import structlog import svcs -from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBearer +from fastapi import Depends, HTTPException, Request, Security, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import TOKEN_SCOPE_WORKLOAD, JWTValidator +from airflow.api_fastapi.auth.tokens import ( + SCOPE_EXECUTION, + SCOPE_MAPPING, + SCOPE_WORKLOAD, + JWTValidator, +) from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf @@ -46,13 +51,8 @@ async def _container(request: Request): DepContainer: svcs.Container = Depends(_container) -class _BaseJWTBearer(HTTPBearer): - """ - Base class for JWT validation in the Execution API. - - Validates JWT tokens are properly signed and extracts claims. Subclasses - handle scope-specific validation. - """ +class JWTBearer(HTTPBearer): + """JWT Bearer auth with scope validation via FastAPI's SecurityScopes.""" def __init__( self, @@ -63,82 +63,77 @@ def __init__( self.path_param_name = path_param_name self.required_claims = required_claims or {} - async def __call__( # type: ignore[override] + async def __call__( self, request: Request, + security_scopes: SecurityScopes, services=DepContainer, - ) -> TIToken | None: - creds = await super().__call__(request) + ) -> TIToken: + creds: HTTPAuthorizationCredentials | None = await super().__call__(request) if not creds: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing auth token", + headers={"WWW-Authenticate": "Bearer"}, + ) validator: JWTValidator = await services.aget(JWTValidator) try: if self.path_param_name: - id = request.path_params[self.path_param_name] + ti_id = request.path_params[self.path_param_name] validators: dict[str, Any] = { **self.required_claims, - "sub": {"essential": True, "value": id}, + "sub": {"essential": True, "value": ti_id}, } else: validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) - - # Let subclasses validate scope - self._check_scope(claims) - + self._validate_scopes(claims, security_scopes) return TIToken(id=claims["sub"], claims=claims) except HTTPException: raise except Exception as err: log.warning("Failed to validate JWT", exc_info=True) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") - - def _check_scope(self, claims: dict[str, Any]) -> None: - """Override in subclasses to validate scope. Raise HTTPException if invalid.""" - pass - - -class JWTBearer(_BaseJWTBearer): - """ - JWT validation that rejects workload-scoped tokens. - - Used for most Execution API endpoints. Workload-scoped tokens can only be used - on the /run endpoint, which exchanges them for short-lived execution tokens. - """ - - def _check_scope(self, claims: dict[str, Any]) -> None: - if claims.get("scope"): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Scoped tokens cannot access this endpoint. Use the token from /run response.", + detail=f"Invalid auth token: {err}", + headers={"WWW-Authenticate": "Bearer"}, ) + def _validate_scopes(self, claims: dict[str, Any], security_scopes: SecurityScopes) -> None: + if not security_scopes.scopes: + return -class JWTBearerWorkloadScope(_BaseJWTBearer): - """ - JWT validation that ONLY accepts workload-scoped tokens. - - Used exclusively by the /run endpoint. Workload tokens have scope="ExecuteTaskWorkload" - and are long-lived to survive executor queue wait times. The /run endpoint validates - the workload token and issues a short-lived execution token for subsequent API calls. - """ - - def _check_scope(self, claims: dict[str, Any]) -> None: - scope = claims.get("scope") - # Reject if scope is missing or not the workload scope - if scope != TOKEN_SCOPE_WORKLOAD: + token_scope = claims.get("scope", "") + mapped_scope = SCOPE_MAPPING.get(token_scope) + if mapped_scope is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="This endpoint requires a workload-scoped token", + detail=f"Unknown token scope: {token_scope}", + headers={"WWW-Authenticate": "Bearer"}, ) - -JWTBearerDep: TIToken = Depends(JWTBearer()) - -# This checks that the UUID in the url matches the one in the token for us. -JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) + for required_scope in security_scopes.scopes: + if required_scope != mapped_scope: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token missing required scope: {required_scope}", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +_jwt_bearer = JWTBearer() +_jwt_bearer_with_path = JWTBearer(path_param_name="task_instance_id") + +# No scope check - for router-level auth +JWTBearerBaseDep = Security(_jwt_bearer, scopes=[]) +# Execution scope - most endpoints +JWTBearerDep = Security(_jwt_bearer, scopes=[SCOPE_EXECUTION]) +# Execution scope with path param validation +JWTBearerTIPathDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_EXECUTION]) +# Workload scope +JWTBearerWorkloadDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_WORKLOAD]) async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index bb2e7c33b7ae9..b8f5cb25f2c5d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -19,7 +19,7 @@ from cadwyn import VersionedAPIRouter from fastapi import APIRouter -from airflow.api_fastapi.execution_api.deps import JWTBearerDep +from airflow.api_fastapi.execution_api.deps import JWTBearerBaseDep from airflow.api_fastapi.execution_api.routes import ( asset_events, assets, @@ -36,8 +36,8 @@ execution_api_router = APIRouter() execution_api_router.include_router(health.router, prefix="/health", tags=["Health"]) -# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these -authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item] +# Base JWT auth; scopes checked at endpoint/router level +authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerBaseDep]) # type: ignore[list-item] authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) @@ -52,8 +52,3 @@ authenticated_router.include_router(hitl.router, prefix="/hitlDetails", tags=["Human in the Loop"]) execution_api_router.include_router(authenticated_router) - -# ti_run_router: /run endpoint - requires workload-scoped tokens (JWTBearerWorkloadDep) -execution_api_router.include_router( - task_instances.ti_run_router, prefix="/task-instances", tags=["Task Instances"] -) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py index 4525a9140e5d1..dd7a0d7b9859a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py @@ -29,14 +29,15 @@ AssetEventResponse, AssetEventsResponse, ) +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel -# TODO: Add dependency on JWT token router = APIRouter( responses={ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py index 316d4fab4770d..7ea9b6d4335c8 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py @@ -24,14 +24,15 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.asset import AssetModel -# TODO: Add dependency on JWT token router = APIRouter( responses={ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index 56fd79b825eba..9696c352093ff 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -32,13 +32,14 @@ from airflow.api_fastapi.compat import HTTP_422_UNPROCESSABLE_CONTENT from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun as DagRunModel from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType -router = VersionedAPIRouter() +router = VersionedAPIRouter(dependencies=[JWTBearerDep]) log = logging.getLogger(__name__) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py index 1c91c3e5b3b84..a7720f31b4195 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py @@ -29,9 +29,10 @@ HITLDetailResponse, UpdateHITLDetailPayload, ) +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.hitl import HITLDetail -router = APIRouter() +router = APIRouter(dependencies=[JWTBearerDep]) log = structlog.get_logger(__name__) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 4febfcb7b83d1..6c087e048a4d1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, Depends, HTTPException, Query, Response, status +from fastapi import Body, HTTPException, Query, Response, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -60,7 +60,11 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import DepContainer, JWTBearerTIPathDep, JWTBearerWorkloadScope +from airflow.api_fastapi.execution_api.deps import ( + DepContainer, + JWTBearerTIPathDep, + JWTBearerWorkloadDep, +) from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -78,21 +82,11 @@ log = structlog.get_logger(__name__) -JWTBearerWorkloadDep = Depends(JWTBearerWorkloadScope(path_param_name="task_instance_id")) - -ti_run_router = VersionedAPIRouter(dependencies=[JWTBearerWorkloadDep]) - router = VersionedAPIRouter() +ti_id_router = VersionedAPIRouter() -ti_id_router = VersionedAPIRouter( - dependencies=[ - # This checks that the UUID in the url matches the one in the token for us. - JWTBearerTIPathDep - ] -) - -@ti_run_router.patch( +@ti_id_router.patch( "/{task_instance_id}/run", status_code=status.HTTP_200_OK, responses={ @@ -101,6 +95,7 @@ HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, + dependencies=[JWTBearerWorkloadDep], ) async def ti_run( task_instance_id: UUID, @@ -291,6 +286,7 @@ async def ti_run( status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_update_state( task_instance_id: UUID, @@ -529,6 +525,7 @@ def _create_ti_state_update_query_and_update_state( status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_skip_downstream( task_instance_id: UUID, @@ -576,6 +573,7 @@ def ti_skip_downstream( }, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_heartbeat( task_instance_id: UUID, @@ -653,6 +651,7 @@ def ti_heartbeat( "description": "Invalid payload for the setting rendered task instance fields" }, }, + dependencies=[JWTBearerTIPathDep], ) def ti_put_rtif( task_instance_id: UUID, @@ -683,6 +682,7 @@ def ti_put_rtif( status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid rendered_map_index value"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_patch_rendered_map_index( task_instance_id: UUID, @@ -720,9 +720,11 @@ def ti_patch_rendered_map_index( responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run not found"}, }, + dependencies=[JWTBearerTIPathDep], ) def get_previous_successful_dagrun( - task_instance_id: UUID, session: SessionDep + task_instance_id: UUID, + session: SessionDep, ) -> PrevSuccessfulDagRunResponse: """ Get the previous successful DagRun for a TaskInstance. @@ -979,6 +981,7 @@ def _get_group_tasks( responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, }, + dependencies=[JWTBearerTIPathDep], ) def validate_inlets_and_outlets( task_instance_id: UUID, @@ -1041,5 +1044,4 @@ def validate_inlets_and_outlets( ) -# This line should be at the end of the file to ensure all routes are registered router.include_router(ti_id_router) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py index f763858f9b9c1..625585d50979e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py @@ -24,6 +24,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.taskreschedule import TaskReschedule router = APIRouter( @@ -31,6 +32,7 @@ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index c9ae4ea364681..49856b52785bc 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -25,8 +25,7 @@ from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.execution_api.app import lifespan from airflow.api_fastapi.execution_api.datamodels.token import TIToken -from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTBearerTIPathDep -from airflow.api_fastapi.execution_api.routes.task_instances import JWTBearerWorkloadDep +from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTBearerTIPathDep, JWTBearerWorkloadDep def _always_allow(ti_id: str | None = None) -> TIToken: From 7ac2da127d1b5ee26f7440dfa93672a5447dd3c2 Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 22 Jan 2026 19:28:17 -0600 Subject: [PATCH 19/19] fix failing test --- airflow-core/src/airflow/api_fastapi/execution_api/deps.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index 11e8b1c2cf99c..549c05851aaf3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -63,7 +63,7 @@ def __init__( self.path_param_name = path_param_name self.required_claims = required_claims or {} - async def __call__( + async def __call__( # type: ignore[override] self, request: Request, security_scopes: SecurityScopes,