diff --git a/openid/src/main/java/com/inrupt/client/openid/OpenIdSession.java b/openid/src/main/java/com/inrupt/client/openid/OpenIdSession.java index de84055fe82..11180fd3532 100644 --- a/openid/src/main/java/com/inrupt/client/openid/OpenIdSession.java +++ b/openid/src/main/java/com/inrupt/client/openid/OpenIdSession.java @@ -40,6 +40,8 @@ import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinTask; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; @@ -54,6 +56,8 @@ import org.jose4j.jwt.consumer.JwtConsumerBuilder; import org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver; import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A session implementation for use with OpenID Connect ID Tokens. @@ -61,12 +65,15 @@ */ public final class OpenIdSession implements Session { + private static final Logger LOGGER = LoggerFactory.getLogger(OpenIdSession.class); + public static final URI ID_TOKEN = URI.create("http://openid.net/specs/openid-connect-core-1_0.html#IDToken"); private final String id; private final Set schemes; private final Supplier> authenticator; private final AtomicReference credential = new AtomicReference<>(); + private final ForkJoinPool executor = new ForkJoinPool(1); private final DPoP dpop; private OpenIdSession(final String id, final DPoP dpop, @@ -182,15 +189,11 @@ public Set supportedSchemes() { @Override public Optional getCredential(final URI name, final URI uri) { if (ID_TOKEN.equals(name)) { - final Credential c = credential.get(); - if (!hasExpired(c)) { - return Optional.of(c); - } - final Credential freshCredential = authenticator.get().toCompletableFuture().join(); - if (!hasExpired(freshCredential)) { - credential.set(freshCredential); - return Optional.of(freshCredential); + final Credential cred = credential.get(); + if (!hasExpired(cred)) { + return Optional.of(cred); } + return Optional.ofNullable(executor.invoke(ForkJoinTask.adapt(this::synchronizedFetch))); } return Optional.empty(); } @@ -222,7 +225,7 @@ public Optional fromCache(final Request request) { @Override public CompletionStage> authenticate(final Request request, final Set algorithms) { - return authenticator.get().thenApply(Optional::ofNullable); + return CompletableFuture.completedFuture(getCredential(ID_TOKEN, null)); } boolean hasExpired(final Credential credential) { @@ -232,6 +235,22 @@ boolean hasExpired(final Credential credential) { return true; } + private synchronized Credential synchronizedFetch() { + // Check again inside the synchronized method + final Credential cred = credential.get(); + if (!hasExpired(cred)) { + return cred; + } + + // Fetch the refreshed credentials + final Credential refreshed = authenticator.get().toCompletableFuture().join(); + if (!hasExpired(refreshed)) { + credential.set(refreshed); + return refreshed; + } + return null; + } + static String getSessionIdentifier(final JwtClaims claims) { final String webid = claims.getClaimValueAsString("webid"); if (webid != null) {