From c410b9c6b74d82f83ea8a21db0e849f6ef868267 Mon Sep 17 00:00:00 2001 From: Chris Patterson Date: Mon, 10 Jan 2022 11:00:21 -0500 Subject: [PATCH] sources/azure: ensure retries on IMDS request failure There are two issues with IMDS retries: 1. IMDS_VER_WANT will never be attempted if retries=0, such as when fetching network metadata with infinite=True. 2. get_imds_data_with_api_fallback() will attempt one request with IMDS_VER_WANT. If the connection fails due to a timeout, connection issue, or error code other than 400, an empty dictionary will be returned without attempting the requested number of retries. This PR: - Updates get_imds_data_with_api_fallback() to invoke get_metadata_from_imds() with the specified retries and infinite parameters. - Updates retry_on_url_exc to take a configurable set of HTTP error codes and exception types to retry on. - Add IMDS_RETRY_CODES set to retry with when fetching data from IMDS: - 404 not found (yet) - 410 gone / unavailable (yet) - 429 rate-limited/throttled - 500 server error - Replace default callback with imds_readurl_exception_callback, which configures retry_on_url_exc() with these error codes and instances. - Add new pytests for IMDS to eventually replace the unittest equivalents and improve existing coverage. Signed-off-by: Chris Patterson --- cloudinit/sources/DataSourceAzure.py | 93 ++++++----- cloudinit/url_helper.py | 16 +- tests/unittests/sources/test_azure.py | 230 ++++++++++++++++++++++++-- 3 files changed, 283 insertions(+), 56 deletions(-) diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 359dfbdee7e..bf2854ad988 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -7,6 +7,7 @@ import base64 import crypt import datetime +import functools import os import os.path import re @@ -68,6 +69,17 @@ IMDS_VER_MIN = "2019-06-01" IMDS_VER_WANT = "2021-08-01" IMDS_EXTENDED_VER_MIN = "2021-03-01" +IMDS_RETRY_CODES = ( + 404, # not found (yet) + 410, # gone / unavailable (yet) + 429, # rate-limited/throttled + 500, # server error +) +imds_readurl_exception_callback = functools.partial( + retry_on_url_exc, + retry_codes=IMDS_RETRY_CODES, + retry_instances=(requests.Timeout,), +) class MetadataType(Enum): @@ -726,44 +738,49 @@ def _get_data(self): def get_imds_data_with_api_fallback( self, *, - retries, - md_type=MetadataType.ALL, - exc_cb=retry_on_url_exc, - infinite=False, - ): - """ - Wrapper for get_metadata_from_imds so that we can have flexibility - in which IMDS api-version we use. If a particular instance of IMDS - does not have the api version that is desired, we want to make - this fault tolerant and fall back to a good known minimum api - version. - """ - for _ in range(retries): - try: - LOG.info("Attempting IMDS api-version: %s", IMDS_VER_WANT) - return get_metadata_from_imds( - retries=0, - md_type=md_type, - api_version=IMDS_VER_WANT, - exc_cb=exc_cb, - ) - except UrlError as err: - LOG.info("UrlError with IMDS api-version: %s", IMDS_VER_WANT) - if err.code == 400: - log_msg = "Fall back to IMDS api-version: {}".format( - IMDS_VER_MIN - ) - report_diagnostic_event(log_msg, logger_func=LOG.info) - break + retries: int, + md_type: MetadataType = MetadataType.ALL, + exc_cb=imds_readurl_exception_callback, + infinite: bool = False, + ) -> dict: + """Fetch metadata from IMDS using IMDS_VER_WANT API version. - LOG.info("Using IMDS api-version: %s", IMDS_VER_MIN) - return get_metadata_from_imds( - retries=retries, - md_type=md_type, - api_version=IMDS_VER_MIN, - exc_cb=exc_cb, - infinite=infinite, - ) + Falls back to IMDS_VER_MIN version if IMDS returns a 400 error code, + indicating that IMDS_VER_WANT is unsupported. + + :return: Parsed metadata dictionary or empty dict on error. + """ + LOG.info("Attempting IMDS api-version: %s", IMDS_VER_WANT) + try: + return get_metadata_from_imds( + retries=retries, + md_type=md_type, + api_version=IMDS_VER_WANT, + exc_cb=exc_cb, + infinite=infinite, + ) + except UrlError as error: + LOG.info("UrlError with IMDS api-version: %s", IMDS_VER_WANT) + # Fall back if HTTP code is 400, otherwise return empty dict. + if error.code != 400: + return {} + + log_msg = "Fall back to IMDS api-version: {}".format(IMDS_VER_MIN) + report_diagnostic_event(log_msg, logger_func=LOG.info) + try: + return get_metadata_from_imds( + retries=retries, + md_type=md_type, + api_version=IMDS_VER_MIN, + exc_cb=exc_cb, + infinite=infinite, + ) + except UrlError as error: + report_diagnostic_event( + "Failed to fetch IMDS metadata: %s" % error, + logger_func=LOG.error, + ) + return {} def device_name_to_device(self, name): return self.ds_cfg["disk_aliases"].get(name) @@ -2310,7 +2327,7 @@ def get_metadata_from_imds( retries, md_type=MetadataType.ALL, api_version=IMDS_VER_MIN, - exc_cb=retry_on_url_exc, + exc_cb=imds_readurl_exception_callback, infinite=False, ): """Query Azure's instance metadata service, returning a dictionary. diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index 5b2f2ef9fd2..c577e8da3ac 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -639,16 +639,22 @@ def oauth_headers( return signed_headers -def retry_on_url_exc(msg, exc): - """readurl exception_cb that will retry on NOT_FOUND and Timeout. +def retry_on_url_exc( + msg, exc, *, retry_codes=(NOT_FOUND,), retry_instances=(requests.Timeout,) +): + """Configurable retry exception callback for readurl(). + + :param retry_codes: Codes to retry on. Defaults to 404. + :param retry_instances: Exception types to retry on. Defaults to + requests.Timeout. - Returns False to raise the exception from readurl, True to retry. + :returns: False to raise the exception from readurl(), True to retry. """ if not isinstance(exc, UrlError): return False - if exc.code == NOT_FOUND: + if exc.code in retry_codes: return True - if exc.cause and isinstance(exc.cause, requests.Timeout): + if exc.cause and isinstance(exc.cause, retry_instances): return True return False diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py index 5f956a63ae3..6f720e4e7d4 100644 --- a/tests/unittests/sources/test_azure.py +++ b/tests/unittests/sources/test_azure.py @@ -3,6 +3,7 @@ import copy import crypt import json +import logging import os import stat import xml.etree.ElementTree as ET @@ -151,12 +152,24 @@ def mock_readurl(): yield m +@pytest.fixture +def mock_requests_session_request(): + with mock.patch("requests.Session.request", autospec=True) as m: + yield m + + @pytest.fixture def mock_subp_subp(): with mock.patch(MOCKPATH + "subp.subp", side_effect=[]) as m: yield m +@pytest.fixture +def mock_url_helper_time_sleep(): + with mock.patch("cloudinit.url_helper.time.sleep", autospec=True) as m: + yield m + + @pytest.fixture def mock_util_ensure_dir(): with mock.patch( @@ -2220,10 +2233,11 @@ def get_metadata_from_imds_side_eff(*args, **kwargs): assert m_get_metadata_from_imds.mock_calls == [ mock.call( - retries=0, + retries=10, md_type=dsaz.MetadataType.ALL, api_version="2021-08-01", exc_cb=mock.ANY, + infinite=False, ), mock.call( retries=10, @@ -2250,10 +2264,11 @@ def test_imds_api_version_wanted_exists(self, m_get_metadata_from_imds): assert m_get_metadata_from_imds.mock_calls == [ mock.call( - retries=0, + retries=10, md_type=dsaz.MetadataType.ALL, api_version="2021-08-01", exc_cb=mock.ANY, + infinite=False, ) ] @@ -3720,6 +3735,195 @@ def test_non_ascii_seed_is_serializable(self): self.assertEqual(deserialized["seed"], result) +def fake_http_error_for_code(status_code: int): + response_failure = requests.Response() + response_failure.status_code = status_code + return requests.exceptions.HTTPError( + "fake error", + response=response_failure, + ) + + +@pytest.mark.parametrize( + "md_type,expected_url", + [ + ( + dsaz.MetadataType.ALL, + "http://169.254.169.254/metadata/instance?" + "api-version=2021-08-01&extended=true", + ), + ( + dsaz.MetadataType.NETWORK, + "http://169.254.169.254/metadata/instance/network?" + "api-version=2021-08-01", + ), + ( + dsaz.MetadataType.REPROVISION_DATA, + "http://169.254.169.254/metadata/reprovisiondata?" + "api-version=2021-08-01", + ), + ], +) +class TestIMDS: + def test_basic_scenarios( + self, azure_ds, caplog, mock_readurl, md_type, expected_url + ): + fake_md = {"foo": {"bar": []}} + mock_readurl.side_effect = [ + mock.MagicMock(contents=json.dumps(fake_md).encode()), + ] + + md = azure_ds.get_imds_data_with_api_fallback( + retries=5, + md_type=md_type, + ) + + assert md == fake_md + assert mock_readurl.mock_calls == [ + mock.call( + expected_url, + timeout=2, + headers={"Metadata": "true"}, + retries=5, + exception_cb=dsaz.imds_readurl_exception_callback, + infinite=False, + ), + ] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + fake_http_error_for_code(429), + fake_http_error_for_code(500), + requests.Timeout("Fake connection timeout"), + ], + ) + def test_will_retry_errors( + self, + azure_ds, + caplog, + md_type, + expected_url, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + mock.Mock(content=json.dumps(fake_md)), + ] + + md = azure_ds.get_imds_data_with_api_fallback( + retries=5, + md_type=md_type, + ) + + assert md == fake_md + assert len(mock_requests_session_request.mock_calls) == 2 + assert mock_url_helper_time_sleep.mock_calls == [mock.call(1)] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [] + + @pytest.mark.parametrize("retries", [0, 1, 5, 10]) + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(404), + fake_http_error_for_code(410), + fake_http_error_for_code(429), + fake_http_error_for_code(500), + requests.Timeout("Fake connection timeout"), + ], + ) + def test_retry_until_failure( + self, + azure_ds, + caplog, + md_type, + expected_url, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + retries, + ): + mock_requests_session_request.side_effect = [error] * (retries + 1) + + assert ( + azure_ds.get_imds_data_with_api_fallback( + retries=retries, + md_type=md_type, + ) + == {} + ) + + assert len(mock_requests_session_request.mock_calls) == (retries + 1) + assert ( + mock_url_helper_time_sleep.mock_calls == [mock.call(1)] * retries + ) + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [ + "Ignoring IMDS instance metadata. " + "Get metadata from IMDS failed: %s" % error + ] + + @pytest.mark.parametrize( + "error", + [ + fake_http_error_for_code(403), + fake_http_error_for_code(501), + requests.ConnectionError("Fake Network Unreachable"), + ], + ) + def test_will_not_retry_errors( + self, + azure_ds, + caplog, + md_type, + expected_url, + mock_requests_session_request, + mock_url_helper_time_sleep, + error, + ): + fake_md = {"foo": {"bar": []}} + mock_requests_session_request.side_effect = [ + error, + mock.Mock(content=json.dumps(fake_md)), + ] + + assert ( + azure_ds.get_imds_data_with_api_fallback( + retries=5, + md_type=md_type, + ) + == {} + ) + + assert len(mock_requests_session_request.mock_calls) == 1 + assert mock_url_helper_time_sleep.mock_calls == [] + + warnings = [ + x.message for x in caplog.records if x.levelno == logging.WARNING + ] + assert warnings == [ + "Ignoring IMDS instance metadata. " + "Get metadata from IMDS failed: %s" % error + ] + + class TestProvisioning: @pytest.fixture(autouse=True) def provisioning_setup( @@ -3816,8 +4020,8 @@ def test_no_pps(self): "api-version=2021-08-01&extended=true", timeout=2, headers={"Metadata": "true"}, - retries=0, - exception_cb=dsaz.retry_on_url_exc, + retries=10, + exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, ), ] @@ -3886,8 +4090,8 @@ def test_running_pps(self): "api-version=2021-08-01&extended=true", timeout=2, headers={"Metadata": "true"}, - retries=0, - exception_cb=dsaz.retry_on_url_exc, + retries=10, + exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, ), mock.call( @@ -3904,8 +4108,8 @@ def test_running_pps(self): "api-version=2021-08-01&extended=true", timeout=2, headers={"Metadata": "true"}, - retries=0, - exception_cb=dsaz.retry_on_url_exc, + retries=10, + exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, ), ] @@ -4005,13 +4209,13 @@ def test_savable_pps(self): "api-version=2021-08-01&extended=true", timeout=2, headers={"Metadata": "true"}, - retries=0, - exception_cb=dsaz.retry_on_url_exc, + retries=10, + exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, ), mock.call( "http://169.254.169.254/metadata/instance/network?" - "api-version=2019-06-01", + "api-version=2021-08-01", timeout=2, headers={"Metadata": "true"}, retries=0, @@ -4032,8 +4236,8 @@ def test_savable_pps(self): "api-version=2021-08-01&extended=true", timeout=2, headers={"Metadata": "true"}, - retries=0, - exception_cb=dsaz.retry_on_url_exc, + retries=10, + exception_cb=dsaz.imds_readurl_exception_callback, infinite=False, ), ]