diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py b/airflow-core/src/airflow/api_fastapi/auth/tokens.py index 276ae17153da0..123e83643e225 100644 --- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py +++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py @@ -46,6 +46,10 @@ "JWKS", "JWTGenerator", "JWTValidator", + "SCOPE_EXECUTION", + "SCOPE_MAPPING", + "SCOPE_WORKLOAD", + "TOKEN_SCOPE_WORKLOAD", "generate_private_key", "get_sig_validation_args", "get_signing_args", @@ -54,6 +58,14 @@ "key_to_jwk_dict", ] +TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload" +SCOPE_WORKLOAD = "workload" +SCOPE_EXECUTION = "execution" +SCOPE_MAPPING: dict[str, str] = { + TOKEN_SCOPE_WORKLOAD: SCOPE_WORKLOAD, + "": SCOPE_EXECUTION, +} + class InvalidClaimError(ValueError): """Raised when a claim in the JWT is invalid.""" @@ -434,15 +446,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str: assert self._secret_key return self._secret_key - def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str: - """Generate a signed JWT for the subject.""" + def generate( + self, + extras: dict[str, Any] | None = None, + headers: dict[str, Any] | None = None, + valid_for: int | None = None, + ) -> str: + """ + Generate a signed JWT. + + Args: + extras: Additional claims to include in the token. These are merged with default claims. + headers: Additional headers to include in the JWT. + valid_for: Optional custom validity duration in seconds. If not provided, uses self.valid_for. + """ now = int(datetime.now(tz=timezone.utc).timestamp()) + token_valid_for = valid_for if valid_for is not None else self.valid_for claims = { "jti": uuid.uuid4().hex, "iss": self.issuer, "aud": self.audience, "nbf": now, - "exp": int(now + self.valid_for), + "exp": int(now + token_valid_for), "iat": now, } @@ -458,6 +483,23 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] headers["kid"] = self.kid return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + def generate_workload_token(self, sub: str) -> str: + """ + Generate a long-lived workload token for task execution. + + Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only. + They are valid for longer (default 24h) to survive executor queue wait times. + """ + from airflow.configuration import conf + + workload_valid_for = conf.getint( + "execution_api", "jwt_workload_token_expiration_time", fallback=86400 + ) + return self.generate( + extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD}, + valid_for=workload_valid_for, + ) + def generate_private_key(key_type: str = "RSA", key_size: int = 2048): """ diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index 9d93f3bf84daf..70e9463769675 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import secrets import time from contextlib import AsyncExitStack from functools import cached_property @@ -301,20 +302,66 @@ def app(self): from airflow.api_fastapi.execution_api.deps import ( JWTBearerDep, JWTBearerTIPathDep, + JWTBearerWorkloadDep, ) from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.configuration import conf + + # Ensure JWT secret is available for in-process execution. + # The /run endpoint needs JWTGenerator to issue execution tokens. + # If the config option is empty, generate a random one for the duration of this process. + if not conf.get("api_auth", "jwt_secret", fallback=None): + logger.debug( + "`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution" + ) + conf.set("api_auth", "jwt_secret", secrets.token_urlsafe(16)) self._app = create_task_execution_api_app() # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() + self._app.state.jwt_generator = _jwt_generator() + self._app.state.jwt_validator = _jwt_validator() + + # Why InProcessContainer instead of lifespan.registry or svcs.Container? + # + # The normal app uses @svcs.fastapi.lifespan which manages the registry lifecycle. + # In tests (conftest.py), lifespan.registry.register_value() works because the + # TestClient initializes the lifespan before requests. However, in InProcessExecutionAPI, + # the lifespan runs later (when transport is accessed), but services may be needed + # before that. Using lifespan.registry fails in CI with ServiceNotFoundError. + # + # This minimal container bypasses the svcs lifecycle and directly returns pre-created + # service instances from app.state. If you add new services, update this class. + from airflow.api_fastapi.execution_api.deps import _container + + class InProcessContainer: + """Minimal container for in-process execution, bypassing svcs lifecycle.""" + + def __init__(self, app_state): + self._services = { + JWTGenerator: app_state.jwt_generator, + JWTValidator: app_state.jwt_validator, + } + + async def aget(self, svc_type): + if svc_type not in self._services: + raise KeyError(f"{svc_type} not registered in InProcessContainer") + return self._services[svc_type] + + async def _inprocess_container(): + yield InProcessContainer(self._app.state) + + self._app.dependency_overrides[_container] = _inprocess_container + async def always_allow(): ... self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow + self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow @@ -337,6 +384,7 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI): self._cm = AsyncExitStack() asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop) + return httpx.WSGITransport(app=middleware) # type: ignore[arg-type] @cached_property diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index fce188d48ed6d..549c05851aaf3 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -22,11 +22,16 @@ import structlog import svcs -from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBearer +from fastapi import Depends, HTTPException, Request, Security, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes from sqlalchemy import select -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import ( + SCOPE_EXECUTION, + SCOPE_MAPPING, + SCOPE_WORKLOAD, + JWTValidator, +) from airflow.api_fastapi.common.db.common import AsyncSessionDep from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.configuration import conf @@ -47,14 +52,7 @@ async def _container(request: Request): class JWTBearer(HTTPBearer): - """ - A FastAPI security dependency that validates JWT tokens using for the Execution API. - - This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that. - - The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other - validated claims. - """ + """JWT Bearer auth with scope validation via FastAPI's SecurityScopes.""" def __init__( self, @@ -68,39 +66,74 @@ def __init__( async def __call__( # type: ignore[override] self, request: Request, + security_scopes: SecurityScopes, services=DepContainer, - ) -> TIToken | None: - creds = await super().__call__(request) + ) -> TIToken: + creds: HTTPAuthorizationCredentials | None = await super().__call__(request) if not creds: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing auth token", + headers={"WWW-Authenticate": "Bearer"}, + ) validator: JWTValidator = await services.aget(JWTValidator) try: - # Example: Validate "task_instance_id" component of the path matches the one in the token if self.path_param_name: - id = request.path_params[self.path_param_name] + ti_id = request.path_params[self.path_param_name] validators: dict[str, Any] = { **self.required_claims, - "sub": {"essential": True, "value": id}, + "sub": {"essential": True, "value": ti_id}, } else: validators = self.required_claims claims = await validator.avalidated_claims(creds.credentials, validators) + self._validate_scopes(claims, security_scopes) return TIToken(id=claims["sub"], claims=claims) + except HTTPException: + raise except Exception as err: - log.warning( - "Failed to validate JWT", - exc_info=True, - token=creds.credentials, + log.warning("Failed to validate JWT", exc_info=True) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid auth token: {err}", + headers={"WWW-Authenticate": "Bearer"}, ) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + def _validate_scopes(self, claims: dict[str, Any], security_scopes: SecurityScopes) -> None: + if not security_scopes.scopes: + return + + token_scope = claims.get("scope", "") + mapped_scope = SCOPE_MAPPING.get(token_scope) + if mapped_scope is None: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Unknown token scope: {token_scope}", + headers={"WWW-Authenticate": "Bearer"}, + ) -JWTBearerDep: TIToken = Depends(JWTBearer()) - -# This checks that the UUID in the url matches the one in the token for us. -JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) + for required_scope in security_scopes.scopes: + if required_scope != mapped_scope: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token missing required scope: {required_scope}", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +_jwt_bearer = JWTBearer() +_jwt_bearer_with_path = JWTBearer(path_param_name="task_instance_id") + +# No scope check - for router-level auth +JWTBearerBaseDep = Security(_jwt_bearer, scopes=[]) +# Execution scope - most endpoints +JWTBearerDep = Security(_jwt_bearer, scopes=[SCOPE_EXECUTION]) +# Execution scope with path param validation +JWTBearerTIPathDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_EXECUTION]) +# Workload scope +JWTBearerWorkloadDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_WORKLOAD]) async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None: diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 562b8588fbf2c..b8f5cb25f2c5d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -19,7 +19,7 @@ from cadwyn import VersionedAPIRouter from fastapi import APIRouter -from airflow.api_fastapi.execution_api.deps import JWTBearerDep +from airflow.api_fastapi.execution_api.deps import JWTBearerBaseDep from airflow.api_fastapi.execution_api.routes import ( asset_events, assets, @@ -36,8 +36,8 @@ execution_api_router = APIRouter() execution_api_router.include_router(health.router, prefix="/health", tags=["Health"]) -# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these -authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item] +# Base JWT auth; scopes checked at endpoint/router level +authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerBaseDep]) # type: ignore[list-item] authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py index 4525a9140e5d1..dd7a0d7b9859a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_events.py @@ -29,14 +29,15 @@ AssetEventResponse, AssetEventsResponse, ) +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel -# TODO: Add dependency on JWT token router = APIRouter( responses={ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py index 316d4fab4770d..7ea9b6d4335c8 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/assets.py @@ -24,14 +24,15 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.asset import AssetModel -# TODO: Add dependency on JWT token router = APIRouter( responses={ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py index 56fd79b825eba..9696c352093ff 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py @@ -32,13 +32,14 @@ from airflow.api_fastapi.compat import HTTP_422_UNPROCESSABLE_CONTENT from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.exceptions import DagRunAlreadyExists from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun as DagRunModel from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType -router = VersionedAPIRouter() +router = VersionedAPIRouter(dependencies=[JWTBearerDep]) log = logging.getLogger(__name__) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py index 1c91c3e5b3b84..a7720f31b4195 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/hitl.py @@ -29,9 +29,10 @@ HITLDetailResponse, UpdateHITLDetailPayload, ) +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.hitl import HITLDetail -router = APIRouter() +router = APIRouter(dependencies=[JWTBearerDep]) log = structlog.get_logger(__name__) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index a73145b30ab5a..6c087e048a4d1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, HTTPException, Query, Response, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -38,6 +38,7 @@ from structlog.contextvars import bind_contextvars from airflow._shared.timezones import timezone +from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime @@ -59,7 +60,11 @@ TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.deps import ( + DepContainer, + JWTBearerTIPathDep, + JWTBearerWorkloadDep, +) from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -75,18 +80,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.dml import Update -router = VersionedAPIRouter() - -ti_id_router = VersionedAPIRouter( - dependencies=[ - # This checks that the UUID in the url matches the one in the token for us. - JWTBearerTIPathDep - ] -) - - log = structlog.get_logger(__name__) +router = VersionedAPIRouter() +ti_id_router = VersionedAPIRouter() + @ti_id_router.patch( "/{task_instance_id}/run", @@ -97,12 +95,15 @@ HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, + dependencies=[JWTBearerWorkloadDep], ) -def ti_run( +async def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep, dag_bag: DagBagDep, + response: Response, + services=DepContainer, ) -> TIRunContext: """ Run a TaskInstance. @@ -264,6 +265,11 @@ def ti_run( context.next_method = ti.next_method context.next_kwargs = ti.next_kwargs + # Generate short-lived execution token for subsequent API calls + generator: JWTGenerator = await services.aget(JWTGenerator) + execution_token = generator.generate(extras={"sub": ti_id_str}) + response.headers["X-Execution-Token"] = execution_token + return context except SQLAlchemyError: log.exception("Error marking Task Instance state as running") @@ -280,6 +286,7 @@ def ti_run( status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_update_state( task_instance_id: UUID, @@ -518,6 +525,7 @@ def _create_ti_state_update_query_and_update_state( status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_skip_downstream( task_instance_id: UUID, @@ -565,6 +573,7 @@ def ti_skip_downstream( }, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid payload for the state transition"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_heartbeat( task_instance_id: UUID, @@ -642,6 +651,7 @@ def ti_heartbeat( "description": "Invalid payload for the setting rendered task instance fields" }, }, + dependencies=[JWTBearerTIPathDep], ) def ti_put_rtif( task_instance_id: UUID, @@ -672,6 +682,7 @@ def ti_put_rtif( status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid rendered_map_index value"}, }, + dependencies=[JWTBearerTIPathDep], ) def ti_patch_rendered_map_index( task_instance_id: UUID, @@ -709,9 +720,11 @@ def ti_patch_rendered_map_index( responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run not found"}, }, + dependencies=[JWTBearerTIPathDep], ) def get_previous_successful_dagrun( - task_instance_id: UUID, session: SessionDep + task_instance_id: UUID, + session: SessionDep, ) -> PrevSuccessfulDagRunResponse: """ Get the previous successful DagRun for a TaskInstance. @@ -968,6 +981,7 @@ def _get_group_tasks( responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, }, + dependencies=[JWTBearerTIPathDep], ) def validate_inlets_and_outlets( task_instance_id: UUID, @@ -1030,5 +1044,4 @@ def validate_inlets_and_outlets( ) -# This line should be at the end of the file to ensure all routes are registered router.include_router(ti_id_router) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py index f763858f9b9c1..625585d50979e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_reschedules.py @@ -24,6 +24,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.types import UtcDateTime +from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.models.taskreschedule import TaskReschedule router = APIRouter( @@ -31,6 +32,7 @@ status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[JWTBearerDep], ) diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 9532de60a6257..97108feacd5c0 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -1858,6 +1858,17 @@ execution_api: type: integer example: ~ default: "600" + jwt_workload_token_expiration_time: + description: | + Number in seconds until the workload JWT token expires. Workload tokens are long-lived tokens + sent with task workloads to executors (e.g., Celery). They can only be used to call + the /run endpoint, which then issues a short-lived execution token. + + This should be set long enough to cover the maximum expected queue wait time. + version_added: 3.1.7 + type: integer + example: ~ + default: "86400" jwt_audience: version_added: 3.0.0 description: | diff --git a/airflow-core/src/airflow/executors/workloads.py b/airflow-core/src/airflow/executors/workloads.py index 7cf1aae60ff21..c73231dc0571b 100644 --- a/airflow-core/src/airflow/executors/workloads.py +++ b/airflow-core/src/airflow/executors/workloads.py @@ -45,7 +45,13 @@ class BaseWorkload(BaseModel): @staticmethod def generate_token(sub_id: str, generator: JWTGenerator | None = None) -> str: - return generator.generate({"sub": sub_id}) if generator else "" + """ + Generate a workload-scoped token for this workload. + + Workload tokens are long-lived and can only be used on the /run endpoint, + which exchanges them for short-lived execution tokens. + """ + return generator.generate_workload_token(sub_id) if generator else "" class BundleInfo(BaseModel): diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py index b17c8147dae24..ed7d51b7fad9e 100644 --- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py +++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py @@ -30,6 +30,7 @@ from airflow._shared.timezones import timezone from airflow.api_fastapi.auth.tokens import ( JWKS, + TOKEN_SCOPE_WORKLOAD, InvalidClaimError, JWTGenerator, JWTValidator, @@ -238,3 +239,42 @@ def rsa_private_key(): @pytest.fixture(scope="session") def ed25519_private_key(): return generate_private_key(key_type="Ed25519") + + +async def test_generate_workload_token(jwt_generator: JWTGenerator, jwt_validator: JWTValidator): + """Test that generate_workload_token creates tokens with workload scope and longer expiration.""" + token = jwt_generator.generate_workload_token("test_subject") + + claims = await jwt_validator.avalidated_claims( + token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + # Verify workload scope is set + assert claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" + + # Verify the token has extended expiration (default 24h = 86400s) + nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc) + exp = datetime.fromtimestamp(claims["exp"], timezone.utc) + expiration_seconds = (exp - nbf).total_seconds() + + # Should be around 24 hours (86400 seconds) - allow some tolerance + assert expiration_seconds >= 86000, "Workload token should have extended expiration (~24h)" + assert expiration_seconds <= 90000, "Workload token expiration should not exceed expected duration" + + +async def test_workload_token_vs_regular_token_scope( + jwt_generator: JWTGenerator, jwt_validator: JWTValidator +): + """Test that regular tokens don't have scope claim while workload tokens do.""" + regular_token = jwt_generator.generate({"sub": "test_subject"}) + workload_token = jwt_generator.generate_workload_token("test_subject") + + regular_claims = await jwt_validator.avalidated_claims( + regular_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + workload_claims = await jwt_validator.avalidated_claims( + workload_token, required_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + assert "scope" not in regular_claims, "Regular token should not have scope claim" + assert workload_claims.get("scope") == TOKEN_SCOPE_WORKLOAD, "Workload token should have workload scope" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9e26937b63c06..49856b52785bc 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,14 +16,21 @@ # under the License. from __future__ import annotations -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app -from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import JWTBearerDep, JWTBearerTIPathDep, JWTBearerWorkloadDep + + +def _always_allow(ti_id: str | None = None) -> TIToken: + """Return a mock TIToken for bypassing auth in tests.""" + return TIToken(id=ti_id or "00000000-0000-0000-0000-000000000000", claims={}) @pytest.fixture @@ -63,4 +70,29 @@ def smart_validated_claims(cred, validators=None): auth.avalidated_claims.side_effect = smart_validated_claims lifespan.registry.register_value(JWTValidator, auth) + # Mock JWTGenerator for /run endpoint that returns execution tokens + jwt_generator = MagicMock(spec=JWTGenerator) + jwt_generator.generate.return_value = "mock-execution-token" + lifespan.registry.register_value(JWTGenerator, jwt_generator) + + jwt_bearer_instance = JWTBearerDep.dependency + jwt_bearer_ti_path_instance = JWTBearerTIPathDep.dependency + jwt_bearer_workload_instance = JWTBearerWorkloadDep.dependency + + execution_app = None + for route in app.routes: + if hasattr(route, "path") and route.path == "/execution": + execution_app = route.app + break + + if execution_app: + execution_app.dependency_overrides[jwt_bearer_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_ti_path_instance] = lambda: _always_allow() + execution_app.dependency_overrides[jwt_bearer_workload_instance] = lambda: _always_allow() + yield client + + if execution_app: + execution_app.dependency_overrides.pop(jwt_bearer_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_ti_path_instance, None) + execution_app.dependency_overrides.pop(jwt_bearer_workload_instance, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index cfee6c9d46cb1..714200a2cf3ed 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -29,8 +29,6 @@ from sqlalchemy.orm import Session from airflow._shared.timezones import timezone -from airflow.api_fastapi.auth.tokens import JWTValidator -from airflow.api_fastapi.execution_api.app import lifespan from airflow.exceptions import AirflowSkipException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel @@ -79,50 +77,35 @@ def _create_asset_aliases(session, num: int = 2) -> None: def client_with_extra_route(): ... -def test_id_matches_sub_claim(client, session, create_task_instance): - # Test that this is validated at the router level, so we don't have to test it in each component - # We validate it is set correctly, and test it once +def test_run_endpoint_returns_execution_token(client, session, create_task_instance, time_machine): + """Test that /run endpoint returns an execution token in the response header.""" + instant = timezone.parse("2024-09-30T12:00:00Z") + time_machine.move_to(instant, tick=False) ti = create_task_instance( - task_id="test_ti_run_state_conflict_if_not_queued", - state="queued", + task_id="test_run_endpoint_returns_execution_token", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), ) session.commit() - validator = mock.AsyncMock(spec=JWTValidator) - claims = {"sub": ti.id} - - def side_effect(cred, validators): - if not validators: - return claims - if validators["sub"]["value"] != ti.id: - raise RuntimeError("Fake auth denied") - return claims - - validator.avalidated_claims.side_effect = side_effect - - lifespan.registry.register_value(JWTValidator, validator) - payload = { "state": "running", "hostname": "random-hostname", "unixname": "random-unixname", "pid": 100, - "start_date": "2024-10-31T12:00:00Z", + "start_date": "2024-09-30T12:00:00Z", } - resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload) - assert resp.status_code == 403 - assert validator.avalidated_claims.call_args_list[1] == mock.call( - mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}} - ) - validator.avalidated_claims.reset_mock() - resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload) + assert resp.status_code == 200 - assert resp.status_code == 200, resp.json() - - validator.avalidated_claims.assert_awaited() + # Verify execution token is returned in header + assert "X-Execution-Token" in resp.headers + assert resp.headers["X-Execution-Token"] == "mock-execution-token" class TestTIRunState: diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 3636c2b6db980..01b2a492abbe3 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -235,7 +235,7 @@ def set_instance_attrs(self) -> Generator: @pytest.fixture def mock_executors(self): mock_jwt_generator = MagicMock(spec=JWTGenerator) - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" default_executor = mock.MagicMock(name="DefaultExecutor", slots_available=8, slots_occupied=0) default_executor.name = ExecutorName(alias="default_exec", module_path="default.exec.module.path") diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index 4e95ed3a4eea7..6280bf03f0ee0 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -58,7 +58,7 @@ def __init__(self, do_update=True, *args, **kwargs): # Mock JWT generator for token generation mock_jwt_generator = MagicMock() - mock_jwt_generator.generate.return_value = "mock-token" + mock_jwt_generator.generate_workload_token.return_value = "mock-token" self.jwt_generator = mock_jwt_generator diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 2a38ef2fad7b3..c6e70a91306c6 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -928,7 +928,12 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, * ) def _update_auth(self, response: httpx.Response): - if new_token := response.headers.get("Refreshed-API-Token"): + # Check for execution token from /run endpoint (replaces queue token with short-lived execution token) + if new_token := response.headers.get("X-Execution-Token"): + log.debug("Received execution token from /run endpoint") + self.auth = BearerAuth(new_token) + # Check for refreshed token from heartbeat/other endpoints + elif new_token := response.headers.get("Refreshed-API-Token"): log.debug("Execution API issued us a refreshed Task token") self.auth = BearerAuth(new_token)