From 11be4054ffce9136865c576e1ddb8faef194ac53 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Thu, 24 Nov 2022 15:32:30 +0100 Subject: [PATCH 1/9] [AIP-44] Add internal API definition. --- airflow/api_internal/__init__.py | 16 +++ airflow/api_internal/endpoints/__init__.py | 16 +++ .../endpoints/rpc_api_endpoint.py | 70 ++++++++++++ airflow/api_internal/internal_api_call.py | 84 +++++++++++++++ .../api_internal/openapi/internal_api_v1.yaml | 100 ++++++++++++++++++ airflow/config_templates/config.yml | 19 ++++ airflow/config_templates/default_airflow.cfg | 10 ++ airflow/dag_processing/processor.py | 26 +++-- airflow/www/app.py | 5 +- airflow/www/extensions/init_views.py | 19 ++++ tests/test_utils/decorators.py | 1 + 11 files changed, 354 insertions(+), 12 deletions(-) create mode 100644 airflow/api_internal/__init__.py create mode 100644 airflow/api_internal/endpoints/__init__.py create mode 100644 airflow/api_internal/endpoints/rpc_api_endpoint.py create mode 100644 airflow/api_internal/internal_api_call.py create mode 100644 airflow/api_internal/openapi/internal_api_v1.yaml diff --git a/airflow/api_internal/__init__.py b/airflow/api_internal/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_internal/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_internal/endpoints/__init__.py b/airflow/api_internal/endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_internal/endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py new file mode 100644 index 0000000000000..585ae43f470ca --- /dev/null +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging + +from flask import Response + +from airflow.api_connexion.types import APIResponse +from airflow.dag_processing.processor import DagFileProcessor +from airflow.serialization.serialized_objects import BaseSerialization + +log = logging.getLogger(__name__) + +METHODS = { + "dag_processing.processor.update_import_errors": DagFileProcessor.update_import_errors, +} + + +def json_rpc( + body: dict, +) -> APIResponse: + """Handler for Internal API /internal/v1/rpcapi endpoint.""" + log.debug("Got request") + json_rpc = body.get("jsonrpc") + if json_rpc != "2.0": + log.warning("Not jsonrpc-2.0 request") + return Response(response="Expected jsonrpc 2.0 request.", status=400) + + method_name = str(body.get("method")) + if method_name not in METHODS: + log.warning("Unrecognized method: %", method_name) + return Response(response=f"Unrecognized method: {method_name}", status=400) + + params_json = body.get("params") + if not params_json: + params_json = "{}" + handler = METHODS[method_name] + try: + params = BaseSerialization.deserialize(json.loads(params_json)) + except Exception as err: + log.warning("Error deserializing parameters.") + log.warning(err) + return Response(response="Error deserializing parameters.", status=400) + + log.debug("Calling method %.", {method_name}) + handler = METHODS[method_name] + output = handler(**params) + if output: + output_json = BaseSerialization.serialize(json.dumps(output)) + else: + output_json = "" + log.debug("Returning response") + return Response(response=str(output_json), headers={"Content-Type": "application/json"}) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py new file mode 100644 index 0000000000000..12b583de2614f --- /dev/null +++ b/airflow/api_internal/internal_api_call.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import inspect +import json +from typing import Callable, TypeVar + +import requests + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.typing_compat import ParamSpec + +PS = ParamSpec("PS") +RT = TypeVar("RT") + +_use_internal_api = conf.get("core", "database_access_isolation") +_internal_api_url = conf.get("core", "database_api_url") + +_internal_api_endpoint = _internal_api_url + "/internal/v1/rpcapi" +if not _internal_api_endpoint.startswith("http://"): + _internal_api_endpoint = "http://" + _internal_api_endpoint + + +def internal_api_call(method_name: str): + """Decorator for methods which may be executed in database isolation mode. + + If [core]database_access_isolation is true then such method are not executed locally, + but instead RPC call is made to Database API (aka Internal API). This makes some components + stop depending on Airflow database access. + Each decorated method must be present in METHODS list in airflow.api_internal.endpoints.rpc_api_endpoint. + Only static methods can be decorated. This decorator must be before "provide_session". + + See AIP-44 for more information. + """ + headers = { + "Content-Type": "application/json", + } + + def make_jsonrpc_request(params_json: str) -> bytes: + data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} + response = requests.post(_internal_api_endpoint, data=json.dumps(data), headers=headers) + if response.status_code != 200: + raise AirflowException( + f"Got {response.status_code}:{response.reason} when sending the internal api request." + ) + return response.content + + def inner(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: + def make_call(*args, **kwargs) -> RT | None: + if not _use_internal_api: + return func(*args, **kwargs) + + bound = inspect.signature(func).bind(*args, **kwargs) + arguments_dict = dict(bound.arguments) + if "session" in arguments_dict: + del arguments_dict["session"] + args_json = json.dumps(BaseSerialization.serialize(arguments_dict)) + result = make_jsonrpc_request(args_json) + if result: + return BaseSerialization.deserialize(json.loads(result)) + else: + return None + + return make_call + + return inner diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml new file mode 100644 index 0000000000000..e94fa7124114d --- /dev/null +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +openapi: 3.0.2 +info: + title: Airflow Internal API + version: 1.0.0 + description: | + This is Airflow Internal API - which is a proxy for components running + customer code for connecting to Airflow Database. + + It is not intended to be used by any external code. + + You can find more information in AIP-44 + https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API + + +servers: + - url: /internal/v1 + description: Airflow Internal API +paths: + "/rpcapi": + post: + operationId: rpcapi + deprecated: false + x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint + operationId: json_rpc + tags: + - JSONRPC + parameters: [] + responses: + '200': + description: Successful response + requestBody: + x-body-name: body + required: true + content: + application/json: + schema: + type: object + required: + - method + - jsonrpc + - params + properties: + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + method: + type: string + description: Method name + params: + title: Parameters + type: string +x-headers: [] +x-explorer-enabled: true +x-proxy-enabled: true +x-samples-enabled: true +components: + schemas: + JsonRpcRequired: + type: object + required: + - method + - jsonrpc + properties: + method: + type: string + description: Method name + jsonrpc: + type: string + default: '2.0' + description: JSON-RPC Version (2.0) + discriminator: + propertyName: method_name + examplePost: + allOf: + - "$ref": "#/components/schemas/JsonRpcRequired" + - type: object + properties: + params: + title: Method arguments + type: string +tags: [] diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 009f6a48466c2..25f4e646a90b8 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -418,6 +418,18 @@ type: string default: ~ example: '{"some_param": "some_value"}' + - name: database_access_isolation + description: Whether components should use Airflow Internal API for DB connectivity. + version_added: 2.6.0 + type: boolean + example: ~ + default: "False" + - name: database_api_url + description: Airflow Internal API url. Only used if [core] database_access_isolation is True. + version_added: 2.6.0 + type: string + default: ~ + example: 'localhost:8080' - name: database description: ~ @@ -1482,6 +1494,13 @@ type: string example: "dagrun_cleared,failed" default: ~ + - name: run_internal_api + description: | + Boolean for for running Internal API in the webserver. + version_added: 2.6.0 + type: boolean + example: ~ + default: "False" - name: email description: | diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 1cff5f8dbe75f..667aa4f396316 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -240,6 +240,13 @@ daemon_umask = 0o077 # Example: dataset_manager_kwargs = {{"some_param": "some_value"}} # dataset_manager_kwargs = +# Whether components should use Airflow Internal API for DB connectivity. +database_access_isolation = False + +# Airflow Internal API url. Only used if [core] database_access_isolation is True. +# Example: database_api_url = localhost:8080 +# database_api_url = + [database] # The SqlAlchemy connection string to the metadata database. # SqlAlchemy supports many different database engines. @@ -752,6 +759,9 @@ audit_view_excluded_events = gantt,landing_times,tries,duration,calendar,graph,g # Example: audit_view_included_events = dagrun_cleared,failed # audit_view_included_events = +# Boolean for for running Internal API in the webserver. +run_internal_api = False + [email] # Configuration email backend and whether to diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 323ff94c5e63f..06dfbcb5caac5 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import datetime import logging import multiprocessing import os @@ -24,7 +23,7 @@ import threading import time from contextlib import redirect_stderr, redirect_stdout, suppress -from datetime import timedelta +from datetime import datetime, timedelta from multiprocessing.connection import Connection as MultiprocessingConnection from typing import TYPE_CHECKING, Iterator @@ -33,6 +32,7 @@ from sqlalchemy.orm.session import Session from airflow import settings +from airflow.api_internal.internal_api_call import internal_api_call from airflow.callbacks.callback_requests import ( CallbackRequest, DagCallbackRequest, @@ -94,7 +94,7 @@ def __init__( # Whether the process is done running. self._done = False # When the process started. - self._start_time: datetime.datetime | None = None + self._start_time: datetime | None = None # This ID is use to uniquely name the process / thread that's launched # by this processor instance self._instance_id = DagFileProcessorProcess.class_creation_counter @@ -327,7 +327,7 @@ def result(self) -> tuple[int, int] | None: return self._result @property - def start_time(self) -> datetime.datetime: + def start_time(self) -> datetime: """Time when this started to process the file.""" if self._start_time is None: raise AirflowException("Tried to get start time before it started!") @@ -448,7 +448,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: .all() ) if slas: - sla_dates: list[datetime.datetime] = [sla.execution_date for sla in slas] + sla_dates: list[datetime] = [sla.execution_date for sla in slas] fetched_tis: list[TI] = ( session.query(TI) .filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id) @@ -524,17 +524,21 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: session.commit() @staticmethod - def update_import_errors(session: Session, dagbag: DagBag) -> None: + @internal_api_call("dag_processing.processor.update_import_errors") + @provide_session + def update_import_errors( + file_last_changed: dict[str, datetime], import_errors: dict[str, str], session: Session = NEW_SESSION + ) -> None: """ Update any import errors to be displayed in the UI. For the DAGs in the given DagBag, record any associated import errors and clears errors for files that no longer have them. These are usually displayed through the Airflow UI so that users know that there are issues parsing DAGs. - :param session: session for ORM operations :param dagbag: DagBag containing DAGs with import errors + :param session: session for ORM operations """ - files_without_error = dagbag.file_last_changed - dagbag.import_errors.keys() + files_without_error = file_last_changed - import_errors.keys() # Clear the errors of the processed files # that no longer have errors @@ -547,7 +551,7 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: existing_import_error_files = [x.filename for x in session.query(errors.ImportError.filename).all()] # Add the errors of the processed files - for filename, stacktrace in dagbag.import_errors.items(): + for filename, stacktrace in import_errors.items(): if filename in existing_import_error_files: session.query(errors.ImportError).filter(errors.ImportError.filename == filename).update( dict(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace), @@ -754,7 +758,7 @@ def process_file( self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) - self.update_import_errors(session, dagbag) + DagFileProcessor.update_import_errors(dagbag.file_last_changed, dagbag.import_errors, session) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, @@ -781,7 +785,7 @@ def process_file( # Record import errors into the ORM try: - self.update_import_errors(session, dagbag) + DagFileProcessor.update_import_errors(dagbag.file_last_changed, dagbag.import_errors, session) except Exception: self.log.exception("Error logging import errors!") diff --git a/airflow/www/app.py b/airflow/www/app.py index 6a64401e1cecb..783b4211e9196 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -48,6 +48,7 @@ from airflow.www.extensions.init_views import ( init_api_connexion, init_api_experimental, + init_api_internal, init_appbuilder_views, init_connection_form, init_error_handlers, @@ -148,7 +149,9 @@ def create_app(config=None, testing=False): init_plugins(flask_app) init_connection_form() init_error_handlers(flask_app) - init_api_connexion(flask_app) + if conf.get("webserver", "run_internal_api"): + init_api_connexion(flask_app) + init_api_internal(flask_app) init_api_experimental(flask_app) sync_appbuilder_roles(flask_app) diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 25d5d5898c642..0dcb51c22e74a 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -220,6 +220,25 @@ def _handle_method_not_allowed(ex): app.extensions["csrf"].exempt(api_bp) +def init_api_internal(app: Flask) -> None: + """Initialize Internal API""" + base_path = "/internal/v1" + + spec_dir = path.join(ROOT_APP_DIR, "api_internal", "openapi") + internal_app = App(__name__, specification_dir=spec_dir, skip_error_handlers=True) + internal_app.app = app + api_bp = internal_app.add_api( + specification="internal_api_v1.yaml", + base_path=base_path, + validate_responses=True, + strict_validation=True, + ).blueprint + # Like "api_bp.after_request", but the BP is already registered, so we have + # to register it in the app directly. + app.after_request_funcs.setdefault(api_bp.name, []).append(set_cors_headers_on_response) + app.extensions["csrf"].exempt(api_bp) + + def init_api_experimental(app): """Initialize Experimental API""" if not conf.getboolean("api", "enable_experimental_api", fallback=False): diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index d0b71b502c3d0..7d809834da4d5 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -39,6 +39,7 @@ def no_op(*args, **kwargs): "init_connection_form", "init_error_handlers", "init_api_connexion", + "init_api_internal", "init_api_experimental", "sync_appbuilder_roles", "init_jinja_globals", From cc901d15a004a42e3d3a7a347e7d7936e7bdda48 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 30 Nov 2022 13:44:55 +0100 Subject: [PATCH 2/9] Apply reviewer suggestions. --- .../api_internal/endpoints/rpc_api_endpoint.py | 15 ++++++--------- airflow/config_templates/config.yml | 2 +- airflow/config_templates/default_airflow.cfg | 2 +- airflow/www/app.py | 4 ++-- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 585ae43f470ca..098cb64898e38 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -40,12 +40,12 @@ def json_rpc( log.debug("Got request") json_rpc = body.get("jsonrpc") if json_rpc != "2.0": - log.warning("Not jsonrpc-2.0 request") + log.error("Not jsonrpc-2.0 request") return Response(response="Expected jsonrpc 2.0 request.", status=400) method_name = str(body.get("method")) if method_name not in METHODS: - log.warning("Unrecognized method: %", method_name) + log.error("Unrecognized method: %", method_name) return Response(response=f"Unrecognized method: {method_name}", status=400) params_json = body.get("params") @@ -55,16 +55,13 @@ def json_rpc( try: params = BaseSerialization.deserialize(json.loads(params_json)) except Exception as err: - log.warning("Error deserializing parameters.") - log.warning(err) + log.error("Error deserializing parameters.") + log.error(err) return Response(response="Error deserializing parameters.", status=400) log.debug("Calling method %.", {method_name}) handler = METHODS[method_name] output = handler(**params) - if output: - output_json = BaseSerialization.serialize(json.dumps(output)) - else: - output_json = "" + output_json = BaseSerialization.serialize(json.dumps(output)) log.debug("Returning response") - return Response(response=str(output_json), headers={"Content-Type": "application/json"}) + return Response(response=str(output_json or ""), headers={"Content-Type": "application/json"}) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 25f4e646a90b8..9b6ba0840c1f6 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1496,7 +1496,7 @@ default: ~ - name: run_internal_api description: | - Boolean for for running Internal API in the webserver. + Boolean for running Internal API in the webserver. version_added: 2.6.0 type: boolean example: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 667aa4f396316..46b1667ebb0e9 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -759,7 +759,7 @@ audit_view_excluded_events = gantt,landing_times,tries,duration,calendar,graph,g # Example: audit_view_included_events = dagrun_cleared,failed # audit_view_included_events = -# Boolean for for running Internal API in the webserver. +# Boolean for running Internal API in the webserver. run_internal_api = False [email] diff --git a/airflow/www/app.py b/airflow/www/app.py index 783b4211e9196..f5ff2bcf4a976 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -149,9 +149,9 @@ def create_app(config=None, testing=False): init_plugins(flask_app) init_connection_form() init_error_handlers(flask_app) + init_api_connexion(flask_app) if conf.get("webserver", "run_internal_api"): - init_api_connexion(flask_app) - init_api_internal(flask_app) + init_api_internal(flask_app) init_api_experimental(flask_app) sync_appbuilder_roles(flask_app) From 0c6843ebb1d1b480afc70e17920d2b794b8291d5 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 30 Nov 2022 13:45:45 +0100 Subject: [PATCH 3/9] Add tests --- .../endpoints/rpc_api_endpoint.py | 48 ++++-- airflow/api_internal/internal_api_call.py | 85 ++++++---- airflow/config_templates/config.yml | 2 +- airflow/config_templates/default_airflow.cfg | 2 +- airflow/dag_processing/processor.py | 2 +- tests/api_internal/__init__.py | 16 ++ tests/api_internal/endpoints/__init__.py | 16 ++ .../endpoints/test_rpc_api_endpoint.py | 124 ++++++++++++++ tests/api_internal/test_internal_api_call.py | 151 ++++++++++++++++++ 9 files changed, 398 insertions(+), 48 deletions(-) create mode 100644 tests/api_internal/__init__.py create mode 100644 tests/api_internal/endpoints/__init__.py create mode 100644 tests/api_internal/endpoints/test_rpc_api_endpoint.py create mode 100644 tests/api_internal/test_internal_api_call.py diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 098cb64898e38..bb59d3a09543d 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -28,9 +28,16 @@ log = logging.getLogger(__name__) -METHODS = { - "dag_processing.processor.update_import_errors": DagFileProcessor.update_import_errors, -} + +def _build_methods_map(list) -> dict: + return {f"{func.__module__}.{func.__name__}": func for func in list} + + +METHODS_MAP = _build_methods_map( + [ + DagFileProcessor.update_import_errors, + ] +) def json_rpc( @@ -40,28 +47,35 @@ def json_rpc( log.debug("Got request") json_rpc = body.get("jsonrpc") if json_rpc != "2.0": - log.error("Not jsonrpc-2.0 request") + log.error("Not jsonrpc-2.0 request.") return Response(response="Expected jsonrpc 2.0 request.", status=400) method_name = str(body.get("method")) - if method_name not in METHODS: - log.error("Unrecognized method: %", method_name) - return Response(response=f"Unrecognized method: {method_name}", status=400) + if method_name not in METHODS_MAP: + log.error("Unrecognized method: %s.", method_name) + return Response(response=f"Unrecognized method: {method_name}.", status=400) - params_json = body.get("params") - if not params_json: - params_json = "{}" - handler = METHODS[method_name] + handler = METHODS_MAP[method_name] try: - params = BaseSerialization.deserialize(json.loads(params_json)) + params = {} + if body.get("params"): + params_json = json.loads(str(body.get("params"))) + params = BaseSerialization.deserialize(params_json) except Exception as err: log.error("Error deserializing parameters.") log.error(err) return Response(response="Error deserializing parameters.", status=400) log.debug("Calling method %.", {method_name}) - handler = METHODS[method_name] - output = handler(**params) - output_json = BaseSerialization.serialize(json.dumps(output)) - log.debug("Returning response") - return Response(response=str(output_json or ""), headers={"Content-Type": "application/json"}) + try: + handler = METHODS_MAP[method_name] + output = handler(**params) + output_json = BaseSerialization.serialize(output) + log.debug("Returning response") + return Response( + response=json.dumps(output_json or "{}"), headers={"Content-Type": "application/json"} + ) + except Exception as e: + log.error("Error when calling method %s.", method_name) + log.error(e) + return Response(response=f"Error executing method: {method_name}.", status=500) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 12b583de2614f..ca4a527f8e574 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -19,27 +19,55 @@ import inspect import json +from functools import wraps from typing import Callable, TypeVar import requests from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowConfigException, AirflowException from airflow.serialization.serialized_objects import BaseSerialization from airflow.typing_compat import ParamSpec PS = ParamSpec("PS") RT = TypeVar("RT") -_use_internal_api = conf.get("core", "database_access_isolation") -_internal_api_url = conf.get("core", "database_api_url") -_internal_api_endpoint = _internal_api_url + "/internal/v1/rpcapi" -if not _internal_api_endpoint.startswith("http://"): - _internal_api_endpoint = "http://" + _internal_api_endpoint +class InternalApiConfig: + """Stores and caches configuration for Internal API.""" + _initialized = False + _use_internal_api = False + _internal_api_endpoint = "" -def internal_api_call(method_name: str): + @staticmethod + def get_use_internal_api(): + if not InternalApiConfig._initialized: + InternalApiConfig._init_values() + return InternalApiConfig._use_internal_api + + @staticmethod + def get_internal_api_endpoint(): + if not InternalApiConfig._initialized: + InternalApiConfig._init_values() + return InternalApiConfig._internal_api_endpoint + + @staticmethod + def _init_values(): + use_internal_api = conf.getboolean("core", "database_access_isolation") + internal_api_url = conf.get("core", "database_api_url") + + internal_api_endpoint = internal_api_url + "/internal/v1/rpcapi" + + if use_internal_api and not internal_api_endpoint.startswith("http://"): + raise AirflowConfigException("[core]database_api_url must start with http://") + + InternalApiConfig._initialized = True + InternalApiConfig._use_internal_api = use_internal_api + InternalApiConfig._internal_api_endpoint = internal_api_endpoint + + +def internal_api_call(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: """Decorator for methods which may be executed in database isolation mode. If [core]database_access_isolation is true then such method are not executed locally, @@ -54,31 +82,32 @@ def internal_api_call(method_name: str): "Content-Type": "application/json", } - def make_jsonrpc_request(params_json: str) -> bytes: + def make_jsonrpc_request(method_name: str, params_json: str) -> bytes: data = {"jsonrpc": "2.0", "method": method_name, "params": params_json} - response = requests.post(_internal_api_endpoint, data=json.dumps(data), headers=headers) + internal_api_endpoint = InternalApiConfig.get_internal_api_endpoint() + response = requests.post(url=internal_api_endpoint, data=json.dumps(data), headers=headers) if response.status_code != 200: raise AirflowException( f"Got {response.status_code}:{response.reason} when sending the internal api request." ) return response.content - def inner(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: - def make_call(*args, **kwargs) -> RT | None: - if not _use_internal_api: - return func(*args, **kwargs) - - bound = inspect.signature(func).bind(*args, **kwargs) - arguments_dict = dict(bound.arguments) - if "session" in arguments_dict: - del arguments_dict["session"] - args_json = json.dumps(BaseSerialization.serialize(arguments_dict)) - result = make_jsonrpc_request(args_json) - if result: - return BaseSerialization.deserialize(json.loads(result)) - else: - return None - - return make_call - - return inner + @wraps(func) + def wrapper(*args, **kwargs) -> RT | None: + use_internal_api = InternalApiConfig.get_use_internal_api() + if not use_internal_api: + return func(*args, **kwargs) + + bound = inspect.signature(func).bind(*args, **kwargs) + arguments_dict = dict(bound.arguments) + if "session" in arguments_dict: + del arguments_dict["session"] + args_json = json.dumps(BaseSerialization.serialize(arguments_dict)) + method_name = f"{func.__module__}.{func.__name__}" + result = make_jsonrpc_request(method_name, args_json) + if result: + return BaseSerialization.deserialize(json.loads(result)) + else: + return None + + return wrapper diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 9b6ba0840c1f6..7dd24a7413d9b 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -429,7 +429,7 @@ version_added: 2.6.0 type: string default: ~ - example: 'localhost:8080' + example: 'http://localhost:8080' - name: database description: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 46b1667ebb0e9..565f2fb4fd732 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -244,7 +244,7 @@ daemon_umask = 0o077 database_access_isolation = False # Airflow Internal API url. Only used if [core] database_access_isolation is True. -# Example: database_api_url = localhost:8080 +# Example: database_api_url = http://localhost:8080 # database_api_url = [database] diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 06dfbcb5caac5..0e6d58334034b 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -524,7 +524,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None: session.commit() @staticmethod - @internal_api_call("dag_processing.processor.update_import_errors") + @internal_api_call @provide_session def update_import_errors( file_last_changed: dict[str, datetime], import_errors: dict[str, str], session: Session = NEW_SESSION diff --git a/tests/api_internal/__init__.py b/tests/api_internal/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/api_internal/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_internal/endpoints/__init__.py b/tests/api_internal/endpoints/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/api_internal/endpoints/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py new file mode 100644 index 0000000000000..45b027d1c8aaa --- /dev/null +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import pytest +from flask import Flask + +from airflow.api_internal.endpoints import rpc_api_endpoint +from airflow.serialization.serialized_objects import BaseSerialization +from airflow.www import app +from tests.test_utils.decorators import dont_initialize_flask_app_submodules + +TEST_METHOD_NAME = "test_method" + +mock_test_method = mock.MagicMock() + + +@pytest.fixture(scope="session") +def minimal_app_for_internal_api() -> Flask: + @dont_initialize_flask_app_submodules( + skip_all_except=[ + "init_appbuilder", + "init_api_internal", + ] + ) + def factory() -> Flask: + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + + return factory() + + +class TestRpcApiEndpoint(unittest.TestCase): + @pytest.fixture(autouse=True) + def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None: + rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method + self.app = minimal_app_for_internal_api + self.client = self.app.test_client() # type:ignore + mock_test_method.reset_mock() + mock_test_method.side_effect = None + + def test_method_without_params(self): + mock_test_method.return_value = "test_me" + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, b'"test_me"') + mock_test_method.assert_called_once() + + def test_method_without_result(self): + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 200) + mock_test_method.assert_called_once() + + def test_method_with_params(self): + mock_test_method.return_value = ("dag_id_15", "fake-task", 1) + data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), + } + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 200) + response_content = BaseSerialization.deserialize(json.loads(response.data)) + self.assertEqual(response_content, ("dag_id_15", "fake-task", 1)) + mock_test_method.assert_called_once_with(dag_id=15, task_id="fake-task") + + def test_method_with_exception(self): + mock_test_method.side_effect = ValueError("Error!!!") + data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 500) + self.assertEqual(response.data, b"Error executing method: test_method.") + mock_test_method.assert_called_once() + + def test_unknown_method(self): + data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist", "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data, b"Unrecognized method: i-bet-it-does-not-exist.") + mock_test_method.assert_not_called() + + def test_invalid_jsonrpc(self): + data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": ""} + + response = self.client.post( + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data, b"Expected jsonrpc 2.0 request.") + mock_test_method.assert_not_called() diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py new file mode 100644 index 0000000000000..5ffde1a13930f --- /dev/null +++ b/tests/api_internal/test_internal_api_call.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from __future__ import annotations + +import json +import unittest +from unittest import mock + +import requests + +from airflow.api_internal.internal_api_call import InternalApiConfig, internal_api_call +from airflow.serialization.serialized_objects import BaseSerialization +from tests.test_utils.config import conf_vars + + +class TestInternalApiConfig(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_disabled(self): + self.assertFalse(InternalApiConfig.get_use_internal_api()) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + def test_get_use_internal_api_enabled(self): + self.assertTrue(InternalApiConfig.get_use_internal_api()) + self.assertEqual( + InternalApiConfig.get_internal_api_endpoint(), + "http://localhost:8888/internal/v1/rpcapi", + ) + + +@internal_api_call +def fake_method() -> str: + return "local-call" + + +@internal_api_call +def fake_method_with_params(dag_id: str, task_id: int) -> str: + return f"local-call-with-params-{dag_id}-{task_id}" + + +class TestInternalApiCall(unittest.TestCase): + def setUp(self): + InternalApiConfig._initialized = False + + @conf_vars( + { + ("core", "database_access_isolation"): "false", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_local_call(self, mock_requests): + result = fake_method() + + self.assertEqual(result, "local-call") + mock_requests.post.assert_not_called() + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method() + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method", + "params": '{"__var": {}, "__type": "dict"}', + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) + + @conf_vars( + { + ("core", "database_access_isolation"): "true", + ("core", "database_api_url"): "http://localhost:8888", + } + ) + @mock.patch("airflow.api_internal.internal_api_call.requests") + def test_remote_call_with_params(self, mock_requests): + response = requests.Response() + response.status_code = 200 + + response._content = json.dumps(BaseSerialization.serialize("remote-call")) + + mock_requests.post.return_value = response + + result = fake_method_with_params("fake-dag", task_id=123) + self.assertEqual(result, "remote-call") + expected_data = json.dumps( + { + "jsonrpc": "2.0", + "method": "tests.api_internal.test_internal_api_call.fake_method_with_params", + "params": json.dumps( + BaseSerialization.serialize( + { + "dag_id": "fake-dag", + "task_id": 123, + } + ) + ), + } + ) + mock_requests.post.assert_called_once_with( + url="http://localhost:8888/internal/v1/rpcapi", + data=expected_data, + headers={"Content-Type": "application/json"}, + ) From 2d2bf1e7ed35200a73a11890a5987b7981d84e9a Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 30 Nov 2022 14:19:25 +0100 Subject: [PATCH 4/9] Fix dag_processing test --- airflow/api_internal/internal_api_call.py | 12 ++++++------ airflow/api_internal/openapi/internal_api_v1.yaml | 8 -------- airflow/dag_processing/processor.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index ca4a527f8e574..ad2bf1282110c 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -55,12 +55,12 @@ def get_internal_api_endpoint(): @staticmethod def _init_values(): use_internal_api = conf.getboolean("core", "database_access_isolation") - internal_api_url = conf.get("core", "database_api_url") - - internal_api_endpoint = internal_api_url + "/internal/v1/rpcapi" - - if use_internal_api and not internal_api_endpoint.startswith("http://"): - raise AirflowConfigException("[core]database_api_url must start with http://") + internal_api_endpoint = "" + if use_internal_api: + internal_api_url = conf.get("core", "database_api_url") + internal_api_endpoint = internal_api_url + "/internal/v1/rpcapi" + if not internal_api_endpoint.startswith("http://"): + raise AirflowConfigException("[core]database_api_url must start with http://") InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = use_internal_api diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml index e94fa7124114d..311c42b4c82af 100644 --- a/airflow/api_internal/openapi/internal_api_v1.yaml +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -89,12 +89,4 @@ components: description: JSON-RPC Version (2.0) discriminator: propertyName: method_name - examplePost: - allOf: - - "$ref": "#/components/schemas/JsonRpcRequired" - - type: object - properties: - params: - title: Method arguments - type: string tags: [] diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 0e6d58334034b..fbb2bd298d743 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -758,7 +758,11 @@ def process_file( self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) - DagFileProcessor.update_import_errors(dagbag.file_last_changed, dagbag.import_errors, session) + DagFileProcessor.update_import_errors( + file_last_changed=dagbag.file_last_changed, + import_errors=dagbag.import_errors, + session=session, + ) if callback_requests: # If there were callback requests for this file but there was a # parse error we still need to progress the state of TIs, @@ -785,7 +789,11 @@ def process_file( # Record import errors into the ORM try: - DagFileProcessor.update_import_errors(dagbag.file_last_changed, dagbag.import_errors, session) + DagFileProcessor.update_import_errors( + file_last_changed=dagbag.file_last_changed, + import_errors=dagbag.import_errors, + session=session, + ) except Exception: self.log.exception("Error logging import errors!") From 64a9da7d49a0099d0931933b543f8fa9c6d37fac Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Thu, 1 Dec 2022 09:54:49 +0100 Subject: [PATCH 5/9] Change tests to parametrized. --- .../endpoints/rpc_api_endpoint.py | 3 +- .../api_internal/openapi/internal_api_v1.yaml | 1 - airflow/www/app.py | 2 +- airflow/www/extensions/init_views.py | 2 + .../endpoints/test_rpc_api_endpoint.py | 80 +++++++++---------- 5 files changed, 43 insertions(+), 45 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index bb59d3a09543d..cc7638fb0eeb2 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -56,8 +56,8 @@ def json_rpc( return Response(response=f"Unrecognized method: {method_name}.", status=400) handler = METHODS_MAP[method_name] + params = {} try: - params = {} if body.get("params"): params_json = json.loads(str(body.get("params"))) params = BaseSerialization.deserialize(params_json) @@ -68,7 +68,6 @@ def json_rpc( log.debug("Calling method %.", {method_name}) try: - handler = METHODS_MAP[method_name] output = handler(**params) output_json = BaseSerialization.serialize(output) log.debug("Returning response") diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml index 311c42b4c82af..c24f7fd1a8e01 100644 --- a/airflow/api_internal/openapi/internal_api_v1.yaml +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -71,7 +71,6 @@ paths: x-headers: [] x-explorer-enabled: true x-proxy-enabled: true -x-samples-enabled: true components: schemas: JsonRpcRequired: diff --git a/airflow/www/app.py b/airflow/www/app.py index f5ff2bcf4a976..19d2831dfdbcb 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -150,7 +150,7 @@ def create_app(config=None, testing=False): init_connection_form() init_error_handlers(flask_app) init_api_connexion(flask_app) - if conf.get("webserver", "run_internal_api"): + if conf.getboolean("webserver", "run_internal_api", fallback=False): init_api_internal(flask_app) init_api_experimental(flask_app) diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 0dcb51c22e74a..7005582c2f3b6 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -222,6 +222,8 @@ def _handle_method_not_allowed(ex): def init_api_internal(app: Flask) -> None: """Initialize Internal API""" + if not conf.getboolean("webserver", "run_internal_api", fallback=False): + return base_path = "/internal/v1" spec_dir = path.join(ROOT_APP_DIR, "api_internal", "openapi") diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index 45b027d1c8aaa..60ce91f2a668d 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import unittest from unittest import mock import pytest @@ -26,6 +25,7 @@ from airflow.api_internal.endpoints import rpc_api_endpoint from airflow.serialization.serialized_objects import BaseSerialization from airflow.www import app +from tests.test_utils.config import conf_vars from tests.test_utils.decorators import dont_initialize_flask_app_submodules TEST_METHOD_NAME = "test_method" @@ -42,12 +42,13 @@ def minimal_app_for_internal_api() -> Flask: ] ) def factory() -> Flask: - return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + with conf_vars({("webserver", "run_internal_api"): "true"}): + return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore return factory() -class TestRpcApiEndpoint(unittest.TestCase): +class TestRpcApiEndpoint: @pytest.fixture(autouse=True) def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None: rpc_api_endpoint.METHODS_MAP[TEST_METHOD_NAME] = mock_test_method @@ -56,41 +57,38 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> None: mock_test_method.reset_mock() mock_test_method.side_effect = None - def test_method_without_params(self): - mock_test_method.return_value = "test_me" - data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} - - response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data, b'"test_me"') - mock_test_method.assert_called_once() - - def test_method_without_result(self): - data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} - - response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) - ) - self.assertEqual(response.status_code, 200) - mock_test_method.assert_called_once() - - def test_method_with_params(self): - mock_test_method.return_value = ("dag_id_15", "fake-task", 1) - data = { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), - } + @pytest.mark.parametrize( + "input_data, method_result, method_params, expected_code", + [ + ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, "test_me", None, 200), + ({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, None, None, 200), + ( + { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), + }, + ("dag_id_15", "fake-task", 1), + {"dag_id": 15, "task_id": "fake-task"}, + 200, + ), + ], + ) + def test_method(self, input_data, method_result, method_params, expected_code): + if method_result: + mock_test_method.return_value = method_result response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(input_data) ) - self.assertEqual(response.status_code, 200) - response_content = BaseSerialization.deserialize(json.loads(response.data)) - self.assertEqual(response_content, ("dag_id_15", "fake-task", 1)) - mock_test_method.assert_called_once_with(dag_id=15, task_id="fake-task") + assert response.status_code == expected_code + if method_result: + response_data = BaseSerialization.deserialize(json.loads(response.data)) + assert response_data == method_result + if method_params: + mock_test_method.assert_called_once_with(**method_params) + else: + mock_test_method.assert_called_once() def test_method_with_exception(self): mock_test_method.side_effect = ValueError("Error!!!") @@ -99,8 +97,8 @@ def test_method_with_exception(self): response = self.client.post( "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) - self.assertEqual(response.status_code, 500) - self.assertEqual(response.data, b"Error executing method: test_method.") + assert response.status_code == 500 + assert response.data, b"Error executing method: test_method." mock_test_method.assert_called_once() def test_unknown_method(self): @@ -109,8 +107,8 @@ def test_unknown_method(self): response = self.client.post( "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.data, b"Unrecognized method: i-bet-it-does-not-exist.") + assert response.status_code == 400 + assert response.data == b"Unrecognized method: i-bet-it-does-not-exist." mock_test_method.assert_not_called() def test_invalid_jsonrpc(self): @@ -119,6 +117,6 @@ def test_invalid_jsonrpc(self): response = self.client.post( "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) - self.assertEqual(response.status_code, 400) - self.assertEqual(response.data, b"Expected jsonrpc 2.0 request.") + assert response.status_code == 400 + assert response.data == b"Expected jsonrpc 2.0 request." mock_test_method.assert_not_called() From c9eb436e981295c1c0eb1eb44be5973e22c6219c Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Tue, 6 Dec 2022 11:46:57 +0100 Subject: [PATCH 6/9] Fix for test_internal_api_call --- tests/api_internal/test_internal_api_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index 5ffde1a13930f..139590009def8 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -104,7 +104,7 @@ def test_remote_call(self, mock_requests): { "jsonrpc": "2.0", "method": "tests.api_internal.test_internal_api_call.fake_method", - "params": '{"__var": {}, "__type": "dict"}', + "params": json.dumps(BaseSerialization.serialize({})), } ) mock_requests.post.assert_called_once_with( From 97a59864fff7fdaa032a1c0157fcd7e97685ebd9 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Tue, 6 Dec 2022 13:51:17 +0100 Subject: [PATCH 7/9] Update internal_api_call comment --- airflow/api_internal/internal_api_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index ad2bf1282110c..041bb8ed0fe42 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -72,7 +72,7 @@ def internal_api_call(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: If [core]database_access_isolation is true then such method are not executed locally, but instead RPC call is made to Database API (aka Internal API). This makes some components - stop depending on Airflow database access. + decouple from direct Airflow database access. Each decorated method must be present in METHODS list in airflow.api_internal.endpoints.rpc_api_endpoint. Only static methods can be decorated. This decorator must be before "provide_session". From 15ee13e8a3d12664ae8419b3f1030b4196d423af Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 7 Dec 2022 16:22:12 +0100 Subject: [PATCH 8/9] Apply reviewer suggestions. --- airflow/api_internal/endpoints/rpc_api_endpoint.py | 4 ++-- airflow/api_internal/internal_api_call.py | 5 +++-- airflow/api_internal/openapi/internal_api_v1.yaml | 4 ++-- airflow/config_templates/config.yml | 5 +++-- airflow/config_templates/default_airflow.cfg | 4 ++-- airflow/www/extensions/init_views.py | 2 +- tests/api_internal/endpoints/test_rpc_api_endpoint.py | 10 ++++++---- tests/api_internal/test_internal_api_call.py | 6 +++--- 8 files changed, 22 insertions(+), 18 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index cc7638fb0eeb2..90bb23e112840 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -40,10 +40,10 @@ def _build_methods_map(list) -> dict: ) -def json_rpc( +def internal_airflow_api( body: dict, ) -> APIResponse: - """Handler for Internal API /internal/v1/rpcapi endpoint.""" + """Handler for Internal API /internal_api/v1/rpcapi endpoint.""" log.debug("Got request") json_rpc = body.get("jsonrpc") if json_rpc != "2.0": diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 041bb8ed0fe42..038369fce0151 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -58,7 +58,7 @@ def _init_values(): internal_api_endpoint = "" if use_internal_api: internal_api_url = conf.get("core", "database_api_url") - internal_api_endpoint = internal_api_url + "/internal/v1/rpcapi" + internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi" if not internal_api_endpoint.startswith("http://"): raise AirflowConfigException("[core]database_api_url must start with http://") @@ -76,7 +76,8 @@ def internal_api_call(func: Callable[PS, RT | None]) -> Callable[PS, RT | None]: Each decorated method must be present in METHODS list in airflow.api_internal.endpoints.rpc_api_endpoint. Only static methods can be decorated. This decorator must be before "provide_session". - See AIP-44 for more information. + See [AIP-44](https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-44+Airflow+Internal+API) + for more information . """ headers = { "Content-Type": "application/json", diff --git a/airflow/api_internal/openapi/internal_api_v1.yaml b/airflow/api_internal/openapi/internal_api_v1.yaml index c24f7fd1a8e01..58ef96217969e 100644 --- a/airflow/api_internal/openapi/internal_api_v1.yaml +++ b/airflow/api_internal/openapi/internal_api_v1.yaml @@ -31,7 +31,7 @@ info: servers: - - url: /internal/v1 + - url: /internal_api/v1 description: Airflow Internal API paths: "/rpcapi": @@ -39,7 +39,7 @@ paths: operationId: rpcapi deprecated: false x-openapi-router-controller: airflow.api_internal.endpoints.rpc_api_endpoint - operationId: json_rpc + operationId: internal_airflow_api tags: - JSONRPC parameters: [] diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 7dd24a7413d9b..1a4099e579ec2 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -419,13 +419,14 @@ default: ~ example: '{"some_param": "some_value"}' - name: database_access_isolation - description: Whether components should use Airflow Internal API for DB connectivity. + description: (experimental) Whether components should use Airflow Internal API for DB connectivity. version_added: 2.6.0 type: boolean example: ~ default: "False" - name: database_api_url - description: Airflow Internal API url. Only used if [core] database_access_isolation is True. + description: | + (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. version_added: 2.6.0 type: string default: ~ diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 565f2fb4fd732..491f66a43d01a 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -240,10 +240,10 @@ daemon_umask = 0o077 # Example: dataset_manager_kwargs = {{"some_param": "some_value"}} # dataset_manager_kwargs = -# Whether components should use Airflow Internal API for DB connectivity. +# (experimental) Whether components should use Airflow Internal API for DB connectivity. database_access_isolation = False -# Airflow Internal API url. Only used if [core] database_access_isolation is True. +# (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. # Example: database_api_url = http://localhost:8080 # database_api_url = diff --git a/airflow/www/extensions/init_views.py b/airflow/www/extensions/init_views.py index 7005582c2f3b6..86f94d2f229ce 100644 --- a/airflow/www/extensions/init_views.py +++ b/airflow/www/extensions/init_views.py @@ -224,7 +224,7 @@ def init_api_internal(app: Flask) -> None: """Initialize Internal API""" if not conf.getboolean("webserver", "run_internal_api", fallback=False): return - base_path = "/internal/v1" + base_path = "/internal_api/v1" spec_dir = path.join(ROOT_APP_DIR, "api_internal", "openapi") internal_app = App(__name__, specification_dir=spec_dir, skip_error_handlers=True) diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index 60ce91f2a668d..68f22fe6cc1a6 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -79,7 +79,9 @@ def test_method(self, input_data, method_result, method_params, expected_code): mock_test_method.return_value = method_result response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(input_data) + "/internal_api/v1/rpcapi", + headers={"Content-Type": "application/json"}, + data=json.dumps(input_data), ) assert response.status_code == expected_code if method_result: @@ -95,7 +97,7 @@ def test_method_with_exception(self): data = {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""} response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 500 assert response.data, b"Error executing method: test_method." @@ -105,7 +107,7 @@ def test_unknown_method(self): data = {"jsonrpc": "2.0", "method": "i-bet-it-does-not-exist", "params": ""} response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 assert response.data == b"Unrecognized method: i-bet-it-does-not-exist." @@ -115,7 +117,7 @@ def test_invalid_jsonrpc(self): data = {"jsonrpc": "1.0", "method": TEST_METHOD_NAME, "params": ""} response = self.client.post( - "/internal/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) + "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(data) ) assert response.status_code == 400 assert response.data == b"Expected jsonrpc 2.0 request." diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index 139590009def8..23508f8351007 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -52,7 +52,7 @@ def test_get_use_internal_api_enabled(self): self.assertTrue(InternalApiConfig.get_use_internal_api()) self.assertEqual( InternalApiConfig.get_internal_api_endpoint(), - "http://localhost:8888/internal/v1/rpcapi", + "http://localhost:8888/internal_api/v1/rpcapi", ) @@ -108,7 +108,7 @@ def test_remote_call(self, mock_requests): } ) mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal/v1/rpcapi", + url="http://localhost:8888/internal_api/v1/rpcapi", data=expected_data, headers={"Content-Type": "application/json"}, ) @@ -145,7 +145,7 @@ def test_remote_call_with_params(self, mock_requests): } ) mock_requests.post.assert_called_once_with( - url="http://localhost:8888/internal/v1/rpcapi", + url="http://localhost:8888/internal_api/v1/rpcapi", data=expected_data, headers={"Content-Type": "application/json"}, ) From 4c55f94ae05f433b6ea6359ebb32891970299338 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 7 Dec 2022 16:24:04 +0100 Subject: [PATCH 9/9] Rename [core]database_api_url to [core]internal_api_url --- airflow/api_internal/internal_api_call.py | 4 ++-- airflow/config_templates/config.yml | 2 +- airflow/config_templates/default_airflow.cfg | 4 ++-- tests/api_internal/test_internal_api_call.py | 10 +++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index 038369fce0151..9de4d33c864d7 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -57,10 +57,10 @@ def _init_values(): use_internal_api = conf.getboolean("core", "database_access_isolation") internal_api_endpoint = "" if use_internal_api: - internal_api_url = conf.get("core", "database_api_url") + internal_api_url = conf.get("core", "internal_api_url") internal_api_endpoint = internal_api_url + "/internal_api/v1/rpcapi" if not internal_api_endpoint.startswith("http://"): - raise AirflowConfigException("[core]database_api_url must start with http://") + raise AirflowConfigException("[core]internal_api_url must start with http://") InternalApiConfig._initialized = True InternalApiConfig._use_internal_api = use_internal_api diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 1a4099e579ec2..5446d617c2796 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -424,7 +424,7 @@ type: boolean example: ~ default: "False" - - name: database_api_url + - name: internal_api_url description: | (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. version_added: 2.6.0 diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 491f66a43d01a..54001d36e102d 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -244,8 +244,8 @@ daemon_umask = 0o077 database_access_isolation = False # (experimental)Airflow Internal API url. Only used if [core] database_access_isolation is True. -# Example: database_api_url = http://localhost:8080 -# database_api_url = +# Example: internal_api_url = http://localhost:8080 +# internal_api_url = [database] # The SqlAlchemy connection string to the metadata database. diff --git a/tests/api_internal/test_internal_api_call.py b/tests/api_internal/test_internal_api_call.py index 23508f8351007..579a7720cc8b0 100644 --- a/tests/api_internal/test_internal_api_call.py +++ b/tests/api_internal/test_internal_api_call.py @@ -36,7 +36,7 @@ def setUp(self): @conf_vars( { ("core", "database_access_isolation"): "false", - ("core", "database_api_url"): "http://localhost:8888", + ("core", "internal_api_url"): "http://localhost:8888", } ) def test_get_use_internal_api_disabled(self): @@ -45,7 +45,7 @@ def test_get_use_internal_api_disabled(self): @conf_vars( { ("core", "database_access_isolation"): "true", - ("core", "database_api_url"): "http://localhost:8888", + ("core", "internal_api_url"): "http://localhost:8888", } ) def test_get_use_internal_api_enabled(self): @@ -73,7 +73,7 @@ def setUp(self): @conf_vars( { ("core", "database_access_isolation"): "false", - ("core", "database_api_url"): "http://localhost:8888", + ("core", "internal_api_url"): "http://localhost:8888", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") @@ -86,7 +86,7 @@ def test_local_call(self, mock_requests): @conf_vars( { ("core", "database_access_isolation"): "true", - ("core", "database_api_url"): "http://localhost:8888", + ("core", "internal_api_url"): "http://localhost:8888", } ) @mock.patch("airflow.api_internal.internal_api_call.requests") @@ -116,7 +116,7 @@ def test_remote_call(self, mock_requests): @conf_vars( { ("core", "database_access_isolation"): "true", - ("core", "database_api_url"): "http://localhost:8888", + ("core", "internal_api_url"): "http://localhost:8888", } ) @mock.patch("airflow.api_internal.internal_api_call.requests")