diff --git a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java index 8b0c859037..fec82f0999 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/ReferenceCountedObject.java @@ -62,6 +62,11 @@ public interface ReferenceCountedObject { */ boolean release(); + /** The same as wrap(value, EMPTY, EMPTY), where EMPTY is an empty method. */ + static ReferenceCountedObject wrap(V value) { + return wrap(value, () -> {}, () -> {}); + } + /** * Wrap the given value as a {@link ReferenceCountedObject}. * @@ -81,8 +86,11 @@ static ReferenceCountedObject wrap(V value, Runnable retainMethod, Runnab @Override public V get() { - if (count.get() < 0) { + final int previous = count.get(); + if (previous < 0) { throw new IllegalStateException("Failed to get: object has already been completely released."); + } else if (previous == 0) { + throw new IllegalStateException("Failed to get: object has not yet been retained."); } return value; } diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyConfigKeys.java b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyConfigKeys.java index 98b1a6d747..be3ad8ee67 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/NettyConfigKeys.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/NettyConfigKeys.java @@ -158,7 +158,7 @@ static void setWorkerGroupSize(RaftProperties properties, int clientWorkerGroupS } String WORKER_GROUP_SHARE_KEY = PREFIX + ".worker-group.share"; - boolean WORKER_GROUP_SHARE_DEFAULT = false; + boolean WORKER_GROUP_SHARE_DEFAULT = true; static boolean workerGroupShare(RaftProperties properties) { return getBoolean(properties::getBoolean, WORKER_GROUP_SHARE_KEY, WORKER_GROUP_SHARE_DEFAULT, getDefaultLog()); diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java index e4c154fd21..41d862a302 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java @@ -56,6 +56,8 @@ import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.NetUtils; +import org.apache.ratis.util.Preconditions; +import org.apache.ratis.util.ReferenceCountedObject; import org.apache.ratis.util.SizeInBytes; import org.apache.ratis.util.TimeDuration; import org.slf4j.Logger; @@ -78,7 +80,36 @@ public class NettyClientStreamRpc implements DataStreamClientRpc { public static final Logger LOG = LoggerFactory.getLogger(NettyClientStreamRpc.class); private static class WorkerGroupGetter implements Supplier { - private static final AtomicReference SHARED_WORKER_GROUP = new AtomicReference<>(); + + private static final AtomicReference>> SHARED_WORKER_GROUP + = new AtomicReference<>(); + + static WorkerGroupGetter newInstance(RaftProperties properties) { + final boolean shared = NettyConfigKeys.DataStream.Client.workerGroupShare(properties); + if (shared) { + final CompletableFuture> created = new CompletableFuture<>(); + final CompletableFuture> current + = SHARED_WORKER_GROUP.updateAndGet(g -> g != null ? g : created); + if (current == created) { + created.complete(ReferenceCountedObject.wrap(newWorkerGroup(properties))); + } + return new WorkerGroupGetter(current.join().retain()) { + @Override + void shutdownGracefully() { + final CompletableFuture> returned + = SHARED_WORKER_GROUP.updateAndGet(previous -> { + Preconditions.assertSame(current, previous, "SHARED_WORKER_GROUP"); + return previous.join().release() ? null : previous; + }); + if (returned == null) { + get().shutdownGracefully(); + } + } + }; + } else { + return new WorkerGroupGetter(newWorkerGroup(properties)); + } + } static EventLoopGroup newWorkerGroup(RaftProperties properties) { return NettyUtils.newEventLoopGroup( @@ -88,27 +119,18 @@ static EventLoopGroup newWorkerGroup(RaftProperties properties) { } private final EventLoopGroup workerGroup; - private final boolean ignoreShutdown; - WorkerGroupGetter(RaftProperties properties) { - if (NettyConfigKeys.DataStream.Client.workerGroupShare(properties)) { - workerGroup = SHARED_WORKER_GROUP.updateAndGet(g -> g != null? g: newWorkerGroup(properties)); - ignoreShutdown = true; - } else { - workerGroup = newWorkerGroup(properties); - ignoreShutdown = false; - } + private WorkerGroupGetter(EventLoopGroup workerGroup) { + this.workerGroup = workerGroup; } @Override - public EventLoopGroup get() { + public final EventLoopGroup get() { return workerGroup; } void shutdownGracefully() { - if (!ignoreShutdown) { - workerGroup.shutdownGracefully(); - } + workerGroup.shutdownGracefully(); } } @@ -257,8 +279,7 @@ public NettyClientStreamRpc(RaftPeer server, TlsConf tlsConf, RaftProperties pro final InetSocketAddress address = NetUtils.createSocketAddr(server.getDataStreamAddress()); final SslContext sslContext = NettyUtils.buildSslContextForClient(tlsConf); - this.connection = new Connection(address, - new WorkerGroupGetter(properties), + this.connection = new Connection(address, WorkerGroupGetter.newInstance(properties), () -> newChannelInitializer(address, sslContext, getClientHandler())); } diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java index a4cc537ddc..ad6bc4e02e 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java @@ -325,10 +325,13 @@ static long writeTo(ByteBuf buf, Iterable options, long byteWritten = 0; for (ByteBuffer buffer : buf.nioBuffers()) { final ReferenceCountedObject wrapped = ReferenceCountedObject.wrap(buffer, buf::retain, buf::release); + wrapped.retain(); try { byteWritten += channel.write(wrapped); } catch (Throwable t) { throw new CompletionException(t); + } finally { + wrapped.release(); } } diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamChainTopologyWithGrpcCluster.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamChainTopologyWithGrpcCluster.java index e4e9fef575..31b28b4c2d 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamChainTopologyWithGrpcCluster.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamChainTopologyWithGrpcCluster.java @@ -17,7 +17,25 @@ */ package org.apache.ratis.datastream; +import org.apache.ratis.client.RaftClientConfigKeys; +import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.netty.NettyConfigKeys; +import org.apache.ratis.util.SizeInBytes; +import org.apache.ratis.util.TimeDuration; +import org.junit.Before; + public class TestNettyDataStreamChainTopologyWithGrpcCluster extends DataStreamAsyncClusterTests implements MiniRaftClusterWithRpcTypeGrpcAndDataStreamTypeNetty.FactoryGet { + + @Before + public void setup() { + final RaftProperties p = getProperties(); + RaftClientConfigKeys.DataStream.setRequestTimeout(p, TimeDuration.ONE_MINUTE); + RaftClientConfigKeys.DataStream.setFlushRequestCountMin(p, 4); + RaftClientConfigKeys.DataStream.setFlushRequestBytesMin(p, SizeInBytes.valueOf("10MB")); + RaftClientConfigKeys.DataStream.setOutstandingRequestsMax(p, 2 << 16); + + NettyConfigKeys.DataStream.Client.setWorkerGroupSize(p,100); + } } diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamStarTopologyWithGrpcCluster.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamStarTopologyWithGrpcCluster.java index 14c62b74f6..45247d489a 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamStarTopologyWithGrpcCluster.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamStarTopologyWithGrpcCluster.java @@ -19,6 +19,7 @@ import org.apache.ratis.client.RaftClientConfigKeys; import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.netty.NettyConfigKeys; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.protocol.RoutingTable; @@ -41,6 +42,8 @@ public void setup() { RaftClientConfigKeys.DataStream.setFlushRequestCountMin(p, 4); RaftClientConfigKeys.DataStream.setFlushRequestBytesMin(p, SizeInBytes.valueOf("10MB")); RaftClientConfigKeys.DataStream.setOutstandingRequestsMax(p, 2 << 16); + + NettyConfigKeys.DataStream.Client.setWorkerGroupSize(p,100); } @Override diff --git a/ratis-test/src/test/java/org/apache/ratis/util/TestReferenceCountedObject.java b/ratis-test/src/test/java/org/apache/ratis/util/TestReferenceCountedObject.java index 448212154c..5a855857a7 100644 --- a/ratis-test/src/test/java/org/apache/ratis/util/TestReferenceCountedObject.java +++ b/ratis-test/src/test/java/org/apache/ratis/util/TestReferenceCountedObject.java @@ -47,7 +47,12 @@ public void testWrap() { value, retained::getAndIncrement, released::getAndIncrement); assertValues(retained, 0, released, 0); - Assert.assertEquals(value, ref.get()); + try { + ref.get(); + Assert.fail(); + } catch (IllegalStateException e) { + e.printStackTrace(System.out); + } assertValues(retained, 0, released, 0); Assert.assertEquals(value, ref.retain());