diff --git a/src/main/java/com/microsoft/graph/authentication/BaseAuthenticationProvider.java b/src/main/java/com/microsoft/graph/authentication/BaseAuthenticationProvider.java index 78a596a31..d2fb35a9f 100644 --- a/src/main/java/com/microsoft/graph/authentication/BaseAuthenticationProvider.java +++ b/src/main/java/com/microsoft/graph/authentication/BaseAuthenticationProvider.java @@ -6,12 +6,35 @@ import java.util.Locale; import javax.annotation.Nonnull; +import javax.annotation.Nullable; /** * Provides basic common methods for all authentication providers */ public abstract class BaseAuthenticationProvider implements IAuthenticationProvider { private static final HashSet validGraphHostNames = new HashSet<>(Arrays.asList("graph.microsoft.com", "graph.microsoft.us", "dod-graph.microsoft.us", "graph.microsoft.de", "microsoftgraph.chinacloudapi.cn", "canary.graph.microsoft.com")); + private HashSet customHosts; + + /** + * Allow the user to add custom hosts by passing in Array + * @param customHosts custom hosts passed in by user. + */ + public void setCustomHosts(@Nonnull String[] customHosts) { + if(this.customHosts == null){ + this.customHosts = new HashSet(); + } + for(String host: customHosts){ + this.customHosts.add(host.toLowerCase(Locale.ROOT)); + } + } + /** + * Get the custom hosts set by user. + * @return the custom hosts set by user. + */ + @Nullable + public String[] getCustomHosts(){ + return customHosts.toArray(new String[customHosts.size()]); + } /** * Determines whether a request should be authenticated or not based on it's url. * If you're implementing a custom provider, call that method first before getting the token @@ -22,6 +45,6 @@ protected boolean shouldAuthenticateRequestWithUrl(@Nonnull final URL requestUrl if (requestUrl == null || !requestUrl.getProtocol().toLowerCase(Locale.ROOT).equals("https")) return false; final String hostName = requestUrl.getHost().toLowerCase(Locale.ROOT); - return validGraphHostNames.contains(hostName); + return customHosts == null ? (validGraphHostNames.contains(hostName)) : (customHosts.contains(hostName)); } } diff --git a/src/main/java/com/microsoft/graph/authentication/TokenCredentialAuthProvider.java b/src/main/java/com/microsoft/graph/authentication/TokenCredentialAuthProvider.java index 555dc34b2..de9a5d9b3 100644 --- a/src/main/java/com/microsoft/graph/authentication/TokenCredentialAuthProvider.java +++ b/src/main/java/com/microsoft/graph/authentication/TokenCredentialAuthProvider.java @@ -21,6 +21,7 @@ public class TokenCredentialAuthProvider extends BaseAuthenticationProvider { private final TokenRequestContext context; /** Default scope to use when no scopes are provided */ private static final String DEFAULT_GRAPH_SCOPE = "https://graph.microsoft.com/.default"; + /** * Creates an Authentication provider using a passed in TokenCredential * diff --git a/src/test/java/com/microsoft/graph/authentication/BaseAuthenticationProviderTest.java b/src/test/java/com/microsoft/graph/authentication/BaseAuthenticationProviderTest.java index 47630890c..48f6f7d98 100644 --- a/src/test/java/com/microsoft/graph/authentication/BaseAuthenticationProviderTest.java +++ b/src/test/java/com/microsoft/graph/authentication/BaseAuthenticationProviderTest.java @@ -1,8 +1,5 @@ package com.microsoft.graph.authentication; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; - import java.net.MalformedURLException; import java.net.URL; import java.util.Arrays; @@ -11,6 +8,8 @@ import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + public class BaseAuthenticationProviderTest { final BaseAuthenticationProvider authProvider = new BaseAuthenticationProvider() { @@ -67,4 +66,18 @@ public void providerDoesNotAddTokenOnNullUrls() throws MalformedURLException { //Assert assertFalse(result); } + @Test + public void providerAddsTokenToCustomHosts() throws MalformedURLException { + //Arrange + final URL url = new URL("https://localhost.com"); + authProvider.setCustomHosts(new String[]{"localHost.com"}); + + //Act + final boolean result = authProvider.shouldAuthenticateRequestWithUrl(url); + + //Assert + assertTrue(result); + assertEquals(authProvider.getCustomHosts().length, 1); + assertEquals(authProvider.getCustomHosts()[0], "localhost.com"); + } }