Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from pathlib import Path
from threading import Lock
from time import sleep
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import requests
from azure.mgmt.compute import ComputeManagementClient # type: ignore
from azure.mgmt.compute.models import VirtualMachine # type: ignore
from azure.mgmt.marketplaceordering import MarketplaceOrderingAgreements # type: ignore
from azure.mgmt.network import NetworkManagementClient # type: ignore
from azure.mgmt.network.models import ( # type: ignore
Expand Down Expand Up @@ -47,6 +48,7 @@
from dataclasses_json import dataclass_json
from marshmallow import validate
from PIL import Image, UnidentifiedImageError
from retry import retry

from lisa import schema
from lisa.environment import Environment, load_environments
Expand All @@ -58,6 +60,7 @@
LisaTimeoutException,
constants,
field_metadata,
get_matched_str,
strip_strs,
)
from lisa.util.logger import Logger
Expand All @@ -69,6 +72,10 @@

AZURE_SHARED_RG_NAME = "lisa_shared_resource"

PATTERN_NIC_NAME = re.compile(r"Microsoft.Network/networkInterfaces/(.*)", re.M)
PATTERN_PUBLIC_IP_NAME = re.compile(
r"providers/Microsoft.Network/publicIPAddresses/(.*)", re.M
)

# when call sdk APIs, it's easy to have conflict on access auth files. Use lock
# to prevent it happens.
Expand All @@ -88,6 +95,8 @@ class NodeContext:
username: str = ""
password: str = ""
private_key_file: str = ""
public_ip_address: str = ""
private_ip_address: str = ""


@dataclass_json()
Expand Down Expand Up @@ -1107,19 +1116,19 @@ def load_environment(
if environment_runbook.nodes_raw is None:
environment_runbook.nodes_raw = []

vms_map: Dict[str, VirtualMachine] = {}
compute_client = get_compute_client(platform)
vms = compute_client.virtual_machines.list(resource_group_name)
for vm in vms:
node_schema = schema.RemoteNode(name=vm.name)
environment_runbook.nodes_raw.append(node_schema)
vms_map[vm.name] = vm

environments = load_environments(
schema.EnvironmentRoot(environments=[environment_runbook])
)
environment = next(x for x in environments.values())

public_ips = platform.load_public_ips_from_resource_group(resource_group_name, log)

platform_runbook: schema.Platform = platform.runbook
for node in environment.nodes.list():
assert isinstance(node, RemoteNode)
Expand All @@ -1131,9 +1140,15 @@ def load_environment(
node_context.username = platform_runbook.admin_username
node_context.password = platform_runbook.admin_password
node_context.private_key_file = platform_runbook.admin_private_key_file

(
node_context.public_ip_address,
node_context.private_ip_address,
) = get_primary_ip_addresses(
platform, resource_group_name, vms_map[node_context.vm_name]
)
node.set_connection_info(
public_address=public_ips[node.name],
address=node_context.private_ip_address,
public_address=node_context.public_ip_address,
username=node_context.username,
password=node_context.password,
private_key_file=node_context.private_key_file,
Expand All @@ -1158,6 +1173,31 @@ def get_vm(platform: "AzurePlatform", node: Node) -> Any:
return vm


@retry(exceptions=LisaException, tries=150, delay=2)
def get_primary_ip_addresses(
platform: "AzurePlatform", resource_group_name: str, vm: VirtualMachine
) -> Tuple[str, str]:
network_client = get_network_client(platform)
for network_interface in vm.network_profile.network_interfaces:
nic_name = get_matched_str(network_interface.id, PATTERN_NIC_NAME)
nic = network_client.network_interfaces.get(resource_group_name, nic_name)
if nic.primary:
if not nic.ip_configurations[0].public_ip_address:
raise LisaException(f"no public address found in nic {nic.name}")
public_ip_name = get_matched_str(
nic.ip_configurations[0].public_ip_address.id, PATTERN_PUBLIC_IP_NAME
)
public_ip_address = network_client.public_ip_addresses.get(
resource_group_name,
public_ip_name,
)
return (
public_ip_address.ip_address,
nic.ip_configurations[0].private_ip_address,
)
raise LisaException(f"fail to find primary nic for vm {vm.name}")


# find resource based on type name from resources section in arm template
def find_by_name(resources: Any, type_name: str) -> Any:
return next(x for x in resources if x["type"] == type_name)
Expand Down
6 changes: 5 additions & 1 deletion lisa/sut_orchestrator/azure/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
get_network_client,
get_node_context,
get_or_create_file_share,
get_primary_ip_addresses,
get_virtual_networks,
get_vm,
global_credential_access_lock,
Expand Down Expand Up @@ -120,7 +121,10 @@ def _start(self, wait: bool = True) -> None:
# the public ip address will change, so reload here
self._node = cast(RemoteNode, self._node)
platform: AzurePlatform = self._platform # type: ignore
public_ip = platform.load_public_ip(self._node, self._log)

public_ip, _ = get_primary_ip_addresses(
platform, self._resource_group_name, get_vm(platform, self._node)
)
node_info = self._node.connection_info
node_info[constants.ENVIRONMENTS_NODES_REMOTE_PUBLIC_ADDRESS] = public_ip
self._node.set_connection_info(**node_info)
Expand Down
145 changes: 24 additions & 121 deletions lisa/sut_orchestrator/azure/platform_.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
VirtualMachineImage,
)
from azure.mgmt.marketplaceordering.models import AgreementTerms # type: ignore
from azure.mgmt.network.models import NetworkInterface # type: ignore
from azure.mgmt.resource import SubscriptionClient # type: ignore
from azure.mgmt.resource.resources.models import ( # type: ignore
Deployment,
Expand Down Expand Up @@ -90,9 +89,9 @@
get_compute_client,
get_environment_context,
get_marketplace_ordering_client,
get_network_client,
get_node_context,
get_or_create_storage_container,
get_primary_ip_addresses,
get_resource_management_client,
get_storage_account_name,
get_storage_client,
Expand Down Expand Up @@ -133,8 +132,6 @@

# names in arm template, they should be changed with template together.
RESOURCE_ID_PORT_POSTFIX = "-ssh"
RESOURCE_ID_NIC_PATTERN = re.compile(r"(.+)-nic-0")
RESOURCE_ID_PUBLIC_IP_PATTERN = re.compile(r"(.+)-public-ip")

# Ubuntu 18.04:
# [ 0.000000] Hyper-V Host Build:18362-10.0-3-0.3198
Expand Down Expand Up @@ -1380,138 +1377,53 @@ def _parse_detail_errors(self, error: Any) -> List[str]:
# the VM may not be queried after deployed. use retry to mitigate it.
@retry(exceptions=LisaException, tries=150, delay=2)
def _load_vms(
self, environment: Environment, log: Logger
self, resource_group_name: str, log: Logger
) -> Dict[str, VirtualMachine]:
compute_client = get_compute_client(self, api_version="2020-06-01")
environment_context = get_environment_context(environment=environment)

log.debug(
f"listing vm in resource group "
f"'{environment_context.resource_group_name}'"
)
log.debug(f"listing vm in resource group {resource_group_name}")
vms_map: Dict[str, VirtualMachine] = {}
vms = compute_client.virtual_machines.list(
environment_context.resource_group_name
)
vms = compute_client.virtual_machines.list(resource_group_name)
for vm in vms:
vms_map[vm.name] = vm
log.debug(f" found vm {vm.name}")
if not vms_map:
raise LisaException(
f"deployment succeeded, but VM not found in 5 minutes "
f"from '{environment_context.resource_group_name}'"
"deployment succeeded, but VM not found in 5 minutes "
f"from '{resource_group_name}'"
)
return vms_map

# Use Exception, because there may be credential conflict error. Make it
# retriable.
@retry(exceptions=Exception, tries=150, delay=2)
def _load_nics(
self, environment: Environment, log: Logger
) -> Dict[str, NetworkInterface]:
network_client = get_network_client(self)
environment_context = get_environment_context(environment=environment)
def initialize_environment(self, environment: Environment, log: Logger) -> None:
vms_map: Dict[str, VirtualMachine] = {}

log.debug(
f"listing network interfaces in resource group "
f"'{environment_context.resource_group_name}'"
)
# load nics
nics_map: Dict[str, NetworkInterface] = {}
network_interfaces = network_client.network_interfaces.list(
environment_context.resource_group_name
)
for nic in network_interfaces:
# nic name is like lisa-test-20220316-182126-985-e0-n0-nic-2, get vm
# name part for later pick only find primary nic, which is ended by
# -nic-0
node_name_from_nic = RESOURCE_ID_NIC_PATTERN.findall(nic.name)
Comment thread
squirrelsc marked this conversation as resolved.
if node_name_from_nic:
name = node_name_from_nic[0]
nics_map[name] = nic
log.debug(f" found nic '{nic.name}', and saved for next step.")
else:
log.debug(
f" found nic '{nic.name}', but dropped, "
"because it's not primary nic."
)
if not nics_map:
raise LisaException(
f"deployment succeeded, but network interfaces not found in 5 minutes "
f"from '{environment_context.resource_group_name}'"
)
return nics_map
environment_context = get_environment_context(environment=environment)
resource_group_name = environment_context.resource_group_name
vms_map = self._load_vms(resource_group_name, log)

@retry(exceptions=LisaException, tries=150, delay=2)
def load_public_ips_from_resource_group(
self, resource_group_name: str, log: Logger
) -> Dict[str, str]:
network_client = get_network_client(self)
log.debug(f"listing public ips in resource group '{resource_group_name}'")
# get public IP
public_ip_addresses = network_client.public_ip_addresses.list(
resource_group_name
)
public_ips_map: Dict[str, str] = {}
for ip_address in public_ip_addresses:
# nic name is like node-0-nic-2, get vm name part for later pick
# only find primary nic, which is ended by -nic-0
node_name_from_public_ip = RESOURCE_ID_PUBLIC_IP_PATTERN.findall(
ip_address.name
)
assert (
ip_address
), f"public IP address cannot be empty, ip_address object: {ip_address}"
if node_name_from_public_ip:
name = node_name_from_public_ip[0]
public_ips_map[name] = ip_address.ip_address
log.debug(
f" found public IP '{ip_address.name}', and saved for next step."
)
else:
log.debug(
f" found public IP '{ip_address.name}', but dropped "
"because it's not primary nic."
)
if not public_ips_map:
vms_name_list = list(vms_map.keys())
if len(vms_name_list) < len(environment.nodes):
raise LisaException(
f"deployment succeeded, but public ips not found in 5 minutes "
f"from '{resource_group_name}'"
f"{len(vms_name_list)} vms count is less than "
f"requirement count {len(environment.nodes)}"
)
return public_ips_map

def initialize_environment(self, environment: Environment, log: Logger) -> None:
node_context_map: Dict[str, Node] = {}
index = 0
for node in environment.nodes.list():
node_context = get_node_context(node)
node_context_map[node_context.vm_name] = node

vms_map: Dict[str, VirtualMachine] = self._load_vms(environment, log)
nics_map: Dict[str, NetworkInterface] = self._load_nics(environment, log)
environment_context = get_environment_context(environment=environment)
public_ips_map: Dict[str, str] = self.load_public_ips_from_resource_group(
environment_context.resource_group_name, log
)

for vm_name, node in node_context_map.items():
node_context = get_node_context(node)
vm = vms_map.get(vm_name, None)
if not vm:
raise LisaException(
f"cannot find vm: '{vm_name}', make sure deployment is correct."
)
nic = nics_map[vm_name]
public_ip = public_ips_map[vm_name]

address = nic.ip_configurations[0].private_ip_address
vm_name = vms_name_list[index]
node_context.vm_name = vm_name
if not node.name:
node.name = vm_name

public_address, private_address = get_primary_ip_addresses(
self, resource_group_name, vms_map[vm_name]
)
index = index + 1
assert isinstance(node, RemoteNode)
node.set_connection_info(
address=address,
address=private_address,
port=22,
public_address=public_ip,
public_address=public_address,
public_port=22,
username=node_context.username,
password=node_context.password,
Expand Down Expand Up @@ -1687,15 +1599,6 @@ def get_sorted_vm_sizes(
sorted_capabilities.extend(level_capabilities)
return sorted_capabilities

def load_public_ip(self, node: Node, log: Logger) -> str:
node_context = get_node_context(node)
vm_name = node_context.vm_name
resource_group_name = node_context.resource_group_name
public_ips_map: Dict[str, str] = self.load_public_ips_from_resource_group(
resource_group_name=resource_group_name, log=self._log
)
return public_ips_map[vm_name]

@lru_cache(maxsize=10) # noqa: B019
def _resolve_marketplace_image(
self, location: str, marketplace: AzureVmMarketplaceSchema
Expand Down
9 changes: 4 additions & 5 deletions lisa/sut_orchestrator/azure/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_compute_client,
get_environment_context,
get_or_create_storage_container,
get_primary_ip_addresses,
get_storage_account_name,
get_vm,
load_environment,
Expand Down Expand Up @@ -232,12 +233,10 @@ def _get_public_ip_address(
self, platform: AzurePlatform, virtual_machine: Any
) -> str:
runbook: VhdTransformerSchema = self.runbook
public_ips = platform.load_public_ips_from_resource_group(
runbook.resource_group_name, self._log
)

public_ip_address: str = public_ips[runbook.vm_name]

public_ip_address, _ = get_primary_ip_addresses(
platform, runbook.resource_group_name, virtual_machine
)
assert (
public_ip_address
), "cannot find public IP address, make sure the VM is in running status."
Expand Down