diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxn.java b/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxn.java index ed03359f7fe..7663e27f68a 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxn.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxn.java @@ -1289,6 +1289,17 @@ public void run() { "SendThread exited loop for session: 0x" + Long.toHexString(getSessionId())); } + private void abortConnection() { + try { + clientCnxnSocket.testableCloseSocket(); + } catch (IOException e) { + LOG.debug("Fail to close ongoing socket", e); + } + } + + /** + * This is not thread-safe and should only be called inside {@link SendThread}. + */ private void cleanAndNotifyState() { cleanup(); if (state.isAlive()) { @@ -1531,7 +1542,7 @@ public ReplyHeader submitRequest( } } if (r.getErr() == Code.REQUESTTIMEOUT.intValue()) { - sendThread.cleanAndNotifyState(); + sendThread.abortConnection(); } return r; } diff --git a/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxnSocketNIO.java b/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxnSocketNIO.java index ea58b857e7d..e39bee11825 100644 --- a/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxnSocketNIO.java +++ b/zookeeper-server/src/main/java/org/apache/zookeeper/ClientCnxnSocketNIO.java @@ -209,6 +209,12 @@ void cleanup() { } catch (IOException e) { LOG.debug("Ignoring exception during channel close", e); } + try { + selector.wakeup(); + selector.selectNow(); + } catch (IOException e) { + LOG.debug("Ignoring exception during selecting of cancelled socket", e); + } } try { Thread.sleep(100); diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/ClientCnxnSocketFragilityTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/ClientCnxnSocketFragilityTest.java index 54426f0b6e2..2b70a599d6a 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/ClientCnxnSocketFragilityTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/ClientCnxnSocketFragilityTest.java @@ -18,20 +18,32 @@ package org.apache.zookeeper; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.SocketException; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.time.Duration; import java.util.Queue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.apache.zookeeper.ClientCnxn.Packet; import org.apache.zookeeper.Watcher.Event.KeeperState; import org.apache.zookeeper.ZooDefs.Ids; import org.apache.zookeeper.client.HostProvider; import org.apache.zookeeper.client.ZKClientConfig; +import org.apache.zookeeper.common.BusyServer; import org.apache.zookeeper.data.Stat; import org.apache.zookeeper.server.quorum.QuorumPeerTestBase; import org.apache.zookeeper.test.ClientBase; @@ -75,6 +87,40 @@ private void closeZookeeper(ZooKeeper zk) { }); } + @Test + public void testSocketClosedAfterFailure() throws Exception { + Duration sessionTimeout = Duration.ofMillis(1000); + final AtomicReference nioSelector = new AtomicReference<>(); + try ( + // given: busy server + BusyServer server = new BusyServer(); + ZooKeeper zk = new ZooKeeper(server.getHostPort(), (int) sessionTimeout.toMillis(), null) { + @Override + ClientCnxn createConnection(HostProvider hostProvider, int sessionTimeout, ZKClientConfig clientConfig, Watcher defaultWatcher, ClientCnxnSocket clientCnxnSocket, long sessionId, byte[] sessionPasswd, boolean canBeReadOnly) throws IOException { + ClientCnxnSocketNIO socket = spy((ClientCnxnSocketNIO) clientCnxnSocket); + + doAnswer(mock -> { + SocketChannel spy = spy((SocketChannel) mock.callRealMethod()); + // when: connect get exception + // + // this could happen if system's network service is unavailable, + // for examples, "ifdown eth0" or "service network stop" and so on. + doThrow(new SocketException("Network is unreachable")).when(spy).connect(any()); + return spy; + }).when(socket).createSock(); + + nioSelector.set(socket.getSelector()); + return super.createConnection(hostProvider, sessionTimeout, clientConfig, defaultWatcher, socket, sessionId, sessionPasswd, canBeReadOnly); + } + }) { + + Thread.sleep(sessionTimeout.toMillis() * 5); + + // then: sockets of failed connections are closed, so at most one registered socket + assertThat(nioSelector.get().keys().size(), lessThanOrEqualTo(1)); + } + } + @Test public void testClientCnxnSocketFragility() throws Exception { System.setProperty(ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/common/BusyServer.java b/zookeeper-server/src/test/java/org/apache/zookeeper/common/BusyServer.java new file mode 100644 index 00000000000..c2eece3d242 --- /dev/null +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/common/BusyServer.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper.common; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.Socket; + +public class BusyServer implements AutoCloseable { + private final ServerSocket server; + private final Socket client; + + public BusyServer() throws IOException { + this.server = new ServerSocket(0, 1, InetAddress.getByName("127.0.0.1")); + this.client = new Socket("127.0.0.1", server.getLocalPort()); + } + + public int getLocalPort() { + return server.getLocalPort(); + } + + public String getHostPort() { + return String.format("127.0.0.1:%d", getLocalPort()); + } + + @Override + public void close() throws Exception { + client.close(); + server.close(); + } +} diff --git a/zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTimeoutTest.java b/zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTimeoutTest.java index 9f5943f6821..86659ba70d7 100644 --- a/zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTimeoutTest.java +++ b/zookeeper-server/src/test/java/org/apache/zookeeper/test/SessionTimeoutTest.java @@ -27,8 +27,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import java.io.IOException; -import java.net.ServerSocket; -import java.net.Socket; import java.util.Arrays; import java.util.List; import java.util.concurrent.CompletableFuture; @@ -42,6 +40,7 @@ import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooDefs; import org.apache.zookeeper.ZooKeeper; +import org.apache.zookeeper.common.BusyServer; import org.apache.zookeeper.common.Time; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -75,30 +74,6 @@ public synchronized void process(WatchedEvent event) { } } - private static class BusyServer implements AutoCloseable { - private final ServerSocket server; - private final Socket client; - - public BusyServer() throws IOException { - this.server = new ServerSocket(0, 1); - this.client = new Socket("127.0.0.1", server.getLocalPort()); - } - - public int getLocalPort() { - return server.getLocalPort(); - } - - public String getHostPort() { - return String.format("127.0.0.1:%d", getLocalPort()); - } - - @Override - public void close() throws Exception { - client.close(); - server.close(); - } - } - @Test public void testSessionExpiration() throws InterruptedException, KeeperException { final CountDownLatch expirationLatch = new CountDownLatch(1);