Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import labapi.AzureEnvironment;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.io.IOException;
Expand Down Expand Up @@ -118,13 +119,18 @@ public void acquireTokenClientCredentials_DefaultCacheLookup() throws Exception
Assert.assertNotEquals(result2.accessToken(), result3.accessToken());
}

@Test
public void acquireTokenClientCredentials_Regional() throws Exception {
@DataProvider(name = "regionWithAuthority")
public static Object[][] createData() {
return new Object[][]{{"westus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS},
{"eastus", TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS}};
}

@Test(dataProvider = "regionWithAuthority")
public void acquireTokenClientCredentials_Regional(String[] regionWithAuthority) throws Exception {
String clientId = "2afb0add-2f32-4946-ac90-81a02aa4550e";

assertAcquireTokenCommon_withRegion(clientId, certificate);
assertAcquireTokenCommon_withRegion(clientId, certificate, regionWithAuthority[0], regionWithAuthority[1]);
}

private ClientAssertion getClientAssertion(String clientId) {
return JwtHelper.buildJwt(
clientId,
Expand Down Expand Up @@ -164,15 +170,15 @@ private void assertAcquireTokenCommon_withParameters(String clientId, IClientCre
Assert.assertNotNull(result.accessToken());
}

private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential) throws Exception {
private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredential credential, String region, String regionalAuthority) throws Exception {
ConfidentialClientApplication ccaNoRegion = ConfidentialClientApplication.builder(
clientId, credential).
authority(TestConstants.MICROSOFT_AUTHORITY).
build();

ConfidentialClientApplication ccaRegion = ConfidentialClientApplication.builder(
clientId, credential).
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion("westus").
authority("https://login.microsoft.com/microsoft.onmicrosoft.com").azureRegion(region).
build();

//Ensure behavior when region not specified
Expand All @@ -193,7 +199,7 @@ private void assertAcquireTokenCommon_withRegion(String clientId, IClientCredent

Assert.assertNotNull(resultRegion);
Assert.assertNotNull(resultRegion.accessToken());
Assert.assertEquals(resultRegion.environment(), TestConstants.REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS);
Assert.assertEquals(resultRegion.environment(), regionalAuthority);

IAuthenticationResult resultRegionCached = ccaRegion.acquireToken(ClientCredentialParameters
.builder(Collections.singleton(KEYVAULT_DEFAULT_SCOPE))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ public class TestConstants {
public final static String TENANT_SPECIFIC_AUTHORITY = MICROSOFT_AUTHORITY_HOST + MICROSOFT_AUTHORITY_TENANT;
public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_WESTUS = "westus.login.microsoft.com";

public final static String REGIONAL_MICROSOFT_AUTHORITY_BASIC_HOST_EASTUS = "eastus.login.microsoft.com";

public final static String ARLINGTON_ORGANIZATIONS_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "organizations/";
public final static String ARLINGTON_COMMON_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + "common/";
public final static String ARLINGTON_TENANT_SPECIFIC_AUTHORITY = ARLINGTON_MICROSOFT_AUTHORITY_HOST + ARLINGTON_AUTHORITY_TENANT;
public final static String ARLINGTON_GRAPH_DEFAULT_SCOPE = "https://graph.microsoft.us/.default";


public final static String B2C_AUTHORITY = "https://msidlabb2c.b2clogin.com/msidlabb2c.onmicrosoft.com/";
public final static String B2C_AUTHORITY_LEGACY_FORMAT = "https://msidlabb2c.b2clogin.com/tfp/msidlabb2c.onmicrosoft.com/";

public final static String B2C_ROPC_POLICY = "B2C_1_ROPC_Auth";
public final static String B2C_SIGN_IN_POLICY = "B2C_1_SignInPolicy";
public final static String B2C_AUTHORITY_SIGN_IN = B2C_AUTHORITY + B2C_SIGN_IN_POLICY;
Expand All @@ -49,19 +50,13 @@ public class TestConstants {
public final static String B2C_MICROSOFTLOGIN_ROPC = B2C_MICROSOFTLOGIN_AUTHORITY + B2C_ROPC_POLICY;

public final static String LOCALHOST = "http://localhost:";
public final static String LOCAL_FLAG_ENV_VAR = "MSAL_JAVA_RUN_LOCAL";

public final static String ADFS_AUTHORITY = "https://fs.msidlab8.com/adfs/";
public final static String ADFS_SCOPE = USER_READ_SCOPE;
public final static String ADFS_APP_ID = "PublicClientId";

public final static String CLAIMS = "{\"id_token\":{\"auth_time\":{\"essential\":true}}}";
public final static Set<String> CLIENT_CAPABILITIES_EMPTY = new HashSet<>(Collections.emptySet());
public final static Set<String> CLIENT_CAPABILITIES_LLT = new HashSet<>(Collections.singletonList("llt"));

// cross cloud b2b settings
public final static String AUTHORITY_ARLINGTON = "https://login.microsoftonline.us/" + ARLINGTON_AUTHORITY_TENANT;
public final static String AUTHORITY_MOONCAKE = "https://login.chinacloudapi.cn/mncmsidlab1.partner.onmschina.cn";
public final static String AUTHORITY_PUBLIC_TENANT_SPECIFIC = "https://login.microsoftonline.com/" + MICROSOFT_AUTHORITY_TENANT;

public final static String DEFAULT_ACCESS_TOKEN = "defaultAccessToken";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import java.util.TreeSet;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.*;

class AadInstanceDiscoveryProvider {

Expand All @@ -31,6 +31,8 @@ class AadInstanceDiscoveryProvider {
private static final String DEFAULT_API_VERSION = "2020-06-01";
private static final String IMDS_ENDPOINT = "https://169.254.169.254/metadata/instance/compute/location?" + DEFAULT_API_VERSION + "&format=text";

private static final int IMDS_TIMEOUT = 2;
private static final TimeUnit IMDS_TIMEOUT_UNIT = TimeUnit.SECONDS;
static final TreeSet<String> TRUSTED_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
static final TreeSet<String> TRUSTED_SOVEREIGN_HOSTS_SET = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);

Expand Down Expand Up @@ -71,8 +73,8 @@ static InstanceDiscoveryMetadataEntry getMetadataEntry(URL authorityUrl,
//If region autodetection is enabled and a specific region not already set,
// set the application's region to the discovered region so that future requests can skip the IMDS endpoint call
if (null == msalRequest.application().azureRegion() && msalRequest.application().autoDetectRegion()
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
&& null != detectedRegion) {
msalRequest.application().azureRegion = detectedRegion;
}
cacheRegionInstanceMetadata(authorityUrl.getHost(), msalRequest.application().azureRegion());
serviceBundle.getServerSideTelemetry().getCurrentRequest().regionOutcome(
Expand Down Expand Up @@ -291,33 +293,39 @@ private static String discoverRegion(MsalRequest msalRequest, ServiceBundle serv
return System.getenv(REGION_NAME);
}

try {
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
Map<String, String> headers = new HashMap<>();
headers.put("Metadata", "true");
IHttpResponse httpResponse = executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle);
//Check the IMDS endpoint to retrieve current region (will only work if application is running in an Azure VM)
Map<String, String> headers = new HashMap<>();
headers.put("Metadata", "true");

ExecutorService executor = Executors.newSingleThreadExecutor();
Future<IHttpResponse> future = executor.submit(() -> executeRequest(IMDS_ENDPOINT, headers, msalRequest, serviceBundle));

try {
log.info("Starting call to IMDS endpoint.");
IHttpResponse httpResponse = future.get(IMDS_TIMEOUT, IMDS_TIMEOUT_UNIT);
//If call to IMDS endpoint was successful, return region from response body
if (httpResponse.statusCode() == HttpHelper.HTTP_STATUS_200 && !httpResponse.body().isEmpty()) {
log.info("Region retrieved from IMDS endpoint: " + httpResponse.body());
log.info(String.format("Region retrieved from IMDS endpoint: %s", httpResponse.body()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_IMDS.telemetryValue);

return httpResponse.body();
}

log.warn(String.format("Call to local IMDS failed with status code: %s, or response was empty", httpResponse.statusCode()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);

return null;
} catch (Exception e) {
} catch (Exception ex) {
// handle other exceptions
//IMDS call failed, cannot find region
//The IMDS endpoint is only available from within an Azure environment, so the most common cause of this
// exception will likely be java.net.SocketException: Network is unreachable: connect
log.warn(String.format("Exception during call to local IMDS endpoint: %s", e.getMessage()));
log.warn(String.format("Exception during call to local IMDS endpoint: %s", ex.getMessage()));
currentRequest.regionSource(RegionTelemetry.REGION_SOURCE_FAILED_AUTODETECT.telemetryValue);
future.cancel(true);

return null;
} finally {
executor.shutdownNow();
}

return null;
}

private static void doInstanceDiscoveryAndCache(URL authorityUrl,
Expand Down