Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions zookeeper-docs/src/main/resources/markdown/zookeeperAdmin.md
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,14 @@ property, when available, is noted below.
**New in 3.6.0:**
The size threshold after which a request is considered a large request. If it is -1, then all requests are considered small, effectively turning off large request throttling. The default is -1.

* *outstandingHandshake.limit*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like something that should be turned on all the time with some meaningful default value. What do you think about setting it to 250 by default and let this property to override it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature hasn't been fully rollout internally, I would suggest to enable this by default, when we verified on our internal environment for a while, so if there is bug we can find it and fix it before its affecting the community.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a master-only patch which doesn't even have an alpha-release, so I believe we could be a little bit more flexible about that, but I'm also happy with leaving it off.

(Jave system property only: **zookeeper.netty.server.outstandingHandshake.limit**)
The maximum in-flight TLS handshake connections could have in ZooKeeper,
the connections exceed this limit will be rejected before starting handshake.
This setting doesn't limit the max TLS concurrency, but helps avoid herd
effect due to TLS handshake timeout when there are too many in-flight TLS
handshakes. Set it to something like 250 is good enough to avoid herd effect.

<a name="sc_clusterOptions"></a>

#### Cluster Options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ public class NettyServerCnxn extends ServerCnxn {

public int readIssuedAfterReadComplete;

private volatile HandshakeState handshakeState = HandshakeState.NONE;

public enum HandshakeState {
NONE,
STARTED,
FINISHED
}

NettyServerCnxn(Channel channel, ZooKeeperServer zks, NettyServerCnxnFactory factory) {
super(zks);
this.channel = channel;
Expand Down Expand Up @@ -629,4 +637,11 @@ public int getQueuedReadableBytes() {
return 0;
}

public void setHandshakeState(HandshakeState state) {
this.handshakeState = state;
}

public HandshakeState getHandshakeState() {
return this.handshakeState;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.apache.zookeeper.common.SSLContextAndOptions;
import org.apache.zookeeper.common.X509Exception;
import org.apache.zookeeper.common.X509Exception.SSLContextException;
import org.apache.zookeeper.server.NettyServerCnxn.HandshakeState;
import org.apache.zookeeper.server.auth.ProviderRegistry;
import org.apache.zookeeper.server.auth.X509AuthenticationProvider;
import org.apache.zookeeper.server.quorum.QuorumPeerConfig;
Expand All @@ -93,6 +94,18 @@ public class NettyServerCnxnFactory extends ServerCnxnFactory {
*/
private static final byte TLS_HANDSHAKE_RECORD_TYPE = 0x16;

private final AtomicInteger outstandingHandshake = new AtomicInteger();
public static final String OUTSTANDING_HANDSHAKE_LIMIT = "zookeeper.netty.server.outstandingHandshake.limit";
private int outstandingHandshakeLimit;
private boolean handshakeThrottlingEnabled;

public void setOutstandingHandshakeLimit(int limit) {
outstandingHandshakeLimit = limit;
handshakeThrottlingEnabled = (secure || shouldUsePortUnification) && outstandingHandshakeLimit > 0;
LOG.info("handshakeThrottlingEnabled = {}, {} = {}",
handshakeThrottlingEnabled, OUTSTANDING_HANDSHAKE_LIMIT, outstandingHandshakeLimit);
}

private final ServerBootstrap bootstrap;
private Channel parentChannel;
private final ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns", new DefaultEventExecutor());
Expand Down Expand Up @@ -164,13 +177,22 @@ protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext ssl
protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
NettyServerCnxn cnxn = Objects.requireNonNull(context.channel().attr(CONNECTION_ATTRIBUTE).get());
LOG.debug("creating plaintext handler for session {}", cnxn.getSessionId());
// Mark handshake finished if it's a insecure cnxn
updateHandshakeCountIfStarted(cnxn);
allChannels.add(context.channel());
addCnxn(cnxn);
return super.newNonSslHandler(context);
}

}

private void updateHandshakeCountIfStarted(NettyServerCnxn cnxn) {
if (cnxn != null && cnxn.getHandshakeState() == HandshakeState.STARTED) {
cnxn.setHandshakeState(HandshakeState.FINISHED);
outstandingHandshake.addAndGet(-1);
}
}

/**
* This is an inner class since we need to extend ChannelDuplexHandler, but
* NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner
Expand Down Expand Up @@ -202,6 +224,23 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception {
NettyServerCnxn cnxn = new NettyServerCnxn(channel, zkServer, NettyServerCnxnFactory.this);
ctx.channel().attr(CONNECTION_ATTRIBUTE).set(cnxn);

if (handshakeThrottlingEnabled) {
// Favor to check and throttling even in dual mode which
// accepts both secure and insecure connections, since
// it's more efficient than throttling when we know it's
// a secure connection in DualModeSslHandler.
//
// From benchmark, this reduced around 15% reconnect time.
int outstandingHandshakesNum = outstandingHandshake.addAndGet(1);
if (outstandingHandshakesNum > outstandingHandshakeLimit) {
outstandingHandshake.addAndGet(-1);
channel.close();
ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED.add(1);
} else {
cnxn.setHandshakeState(HandshakeState.STARTED);
}
}

if (secure) {
SslHandler sslHandler = ctx.pipeline().get(SslHandler.class);
Future<Channel> handshakeFuture = sslHandler.handshakeFuture();
Expand All @@ -224,6 +263,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (LOG.isTraceEnabled()) {
LOG.trace("Channel inactive caused close {}", cnxn);
}
updateHandshakeCountIfStarted(cnxn);
cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_DISCONNECTED);
}
}
Expand All @@ -234,6 +274,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
NettyServerCnxn cnxn = ctx.channel().attr(CONNECTION_ATTRIBUTE).getAndSet(null);
if (cnxn != null) {
LOG.debug("Closing {}", cnxn);
updateHandshakeCountIfStarted(cnxn);
cnxn.close(ServerCnxn.DisconnectReason.CHANNEL_CLOSED_EXCEPTION);
}
}
Expand Down Expand Up @@ -339,6 +380,8 @@ final class CertificateVerifier implements GenericFutureListener<Future<Channel>
* Only allow the connection to stay open if certificate passes auth
*/
public void operationComplete(Future<Channel> future) {
updateHandshakeCountIfStarted(cnxn);

if (future.isSuccess()) {
LOG.debug("Successful handshake with session 0x{}", Long.toHexString(cnxn.getSessionId()));
SSLEngine eng = sslHandler.engine();
Expand Down Expand Up @@ -451,6 +494,8 @@ private ServerBootstrap configureBootstrapAllocator(ServerBootstrap bootstrap) {
this.advancedFlowControlEnabled = Boolean.getBoolean(NETTY_ADVANCED_FLOW_CONTROL);
LOG.info("{} = {}", NETTY_ADVANCED_FLOW_CONTROL, this.advancedFlowControlEnabled);

setOutstandingHandshakeLimit(Integer.getInteger(OUTSTANDING_HANDSHAKE_LIMIT, -1));

EventLoopGroup bossGroup = NettyUtils.newNioOrEpollEventLoopGroup(NettyUtils.getClientReachableLocalInetAddressCount());
EventLoopGroup workerGroup = NettyUtils.newNioOrEpollEventLoopGroup();
ServerBootstrap bootstrap = new ServerBootstrap().group(bossGroup, workerGroup)
Expand Down Expand Up @@ -756,4 +801,8 @@ public void setSecure(boolean secure) {
public Channel getParentChannel() {
return parentChannel;
}

public int getOutstandingHandshakeNum() {
return outstandingHandshake.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ private ServerMetrics(MetricsProvider metricsProvider) {
NETTY_QUEUED_BUFFER = metricsContext.getSummary("netty_queued_buffer_capacity", DetailLevel.BASIC);

DIGEST_MISMATCHES_COUNT = metricsContext.getCounter("digest_mismatches_count");
TLS_HANDSHAKE_EXCEEDED = metricsContext.getCounter("tls_handshake_exceeded");
}

/**
Expand Down Expand Up @@ -434,6 +435,8 @@ private ServerMetrics(MetricsProvider metricsProvider) {
// txns to data tree.
public final Counter DIGEST_MISMATCHES_COUNT;

public final Counter TLS_HANDSHAKE_EXCEEDED;

private final MetricsProvider metricsProvider;

public void resetAll() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,7 @@ protected void registerMetrics() {
rootContext.registerGauge("max_client_response_size", stats.getClientResponseStats()::getMaxBufferSize);
rootContext.registerGauge("min_client_response_size", stats.getClientResponseStats()::getMinBufferSize);

rootContext.registerGauge("outstanding_tls_handshake", this::getOutstandingHandshakeNum);
}

protected void unregisterMetrics() {
Expand Down Expand Up @@ -2060,4 +2061,11 @@ private boolean buffer2Record(ByteBuffer request, Record record) {
return rv;
}

public int getOutstandingHandshakeNum() {
if (serverCnxnFactory instanceof NettyServerCnxnFactory) {
return ((NettyServerCnxnFactory) serverCnxnFactory).getOutstandingHandshakeNum();
} else {
return 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,49 @@
package org.apache.zookeeper.server;

import java.net.InetSocketAddress;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.zookeeper.PortAssignment;
import org.apache.zookeeper.WatchedEvent;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.ZooKeeper;
import org.apache.zookeeper.common.ClientX509Util;
import org.apache.zookeeper.server.metric.SimpleCounter;
import org.apache.zookeeper.test.ClientBase;
import org.apache.zookeeper.test.SSLAuthTest;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NettyServerCnxnFactoryTest {

public class NettyServerCnxnFactoryTest extends ClientBase {

private static final Logger LOG = LoggerFactory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use logging in tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To print the useful information which is useful for debugging unexpected test behavior. Do you suggest to use System.out ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually don't output anything in (unit)tests, but feel free to leave it there if it's comfortable.

.getLogger(NettyServerCnxnFactoryTest.class);

final LinkedBlockingQueue<ZooKeeper> zks = new LinkedBlockingQueue<ZooKeeper>();

@Override
public void setUp() throws Exception {
System.setProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY,
"org.apache.zookeeper.server.NettyServerCnxnFactory");
super.setUp();
}

@Override
public void tearDown() throws Exception {
System.clearProperty(ServerCnxnFactory.ZOOKEEPER_SERVER_CNXN_FACTORY);

// clean up
for (ZooKeeper zk : zks) {
zk.close();
}
super.tearDown();
}

@Test
public void testRebind() throws Exception {
Expand Down Expand Up @@ -58,4 +96,63 @@ public void testRebindIPv4IPv6() throws Exception {
Assert.assertTrue(factory.getParentChannel().isActive());
}

@Test
public void testOutstandingHandshakeLimit() throws Exception {

SimpleCounter tlsHandshakeExceeded = (SimpleCounter) ServerMetrics.getMetrics().TLS_HANDSHAKE_EXCEEDED;
tlsHandshakeExceeded.reset();
Assert.assertEquals(tlsHandshakeExceeded.get(), 0);

ClientX509Util x509Util = SSLAuthTest.setUpSecure();
NettyServerCnxnFactory factory = (NettyServerCnxnFactory) serverFactory;
factory.setSecure(true);
factory.setOutstandingHandshakeLimit(10);

int threadNum = 3;
int cnxnPerThread = 10;
Thread[] cnxnWorker = new Thread[threadNum];

AtomicInteger cnxnCreated = new AtomicInteger(0);
CountDownLatch latch = new CountDownLatch(1);

for (int i = 0; i < cnxnWorker.length; i++) {
cnxnWorker[i] = new Thread() {
@Override
public void run() {
for (int i = 0; i < cnxnPerThread; i++) {
try {
zks.add(new ZooKeeper(hostPort, 3000, new Watcher() {
@Override
public void process(WatchedEvent event) {
int created = cnxnCreated.addAndGet(1);
if (created == threadNum * cnxnPerThread) {
latch.countDown();
}
}
}));
} catch (Exception e) {
LOG.info("Error while creating zk client", e);
}
}
}
};
cnxnWorker[i].start();
}

Assert.assertThat(latch.await(3, TimeUnit.SECONDS), Matchers.is(true));
LOG.info("created {} connections", threadNum * cnxnPerThread);

// Assert throttling not 0
long handshakeThrottledNum = tlsHandshakeExceeded.get();
LOG.info("TLS_HANDSHAKE_EXCEEDED: {}", handshakeThrottledNum);
Assert.assertThat("The number of handshake throttled should be "
+ "greater than 0", handshakeThrottledNum, Matchers.greaterThan(0L));

// Assert there is no outstanding handshake anymore
int outstandingHandshakeNum = factory.getOutstandingHandshakeNum();
LOG.info("outstanding handshake is {}", outstandingHandshakeNum);
Assert.assertThat("The outstanding handshake number should be 0 "
+ "after all cnxns established", outstandingHandshakeNum, Matchers.is(0));

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,31 @@ public void testLastSnapshot() throws IOException, InterruptedException {

@Test
public void testMonitor() throws IOException, InterruptedException {
ArrayList<Field> fields = new ArrayList<>(Arrays.asList(new Field("version", String.class), new Field("avg_latency", Double.class), new Field("max_latency", Long.class), new Field("min_latency", Long.class), new Field("packets_received", Long.class), new Field("packets_sent", Long.class), new Field("num_alive_connections", Integer.class), new Field("outstanding_requests", Long.class), new Field("server_state", String.class), new Field("znode_count", Integer.class), new Field("watch_count", Integer.class), new Field("ephemerals_count", Integer.class), new Field("approximate_data_size", Long.class), new Field("open_file_descriptor_count", Long.class), new Field("max_file_descriptor_count", Long.class), new Field("last_client_response_size", Integer.class), new Field("max_client_response_size", Integer.class), new Field("min_client_response_size", Integer.class), new Field("uptime", Long.class), new Field("global_sessions", Long.class), new Field("local_sessions", Long.class), new Field("connection_drop_probability", Double.class)));
ArrayList<Field> fields = new ArrayList<>(Arrays.asList(
new Field("version", String.class),
new Field("avg_latency", Double.class),
new Field("max_latency", Long.class),
new Field("min_latency", Long.class),
new Field("packets_received", Long.class),
new Field("packets_sent", Long.class),
new Field("num_alive_connections", Integer.class),
new Field("outstanding_requests", Long.class),
new Field("server_state", String.class),
new Field("znode_count", Integer.class),
new Field("watch_count", Integer.class),
new Field("ephemerals_count", Integer.class),
new Field("approximate_data_size", Long.class),
new Field("open_file_descriptor_count", Long.class),
new Field("max_file_descriptor_count", Long.class),
new Field("last_client_response_size", Integer.class),
new Field("max_client_response_size", Integer.class),
new Field("min_client_response_size", Integer.class),
new Field("uptime", Long.class),
new Field("global_sessions", Long.class),
new Field("local_sessions", Long.class),
new Field("connection_drop_probability", Double.class),
new Field("outstanding_tls_handshake", Integer.class)
));
Map<String, Object> metrics = MetricsUtils.currentServerMetrics();

for (String metric : metrics.keySet()) {
Expand Down