diff --git a/airflow/providers/edge/api_endpoints/__init__.py b/airflow/providers/edge/api_endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/edge/api_endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/edge/api_endpoints/health_endpoint.py b/airflow/providers/edge/api_endpoints/health_endpoint.py new file mode 100644 index 0000000000000..a6c8a9c7950da --- /dev/null +++ b/airflow/providers/edge/api_endpoints/health_endpoint.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + + +def health(): + return {} diff --git a/airflow/providers/edge/api_endpoints/rpc_api_endpoint.py b/airflow/providers/edge/api_endpoints/rpc_api_endpoint.py new file mode 100644 index 0000000000000..e0d589fdc3c2c --- /dev/null +++ b/airflow/providers/edge/api_endpoints/rpc_api_endpoint.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import functools +import json +import logging +from typing import TYPE_CHECKING, Any, Callable + +from flask import Response, request +from itsdangerous import BadSignature +from jwt import ( + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidSignatureError, +) + +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.api_internal.endpoints.rpc_api_endpoint import ( + initialize_method_map, + log_and_build_error_response, +) +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.utils.jwt_signer import JWTSigner +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from airflow.api_connexion.types import APIResponse + +log = logging.getLogger(__name__) + + +@functools.lru_cache +def _initialize_method_map() -> dict[str, Callable]: + from airflow.providers.edge.models.edge_job import EdgeJob + from airflow.providers.edge.models.edge_logs import EdgeLogs + from airflow.providers.edge.models.edge_worker import EdgeWorker + + internal_api_functions = initialize_method_map().values() + functions: list[Callable] = [ + *internal_api_functions, + # Additional things from EdgeExecutor + EdgeJob.reserve_task, + EdgeJob.set_state, + EdgeLogs.push_logs, + EdgeWorker.register_worker, + EdgeWorker.set_state, + ] + return {f"{func.__module__}.{func.__qualname__}": func for func in functions} + + +def edge_worker_api(body: dict[str, Any]) -> APIResponse: + """Handle Edge Worker API `/edge_worker/v1/rpcapi` endpoint.""" + # Note: Except the method map this is a 100% copy of internal API module + # airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api() + content_type = request.headers.get("Content-Type") + if content_type != "application/json": + raise PermissionDenied("Expected Content-Type: application/json") + accept = request.headers.get("Accept") + if accept != "application/json": + raise PermissionDenied("Expected Accept: application/json") + auth = request.headers.get("Authorization", "") + clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30) + signer = JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=clock_grace, + leeway_in_seconds=clock_grace, + audience="api", + ) + try: + payload = signer.verify_token(auth) + signed_method = payload.get("method") + if not signed_method or signed_method != body.get("method"): + raise BadSignature("Invalid method in token authorization.") + except BadSignature: + raise PermissionDenied("Bad Signature. Please use only the tokens provided by the API.") + except InvalidAudienceError: + raise PermissionDenied("Invalid audience for the request") + except InvalidSignatureError: + raise PermissionDenied("The signature of the request was wrong") + except ImmatureSignatureError: + raise PermissionDenied("The signature of the request was sent from the future") + except ExpiredSignatureError: + raise PermissionDenied( + "The signature of the request has expired. Make sure that all components " + "in your system have synchronized clocks.", + ) + except InvalidIssuedAtError: + raise PermissionDenied( + "The request was issues in the future. Make sure that all components " + "in your system have synchronized clocks.", + ) + except Exception: + raise PermissionDenied("Unable to authenticate API via token.") + + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + return log_and_build_error_response(message="Expected jsonrpc 2.0 request.", status=400) + + methods_map = _initialize_method_map() + method_name = body.get("method") + if method_name not in methods_map: + return log_and_build_error_response(message=f"Unrecognized method: {method_name}.", status=400) + + handler = methods_map[method_name] + params = {} + try: + if body.get("params"): + params_json = body.get("params") + params = BaseSerialization.deserialize(params_json, use_pydantic_models=True) + except Exception: + return log_and_build_error_response(message="Error deserializing parameters.", status=400) + + log.debug("Calling method %s\nparams: %s", method_name, params) + try: + # Session must be created there as it may be needed by serializer for lazy-loaded fields. + with create_session() as session: + output = handler(**params, session=session) + output_json = BaseSerialization.serialize(output, use_pydantic_models=True) + response = json.dumps(output_json) if output_json is not None else None + log.debug("Sending response: %s", response) + return Response(response=response, headers={"Content-Type": "application/json"}) + # In case of AirflowException or other selective known types, transport the exception class back to caller + except (KeyError, AttributeError, AirflowException) as e: + exception_json = BaseSerialization.serialize(e, use_pydantic_models=True) + response = json.dumps(exception_json) + log.debug("Sending exception response: %s", response) + return Response(response=response, headers={"Content-Type": "application/json"}) + except Exception: + return log_and_build_error_response(message=f"Error executing method '{method_name}'.", status=500) diff --git a/airflow/providers/edge/openapi/__init__.py b/airflow/providers/edge/openapi/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/edge/openapi/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/edge/openapi/edge_worker_api_v1.yaml b/airflow/providers/edge/openapi/edge_worker_api_v1.yaml new file mode 100644 index 0000000000000..fc8aadf48418b --- /dev/null +++ b/airflow/providers/edge/openapi/edge_worker_api_v1.yaml @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +openapi: 3.0.2 +info: + title: Airflow Edge Worker API + version: 1.0.0 + description: | + This is Airflow Edge Worker API - which is a the access endpoint for workers + running on remote sites serving for Apache Airflow jobs. It also proxies internal API + to edge endpoints. + + It is not intended to be used by any external code. + + You can find more information in AIP-69 + https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=301795932 + + +servers: + - url: /edge_worker/v1 + description: Airflow Edge Worker API +paths: + "/rpcapi": + post: + deprecated: false + x-openapi-router-controller: airflow.providers.edge.api_endpoints.rpc_api_endpoint + operationId: edge_worker_api + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response + requestBody: + x-body-name: body + required: true + content: + application/json: + schema: + type: object + required: + - method + - jsonrpc + - params + properties: + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + method: + type: string + description: Method name + params: + title: Parameters + type: object + "/health": + get: + operationId: health + deprecated: false + x-openapi-router-controller: airflow.providers.edge.api_endpoints.health_endpoint + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response +x-headers: [] +x-explorer-enabled: true +x-proxy-enabled: true +components: + schemas: + JsonRpcRequired: + type: object + required: + - method + - jsonrpc + properties: + method: + type: string + description: Method name + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + discriminator: + propertyName: method_name +tags: [] diff --git a/airflow/providers/edge/plugins/__init__.py b/airflow/providers/edge/plugins/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/edge/plugins/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/edge/plugins/edge_executor_plugin.py b/airflow/providers/edge/plugins/edge_executor_plugin.py new file mode 100644 index 0000000000000..57b4a278f5e83 --- /dev/null +++ b/airflow/providers/edge/plugins/edge_executor_plugin.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING + +from connexion import FlaskApi +from flask import Blueprint +from flask_appbuilder import BaseView, expose +from sqlalchemy import select + +from airflow.auth.managers.models.resource_details import AccessView +from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException +from airflow.models.taskinstance import TaskInstanceState +from airflow.plugins_manager import AirflowPlugin +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.yaml import safe_load +from airflow.www import utils as wwwutils +from airflow.www.auth import has_access_view +from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED +from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + +def _get_api_endpoints() -> Blueprint: + folder = Path(__file__).parents[1].resolve() # this is airflow/providers/edge/ + with folder.joinpath("openapi", "edge_worker_api_v1.yaml").open() as f: + specification = safe_load(f) + bp = FlaskApi( + specification=specification, + resolver=_LazyResolver(), + base_path="/edge_worker/v1", + options={"swagger_ui": SWAGGER_ENABLED, "swagger_path": SWAGGER_BUNDLE.__fspath__()}, + strict_validation=True, + validate_responses=True, + validator_map={"body": _CustomErrorRequestBodyValidator}, + ).blueprint + # Need to exempt CSRF to make API usable + from airflow.www.app import csrf + + csrf.exempt(bp) + return bp + + +# registers airflow/providers/edge/plugins/templates as a Jinja template folder +template_bp = Blueprint( + "template_blueprint", + __name__, + template_folder="templates", +) + + +class EdgeWorkerJobs(BaseView): + """Simple view to show Edge Worker jobs.""" + + default_view = "jobs" + + @expose("/jobs") + @has_access_view(AccessView.JOBS) + @provide_session + def jobs(self, session: Session = NEW_SESSION): + from airflow.providers.edge.models.edge_job import EdgeJobModel + + jobs = session.scalars(select(EdgeJobModel).order_by(EdgeJobModel.queued_dttm)).all() + html_states = { + str(state): wwwutils.state_token(str(state)) for state in TaskInstanceState.__members__.values() + } + return self.render_template("edge_worker_jobs.html", jobs=jobs, html_states=html_states) + + +class EdgeWorkerHosts(BaseView): + """Simple view to show Edge Worker status.""" + + default_view = "status" + + @expose("/status") + @has_access_view(AccessView.JOBS) + @provide_session + def status(self, session: Session = NEW_SESSION): + from airflow.providers.edge.models.edge_worker import EdgeWorkerModel + + hosts = session.scalars(select(EdgeWorkerModel).order_by(EdgeWorkerModel.worker_name)).all() + five_min_ago = datetime.now() - timedelta(minutes=5) + return self.render_template("edge_worker_hosts.html", hosts=hosts, five_min_ago=five_min_ago) + + +# Check if EdgeExecutor is actually loaded +try: + EDGE_EXECUTOR_ACTIVE = conf.getboolean("edge", "api_enabled") +except AirflowConfigException: + EDGE_EXECUTOR_ACTIVE = False + + +class EdgeExecutorPlugin(AirflowPlugin): + """EdgeExecutor Plugin - provides API endpoints for Edge Workers in Webserver.""" + + name = "edge_executor" + flask_blueprints = [_get_api_endpoints(), template_bp] if EDGE_EXECUTOR_ACTIVE else [] + appbuilder_views = ( + [ + { + "name": "Edge Worker Jobs", + "category": "Admin", + "view": EdgeWorkerJobs(), + }, + { + "name": "Edge Worker Hosts", + "category": "Admin", + "view": EdgeWorkerHosts(), + }, + ] + if EDGE_EXECUTOR_ACTIVE + else [] + ) diff --git a/airflow/providers/edge/plugins/templates/edge_worker_hosts.html b/airflow/providers/edge/plugins/templates/edge_worker_hosts.html new file mode 100644 index 0000000000000..5d2b3bc718a0d --- /dev/null +++ b/airflow/providers/edge/plugins/templates/edge_worker_hosts.html @@ -0,0 +1,91 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + #} + + {% extends base_template %} + + {% block title %} + Edge Worker Hosts + {% endblock %} + + {% block content %} +

Edge Worker Hosts

+ {% if hosts|length == 0 %} +

No Edge Workers connected or known currently.

+ {% else %} + + + + + + + + + + + + + {% for host in hosts %} + + + + + + + + + + + {% endfor %} +
HostnameStateQueuesFirst OnlineLast Heart BeatActive JobsSystem Information
{{ host.worker_name }} + {%- if host.state == "offline" -%} + {{ host.state }} + {%- elif host.last_update.timestamp() <= five_min_ago.timestamp() -%} + Reported {{ host.state }} + but no heartbeat + {%- elif host.state == "starting" -%} + {{ host.state }} + {%- elif host.state == "running" -%} + {{ host.state }} + {%- elif host.state == "idle" -%} + {{ host.state }} + {%- elif host.state == "terminating" -%} + {{ host.state }} + {%- elif host.state == "unknown" -%} + {{ host.state }} + {%- else -%} + {{ host.state }} + {%- endif -%} + {% if host.queues %}{{ host.queues }}{% else %}(all){% endif %}{% if host.last_update %}{% endif %}{{ host.jobs_active }} +
    + {% for item in host.sysinfo_json %} +
  • {{ item }}: {{ host.sysinfo_json[item] }}
  • + {% endfor %} +
+
+ {% endif %} + {% endblock %} + diff --git a/airflow/providers/edge/plugins/templates/edge_worker_jobs.html b/airflow/providers/edge/plugins/templates/edge_worker_jobs.html new file mode 100644 index 0000000000000..a73e0f1d485f4 --- /dev/null +++ b/airflow/providers/edge/plugins/templates/edge_worker_jobs.html @@ -0,0 +1,63 @@ +{# + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + #} + + {% extends base_template %} + + {% block title %} + Edge Worker Jobs + {% endblock %} + + {% block content %} +

Edge Worker Jobs

+ {% if jobs|length == 0 %} +

No jobs running currently

+ {% else %} + + + + + + + + + + + + + + + + {% for job in jobs %} + + + + + + + + + + + + + {% endfor %} +
DAG IDTask IDRun IDMap IndexTry NumberStateQueueQueued DTTMEdge WorkerLast Update
{{ job.dag_id }}{{ job.task_id }}{{ job.run_id }}{% if job.map_index >= 0 %}{{ job.map_index }}{% else %}-{% endif %}{{ job.try_number }}{{ html_states[job.state] }}{{ job.queue }}{% if job.edge_worker %}{{ job.edge_worker }}{% endif %}{% if job.last_update %}{% endif %}
+ {% endif %} + {% endblock %} + diff --git a/airflow/providers/edge/provider.yaml b/airflow/providers/edge/provider.yaml index cb775ee7cc7e4..6525b7bb846ff 100644 --- a/airflow/providers/edge/provider.yaml +++ b/airflow/providers/edge/provider.yaml @@ -32,6 +32,10 @@ dependencies: - apache-airflow>=2.10.0 - pydantic>=2.3.0 +plugins: + - name: edge_executor + plugin-class: airflow.providers.edge.plugins.edge_executor_plugin.EdgeExecutorPlugin + config: edge: description: | diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 074c5dd41e93b..610124e65fd28 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -526,7 +526,12 @@ "pydantic>=2.3.0" ], "devel-deps": [], - "plugins": [], + "plugins": [ + { + "name": "edge_executor", + "plugin-class": "airflow.providers.edge.plugins.edge_executor_plugin.EdgeExecutorPlugin" + } + ], "cross-providers-deps": [], "excluded-python-versions": [], "state": "not-ready" diff --git a/tests/plugins/test_plugins_manager.py b/tests/plugins/test_plugins_manager.py index cb59afd36742a..7e4bedbfb8c1b 100644 --- a/tests/plugins/test_plugins_manager.py +++ b/tests/plugins/test_plugins_manager.py @@ -417,7 +417,7 @@ def test_does_not_double_import_entrypoint_provider_plugins(self): assert len(plugins_manager.plugins) == 0 plugins_manager.load_entrypoint_plugins() plugins_manager.load_providers_plugins() - assert len(plugins_manager.plugins) == 3 + assert len(plugins_manager.plugins) == 4 class TestPluginsDirectorySource: diff --git a/tests/providers/edge/api_endpoints/__init__.py b/tests/providers/edge/api_endpoints/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/api_endpoints/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/edge/api_endpoints/test_health_endpoint.py b/tests/providers/edge/api_endpoints/test_health_endpoint.py new file mode 100644 index 0000000000000..1bfc9e5c0c5bf --- /dev/null +++ b/tests/providers/edge/api_endpoints/test_health_endpoint.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.edge.api_endpoints.health_endpoint import health + + +def test_health(): + assert health() == {} diff --git a/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py b/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py new file mode 100644 index 0000000000000..becf2f9397e31 --- /dev/null +++ b/tests/providers/edge/api_endpoints/test_rpc_api_endpoint.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Generator +from unittest import mock + +import pytest + +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.configuration import conf +from airflow.models.baseoperator import BaseOperator +from airflow.models.connection import Connection +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import XCom +from airflow.operators.empty import EmptyOperator +from airflow.providers.edge.api_endpoints.rpc_api_endpoint import _initialize_method_map +from airflow.providers.edge.models.edge_job import EdgeJob +from airflow.providers.edge.models.edge_logs import EdgeLogs +from airflow.providers.edge.models.edge_worker import EdgeWorker +from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.settings import _ENABLE_AIP_44 +from airflow.utils.jwt_signer import JWTSigner +from airflow.utils.state import State +from airflow.www import app +from tests.test_utils.decorators import dont_initialize_flask_app_submodules +from tests.test_utils.mock_plugins import mock_plugin_manager + +# Note: Sounds a bit strange to disable internal API tests in isolation mode but... +# As long as the test is modelled to run its own internal API endpoints, it is conflicting +# to the test setup with a dedicated internal API server. +pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] + + +def test_initialize_method_map(): + method_map = _initialize_method_map() + assert len(method_map) > 70 + for method in [ + # Test some basics + XCom.get_value, + XCom.get_one, + XCom.clear, + XCom.set, + DagRun.get_previous_dagrun, + DagRun.get_previous_scheduled_dagrun, + DagRun.get_task_instances, + DagRun.fetch_task_instance, + # Test some for Edge + EdgeJob.reserve_task, + EdgeJob.set_state, + EdgeLogs.push_logs, + EdgeWorker.register_worker, + EdgeWorker.set_state, + ]: + method_key = f"{method.__module__}.{method.__qualname__}" + assert method_key in method_map.keys() + + +if TYPE_CHECKING: + from flask import Flask + +TEST_METHOD_NAME = "test_method" +TEST_METHOD_WITH_LOG_NAME = "test_method_with_log" +TEST_API_ENDPOINT = "/edge_worker/v1/rpcapi" + +mock_test_method = mock.MagicMock() + +pytest.importorskip("pydantic", minversion="2.0.0") + + +def equals(a, b) -> bool: + return a == b + + +@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") +class TestRpcApiEndpoint: + @pytest.fixture(scope="session") + def minimal_app_for_edge_api(self) -> Flask: + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_api_auth", # This is needed for Airflow 2.10 compat tests + "init_appbuilder", + "init_plugins", + ] + ) + def factory() -> Flask: + import airflow.providers.edge.plugins.edge_executor_plugin as plugin_module + + class TestingEdgeExecutorPlugin(plugin_module.EdgeExecutorPlugin): + flask_blueprints = [plugin_module._get_api_endpoints(), plugin_module.template_bp] + + testing_edge_plugin = TestingEdgeExecutorPlugin() + assert len(testing_edge_plugin.flask_blueprints) > 0 + with mock_plugin_manager(plugins=[testing_edge_plugin]): + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + + return factory() + + @pytest.fixture + def setup_attrs(self, minimal_app_for_edge_api: Flask) -> Generator: + self.app = minimal_app_for_edge_api + self.client = self.app.test_client() # type:ignore + mock_test_method.reset_mock() + mock_test_method.side_effect = None + with mock.patch( + "airflow.providers.edge.api_endpoints.rpc_api_endpoint._initialize_method_map" + ) as mock_initialize_method_map: + mock_initialize_method_map.return_value = { + TEST_METHOD_NAME: mock_test_method, + } + yield mock_initialize_method_map + + @pytest.fixture + def signer(self) -> JWTSigner: + return JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=conf.getint("core", "internal_api_clock_grace", fallback=30), + audience="api", + ) + + @pytest.mark.parametrize( + "input_params, method_result, result_cmp_func, method_params", + [ + ({}, None, lambda got, _: got == b"", {}), + ({}, "test_me", equals, {}), + ( + BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"}), + ("dag_id_15", "fake-task", 1), + equals, + {"dag_id": 15, "task_id": "fake-task"}, + ), + ( + {}, + TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING), + lambda a, b: a.model_dump() == TaskInstancePydantic.model_validate(b).model_dump() + and isinstance(a.task, BaseOperator), + {}, + ), + ( + {}, + Connection(conn_id="test_conn", conn_type="http", host="", password=""), + lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id, + {}, + ), + ], + ) + def test_method( + self, input_params, method_result, result_cmp_func, method_params, setup_attrs, signer: JWTSigner + ): + mock_test_method.return_value = method_result + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": input_params, + } + response = self.client.post( + TEST_API_ENDPOINT, + headers=headers, + data=json.dumps(input_data), + ) + assert response.status_code == 200 + if method_result: + response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + else: + response_data = response.data + + assert result_cmp_func(response_data, method_result) + + mock_test_method.assert_called_once_with(**method_params, session=mock.ANY) + + def test_method_with_exception(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + mock_test_method.side_effect = ValueError("Error!!!") + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 500 + assert response.data, b"Error executing method: test_method." + mock_test_method.assert_called_once() + + def test_unknown_method(self, setup_attrs, signer: JWTSigner): + UNKNOWN_METHOD = "i-bet-it-does-not-exist" + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": UNKNOWN_METHOD}), + } + data = {"jsonrpc": "2.0", "method": UNKNOWN_METHOD, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 400 + assert response.data.startswith(b"Unrecognized method: i-bet-it-does-not-exist.") + mock_test_method.assert_not_called() + + def test_invalid_jsonrpc(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": TEST_METHOD_NAME}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + response = self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + assert response.status_code == 400 + assert response.data.startswith(b"Expected jsonrpc 2.0 request.") + mock_test_method.assert_not_called() + + def test_missing_token(self, setup_attrs): + mock_test_method.return_value = None + + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": {}, + } + with pytest.raises(PermissionDenied, match="Unable to authenticate API via token."): + self.client.post( + TEST_API_ENDPOINT, + headers={"Content-Type": "application/json", "Accept": "application/json"}, + data=json.dumps(input_data), + ) + + def test_invalid_token(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises( + PermissionDenied, match="Bad Signature. Please use only the tokens provided by the API." + ): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + + def test_missing_accept(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) + + def test_wrong_accept(self, setup_attrs, signer: JWTSigner): + headers = { + "Content-Type": "application/json", + "Accept": "application/html", + "Authorization": signer.generate_signed_token({"method": "WRONG_METHOD_NAME"}), + } + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": {}} + + with pytest.raises(PermissionDenied, match="Expected Accept: application/json"): + self.client.post(TEST_API_ENDPOINT, headers=headers, data=json.dumps(data)) diff --git a/tests/providers/edge/plugins/__init__.py b/tests/providers/edge/plugins/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/plugins/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/edge/plugins/test_edge_executor_plugin.py b/tests/providers/edge/plugins/test_edge_executor_plugin.py new file mode 100644 index 0000000000000..e3422b17da3c8 --- /dev/null +++ b/tests/providers/edge/plugins/test_edge_executor_plugin.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import importlib + +import pytest + +from airflow.plugins_manager import AirflowPlugin +from airflow.providers.edge.plugins import edge_executor_plugin +from tests.test_utils.config import conf_vars + + +def test_plugin_inactive(): + with conf_vars({("edge", "api_enabled"): "false"}): + importlib.reload(edge_executor_plugin) + + from airflow.providers.edge.plugins.edge_executor_plugin import ( + EDGE_EXECUTOR_ACTIVE, + EdgeExecutorPlugin, + ) + + rep = EdgeExecutorPlugin() + assert not EDGE_EXECUTOR_ACTIVE + assert len(rep.flask_blueprints) == 0 + assert len(rep.appbuilder_views) == 0 + + +def test_plugin_active(): + with conf_vars({("edge", "api_enabled"): "true"}): + importlib.reload(edge_executor_plugin) + + from airflow.providers.edge.plugins.edge_executor_plugin import ( + EDGE_EXECUTOR_ACTIVE, + EdgeExecutorPlugin, + ) + + rep = EdgeExecutorPlugin() + assert EDGE_EXECUTOR_ACTIVE + assert len(rep.flask_blueprints) == 2 + assert len(rep.appbuilder_views) == 2 + + +@pytest.fixture +def plugin(): + from airflow.providers.edge.plugins.edge_executor_plugin import EdgeExecutorPlugin + + return EdgeExecutorPlugin() + + +def test_plugin_is_airflow_plugin(plugin): + assert isinstance(plugin, AirflowPlugin)