Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions airflow-core/src/airflow/api_fastapi/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"JWKS",
"JWTGenerator",
"JWTValidator",
"SCOPE_EXECUTION",
"SCOPE_MAPPING",
"SCOPE_WORKLOAD",
"TOKEN_SCOPE_WORKLOAD",
"generate_private_key",
"get_sig_validation_args",
"get_signing_args",
Expand All @@ -54,6 +58,14 @@
"key_to_jwk_dict",
]

TOKEN_SCOPE_WORKLOAD = "ExecuteTaskWorkload"
SCOPE_WORKLOAD = "workload"
SCOPE_EXECUTION = "execution"
SCOPE_MAPPING: dict[str, str] = {
TOKEN_SCOPE_WORKLOAD: SCOPE_WORKLOAD,
"": SCOPE_EXECUTION,
}


class InvalidClaimError(ValueError):
"""Raised when a claim in the JWT is invalid."""
Expand Down Expand Up @@ -434,15 +446,28 @@ def signing_arg(self) -> AllowedPrivateKeys | str:
assert self._secret_key
return self._secret_key

def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None) -> str:
"""Generate a signed JWT for the subject."""
def generate(
self,
extras: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
valid_for: int | None = None,
) -> str:
"""
Generate a signed JWT.

Args:
extras: Additional claims to include in the token. These are merged with default claims.
headers: Additional headers to include in the JWT.
valid_for: Optional custom validity duration in seconds. If not provided, uses self.valid_for.
"""
now = int(datetime.now(tz=timezone.utc).timestamp())
token_valid_for = valid_for if valid_for is not None else self.valid_for
claims = {
"jti": uuid.uuid4().hex,
"iss": self.issuer,
"aud": self.audience,
"nbf": now,
"exp": int(now + self.valid_for),
"exp": int(now + token_valid_for),
"iat": now,
}

Expand All @@ -458,6 +483,23 @@ def generate(self, extras: dict[str, Any] | None = None, headers: dict[str, Any]
headers["kid"] = self.kid
return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers)

def generate_workload_token(self, sub: str) -> str:
"""
Generate a long-lived workload token for task execution.

Workload tokens have a special 'scope' claim that restricts them to the /run endpoint only.
They are valid for longer (default 24h) to survive executor queue wait times.
"""
from airflow.configuration import conf

workload_valid_for = conf.getint(
"execution_api", "jwt_workload_token_expiration_time", fallback=86400
)
return self.generate(
extras={"sub": sub, "scope": TOKEN_SCOPE_WORKLOAD},
valid_for=workload_valid_for,
)


def generate_private_key(key_type: str = "RSA", key_size: int = 2048):
"""
Expand Down
48 changes: 48 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import secrets
import time
from contextlib import AsyncExitStack
from functools import cached_property
Expand Down Expand Up @@ -301,20 +302,66 @@ def app(self):
from airflow.api_fastapi.execution_api.deps import (
JWTBearerDep,
JWTBearerTIPathDep,
JWTBearerWorkloadDep,
)
from airflow.api_fastapi.execution_api.routes.connections import has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
from airflow.configuration import conf

# Ensure JWT secret is available for in-process execution.
# The /run endpoint needs JWTGenerator to issue execution tokens.
# If the config option is empty, generate a random one for the duration of this process.
if not conf.get("api_auth", "jwt_secret", fallback=None):
logger.debug(
"`api_auth/jwt_secret` is not set, generating a temporary one for in-process execution"
)
conf.set("api_auth", "jwt_secret", secrets.token_urlsafe(16))

self._app = create_task_execution_api_app()

# Set up dag_bag in app state for dependency injection
self._app.state.dag_bag = create_dag_bag()

self._app.state.jwt_generator = _jwt_generator()
self._app.state.jwt_validator = _jwt_validator()

# Why InProcessContainer instead of lifespan.registry or svcs.Container?
#
# The normal app uses @svcs.fastapi.lifespan which manages the registry lifecycle.
# In tests (conftest.py), lifespan.registry.register_value() works because the
# TestClient initializes the lifespan before requests. However, in InProcessExecutionAPI,
# the lifespan runs later (when transport is accessed), but services may be needed
# before that. Using lifespan.registry fails in CI with ServiceNotFoundError.
#
# This minimal container bypasses the svcs lifecycle and directly returns pre-created
# service instances from app.state. If you add new services, update this class.
from airflow.api_fastapi.execution_api.deps import _container

class InProcessContainer:
"""Minimal container for in-process execution, bypassing svcs lifecycle."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't we use svcs? Why do we need to implement our version of it?

Copy link
Contributor Author

@anishgirianish anishgirianish Jan 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still getting familiar with the codebase, but from what I understand, I added this to fix ServiceNotFoundError failures in CI. With InProcessExecutionAPI, the svcs lifespan runs later (when transport is accessed), but services like JWTGenerator are needed before that. This container bypasses the lifecycle and returns pre-created instances from app.state. I may well be missing something - if there's a cleaner pattern you'd recommend, I'd really appreciate the guidance.


def __init__(self, app_state):
self._services = {
JWTGenerator: app_state.jwt_generator,
JWTValidator: app_state.jwt_validator,
}

async def aget(self, svc_type):
if svc_type not in self._services:
raise KeyError(f"{svc_type} not registered in InProcessContainer")
return self._services[svc_type]

async def _inprocess_container():
yield InProcessContainer(self._app.state)

self._app.dependency_overrides[_container] = _inprocess_container

async def always_allow(): ...

self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow
self._app.dependency_overrides[JWTBearerWorkloadDep.dependency] = always_allow
self._app.dependency_overrides[has_connection_access] = always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
Expand All @@ -337,6 +384,7 @@ async def start_lifespan(cm: AsyncExitStack, app: FastAPI):
self._cm = AsyncExitStack()

asyncio.run_coroutine_threadsafe(start_lifespan(self._cm, self.app), middleware.loop)

return httpx.WSGITransport(app=middleware) # type: ignore[arg-type]

@cached_property
Expand Down
85 changes: 59 additions & 26 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@

import structlog
import svcs
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
from fastapi import Depends, HTTPException, Request, Security, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer, SecurityScopes
from sqlalchemy import select

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.auth.tokens import (
SCOPE_EXECUTION,
SCOPE_MAPPING,
SCOPE_WORKLOAD,
JWTValidator,
)
from airflow.api_fastapi.common.db.common import AsyncSessionDep
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.configuration import conf
Expand All @@ -47,14 +52,7 @@ async def _container(request: Request):


class JWTBearer(HTTPBearer):
"""
A FastAPI security dependency that validates JWT tokens using for the Execution API.

This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that.

The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other
validated claims.
"""
"""JWT Bearer auth with scope validation via FastAPI's SecurityScopes."""

def __init__(
self,
Expand All @@ -68,39 +66,74 @@ def __init__(
async def __call__( # type: ignore[override]
self,
request: Request,
security_scopes: SecurityScopes,
services=DepContainer,
) -> TIToken | None:
creds = await super().__call__(request)
) -> TIToken:
creds: HTTPAuthorizationCredentials | None = await super().__call__(request)
if not creds:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing auth token",
headers={"WWW-Authenticate": "Bearer"},
)

validator: JWTValidator = await services.aget(JWTValidator)

try:
# Example: Validate "task_instance_id" component of the path matches the one in the token
if self.path_param_name:
id = request.path_params[self.path_param_name]
ti_id = request.path_params[self.path_param_name]
validators: dict[str, Any] = {
**self.required_claims,
"sub": {"essential": True, "value": id},
"sub": {"essential": True, "value": ti_id},
}
else:
validators = self.required_claims
claims = await validator.avalidated_claims(creds.credentials, validators)
self._validate_scopes(claims, security_scopes)
return TIToken(id=claims["sub"], claims=claims)
except HTTPException:
raise
except Exception as err:
log.warning(
"Failed to validate JWT",
exc_info=True,
token=creds.credentials,
log.warning("Failed to validate JWT", exc_info=True)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Invalid auth token: {err}",
headers={"WWW-Authenticate": "Bearer"},
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}")

def _validate_scopes(self, claims: dict[str, Any], security_scopes: SecurityScopes) -> None:
if not security_scopes.scopes:
return

token_scope = claims.get("scope", "")
mapped_scope = SCOPE_MAPPING.get(token_scope)
if mapped_scope is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Unknown token scope: {token_scope}",
headers={"WWW-Authenticate": "Bearer"},
)

JWTBearerDep: TIToken = Depends(JWTBearer())

# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
for required_scope in security_scopes.scopes:
if required_scope != mapped_scope:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Token missing required scope: {required_scope}",
headers={"WWW-Authenticate": "Bearer"},
)


_jwt_bearer = JWTBearer()
_jwt_bearer_with_path = JWTBearer(path_param_name="task_instance_id")

# No scope check - for router-level auth
JWTBearerBaseDep = Security(_jwt_bearer, scopes=[])
# Execution scope - most endpoints
JWTBearerDep = Security(_jwt_bearer, scopes=[SCOPE_EXECUTION])
# Execution scope with path param validation
JWTBearerTIPathDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_EXECUTION])
# Workload scope
JWTBearerWorkloadDep = Security(_jwt_bearer_with_path, scopes=[SCOPE_WORKLOAD])


async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cadwyn import VersionedAPIRouter
from fastapi import APIRouter

from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.deps import JWTBearerBaseDep
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
Expand All @@ -36,8 +36,8 @@
execution_api_router = APIRouter()
execution_api_router.include_router(health.router, prefix="/health", tags=["Health"])

# _Every_ single endpoint under here must be authenticated. Some do further checks on top of these
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item]
# Base JWT auth; scopes checked at endpoint/router level
authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerBaseDep]) # type: ignore[list-item]

authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
AssetEventResponse,
AssetEventsResponse,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.asset import AssetAliasModel, AssetEvent, AssetModel

# TODO: Add dependency on JWT token
router = APIRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
dependencies=[JWTBearerDep],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.asset import AssetModel

# TODO: Add dependency on JWT token
router = APIRouter(
responses={
status.HTTP_404_NOT_FOUND: {"description": "Asset not found"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
dependencies=[JWTBearerDep],
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@
from airflow.api_fastapi.compat import HTTP_422_UNPROCESSABLE_CONTENT
from airflow.api_fastapi.execution_api.datamodels.dagrun import DagRunStateResponse, TriggerDAGRunPayload
from airflow.api_fastapi.execution_api.datamodels.taskinstance import DagRun
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.exceptions import DagRunAlreadyExists
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun as DagRunModel
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType

router = VersionedAPIRouter()
router = VersionedAPIRouter(dependencies=[JWTBearerDep])

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
HITLDetailResponse,
UpdateHITLDetailPayload,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.hitl import HITLDetail

router = APIRouter()
router = APIRouter(dependencies=[JWTBearerDep])

log = structlog.get_logger(__name__)

Expand Down
Loading
Loading