diff --git a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md index 662355d51d6..8c6af1b1e2f 100644 --- a/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md +++ b/zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md @@ -988,6 +988,14 @@ property, when available, is noted below. **New in 3.6.0:** The size threshold after which a request is considered a large request. If it is -1, then all requests are considered small, effectively turning off large request throttling. The default is -1. +* *outstandingHandshake.limit* + (Jave system property only: **zookeeper.netty.server.outstandingHandshake.limit**) + The maximum in-flight TLS handshake connections could have in ZooKeeper, + the connections exceed this limit will be rejected before starting handshake. + This setting doesn't limit the max TLS concurrency, but helps avoid herd + effect due to TLS handshake timeout when there are too many in-flight TLS + handshakes. Set it to something like 250 is good enough to avoid herd effect. + #### Cluster Options diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java index ca89672144d..209066536b7 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxn.java @@ -70,6 +70,14 @@ public class NettyServerCnxn extends ServerCnxn { public int readIssuedAfterReadComplete; + private volatile HandshakeState handshakeState = HandshakeState.NONE; + + public enum HandshakeState { + NONE, + STARTED, + FINISHED + } + NettyServerCnxn(Channel channel, ZooKeeperServer zks, NettyServerCnxnFactory factory) { super(zks); this.channel = channel; @@ -629,4 +637,11 @@ public int getQueuedReadableBytes() { return 0; } + public void setHandshakeState(HandshakeState state) { + this.handshakeState = state; + } + + public HandshakeState getHandshakeState() { + return this.handshakeState; + } } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java index ef35837c460..78e03029fa2 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/NettyServerCnxnFactory.java @@ -69,6 +69,7 @@ import org.apache.zookeeper.common.SSLContextAndOptions; import org.apache.zookeeper.common.X509Exception; import org.apache.zookeeper.common.X509Exception.SSLContextException; +import org.apache.zookeeper.server.NettyServerCnxn.HandshakeState; import org.apache.zookeeper.server.auth.ProviderRegistry; import org.apache.zookeeper.server.auth.X509AuthenticationProvider; import org.apache.zookeeper.server.quorum.QuorumPeerConfig; @@ -93,6 +94,18 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory { */ private static final byte TLS_HANDSHAKE_RECORD_TYPE = 0x16; + private final AtomicInteger outstandingHandshake = new AtomicInteger(); + public static final String OUTSTANDING_HANDSHAKE_LIMIT = "zookeeper.netty.server.outstandingHandshake.limit"; + private int outstandingHandshakeLimit; + private boolean handshakeThrottlingEnabled; + + public void setOutstandingHandshakeLimit(int limit) { + outstandingHandshakeLimit = limit; + handshakeThrottlingEnabled = (secure || shouldUsePortUnification) && outstandingHandshakeLimit > 0; + LOG.info("handshakeThrottlingEnabled = {}, {} = {}", + handshakeThrottlingEnabled, OUTSTANDING_HANDSHAKE_LIMIT, outstandingHandshakeLimit); + } + private final ServerBootstrap bootstrap; private Channel parentChannel; private final ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns", new DefaultEventExecutor()); @@ -164,6 +177,8 @@ protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext ssl protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { NettyServerCnxn cnxn = Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get()); LOG.debug("creating plaintext handler for session {}", cnxn.getSessionId()); + // Mark handshake finished if it's a insecure cnxn + updateHandshakeCountIfStarted(cnxn); allChannels.add(context.channel()); addCnxn(cnxn); return super.newNonSslHandler(context); @@ -171,6 +186,13 @@ protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { } + private void updateHandshakeCountIfStarted(NettyServerCnxn cnxn) { + if (cnxn != null && cnxn.getHandshakeState() == HandshakeState.STARTED) { + cnxn.setHandshakeState(HandshakeState.FINISHED); + outstandingHandshake.addAndGet(-1); + } + } + /** * This is an inner class since we need to extend ChannelDuplexHandler, but * NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner @@ -202,6 +224,23 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { NettyServerCnxn cnxn = new NettyServerCnxn(channel, zkServer, NettyServerCnxnFactory.this); ctx.channel().attr(CONNECTION_ATTRIBUTE).set(cnxn); + if (handshakeThrottlingEnabled) { + // Favor to check and throttling even in dual mode which + // accepts both secure and insecure connections, since + // it's more efficient than throttling when we know it's + // a secure connection in DualModeSslHandler. + // + // From benchmark, this reduced around 15% reconnect time. + int outstandingHandshakesNum = outstandingHandshake.addAndGet(1); + if (outstandingHandshakesNum > outstandingHandshakeLimit) { + outstandingHandshake.addAndGet(-1); + channel.close(); + ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED.add(1); + } else { + cnxn.setHandshakeState(HandshakeState.STARTED); + } + } + if (secure) { SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); Future handshakeFuture = sslHandler.handshakeFuture(); @@ -224,6 +263,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { if (LOG.isTraceEnabled()) { LOG.trace("Channel inactive caused close {}", cnxn); } + updateHandshakeCountIfStarted(cnxn); cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_DISCONNECTED); } } @@ -234,6 +274,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null); if (cnxn != null) { LOG.debug("Closing {}", cnxn); + updateHandshakeCountIfStarted(cnxn); cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_CLOSED_EXCEPTION); } } @@ -339,6 +380,8 @@ final class CertificateVerifier implements GenericFutureListener * Only allow the connection to stay open if certificate passes auth */ public void operationComplete(Future future) { + updateHandshakeCountIfStarted(cnxn); + if (future.isSuccess()) { LOG.debug("Successful handshake with session 0x{}", Long.toHexString(cnxn.getSessionId())); SSLEngine eng = sslHandler.engine(); @@ -451,6 +494,8 @@ private ServerBootstrap configureBootstrapAllocator(ServerBootstrap bootstrap) { this.advancedFlowControlEnabled = Boolean.getBoolean(NETTY_ADVANCED_FLOW_CONTROL); LOG.info("{} = {}", NETTY_ADVANCED_FLOW_CONTROL, this.advancedFlowControlEnabled); + setOutstandingHandshakeLimit(Integer.getInteger(OUTSTANDING_HANDSHAKE_LIMIT, -1)); + EventLoopGroup bossGroup = NettyUtils.newNioOrEpollEventLoopGroup(NettyUtils.getClientReachableLocalInetAddressCount()); EventLoopGroup workerGroup = NettyUtils.newNioOrEpollEventLoopGroup(); ServerBootstrap bootstrap = new ServerBootstrap().group(bossGroup, workerGroup) @@ -756,4 +801,8 @@ public void setSecure(boolean secure) { public Channel getParentChannel() { return parentChannel; } + + public int getOutstandingHandshakeNum() { + return outstandingHandshake.get(); + } } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java index 1f9855ca260..fe1539dfe75 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ServerMetrics.java @@ -228,6 +228,7 @@ private ServerMetrics(MetricsProvider metricsProvider) { NETTY_QUEUED_BUFFER = metricsContext.getSummary("netty_queued_buffer_capacity", DetailLevel.BASIC); DIGEST_MISMATCHES_COUNT = metricsContext.getCounter("digest_mismatches_count"); + TLS_HANDSHAKE_EXCEEDED = metricsContext.getCounter("tls_handshake_exceeded"); } /** @@ -434,6 +435,8 @@ private ServerMetrics(MetricsProvider metricsProvider) { // txns to data tree. public final Counter DIGEST_MISMATCHES_COUNT; + public final Counter TLS_HANDSHAKE_EXCEEDED; + private final MetricsProvider metricsProvider; public void resetAll() { diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java index 05bf82e08d9..a14345964bd 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java @@ -1801,6 +1801,7 @@ protected void registerMetrics() { rootContext.registerGauge("max_client_response_size", stats.getClientResponseStats()::getMaxBufferSize); rootContext.registerGauge("min_client_response_size", stats.getClientResponseStats()::getMinBufferSize); + rootContext.registerGauge("outstanding_tls_handshake", this::getOutstandingHandshakeNum); } protected void unregisterMetrics() { @@ -2060,4 +2061,11 @@ private boolean buffer2Record(ByteBuffer request, Record record) { return rv; } + public int getOutstandingHandshakeNum() { + if (serverCnxnFactory instanceof NettyServerCnxnFactory) { + return ((NettyServerCnxnFactory) serverCnxnFactory).getOutstandingHandshakeNum(); + } else { + return 0; + } + } } diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java index 144ca3ba2c6..afb97b14ab3 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/NettyServerCnxnFactoryTest.java @@ -19,11 +19,49 @@ package org.apache.zookeeper.server; import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import org.apache.zookeeper.PortAssignment; +import org.apache.zookeeper.WatchedEvent; +import org.apache.zookeeper.Watcher; +import org.apache.zookeeper.ZooKeeper; +import org.apache.zookeeper.common.ClientX509Util; +import org.apache.zookeeper.server.metric.SimpleCounter; +import org.apache.zookeeper.test.ClientBase; +import org.apache.zookeeper.test.SSLAuthTest; +import org.hamcrest.Matchers; import org.junit.Assert; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -public class NettyServerCnxnFactoryTest { + +public class NettyServerCnxnFactoryTest extends ClientBase { + + private static final Logger LOG = LoggerFactory + .getLogger(NettyServerCnxnFactoryTest.class); + + final LinkedBlockingQueue zks = new LinkedBlockingQueue(); + + @Override + public void setUp() throws Exception { + System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY, + "org.apache.zookeeper.server.NettyServerCnxnFactory"); + super.setUp(); + } + + @Override + public void tearDown() throws Exception { + System.clearProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY); + + // clean up + for (ZooKeeper zk : zks) { + zk.close(); + } + super.tearDown(); + } @Test public void testRebind() throws Exception { @@ -58,4 +96,63 @@ public void testRebindIPv4IPv6() throws Exception { Assert.assertTrue(factory.getParentChannel().isActive()); } + @Test + public void testOutstandingHandshakeLimit() throws Exception { + + SimpleCounter tlsHandshakeExceeded = (SimpleCounter) ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED; + tlsHandshakeExceeded.reset(); + Assert.assertEquals(tlsHandshakeExceeded.get(), 0); + + ClientX509Util x509Util = SSLAuthTest.setUpSecure(); + NettyServerCnxnFactory factory = (NettyServerCnxnFactory) serverFactory; + factory.setSecure(true); + factory.setOutstandingHandshakeLimit(10); + + int threadNum = 3; + int cnxnPerThread = 10; + Thread[] cnxnWorker = new Thread[threadNum]; + + AtomicInteger cnxnCreated = new AtomicInteger(0); + CountDownLatch latch = new CountDownLatch(1); + + for (int i = 0; i < cnxnWorker.length; i++) { + cnxnWorker[i] = new Thread() { + @Override + public void run() { + for (int i = 0; i < cnxnPerThread; i++) { + try { + zks.add(new ZooKeeper(hostPort, 3000, new Watcher() { + @Override + public void process(WatchedEvent event) { + int created = cnxnCreated.addAndGet(1); + if (created == threadNum * cnxnPerThread) { + latch.countDown(); + } + } + })); + } catch (Exception e) { + LOG.info("Error while creating zk client", e); + } + } + } + }; + cnxnWorker[i].start(); + } + + Assert.assertThat(latch.await(3, TimeUnit.SECONDS), Matchers.is(true)); + LOG.info("created {} connections", threadNum * cnxnPerThread); + + // Assert throttling not 0 + long handshakeThrottledNum = tlsHandshakeExceeded.get(); + LOG.info("TLS_HANDSHAKE_EXCEEDED: {}", handshakeThrottledNum); + Assert.assertThat("The number of handshake throttled should be " + + "greater than 0", handshakeThrottledNum, Matchers.greaterThan(0L)); + + // Assert there is no outstanding handshake anymore + int outstandingHandshakeNum = factory.getOutstandingHandshakeNum(); + LOG.info("outstanding handshake is {}", outstandingHandshakeNum); + Assert.assertThat("The outstanding handshake number should be 0 " + + "after all cnxns established", outstandingHandshakeNum, Matchers.is(0)); + + } } diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java index 6afbbe2e027..9f62a3fd547 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/server/admin/CommandsTest.java @@ -155,7 +155,31 @@ public void testLastSnapshot() throws IOException, InterruptedException { @Test public void testMonitor() throws IOException, InterruptedException { - ArrayList fields = new ArrayList<>(Arrays.asList(new Field("version", String.class), new Field("avg_latency", Double.class), new Field("max_latency", Long.class), new Field("min_latency", Long.class), new Field("packets_received", Long.class), new Field("packets_sent", Long.class), new Field("num_alive_connections", Integer.class), new Field("outstanding_requests", Long.class), new Field("server_state", String.class), new Field("znode_count", Integer.class), new Field("watch_count", Integer.class), new Field("ephemerals_count", Integer.class), new Field("approximate_data_size", Long.class), new Field("open_file_descriptor_count", Long.class), new Field("max_file_descriptor_count", Long.class), new Field("last_client_response_size", Integer.class), new Field("max_client_response_size", Integer.class), new Field("min_client_response_size", Integer.class), new Field("uptime", Long.class), new Field("global_sessions", Long.class), new Field("local_sessions", Long.class), new Field("connection_drop_probability", Double.class))); + ArrayList fields = new ArrayList<>(Arrays.asList( + new Field("version", String.class), + new Field("avg_latency", Double.class), + new Field("max_latency", Long.class), + new Field("min_latency", Long.class), + new Field("packets_received", Long.class), + new Field("packets_sent", Long.class), + new Field("num_alive_connections", Integer.class), + new Field("outstanding_requests", Long.class), + new Field("server_state", String.class), + new Field("znode_count", Integer.class), + new Field("watch_count", Integer.class), + new Field("ephemerals_count", Integer.class), + new Field("approximate_data_size", Long.class), + new Field("open_file_descriptor_count", Long.class), + new Field("max_file_descriptor_count", Long.class), + new Field("last_client_response_size", Integer.class), + new Field("max_client_response_size", Integer.class), + new Field("min_client_response_size", Integer.class), + new Field("uptime", Long.class), + new Field("global_sessions", Long.class), + new Field("local_sessions", Long.class), + new Field("connection_drop_probability", Double.class), + new Field("outstanding_tls_handshake", Integer.class) + )); Map metrics = MetricsUtils.currentServerMetrics(); for (String metric : metrics.keySet()) {