Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
*/
package org.apache.pulsar.broker.service;

import static org.apache.bookkeeper.util.SafeRunnable.safeRun;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
Expand All @@ -30,8 +27,6 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -57,15 +52,6 @@ public class PulsarChannelInitializer extends ChannelInitializer<SocketChannel>
private final ServiceConfiguration brokerConf;
private NettySSLContextAutoRefreshBuilder nettySSLContextAutoRefreshBuilder;

// This cache is used to maintain a list of active connections to iterate over them
// We keep weak references to have the cache to be auto cleaned up when the connections
// objects are GCed.
@VisibleForTesting
protected final Cache<SocketAddress, ServerCnx> connections = Caffeine.newBuilder()
.weakKeys()
.weakValues()
.build();

/**
* @param pulsar
* An instance of {@link PulsarService}
Expand Down Expand Up @@ -114,10 +100,6 @@ public PulsarChannelInitializer(PulsarService pulsar, PulsarChannelOptions opts)
this.sslCtxRefresher = null;
}
this.brokerConf = pulsar.getConfiguration();

pulsar.getExecutor().scheduleAtFixedRate(safeRun(this::refreshAuthenticationCredentials),
pulsar.getConfig().getAuthenticationRefreshCheckSeconds(),
pulsar.getConfig().getAuthenticationRefreshCheckSeconds(), TimeUnit.SECONDS);
}

@Override
Expand Down Expand Up @@ -148,18 +130,6 @@ protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast("flowController", new FlowControlHandler());
ServerCnx cnx = newServerCnx(pulsar, listenerName);
ch.pipeline().addLast("handler", cnx);

connections.put(ch.remoteAddress(), cnx);
}

private void refreshAuthenticationCredentials() {
connections.asMap().values().forEach(cnx -> {
try {
cnx.refreshAuthenticationCredentials();
} catch (Throwable t) {
log.warn("[{}] Failed to refresh auth credentials", cnx.clientAddress());
}
});
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import io.prometheus.client.Gauge;
import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -199,6 +200,7 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
// Keep temporarily in order to verify after verifying proxy's authData
private AuthData originalAuthDataCopy;
private boolean pendingAuthChallengeResponse = false;
private ScheduledFuture<?> authRefreshTask;

// Max number of pending requests per connections. If multiple producers are sharing the same connection the flow
// control done by a single producer might not be enough to prevent write spikes on the broker.
Expand Down Expand Up @@ -332,6 +334,9 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

cnxsPerThread.get().remove(this);
if (authRefreshTask != null) {
authRefreshTask.cancel(false);
}

// Connection is gone, close the producers immediately
producers.forEach((__, producerFuture) -> {
Expand Down Expand Up @@ -665,15 +670,18 @@ ByteBuf createConsumerStatsResponse(Consumer consumer, long requestId) {

// complete the connect and sent newConnected command
private void completeConnect(int clientProtoVersion, String clientVersion) {
if (service.isAuthenticationEnabled() && service.isAuthorizationEnabled()) {
if (!service.getAuthorizationService()
if (service.isAuthenticationEnabled()) {
if (service.isAuthorizationEnabled()) {
if (!service.getAuthorizationService()
.isValidOriginalPrincipal(authRole, originalPrincipal, remoteAddress)) {
state = State.Failed;
service.getPulsarStats().recordConnectionCreateFail();
final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, "Invalid roles.");
NettyChannelUtil.writeAndFlushWithClosePromise(ctx, msg);
return;
state = State.Failed;
service.getPulsarStats().recordConnectionCreateFail();
final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, "Invalid roles.");
NettyChannelUtil.writeAndFlushWithClosePromise(ctx, msg);
return;
}
}
maybeScheduleAuthenticationCredentialsRefresh();
}
writeAndFlush(Commands.newConnected(clientProtoVersion, maxMessageSize, enableSubscriptionPatternEvaluation));
state = State.Connected;
Expand Down Expand Up @@ -772,7 +780,7 @@ public void authChallengeSuccessCallback(AuthData authChallenge,
log.debug("[{}] Authentication in progress client by method {}.", remoteAddress, authMethod);
}
}
} catch (Exception e) {
} catch (Exception | AssertionError e) {
authenticationFailed(e);
}
}
Expand All @@ -799,7 +807,7 @@ private void authenticateOriginalData(int clientProtoVersion, String clientVersi
remoteAddress, originalPrincipal);
}
completeConnect(clientProtoVersion, clientVersion);
} catch (Exception e) {
} catch (Exception | AssertionError e) {
authenticationFailed(e);
}
}
Expand All @@ -821,61 +829,75 @@ private void authenticationFailed(Throwable t) {
NettyChannelUtil.writeAndFlushWithClosePromise(ctx, msg);
}

public void refreshAuthenticationCredentials() {
AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;

/**
* Method to initialize the {@link #authRefreshTask} task.
*/
private void maybeScheduleAuthenticationCredentialsRefresh() {
assert ctx.executor().inEventLoop();
assert authRefreshTask == null;
if (authState == null) {
// Authentication is disabled or there's no local state to refresh
return;
} else if (getState() != State.Connected || !isActive) {
// Connection is either still being established or already closed.
}
authRefreshTask = ctx.executor().scheduleAtFixedRate(this::refreshAuthenticationCredentials,
service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
TimeUnit.SECONDS);
}

private void refreshAuthenticationCredentials() {
assert ctx.executor().inEventLoop();
AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;
if (getState() == State.Failed) {
// Happens when an exception is thrown that causes this connection to close.
return;
} else if (!authState.isExpired()) {
// Credentials are still valid. Nothing to do at this point
return;
} else if (originalPrincipal != null && originalAuthState == null) {
// This case is only checked when the authState is expired because we've reached a point where
// authentication needs to be refreshed, but the protocol does not support it unless the proxy forwards
// the originalAuthData.
log.info(
"[{}] Cannot revalidate user credential when using proxy and"
+ " not forwarding the credentials. Closing connection",
remoteAddress);
ctx.close();
return;
}

ctx.executor().execute(SafeRun.safeRun(() -> {
log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
remoteAddress, originalPrincipal, this.authRole);

if (!supportsAuthenticationRefresh()) {
log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
remoteAddress);
ctx.close();
return;
}
if (!supportsAuthenticationRefresh()) {
log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
remoteAddress);
ctx.close();
return;
}

if (pendingAuthChallengeResponse) {
log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
remoteAddress);
ctx.close();
return;
}
if (pendingAuthChallengeResponse) {
log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
remoteAddress);
ctx.close();
return;
}

try {
AuthData brokerData = authState.refreshAuthentication();
log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
remoteAddress, originalPrincipal, this.authRole);
try {
AuthData brokerData = authState.refreshAuthentication();

writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
getRemoteEndpointProtocolVersion()));
if (log.isDebugEnabled()) {
log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
getRemoteEndpointProtocolVersion()));
if (log.isDebugEnabled()) {
log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
remoteAddress, authMethod);
}
}

pendingAuthChallengeResponse = true;
pendingAuthChallengeResponse = true;

} catch (AuthenticationException e) {
log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
ctx.close();
}
}));
} catch (AuthenticationException e) {
log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
ctx.close();
}
}

private static final byte[] emptyArray = new byte[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1948,8 +1948,6 @@ protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().remove("handler");
PersistentTopicE2ETest.ServerCnxForTest serverCnxForTest = new PersistentTopicE2ETest.ServerCnxForTest(this.pulsar, this.opts.getListenerName());
ch.pipeline().addAfter("flowController", "testHandler", serverCnxForTest);
//override parent
connections.put(ch.remoteAddress(), serverCnxForTest);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,13 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {
when(brokerService.getAuthenticationService()).thenReturn(authenticationService);
when(authenticationService.getAuthenticationProvider(authMethodName)).thenReturn(authenticationProvider);
svcConfig.setAuthenticationEnabled(true);
svcConfig.setAuthenticationRefreshCheckSeconds(30);

resetChannel();
assertTrue(channel.isActive());
assertEquals(serverCnx.getState(), State.Start);
// Don't want the keep alive task affecting which messages are handled
serverCnx.cancelKeepAliveTask();

ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.client", "");
channel.writeInbound(clientCommand);
Expand All @@ -503,7 +506,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {

// Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
// and then force channel to run the task
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
channel.runPendingTasks();
Object responseAuthChallenge1 = getResponse();
assertTrue(responseAuthChallenge1 instanceof CommandAuthChallenge);
Expand All @@ -513,7 +516,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {
channel.writeInbound(authResponse1);

// Trigger the ServerCnx to check if authentication is expired again
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
channel.runPendingTasks();
Object responseAuthChallenge2 = getResponse();
Expand All @@ -539,10 +542,13 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {
svcConfig.setAuthenticationEnabled(true);
svcConfig.setAuthenticateOriginalAuthData(true);
svcConfig.setProxyRoles(Collections.singleton("pass.proxy"));
svcConfig.setAuthenticationRefreshCheckSeconds(30);

resetChannel();
assertTrue(channel.isActive());
assertEquals(serverCnx.getState(), State.Start);
// Don't want the keep alive task affecting which messages are handled
serverCnx.cancelKeepAliveTask();

ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.proxy", 1, null,
null, "pass.client", "pass.client", authMethodName);
Expand All @@ -559,7 +565,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {

// Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
// and then force channel to run the task
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
channel.runPendingTasks();
Object responseAuthChallenge1 = getResponse();
Expand All @@ -570,7 +576,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {
channel.writeInbound(authResponse1);

// Trigger the ServerCnx to check if authentication is expired again
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
channel.runPendingTasks();
Object responseAuthChallenge2 = getResponse();
assertTrue(responseAuthChallenge2 instanceof CommandAuthChallenge);
Expand Down