From b9e35b940ea13e46cf0875674d9353f7d91f3c0f Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Tue, 9 Jun 2020 02:51:55 +0000 Subject: [PATCH 1/4] Refactor Azure report ready code. This PR refactors Azure report ready code to include more robust tests and telemetry. --- cloudinit/sources/DataSourceAzure.py | 54 ++- cloudinit/sources/helpers/azure.py | 403 +++++++++++++----- .../test_datasource/test_azure_helper.py | 372 +++++++++++++--- 3 files changed, 656 insertions(+), 173 deletions(-) diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 89312b9e9bd..ca63f3979dc 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -426,17 +426,19 @@ def crawl_metadata(self): ret = load_azure_ds_dir(cdev) except NonAzureDataSource: - report_diagnostic_event( - "Did not find Azure data source in %s" % cdev) + msg = "Did not find Azure data source in %s" % cdev + LOG.debug(msg) + report_diagnostic_event(msg) continue except BrokenAzureDataSource as exc: msg = 'BrokenAzureDataSource: %s' % exc + LOG.error(msg) report_diagnostic_event(msg) raise sources.InvalidMetaDataException(msg) except util.MountFailedError: msg = '%s was not mountable' % cdev - report_diagnostic_event(msg) LOG.warning(msg) + report_diagnostic_event(msg) continue perform_reprovision = reprovision or self._should_reprovision(ret) @@ -464,6 +466,7 @@ def crawl_metadata(self): if not found: msg = 'No Azure metadata found' + LOG.error(msg) report_diagnostic_event(msg) raise sources.InvalidMetaDataException(msg) @@ -488,8 +491,9 @@ def crawl_metadata(self): azure_ds_reporter) as lease: self._report_ready(lease=lease) except Exception as e: - report_diagnostic_event( - "exception while reporting ready: %s" % e) + msg = "exception while reporting ready: %s" % e + LOG.error(msg) + report_diagnostic_event(msg) raise return crawled_data @@ -617,10 +621,11 @@ def exc_cb(msg, exception): else: # If we get an exception while trying to call IMDS, we call # DHCP and setup the ephemeral network to acquire a new IP. - report_diagnostic_event("poll IMDS with %s failed. " - "Exception: %s and code: %s" % - (msg, exception.cause, - exception.code)) + evt_msg = ("poll IMDS with %s failed. " + "Exception: %s and code: %s" % + (msg, exception.cause, exception.code)) + LOG.warning(evt_msg) + report_diagnostic_event(evt_msg) return False LOG.debug("poll IMDS failed with an unexpected exception: %s", @@ -644,8 +649,8 @@ def exc_cb(msg, exception): try: nl_sock = netlink.create_bound_netlink_socket() except netlink.NetlinkCreateSocketError as e: - report_diagnostic_event(e) LOG.warning(e) + report_diagnostic_event(e) self._ephemeral_dhcp_ctx.clean_network() break @@ -665,8 +670,8 @@ def exc_cb(msg, exception): netlink.wait_for_media_disconnect_connect( nl_sock, lease['interface']) except AssertionError as error: - report_diagnostic_event(error) LOG.error(error) + report_diagnostic_event(error) break vnet_switched = True @@ -692,10 +697,12 @@ def exc_cb(msg, exception): nl_sock.close() if vnet_switched: - report_diagnostic_event("attempted dhcp %d times after reuse" % - dhcp_attempts) - report_diagnostic_event("polled imds %d times after reuse" % - self.imds_poll_counter) + msg = "attempted dhcp %d times after reuse" % dhcp_attempts + LOG.debug(msg) + report_diagnostic_event(msg) + msg = "polled imds %d times after reuse" % self.imds_poll_counter + LOG.debug(msg) + report_diagnostic_event(msg) return return_val @@ -768,12 +775,12 @@ def _negotiate(self): try: fabric_data = metadata_func() except Exception as e: + LOG.error( + "Error communicating with Azure fabric; You may experience " + "connectivity issues.", exc_info=True) report_diagnostic_event( "Error communicating with Azure fabric; You may experience " "connectivity issues: %s" % e) - LOG.warning( - "Error communicating with Azure fabric; You may experience " - "connectivity issues.", exc_info=True) return False util.del_file(REPORTED_READY_MARKER_FILE) @@ -1132,6 +1139,7 @@ def read_azure_ovf(contents): dom = minidom.parseString(contents) except Exception as e: error_str = "Invalid ovf-env.xml: %s" % e + LOG.error(error_str) report_diagnostic_event(error_str) raise BrokenAzureDataSource(error_str) @@ -1414,7 +1422,9 @@ def get_metadata_from_imds(fallback_nic, retries): azure_ds_reporter, fallback_nic): return util.log_time(**kwargs) except Exception as e: - report_diagnostic_event("exception while getting metadata: %s" % e) + msg = "exception while getting metadata: %s" % e + LOG.error(msg) + report_diagnostic_event(msg) raise @@ -1429,15 +1439,15 @@ def _get_metadata_from_imds(retries): retries=retries, exception_cb=retry_on_url_exc) except Exception as e: msg = 'Ignoring IMDS instance metadata: %s' % e + LOG.error(msg) report_diagnostic_event(msg) - LOG.debug(msg) return {} try: return util.load_json(str(response)) except json.decoder.JSONDecodeError as e: - report_diagnostic_event('non-json imds response' % e) - LOG.warning( + LOG.error( 'Ignoring non-json IMDS instance metadata: %s', str(response)) + report_diagnostic_event('non-json imds response' % e) return {} diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 7bace8ca294..5674111f857 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -213,24 +213,66 @@ def get(self, url, secure=False): if secure: headers = self.headers.copy() headers.update(self.extra_secure_headers) - return url_helper.read_file_or_url(url, headers=headers, timeout=5, - retries=10) + return url_helper.readurl(url, headers=headers, + timeout=5, retries=10, sec_between=5) def post(self, url, data=None, extra_headers=None): headers = self.headers if extra_headers is not None: headers = self.headers.copy() headers.update(extra_headers) - return url_helper.read_file_or_url(url, data=data, headers=headers, - timeout=5, retries=10) + return url_helper.readurl(url, data=data, headers=headers, + timeout=5, retries=10, sec_between=5) + + +class InvalidGoalStateXMLException(Exception): + """Raised when GoalState XML is invalid or has missing data.""" class GoalState(object): - def __init__(self, xml, http_client): - self.http_client = http_client - self.root = ElementTree.fromstring(xml) + def __init__(self, unparsed_xml, azure_endpoint_client): + """Parses a GoalState XML string and returns a GoalState object. + + @param unparsed_xml: string representing a GoalState XML. + @param azure_endpoint_client: instance of AzureEndpointHttpClient + @return: GoalState object representing the GoalState XML string. + """ + self.azure_endpoint_client = azure_endpoint_client + + try: + self.root = ElementTree.fromstring(unparsed_xml) + except ElementTree.ParseError as e: + msg = 'Failed to parse GoalState XML: %s' % e + LOG.warning(msg) + report_diagnostic_event(msg) + raise + + self._certificates_xml = None + self._container_id = self._text_from_xpath('./Container/ContainerId') + self._instance_id = self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance/InstanceId') + self._incarnation = self._text_from_xpath('./Incarnation') + + if any(x is None for x in (self.container_id, + self.instance_id, self.incarnation)): + msg = ('Missing container id/instance id/incarnation in ' + 'GoalState XML.') + LOG.warning(msg) + report_diagnostic_event(msg) + raise InvalidGoalStateXMLException(msg) + self._certificates_xml = None + url = self._text_from_xpath( + './Container/RoleInstanceList/RoleInstance' + '/Configuration/Certificates') + if url is not None: + self._certificates_xml = \ + self.azure_endpoint_client.get( + url, secure=True).contents + if self._certificates_xml is None: + raise InvalidGoalStateXMLException( + 'Azure endpoint returned empty certificates xml.') def _text_from_xpath(self, xpath): element = self.root.find(xpath) @@ -240,26 +282,18 @@ def _text_from_xpath(self, xpath): @property def container_id(self): - return self._text_from_xpath('./Container/ContainerId') + return self._container_id @property - def incarnation(self): - return self._text_from_xpath('./Incarnation') + def instance_id(self): + return self._instance_id @property - def instance_id(self): - return self._text_from_xpath( - './Container/RoleInstanceList/RoleInstance/InstanceId') + def incarnation(self): + return self._incarnation @property def certificates_xml(self): - if self._certificates_xml is None: - url = self._text_from_xpath( - './Container/RoleInstanceList/RoleInstance' - '/Configuration/Certificates') - if url is not None: - self._certificates_xml = self.http_client.get( - url, secure=True).contents return self._certificates_xml @@ -370,25 +404,105 @@ def parse_certificates(self, certificates_xml): return keys -class WALinuxAgentShim(object): +class GoalStateHealthReporter(object): + + HEALTH_REPORT_XML_TEMPLATE = textwrap.dedent('''\ + + + {incarnation} + + {container_id} + + + {instance_id} + + {health_status} + {health_detail_subsection} + + + + + + ''') + + HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = textwrap.dedent('''\ +
+ {health_substatus} + {health_description} +
+ ''') + + PROVISIONING_SUCCESS_STATUS = 'Ready' + + def __init__(self, goal_state, azure_endpoint_client, endpoint): + """Creates instance that will report provisioning status to an endpoint + + @param goal_state: An instance of class GoalState that contains + goal state info such as incarnation, container id, and instance id. + These 3 values are needed when reporting the provisioning status + to Azure + @param azure_endpoint_client: Instance of class AzureEndpointHttpClient + @param endpoint: Endpoint (string) where the provisioning status report + will be sent to + @return: Instance of class GoalStateHealthReporter + """ + self._goal_state = goal_state + self._azure_endpoint_client = azure_endpoint_client + self._endpoint = endpoint - REPORT_READY_XML_TEMPLATE = '\n'.join([ - '', - '', - ' {incarnation}', - ' ', - ' {container_id}', - ' ', - ' ', - ' {instance_id}', - ' ', - ' Ready', - ' ', - ' ', - ' ', - ' ', - '']) + @azure_ds_telemetry_reporter + def send_ready_signal(self): + document = self.build_report( + incarnation=self._goal_state.incarnation, + container_id=self._goal_state.container_id, + instance_id=self._goal_state.instance_id, + status=self.PROVISIONING_SUCCESS_STATUS) + LOG.debug('Reporting ready to Azure fabric.') + try: + self._post_health_report(document=document) + except Exception as e: + msg = "exception while reporting ready: %s" % e + LOG.error(msg) + report_diagnostic_event(msg) + raise + + LOG.info('Reported ready to Azure fabric.') + + def build_report(self, incarnation, container_id, instance_id, + status, substatus=None, description=None): + health_detail = '' + if substatus is not None: + health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format( + health_substatus=substatus, health_description=description) + + health_report = self.HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=incarnation, + container_id=container_id, + instance_id=instance_id, + health_status=status, + health_detail_subsection=health_detail) + + return health_report + + @azure_ds_telemetry_reporter + def _post_health_report(self, document): + """Host will collect kvps when cloud-init reports to fabric. + Some kvps might still be in the queue. We yield the scheduler + to make sure we process all kvps up till this point. + """ + time.sleep(0) + + LOG.debug('Sending health report to Azure fabric.') + url = "http://{0}/machine?comp=health".format(self._endpoint) + self._azure_endpoint_client.post( + url, + data=document, + extra_headers={'Content-Type': 'text/xml; charset=utf-8'}) + LOG.debug('Successfully sent health report to Azure fabric') + + +class WALinuxAgentShim(object): def __init__(self, fallback_lease_file=None, dhcp_options=None): LOG.debug('WALinuxAgentShim instantiated, fallback_lease_file=%s', @@ -494,84 +608,201 @@ def _get_value_from_dhcpoptions(dhcp_options): @staticmethod @azure_ds_telemetry_reporter def find_endpoint(fallback_lease_file=None, dhcp245=None): + """Finds and returns the Azure endpoint using various methods. + + The Azure endpoint is searched in the following order: + 1. Endpoint from dhcp options (dhcp option 245). + 2. Endpoint from networkd. + 3. Endpoint from dhclient hook json. + 4. Endpoint from fallback lease file. + 5. The default Azure endpoint. + + @param fallback_lease_file: Fallback lease file that will be used + during endpoint search. + @param dhcp245: dhcp options that will be used during endpoint search. + @return: Azure endpoint IP address. + """ value = None + if dhcp245 is not None: value = dhcp245 - LOG.debug("Using Azure Endpoint from dhcp options") + LOG.debug("Using Azure Endpoint from dhcp options.") + if value is None: - report_diagnostic_event("No Azure endpoint from dhcp options") - LOG.debug('Finding Azure endpoint from networkd...') + msg = ("No Azure endpoint from dhcp options. " + "Finding Azure endpoint from networkd...") + LOG.debug(msg) + report_diagnostic_event(msg) value = WALinuxAgentShim._networkd_get_value_from_leases() + if value is None: # Option-245 stored in /run/cloud-init/dhclient.hooks/.json # a dhclient exit hook that calls cloud-init-dhclient-hook - report_diagnostic_event("No Azure endpoint from networkd") - LOG.debug('Finding Azure endpoint from hook json...') + msg = ("No Azure endpoint from networkd. " + "Finding Azure endpoint from hook json...") + LOG.debug(msg) + report_diagnostic_event(msg) dhcp_options = WALinuxAgentShim._load_dhclient_json() value = WALinuxAgentShim._get_value_from_dhcpoptions(dhcp_options) + if value is None: # Fallback and check the leases file if unsuccessful - report_diagnostic_event("No Azure endpoint from dhclient logs") - LOG.debug("Unable to find endpoint in dhclient logs. " - " Falling back to check lease files") + msg = ("No Azure endpoint from dhclient logs. " + "Unable to find endpoint in dhclient logs. " + "Falling back to check lease files.") + LOG.debug(msg) + report_diagnostic_event(msg) + if fallback_lease_file is None: - LOG.warning("No fallback lease file was specified.") + msg = "No fallback lease file was specified." + LOG.warning(msg) + report_diagnostic_event(msg) value = None else: LOG.debug("Looking for endpoint in lease file %s", fallback_lease_file) value = WALinuxAgentShim._get_value_from_leases_file( fallback_lease_file) + if value is None: msg = "No lease found; using default endpoint" - report_diagnostic_event(msg) LOG.warning(msg) + report_diagnostic_event(msg) value = DEFAULT_WIRESERVER_ENDPOINT endpoint_ip_address = WALinuxAgentShim.get_ip_from_lease_value(value) msg = 'Azure endpoint found at %s' % endpoint_ip_address - report_diagnostic_event(msg) LOG.debug(msg) + report_diagnostic_event(msg) return endpoint_ip_address @azure_ds_telemetry_reporter def register_with_azure_and_fetch_data(self, pubkey_info=None): + """Gets the VM's GoalState from Azure, uses the GoalState information + to report ready/send the ready signal/provisioning complete signal to + Azure, and then uses pubkey_info to filter and obtain the user's + pubkeys from the GoalState. + + @param pubkey_info: List of pubkey values and fingerprints which are + used to filter and obtain the user's pubkey values from the + GoalState. + @return: The list of user's authorized pubkey values. + """ if self.openssl_manager is None: self.openssl_manager = OpenSSLManager() - http_client = AzureEndpointHttpClient(self.openssl_manager.certificate) + azure_endpoint_client = AzureEndpointHttpClient( + self.openssl_manager.certificate) + goal_state = self._fetch_goal_state_from_azure(azure_endpoint_client) + ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info) + health_reporter = GoalStateHealthReporter( + goal_state, azure_endpoint_client, self.endpoint) + health_reporter.send_ready_signal() + return {'public-keys': ssh_keys} + + @azure_ds_telemetry_reporter + def _fetch_goal_state_from_azure(self, azure_endpoint_client): + """Fetches the GoalState XML from the Azure endpoint, parses the XML, + and returns a GoalState object. + + @param azure_endpoint_client: instance of AzureEndpointHttpClient + @return: GoalState object representing the GoalState XML + """ + unparsed_goal_state_xml = self._get_goal_state_from_azure( + azure_endpoint_client) + return self._parse_goal_state( + unparsed_goal_state_xml, azure_endpoint_client) + + @azure_ds_telemetry_reporter + def _get_goal_state_from_azure(self, azure_endpoint_client): + """Fetches the GoalState XML from the Azure endpoint and returns + the XML as a string. + + @param azure_endpoint_client: instance of AzureEndpointHttpClient + @return: GoalState XML string + """ + LOG.info('Registering with Azure...') - attempts = 0 - while True: - try: - response = http_client.get( - 'http://{0}/machine/?comp=goalstate'.format(self.endpoint)) - except Exception as e: - if attempts < 10: - time.sleep(attempts + 1) - else: - report_diagnostic_event( - "failed to register with Azure: %s" % e) - raise - else: - break - attempts += 1 + url = 'http://{0}/machine/?comp=goalstate'.format(self.endpoint) + try: + response = azure_endpoint_client.get(url) + except Exception as e: + msg = 'failed to register with Azure: %s' % e + LOG.warning(msg) + report_diagnostic_event(msg) + raise LOG.debug('Successfully fetched GoalState XML.') - goal_state = GoalState(response.contents, http_client) - report_diagnostic_event("container_id %s" % goal_state.container_id) + return response.contents + + @azure_ds_telemetry_reporter + def _parse_goal_state(self, + unparsed_goal_state_xml, + azure_endpoint_client): + """Parses a GoalState XML string and returns a GoalState object. + + @param unparsed_goal_state_xml: GoalState XML string + @param azure_endpoint_client: instance of AzureEndpointHttpClient + @return: GoalState object representing the GoalState XML + """ + try: + goal_state = GoalState( + unparsed_goal_state_xml, azure_endpoint_client) + except Exception as e: + msg = 'Error processing GoalState XML: %s' % e + LOG.warning(msg) + report_diagnostic_event(msg) + raise + msg = ', '.join([ + 'GoalState XML container id: %s' % goal_state.container_id, + 'GoalState XML instance id: %s' % goal_state.instance_id, + 'GoalState XML incarnation: %s' % goal_state.incarnation]) + LOG.debug(msg) + report_diagnostic_event(msg) + return goal_state + + @azure_ds_telemetry_reporter + def _get_user_pubkeys(self, goal_state, pubkey_info): + """Gets and filters the VM user's authorized pubkeys. + + cloud-init expects a straightforward array of keys to be dropped + into the user's authorized_keys file. Azure control plane exposes + multiple public keys to the VM via wireserver. Select just the + user's key(s) and return them, ignoring any other certs. + + @param goal_state: GoalState object. The GoalState object contains + a certificate XML, which contains both the VM user's authorized + pubkeys and other non-user pubkeys, which are used for + MSI and protected extension handling. + @param pubkey_info: List of VM user pubkey dicts that were previously + obtained from provisioning data. + Each pubkey dict in this list can either have the format + pubkey['value'] or pubkey['fingerprint']. + Each pubkey['fingerprint'] in the list is used to filter + and obtain the actual pubkey value from the GoalState + certificates XML. + Each pubkey['value'] requires no further processing and is + immediately added to the return list. + @return: A list of the VM user's authorized pubkey values. + """ ssh_keys = [] if goal_state.certificates_xml is not None and pubkey_info is not None: LOG.debug('Certificate XML found; parsing out public keys.') keys_by_fingerprint = self.openssl_manager.parse_certificates( goal_state.certificates_xml) ssh_keys = self._filter_pubkeys(keys_by_fingerprint, pubkey_info) - self._report_ready(goal_state, http_client) - return {'public-keys': ssh_keys} + return ssh_keys - def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info): - """cloud-init expects a straightforward array of keys to be dropped - into the user's authorized_keys file. Azure control plane exposes - multiple public keys to the VM via wireserver. Select just the - user's key(s) and return them, ignoring any other certs. + @staticmethod + def _filter_pubkeys(keys_by_fingerprint, pubkey_info): + """ Filter and return only the user's actual pubkeys. + + @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict + that was obtained from GoalState Certificates XML. May contain + non-user pubkeys. + @param pubkey_info: List of VM user pubkeys. Pubkey values are added + to the return list without further processing. Pubkey fingerprints + are used to filter and obtain the actual pubkey values from + keys_by_fingerprint. + @return: A list of the VM user's authorized pubkey values. """ keys = [] for pubkey in pubkey_info: @@ -590,30 +821,6 @@ def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info): return keys - @azure_ds_telemetry_reporter - def _report_ready(self, goal_state, http_client): - LOG.debug('Reporting ready to Azure fabric.') - document = self.REPORT_READY_XML_TEMPLATE.format( - incarnation=goal_state.incarnation, - container_id=goal_state.container_id, - instance_id=goal_state.instance_id, - ) - # Host will collect kvps when cloud-init reports ready. - # some kvps might still be in the queue. We yield the scheduler - # to make sure we process all kvps up till this point. - time.sleep(0) - try: - http_client.post( - "http://{0}/machine?comp=health".format(self.endpoint), - data=document, - extra_headers={'Content-Type': 'text/xml; charset=utf-8'}, - ) - except Exception as e: - report_diagnostic_event("exception while reporting ready: %s" % e) - raise - - LOG.info('Reported ready to Azure fabric.') - @azure_ds_telemetry_reporter def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None, diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py index 71ef57f0c4d..1954b53336b 100644 --- a/tests/unittests/test_datasource/test_azure_helper.py +++ b/tests/unittests/test_datasource/test_azure_helper.py @@ -1,8 +1,10 @@ # This file is part of cloud-init. See LICENSE file for license information. import os +import re import unittest from textwrap import dedent +from xml.etree import ElementTree from cloudinit.sources.helpers import azure as azure_helper from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir @@ -48,6 +50,30 @@ """ +HEALTH_REPORT_XML_TEMPLATE = '''\ + + + {incarnation} + + {container_id} + + + {instance_id} + + {health_status} + {health_detail_subsection} + + + + + +''' + + +class SentinelException(Exception): + pass + class TestFindEndpoint(CiTestCase): @@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase): 'certificates_url': 'MyCertificatesUrl', } - def _get_goal_state(self, http_client=None, **kwargs): - if http_client is None: - http_client = mock.MagicMock() + def _get_formatted_goal_state_xml_string(self, **kwargs): parameters = self.default_parameters.copy() parameters.update(kwargs) xml = GOAL_STATE_TEMPLATE.format(**parameters) @@ -153,7 +177,13 @@ def _get_goal_state(self, http_client=None, **kwargs): continue new_xml_lines.append(line) xml = '\n'.join(new_xml_lines) - return azure_helper.GoalState(xml, http_client) + return xml + + def _get_goal_state(self, azure_endpoint_client=None, **kwargs): + if azure_endpoint_client is None: + azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string(**kwargs) + return azure_helper.GoalState(xml, azure_endpoint_client) def test_incarnation_parsed_correctly(self): incarnation = '123' @@ -190,25 +220,61 @@ def test_instance_id_no_byte_swap_diff_instance_id(self): azure_helper.is_byte_swapped(previous_iid, current_iid)) def test_certificates_xml_parsed_and_fetched_correctly(self): - http_client = mock.MagicMock() + azure_endpoint_client = mock.MagicMock() certificates_url = 'TestCertificatesUrl' goal_state = self._get_goal_state( - http_client=http_client, certificates_url=certificates_url) + azure_endpoint_client=azure_endpoint_client, + certificates_url=certificates_url) certificates_xml = goal_state.certificates_xml - self.assertEqual(1, http_client.get.call_count) - self.assertEqual(certificates_url, http_client.get.call_args[0][0]) - self.assertTrue(http_client.get.call_args[1].get('secure', False)) - self.assertEqual(http_client.get.return_value.contents, - certificates_xml) + self.assertEqual(1, azure_endpoint_client.get.call_count) + self.assertEqual( + certificates_url, + azure_endpoint_client.get.call_args[0][0]) + self.assertTrue( + azure_endpoint_client.get.call_args[1].get( + 'secure', False)) + self.assertEqual( + azure_endpoint_client.get.return_value.contents, + certificates_xml) def test_missing_certificates_skips_http_get(self): - http_client = mock.MagicMock() + azure_endpoint_client = mock.MagicMock() goal_state = self._get_goal_state( - http_client=http_client, certificates_url=None) + azure_endpoint_client=azure_endpoint_client, certificates_url=None) certificates_xml = goal_state.certificates_xml - self.assertEqual(0, http_client.get.call_count) + self.assertEqual(0, azure_endpoint_client.get.call_count) self.assertIsNone(certificates_xml) + def test_invalid_goal_state_xml_raises_parse_error(self): + azure_endpoint_client = mock.MagicMock() + xml = 'random non-xml data' + with self.assertRaises(ElementTree.ParseError): + azure_helper.GoalState(xml, azure_endpoint_client) + + def test_missing_container_id_in_goal_state_xml_raises_exc(self): + azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string( + azure_endpoint_client=azure_endpoint_client) + xml = re.sub('.*', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, azure_endpoint_client) + + def test_missing_instance_id_in_goal_state_xml_raises_exc(self): + azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string( + azure_endpoint_client=azure_endpoint_client) + xml = re.sub('.*', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, azure_endpoint_client) + + def test_missing_incarnation_in_goal_state_xml_raises_exc(self): + azure_endpoint_client = mock.MagicMock() + xml = self._get_formatted_goal_state_xml_string( + azure_endpoint_client=azure_endpoint_client) + xml = re.sub('.*', '', xml) + with self.assertRaises(azure_helper.InvalidGoalStateXMLException): + azure_helper.GoalState(xml, azure_endpoint_client) + class TestAzureEndpointHttpClient(CiTestCase): @@ -222,19 +288,28 @@ def setUp(self): patches = ExitStack() self.addCleanup(patches.close) - self.read_file_or_url = patches.enter_context( - mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + self.readurl = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'readurl')) + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) def test_non_secure_get(self): client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) url = 'MyTestUrl' response = client.get(url, secure=False) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_non_secure_get_raises_exception(self): + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + url = 'MyTestUrl' + with self.assertRaises(SentinelException): + client.get(url, secure=False) def test_secure_get(self): url = 'MyTestUrl' @@ -246,37 +321,62 @@ def test_secure_get(self): }) client = azure_helper.AzureEndpointHttpClient(certificate) response = client.get(url, secure=True) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, headers=expected_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, headers=expected_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_secure_get_raises_exception(self): + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.get(url, secure=True) def test_post(self): data = mock.MagicMock() url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) response = client.post(url, data=data) - self.assertEqual(1, self.read_file_or_url.call_count) - self.assertEqual(self.read_file_or_url.return_value, response) + self.assertEqual(1, self.readurl.call_count) + self.assertEqual(self.readurl.return_value, response) self.assertEqual( - mock.call(url, data=data, headers=self.regular_headers, retries=10, - timeout=5), - self.read_file_or_url.call_args) + mock.call(url, data=data, headers=self.regular_headers, + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_raises_exception(self): + data = mock.MagicMock() + url = 'MyTestUrl' + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post(url, data=data) def test_post_with_extra_headers(self): url = 'MyTestUrl' client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) extra_headers = {'test': 'header'} client.post(url, extra_headers=extra_headers) - self.assertEqual(1, self.read_file_or_url.call_count) expected_headers = self.regular_headers.copy() expected_headers.update(extra_headers) + self.assertEqual(1, self.readurl.call_count) self.assertEqual( mock.call(mock.ANY, data=mock.ANY, headers=expected_headers, - retries=10, timeout=5), - self.read_file_or_url.call_args) + timeout=5, retries=10, sec_between=5), + self.readurl.call_args) + + def test_post_with_sleep_with_extra_headers_raises_exception(self): + data = mock.MagicMock() + url = 'MyTestUrl' + extra_headers = {'test': 'header'} + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + self.readurl.side_effect = SentinelException + with self.assertRaises(SentinelException): + client.post( + url, data=data, extra_headers=extra_headers) class TestOpenSSLManager(CiTestCase): @@ -365,6 +465,116 @@ def test_parse_certificates(self, mock_decrypt_certs): self.assertIn(fp, keys_by_fp) +class TestGoalStateHealthReporter(CiTestCase): + + default_parameters = { + 'incarnation': 1, + 'container_id': 'MyContainerId', + 'instance_id': 'MyInstanceId' + } + + test_endpoint = 'TestEndpoint' + test_url = 'http://{0}/machine?comp=health'.format(test_endpoint) + test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'} + + provisioning_success_status = 'Ready' + + def setUp(self): + super(TestGoalStateHealthReporter, self).setUp() + patches = ExitStack() + self.addCleanup(patches.close) + + patches.enter_context( + mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) + self.read_file_or_url = patches.enter_context( + mock.patch.object(azure_helper.url_helper, 'read_file_or_url')) + + self.post = patches.enter_context( + mock.patch.object(azure_helper.AzureEndpointHttpClient, + 'post')) + + self.GoalState = patches.enter_context( + mock.patch.object(azure_helper, 'GoalState')) + self.GoalState.return_value.container_id = \ + self.default_parameters['container_id'] + self.GoalState.return_value.instance_id = \ + self.default_parameters['instance_id'] + self.GoalState.return_value.incarnation = \ + self.default_parameters['incarnation'] + + def _get_formatted_health_report_xml_string(self, **kwargs): + return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs) + + def _get_report_ready_health_document(self): + return self._get_formatted_health_report_xml_string( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + health_status=self.provisioning_success_status, + health_detail_subsection='') + + def test_send_ready_signal_sends_post_request(self): + health_document = self._get_report_ready_health_document() + client = azure_helper.AzureEndpointHttpClient(mock.MagicMock()) + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + client, self.test_endpoint) + reporter.send_ready_signal() + self.assertEqual(1, self.post.call_count) + self.assertEqual( + mock.call( + self.test_url, + data=health_document, + extra_headers=self.test_default_headers), + self.post.call_args) + + def test_send_ready_signal_health_document(self): + health_document = self._get_report_ready_health_document() + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + generated_health_document = reporter.build_report( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status) + self.assertEqual(health_document, generated_health_document) + self.assertIn(str(self.default_parameters['incarnation']), + generated_health_document) + self.assertIn(self.default_parameters['container_id'], + generated_health_document) + self.assertIn(self.default_parameters['instance_id'], + generated_health_document) + self.assertIn(self.provisioning_success_status, + generated_health_document) + self.assertNotIn('
', generated_health_document) + self.assertNotIn('', generated_health_document) + self.assertNotIn('', generated_health_document) + + def test_send_ready_signal_calls_build_report(self): + patches = ExitStack() + self.addCleanup(patches.close) + build_report = patches.enter_context( + mock.patch.object( + azure_helper.GoalStateHealthReporter, 'build_report')) + + reporter = azure_helper.GoalStateHealthReporter( + azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()), + azure_helper.AzureEndpointHttpClient(mock.MagicMock()), + self.test_endpoint) + reporter.send_ready_signal() + + self.assertEqual(1, build_report.call_count) + self.assertEqual( + mock.call( + incarnation=self.default_parameters['incarnation'], + container_id=self.default_parameters['container_id'], + instance_id=self.default_parameters['instance_id'], + status=self.provisioning_success_status), + build_report.call_args) + + class TestWALinuxAgentShim(CiTestCase): def setUp(self): @@ -383,14 +593,21 @@ def setUp(self): patches.enter_context( mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock())) - def test_http_client_uses_certificate(self): + self.test_incarnation = 'TestIncarnation' + self.test_container_id = 'TestContainerId' + self.test_instance_id = 'TestInstanceId' + self.GoalState.return_value.incarnation = self.test_incarnation + self.GoalState.return_value.container_id = self.test_container_id + self.GoalState.return_value.instance_id = self.test_instance_id + + def test_azure_endpoint_client_uses_certificate_during_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() self.assertEqual( [mock.call(self.OpenSSLManager.return_value.certificate)], self.AzureEndpointHttpClient.call_args_list) - def test_correct_url_used_for_goalstate(self): + def test_correct_url_used_for_goalstate_during_report_ready(self): self.find_endpoint.return_value = 'test_endpoint' shim = wa_shim() shim.register_with_azure_and_fetch_data() @@ -439,43 +656,67 @@ def test_correct_url_used_for_report_ready(self): expected_url = 'http://test_endpoint/machine?comp=health' self.assertEqual( [mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)], - self.AzureEndpointHttpClient.return_value.post.call_args_list) + self.AzureEndpointHttpClient.return_value.post + .call_args_list) def test_goal_state_values_used_for_report_ready(self): - self.GoalState.return_value.incarnation = 'TestIncarnation' - self.GoalState.return_value.container_id = 'TestContainerId' - self.GoalState.return_value.instance_id = 'TestInstanceId' shim = wa_shim() shim.register_with_azure_and_fetch_data() posted_document = ( - self.AzureEndpointHttpClient.return_value.post.call_args[1]['data'] + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data'] ) - self.assertIn('TestIncarnation', posted_document) - self.assertIn('TestContainerId', posted_document) - self.assertIn('TestInstanceId', posted_document) + self.assertIn(self.test_incarnation, posted_document) + self.assertIn(self.test_container_id, posted_document) + self.assertIn(self.test_instance_id, posted_document) + + def test_xml_elems_in_report_ready(self): + shim = wa_shim() + shim.register_with_azure_and_fetch_data() + health_document = HEALTH_REPORT_XML_TEMPLATE.format( + incarnation=self.test_incarnation, + container_id=self.test_container_id, + instance_id=self.test_instance_id, + health_status='Ready', + health_detail_subsection='') + posted_document = ( + self.AzureEndpointHttpClient.return_value.post + .call_args[1]['data']) + self.assertEqual(health_document, posted_document) def test_clean_up_can_be_called_at_any_time(self): shim = wa_shim() shim.clean_up() - def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self): + def test_clean_up_after_report_ready(self): shim = wa_shim() shim.register_with_azure_and_fetch_data() shim.clean_up() self.assertEqual( 1, self.OpenSSLManager.return_value.clean_up.call_count) - def test_failure_to_fetch_goalstate_bubbles_up(self): - class SentinelException(Exception): - pass - self.AzureEndpointHttpClient.return_value.get.side_effect = ( - SentinelException) + def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self): + self.AzureEndpointHttpClient.return_value.get \ + .side_effect = (SentinelException) shim = wa_shim() self.assertRaises(SentinelException, shim.register_with_azure_and_fetch_data) + def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self): + self.GoalState.side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) -class TestGetMetadataFromFabric(CiTestCase): + def test_failure_to_send_report_ready_health_doc_bubbles_up(self): + self.AzureEndpointHttpClient.return_value.post \ + .side_effect = SentinelException + shim = wa_shim() + self.assertRaises(SentinelException, + shim.register_with_azure_and_fetch_data) + + +class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_data_from_shim_returned(self, shim): @@ -491,14 +732,39 @@ def test_success_calls_clean_up(self, shim): @mock.patch.object(azure_helper, 'WALinuxAgentShim') def test_failure_in_registration_calls_clean_up(self, shim): - class SentinelException(Exception): - pass shim.return_value.register_with_azure_and_fetch_data.side_effect = ( SentinelException) self.assertRaises(SentinelException, azure_helper.get_metadata_from_fabric) self.assertEqual(1, shim.return_value.clean_up.call_count) + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_calls_shim_register_with_azure_and_fetch_data(self, shim): + pubkey_info = mock.MagicMock() + azure_helper.get_metadata_from_fabric(pubkey_info=pubkey_info) + self.assertEqual( + 1, + shim.return_value + .register_with_azure_and_fetch_data.call_count) + self.assertEqual( + mock.call(pubkey_info=pubkey_info), + shim.return_value + .register_with_azure_and_fetch_data.call_args) + + @mock.patch.object(azure_helper, 'WALinuxAgentShim') + def test_instantiates_shim_with_kwargs(self, shim): + fallback_lease_file = mock.MagicMock() + dhcp_options = mock.MagicMock() + azure_helper.get_metadata_from_fabric( + fallback_lease_file=fallback_lease_file, + dhcp_opts=dhcp_options) + self.assertEqual(1, shim.call_count) + self.assertEqual( + mock.call( + fallback_lease_file=fallback_lease_file, + dhcp_options=dhcp_options), + shim.call_args) + class TestExtractIpAddressFromNetworkd(CiTestCase): From 4369c5cf1b86b3f8ca77092c87c0a850b398d398 Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Tue, 9 Jun 2020 03:40:50 +0000 Subject: [PATCH 2/4] Add telemetry for obtaining Azure certificates XML --- cloudinit/sources/helpers/azure.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py index 5674111f857..bb2dbacfbad 100755 --- a/cloudinit/sources/helpers/azure.py +++ b/cloudinit/sources/helpers/azure.py @@ -267,12 +267,16 @@ def __init__(self, unparsed_xml, azure_endpoint_client): './Container/RoleInstanceList/RoleInstance' '/Configuration/Certificates') if url is not None: - self._certificates_xml = \ - self.azure_endpoint_client.get( - url, secure=True).contents - if self._certificates_xml is None: - raise InvalidGoalStateXMLException( - 'Azure endpoint returned empty certificates xml.') + with events.ReportEventStack( + name="get-certificates-xml", + description="get certificates xml", + parent=azure_ds_reporter): + self._certificates_xml = \ + self.azure_endpoint_client.get( + url, secure=True).contents + if self._certificates_xml is None: + raise InvalidGoalStateXMLException( + 'Azure endpoint returned empty certificates xml.') def _text_from_xpath(self, xpath): element = self.root.find(xpath) From 96964527da773adad5fc9b96fdaaef1c48ae424a Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Tue, 9 Jun 2020 05:39:28 +0000 Subject: [PATCH 3/4] Add extra telemetry for marker files --- cloudinit/sources/DataSourceAzure.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index ca63f3979dc..4a52c42cca4 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -405,6 +405,10 @@ def crawl_metadata(self): candidates = [self.seed_dir] if os.path.isfile(REPROVISION_MARKER_FILE): candidates.insert(0, "IMDS") + msg = ('Reprovision marker file already exists ' + 'before crawl of Azure metadata') + LOG.warning(msg) + report_diagnostic_event(msg) candidates.extend(list_possible_azure_ds_devs()) if ddir: candidates.append(ddir) @@ -461,7 +465,12 @@ def crawl_metadata(self): 'userdata_raw': userdata_raw}) found = cdev - LOG.debug("found datasource in %s", cdev) + if perform_reprovision: + msg = "found datasource in IMDS" + else: + msg = "found datasource in %s" % cdev + LOG.debug(msg) + report_diagnostic_event(msg) break if not found: @@ -471,7 +480,9 @@ def crawl_metadata(self): raise sources.InvalidMetaDataException(msg) if found == ddir: - LOG.debug("using files cached in %s", ddir) + msg = "using files cached in %s" % ddir + LOG.debug(msg) + report_diagnostic_event(msg) seed = _get_random_seed() if seed: From 7027a42e661c936582808c32be97c818c7058c8b Mon Sep 17 00:00:00 2001 From: Johnson Shi Date: Tue, 9 Jun 2020 19:52:00 +0000 Subject: [PATCH 4/4] github-cla-signers: add johnsonshi --- tools/.github-cla-signers | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/.github-cla-signers b/tools/.github-cla-signers index f98c7a4588f..a09ff2d19c6 100644 --- a/tools/.github-cla-signers +++ b/tools/.github-cla-signers @@ -1,6 +1,7 @@ beezly bipinbachhao dhensby +johnsonshi lucasmoura matthewruffell nishigori