From 94865db74dd1bf9904eb8431ea7262304e60efef Mon Sep 17 00:00:00 2001 From: Lili Deng Date: Fri, 13 Jan 2023 17:41:10 +0800 Subject: [PATCH] [improvement] support specify user customized env --- lisa/sut_orchestrator/azure/common.py | 50 ++++++- lisa/sut_orchestrator/azure/features.py | 6 +- lisa/sut_orchestrator/azure/platform_.py | 145 ++++---------------- lisa/sut_orchestrator/azure/transformers.py | 9 +- 4 files changed, 78 insertions(+), 132 deletions(-) diff --git a/lisa/sut_orchestrator/azure/common.py b/lisa/sut_orchestrator/azure/common.py index 47d21a07f4..3aa6d92ea7 100644 --- a/lisa/sut_orchestrator/azure/common.py +++ b/lisa/sut_orchestrator/azure/common.py @@ -8,10 +8,11 @@ from pathlib import Path from threading import Lock from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import requests from azure.mgmt.compute import ComputeManagementClient # type: ignore +from azure.mgmt.compute.models import VirtualMachine # type: ignore from azure.mgmt.marketplaceordering import MarketplaceOrderingAgreements # type: ignore from azure.mgmt.network import NetworkManagementClient # type: ignore from azure.mgmt.network.models import ( # type: ignore @@ -47,6 +48,7 @@ from dataclasses_json import dataclass_json from marshmallow import validate from PIL import Image, UnidentifiedImageError +from retry import retry from lisa import schema from lisa.environment import Environment, load_environments @@ -58,6 +60,7 @@ LisaTimeoutException, constants, field_metadata, + get_matched_str, strip_strs, ) from lisa.util.logger import Logger @@ -69,6 +72,10 @@ AZURE_SHARED_RG_NAME = "lisa_shared_resource" +PATTERN_NIC_NAME = re.compile(r"Microsoft.Network/networkInterfaces/(.*)", re.M) +PATTERN_PUBLIC_IP_NAME = re.compile( + r"providers/Microsoft.Network/publicIPAddresses/(.*)", re.M +) # when call sdk APIs, it's easy to have conflict on access auth files. Use lock # to prevent it happens. @@ -88,6 +95,8 @@ class NodeContext: username: str = "" password: str = "" private_key_file: str = "" + public_ip_address: str = "" + private_ip_address: str = "" @dataclass_json() @@ -1107,19 +1116,19 @@ def load_environment( if environment_runbook.nodes_raw is None: environment_runbook.nodes_raw = [] + vms_map: Dict[str, VirtualMachine] = {} compute_client = get_compute_client(platform) vms = compute_client.virtual_machines.list(resource_group_name) for vm in vms: node_schema = schema.RemoteNode(name=vm.name) environment_runbook.nodes_raw.append(node_schema) + vms_map[vm.name] = vm environments = load_environments( schema.EnvironmentRoot(environments=[environment_runbook]) ) environment = next(x for x in environments.values()) - public_ips = platform.load_public_ips_from_resource_group(resource_group_name, log) - platform_runbook: schema.Platform = platform.runbook for node in environment.nodes.list(): assert isinstance(node, RemoteNode) @@ -1131,9 +1140,15 @@ def load_environment( node_context.username = platform_runbook.admin_username node_context.password = platform_runbook.admin_password node_context.private_key_file = platform_runbook.admin_private_key_file - + ( + node_context.public_ip_address, + node_context.private_ip_address, + ) = get_primary_ip_addresses( + platform, resource_group_name, vms_map[node_context.vm_name] + ) node.set_connection_info( - public_address=public_ips[node.name], + address=node_context.private_ip_address, + public_address=node_context.public_ip_address, username=node_context.username, password=node_context.password, private_key_file=node_context.private_key_file, @@ -1158,6 +1173,31 @@ def get_vm(platform: "AzurePlatform", node: Node) -> Any: return vm +@retry(exceptions=LisaException, tries=150, delay=2) +def get_primary_ip_addresses( + platform: "AzurePlatform", resource_group_name: str, vm: VirtualMachine +) -> Tuple[str, str]: + network_client = get_network_client(platform) + for network_interface in vm.network_profile.network_interfaces: + nic_name = get_matched_str(network_interface.id, PATTERN_NIC_NAME) + nic = network_client.network_interfaces.get(resource_group_name, nic_name) + if nic.primary: + if not nic.ip_configurations[0].public_ip_address: + raise LisaException(f"no public address found in nic {nic.name}") + public_ip_name = get_matched_str( + nic.ip_configurations[0].public_ip_address.id, PATTERN_PUBLIC_IP_NAME + ) + public_ip_address = network_client.public_ip_addresses.get( + resource_group_name, + public_ip_name, + ) + return ( + public_ip_address.ip_address, + nic.ip_configurations[0].private_ip_address, + ) + raise LisaException(f"fail to find primary nic for vm {vm.name}") + + # find resource based on type name from resources section in arm template def find_by_name(resources: Any, type_name: str) -> Any: return next(x for x in resources if x["type"] == type_name) diff --git a/lisa/sut_orchestrator/azure/features.py b/lisa/sut_orchestrator/azure/features.py index 5ef74d3756..a796161579 100644 --- a/lisa/sut_orchestrator/azure/features.py +++ b/lisa/sut_orchestrator/azure/features.py @@ -81,6 +81,7 @@ get_network_client, get_node_context, get_or_create_file_share, + get_primary_ip_addresses, get_virtual_networks, get_vm, global_credential_access_lock, @@ -120,7 +121,10 @@ def _start(self, wait: bool = True) -> None: # the public ip address will change, so reload here self._node = cast(RemoteNode, self._node) platform: AzurePlatform = self._platform # type: ignore - public_ip = platform.load_public_ip(self._node, self._log) + + public_ip, _ = get_primary_ip_addresses( + platform, self._resource_group_name, get_vm(platform, self._node) + ) node_info = self._node.connection_info node_info[constants.ENVIRONMENTS_NODES_REMOTE_PUBLIC_ADDRESS] = public_ip self._node.set_connection_info(**node_info) diff --git a/lisa/sut_orchestrator/azure/platform_.py b/lisa/sut_orchestrator/azure/platform_.py index 9b2f7606de..96e4c0ec8e 100644 --- a/lisa/sut_orchestrator/azure/platform_.py +++ b/lisa/sut_orchestrator/azure/platform_.py @@ -31,7 +31,6 @@ VirtualMachineImage, ) from azure.mgmt.marketplaceordering.models import AgreementTerms # type: ignore -from azure.mgmt.network.models import NetworkInterface # type: ignore from azure.mgmt.resource import SubscriptionClient # type: ignore from azure.mgmt.resource.resources.models import ( # type: ignore Deployment, @@ -90,9 +89,9 @@ get_compute_client, get_environment_context, get_marketplace_ordering_client, - get_network_client, get_node_context, get_or_create_storage_container, + get_primary_ip_addresses, get_resource_management_client, get_storage_account_name, get_storage_client, @@ -133,8 +132,6 @@ # names in arm template, they should be changed with template together. RESOURCE_ID_PORT_POSTFIX = "-ssh" -RESOURCE_ID_NIC_PATTERN = re.compile(r"(.+)-nic-0") -RESOURCE_ID_PUBLIC_IP_PATTERN = re.compile(r"(.+)-public-ip") # Ubuntu 18.04: # [ 0.000000] Hyper-V Host Build:18362-10.0-3-0.3198 @@ -1380,138 +1377,53 @@ def _parse_detail_errors(self, error: Any) -> List[str]: # the VM may not be queried after deployed. use retry to mitigate it. @retry(exceptions=LisaException, tries=150, delay=2) def _load_vms( - self, environment: Environment, log: Logger + self, resource_group_name: str, log: Logger ) -> Dict[str, VirtualMachine]: compute_client = get_compute_client(self, api_version="2020-06-01") - environment_context = get_environment_context(environment=environment) - log.debug( - f"listing vm in resource group " - f"'{environment_context.resource_group_name}'" - ) + log.debug(f"listing vm in resource group {resource_group_name}") vms_map: Dict[str, VirtualMachine] = {} - vms = compute_client.virtual_machines.list( - environment_context.resource_group_name - ) + vms = compute_client.virtual_machines.list(resource_group_name) for vm in vms: vms_map[vm.name] = vm log.debug(f" found vm {vm.name}") if not vms_map: raise LisaException( - f"deployment succeeded, but VM not found in 5 minutes " - f"from '{environment_context.resource_group_name}'" + "deployment succeeded, but VM not found in 5 minutes " + f"from '{resource_group_name}'" ) return vms_map - # Use Exception, because there may be credential conflict error. Make it - # retriable. - @retry(exceptions=Exception, tries=150, delay=2) - def _load_nics( - self, environment: Environment, log: Logger - ) -> Dict[str, NetworkInterface]: - network_client = get_network_client(self) - environment_context = get_environment_context(environment=environment) + def initialize_environment(self, environment: Environment, log: Logger) -> None: + vms_map: Dict[str, VirtualMachine] = {} - log.debug( - f"listing network interfaces in resource group " - f"'{environment_context.resource_group_name}'" - ) - # load nics - nics_map: Dict[str, NetworkInterface] = {} - network_interfaces = network_client.network_interfaces.list( - environment_context.resource_group_name - ) - for nic in network_interfaces: - # nic name is like lisa-test-20220316-182126-985-e0-n0-nic-2, get vm - # name part for later pick only find primary nic, which is ended by - # -nic-0 - node_name_from_nic = RESOURCE_ID_NIC_PATTERN.findall(nic.name) - if node_name_from_nic: - name = node_name_from_nic[0] - nics_map[name] = nic - log.debug(f" found nic '{nic.name}', and saved for next step.") - else: - log.debug( - f" found nic '{nic.name}', but dropped, " - "because it's not primary nic." - ) - if not nics_map: - raise LisaException( - f"deployment succeeded, but network interfaces not found in 5 minutes " - f"from '{environment_context.resource_group_name}'" - ) - return nics_map + environment_context = get_environment_context(environment=environment) + resource_group_name = environment_context.resource_group_name + vms_map = self._load_vms(resource_group_name, log) - @retry(exceptions=LisaException, tries=150, delay=2) - def load_public_ips_from_resource_group( - self, resource_group_name: str, log: Logger - ) -> Dict[str, str]: - network_client = get_network_client(self) - log.debug(f"listing public ips in resource group '{resource_group_name}'") - # get public IP - public_ip_addresses = network_client.public_ip_addresses.list( - resource_group_name - ) - public_ips_map: Dict[str, str] = {} - for ip_address in public_ip_addresses: - # nic name is like node-0-nic-2, get vm name part for later pick - # only find primary nic, which is ended by -nic-0 - node_name_from_public_ip = RESOURCE_ID_PUBLIC_IP_PATTERN.findall( - ip_address.name - ) - assert ( - ip_address - ), f"public IP address cannot be empty, ip_address object: {ip_address}" - if node_name_from_public_ip: - name = node_name_from_public_ip[0] - public_ips_map[name] = ip_address.ip_address - log.debug( - f" found public IP '{ip_address.name}', and saved for next step." - ) - else: - log.debug( - f" found public IP '{ip_address.name}', but dropped " - "because it's not primary nic." - ) - if not public_ips_map: + vms_name_list = list(vms_map.keys()) + if len(vms_name_list) < len(environment.nodes): raise LisaException( - f"deployment succeeded, but public ips not found in 5 minutes " - f"from '{resource_group_name}'" + f"{len(vms_name_list)} vms count is less than " + f"requirement count {len(environment.nodes)}" ) - return public_ips_map - def initialize_environment(self, environment: Environment, log: Logger) -> None: - node_context_map: Dict[str, Node] = {} + index = 0 for node in environment.nodes.list(): node_context = get_node_context(node) - node_context_map[node_context.vm_name] = node - - vms_map: Dict[str, VirtualMachine] = self._load_vms(environment, log) - nics_map: Dict[str, NetworkInterface] = self._load_nics(environment, log) - environment_context = get_environment_context(environment=environment) - public_ips_map: Dict[str, str] = self.load_public_ips_from_resource_group( - environment_context.resource_group_name, log - ) - - for vm_name, node in node_context_map.items(): - node_context = get_node_context(node) - vm = vms_map.get(vm_name, None) - if not vm: - raise LisaException( - f"cannot find vm: '{vm_name}', make sure deployment is correct." - ) - nic = nics_map[vm_name] - public_ip = public_ips_map[vm_name] - - address = nic.ip_configurations[0].private_ip_address + vm_name = vms_name_list[index] + node_context.vm_name = vm_name if not node.name: node.name = vm_name - + public_address, private_address = get_primary_ip_addresses( + self, resource_group_name, vms_map[vm_name] + ) + index = index + 1 assert isinstance(node, RemoteNode) node.set_connection_info( - address=address, + address=private_address, port=22, - public_address=public_ip, + public_address=public_address, public_port=22, username=node_context.username, password=node_context.password, @@ -1687,15 +1599,6 @@ def get_sorted_vm_sizes( sorted_capabilities.extend(level_capabilities) return sorted_capabilities - def load_public_ip(self, node: Node, log: Logger) -> str: - node_context = get_node_context(node) - vm_name = node_context.vm_name - resource_group_name = node_context.resource_group_name - public_ips_map: Dict[str, str] = self.load_public_ips_from_resource_group( - resource_group_name=resource_group_name, log=self._log - ) - return public_ips_map[vm_name] - @lru_cache(maxsize=10) # noqa: B019 def _resolve_marketplace_image( self, location: str, marketplace: AzureVmMarketplaceSchema diff --git a/lisa/sut_orchestrator/azure/transformers.py b/lisa/sut_orchestrator/azure/transformers.py index 589ed3045f..081bbc3db7 100644 --- a/lisa/sut_orchestrator/azure/transformers.py +++ b/lisa/sut_orchestrator/azure/transformers.py @@ -30,6 +30,7 @@ get_compute_client, get_environment_context, get_or_create_storage_container, + get_primary_ip_addresses, get_storage_account_name, get_vm, load_environment, @@ -232,12 +233,10 @@ def _get_public_ip_address( self, platform: AzurePlatform, virtual_machine: Any ) -> str: runbook: VhdTransformerSchema = self.runbook - public_ips = platform.load_public_ips_from_resource_group( - runbook.resource_group_name, self._log - ) - - public_ip_address: str = public_ips[runbook.vm_name] + public_ip_address, _ = get_primary_ip_addresses( + platform, runbook.resource_group_name, virtual_machine + ) assert ( public_ip_address ), "cannot find public IP address, make sure the VM is in running status."