diff --git a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md
index 662355d51d6..8c6af1b1e2f 100644
--- a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md
+++ b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md
@@ -988,6 +988,14 @@ property, when available, is noted below.
**New in 3.6.0:**
The size threshold after which a request is considered a large request. If it is -1, then all requests are considered small, effectively turning off large request throttling. The default is -1.
+* *outstandingHandshake.limit*
+ (Jave system property only: **zookeeper.netty.server.outstandingHandshake.limit**)
+ The maximum in-flight TLS handshake connections could have in ZooKeeper,
+ the connections exceed this limit will be rejected before starting handshake.
+ This setting doesn't limit the max TLS concurrency, but helps avoid herd
+ effect due to TLS handshake timeout when there are too many in-flight TLS
+ handshakes. Set it to something like 250 is good enough to avoid herd effect.
+
#### Cluster Options
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java
index ca89672144d..209066536b7 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java
@@ -70,6 +70,14 @@ public class NettyServerCnxn extends ServerCnxn {
public int readIssuedAfterReadComplete;
+ private volatile HandshakeState handshakeState = HandshakeState.NONE;
+
+ public enum HandshakeState {
+ NONE,
+ STARTED,
+ FINISHED
+ }
+
NettyServerCnxn(Channel channel, ZooKeeperServer zks, NettyServerCnxnFactory factory) {
super(zks);
this.channel = channel;
@@ -629,4 +637,11 @@ public int getQueuedReadableBytes() {
return 0;
}
+ public void setHandshakeState(HandshakeState state) {
+ this.handshakeState = state;
+ }
+
+ public HandshakeState getHandshakeState() {
+ return this.handshakeState;
+ }
}
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java
index ef35837c460..78e03029fa2 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java
@@ -69,6 +69,7 @@
import org.apache.zookeeper.common.SSLContextAndOptions;
import org.apache.zookeeper.common.X509Exception;
import org.apache.zookeeper.common.X509Exception.SSLContextException;
+import org.apache.zookeeper.server.NettyServerCnxn.HandshakeState;
import org.apache.zookeeper.server.auth.ProviderRegistry;
import org.apache.zookeeper.server.auth.X509AuthenticationProvider;
import org.apache.zookeeper.server.quorum.QuorumPeerConfig;
@@ -93,6 +94,18 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
*/
private static final byte TLS_HANDSHAKE_RECORD_TYPE = 0x16;
+ private final AtomicInteger outstandingHandshake = new AtomicInteger();
+ public static final String OUTSTANDING_HANDSHAKE_LIMIT = "zookeeper.netty.server.outstandingHandshake.limit";
+ private int outstandingHandshakeLimit;
+ private boolean handshakeThrottlingEnabled;
+
+ public void setOutstandingHandshakeLimit(int limit) {
+ outstandingHandshakeLimit = limit;
+ handshakeThrottlingEnabled = (secure || shouldUsePortUnification) && outstandingHandshakeLimit > 0;
+ LOG.info("handshakeThrottlingEnabled = {}, {} = {}",
+ handshakeThrottlingEnabled, OUTSTANDING_HANDSHAKE_LIMIT, outstandingHandshakeLimit);
+ }
+
private final ServerBootstrap bootstrap;
private Channel parentChannel;
private final ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns", new DefaultEventExecutor());
@@ -164,6 +177,8 @@ protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext ssl
protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
NettyServerCnxn cnxn = Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get());
LOG.debug("creating plaintext handler for session {}", cnxn.getSessionId());
+ // Mark handshake finished if it's a insecure cnxn
+ updateHandshakeCountIfStarted(cnxn);
allChannels.add(context.channel());
addCnxn(cnxn);
return super.newNonSslHandler(context);
@@ -171,6 +186,13 @@ protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
}
+ private void updateHandshakeCountIfStarted(NettyServerCnxn cnxn) {
+ if (cnxn != null && cnxn.getHandshakeState() == HandshakeState.STARTED) {
+ cnxn.setHandshakeState(HandshakeState.FINISHED);
+ outstandingHandshake.addAndGet(-1);
+ }
+ }
+
/**
* This is an inner class since we need to extend ChannelDuplexHandler, but
* NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner
@@ -202,6 +224,23 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
NettyServerCnxn cnxn = new NettyServerCnxn(channel, zkServer, NettyServerCnxnFactory.this);
ctx.channel().attr(CONNECTION_ATTRIBUTE).set(cnxn);
+ if (handshakeThrottlingEnabled) {
+ // Favor to check and throttling even in dual mode which
+ // accepts both secure and insecure connections, since
+ // it's more efficient than throttling when we know it's
+ // a secure connection in DualModeSslHandler.
+ //
+ // From benchmark, this reduced around 15% reconnect time.
+ int outstandingHandshakesNum = outstandingHandshake.addAndGet(1);
+ if (outstandingHandshakesNum > outstandingHandshakeLimit) {
+ outstandingHandshake.addAndGet(-1);
+ channel.close();
+ ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED.add(1);
+ } else {
+ cnxn.setHandshakeState(HandshakeState.STARTED);
+ }
+ }
+
if (secure) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
Future handshakeFuture = sslHandler.handshakeFuture();
@@ -224,6 +263,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (LOG.isTraceEnabled()) {
LOG.trace("Channel inactive caused close {}", cnxn);
}
+ updateHandshakeCountIfStarted(cnxn);
cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_DISCONNECTED);
}
}
@@ -234,6 +274,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null);
if (cnxn != null) {
LOG.debug("Closing {}", cnxn);
+ updateHandshakeCountIfStarted(cnxn);
cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_CLOSED_EXCEPTION);
}
}
@@ -339,6 +380,8 @@ final class CertificateVerifier implements GenericFutureListener
* Only allow the connection to stay open if certificate passes auth
*/
public void operationComplete(Future future) {
+ updateHandshakeCountIfStarted(cnxn);
+
if (future.isSuccess()) {
LOG.debug("Successful handshake with session 0x{}", Long.toHexString(cnxn.getSessionId()));
SSLEngine eng = sslHandler.engine();
@@ -451,6 +494,8 @@ private ServerBootstrap configureBootstrapAllocator(ServerBootstrap bootstrap) {
this.advancedFlowControlEnabled = Boolean.getBoolean(NETTY_ADVANCED_FLOW_CONTROL);
LOG.info("{} = {}", NETTY_ADVANCED_FLOW_CONTROL, this.advancedFlowControlEnabled);
+ setOutstandingHandshakeLimit(Integer.getInteger(OUTSTANDING_HANDSHAKE_LIMIT, -1));
+
EventLoopGroup bossGroup = NettyUtils.newNioOrEpollEventLoopGroup(NettyUtils.getClientReachableLocalInetAddressCount());
EventLoopGroup workerGroup = NettyUtils.newNioOrEpollEventLoopGroup();
ServerBootstrap bootstrap = new ServerBootstrap().group(bossGroup, workerGroup)
@@ -756,4 +801,8 @@ public void setSecure(boolean secure) {
public Channel getParentChannel() {
return parentChannel;
}
+
+ public int getOutstandingHandshakeNum() {
+ return outstandingHandshake.get();
+ }
}
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java
index 1f9855ca260..fe1539dfe75 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java
@@ -228,6 +228,7 @@ private ServerMetrics(MetricsProvider metricsProvider) {
NETTY_QUEUED_BUFFER = metricsContext.getSummary("netty_queued_buffer_capacity", DetailLevel.BASIC);
DIGEST_MISMATCHES_COUNT = metricsContext.getCounter("digest_mismatches_count");
+ TLS_HANDSHAKE_EXCEEDED = metricsContext.getCounter("tls_handshake_exceeded");
}
/**
@@ -434,6 +435,8 @@ private ServerMetrics(MetricsProvider metricsProvider) {
// txns to data tree.
public final Counter DIGEST_MISMATCHES_COUNT;
+ public final Counter TLS_HANDSHAKE_EXCEEDED;
+
private final MetricsProvider metricsProvider;
public void resetAll() {
diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java
index 05bf82e08d9..a14345964bd 100644
--- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java
+++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java
@@ -1801,6 +1801,7 @@ protected void registerMetrics() {
rootContext.registerGauge("max_client_response_size", stats.getClientResponseStats()::getMaxBufferSize);
rootContext.registerGauge("min_client_response_size", stats.getClientResponseStats()::getMinBufferSize);
+ rootContext.registerGauge("outstanding_tls_handshake", this::getOutstandingHandshakeNum);
}
protected void unregisterMetrics() {
@@ -2060,4 +2061,11 @@ private boolean buffer2Record(ByteBuffer request, Record record) {
return rv;
}
+ public int getOutstandingHandshakeNum() {
+ if (serverCnxnFactory instanceof NettyServerCnxnFactory) {
+ return ((NettyServerCnxnFactory) serverCnxnFactory).getOutstandingHandshakeNum();
+ } else {
+ return 0;
+ }
+ }
}
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java
index 144ca3ba2c6..afb97b14ab3 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java
@@ -19,11 +19,49 @@
package org.apache.zookeeper.server;
import java.net.InetSocketAddress;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.zookeeper.PortAssignment;
+import org.apache.zookeeper.WatchedEvent;
+import org.apache.zookeeper.Watcher;
+import org.apache.zookeeper.ZooKeeper;
+import org.apache.zookeeper.common.ClientX509Util;
+import org.apache.zookeeper.server.metric.SimpleCounter;
+import org.apache.zookeeper.test.ClientBase;
+import org.apache.zookeeper.test.SSLAuthTest;
+import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-public class NettyServerCnxnFactoryTest {
+
+public class NettyServerCnxnFactoryTest extends ClientBase {
+
+ private static final Logger LOG = LoggerFactory
+ .getLogger(NettyServerCnxnFactoryTest.class);
+
+ final LinkedBlockingQueue zks = new LinkedBlockingQueue();
+
+ @Override
+ public void setUp() throws Exception {
+ System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY,
+ "org.apache.zookeeper.server.NettyServerCnxnFactory");
+ super.setUp();
+ }
+
+ @Override
+ public void tearDown() throws Exception {
+ System.clearProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY);
+
+ // clean up
+ for (ZooKeeper zk : zks) {
+ zk.close();
+ }
+ super.tearDown();
+ }
@Test
public void testRebind() throws Exception {
@@ -58,4 +96,63 @@ public void testRebindIPv4IPv6() throws Exception {
Assert.assertTrue(factory.getParentChannel().isActive());
}
+ @Test
+ public void testOutstandingHandshakeLimit() throws Exception {
+
+ SimpleCounter tlsHandshakeExceeded = (SimpleCounter) ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED;
+ tlsHandshakeExceeded.reset();
+ Assert.assertEquals(tlsHandshakeExceeded.get(), 0);
+
+ ClientX509Util x509Util = SSLAuthTest.setUpSecure();
+ NettyServerCnxnFactory factory = (NettyServerCnxnFactory) serverFactory;
+ factory.setSecure(true);
+ factory.setOutstandingHandshakeLimit(10);
+
+ int threadNum = 3;
+ int cnxnPerThread = 10;
+ Thread[] cnxnWorker = new Thread[threadNum];
+
+ AtomicInteger cnxnCreated = new AtomicInteger(0);
+ CountDownLatch latch = new CountDownLatch(1);
+
+ for (int i = 0; i < cnxnWorker.length; i++) {
+ cnxnWorker[i] = new Thread() {
+ @Override
+ public void run() {
+ for (int i = 0; i < cnxnPerThread; i++) {
+ try {
+ zks.add(new ZooKeeper(hostPort, 3000, new Watcher() {
+ @Override
+ public void process(WatchedEvent event) {
+ int created = cnxnCreated.addAndGet(1);
+ if (created == threadNum * cnxnPerThread) {
+ latch.countDown();
+ }
+ }
+ }));
+ } catch (Exception e) {
+ LOG.info("Error while creating zk client", e);
+ }
+ }
+ }
+ };
+ cnxnWorker[i].start();
+ }
+
+ Assert.assertThat(latch.await(3, TimeUnit.SECONDS), Matchers.is(true));
+ LOG.info("created {} connections", threadNum * cnxnPerThread);
+
+ // Assert throttling not 0
+ long handshakeThrottledNum = tlsHandshakeExceeded.get();
+ LOG.info("TLS_HANDSHAKE_EXCEEDED: {}", handshakeThrottledNum);
+ Assert.assertThat("The number of handshake throttled should be "
+ + "greater than 0", handshakeThrottledNum, Matchers.greaterThan(0L));
+
+ // Assert there is no outstanding handshake anymore
+ int outstandingHandshakeNum = factory.getOutstandingHandshakeNum();
+ LOG.info("outstanding handshake is {}", outstandingHandshakeNum);
+ Assert.assertThat("The outstanding handshake number should be 0 "
+ + "after all cnxns established", outstandingHandshakeNum, Matchers.is(0));
+
+ }
}
diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java
index 6afbbe2e027..9f62a3fd547 100644
--- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java
+++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java
@@ -155,7 +155,31 @@ public void testLastSnapshot() throws IOException, InterruptedException {
@Test
public void testMonitor() throws IOException, InterruptedException {
- ArrayList fields = new ArrayList<>(Arrays.asList(new Field("version", String.class), new Field("avg_latency", Double.class), new Field("max_latency", Long.class), new Field("min_latency", Long.class), new Field("packets_received", Long.class), new Field("packets_sent", Long.class), new Field("num_alive_connections", Integer.class), new Field("outstanding_requests", Long.class), new Field("server_state", String.class), new Field("znode_count", Integer.class), new Field("watch_count", Integer.class), new Field("ephemerals_count", Integer.class), new Field("approximate_data_size", Long.class), new Field("open_file_descriptor_count", Long.class), new Field("max_file_descriptor_count", Long.class), new Field("last_client_response_size", Integer.class), new Field("max_client_response_size", Integer.class), new Field("min_client_response_size", Integer.class), new Field("uptime", Long.class), new Field("global_sessions", Long.class), new Field("local_sessions", Long.class), new Field("connection_drop_probability", Double.class)));
+ ArrayList fields = new ArrayList<>(Arrays.asList(
+ new Field("version", String.class),
+ new Field("avg_latency", Double.class),
+ new Field("max_latency", Long.class),
+ new Field("min_latency", Long.class),
+ new Field("packets_received", Long.class),
+ new Field("packets_sent", Long.class),
+ new Field("num_alive_connections", Integer.class),
+ new Field("outstanding_requests", Long.class),
+ new Field("server_state", String.class),
+ new Field("znode_count", Integer.class),
+ new Field("watch_count", Integer.class),
+ new Field("ephemerals_count", Integer.class),
+ new Field("approximate_data_size", Long.class),
+ new Field("open_file_descriptor_count", Long.class),
+ new Field("max_file_descriptor_count", Long.class),
+ new Field("last_client_response_size", Integer.class),
+ new Field("max_client_response_size", Integer.class),
+ new Field("min_client_response_size", Integer.class),
+ new Field("uptime", Long.class),
+ new Field("global_sessions", Long.class),
+ new Field("local_sessions", Long.class),
+ new Field("connection_drop_probability", Double.class),
+ new Field("outstanding_tls_handshake", Integer.class)
+ ));
Map metrics = MetricsUtils.currentServerMetrics();
for (String metric : metrics.keySet()) {