diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py
index 89312b9e9bd..4a52c42cca4 100755
--- a/cloudinit/sources/DataSourceAzure.py
+++ b/cloudinit/sources/DataSourceAzure.py
@@ -405,6 +405,10 @@ def crawl_metadata(self):
candidates = [self.seed_dir]
if os.path.isfile(REPROVISION_MARKER_FILE):
candidates.insert(0, "IMDS")
+ msg = ('Reprovision marker file already exists '
+ 'before crawl of Azure metadata')
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
candidates.extend(list_possible_azure_ds_devs())
if ddir:
candidates.append(ddir)
@@ -426,17 +430,19 @@ def crawl_metadata(self):
ret = load_azure_ds_dir(cdev)
except NonAzureDataSource:
- report_diagnostic_event(
- "Did not find Azure data source in %s" % cdev)
+ msg = "Did not find Azure data source in %s" % cdev
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
continue
except BrokenAzureDataSource as exc:
msg = 'BrokenAzureDataSource: %s' % exc
+ LOG.error(msg)
report_diagnostic_event(msg)
raise sources.InvalidMetaDataException(msg)
except util.MountFailedError:
msg = '%s was not mountable' % cdev
- report_diagnostic_event(msg)
LOG.warning(msg)
+ report_diagnostic_event(msg)
continue
perform_reprovision = reprovision or self._should_reprovision(ret)
@@ -459,16 +465,24 @@ def crawl_metadata(self):
'userdata_raw': userdata_raw})
found = cdev
- LOG.debug("found datasource in %s", cdev)
+ if perform_reprovision:
+ msg = "found datasource in IMDS"
+ else:
+ msg = "found datasource in %s" % cdev
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
break
if not found:
msg = 'No Azure metadata found'
+ LOG.error(msg)
report_diagnostic_event(msg)
raise sources.InvalidMetaDataException(msg)
if found == ddir:
- LOG.debug("using files cached in %s", ddir)
+ msg = "using files cached in %s" % ddir
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
seed = _get_random_seed()
if seed:
@@ -488,8 +502,9 @@ def crawl_metadata(self):
azure_ds_reporter) as lease:
self._report_ready(lease=lease)
except Exception as e:
- report_diagnostic_event(
- "exception while reporting ready: %s" % e)
+ msg = "exception while reporting ready: %s" % e
+ LOG.error(msg)
+ report_diagnostic_event(msg)
raise
return crawled_data
@@ -617,10 +632,11 @@ def exc_cb(msg, exception):
else:
# If we get an exception while trying to call IMDS, we call
# DHCP and setup the ephemeral network to acquire a new IP.
- report_diagnostic_event("poll IMDS with %s failed. "
- "Exception: %s and code: %s" %
- (msg, exception.cause,
- exception.code))
+ evt_msg = ("poll IMDS with %s failed. "
+ "Exception: %s and code: %s" %
+ (msg, exception.cause, exception.code))
+ LOG.warning(evt_msg)
+ report_diagnostic_event(evt_msg)
return False
LOG.debug("poll IMDS failed with an unexpected exception: %s",
@@ -644,8 +660,8 @@ def exc_cb(msg, exception):
try:
nl_sock = netlink.create_bound_netlink_socket()
except netlink.NetlinkCreateSocketError as e:
- report_diagnostic_event(e)
LOG.warning(e)
+ report_diagnostic_event(e)
self._ephemeral_dhcp_ctx.clean_network()
break
@@ -665,8 +681,8 @@ def exc_cb(msg, exception):
netlink.wait_for_media_disconnect_connect(
nl_sock, lease['interface'])
except AssertionError as error:
- report_diagnostic_event(error)
LOG.error(error)
+ report_diagnostic_event(error)
break
vnet_switched = True
@@ -692,10 +708,12 @@ def exc_cb(msg, exception):
nl_sock.close()
if vnet_switched:
- report_diagnostic_event("attempted dhcp %d times after reuse" %
- dhcp_attempts)
- report_diagnostic_event("polled imds %d times after reuse" %
- self.imds_poll_counter)
+ msg = "attempted dhcp %d times after reuse" % dhcp_attempts
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
+ msg = "polled imds %d times after reuse" % self.imds_poll_counter
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
return return_val
@@ -768,12 +786,12 @@ def _negotiate(self):
try:
fabric_data = metadata_func()
except Exception as e:
+ LOG.error(
+ "Error communicating with Azure fabric; You may experience "
+ "connectivity issues.", exc_info=True)
report_diagnostic_event(
"Error communicating with Azure fabric; You may experience "
"connectivity issues: %s" % e)
- LOG.warning(
- "Error communicating with Azure fabric; You may experience "
- "connectivity issues.", exc_info=True)
return False
util.del_file(REPORTED_READY_MARKER_FILE)
@@ -1132,6 +1150,7 @@ def read_azure_ovf(contents):
dom = minidom.parseString(contents)
except Exception as e:
error_str = "Invalid ovf-env.xml: %s" % e
+ LOG.error(error_str)
report_diagnostic_event(error_str)
raise BrokenAzureDataSource(error_str)
@@ -1414,7 +1433,9 @@ def get_metadata_from_imds(fallback_nic, retries):
azure_ds_reporter, fallback_nic):
return util.log_time(**kwargs)
except Exception as e:
- report_diagnostic_event("exception while getting metadata: %s" % e)
+ msg = "exception while getting metadata: %s" % e
+ LOG.error(msg)
+ report_diagnostic_event(msg)
raise
@@ -1429,15 +1450,15 @@ def _get_metadata_from_imds(retries):
retries=retries, exception_cb=retry_on_url_exc)
except Exception as e:
msg = 'Ignoring IMDS instance metadata: %s' % e
+ LOG.error(msg)
report_diagnostic_event(msg)
- LOG.debug(msg)
return {}
try:
return util.load_json(str(response))
except json.decoder.JSONDecodeError as e:
- report_diagnostic_event('non-json imds response' % e)
- LOG.warning(
+ LOG.error(
'Ignoring non-json IMDS instance metadata: %s', str(response))
+ report_diagnostic_event('non-json imds response' % e)
return {}
diff --git a/cloudinit/sources/helpers/azure.py b/cloudinit/sources/helpers/azure.py
index 7bace8ca294..bb2dbacfbad 100755
--- a/cloudinit/sources/helpers/azure.py
+++ b/cloudinit/sources/helpers/azure.py
@@ -213,24 +213,70 @@ def get(self, url, secure=False):
if secure:
headers = self.headers.copy()
headers.update(self.extra_secure_headers)
- return url_helper.read_file_or_url(url, headers=headers, timeout=5,
- retries=10)
+ return url_helper.readurl(url, headers=headers,
+ timeout=5, retries=10, sec_between=5)
def post(self, url, data=None, extra_headers=None):
headers = self.headers
if extra_headers is not None:
headers = self.headers.copy()
headers.update(extra_headers)
- return url_helper.read_file_or_url(url, data=data, headers=headers,
- timeout=5, retries=10)
+ return url_helper.readurl(url, data=data, headers=headers,
+ timeout=5, retries=10, sec_between=5)
+
+
+class InvalidGoalStateXMLException(Exception):
+ """Raised when GoalState XML is invalid or has missing data."""
class GoalState(object):
- def __init__(self, xml, http_client):
- self.http_client = http_client
- self.root = ElementTree.fromstring(xml)
+ def __init__(self, unparsed_xml, azure_endpoint_client):
+ """Parses a GoalState XML string and returns a GoalState object.
+
+ @param unparsed_xml: string representing a GoalState XML.
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @return: GoalState object representing the GoalState XML string.
+ """
+ self.azure_endpoint_client = azure_endpoint_client
+
+ try:
+ self.root = ElementTree.fromstring(unparsed_xml)
+ except ElementTree.ParseError as e:
+ msg = 'Failed to parse GoalState XML: %s' % e
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise
+
self._certificates_xml = None
+ self._container_id = self._text_from_xpath('./Container/ContainerId')
+ self._instance_id = self._text_from_xpath(
+ './Container/RoleInstanceList/RoleInstance/InstanceId')
+ self._incarnation = self._text_from_xpath('./Incarnation')
+
+ if any(x is None for x in (self.container_id,
+ self.instance_id, self.incarnation)):
+ msg = ('Missing container id/instance id/incarnation in '
+ 'GoalState XML.')
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise InvalidGoalStateXMLException(msg)
+
+ self._certificates_xml = None
+ url = self._text_from_xpath(
+ './Container/RoleInstanceList/RoleInstance'
+ '/Configuration/Certificates')
+ if url is not None:
+ with events.ReportEventStack(
+ name="get-certificates-xml",
+ description="get certificates xml",
+ parent=azure_ds_reporter):
+ self._certificates_xml = \
+ self.azure_endpoint_client.get(
+ url, secure=True).contents
+ if self._certificates_xml is None:
+ raise InvalidGoalStateXMLException(
+ 'Azure endpoint returned empty certificates xml.')
def _text_from_xpath(self, xpath):
element = self.root.find(xpath)
@@ -240,26 +286,18 @@ def _text_from_xpath(self, xpath):
@property
def container_id(self):
- return self._text_from_xpath('./Container/ContainerId')
+ return self._container_id
@property
- def incarnation(self):
- return self._text_from_xpath('./Incarnation')
+ def instance_id(self):
+ return self._instance_id
@property
- def instance_id(self):
- return self._text_from_xpath(
- './Container/RoleInstanceList/RoleInstance/InstanceId')
+ def incarnation(self):
+ return self._incarnation
@property
def certificates_xml(self):
- if self._certificates_xml is None:
- url = self._text_from_xpath(
- './Container/RoleInstanceList/RoleInstance'
- '/Configuration/Certificates')
- if url is not None:
- self._certificates_xml = self.http_client.get(
- url, secure=True).contents
return self._certificates_xml
@@ -370,25 +408,105 @@ def parse_certificates(self, certificates_xml):
return keys
-class WALinuxAgentShim(object):
+class GoalStateHealthReporter(object):
+
+ HEALTH_REPORT_XML_TEMPLATE = textwrap.dedent('''\
+
+
+ {incarnation}
+
+ {container_id}
+
+
+ {instance_id}
+
+ {health_status}
+ {health_detail_subsection}
+
+
+
+
+
+ ''')
+
+ HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE = textwrap.dedent('''\
+
+ {health_substatus}
+ {health_description}
+
+ ''')
+
+ PROVISIONING_SUCCESS_STATUS = 'Ready'
+
+ def __init__(self, goal_state, azure_endpoint_client, endpoint):
+ """Creates instance that will report provisioning status to an endpoint
+
+ @param goal_state: An instance of class GoalState that contains
+ goal state info such as incarnation, container id, and instance id.
+ These 3 values are needed when reporting the provisioning status
+ to Azure
+ @param azure_endpoint_client: Instance of class AzureEndpointHttpClient
+ @param endpoint: Endpoint (string) where the provisioning status report
+ will be sent to
+ @return: Instance of class GoalStateHealthReporter
+ """
+ self._goal_state = goal_state
+ self._azure_endpoint_client = azure_endpoint_client
+ self._endpoint = endpoint
+
+ @azure_ds_telemetry_reporter
+ def send_ready_signal(self):
+ document = self.build_report(
+ incarnation=self._goal_state.incarnation,
+ container_id=self._goal_state.container_id,
+ instance_id=self._goal_state.instance_id,
+ status=self.PROVISIONING_SUCCESS_STATUS)
+ LOG.debug('Reporting ready to Azure fabric.')
+ try:
+ self._post_health_report(document=document)
+ except Exception as e:
+ msg = "exception while reporting ready: %s" % e
+ LOG.error(msg)
+ report_diagnostic_event(msg)
+ raise
+
+ LOG.info('Reported ready to Azure fabric.')
- REPORT_READY_XML_TEMPLATE = '\n'.join([
- '',
- '',
- ' {incarnation}',
- ' ',
- ' {container_id}',
- ' ',
- ' ',
- ' {instance_id}',
- ' ',
- ' Ready',
- ' ',
- ' ',
- ' ',
- ' ',
- ''])
+ def build_report(self, incarnation, container_id, instance_id,
+ status, substatus=None, description=None):
+ health_detail = ''
+ if substatus is not None:
+ health_detail = self.HEALTH_DETAIL_SUBSECTION_XML_TEMPLATE.format(
+ health_substatus=substatus, health_description=description)
+
+ health_report = self.HEALTH_REPORT_XML_TEMPLATE.format(
+ incarnation=incarnation,
+ container_id=container_id,
+ instance_id=instance_id,
+ health_status=status,
+ health_detail_subsection=health_detail)
+
+ return health_report
+
+ @azure_ds_telemetry_reporter
+ def _post_health_report(self, document):
+ """Host will collect kvps when cloud-init reports to fabric.
+ Some kvps might still be in the queue. We yield the scheduler
+ to make sure we process all kvps up till this point.
+ """
+ time.sleep(0)
+
+ LOG.debug('Sending health report to Azure fabric.')
+ url = "http://{0}/machine?comp=health".format(self._endpoint)
+ self._azure_endpoint_client.post(
+ url,
+ data=document,
+ extra_headers={'Content-Type': 'text/xml; charset=utf-8'})
+ LOG.debug('Successfully sent health report to Azure fabric')
+
+
+class WALinuxAgentShim(object):
def __init__(self, fallback_lease_file=None, dhcp_options=None):
LOG.debug('WALinuxAgentShim instantiated, fallback_lease_file=%s',
@@ -494,84 +612,201 @@ def _get_value_from_dhcpoptions(dhcp_options):
@staticmethod
@azure_ds_telemetry_reporter
def find_endpoint(fallback_lease_file=None, dhcp245=None):
+ """Finds and returns the Azure endpoint using various methods.
+
+ The Azure endpoint is searched in the following order:
+ 1. Endpoint from dhcp options (dhcp option 245).
+ 2. Endpoint from networkd.
+ 3. Endpoint from dhclient hook json.
+ 4. Endpoint from fallback lease file.
+ 5. The default Azure endpoint.
+
+ @param fallback_lease_file: Fallback lease file that will be used
+ during endpoint search.
+ @param dhcp245: dhcp options that will be used during endpoint search.
+ @return: Azure endpoint IP address.
+ """
value = None
+
if dhcp245 is not None:
value = dhcp245
- LOG.debug("Using Azure Endpoint from dhcp options")
+ LOG.debug("Using Azure Endpoint from dhcp options.")
+
if value is None:
- report_diagnostic_event("No Azure endpoint from dhcp options")
- LOG.debug('Finding Azure endpoint from networkd...')
+ msg = ("No Azure endpoint from dhcp options. "
+ "Finding Azure endpoint from networkd...")
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
value = WALinuxAgentShim._networkd_get_value_from_leases()
+
if value is None:
# Option-245 stored in /run/cloud-init/dhclient.hooks/.json
# a dhclient exit hook that calls cloud-init-dhclient-hook
- report_diagnostic_event("No Azure endpoint from networkd")
- LOG.debug('Finding Azure endpoint from hook json...')
+ msg = ("No Azure endpoint from networkd. "
+ "Finding Azure endpoint from hook json...")
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
dhcp_options = WALinuxAgentShim._load_dhclient_json()
value = WALinuxAgentShim._get_value_from_dhcpoptions(dhcp_options)
+
if value is None:
# Fallback and check the leases file if unsuccessful
- report_diagnostic_event("No Azure endpoint from dhclient logs")
- LOG.debug("Unable to find endpoint in dhclient logs. "
- " Falling back to check lease files")
+ msg = ("No Azure endpoint from dhclient logs. "
+ "Unable to find endpoint in dhclient logs. "
+ "Falling back to check lease files.")
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
+
if fallback_lease_file is None:
- LOG.warning("No fallback lease file was specified.")
+ msg = "No fallback lease file was specified."
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
value = None
else:
LOG.debug("Looking for endpoint in lease file %s",
fallback_lease_file)
value = WALinuxAgentShim._get_value_from_leases_file(
fallback_lease_file)
+
if value is None:
msg = "No lease found; using default endpoint"
- report_diagnostic_event(msg)
LOG.warning(msg)
+ report_diagnostic_event(msg)
value = DEFAULT_WIRESERVER_ENDPOINT
endpoint_ip_address = WALinuxAgentShim.get_ip_from_lease_value(value)
msg = 'Azure endpoint found at %s' % endpoint_ip_address
- report_diagnostic_event(msg)
LOG.debug(msg)
+ report_diagnostic_event(msg)
return endpoint_ip_address
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(self, pubkey_info=None):
+ """Gets the VM's GoalState from Azure, uses the GoalState information
+ to report ready/send the ready signal/provisioning complete signal to
+ Azure, and then uses pubkey_info to filter and obtain the user's
+ pubkeys from the GoalState.
+
+ @param pubkey_info: List of pubkey values and fingerprints which are
+ used to filter and obtain the user's pubkey values from the
+ GoalState.
+ @return: The list of user's authorized pubkey values.
+ """
if self.openssl_manager is None:
self.openssl_manager = OpenSSLManager()
- http_client = AzureEndpointHttpClient(self.openssl_manager.certificate)
+ azure_endpoint_client = AzureEndpointHttpClient(
+ self.openssl_manager.certificate)
+ goal_state = self._fetch_goal_state_from_azure(azure_endpoint_client)
+ ssh_keys = self._get_user_pubkeys(goal_state, pubkey_info)
+ health_reporter = GoalStateHealthReporter(
+ goal_state, azure_endpoint_client, self.endpoint)
+ health_reporter.send_ready_signal()
+ return {'public-keys': ssh_keys}
+
+ @azure_ds_telemetry_reporter
+ def _fetch_goal_state_from_azure(self, azure_endpoint_client):
+ """Fetches the GoalState XML from the Azure endpoint, parses the XML,
+ and returns a GoalState object.
+
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @return: GoalState object representing the GoalState XML
+ """
+ unparsed_goal_state_xml = self._get_goal_state_from_azure(
+ azure_endpoint_client)
+ return self._parse_goal_state(
+ unparsed_goal_state_xml, azure_endpoint_client)
+
+ @azure_ds_telemetry_reporter
+ def _get_goal_state_from_azure(self, azure_endpoint_client):
+ """Fetches the GoalState XML from the Azure endpoint and returns
+ the XML as a string.
+
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @return: GoalState XML string
+ """
+
LOG.info('Registering with Azure...')
- attempts = 0
- while True:
- try:
- response = http_client.get(
- 'http://{0}/machine/?comp=goalstate'.format(self.endpoint))
- except Exception as e:
- if attempts < 10:
- time.sleep(attempts + 1)
- else:
- report_diagnostic_event(
- "failed to register with Azure: %s" % e)
- raise
- else:
- break
- attempts += 1
+ url = 'http://{0}/machine/?comp=goalstate'.format(self.endpoint)
+ try:
+ response = azure_endpoint_client.get(url)
+ except Exception as e:
+ msg = 'failed to register with Azure: %s' % e
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise
LOG.debug('Successfully fetched GoalState XML.')
- goal_state = GoalState(response.contents, http_client)
- report_diagnostic_event("container_id %s" % goal_state.container_id)
+ return response.contents
+
+ @azure_ds_telemetry_reporter
+ def _parse_goal_state(self,
+ unparsed_goal_state_xml,
+ azure_endpoint_client):
+ """Parses a GoalState XML string and returns a GoalState object.
+
+ @param unparsed_goal_state_xml: GoalState XML string
+ @param azure_endpoint_client: instance of AzureEndpointHttpClient
+ @return: GoalState object representing the GoalState XML
+ """
+ try:
+ goal_state = GoalState(
+ unparsed_goal_state_xml, azure_endpoint_client)
+ except Exception as e:
+ msg = 'Error processing GoalState XML: %s' % e
+ LOG.warning(msg)
+ report_diagnostic_event(msg)
+ raise
+ msg = ', '.join([
+ 'GoalState XML container id: %s' % goal_state.container_id,
+ 'GoalState XML instance id: %s' % goal_state.instance_id,
+ 'GoalState XML incarnation: %s' % goal_state.incarnation])
+ LOG.debug(msg)
+ report_diagnostic_event(msg)
+ return goal_state
+
+ @azure_ds_telemetry_reporter
+ def _get_user_pubkeys(self, goal_state, pubkey_info):
+ """Gets and filters the VM user's authorized pubkeys.
+
+ cloud-init expects a straightforward array of keys to be dropped
+ into the user's authorized_keys file. Azure control plane exposes
+ multiple public keys to the VM via wireserver. Select just the
+ user's key(s) and return them, ignoring any other certs.
+
+ @param goal_state: GoalState object. The GoalState object contains
+ a certificate XML, which contains both the VM user's authorized
+ pubkeys and other non-user pubkeys, which are used for
+ MSI and protected extension handling.
+ @param pubkey_info: List of VM user pubkey dicts that were previously
+ obtained from provisioning data.
+ Each pubkey dict in this list can either have the format
+ pubkey['value'] or pubkey['fingerprint'].
+ Each pubkey['fingerprint'] in the list is used to filter
+ and obtain the actual pubkey value from the GoalState
+ certificates XML.
+ Each pubkey['value'] requires no further processing and is
+ immediately added to the return list.
+ @return: A list of the VM user's authorized pubkey values.
+ """
ssh_keys = []
if goal_state.certificates_xml is not None and pubkey_info is not None:
LOG.debug('Certificate XML found; parsing out public keys.')
keys_by_fingerprint = self.openssl_manager.parse_certificates(
goal_state.certificates_xml)
ssh_keys = self._filter_pubkeys(keys_by_fingerprint, pubkey_info)
- self._report_ready(goal_state, http_client)
- return {'public-keys': ssh_keys}
+ return ssh_keys
- def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info):
- """cloud-init expects a straightforward array of keys to be dropped
- into the user's authorized_keys file. Azure control plane exposes
- multiple public keys to the VM via wireserver. Select just the
- user's key(s) and return them, ignoring any other certs.
+ @staticmethod
+ def _filter_pubkeys(keys_by_fingerprint, pubkey_info):
+ """ Filter and return only the user's actual pubkeys.
+
+ @param keys_by_fingerprint: pubkey fingerprint -> pubkey value dict
+ that was obtained from GoalState Certificates XML. May contain
+ non-user pubkeys.
+ @param pubkey_info: List of VM user pubkeys. Pubkey values are added
+ to the return list without further processing. Pubkey fingerprints
+ are used to filter and obtain the actual pubkey values from
+ keys_by_fingerprint.
+ @return: A list of the VM user's authorized pubkey values.
"""
keys = []
for pubkey in pubkey_info:
@@ -590,30 +825,6 @@ def _filter_pubkeys(self, keys_by_fingerprint, pubkey_info):
return keys
- @azure_ds_telemetry_reporter
- def _report_ready(self, goal_state, http_client):
- LOG.debug('Reporting ready to Azure fabric.')
- document = self.REPORT_READY_XML_TEMPLATE.format(
- incarnation=goal_state.incarnation,
- container_id=goal_state.container_id,
- instance_id=goal_state.instance_id,
- )
- # Host will collect kvps when cloud-init reports ready.
- # some kvps might still be in the queue. We yield the scheduler
- # to make sure we process all kvps up till this point.
- time.sleep(0)
- try:
- http_client.post(
- "http://{0}/machine?comp=health".format(self.endpoint),
- data=document,
- extra_headers={'Content-Type': 'text/xml; charset=utf-8'},
- )
- except Exception as e:
- report_diagnostic_event("exception while reporting ready: %s" % e)
- raise
-
- LOG.info('Reported ready to Azure fabric.')
-
@azure_ds_telemetry_reporter
def get_metadata_from_fabric(fallback_lease_file=None, dhcp_opts=None,
diff --git a/tests/unittests/test_datasource/test_azure_helper.py b/tests/unittests/test_datasource/test_azure_helper.py
index 71ef57f0c4d..1954b53336b 100644
--- a/tests/unittests/test_datasource/test_azure_helper.py
+++ b/tests/unittests/test_datasource/test_azure_helper.py
@@ -1,8 +1,10 @@
# This file is part of cloud-init. See LICENSE file for license information.
import os
+import re
import unittest
from textwrap import dedent
+from xml.etree import ElementTree
from cloudinit.sources.helpers import azure as azure_helper
from cloudinit.tests.helpers import CiTestCase, ExitStack, mock, populate_dir
@@ -48,6 +50,30 @@
"""
+HEALTH_REPORT_XML_TEMPLATE = '''\
+
+
+ {incarnation}
+
+ {container_id}
+
+
+ {instance_id}
+
+ {health_status}
+ {health_detail_subsection}
+
+
+
+
+
+'''
+
+
+class SentinelException(Exception):
+ pass
+
class TestFindEndpoint(CiTestCase):
@@ -140,9 +166,7 @@ class TestGoalStateParsing(CiTestCase):
'certificates_url': 'MyCertificatesUrl',
}
- def _get_goal_state(self, http_client=None, **kwargs):
- if http_client is None:
- http_client = mock.MagicMock()
+ def _get_formatted_goal_state_xml_string(self, **kwargs):
parameters = self.default_parameters.copy()
parameters.update(kwargs)
xml = GOAL_STATE_TEMPLATE.format(**parameters)
@@ -153,7 +177,13 @@ def _get_goal_state(self, http_client=None, **kwargs):
continue
new_xml_lines.append(line)
xml = '\n'.join(new_xml_lines)
- return azure_helper.GoalState(xml, http_client)
+ return xml
+
+ def _get_goal_state(self, azure_endpoint_client=None, **kwargs):
+ if azure_endpoint_client is None:
+ azure_endpoint_client = mock.MagicMock()
+ xml = self._get_formatted_goal_state_xml_string(**kwargs)
+ return azure_helper.GoalState(xml, azure_endpoint_client)
def test_incarnation_parsed_correctly(self):
incarnation = '123'
@@ -190,25 +220,61 @@ def test_instance_id_no_byte_swap_diff_instance_id(self):
azure_helper.is_byte_swapped(previous_iid, current_iid))
def test_certificates_xml_parsed_and_fetched_correctly(self):
- http_client = mock.MagicMock()
+ azure_endpoint_client = mock.MagicMock()
certificates_url = 'TestCertificatesUrl'
goal_state = self._get_goal_state(
- http_client=http_client, certificates_url=certificates_url)
+ azure_endpoint_client=azure_endpoint_client,
+ certificates_url=certificates_url)
certificates_xml = goal_state.certificates_xml
- self.assertEqual(1, http_client.get.call_count)
- self.assertEqual(certificates_url, http_client.get.call_args[0][0])
- self.assertTrue(http_client.get.call_args[1].get('secure', False))
- self.assertEqual(http_client.get.return_value.contents,
- certificates_xml)
+ self.assertEqual(1, azure_endpoint_client.get.call_count)
+ self.assertEqual(
+ certificates_url,
+ azure_endpoint_client.get.call_args[0][0])
+ self.assertTrue(
+ azure_endpoint_client.get.call_args[1].get(
+ 'secure', False))
+ self.assertEqual(
+ azure_endpoint_client.get.return_value.contents,
+ certificates_xml)
def test_missing_certificates_skips_http_get(self):
- http_client = mock.MagicMock()
+ azure_endpoint_client = mock.MagicMock()
goal_state = self._get_goal_state(
- http_client=http_client, certificates_url=None)
+ azure_endpoint_client=azure_endpoint_client, certificates_url=None)
certificates_xml = goal_state.certificates_xml
- self.assertEqual(0, http_client.get.call_count)
+ self.assertEqual(0, azure_endpoint_client.get.call_count)
self.assertIsNone(certificates_xml)
+ def test_invalid_goal_state_xml_raises_parse_error(self):
+ azure_endpoint_client = mock.MagicMock()
+ xml = 'random non-xml data'
+ with self.assertRaises(ElementTree.ParseError):
+ azure_helper.GoalState(xml, azure_endpoint_client)
+
+ def test_missing_container_id_in_goal_state_xml_raises_exc(self):
+ azure_endpoint_client = mock.MagicMock()
+ xml = self._get_formatted_goal_state_xml_string(
+ azure_endpoint_client=azure_endpoint_client)
+ xml = re.sub('.*', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, azure_endpoint_client)
+
+ def test_missing_instance_id_in_goal_state_xml_raises_exc(self):
+ azure_endpoint_client = mock.MagicMock()
+ xml = self._get_formatted_goal_state_xml_string(
+ azure_endpoint_client=azure_endpoint_client)
+ xml = re.sub('.*', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, azure_endpoint_client)
+
+ def test_missing_incarnation_in_goal_state_xml_raises_exc(self):
+ azure_endpoint_client = mock.MagicMock()
+ xml = self._get_formatted_goal_state_xml_string(
+ azure_endpoint_client=azure_endpoint_client)
+ xml = re.sub('.*', '', xml)
+ with self.assertRaises(azure_helper.InvalidGoalStateXMLException):
+ azure_helper.GoalState(xml, azure_endpoint_client)
+
class TestAzureEndpointHttpClient(CiTestCase):
@@ -222,19 +288,28 @@ def setUp(self):
patches = ExitStack()
self.addCleanup(patches.close)
- self.read_file_or_url = patches.enter_context(
- mock.patch.object(azure_helper.url_helper, 'read_file_or_url'))
+ self.readurl = patches.enter_context(
+ mock.patch.object(azure_helper.url_helper, 'readurl'))
+ patches.enter_context(
+ mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
def test_non_secure_get(self):
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
url = 'MyTestUrl'
response = client.get(url, secure=False)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, headers=self.regular_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, headers=self.regular_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_non_secure_get_raises_exception(self):
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ url = 'MyTestUrl'
+ with self.assertRaises(SentinelException):
+ client.get(url, secure=False)
def test_secure_get(self):
url = 'MyTestUrl'
@@ -246,37 +321,62 @@ def test_secure_get(self):
})
client = azure_helper.AzureEndpointHttpClient(certificate)
response = client.get(url, secure=True)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, headers=expected_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, headers=expected_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_secure_get_raises_exception(self):
+ url = 'MyTestUrl'
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.get(url, secure=True)
def test_post(self):
data = mock.MagicMock()
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
response = client.post(url, data=data)
- self.assertEqual(1, self.read_file_or_url.call_count)
- self.assertEqual(self.read_file_or_url.return_value, response)
+ self.assertEqual(1, self.readurl.call_count)
+ self.assertEqual(self.readurl.return_value, response)
self.assertEqual(
- mock.call(url, data=data, headers=self.regular_headers, retries=10,
- timeout=5),
- self.read_file_or_url.call_args)
+ mock.call(url, data=data, headers=self.regular_headers,
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_post_raises_exception(self):
+ data = mock.MagicMock()
+ url = 'MyTestUrl'
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.post(url, data=data)
def test_post_with_extra_headers(self):
url = 'MyTestUrl'
client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
extra_headers = {'test': 'header'}
client.post(url, extra_headers=extra_headers)
- self.assertEqual(1, self.read_file_or_url.call_count)
expected_headers = self.regular_headers.copy()
expected_headers.update(extra_headers)
+ self.assertEqual(1, self.readurl.call_count)
self.assertEqual(
mock.call(mock.ANY, data=mock.ANY, headers=expected_headers,
- retries=10, timeout=5),
- self.read_file_or_url.call_args)
+ timeout=5, retries=10, sec_between=5),
+ self.readurl.call_args)
+
+ def test_post_with_sleep_with_extra_headers_raises_exception(self):
+ data = mock.MagicMock()
+ url = 'MyTestUrl'
+ extra_headers = {'test': 'header'}
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ self.readurl.side_effect = SentinelException
+ with self.assertRaises(SentinelException):
+ client.post(
+ url, data=data, extra_headers=extra_headers)
class TestOpenSSLManager(CiTestCase):
@@ -365,6 +465,116 @@ def test_parse_certificates(self, mock_decrypt_certs):
self.assertIn(fp, keys_by_fp)
+class TestGoalStateHealthReporter(CiTestCase):
+
+ default_parameters = {
+ 'incarnation': 1,
+ 'container_id': 'MyContainerId',
+ 'instance_id': 'MyInstanceId'
+ }
+
+ test_endpoint = 'TestEndpoint'
+ test_url = 'http://{0}/machine?comp=health'.format(test_endpoint)
+ test_default_headers = {'Content-Type': 'text/xml; charset=utf-8'}
+
+ provisioning_success_status = 'Ready'
+
+ def setUp(self):
+ super(TestGoalStateHealthReporter, self).setUp()
+ patches = ExitStack()
+ self.addCleanup(patches.close)
+
+ patches.enter_context(
+ mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
+ self.read_file_or_url = patches.enter_context(
+ mock.patch.object(azure_helper.url_helper, 'read_file_or_url'))
+
+ self.post = patches.enter_context(
+ mock.patch.object(azure_helper.AzureEndpointHttpClient,
+ 'post'))
+
+ self.GoalState = patches.enter_context(
+ mock.patch.object(azure_helper, 'GoalState'))
+ self.GoalState.return_value.container_id = \
+ self.default_parameters['container_id']
+ self.GoalState.return_value.instance_id = \
+ self.default_parameters['instance_id']
+ self.GoalState.return_value.incarnation = \
+ self.default_parameters['incarnation']
+
+ def _get_formatted_health_report_xml_string(self, **kwargs):
+ return HEALTH_REPORT_XML_TEMPLATE.format(**kwargs)
+
+ def _get_report_ready_health_document(self):
+ return self._get_formatted_health_report_xml_string(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ health_status=self.provisioning_success_status,
+ health_detail_subsection='')
+
+ def test_send_ready_signal_sends_post_request(self):
+ health_document = self._get_report_ready_health_document()
+ client = azure_helper.AzureEndpointHttpClient(mock.MagicMock())
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ client, self.test_endpoint)
+ reporter.send_ready_signal()
+ self.assertEqual(1, self.post.call_count)
+ self.assertEqual(
+ mock.call(
+ self.test_url,
+ data=health_document,
+ extra_headers=self.test_default_headers),
+ self.post.call_args)
+
+ def test_send_ready_signal_health_document(self):
+ health_document = self._get_report_ready_health_document()
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+ self.test_endpoint)
+ generated_health_document = reporter.build_report(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ status=self.provisioning_success_status)
+ self.assertEqual(health_document, generated_health_document)
+ self.assertIn(str(self.default_parameters['incarnation']),
+ generated_health_document)
+ self.assertIn(self.default_parameters['container_id'],
+ generated_health_document)
+ self.assertIn(self.default_parameters['instance_id'],
+ generated_health_document)
+ self.assertIn(self.provisioning_success_status,
+ generated_health_document)
+ self.assertNotIn('', generated_health_document)
+ self.assertNotIn('', generated_health_document)
+ self.assertNotIn('', generated_health_document)
+
+ def test_send_ready_signal_calls_build_report(self):
+ patches = ExitStack()
+ self.addCleanup(patches.close)
+ build_report = patches.enter_context(
+ mock.patch.object(
+ azure_helper.GoalStateHealthReporter, 'build_report'))
+
+ reporter = azure_helper.GoalStateHealthReporter(
+ azure_helper.GoalState(mock.MagicMock(), mock.MagicMock()),
+ azure_helper.AzureEndpointHttpClient(mock.MagicMock()),
+ self.test_endpoint)
+ reporter.send_ready_signal()
+
+ self.assertEqual(1, build_report.call_count)
+ self.assertEqual(
+ mock.call(
+ incarnation=self.default_parameters['incarnation'],
+ container_id=self.default_parameters['container_id'],
+ instance_id=self.default_parameters['instance_id'],
+ status=self.provisioning_success_status),
+ build_report.call_args)
+
+
class TestWALinuxAgentShim(CiTestCase):
def setUp(self):
@@ -383,14 +593,21 @@ def setUp(self):
patches.enter_context(
mock.patch.object(azure_helper.time, 'sleep', mock.MagicMock()))
- def test_http_client_uses_certificate(self):
+ self.test_incarnation = 'TestIncarnation'
+ self.test_container_id = 'TestContainerId'
+ self.test_instance_id = 'TestInstanceId'
+ self.GoalState.return_value.incarnation = self.test_incarnation
+ self.GoalState.return_value.container_id = self.test_container_id
+ self.GoalState.return_value.instance_id = self.test_instance_id
+
+ def test_azure_endpoint_client_uses_certificate_during_report_ready(self):
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
self.assertEqual(
[mock.call(self.OpenSSLManager.return_value.certificate)],
self.AzureEndpointHttpClient.call_args_list)
- def test_correct_url_used_for_goalstate(self):
+ def test_correct_url_used_for_goalstate_during_report_ready(self):
self.find_endpoint.return_value = 'test_endpoint'
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
@@ -439,43 +656,67 @@ def test_correct_url_used_for_report_ready(self):
expected_url = 'http://test_endpoint/machine?comp=health'
self.assertEqual(
[mock.call(expected_url, data=mock.ANY, extra_headers=mock.ANY)],
- self.AzureEndpointHttpClient.return_value.post.call_args_list)
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args_list)
def test_goal_state_values_used_for_report_ready(self):
- self.GoalState.return_value.incarnation = 'TestIncarnation'
- self.GoalState.return_value.container_id = 'TestContainerId'
- self.GoalState.return_value.instance_id = 'TestInstanceId'
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
posted_document = (
- self.AzureEndpointHttpClient.return_value.post.call_args[1]['data']
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args[1]['data']
)
- self.assertIn('TestIncarnation', posted_document)
- self.assertIn('TestContainerId', posted_document)
- self.assertIn('TestInstanceId', posted_document)
+ self.assertIn(self.test_incarnation, posted_document)
+ self.assertIn(self.test_container_id, posted_document)
+ self.assertIn(self.test_instance_id, posted_document)
+
+ def test_xml_elems_in_report_ready(self):
+ shim = wa_shim()
+ shim.register_with_azure_and_fetch_data()
+ health_document = HEALTH_REPORT_XML_TEMPLATE.format(
+ incarnation=self.test_incarnation,
+ container_id=self.test_container_id,
+ instance_id=self.test_instance_id,
+ health_status='Ready',
+ health_detail_subsection='')
+ posted_document = (
+ self.AzureEndpointHttpClient.return_value.post
+ .call_args[1]['data'])
+ self.assertEqual(health_document, posted_document)
def test_clean_up_can_be_called_at_any_time(self):
shim = wa_shim()
shim.clean_up()
- def test_clean_up_will_clean_up_openssl_manager_if_instantiated(self):
+ def test_clean_up_after_report_ready(self):
shim = wa_shim()
shim.register_with_azure_and_fetch_data()
shim.clean_up()
self.assertEqual(
1, self.OpenSSLManager.return_value.clean_up.call_count)
- def test_failure_to_fetch_goalstate_bubbles_up(self):
- class SentinelException(Exception):
- pass
- self.AzureEndpointHttpClient.return_value.get.side_effect = (
- SentinelException)
+ def test_fetch_goalstate_during_report_ready_raises_exc_on_get_exc(self):
+ self.AzureEndpointHttpClient.return_value.get \
+ .side_effect = (SentinelException)
shim = wa_shim()
self.assertRaises(SentinelException,
shim.register_with_azure_and_fetch_data)
+ def test_fetch_goalstate_during_report_ready_raises_exc_on_parse_exc(self):
+ self.GoalState.side_effect = SentinelException
+ shim = wa_shim()
+ self.assertRaises(SentinelException,
+ shim.register_with_azure_and_fetch_data)
-class TestGetMetadataFromFabric(CiTestCase):
+ def test_failure_to_send_report_ready_health_doc_bubbles_up(self):
+ self.AzureEndpointHttpClient.return_value.post \
+ .side_effect = SentinelException
+ shim = wa_shim()
+ self.assertRaises(SentinelException,
+ shim.register_with_azure_and_fetch_data)
+
+
+class TestGetMetadataGoalStateXMLAndReportReadyToFabric(CiTestCase):
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_data_from_shim_returned(self, shim):
@@ -491,14 +732,39 @@ def test_success_calls_clean_up(self, shim):
@mock.patch.object(azure_helper, 'WALinuxAgentShim')
def test_failure_in_registration_calls_clean_up(self, shim):
- class SentinelException(Exception):
- pass
shim.return_value.register_with_azure_and_fetch_data.side_effect = (
SentinelException)
self.assertRaises(SentinelException,
azure_helper.get_metadata_from_fabric)
self.assertEqual(1, shim.return_value.clean_up.call_count)
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_calls_shim_register_with_azure_and_fetch_data(self, shim):
+ pubkey_info = mock.MagicMock()
+ azure_helper.get_metadata_from_fabric(pubkey_info=pubkey_info)
+ self.assertEqual(
+ 1,
+ shim.return_value
+ .register_with_azure_and_fetch_data.call_count)
+ self.assertEqual(
+ mock.call(pubkey_info=pubkey_info),
+ shim.return_value
+ .register_with_azure_and_fetch_data.call_args)
+
+ @mock.patch.object(azure_helper, 'WALinuxAgentShim')
+ def test_instantiates_shim_with_kwargs(self, shim):
+ fallback_lease_file = mock.MagicMock()
+ dhcp_options = mock.MagicMock()
+ azure_helper.get_metadata_from_fabric(
+ fallback_lease_file=fallback_lease_file,
+ dhcp_opts=dhcp_options)
+ self.assertEqual(1, shim.call_count)
+ self.assertEqual(
+ mock.call(
+ fallback_lease_file=fallback_lease_file,
+ dhcp_options=dhcp_options),
+ shim.call_args)
+
class TestExtractIpAddressFromNetworkd(CiTestCase):
diff --git a/tools/.github-cla-signers b/tools/.github-cla-signers
index e6690e5b3a8..9012156e4b3 100644
--- a/tools/.github-cla-signers
+++ b/tools/.github-cla-signers
@@ -2,6 +2,7 @@ beezly
bipinbachhao
candlerb
dhensby
+johnsonshi
lucasmoura
matthewruffell
nishigori