diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f58c010b3f58e..d2f9a81d7441b 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -528,7 +528,8 @@ "edge": { "deps": [ "apache-airflow>=2.10.0", - "pydantic>=2.10.2" + "pydantic>=2.10.2", + "retryhttp>=1.2.0" ], "devel-deps": [], "plugins": [ diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index af4df15918cd3..2c120ed652e82 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -27,6 +27,15 @@ Changelog --------- +0.9.7pre0 +......... + +* ``Make API retries configurable via ENV. Connection loss is sustained for 5min by default.`` +* ``Align retry handling logic and tooling with Task SDK, via retryhttp.`` + +Misc +~~~~ + 0.9.6pre0 ......... diff --git a/providers/src/airflow/providers/edge/__init__.py b/providers/src/airflow/providers/edge/__init__.py index 7c0490c20785e..9c2324041e8a0 100644 --- a/providers/src/airflow/providers/edge/__init__.py +++ b/providers/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.9.6pre0" +__version__ = "0.9.7pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/src/airflow/providers/edge/cli/api_client.py b/providers/src/airflow/providers/edge/cli/api_client.py index 483c5ab3759e5..75fb821d259d6 100644 --- a/providers/src/airflow/providers/edge/cli/api_client.py +++ b/providers/src/airflow/providers/edge/cli/api_client.py @@ -18,6 +18,7 @@ import json import logging +import os from datetime import datetime from http import HTTPStatus from pathlib import Path @@ -25,12 +26,10 @@ from urllib.parse import quote, urljoin import requests -import tenacity -from requests.exceptions import ConnectionError -from urllib3.exceptions import NewConnectionError +from retryhttp import retry, wait_retry_after +from tenacity import before_log, wait_random_exponential from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.providers.edge.worker_api.auth import jwt_signer from airflow.providers.edge.worker_api.datamodels import ( EdgeJobFetched, @@ -47,29 +46,30 @@ logger = logging.getLogger(__name__) -def _is_retryable_exception(exception: BaseException) -> bool: - """ - Evaluate which exception types to retry. +# Hidden config options for Edge Worker how retries on HTTP requests should be handled +# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min +# So far there is no other config facility in Task SDK we use ENV for the moment +# TODO: Consider these env variables jointly in task sdk together with task_sdk/src/airflow/sdk/api/client.py +API_RETRIES = int(os.getenv("AIRFLOW__EDGE__API_RETRIES", os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))) +API_RETRY_WAIT_MIN = float( + os.getenv("AIRFLOW__EDGE__API_RETRY_WAIT_MIN", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1.0)) +) +API_RETRY_WAIT_MAX = float( + os.getenv("AIRFLOW__EDGE__API_RETRY_WAIT_MAX", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90.0)) +) - This is especially demanded for cases where an application gateway or Kubernetes ingress can - not find a running instance of a webserver hosting the API (HTTP 502+504) or when the - HTTP request fails in general on network level. - Note that we want to fail on other general errors on the webserver not to send bad requests in an endless loop. - """ - retryable_status_codes = (HTTPStatus.BAD_GATEWAY, HTTPStatus.GATEWAY_TIMEOUT) - return ( - isinstance(exception, AirflowException) - and exception.status_code in retryable_status_codes - or isinstance(exception, (ConnectionError, NewConnectionError)) - ) +_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX) -@tenacity.retry( - stop=tenacity.stop_after_attempt(10), # TODO: Make this configurable - wait=tenacity.wait_exponential(min=1), # TODO: Make this configurable - retry=tenacity.retry_if_exception(_is_retryable_exception), - before_sleep=tenacity.before_log(logger, logging.WARNING), +@retry( + reraise=True, + max_attempt_number=API_RETRIES, + wait_server_errors=_default_wait, + wait_network_errors=_default_wait, + wait_timeouts=_default_wait, + wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429 + before_sleep=before_log(logger, logging.WARNING), ) def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any: signer = jwt_signer() @@ -81,14 +81,9 @@ def _make_generic_request(method: str, rest_path: str, data: str | None = None) } api_endpoint = urljoin(api_url, rest_path) response = requests.request(method, url=api_endpoint, data=data, headers=headers) + response.raise_for_status() if response.status_code == HTTPStatus.NO_CONTENT: return None - if response.status_code != HTTPStatus.OK: - raise AirflowException( - f"Got {response.status_code}:{response.reason} when sending " - f"the internal api request: {response.text}", - HTTPStatus(response.status_code), - ) return json.loads(response.content) diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 115923e981fb7..6835e7909d2b4 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -23,6 +23,7 @@ import sys from dataclasses import dataclass from datetime import datetime +from http import HTTPStatus from pathlib import Path from subprocess import Popen from time import sleep @@ -30,11 +31,11 @@ import psutil from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile +from requests import HTTPError from airflow import __version__ as airflow_version, settings from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.providers.edge import __version__ as edge_provider_version from airflow.providers.edge.cli.api_client import ( jobs_fetch, @@ -199,8 +200,8 @@ def start(self): except EdgeWorkerVersionException as e: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") raise SystemExit(str(e)) - except AirflowException as e: - if "404:NOT FOUND" in str(e): + except HTTPError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: raise SystemExit("Error: API endpoint is not ready, please set [edge] api_enabled=True.") raise SystemExit(str(e)) _write_pid_to_pidfile(self.pid_file_path) diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index f6b0457c07d7e..6628aceab5786 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -27,11 +27,12 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.9.6pre0 + - 0.9.7pre0 dependencies: - apache-airflow>=2.10.0 - pydantic>=2.10.2 + - retryhttp>=1.2.0 plugins: - name: edge_executor diff --git a/providers/tests/edge/cli/test_api_client.py b/providers/tests/edge/cli/test_api_client.py new file mode 100644 index 0000000000000..e77b157b302d5 --- /dev/null +++ b/providers/tests/edge/cli/test_api_client.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from http import HTTPStatus +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest +from requests import HTTPError +from requests.exceptions import ConnectTimeout +from requests_mock import ANY + +from airflow.providers.edge.cli.api_client import _make_generic_request + +from tests_common.test_utils.config import conf_vars + +if TYPE_CHECKING: + from requests_mock import Mocker as RequestsMocker + +MOCK_ENDPOINT = "https://invalid-api-test-endpoint" + + +class TestApiClient: + @conf_vars({("edge", "api_url"): MOCK_ENDPOINT}) + def test_make_generic_request_success(self, requests_mock: RequestsMocker): + requests_mock.get( + ANY, + [ + {"json": {"test": "ok"}}, + {"status_code": HTTPStatus.NO_CONTENT}, + ], + ) + + result1 = _make_generic_request("GET", f"{MOCK_ENDPOINT}/dummy_service", "test") + result2 = _make_generic_request("GET", f"{MOCK_ENDPOINT}/service_no_content", "test") + + assert result1 == {"test": "ok"} + assert result2 is None + assert requests_mock.call_count == 2 + + @patch("time.sleep", return_value=None) + @conf_vars({("edge", "api_url"): MOCK_ENDPOINT}) + def test_make_generic_request_retry(self, mock_sleep, requests_mock: RequestsMocker): + requests_mock.get( + ANY, + [ + *[{"status_code": HTTPStatus.SERVICE_UNAVAILABLE}] * 3, + {"exc": ConnectTimeout}, + {"json": {"test": 42}}, + ], + ) + + result = _make_generic_request("GET", f"{MOCK_ENDPOINT}/flaky_service", "test") + + assert result == {"test": 42} + assert requests_mock.call_count == 5 + + @patch("time.sleep", return_value=None) + @conf_vars({("edge", "api_url"): MOCK_ENDPOINT}) + def test_make_generic_request_unrecoverable_error(self, mock_sleep, requests_mock: RequestsMocker): + requests_mock.get( + ANY, + [ + *[{"status_code": HTTPStatus.INTERNAL_SERVER_ERROR}] * 11, + {"json": {"test": "uups"}}, + ], + ) + + with pytest.raises(HTTPError) as err: + _make_generic_request("GET", f"{MOCK_ENDPOINT}/broken_service", "test") + + assert err.value.response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR + assert requests_mock.call_count == 10 diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py index 123b06af3f9c4..df0cd7c81407f 100644 --- a/providers/tests/edge/cli/test_edge_command.py +++ b/providers/tests/edge/cli/test_edge_command.py @@ -25,8 +25,8 @@ import pytest import time_machine +from requests import HTTPError, Response -from airflow.exceptions import AirflowException from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, _write_pid_to_pidfile from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched @@ -282,15 +282,21 @@ def test_version_mismatch(self, mock_set_state, worker_with_job): @patch("airflow.providers.edge.cli.edge_command.worker_register") def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): - mock_register_worker.side_effect = AirflowException( - "Something with 404:NOT FOUND means API is not active" + mock_response = Response() + mock_response.status_code = 404 + mock_register_worker.side_effect = HTTPError( + "Something with 404:NOT FOUND means API is not active", response=mock_response ) with pytest.raises(SystemExit, match=r"API endpoint is not ready"): worker_with_job.start() @patch("airflow.providers.edge.cli.edge_command.worker_register") def test_start_server_error(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): - mock_register_worker.side_effect = AirflowException("Something other error not FourhundretFour") + mock_response = Response() + mock_response.status_code = 500 + mock_register_worker.side_effect = HTTPError( + "Something other error not FourhundretFour", response=mock_response + ) with pytest.raises(SystemExit, match=r"Something other"): worker_with_job.start()