Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@

import io.grpc.stub.StreamObserver;
import io.reactivex.rxjava3.disposables.Disposable;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -90,7 +89,7 @@ public InboxFetchPipeline(StreamObserver<InboxFetched> responseObserver,
fetchHint.getSessionId());
inboxSessionMap.computeIfAbsent(
new InboxId(fetchHint.getInboxId(), fetchHint.getIncarnation()),
k1 -> new HashSet<>()).add(fetchHint.getSessionId());
k1 -> ConcurrentHashMap.newKeySet()).add(fetchHint.getSessionId());
}
v.lastFetchQoS0Seq.set(
Math.max(fetchHint.getLastFetchQoS0Seq(), v.lastFetchQoS0Seq.get()));
Expand Down Expand Up @@ -131,16 +130,31 @@ public void send(InboxFetched message) {
public boolean signalFetch(String inboxId, long incarnation, long now) {
log.trace("Signal fetch: tenantId={}, inboxId={}", tenantId, inboxId);
// signal fetch won't refresh expiry
Set<Long> sessionIds = inboxSessionMap.getOrDefault(new InboxId(inboxId, incarnation), Collections.emptySet());
for (Long sessionId : sessionIds) {
InboxId inboxKey = new InboxId(inboxId, incarnation);
Set<Long> sessionIds = inboxSessionMap.get(inboxKey);
if (sessionIds == null || sessionIds.isEmpty()) {
return false;
}
boolean triggered = false;
Iterator<Long> itr = sessionIds.iterator();
while (itr.hasNext()) {
Long sessionId = itr.next();
FetchState fetchState = inboxFetchSessions.get(sessionId);
if (fetchState != null && fetchState.signalFetchTS.get() < now) {
if (fetchState == null) {
itr.remove();
continue;
}
triggered = true;
if (fetchState.signalFetchTS.get() < now) {
fetchState.hasMore.set(true);
fetchState.signalFetchTS.set(now);
fetch(fetchState);
}
}
return !sessionIds.isEmpty();
if (sessionIds.isEmpty()) {
inboxSessionMap.remove(inboxKey, sessionIds);
}
return triggered;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,27 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;

import io.grpc.Context;
import io.grpc.stub.ServerCallStreamObserver;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import lombok.SneakyThrows;
import org.apache.bifromq.baserpc.RPCContext;
import org.apache.bifromq.baserpc.metrics.IRPCMeter;
Expand Down Expand Up @@ -205,12 +212,96 @@ public void shouldNotRewindStartAfterWhenHintIsStale() {
pipeline.close();
}

@Test
public void shouldCleanStaleSessionIdWhenFetchStateMissing() throws Exception {
InboxFetcherRegistry registry = new InboxFetcherRegistry();
InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver, noopFetcher(), registry);

long sessionId = 4004L;
pipeline.onNext(hint(sessionId, 1));

Map<Long, ?> fetchSessions = fetchSessions(pipeline);
fetchSessions.remove(sessionId);

Map<?, Set<Long>> sessionMap = inboxSessionMap(pipeline);
Set<Long> sessionIds = sessionMap.values().iterator().next();
assertTrue(sessionIds.contains(sessionId));

boolean signalled = pipeline.signalFetch(INBOX, INCARNATION, System.nanoTime());

assertFalse(signalled);
assertTrue(sessionMap.isEmpty());
}

@Test
public void shouldNotThrowWhenSignalFetchConcurrentWithSessionRemoval() throws Exception {
InboxFetcherRegistry registry = new InboxFetcherRegistry();
CountingFetcher fetcher = new CountingFetcher();
InboxFetchPipeline pipeline = new InboxFetchPipeline(responseObserver, fetcher, registry);

long sessionA = 5005L;
long sessionB = 6006L;

pipeline.onNext(hint(sessionA, 5));
pipeline.onNext(hint(sessionB, 5));

CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Throwable> error = new AtomicReference<>();

Thread signalThread = new Thread(() -> {
try {
latch.await();
for (int i = 0; i < 500; i++) {
pipeline.signalFetch(INBOX, INCARNATION, System.nanoTime());
}
} catch (Throwable t) {
error.compareAndSet(null, t);
}
});

Thread removeThread = new Thread(() -> {
try {
latch.await();
for (int i = 0; i < 500; i++) {
pipeline.onNext(hint(sessionB, -1));
pipeline.onNext(hint(sessionB, 5));
}
} catch (Throwable t) {
error.compareAndSet(null, t);
}
});

signalThread.start();
removeThread.start();
latch.countDown();
signalThread.join();
removeThread.join();

assertNull(error.get());
await().until(() -> fetcher.fetchCount.get() > 0);
pipeline.close();
}

private InboxFetched lastReceived() {
synchronized (received) {
return received.get(received.size() - 1);
}
}

@SuppressWarnings("unchecked")
private Map<Long, ?> fetchSessions(InboxFetchPipeline pipeline) throws Exception {
Field field = InboxFetchPipeline.class.getDeclaredField("inboxFetchSessions");
field.setAccessible(true);
return (Map<Long, ?>) field.get(pipeline);
}

@SuppressWarnings("unchecked")
private Map<?, Set<Long>> inboxSessionMap(InboxFetchPipeline pipeline) throws Exception {
Field field = InboxFetchPipeline.class.getDeclaredField("inboxSessionMap");
field.setAccessible(true);
return (Map<?, Set<Long>>) field.get(pipeline);
}

private static class TestFetcher implements InboxFetchPipeline.Fetcher {
private final BlockingQueue<FetchRequest> requests = new LinkedBlockingQueue<>();
private final BlockingQueue<CompletableFuture<Fetched>> responses = new LinkedBlockingQueue<>();
Expand All @@ -237,4 +328,16 @@ void completeNext(Fetched fetched) {
future.complete(fetched);
}
}

private static class CountingFetcher implements InboxFetchPipeline.Fetcher {
private final AtomicInteger fetchCount = new AtomicInteger();

@Override
public CompletableFuture<Fetched> fetch(FetchRequest request) {
fetchCount.incrementAndGet();
return CompletableFuture.completedFuture(Fetched.newBuilder()
.setResult(Fetched.Result.OK)
.build());
}
}
}
Loading