diff --git a/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java b/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java index f06e56040..7259742c9 100644 --- a/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java +++ b/bifromq-inbox/bifromq-inbox-server/src/main/java/org/apache/bifromq/inbox/server/InboxFetchPipeline.java @@ -25,8 +25,7 @@ import io.grpc.stub.StreamObserver; import io.reactivex.rxjava3.disposables.Disposable; -import java.util.Collections; -import java.util.HashSet; +import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; @@ -90,7 +89,7 @@ public InboxFetchPipeline(StreamObserver responseObserver, fetchHint.getSessionId()); inboxSessionMap.computeIfAbsent( new InboxId(fetchHint.getInboxId(), fetchHint.getIncarnation()), - k1 -> new HashSet<>()).add(fetchHint.getSessionId()); + k1 -> ConcurrentHashMap.newKeySet()).add(fetchHint.getSessionId()); } v.lastFetchQoS0Seq.set( Math.max(fetchHint.getLastFetchQoS0Seq(), v.lastFetchQoS0Seq.get())); @@ -131,16 +130,31 @@ public void send(InboxFetched message) { public boolean signalFetch(String inboxId, long incarnation, long now) { log.trace("Signal fetch: tenantId={}, inboxId={}", tenantId, inboxId); // signal fetch won't refresh expiry - Set sessionIds = inboxSessionMap.getOrDefault(new InboxId(inboxId, incarnation), Collections.emptySet()); - for (Long sessionId : sessionIds) { + InboxId inboxKey = new InboxId(inboxId, incarnation); + Set sessionIds = inboxSessionMap.get(inboxKey); + if (sessionIds == null || sessionIds.isEmpty()) { + return false; + } + boolean triggered = false; + Iterator itr = sessionIds.iterator(); + while (itr.hasNext()) { + Long sessionId = itr.next(); FetchState fetchState = inboxFetchSessions.get(sessionId); - if (fetchState != null && fetchState.signalFetchTS.get() < now) { + if (fetchState == null) { + itr.remove(); + continue; + } + triggered = true; + if (fetchState.signalFetchTS.get() < now) { fetchState.hasMore.set(true); fetchState.signalFetchTS.set(now); fetch(fetchState); } } - return !sessionIds.isEmpty(); + if (sessionIds.isEmpty()) { + inboxSessionMap.remove(inboxKey, sessionIds); + } + return triggered; } @Override diff --git a/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java b/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java index 071fab65c..cff671dfa 100644 --- a/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java +++ b/bifromq-inbox/bifromq-inbox-server/src/test/java/org/apache/bifromq/inbox/server/InboxFetchPipelineMappingTest.java @@ -25,20 +25,27 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; import io.grpc.Context; import io.grpc.stub.ServerCallStreamObserver; import io.micrometer.core.instrument.Timer; import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import java.lang.reflect.Field; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import lombok.SneakyThrows; import org.apache.bifromq.baserpc.RPCContext; import org.apache.bifromq.baserpc.metrics.IRPCMeter; @@ -205,12 +212,96 @@ public void shouldNotRewindStartAfterWhenHintIsStale() { pipeline.close(); } + @Test + public void shouldCleanStaleSessionIdWhenFetchStateMissing() throws Exception { + InboxFetcherRegistry registry = new InboxFetcherRegistry(); + InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver, noopFetcher(), registry); + + long sessionId = 4004L; + pipeline.onNext(hint(sessionId, 1)); + + Map fetchSessions = fetchSessions(pipeline); + fetchSessions.remove(sessionId); + + Map> sessionMap = inboxSessionMap(pipeline); + Set sessionIds = sessionMap.values().iterator().next(); + assertTrue(sessionIds.contains(sessionId)); + + boolean signalled = pipeline.signalFetch(INBOX, INCARNATION, System.nanoTime()); + + assertFalse(signalled); + assertTrue(sessionMap.isEmpty()); + } + + @Test + public void shouldNotThrowWhenSignalFetchConcurrentWithSessionRemoval() throws Exception { + InboxFetcherRegistry registry = new InboxFetcherRegistry(); + CountingFetcher fetcher = new CountingFetcher(); + InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver, fetcher, registry); + + long sessionA = 5005L; + long sessionB = 6006L; + + pipeline.onNext(hint(sessionA, 5)); + pipeline.onNext(hint(sessionB, 5)); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference error = new AtomicReference<>(); + + Thread signalThread = new Thread(() -> { + try { + latch.await(); + for (int i = 0; i < 500; i++) { + pipeline.signalFetch(INBOX, INCARNATION, System.nanoTime()); + } + } catch (Throwable t) { + error.compareAndSet(null, t); + } + }); + + Thread removeThread = new Thread(() -> { + try { + latch.await(); + for (int i = 0; i < 500; i++) { + pipeline.onNext(hint(sessionB, -1)); + pipeline.onNext(hint(sessionB, 5)); + } + } catch (Throwable t) { + error.compareAndSet(null, t); + } + }); + + signalThread.start(); + removeThread.start(); + latch.countDown(); + signalThread.join(); + removeThread.join(); + + assertNull(error.get()); + await().until(() -> fetcher.fetchCount.get() > 0); + pipeline.close(); + } + private InboxFetched lastReceived() { synchronized (received) { return received.get(received.size() - 1); } } + @SuppressWarnings("unchecked") + private Map fetchSessions(InboxFetchPipeline pipeline) throws Exception { + Field field = InboxFetchPipeline.class.getDeclaredField("inboxFetchSessions"); + field.setAccessible(true); + return (Map) field.get(pipeline); + } + + @SuppressWarnings("unchecked") + private Map> inboxSessionMap(InboxFetchPipeline pipeline) throws Exception { + Field field = InboxFetchPipeline.class.getDeclaredField("inboxSessionMap"); + field.setAccessible(true); + return (Map>) field.get(pipeline); + } + private static class TestFetcher implements InboxFetchPipeline.Fetcher { private final BlockingQueue requests = new LinkedBlockingQueue<>(); private final BlockingQueue> responses = new LinkedBlockingQueue<>(); @@ -237,4 +328,16 @@ void completeNext(Fetched fetched) { future.complete(fetched); } } + + private static class CountingFetcher implements InboxFetchPipeline.Fetcher { + private final AtomicInteger fetchCount = new AtomicInteger(); + + @Override + public CompletableFuture fetch(FetchRequest request) { + fetchCount.incrementAndGet(); + return CompletableFuture.completedFuture(Fetched.newBuilder() + .setResult(Fetched.Result.OK) + .build()); + } + } }