diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index c3181e5bf2833..d8852fb16b86a 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -858,36 +858,49 @@ class SocketServerTest { @Test def testConnectionRatePerIp(): Unit = { + val defaultTimeoutMs = 2000 val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) overrideProps.remove(KafkaConfig.MaxConnectionsPerIpProp) overrideProps.put(KafkaConfig.NumQuotaSamplesProp, String.valueOf(2)) val connectionRate = 5 val time = new MockTime() val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), new Metrics(), time, credentialProvider) + // update the connection rate to 5 overrideServer.connectionQuotas.updateIpConnectionRateQuota(None, Some(connectionRate)) try { overrideServer.startup() - // make the maximum allowable number of connections - (0 until connectionRate).map(_ => connect(overrideServer)) - // now try one more (should get throttled) - var conn = connect(overrideServer) + // make the (maximum allowable number + 1) of connections + (0 to connectionRate).map(_ => connect(overrideServer)) + val acceptors = overrideServer.dataPlaneAcceptors.asScala.values - TestUtils.waitUntilTrue(() => acceptors.exists(_.throttledSockets.nonEmpty), - "timeout waiting for connection to get throttled", - 1000) + // waiting for 5 connections got accepted and 1 connection got throttled + TestUtils.waitUntilTrue( + () => acceptors.foldLeft(0)((accumulator, acceptor) => accumulator + acceptor.throttledSockets.size) == 1, + "timeout waiting for 1 connection to get throttled", + defaultTimeoutMs) + + // now try one more, so that we can make sure this connection will get throttled + var conn = connect(overrideServer) + // there should be total 2 connection got throttled now + TestUtils.waitUntilTrue( + () => acceptors.foldLeft(0)((accumulator, acceptor) => accumulator + acceptor.throttledSockets.size) == 2, + "timeout waiting for 2 connection to get throttled", + defaultTimeoutMs) // advance time to unthrottle connections - time.sleep(2000) + time.sleep(defaultTimeoutMs) acceptors.foreach(_.wakeup()) + // make sure there are no connection got throttled now(and the throttled connections should be closed) TestUtils.waitUntilTrue(() => acceptors.forall(_.throttledSockets.isEmpty), "timeout waiting for connection to be unthrottled", - 1000) + defaultTimeoutMs) + // verify the connection is closed now verifyRemoteConnectionClosed(conn) // new connection should succeed after previous connection closed, and previous samples have been expired conn = connect(overrideServer) val serializedBytes = producerRequestBytes() sendRequest(conn, serializedBytes) - val request = overrideServer.dataPlaneRequestChannel.receiveRequest(2000) + val request = overrideServer.dataPlaneRequestChannel.receiveRequest(defaultTimeoutMs) assertNotNull(request) } finally { shutdownServerAndMetrics(overrideServer)