Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,35 @@
import java.util.Locale;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/**
* Provides basic common methods for all authentication providers
*/
public abstract class BaseAuthenticationProvider implements IAuthenticationProvider {
private static final HashSet<String> validGraphHostNames = new HashSet<>(Arrays.asList("graph.microsoft.com", "graph.microsoft.us", "dod-graph.microsoft.us", "graph.microsoft.de", "microsoftgraph.chinacloudapi.cn", "canary.graph.microsoft.com"));
private HashSet<String> customHosts;

/**
* Allow the user to add custom hosts by passing in Array
* @param customHosts custom hosts passed in by user.
*/
public void setCustomHosts(@Nonnull String[] customHosts) {
if(this.customHosts == null){
this.customHosts = new HashSet<String>();
}
for(String host: customHosts){
this.customHosts.add(host.toLowerCase(Locale.ROOT));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: According to this SO post, we should set this to Locale.ENGLISH in the very rare case we get to country specific endpoints (Turkey in the example) to avoid potential language specific user input error. With that stated, MG will likely define the host names in English so we should be safe even with Locale.ROOT.

No action required as I don't see this scenario happening, just thought I'd share my learning.

}
}
/**
* Get the custom hosts set by user.
* @return the custom hosts set by user.
*/
@Nullable
public String[] getCustomHosts(){
return customHosts.toArray(new String[customHosts.size()]);
}
/**
* Determines whether a request should be authenticated or not based on it's url.
* If you're implementing a custom provider, call that method first before getting the token
Expand All @@ -22,6 +45,6 @@ protected boolean shouldAuthenticateRequestWithUrl(@Nonnull final URL requestUrl
if (requestUrl == null || !requestUrl.getProtocol().toLowerCase(Locale.ROOT).equals("https"))
return false;
final String hostName = requestUrl.getHost().toLowerCase(Locale.ROOT);
return validGraphHostNames.contains(hostName);
return customHosts == null ? (validGraphHostNames.contains(hostName)) : (customHosts.contains(hostName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class TokenCredentialAuthProvider extends BaseAuthenticationProvider {
private final TokenRequestContext context;
/** Default scope to use when no scopes are provided */
private static final String DEFAULT_GRAPH_SCOPE = "https://graph.microsoft.com/.default";

/**
* Creates an Authentication provider using a passed in TokenCredential
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package com.microsoft.graph.authentication;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.Arrays;
Expand All @@ -11,6 +8,8 @@

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;

public class BaseAuthenticationProviderTest {
final BaseAuthenticationProvider authProvider = new BaseAuthenticationProvider() {

Expand Down Expand Up @@ -67,4 +66,18 @@ public void providerDoesNotAddTokenOnNullUrls() throws MalformedURLException {
//Assert
assertFalse(result);
}
@Test
public void providerAddsTokenToCustomHosts() throws MalformedURLException {
//Arrange
final URL url = new URL("https://localhost.com");
authProvider.setCustomHosts(new String[]{"localHost.com"});

//Act
final boolean result = authProvider.shouldAuthenticateRequestWithUrl(url);

//Assert
assertTrue(result);
assertEquals(authProvider.getCustomHosts().length, 1);
assertEquals(authProvider.getCustomHosts()[0], "localhost.com");
}
}