diff --git a/cloudinit/sources/DataSourceEc2.py b/cloudinit/sources/DataSourceEc2.py index a030b4987b9..961c5090dc3 100644 --- a/cloudinit/sources/DataSourceEc2.py +++ b/cloudinit/sources/DataSourceEc2.py @@ -55,7 +55,11 @@ class DataSourceEc2(sources.DataSource): # Default metadata urls that will be used if none are provided # They will be checked for 'resolveability' and some of the # following may be discarded if they do not resolve - metadata_urls = ["http://169.254.169.254", "http://instance-data.:8773"] + metadata_urls = [ + "http://169.254.169.254", + "http://[fd00:ec2::254]", + "http://instance-data.:8773", + ] # The minimum supported metadata_version from the ec2 metadata apis min_metadata_version = "2009-04-04" @@ -253,6 +257,7 @@ def _maybe_fetch_api_token(self, mdurls, timeout=None, max_wait=None): exception_cb=self._imds_exception_cb, request_method=request_method, headers_redact=AWS_TOKEN_REDACT, + connect_synchronously=False, ) except uhelp.UrlError: # We use the raised exception to interupt the retry loop. diff --git a/cloudinit/url_helper.py b/cloudinit/url_helper.py index b827eba9140..ebb6bc4a0ef 100644 --- a/cloudinit/url_helper.py +++ b/cloudinit/url_helper.py @@ -11,12 +11,15 @@ import copy import json import os +import threading import time +from concurrent.futures import ThreadPoolExecutor, TimeoutError, as_completed from email.utils import parsedate from errno import ENOENT from functools import partial from http.client import NOT_FOUND from itertools import count +from typing import Any, Callable, List, Tuple from urllib.parse import quote, urlparse, urlunparse import requests @@ -187,7 +190,7 @@ def readurl( session=None, infinite=False, log_req_resp=True, - request_method=None, + request_method="", ) -> UrlResponse: """Wrapper around requests.Session to read the url and retry if necessary @@ -347,17 +350,119 @@ def _cb(url): raise excps[-1] +def _run_func_with_delay( + func: Callable[..., Any], + addr: str, + timeout: int, + event: threading.Event, + delay: float = None, +) -> Any: + """Execute func with optional delay""" + if delay: + + # event returns True iff the flag is set to true: indicating that + # another thread has already completed successfully, no need to try + # again - exit early + if event.wait(timeout=delay): + return + return func(addr, timeout) + + +def dual_stack( + func: Callable[..., Any], + addresses: List[str], + stagger_delay: float = 0.150, + timeout: int = 10, +) -> Tuple: + """execute multiple callbacks in parallel + + Run blocking func against two different addresses staggered with a + delay. The first call to return successfully is returned from this + function and remaining unfinished calls are cancelled if they have not + yet started + """ + return_result = None + returned_address = None + last_exception = None + exceptions = [] + is_done = threading.Event() + + # future work: add cancel_futures to Python stdlib ThreadPoolExecutor + # context manager implementation + # + # for now we don't use this feature since it only supports python >3.8 + # and doesn't provide a context manager and only marginal benefit + executor = ThreadPoolExecutor(max_workers=len(addresses)) + try: + futures = { + executor.submit( + _run_func_with_delay, + func=func, + addr=addr, + timeout=timeout, + event=is_done, + delay=(i * stagger_delay), + ): addr + for i, addr in enumerate(addresses) + } + + # handle returned requests in order of completion + for future in as_completed(futures, timeout=timeout): + + returned_address = futures[future] + return_exception = future.exception() + if return_exception: + last_exception = return_exception + else: + return_result = future.result() + if return_result: + + # communicate to other threads that they do not need to + # try: this thread has already succeeded + is_done.set() + return (returned_address, return_result) + + # No success, return the last exception but log them all for + # debugging + if last_exception: + LOG.warning( + "Exception(s) %s during request to %s, " + "raising last exception", + " ".join(exceptions), + returned_address, + ) + raise last_exception + else: + LOG.error("Empty result for address %s", returned_address) + raise ValueError("No result returned") + + # when max_wait expires, log but don't throw (retries happen) + except TimeoutError: + LOG.warning( + "Timed out waiting for addresses: %s, " + "exception(s) raised while waiting: %s", + " ".join(addresses), + " ".join(exceptions), + ) + finally: + executor.shutdown(wait=False) + + return (returned_address, return_result) + + def wait_for_url( urls, max_wait=None, timeout=None, - status_cb=None, - headers_cb=None, + status_cb: Callable = LOG.debug, # some sources use different log levels + headers_cb: Callable = None, headers_redact=None, - sleep_time=1, - exception_cb=None, - sleep_time_cb=None, - request_method=None, + sleep_time: int = 1, + exception_cb: Callable = None, + sleep_time_cb: Callable = None, + request_method: str = "", + connect_synchronously: bool = True, + async_delay: float = 0.150, ): """ urls: a list of urls to try @@ -375,6 +480,8 @@ def wait_for_url( sleep_time_cb: call method with 2 arguments (response, loop_n) that generates the next sleep time. request_method: indicate the type of HTTP request, GET, PUT, or POST + connect_synchronously: if false, enables executing requests in parallel + async_delay: delay before parallel metadata requests, see RFC 6555 returns: tuple of (url, response contents), on failure, (False, None) the idea of this routine is to wait for the EC2 metadata service to @@ -394,31 +501,94 @@ def wait_for_url( A value of None for max_wait will retry indefinitely. """ - start_time = time.time() - def log_status_cb(msg, exc=None): - LOG.debug(msg) - - if status_cb is None: - status_cb = log_status_cb + def default_sleep_time(_, loop_number: int): + return int(loop_number / 5) + 1 def timeup(max_wait, start_time): + """Check if time is up based on start time and max wait""" if max_wait is None: return False return (max_wait <= 0) or (time.time() - start_time > max_wait) - loop_n = 0 - response = None - while True: - if sleep_time_cb is not None: - sleep_time = sleep_time_cb(response, loop_n) + def handle_url_response(response, url): + """Map requests response code/contents to internal "UrlError" type""" + if not response.contents: + reason = "empty response [%s]" % (response.code) + url_exc = UrlError( + ValueError(reason), + code=response.code, + headers=response.headers, + url=url, + ) + elif not response.ok(): + reason = "bad status code [%s]" % (response.code) + url_exc = UrlError( + ValueError(reason), + code=response.code, + headers=response.headers, + url=url, + ) else: - sleep_time = int(loop_n / 5) + 1 + reason = "" + url_exc = None + return (url_exc, reason) + + def read_url_handle_exceptions( + url_reader_cb, urls, start_time, exc_cb, log_cb + ): + """Execute request, handle response, optionally log exception""" + reason = "" + url = None + try: + url, response = url_reader_cb(urls) + url_exc, reason = handle_url_response(response, url) + if not url_exc: + return (url, response) + except UrlError as e: + reason = "request error [%s]" % e + url_exc = e + except Exception as e: + reason = "unexpected error [%s]" % e + url_exc = e + time_taken = int(time.time() - start_time) + max_wait_str = "%ss" % max_wait if max_wait else "unlimited" + status_msg = "Calling '%s' failed [%s/%s]: %s" % ( + url, + time_taken, + max_wait_str, + reason, + ) + log_cb(status_msg) + if exc_cb: + # This can be used to alter the headers that will be sent + # in the future, for example this is what the MAAS datasource + # does. + exc_cb(msg=status_msg, exception=url_exc) + + def read_url_cb(url, timeout): + return readurl( + url, + headers={} if headers_cb is None else headers_cb(url), + headers_redact=headers_redact, + timeout=timeout, + check_status=False, + request_method=request_method, + ) + + def read_url_serial(start_time, timeout, exc_cb, log_cb): + """iterate over list of urls, request each one and handle responses + and thrown exceptions individually per url + """ + + def url_reader_serial(url): + return (url, read_url_cb(url, timeout)) + for url in urls: now = time.time() if loop_n != 0: if timeup(max_wait, start_time): - break + return if ( max_wait is not None and timeout @@ -427,61 +597,52 @@ def timeup(max_wait, start_time): # shorten timeout to not run way over max_time timeout = int((start_time + max_wait) - now) - reason = "" - url_exc = None - try: - if headers_cb is not None: - headers = headers_cb(url) - else: - headers = {} - - response = readurl( - url, - headers=headers, - headers_redact=headers_redact, - timeout=timeout, - check_status=False, - request_method=request_method, - ) - if not response.contents: - reason = "empty response [%s]" % (response.code) - url_exc = UrlError( - ValueError(reason), - code=response.code, - headers=response.headers, - url=url, - ) - elif not response.ok(): - reason = "bad status code [%s]" % (response.code) - url_exc = UrlError( - ValueError(reason), - code=response.code, - headers=response.headers, - url=url, - ) - else: - return url, response.contents - except UrlError as e: - reason = "request error [%s]" % e - url_exc = e - except Exception as e: - reason = "unexpected error [%s]" % e - url_exc = e - - time_taken = int(time.time() - start_time) - max_wait_str = "%ss" % max_wait if max_wait else "unlimited" - status_msg = "Calling '%s' failed [%s/%s]: %s" % ( - url, - time_taken, - max_wait_str, - reason, + out = read_url_handle_exceptions( + url_reader_serial, url, start_time, exc_cb, log_cb ) - status_cb(status_msg) - if exception_cb: - # This can be used to alter the headers that will be sent - # in the future, for example this is what the MAAS datasource - # does. - exception_cb(msg=status_msg, exception=url_exc) + if out: + return out + + def read_url_parallel(start_time, timeout, exc_cb, log_cb): + """pass list of urls to dual_stack which sends requests in parallel + handle response and exceptions of the first endpoint to respond + """ + url_reader_parallel = partial( + dual_stack, + read_url_cb, + stagger_delay=async_delay, + timeout=timeout, + ) + out = read_url_handle_exceptions( + url_reader_parallel, urls, start_time, exc_cb, log_cb + ) + if out: + return out + + start_time = time.time() + + # Dual-stack support factored out serial and parallel execution paths to + # allow the retry loop logic to exist separately from the http calls. + # Serial execution should be fundamentally the same as before, but with a + # layer of indirection so that the parallel dual-stack path may use the + # same max timeout logic. + do_read_url = ( + read_url_serial if connect_synchronously else read_url_parallel + ) + + calculate_sleep_time = ( + default_sleep_time if not sleep_time_cb else sleep_time_cb + ) + + loop_n: int = 0 + response = None + while True: + sleep_time = calculate_sleep_time(response, loop_n) + + url = do_read_url(start_time, timeout, exception_cb, status_cb) + if url: + address, response = url + return (address, response.contents) if timeup(max_wait, start_time): break @@ -492,6 +653,11 @@ def timeup(max_wait, start_time): ) time.sleep(sleep_time) + # shorten timeout to not run way over max_time + # timeout=0.0 causes exceptions in urllib, set to None if zero + timeout = int((start_time + max_wait) - time.time()) or None + + LOG.error("Timed out, no response from urls: %s", urls) return False, None diff --git a/integration-requirements.txt b/integration-requirements.txt index 8329eeecefb..ad41a82946c 100644 --- a/integration-requirements.txt +++ b/integration-requirements.txt @@ -1,5 +1,5 @@ # PyPI requirements for cloud-init integration testing # https://cloudinit.readthedocs.io/en/latest/topics/integration_tests.html # -pycloudlib @ git+https://github.com/canonical/pycloudlib.git@44206bb95c49901d994c9eb772eba07f2a1b6661 +pycloudlib @ git+https://github.com/canonical/pycloudlib.git@a2fde24361eeb6ad96db2a1ccb5cd70c7d76aa7f pytest diff --git a/test-requirements.txt b/test-requirements.txt index 06dfbbec156..44a92430d92 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,3 +6,4 @@ pytest-cov # Only really needed on older versions of python setuptools jsonschema +responses diff --git a/tests/integration_tests/clouds.py b/tests/integration_tests/clouds.py index e5fe56d1565..0e2e1deb25c 100644 --- a/tests/integration_tests/clouds.py +++ b/tests/integration_tests/clouds.py @@ -207,9 +207,18 @@ def _get_cloud_instance(self): def _perform_launch(self, launch_kwargs, **kwargs): """Use a dual-stack VPC for cloud-init integration testing.""" - launch_kwargs["vpc"] = self.cloud_instance.get_or_create_vpc( - name="ec2-cloud-init-integration" - ) + if "vpc" not in launch_kwargs: + launch_kwargs["vpc"] = self.cloud_instance.get_or_create_vpc( + name="ec2-cloud-init-integration" + ) + # Enable IPv6 metadata at http://[fd00:ec2::254] + if "Ipv6AddressCount" not in launch_kwargs: + launch_kwargs["Ipv6AddressCount"] = 1 + if "MetadataOptions" not in launch_kwargs: + launch_kwargs["MetadataOptions"] = {} + if "HttpProtocolIpv6" not in launch_kwargs["MetadataOptions"]: + launch_kwargs["MetadataOptions"] = {"HttpProtocolIpv6": "enabled"} + pycloudlib_instance = self.cloud_instance.launch(**launch_kwargs) return pycloudlib_instance diff --git a/tests/integration_tests/datasources/test_ec2.py b/tests/integration_tests/datasources/test_ec2.py new file mode 100644 index 00000000000..8cde4dc9e08 --- /dev/null +++ b/tests/integration_tests/datasources/test_ec2.py @@ -0,0 +1,43 @@ +import re + +import pytest + +from tests.integration_tests.instances import IntegrationInstance + + +def _test_crawl(client, ip): + assert client.execute("cloud-init clean --logs").ok + assert client.execute("cloud-init init --local").ok + log = client.read_from_file("/var/log/cloud-init.log") + assert f"Using metadata source: '{ip}'" in log + result = re.findall( + r"Crawl of metadata service took (\d+.\d+) seconds", log + ) + if len(result) != 1: + pytest.fail(f"Expected 1 metadata crawl time, got {result}") + # 20 would still be a crazy long time for metadata service to crawl, + # but it's short enough to know we're not waiting for a response + assert float(result[0]) < 20 + + +@pytest.mark.ec2 +def test_dual_stack(client: IntegrationInstance): + # Drop IPv4 responses + assert client.execute("iptables -I INPUT -s 169.254.169.254 -j DROP").ok + _test_crawl(client, "http://[fd00:ec2::254]") + + # Block IPv4 requests + assert client.execute("iptables -I OUTPUT -d 169.254.169.254 -j REJECT").ok + _test_crawl(client, "http://[fd00:ec2::254]") + + # Re-enable IPv4 + assert client.execute("iptables -D OUTPUT -d 169.254.169.254 -j REJECT").ok + assert client.execute("iptables -D INPUT -s 169.254.169.254 -j DROP").ok + + # Drop IPv6 responses + assert client.execute("ip6tables -I INPUT -s fd00:ec2::254 -j DROP").ok + _test_crawl(client, "http://169.254.169.254") + + # Block IPv6 requests + assert client.execute("ip6tables -I OUTPUT -d fd00:ec2::254 -j REJECT").ok + _test_crawl(client, "http://169.254.169.254") diff --git a/tests/unittests/sources/test_ec2.py b/tests/unittests/sources/test_ec2.py index 7c8a5ea5ef9..a192bd0488f 100644 --- a/tests/unittests/sources/test_ec2.py +++ b/tests/unittests/sources/test_ec2.py @@ -2,10 +2,11 @@ import copy import json +import threading from unittest import mock -import httpretty import requests +import responses from cloudinit import helpers from cloudinit.sources import DataSourceEc2 as ec2 @@ -300,9 +301,10 @@ def register_helper(register, base_url, body): register(base_url, "not found", status=404) def myreg(*argc, **kwargs): - url = argc[0] - method = httpretty.PUT if ec2.API_TOKEN_ROUTE in url else httpretty.GET - return httpretty.register_uri(method, *argc, **kwargs) + url, body = argc + method = responses.PUT if ec2.API_TOKEN_ROUTE in url else responses.GET + status = kwargs.get("status", 200) + return responses.add(method, url, body, status=status) register_helper(myreg, base_url, data) @@ -339,6 +341,15 @@ def _setup_ds(self, sys_cfg, platform_data, md, md_version=None): if sys_cfg is None: sys_cfg = {} ds = self.datasource(sys_cfg=sys_cfg, distro=distro, paths=paths) + event = threading.Event() + p = mock.patch("time.sleep", event.wait) + p.start() + + def _mock_sleep(): + event.set() + p.stop() + + self.addCleanup(_mock_sleep) if not md_version: md_version = ds.min_metadata_version if platform_data is not None: @@ -382,6 +393,7 @@ def _setup_ds(self, sys_cfg, platform_data, md, md_version=None): register_mock_metaserver(instance_id_url, None) return ds + @responses.activate def test_network_config_property_returns_version_2_network_data(self): """network_config property returns network version 2 for metadata""" ds = self._setup_ds( @@ -416,6 +428,7 @@ def test_network_config_property_returns_version_2_network_data(self): m_get_mac.return_value = mac1 self.assertEqual(expected, ds.network_config) + @responses.activate def test_network_config_property_set_dhcp4(self): """network_config property configures dhcp4 on nics with local-ipv4s. @@ -454,6 +467,7 @@ def test_network_config_property_set_dhcp4(self): m_get_mac.return_value = mac1 self.assertEqual(expected, ds.network_config) + @responses.activate def test_network_config_property_secondary_private_ips(self): """network_config property configures any secondary ipv4 addresses. @@ -497,6 +511,7 @@ def test_network_config_property_secondary_private_ips(self): m_get_mac.return_value = mac1 self.assertEqual(expected, ds.network_config) + @responses.activate def test_network_config_property_is_cached_in_datasource(self): """network_config property is cached in DataSourceEc2.""" ds = self._setup_ds( @@ -508,6 +523,7 @@ def test_network_config_property_is_cached_in_datasource(self): self.assertEqual({"cached": "data"}, ds.network_config) @mock.patch("cloudinit.net.dhcp.maybe_perform_dhcp_discovery") + @responses.activate def test_network_config_cached_property_refreshed_on_upgrade(self, m_dhcp): """Refresh the network_config Ec2 cache if network key is absent. @@ -522,6 +538,17 @@ def test_network_config_cached_property_refreshed_on_upgrade(self, m_dhcp): md={"md": old_metadata}, ) self.assertTrue(ds.get_data()) + + # Workaround https://github.com/getsentry/responses/issues/212 + # Can be removed when requests < 0.17.0 is no longer tested + # i.e. after Focal is EOL + if hasattr(responses.mock, "_urls"): + for index, url in enumerate(responses.mock._urls): + if url["url"].startswith( + "http://169.254.169.254/2009-04-04/meta-data/" + ): + del responses.mock._urls[index] + # Provide new revision of metadata that contains network data register_mock_metaserver( "http://169.254.169.254/2009-04-04/meta-data/", DEFAULT_METADATA @@ -550,6 +577,7 @@ def test_network_config_cached_property_refreshed_on_upgrade(self, m_dhcp): } self.assertEqual(expected, ds.network_config) + @responses.activate def test_ec2_get_instance_id_refreshes_identity_on_upgrade(self): """get_instance-id gets DataSourceEc2Local.identity if not present. @@ -569,10 +597,11 @@ def test_ec2_get_instance_id_refreshes_identity_on_upgrade(self): ] + ds.extended_metadata_versions for ver in all_versions[:-1]: register_mock_metaserver( - "http://169.254.169.254/{0}/meta-data/instance-id".format(ver), + "http://[fd00:ec2::254]/{0}/meta-data/instance-id".format(ver), None, ) - ds.metadata_address = "http://169.254.169.254" + + ds.metadata_address = "http://[fd00:ec2::254]" register_mock_metaserver( "{0}/{1}/meta-data/".format(ds.metadata_address, all_versions[-1]), DEFAULT_METADATA, @@ -587,6 +616,7 @@ def test_ec2_get_instance_id_refreshes_identity_on_upgrade(self): ds.metadata = DEFAULT_METADATA self.assertEqual("my-identity-id", ds.get_instance_id()) + @responses.activate def test_classic_instance_true(self): """If no vpc-id in metadata, is_classic_instance must return true.""" md_copy = copy.deepcopy(DEFAULT_METADATA) @@ -603,6 +633,7 @@ def test_classic_instance_true(self): self.assertTrue(ds.get_data()) self.assertTrue(ds.is_classic_instance()) + @responses.activate def test_classic_instance_false(self): """If vpc-id in metadata, is_classic_instance must return false.""" ds = self._setup_ds( @@ -613,6 +644,7 @@ def test_classic_instance_false(self): self.assertTrue(ds.get_data()) self.assertFalse(ds.is_classic_instance()) + @responses.activate def test_aws_inaccessible_imds_service_fails_with_retries(self): """Inaccessibility of http://169.254.169.254 are retried.""" ds = self._setup_ds( @@ -629,15 +661,37 @@ def test_aws_inaccessible_imds_service_fails_with_retries(self): mock_success.ok.return_value = True with mock.patch("cloudinit.url_helper.readurl") as m_readurl: - m_readurl.side_effect = (conn_error, conn_error, mock_success) + # yikes, this endpoint needs help + m_readurl.side_effect = ( + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + conn_error, + mock_success, + ) with mock.patch("cloudinit.url_helper.time.sleep"): self.assertTrue(ds.wait_for_metadata_service()) # Just one /latest/api/token request - self.assertEqual(3, len(m_readurl.call_args_list)) + self.assertEqual(19, len(m_readurl.call_args_list)) for readurl_call in m_readurl.call_args_list: self.assertIn("latest/api/token", readurl_call[0][0]) + @responses.activate def test_aws_token_403_fails_without_retries(self): """Verify that 403s fetching AWS tokens are not retried.""" ds = self._setup_ds( @@ -645,27 +699,21 @@ def test_aws_token_403_fails_without_retries(self): sys_cfg={"datasource": {"Ec2": {"strict_id": False}}}, md=None, ) + token_url = self.data_url("latest", data_item="api/token") - httpretty.register_uri(httpretty.PUT, token_url, body={}, status=403) + responses.add(responses.PUT, token_url, status=403) self.assertFalse(ds.get_data()) # Just one /latest/api/token request logs = self.logs.getvalue() - failed_put_log = '"PUT /latest/api/token HTTP/1.1" 403 0' expected_logs = [ "WARNING: Ec2 IMDS endpoint returned a 403 error. HTTP endpoint is" " disabled. Aborting.", "WARNING: IMDS's HTTP endpoint is probably disabled", - failed_put_log, ] for log in expected_logs: self.assertIn(log, logs) - self.assertEqual( - 1, - len( - [line for line in logs.splitlines() if failed_put_log in line] - ), - ) + @responses.activate def test_aws_token_redacted(self): """Verify that aws tokens are redacted when logged.""" ds = self._setup_ds( @@ -684,6 +732,7 @@ def test_aws_token_redacted(self): self.assertEqual(83, len(logs_with_redacted)) self.assertEqual(0, len(logs_with_token)) + @responses.activate @mock.patch("cloudinit.net.dhcp.maybe_perform_dhcp_discovery") def test_valid_platform_with_strict_true(self, m_dhcp): """Valid platform data should return true with strict_id true.""" @@ -699,6 +748,7 @@ def test_valid_platform_with_strict_true(self, m_dhcp): self.assertEqual("ec2", ds.platform_type) self.assertEqual("metadata (%s)" % ds.metadata_address, ds.subplatform) + @responses.activate def test_valid_platform_with_strict_false(self): """Valid platform data should return true with strict_id false.""" ds = self._setup_ds( @@ -709,6 +759,7 @@ def test_valid_platform_with_strict_false(self): ret = ds.get_data() self.assertTrue(ret) + @responses.activate def test_unknown_platform_with_strict_true(self): """Unknown platform data with strict_id true should return False.""" uuid = "ab439480-72bf-11d3-91fc-b8aded755F9a" @@ -720,6 +771,7 @@ def test_unknown_platform_with_strict_true(self): ret = ds.get_data() self.assertFalse(ret) + @responses.activate def test_unknown_platform_with_strict_false(self): """Unknown platform data with strict_id false should return True.""" uuid = "ab439480-72bf-11d3-91fc-b8aded755F9a" @@ -731,6 +783,7 @@ def test_unknown_platform_with_strict_false(self): ret = ds.get_data() self.assertTrue(ret) + @responses.activate def test_ec2_local_returns_false_on_non_aws(self): """DataSourceEc2Local returns False when platform is not AWS.""" self.datasource = ec2.DataSourceEc2Local @@ -758,6 +811,7 @@ def test_ec2_local_returns_false_on_non_aws(self): self.assertIn(message, self.logs.getvalue()) @mock.patch("cloudinit.sources.DataSourceEc2.util.is_FreeBSD") + @responses.activate def test_ec2_local_returns_false_on_bsd(self, m_is_freebsd): """DataSourceEc2Local returns False on BSD. @@ -781,6 +835,7 @@ def test_ec2_local_returns_false_on_bsd(self, m_is_freebsd): @mock.patch("cloudinit.net.find_fallback_nic") @mock.patch("cloudinit.net.dhcp.maybe_perform_dhcp_discovery") @mock.patch("cloudinit.sources.DataSourceEc2.util.is_FreeBSD") + @responses.activate def test_ec2_local_performs_dhcp_on_non_bsd( self, m_is_bsd, m_dhcp, m_fallback_nic, m_net ): @@ -822,6 +877,7 @@ def test_ec2_local_performs_dhcp_on_non_bsd( ) self.assertIn("Crawl of metadata service took", self.logs.getvalue()) + @responses.activate def test_get_instance_tags(self): ds = self._setup_ds( platform_data=self.valid_platform_data, diff --git a/tests/unittests/test_url_helper.py b/tests/unittests/test_url_helper.py index 059809d9e27..5e09221938c 100644 --- a/tests/unittests/test_url_helper.py +++ b/tests/unittests/test_url_helper.py @@ -1,19 +1,25 @@ # This file is part of cloud-init. See LICENSE file for license information. import logging +from functools import partial +from threading import Event +from time import process_time import httpretty import pytest import requests +import responses from cloudinit import util, version from cloudinit.url_helper import ( NOT_FOUND, REDACTED, UrlError, + dual_stack, oauth_headers, read_file_or_url, retry_on_url_exc, + wait_for_url, ) from tests.unittests.helpers import CiTestCase, mock, skipIf @@ -250,4 +256,246 @@ def test_perform_retries_on_timeout(self): self.assertTrue(retry_on_url_exc(msg="", exc=myerror)) +def assert_time(func, max_time=1): + """Assert function time is bounded by a max (default=1s) + + The following async tests should canceled in under 1ms and have stagger + delay and max_ + It is possible that this could yield a false positive, but this should + basically never happen (esp under normal system load). + """ + start = process_time() + try: + out = func() + finally: + diff = process_time() - start + assert diff < max_time + return out + + +event = Event() + + +class TestDualStack: + """Async testing suggestions welcome - these all rely on time-bounded + assertions (via threading.Event) to prove ordering + """ + + @pytest.mark.parametrize( + "func," + "addresses," + "stagger_delay," + "timeout," + "expected_val," + "expected_exc", + [ + # Assert order based on timeout + (lambda x, _: x, ("one", "two"), 1, 1, "one", None), + # Assert timeout results in (None, None) + (lambda _a, _b: event.wait(1), ("one", "two"), 1, 0, None, None), + # Assert that exception in func is raised if all threads + # raise exception + # currently if all threads experience exception + # dual_stack() logs an error containing all exceptions + # but only raises the last exception to occur + ( + lambda _a, _b: 1 / 0, + ("one", "two"), + 0, + 1, + None, + ZeroDivisionError, + ), + # Verify "best effort behavior" + # dual_stack will temporarily ignore an exception in any of the + # request threads in hopes that a later thread will succeed + # this behavior is intended to allow a requests.ConnectionError + # exception from on endpoint to occur without preventing another + # thread from succeeding + ( + lambda a, _b: 1 / 0 if a == "one" else a, + ("one", "two"), + 0, + 1, + "two", + None, + ), + # Assert that exception in func is only raised + # if neither thread gets a valid result + ( + lambda a, _b: 1 / 0 if a == "two" else a, + ("one", "two"), + 0, + 1, + "one", + None, + ), + # simulate a slow response to verify correct order + ( + lambda x, _: event.wait(1) if x != "two" else x, + ("one", "two"), + 0, + 1, + "two", + None, + ), + # simulate a slow response to verify correct order + ( + lambda x, _: event.wait(1) if x != "tri" else x, + ("one", "two", "tri"), + 0, + 1, + "tri", + None, + ), + ], + ) + def test_dual_stack( + self, + func, + addresses, + stagger_delay, + timeout, + expected_val, + expected_exc, + ): + """Assert various failure modes behave as expected""" + event.clear() + + gen = partial( + dual_stack, + func, + addresses, + stagger_delay=stagger_delay, + timeout=timeout, + ) + if expected_exc: + with pytest.raises(expected_exc): + _, result = assert_time(gen) + assert expected_val == result + else: + _, result = assert_time(gen) + assert expected_val == result + event.set() + + def test_dual_stack_staggered(self): + """Assert expected call intervals occur""" + stagger = 0.1 + with mock.patch(M_PATH + "_run_func_with_delay") as delay_func: + dual_stack( + lambda x, _y: x, + ["you", "and", "me", "and", "dog"], + stagger_delay=stagger, + timeout=1, + ) + + # ensure that stagger delay for each subsequent call is: + # [ 0 * N, 1 * N, 2 * N, 3 * N, 4 * N, 5 * N] where N = stagger + # it appears that without an explicit wait/join we can't assert + # number of calls + for delay, call_item in enumerate(delay_func.call_args_list): + _, kwargs = call_item + assert stagger * delay == kwargs.get("delay") + + +ADDR1 = "https://addr1/" +SLEEP1 = "https://sleep1/" +SLEEP2 = "https://sleep2/" + + +class TestUrlHelper: + success = "SUCCESS" + fail = "FAIL" + event = Event() + + @classmethod + def response_wait(cls, _request): + cls.event.wait(0.1) + return (500, {"request-id": "1"}, cls.fail) + + @classmethod + def response_nowait(cls, _request): + return (200, {"request-id": "0"}, cls.success) + + @pytest.mark.parametrize( + "addresses," "expected_address_index," "response,", + [ + # Use timeout to test ordering happens as expected + ((ADDR1, SLEEP1), 0, "SUCCESS"), + ((SLEEP1, ADDR1), 1, "SUCCESS"), + ((SLEEP1, SLEEP2, ADDR1), 2, "SUCCESS"), + ((ADDR1, SLEEP1, SLEEP2), 0, "SUCCESS"), + ], + ) + @responses.activate + def test_order(self, addresses, expected_address_index, response): + """Check that the first response gets returned. Simulate a + non-responding endpoint with a response that has a one second wait. + + If this test proves flaky, increase wait time. Since it is async, + increasing wait time for the non-responding endpoint should not + increase total test time, assuming async_delay=0 is used and at least + one non-waiting endpoint is registered with httpretty. + Subsequent tests will continue execution after the first response is + received. + """ + self.event.clear() + for address in set(addresses): + responses.add_callback( + responses.GET, + address, + callback=( + self.response_wait + if "sleep" in address + else self.response_nowait + ), + content_type="application/json", + ) + + # Use async_delay=0.0 to avoid adding unnecessary time to tests + # In practice a value such as 0.150 is used + url, response_contents = wait_for_url( + urls=addresses, + max_wait=1, + timeout=1, + connect_synchronously=False, + async_delay=0.0, + ) + self.event.set() + + # Test for timeout (no responding endpoint) + assert addresses[expected_address_index] == url + assert response.encode() == response_contents + + @responses.activate + def test_timeout(self): + """If no endpoint responds in time, expect no response""" + + self.event.clear() + addresses = [SLEEP1, SLEEP2] + for address in set(addresses): + responses.add_callback( + responses.GET, + address, + callback=( + self.response_wait + if "sleep" in address + else self.response_nowait + ), + content_type="application/json", + ) + + # Use async_delay=0.0 to avoid adding unnecessary time to tests + url, response_contents = wait_for_url( + urls=addresses, + max_wait=1, + timeout=1, + connect_synchronously=False, + async_delay=0, + ) + self.event.set() + assert not url + assert not response_contents + + # vi: ts=4 expandtab diff --git a/tox.ini b/tox.ini index b9fe5622142..c3c4a016f5b 100644 --- a/tox.ini +++ b/tox.ini @@ -113,6 +113,7 @@ deps = pytest-cov==2.5.1 # Needed by pytest and default causes failures attrs==17.4.0 + responses==0.5.1 [testenv:lowest-supported] # This definition will run on bionic with the version of httpretty