diff --git a/airflow/providers/http/hooks/http.py b/airflow/providers/http/hooks/http.py index fc19d0b102c05..7b98ec25dfdee 100644 --- a/airflow/providers/http/hooks/http.py +++ b/airflow/providers/http/hooks/http.py @@ -26,6 +26,7 @@ from aiohttp import ClientResponseError from asgiref.sync import sync_to_async from requests.auth import HTTPBasicAuth +from requests.models import DEFAULT_REDIRECT_LIMIT from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter from airflow.exceptions import AirflowException @@ -34,6 +35,8 @@ if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse + from airflow.models import Connection + class HttpHook(BaseHook): """Interact with HTTP servers. @@ -113,8 +116,19 @@ def get_conn(self, headers: dict[Any, Any] | None = None) -> requests.Session: elif self._auth_type: session.auth = self.auth_type() if conn.extra: + extra = conn.extra_dejson + extra.pop( + "timeout", None + ) # ignore this as timeout is only accepted in request method of Session + extra.pop("allow_redirects", None) # ignore this as only max_redirects is accepted in Session + session.proxies = extra.pop("proxies", extra.pop("proxy", {})) + session.stream = extra.pop("stream", False) + session.verify = extra.pop("verify", extra.pop("verify_ssl", True)) + session.cert = extra.pop("cert", None) + session.max_redirects = extra.pop("max_redirects", DEFAULT_REDIRECT_LIMIT) + try: - session.headers.update(conn.extra_dejson) + session.headers.update(extra) except TypeError: self.log.warning("Connection to %s has invalid extra field.", conn.host) if headers: @@ -336,8 +350,10 @@ async def run( if conn.login: auth = self.auth_type(conn.login, conn.password) if conn.extra: + extra = self._process_extra_options_from_connection(conn=conn, extra_options=extra_options) + try: - _headers.update(conn.extra_dejson) + _headers.update(extra) except TypeError: self.log.warning("Connection to %s has invalid extra field.", conn.host) if headers: @@ -395,6 +411,29 @@ async def run( else: raise NotImplementedError # should not reach this, but makes mypy happy + @classmethod + def _process_extra_options_from_connection(cls, conn: Connection, extra_options: dict) -> dict: + extra = conn.extra_dejson + extra.pop("stream", None) + extra.pop("cert", None) + proxies = extra.pop("proxies", extra.pop("proxy", None)) + timeout = extra.pop("timeout", None) + verify_ssl = extra.pop("verify", extra.pop("verify_ssl", None)) + allow_redirects = extra.pop("allow_redirects", None) + max_redirects = extra.pop("max_redirects", None) + + if proxies is not None and "proxy" not in extra_options: + extra_options["proxy"] = proxies + if timeout is not None and "timeout" not in extra_options: + extra_options["timeout"] = timeout + if verify_ssl is not None and "verify_ssl" not in extra_options: + extra_options["verify_ssl"] = verify_ssl + if allow_redirects is not None and "allow_redirects" not in extra_options: + extra_options["allow_redirects"] = allow_redirects + if max_redirects is not None and "max_redirects" not in extra_options: + extra_options["max_redirects"] = max_redirects + return extra + def _retryable_error_async(self, exception: ClientResponseError) -> bool: """Determine whether an exception may successful on a subsequent attempt. diff --git a/docs/apache-airflow-providers-http/connections/http.rst b/docs/apache-airflow-providers-http/connections/http.rst index 41856cefee7d8..6f1decdec9517 100644 --- a/docs/apache-airflow-providers-http/connections/http.rst +++ b/docs/apache-airflow-providers-http/connections/http.rst @@ -54,7 +54,15 @@ Schema (optional) Specify the service type etc: http/https. Extras (optional) - Specify headers in json format. + Specify headers and default requests parameters in json format. + Following default requests parameters are taken into account: + * ``stream`` + * ``cert`` + * ``proxies or proxy`` + * ``verify or verify_ssl`` + * ``allow_redirects`` + * ``max_redirects`` + When specifying the connection in environment variable you should specify it using URI syntax. diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 617009d5750d6..7b093c66bbabc 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -31,6 +31,7 @@ from aioresponses import aioresponses from requests.adapters import Response from requests.auth import AuthBase, HTTPBasicAuth +from requests.models import DEFAULT_REDIRECT_LIMIT from airflow.exceptions import AirflowException from airflow.models import Connection @@ -46,18 +47,23 @@ def aioresponse(): yield async_response -def get_airflow_connection(unused_conn_id=None): - return Connection(conn_id="http_default", conn_type="http", host="test:8080/", extra='{"bearer": "test"}') +def get_airflow_connection(conn_id: str = "http_default"): + return Connection(conn_id=conn_id, conn_type="http", host="test:8080/", extra='{"bearer": "test"}') -def get_airflow_connection_with_port(unused_conn_id=None): - return Connection(conn_id="http_default", conn_type="http", host="test.com", port=1234) +def get_airflow_connection_with_extra(extra: dict): + def inner(conn_id: str = "http_default"): + return Connection(conn_id=conn_id, conn_type="http", host="test:8080/", extra=json.dumps(extra)) + return inner -def get_airflow_connection_with_login_and_password(unused_conn_id=None): - return Connection( - conn_id="http_default", conn_type="http", host="test.com", login="username", password="pass" - ) + +def get_airflow_connection_with_port(conn_id: str = "http_default"): + return Connection(conn_id=conn_id, conn_type="http", host="test.com", port=1234) + + +def get_airflow_connection_with_login_and_password(conn_id: str = "http_default"): + return Connection(conn_id=conn_id, conn_type="http", host="test.com", login="username", password="pass") class TestHttpHook: @@ -119,6 +125,64 @@ def test_hook_contains_header_from_extra_field(self): assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers assert conn.headers.get("bearer") == "test" + def test_hook_ignore_max_redirects_from_extra_field_as_header(self): + airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "max_redirects": 3}) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("allow_redirects") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is True + assert conn.cert is None + assert conn.max_redirects == 3 + + def test_hook_ignore_proxies_from_extra_field_as_header(self): + airflow_connection = get_airflow_connection_with_extra( + extra={"bearer": "test", "proxies": {"http": "http://proxy:80", "https": "https://proxy:80"}} + ) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("proxies") is None + assert conn.proxies == {"http": "http://proxy:80", "https": "https://proxy:80"} + assert conn.stream is False + assert conn.verify is True + assert conn.cert is None + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + + def test_hook_ignore_verify_from_extra_field_as_header(self): + airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "verify": False}) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("verify") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is False + assert conn.cert is None + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + + def test_hook_ignore_cert_from_extra_field_as_header(self): + airflow_connection = get_airflow_connection_with_extra(extra={"bearer": "test", "cert": "cert.crt"}) + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): + expected_conn = airflow_connection() + conn = self.get_hook.get_conn() + assert dict(conn.headers, **json.loads(expected_conn.extra)) != conn.headers + assert conn.headers.get("bearer") == "test" + assert conn.headers.get("cert") is None + assert conn.proxies == {} + assert conn.stream is False + assert conn.verify is True + assert conn.cert == "cert.crt" + assert conn.max_redirects == DEFAULT_REDIRECT_LIMIT + @mock.patch("requests.Request") def test_hook_with_method_in_lowercase(self, mock_requests): from requests.exceptions import InvalidURL, MissingSchema @@ -525,3 +589,62 @@ async def test_async_request_uses_connection_extra(self, aioresponse): assert all( key in headers and headers[key] == value for key, value in connection_extra.items() ) + + @pytest.mark.asyncio + async def test_async_request_uses_connection_extra_with_requests_parameters(self): + """Test api call asynchronously with a connection that has extra field.""" + connection_extra = {"bearer": "test"} + proxy = {"http": "http://proxy:80", "https": "https://proxy:80"} + airflow_connection = get_airflow_connection_with_extra( + extra={ + **connection_extra, + **{ + "proxies": proxy, + "timeout": 60, + "verify": False, + "allow_redirects": False, + "max_redirects": 3, + }, + } + ) + + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=airflow_connection): + hook = HttpAsyncHook() + with mock.patch("aiohttp.ClientSession.post", new_callable=mock.AsyncMock) as mocked_function: + await hook.run("v1/test") + headers = mocked_function.call_args.kwargs.get("headers") + assert all( + key in headers and headers[key] == value for key, value in connection_extra.items() + ) + assert mocked_function.call_args.kwargs.get("proxy") == proxy + assert mocked_function.call_args.kwargs.get("timeout") == 60 + assert mocked_function.call_args.kwargs.get("verify_ssl") is False + assert mocked_function.call_args.kwargs.get("allow_redirects") is False + assert mocked_function.call_args.kwargs.get("max_redirects") == 3 + + def test_process_extra_options_from_connection(self): + extra_options = {} + proxy = {"http": "http://proxy:80", "https": "https://proxy:80"} + conn = get_airflow_connection_with_extra( + extra={ + "bearer": "test", + "stream": True, + "cert": "cert.crt", + "proxies": proxy, + "timeout": 60, + "verify": False, + "allow_redirects": False, + "max_redirects": 3, + } + )() + + actual = HttpAsyncHook._process_extra_options_from_connection(conn=conn, extra_options=extra_options) + + assert extra_options == { + "proxy": proxy, + "timeout": 60, + "verify_ssl": False, + "allow_redirects": False, + "max_redirects": 3, + } + assert actual == {"bearer": "test"}