diff --git a/src/main/java/org/tikv/common/TiSession.java b/src/main/java/org/tikv/common/TiSession.java index 63a07ed9a40..87410eab637 100644 --- a/src/main/java/org/tikv/common/TiSession.java +++ b/src/main/java/org/tikv/common/TiSession.java @@ -70,7 +70,7 @@ public class TiSession implements AutoCloseable { private volatile boolean enableGrpcForward; private volatile RegionStoreClient.RegionStoreClientBuilder clientBuilder; private volatile ImporterStoreClient.ImporterStoreClientBuilder importerClientBuilder; - private boolean isClosed = false; + private volatile boolean isClosed = false; private MetricsServer metricsServer; private static final int MAX_SPLIT_REGION_STACK_DEPTH = 6; @@ -106,22 +106,30 @@ public static TiSession getInstance(TiConfiguration conf) { } public RawKVClient createRawClient() { + checkIsClosed(); + RegionStoreClientBuilder builder = new RegionStoreClientBuilder(conf, channelFactory, this.getRegionManager(), client); return new RawKVClient(this, builder); } public KVClient createKVClient() { + checkIsClosed(); + RegionStoreClientBuilder builder = new RegionStoreClientBuilder(conf, channelFactory, this.getRegionManager(), client); return new KVClient(conf, builder); } public TxnKVClient createTxnClient() { + checkIsClosed(); + return new TxnKVClient(conf, this.getRegionStoreClientBuilder(), this.getPDClient()); } public RegionStoreClient.RegionStoreClientBuilder getRegionStoreClientBuilder() { + checkIsClosed(); + RegionStoreClient.RegionStoreClientBuilder res = clientBuilder; if (res == null) { synchronized (this) { @@ -137,6 +145,8 @@ public RegionStoreClient.RegionStoreClientBuilder getRegionStoreClientBuilder() } public ImporterStoreClient.ImporterStoreClientBuilder getImporterRegionStoreClientBuilder() { + checkIsClosed(); + ImporterStoreClient.ImporterStoreClientBuilder res = importerClientBuilder; if (res == null) { synchronized (this) { @@ -156,18 +166,26 @@ public TiConfiguration getConf() { } public TiTimestamp getTimestamp() { + checkIsClosed(); + return getPDClient().getTimestamp(ConcreteBackOffer.newTsoBackOff()); } public Snapshot createSnapshot() { + checkIsClosed(); + return new Snapshot(getTimestamp(), this); } public Snapshot createSnapshot(TiTimestamp ts) { + checkIsClosed(); + return new Snapshot(ts, this); } public PDClient getPDClient() { + checkIsClosed(); + PDClient res = client; if (res == null) { synchronized (this) { @@ -181,6 +199,8 @@ public PDClient getPDClient() { } public Catalog getCatalog() { + checkIsClosed(); + Catalog res = catalog; if (res == null) { synchronized (this) { @@ -194,6 +214,8 @@ public Catalog getCatalog() { } public RegionManager getRegionManager() { + checkIsClosed(); + RegionManager res = regionManager; if (res == null) { synchronized (this) { @@ -207,6 +229,8 @@ public RegionManager getRegionManager() { } public ExecutorService getThreadPoolForIndexScan() { + checkIsClosed(); + ExecutorService res = indexScanThreadPool; if (res == null) { synchronized (this) { @@ -226,6 +250,8 @@ public ExecutorService getThreadPoolForIndexScan() { } public ExecutorService getThreadPoolForTableScan() { + checkIsClosed(); + ExecutorService res = tableScanThreadPool; if (res == null) { synchronized (this) { @@ -242,6 +268,8 @@ public ExecutorService getThreadPoolForTableScan() { } public ExecutorService getThreadPoolForBatchPut() { + checkIsClosed(); + ExecutorService res = batchPutThreadPool; if (res == null) { synchronized (this) { @@ -261,6 +289,8 @@ public ExecutorService getThreadPoolForBatchPut() { } public ExecutorService getThreadPoolForBatchGet() { + checkIsClosed(); + ExecutorService res = batchGetThreadPool; if (res == null) { synchronized (this) { @@ -280,6 +310,8 @@ public ExecutorService getThreadPoolForBatchGet() { } public ExecutorService getThreadPoolForBatchDelete() { + checkIsClosed(); + ExecutorService res = batchDeleteThreadPool; if (res == null) { synchronized (this) { @@ -299,6 +331,8 @@ public ExecutorService getThreadPoolForBatchDelete() { } public ExecutorService getThreadPoolForBatchScan() { + checkIsClosed(); + ExecutorService res = batchScanThreadPool; if (res == null) { synchronized (this) { @@ -318,6 +352,8 @@ public ExecutorService getThreadPoolForBatchScan() { } public ExecutorService getThreadPoolForDeleteRange() { + checkIsClosed(); + ExecutorService res = deleteRangeThreadPool; if (res == null) { synchronized (this) { @@ -338,6 +374,8 @@ public ExecutorService getThreadPoolForDeleteRange() { @VisibleForTesting public ChannelFactory getChannelFactory() { + checkIsClosed(); + return channelFactory; } @@ -347,6 +385,8 @@ public ChannelFactory getChannelFactory() { * @return a SwitchTiKVModeClient */ public SwitchTiKVModeClient getSwitchTiKVModeClient() { + checkIsClosed(); + return new SwitchTiKVModeClient(getPDClient(), getImporterRegionStoreClientBuilder()); } @@ -363,6 +403,8 @@ public void splitRegionAndScatter( int splitRegionBackoffMS, int scatterRegionBackoffMS, int scatterWaitMS) { + checkIsClosed(); + logger.info(String.format("split key's size is %d", splitKeys.size())); long startMS = System.currentTimeMillis(); @@ -412,6 +454,8 @@ public void splitRegionAndScatter( * @param splitKeys */ public void splitRegionAndScatter(List splitKeys) { + checkIsClosed(); + int splitRegionBackoffMS = BackOffer.SPLIT_REGION_BACKOFF; int scatterRegionBackoffMS = BackOffer.SCATTER_REGION_BACKOFF; int scatterWaitMS = conf.getScatterWaitSeconds() * 1000; @@ -475,50 +519,111 @@ private List splitRegion( return regions; } - @Override - public synchronized void close() throws Exception { + private void checkIsClosed() { if (isClosed) { - logger.warn("this TiSession is already closed!"); - return; + throw new RuntimeException("this TiSession is closed!"); + } + } + + public synchronized void closeAwaitTermination(long timeoutMS) throws Exception { + shutdown(false); + + long startMS = System.currentTimeMillis(); + while (true) { + if (isTerminatedExecutorServices()) { + cleanAfterTerminated(); + return; + } + + if (System.currentTimeMillis() - startMS > timeoutMS) { + shutdown(true); + return; + } + Thread.sleep(500); } + } + + @Override + public synchronized void close() throws Exception { + shutdown(true); + } - if (metricsServer != null) { - metricsServer.close(); + private synchronized void shutdown(boolean now) throws Exception { + if (!isClosed) { + isClosed = true; + synchronized (sessionCachedMap) { + sessionCachedMap.remove(conf.getPdAddrsString()); + } + + if (metricsServer != null) { + metricsServer.close(); + } } - isClosed = true; - synchronized (sessionCachedMap) { - sessionCachedMap.remove(conf.getPdAddrsString()); + if (now) { + shutdownNowExecutorServices(); + cleanAfterTerminated(); + } else { + shutdownExecutorServices(); } + } + + private synchronized void cleanAfterTerminated() throws InterruptedException { if (regionManager != null) { regionManager.close(); } + if (client != null) { + client.close(); + } + if (catalog != null) { + catalog.close(); + } + } + + private List getExecutorServices() { + List executorServiceList = new ArrayList<>(); if (tableScanThreadPool != null) { - tableScanThreadPool.shutdownNow(); + executorServiceList.add(tableScanThreadPool); } if (indexScanThreadPool != null) { - indexScanThreadPool.shutdownNow(); + executorServiceList.add(indexScanThreadPool); } if (batchGetThreadPool != null) { - batchGetThreadPool.shutdownNow(); + executorServiceList.add(batchGetThreadPool); } if (batchPutThreadPool != null) { - batchPutThreadPool.shutdownNow(); + executorServiceList.add(batchPutThreadPool); } if (batchDeleteThreadPool != null) { - batchDeleteThreadPool.shutdownNow(); + executorServiceList.add(batchDeleteThreadPool); } if (batchScanThreadPool != null) { - batchScanThreadPool.shutdownNow(); + executorServiceList.add(batchScanThreadPool); } if (deleteRangeThreadPool != null) { - deleteRangeThreadPool.shutdownNow(); + executorServiceList.add(deleteRangeThreadPool); } - if (client != null) { - getPDClient().close(); + return executorServiceList; + } + + private void shutdownExecutorServices() { + for (ExecutorService executorService : getExecutorServices()) { + executorService.shutdown(); } - if (catalog != null) { - getCatalog().close(); + } + + private void shutdownNowExecutorServices() { + for (ExecutorService executorService : getExecutorServices()) { + executorService.shutdownNow(); + } + } + + private boolean isTerminatedExecutorServices() { + for (ExecutorService executorService : getExecutorServices()) { + if (!executorService.isTerminated()) { + return false; + } } + return true; } } diff --git a/src/test/java/org/tikv/common/TiSessionTest.java b/src/test/java/org/tikv/common/TiSessionTest.java new file mode 100644 index 00000000000..4b2cd183c54 --- /dev/null +++ b/src/test/java/org/tikv/common/TiSessionTest.java @@ -0,0 +1,124 @@ +package org.tikv.common; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.protobuf.ByteString; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Test; +import org.tikv.common.region.TiRegion; +import org.tikv.raw.RawKVClient; + +public class TiSessionTest { + private TiSession session; + + @After + public void tearDown() throws Exception { + if (session != null) { + session.close(); + } + } + + @Test + public void closeWithRunningTaskTest() throws Exception { + doCloseWithRunningTaskTest(true, 0); + } + + @Test + public void closeAwaitTerminationWithRunningTaskTest() throws Exception { + doCloseWithRunningTaskTest(false, 10000); + } + + private void doCloseWithRunningTaskTest(boolean now, long timeoutMS) throws Exception { + TiConfiguration conf = TiConfiguration.createRawDefault(); + session = TiSession.create(conf); + + ExecutorService executorService = session.getThreadPoolForBatchGet(); + AtomicReference interruptedException = new AtomicReference<>(); + executorService.submit( + () -> { + int i = 1; + while (true) { + i = i + 1; + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + interruptedException.set(e); + break; + } + } + }); + + Thread.sleep(2000); + + long startMS = System.currentTimeMillis(); + if (now) { + session.close(); + Thread.sleep(1000); + assertNotNull(interruptedException.get()); + assertTrue(System.currentTimeMillis() - startMS < 2000); + } else { + session.closeAwaitTermination(timeoutMS); + assertNotNull(interruptedException.get()); + assertTrue(System.currentTimeMillis() - startMS >= timeoutMS); + } + } + + @Test + public void closeTest() throws Exception { + doCloseTest(true, 0); + } + + @Test + public void closeAwaitTerminationTest() throws Exception { + doCloseTest(false, 10000); + } + + private void doCloseTest(boolean now, long timeoutMS) throws Exception { + TiConfiguration conf = TiConfiguration.createRawDefault(); + session = TiSession.create(conf); + RawKVClient client = session.createRawClient(); + + // test getRegionByKey + ByteString key = ByteString.copyFromUtf8("key"); + ByteString value = ByteString.copyFromUtf8("value"); + TiRegion region = session.getRegionManager().getRegionByKey(key); + assertNotNull(region); + + // test RawKVClient + client.put(key, value); + List keys = new ArrayList<>(); + keys.add(key); + client.batchGet(keys); + + // close TiSession + if (now) { + session.close(); + } else { + session.closeAwaitTermination(timeoutMS); + } + + // test getRegionByKey + try { + session.getRegionManager().getRegionByKey(key); + fail(); + } catch (RuntimeException e) { + assertEquals("this TiSession is closed!", e.getMessage()); + } + + // test RawKVClient + try { + client.batchGet(keys); + fail(); + } catch (RejectedExecutionException e) { + assertTrue(e.getMessage().contains("rejected from java.util.concurrent.ThreadPoolExecutor")); + } + } +}