Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions cloudinit/sources/DataSourceGCE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import datetime
import json
from contextlib import suppress as noop

from base64 import b64decode

Expand All @@ -13,6 +14,7 @@
from cloudinit import sources
from cloudinit import url_helper
from cloudinit import util
from cloudinit.net.dhcp import EphemeralDHCPv4

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,6 +60,7 @@ def get_value(self, path, is_text, is_recursive=False):
class DataSourceGCE(sources.DataSource):

dsname = 'GCE'
perform_dhcp_setup = False

def __init__(self, sys_cfg, distro, paths):
sources.DataSource.__init__(self, sys_cfg, distro, paths)
Expand All @@ -73,10 +76,19 @@ def __init__(self, sys_cfg, distro, paths):

def _get_data(self):
url_params = self.get_url_params()
ret = util.log_time(
LOG.debug, 'Crawl of GCE metadata service',
read_md, kwargs={'address': self.metadata_address,
'url_params': url_params})
network_context = noop()
if self.perform_dhcp_setup:
network_context = EphemeralDHCPv4(self.fallback_interface)
Comment thread
TheRealFalcon marked this conversation as resolved.
with network_context:
ret = util.log_time(
LOG.debug,
"Crawl of GCE metadata service",
read_md,
kwargs={
"address": self.metadata_address,
"url_params": url_params,
},
)

if not ret['success']:
if ret['platform_reports_gce']:
Expand Down Expand Up @@ -117,6 +129,10 @@ def region(self):
return self.availability_zone.rsplit('-', 1)[0]


class DataSourceGCELocal(DataSourceGCE):
perform_dhcp_setup = True


def _write_host_key_to_guest_attributes(key_type, key_value):
url = '%s/%s/%s' % (GUEST_ATTRIBUTES_URL, HOSTKEY_NAMESPACE, key_type)
key_value = key_value.encode('utf-8')
Expand Down Expand Up @@ -272,6 +288,7 @@ def platform_reports_gce():

# Used to match classes to dependencies.
datasources = [
(DataSourceGCELocal, (sources.DEP_FILESYSTEM,)),
(DataSourceGCE, (sources.DEP_FILESYSTEM, sources.DEP_NETWORK)),
]

Expand Down
41 changes: 41 additions & 0 deletions tests/integration_tests/modules/test_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,31 @@ def test_no_problems(self, class_client: IntegrationInstance):
log = client.read_from_file('/var/log/cloud-init.log')
verify_clean_log(log)

def test_correct_datasource_detected(
self, class_client: IntegrationInstance
):
"""Test datasource is detected at the proper boot stage."""
client = class_client
status_file = client.read_from_file("/run/cloud-init/status.json")

platform_datasources = {
"azure": "DataSourceAzure [seed=/dev/sr0]",
"ec2": "DataSourceEc2Local",
"gce": "DataSourceGCELocal",
"oci": "DataSourceOracle",
"openstack": "DataSourceOpenStackLocal [net,ver=2]",
"lxd_container": (
"DataSourceNoCloud "
"[seed=/var/lib/cloud/seed/nocloud-net][dsmode=net]"
),
"lxd_vm": "DataSourceNoCloud [seed=/dev/sr0][dsmode=net]",
}

assert (
platform_datasources[client.settings.PLATFORM]
== json.loads(status_file)["v1"]["datasource"]
)

def _check_common_metadata(self, data):
assert data['base64_encoded_keys'] == []
assert data['merged_cfg'] == 'redacted for non-root user'
Expand Down Expand Up @@ -277,3 +302,19 @@ def test_instance_json_ec2(self, class_client: IntegrationInstance):
assert v1_data['instance_id'] == client.instance.name
assert v1_data['local_hostname'].startswith('ip-')
assert v1_data['region'] == client.cloud.cloud_instance.region

@pytest.mark.gce
def test_instance_json_gce(self, class_client: IntegrationInstance):
client = class_client
instance_json_file = client.read_from_file(
"/run/cloud-init/instance-data.json"
)
data = json.loads(instance_json_file)
self._check_common_metadata(data)
v1_data = data["v1"]
assert v1_data["cloud_name"] == "gce"
assert v1_data["platform"] == "gce"
assert v1_data["subplatform"].startswith("metadata")
assert v1_data["availability_zone"] == client.instance.zone
assert v1_data["instance_id"] == client.instance.instance_id
assert v1_data["local_hostname"] == client.instance.name
1 change: 1 addition & 0 deletions tests/unittests/test_datasource/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
CloudSigma.DataSourceCloudSigma,
ConfigDrive.DataSourceConfigDrive,
DigitalOcean.DataSourceDigitalOcean,
GCE.DataSourceGCELocal,
Hetzner.DataSourceHetzner,
IBMCloud.DataSourceIBMCloud,
LXD.DataSourceLXD,
Expand Down
24 changes: 24 additions & 0 deletions tests/unittests/test_datasource/test_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,5 +360,29 @@ def test_publish_host_keys(self, m_readurl):
self.ds.publish_host_keys(hostkeys)
m_readurl.assert_has_calls(readurl_expected_calls, any_order=True)

@mock.patch(
"cloudinit.sources.DataSourceGCE.EphemeralDHCPv4",
autospec=True,
)
@mock.patch(
"cloudinit.sources.DataSourceGCE.DataSourceGCELocal.fallback_interface"
)
def test_local_datasource_uses_ephemeral_dhcp(self, _m_fallback, m_dhcp):
_set_mock_metadata()
ds = DataSourceGCE.DataSourceGCELocal(
sys_cfg={}, distro=None, paths=None
)
ds._get_data()
assert m_dhcp.call_count == 1

@mock.patch(
"cloudinit.sources.DataSourceGCE.EphemeralDHCPv4",
autospec=True,
)
def test_datasource_doesnt_use_ephemeral_dhcp(self, m_dhcp):
_set_mock_metadata()
ds = DataSourceGCE.DataSourceGCE(sys_cfg={}, distro=None, paths=None)
ds._get_data()
assert m_dhcp.call_count == 0

# vi: ts=4 expandtab