From e914e3fa37221e9fea4c9b6bb38a199c4196e789 Mon Sep 17 00:00:00 2001 From: Sijie Guo Date: Sat, 3 Aug 2019 16:56:53 +0800 Subject: [PATCH 1/2] Port Kafka timer classes *Motivation* Need to port Kafka's coordinator algorithm. It requires using Kafka's timer and deplayed operations. --- .../kop/utils/timer/SystemTimer.java | 181 +++++++++++++++ .../streamnative/kop/utils/timer/Timer.java | 51 +++++ .../kop/utils/timer/TimerTask.java | 50 +++++ .../kop/utils/timer/TimerTaskList.java | 209 ++++++++++++++++++ .../kop/utils/timer/TimingWheel.java | 206 +++++++++++++++++ .../kop/utils/timer/package-info.java | 19 ++ .../io/streamnative/kop/utils/MockTime.java | 111 ++++++++++ .../kop/utils/timer/MockTimer.java | 85 +++++++ .../kop/utils/timer/TimerTaskListTest.java | 108 +++++++++ .../kop/utils/timer/TimerTest.java | 161 ++++++++++++++ 10 files changed, 1181 insertions(+) create mode 100644 src/main/java/io/streamnative/kop/utils/timer/SystemTimer.java create mode 100644 src/main/java/io/streamnative/kop/utils/timer/Timer.java create mode 100644 src/main/java/io/streamnative/kop/utils/timer/TimerTask.java create mode 100644 src/main/java/io/streamnative/kop/utils/timer/TimerTaskList.java create mode 100644 src/main/java/io/streamnative/kop/utils/timer/TimingWheel.java create mode 100644 src/main/java/io/streamnative/kop/utils/timer/package-info.java create mode 100644 src/test/java/io/streamnative/kop/utils/MockTime.java create mode 100644 src/test/java/io/streamnative/kop/utils/timer/MockTimer.java create mode 100644 src/test/java/io/streamnative/kop/utils/timer/TimerTaskListTest.java create mode 100644 src/test/java/io/streamnative/kop/utils/timer/TimerTest.java diff --git a/src/main/java/io/streamnative/kop/utils/timer/SystemTimer.java b/src/main/java/io/streamnative/kop/utils/timer/SystemTimer.java new file mode 100644 index 0000000000..9056db31fc --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/SystemTimer.java @@ -0,0 +1,181 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.streamnative.kop.utils.timer.TimerTaskList.TimerTaskEntry; +import java.util.Objects; +import java.util.concurrent.DelayQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Consumer; +import javax.annotation.concurrent.ThreadSafe; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.common.utils.Time; + +/** + * A system timer implementation. + */ +@Slf4j +@ThreadSafe +public class SystemTimer implements Timer { + + /** + * Create a system timer builder. + * + * @return a system timer builder. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder to build a system timer. + */ + public static class Builder { + + private String executorName; + private long tickMs = 1; + private int wheelSize = 20; + private long startMs = Time.SYSTEM.hiResClockMs(); + + private Builder() {} + + public Builder executorName(String executorName) { + this.executorName = executorName; + return this; + } + + public Builder tickMs(long tickMs) { + this.tickMs = tickMs; + return this; + } + + public Builder wheelSize(int wheelSize) { + this.wheelSize = wheelSize; + return this; + } + + public Builder startMs(long startMs) { + this.startMs = startMs; + return this; + } + + public SystemTimer build() { + Objects.requireNonNull(executorName, "No executor name is provided"); + + return new SystemTimer( + executorName, + tickMs, + wheelSize, + startMs + ); + } + + } + + private final ExecutorService taskExecutor; + private final DelayQueue delayQueue; + private final AtomicInteger taskCounter; + private final TimingWheel timingWheel; + + // Locks used to protect data structures while ticking + private final ReentrantReadWriteLock readWriteLock; + private final Lock readLock; + private final Lock writeLock; + private final Consumer reinsert; + + private SystemTimer(String executorName, + long tickMs, + int wheelSize, + long startMs) { + this.taskExecutor = Executors.newFixedThreadPool( + 1, new ThreadFactoryBuilder() + .setDaemon(false) + .setNameFormat("system-timer-%d") + .build() + ); + this.delayQueue = new DelayQueue(); + this.taskCounter = new AtomicInteger(0); + this.timingWheel = new TimingWheel( + tickMs, + wheelSize, + startMs, + taskCounter, + delayQueue + ); + this.readWriteLock = new ReentrantReadWriteLock(); + this.readLock = readWriteLock.readLock(); + this.writeLock = readWriteLock.writeLock(); + this.reinsert = timerTaskEntry -> addTimerTaskEntry(timerTaskEntry); + } + + @Override + public void add(TimerTask timerTask) { + readLock.lock(); + try { + addTimerTaskEntry(new TimerTaskEntry( + timerTask, timerTask.delayMs + Time.SYSTEM.hiResClockMs() + )); + } finally { + readLock.unlock(); + } + } + + private void addTimerTaskEntry(TimerTaskEntry timerTaskEntry) { + if (!timingWheel.add(timerTaskEntry)) { + // Already expired or cancelled + if (!timerTaskEntry.cancelled()) { + taskExecutor.submit(timerTaskEntry.timerTask()); + } + } + } + + @SneakyThrows + @Override + public boolean advanceClock(long timeoutMs) { + TimerTaskList bucket = delayQueue.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (null != bucket) { + writeLock.lock(); + try { + while (null != bucket) { + timingWheel.advanceClock(bucket.getExpiration()); + bucket.flush(reinsert); + bucket = delayQueue.poll(); + } + } finally { + writeLock.unlock(); + } + return true; + } else { + return false; + } + } + + @Override + public int size() { + return taskCounter.get(); + } + + @Override + public void shutdown() { + taskExecutor.shutdown(); + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/timer/Timer.java b/src/main/java/io/streamnative/kop/utils/timer/Timer.java new file mode 100644 index 0000000000..8f17321449 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/Timer.java @@ -0,0 +1,51 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +/** + * The timer interface to execute delayed operations. + */ +public interface Timer { + + /** + * Add a new task to this executor. It will be executed after the task's delay + * (beginning from the time of submission) + * + * @param timerTask the task to add + */ + void add(TimerTask timerTask); + + /** + * Advance the internal clock, executing any tasks whose expiration has been + * reached within the duration of the passed timeout. + * + * @param timeoutMs + * @return whether or not any tasks were executed + */ + boolean advanceClock(long timeoutMs); + + /** + * Get the number of tasks pending execution. + * + * @return the number of tasks + */ + int size(); + + /** + * Shutdown the timer service, leaving pending tasks unexecuted. + */ + void shutdown(); + + +} diff --git a/src/main/java/io/streamnative/kop/utils/timer/TimerTask.java b/src/main/java/io/streamnative/kop/utils/timer/TimerTask.java new file mode 100644 index 0000000000..0adbc766be --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/TimerTask.java @@ -0,0 +1,50 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import io.streamnative.kop.utils.timer.TimerTaskList.TimerTaskEntry; + +/** + * Timer task. + */ +public abstract class TimerTask implements Runnable { + + protected final long delayMs; + private TimerTaskEntry timerTaskEntry = null; + + protected TimerTask(long delayMs) { + this.delayMs = delayMs; + } + + public synchronized void cancel() { + if (null != timerTaskEntry) { + timerTaskEntry.remove(); + timerTaskEntry = null; + } + } + + synchronized void setTimerTaskEntry(TimerTaskEntry entry) { + // if this timerTask is already held by an existing timer task entry, + // we will remove such an entry first. + if (null != timerTaskEntry && timerTaskEntry != entry) { + timerTaskEntry.remove(); + } + timerTaskEntry = entry; + } + + synchronized TimerTaskEntry getTimerTaskEntry() { + return timerTaskEntry; + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/timer/TimerTaskList.java b/src/main/java/io/streamnative/kop/utils/timer/TimerTaskList.java new file mode 100644 index 0000000000..1f6451883d --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/TimerTaskList.java @@ -0,0 +1,209 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.util.concurrent.Delayed; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import javax.annotation.concurrent.ThreadSafe; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.common.utils.Time; + +/** + * The timer task list is a java implementation of Kafka implementation. + */ +@SuppressFBWarnings({ + "EQ_COMPARETO_USE_OBJECT_EQUALS", + "HE_EQUALS_USE_HASHCODE" +}) +@Slf4j +@ThreadSafe +public class TimerTaskList implements Delayed { + + private final AtomicInteger taskCounter; + private final AtomicLong expiration; + + // TimerTaskList forms a doubly linked cyclic list using a dummy root entry + // root.next points to the head + // root.prev points to the tail + private final TimerTaskEntry root; + + public TimerTaskList(AtomicInteger taskCounter) { + this.taskCounter = taskCounter; + this.root = new TimerTaskEntry(null, -1); + this.root.next = root; + this.root.prev = root; + this.expiration = new AtomicLong(-1L); + } + + // Set the bucket's expiration time + // Returns true if the expiration time is changed + public boolean setExpiration(long expirationMs) { + return expiration.getAndSet(expirationMs) != expirationMs; + } + + // Get the bucket's expiration time + public long getExpiration() { + return expiration.get(); + } + + public synchronized void forEach(Consumer f) { + TimerTaskEntry entry = root.next; + while (entry != root) { + final TimerTaskEntry nextEntry = entry.next; + if (!entry.cancelled()) { + f.accept(entry.timerTask); + } + entry = nextEntry; + } + } + + // add a timer task entry to this list + public void add(TimerTaskEntry timerTaskEntry) { + boolean done = false; + while (!done) { + // Remove the timer task entry if it is already in any other list + // We do this outside of the sync block below to avoid deadlocking. + // We may retry until timerTaskEntry.list becomes null. + timerTaskEntry.remove(); + + synchronized (this) { + synchronized (timerTaskEntry) { + if (timerTaskEntry.list == null) { + // put the timer task entry to the end of the list. (root.prev points to the tail entry) + TimerTaskEntry tail = root.prev; + timerTaskEntry.next = root; + timerTaskEntry.prev = tail; + timerTaskEntry.list = this; + tail.next = timerTaskEntry; + root.prev = timerTaskEntry; + taskCounter.incrementAndGet(); + done = true; + } + } + } + } + } + + // Remove the specified timer task entry from this list + public void remove(TimerTaskEntry timerTaskEntry) { + synchronized (this) { + synchronized (timerTaskEntry) { + if (timerTaskEntry.list == this) { + timerTaskEntry.next.prev = timerTaskEntry.prev; + timerTaskEntry.prev.next = timerTaskEntry.next; + timerTaskEntry.next = null; + timerTaskEntry.prev = null; + timerTaskEntry.list = null; + taskCounter.decrementAndGet(); + } + } + } + } + + // Remove all task entries and apply the supplied function to each of them + public synchronized void flush(Consumer f) { + TimerTaskEntry head = root.next; + while (head != root) { + remove(head); + f.accept(head); + head = root.next; + } + expiration.set(-1L); + } + + public long getDelay(TimeUnit unit) { + return unit.convert(Math.max(getExpiration() - Time.SYSTEM.hiResClockMs(), 0), TimeUnit.MILLISECONDS); + } + + @Override + public int compareTo(Delayed o) { + TimerTaskList other = (TimerTaskList) o; + + if (getExpiration() < other.getExpiration()) { + return -1; + } else if (getExpiration() > other.getExpiration()) { + return 1; + } else { + return 0; + } + } + + /** + * A timer task entry in the timer task list. + */ + @Accessors(fluent = true) + protected static class TimerTaskEntry implements Comparable { + + @Getter + private final TimerTask timerTask; + @Getter + private final long expirationMs; + private volatile TimerTaskList list = null; + private TimerTaskEntry next = null; + private TimerTaskEntry prev = null; + + public TimerTaskEntry(TimerTask timerTask, + long expirationMs) { + this.timerTask = timerTask; + this.expirationMs = expirationMs; + // if this timerTask is already held by an existing timer task entry, + // setTimerTaskEntry will remove it. + if (null != timerTask) { + timerTask.setTimerTaskEntry(this); + } + } + + public boolean cancelled() { + return timerTask.getTimerTaskEntry() != this; + } + + public void remove() { + TimerTaskList currentList = list; + // If remove is called when another thread is moving the entry from a task entry list to another, + // this may fail to remove the entry due to the change of value of list. Thus, we retry until the + // list becomes null. In a rare case, this thread sees null and exits the loop, but the other thread + // insert the entry to another list later. + while (currentList != null) { + currentList.remove(this); + currentList = list; + } + } + + @Override + public int compareTo(TimerTaskEntry o) { + return Long.compare(this.expirationMs, o.expirationMs); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof TimerTaskEntry)) { + return false; + } + TimerTaskEntry other = (TimerTaskEntry) obj; + return compareTo(other) == 0 + && list == other.list + && next == other.next + && prev == other.prev + && timerTask == other.timerTask; + } + } + + +} diff --git a/src/main/java/io/streamnative/kop/utils/timer/TimingWheel.java b/src/main/java/io/streamnative/kop/utils/timer/TimingWheel.java new file mode 100644 index 0000000000..051d66544b --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/TimingWheel.java @@ -0,0 +1,206 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import io.streamnative.kop.utils.timer.TimerTaskList.TimerTaskEntry; +import java.util.List; +import java.util.concurrent.DelayQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Hierarchical Timing Wheels + * + *

A simple timing wheel is a circular list of buckets of timer tasks. Let u be the time unit. + * A timing wheel with size n has n buckets and can hold timer tasks in n * u time interval. + * Each bucket holds timer tasks that fall into the corresponding time range. At the beginning, + * the first bucket holds tasks for [0, u), the second bucket holds tasks for [u, 2u), …, + * the n-th bucket for [u * (n -1), u * n). Every interval of time unit u, the timer ticks and + * moved to the next bucket then expire all timer tasks in it. So, the timer never insert a task + * into the bucket for the current time since it is already expired. The timer immediately runs + * the expired task. The emptied bucket is then available for the next round, so if the current + * bucket is for the time t, it becomes the bucket for [t + u * n, t + (n + 1) * u) after a tick. + * A timing wheel has O(1) cost for insert/delete (start-timer/stop-timer) whereas priority queue + * based timers, such as java.util.concurrent.DelayQueue and java.util.Timer, have O(log n) + * insert/delete cost. + * + *

A major drawback of a simple timing wheel is that it assumes that a timer request is within + * the time interval of n * u from the current time. If a timer request is out of this interval, + * it is an overflow. A hierarchical timing wheel deals with such overflows. It is a hierarchically + * organized timing wheels. The lowest level has the finest time resolution. As moving up the + * hierarchy, time resolutions become coarser. If the resolution of a wheel at one level is u and + * the size is n, the resolution of the next level should be n * u. At each level overflows are + * delegated to the wheel in one level higher. When the wheel in the higher level ticks, it reinsert + * timer tasks to the lower level. An overflow wheel can be created on-demand. When a bucket in an + * overflow bucket expires, all tasks in it are reinserted into the timer recursively. The tasks + * are then moved to the finer grain wheels or be executed. The insert (start-timer) cost is O(m) + * where m is the number of wheels, which is usually very small compared to the number of requests + * in the system, and the delete (stop-timer) cost is still O(1). + * + *

Example + * Let's say that u is 1 and n is 3. If the start time is c, + * then the buckets at different levels are: + * + *

+ * level buckets + * 1 [c,c] [c+1,c+1] [c+2,c+2] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+26] + *

+ * + *

The bucket expiration is at the time of bucket beginning. + * So at time = c+1, buckets [c,c], [c,c+2] and [c,c+8] are expired. + * Level 1's clock moves to c+1, and [c+3,c+3] is created. + * Level 2 and level3's clock stay at c since their clocks move in unit of 3 and 9, respectively. + * So, no new buckets are created in level 2 and 3. + * + *

Note that bucket [c,c+2] in level 2 won't receive any task since that range is already covered in level 1. + * The same is true for the bucket [c,c+8] in level 3 since its range is covered in level 2. + * This is a bit wasteful, but simplifies the implementation. + * + *

+ * 1 [c+1,c+1] [c+2,c+2] [c+3,c+3] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+26] + *

+ * + *

At time = c+2, [c+1,c+1] is newly expired. + * Level 1 moves to c+2, and [c+4,c+4] is created, + * + *

+ * 1 [c+2,c+2] [c+3,c+3] [c+4,c+4] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+18] + *

+ * + *

+ * At time = c+3, [c+2,c+2] is newly expired. + * Level 2 moves to c+3, and [c+5,c+5] and [c+9,c+11] are created. + * Level 3 stay at c. + *

+ * + *

+ * 1 [c+3,c+3] [c+4,c+4] [c+5,c+5] + * 2 [c+3,c+5] [c+6,c+8] [c+9,c+11] + * 3 [c,c+8] [c+9,c+17] [c+8,c+11] + *

+ * + *

The hierarchical timing wheels works especially well when operations are completed before they time out. + * Even when everything times out, it still has advantageous when there are many items in the timer. + * Its insert cost (including reinsert) and delete cost are O(m) and O(1), respectively while priority + * queue based timers takes O(log N) for both insert and delete where N is the number of items in the queue. + * + *

This class is not thread-safe. There should not be any add calls while advanceClock is executing. + * It is caller's responsibility to enforce it. Simultaneous add calls are thread-safe. + * + *

Note: this is the implementation from Kafka. + */ +class TimingWheel { + + private final long tickMs; + private final int wheelSize; + private final long startMs; + private final AtomicInteger taskCounter; + private final DelayQueue queue; + + private final long interval; + private final List buckets; + private long currentTime; + + // overflowWheel can potentially be updated and read by two concurrent threads through add(). + // Therefore, it needs to be volatile due to the issue of Double-Checked Locking pattern with JVM + private volatile TimingWheel overflowWheel = null; + + public TimingWheel( + long tickMs, + int wheelSize, + long startMs, + AtomicInteger taskCounter, + DelayQueue queue + ) { + this.tickMs = tickMs; + this.wheelSize = wheelSize; + this.startMs = startMs; + this.taskCounter = taskCounter; + this.queue = queue; + + this.interval = tickMs * wheelSize; + this.buckets = IntStream.range(0, wheelSize) + .mapToObj(i -> new TimerTaskList(taskCounter)) + .collect(Collectors.toList()); + this.currentTime = startMs - (startMs % tickMs); // rounding down to multiple of tickMs + } + + private synchronized void addOverflowWheel() { + if (null == overflowWheel) { + overflowWheel = new TimingWheel( + interval, + wheelSize, + currentTime, + taskCounter, + queue + ); + } + } + + public boolean add(TimerTaskEntry timerTaskEntry) { + final long expiration = timerTaskEntry.expirationMs(); + + if (timerTaskEntry.cancelled()) { + // cancelled + return false; + } else if (expiration < currentTime + tickMs) { + // Already expired + return false; + } else if (expiration < currentTime + interval) { + // Put in its own bucket + final long virtualId = expiration / tickMs; + TimerTaskList bucket = buckets.get( + (int) (virtualId % (long) wheelSize) + ); + bucket.add(timerTaskEntry); + + // Set the bucket expiration time + if (bucket.setExpiration(virtualId * tickMs)) { + // The bucket needs to be enqueued because it was an expired bucket + // We only need to enqueue the bucket when its expiration time has changed, i.e. the wheel has advanced + // and the previous buckets gets reused; further calls to set the expiration within the same wheel cycle + // will pass in the same value and hence return false, thus the bucket with the same expiration will not + // be enqueued multiple times. + queue.offer(bucket); + } + return true; + } else { + // Out of the interval. Put it into the parent timer + if (null == overflowWheel) { + addOverflowWheel(); + } + return overflowWheel.add(timerTaskEntry); + } + } + + // Try to advance the clock + public void advanceClock(long timeMs) { + if (timeMs >= currentTime + tickMs) { + currentTime = timeMs - (timeMs % tickMs); + + // Try to advance the clock of the overflow wheel if present + if (null != overflowWheel) { + overflowWheel.advanceClock(currentTime); + } + } + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/timer/package-info.java b/src/main/java/io/streamnative/kop/utils/timer/package-info.java new file mode 100644 index 0000000000..ae50afe3b5 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/timer/package-info.java @@ -0,0 +1,19 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Timer related classes. + * + *

The classes under this package are ported from Kafka. + */ +package io.streamnative.kop.utils.timer; diff --git a/src/test/java/io/streamnative/kop/utils/MockTime.java b/src/test/java/io/streamnative/kop/utils/MockTime.java new file mode 100644 index 0000000000..ffc3086a9a --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/MockTime.java @@ -0,0 +1,111 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils; + +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.apache.kafka.common.utils.Time; + +/** + * A clock that you can manually advance by calling sleep. + */ +public class MockTime implements Time { + + /** + * Mock time listener. + */ + interface MockTimeListener { + void tick(); + } + + /** + * Listeners which are waiting for time changes. + */ + private final CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>(); + + private final long autoTickMs; + + // Values from `nanoTime` and `currentTimeMillis` are not comparable, so we store them separately to allow tests + // using this class to detect bugs where this is incorrectly assumed to be true + private final AtomicLong timeMs; + private final AtomicLong highResTimeNs; + + public MockTime() { + this(0); + } + + public MockTime(long autoTickMs) { + this(autoTickMs, System.currentTimeMillis(), System.nanoTime()); + } + + public MockTime(long autoTickMs, long currentTimeMs, long currentHighResTimeNs) { + this.timeMs = new AtomicLong(currentTimeMs); + this.highResTimeNs = new AtomicLong(currentHighResTimeNs); + this.autoTickMs = autoTickMs; + } + + public void addListener(MockTimeListener listener) { + listeners.add(listener); + } + + @Override + public long milliseconds() { + maybeSleep(autoTickMs); + return timeMs.get(); + } + + @Override + public long nanoseconds() { + maybeSleep(autoTickMs); + return highResTimeNs.get(); + } + + @Override + public long hiResClockMs() { + return TimeUnit.NANOSECONDS.toMillis(nanoseconds()); + } + + private void maybeSleep(long ms) { + if (ms != 0) { + sleep(ms); + } + } + + @Override + public void sleep(long ms) { + timeMs.addAndGet(ms); + highResTimeNs.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms)); + tick(); + } + + public void setCurrentTimeMs(long newMs) { + long oldMs = timeMs.getAndSet(newMs); + + // does not allow to set to an older timestamp + if (oldMs > newMs) { + throw new IllegalArgumentException("Setting the time to " + newMs + " while current time " + + oldMs + " is newer; this is not allowed"); + } + + highResTimeNs.set(TimeUnit.MILLISECONDS.toNanos(newMs)); + tick(); + } + + private void tick() { + for (MockTimeListener listener : listeners) { + listener.tick(); + } + } +} diff --git a/src/test/java/io/streamnative/kop/utils/timer/MockTimer.java b/src/test/java/io/streamnative/kop/utils/timer/MockTimer.java new file mode 100644 index 0000000000..27d6b2c0f5 --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/timer/MockTimer.java @@ -0,0 +1,85 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import io.streamnative.kop.utils.MockTime; +import io.streamnative.kop.utils.timer.TimerTaskList.TimerTaskEntry; +import java.util.Comparator; +import java.util.PriorityQueue; + +/** + * A mock implementation of {@link Timer}. + */ +public class MockTimer implements Timer { + + private final MockTime time = new MockTime(); + private final PriorityQueue taskQueue = new PriorityQueue<>(Comparator.reverseOrder()); + + @Override + public void add(TimerTask timerTask) { + if (timerTask.delayMs <= 0) { + timerTask.run(); + } else { + synchronized (taskQueue) { + taskQueue.add( + new TimerTaskEntry( + timerTask, timerTask.delayMs + time.milliseconds())); + } + } + } + + @Override + public boolean advanceClock(long timeoutMs) { + time.sleep(timeoutMs); + + boolean executed = false; + final long now = time.milliseconds(); + boolean hasMore = true; + + while (hasMore) { + hasMore = false; + TimerTaskEntry head; + synchronized (taskQueue) { + head = taskQueue.peek(); + if (null != head && now > head.expirationMs()) { + head = taskQueue.poll(); + hasMore = !taskQueue.isEmpty(); + } else { + head = null; + } + } + if (null != head) { + if (!head.cancelled()) { + TimerTask task = head.timerTask(); + task.run(); + executed = true; + } + } + } + + return executed; + } + + @Override + public int size() { + synchronized (taskQueue) { + return taskQueue.size(); + } + } + + @Override + public void shutdown() { + // no-op + } +} diff --git a/src/test/java/io/streamnative/kop/utils/timer/TimerTaskListTest.java b/src/test/java/io/streamnative/kop/utils/timer/TimerTaskListTest.java new file mode 100644 index 0000000000..89a512ad17 --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/timer/TimerTaskListTest.java @@ -0,0 +1,108 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import static org.junit.Assert.assertEquals; + +import io.streamnative.kop.utils.timer.TimerTaskList.TimerTaskEntry; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.Test; + +/** + * Unit test {@link TimerTaskList}. + */ +public class TimerTaskListTest { + + /** + * Test task. + */ + private static class TestTask extends TimerTask { + + protected TestTask(long delayMs) { + super(delayMs); + } + + @Override + public void run() { + + } + + } + + private int size(TimerTaskList list) { + AtomicInteger count = new AtomicInteger(0); + list.forEach(ignored -> count.incrementAndGet()); + return count.get(); + } + + @Test + public void testAll() { + AtomicInteger sharedCounter = new AtomicInteger(0); + TimerTaskList list1 = new TimerTaskList(sharedCounter); + TimerTaskList list2 = new TimerTaskList(sharedCounter); + TimerTaskList list3 = new TimerTaskList(sharedCounter); + + List tasks = IntStream.rangeClosed(1, 10).mapToObj(i -> { + TestTask task = new TestTask(0L); + list1.add(new TimerTaskEntry(task, 10L)); + assertEquals(i, sharedCounter.get()); + return task; + }).collect(Collectors.toList()); + + assertEquals(tasks.size(), sharedCounter.get()); + + // reinserting the existing tasks shouldn't change the task count. + tasks.subList(0, 4).forEach(task -> { + int prevCount = sharedCounter.get(); + // new TimerTaskEntry(task) will remove the existing entry from the list + list2.add(new TimerTaskEntry(task, 10L)); + assertEquals(prevCount, sharedCounter.get()); + }); + assertEquals(10 - 4, size(list1)); + assertEquals(4, size(list2)); + assertEquals(tasks.size(), sharedCounter.get()); + + // reinserting the existing tasks shouldn't change the task count + tasks.subList(4, 10).forEach(task -> { + int prevCount = sharedCounter.get(); + // new TimerTaskEntry(task) will remove the existing entry from the list + list3.add(new TimerTaskEntry(task, 10L)); + assertEquals(prevCount, sharedCounter.get()); + }); + assertEquals(0, size(list1)); + assertEquals(4, size(list2)); + assertEquals(6, size(list3)); + assertEquals(tasks.size(), sharedCounter.get()); + + // cancel tasks in the lists + list1.forEach(TimerTask::cancel); + assertEquals(0, size(list1)); + assertEquals(4, size(list2)); + assertEquals(6, size(list3)); + + list2.forEach(TimerTask::cancel); + assertEquals(0, size(list1)); + assertEquals(0, size(list2)); + assertEquals(6, size(list3)); + + list3.forEach(TimerTask::cancel); + assertEquals(0, size(list1)); + assertEquals(0, size(list2)); + assertEquals(0, size(list3)); + } + +} diff --git a/src/test/java/io/streamnative/kop/utils/timer/TimerTest.java b/src/test/java/io/streamnative/kop/utils/timer/TimerTest.java new file mode 100644 index 0000000000..5bbe356d11 --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/timer/TimerTest.java @@ -0,0 +1,161 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.timer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.common.collect.Lists; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.common.utils.Time; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit test {@link Timer}. + */ +@Slf4j +public class TimerTest { + + private static class TestTask extends TimerTask { + private final int id; + private final CountDownLatch latch; + private final List output; + private final AtomicBoolean completed = new AtomicBoolean(false); + + public TestTask(long delayMs, + int id, + CountDownLatch latch, + List output) { + super(delayMs); + this.id = id; + this.latch = latch; + this.output = output; + } + + public void run() { + if (completed.compareAndSet(false, true)) { + synchronized (output) { + output.add(id); + } + latch.countDown(); + } + } + } + + private Timer timer = null; + + @Before + public void setup() { + this.timer = SystemTimer.builder() + .executorName("test") + .tickMs(1) + .wheelSize(3) + .startMs(Time.SYSTEM.hiResClockMs()) + .build(); + } + + @After + public void teardown() { + timer.shutdown(); + } + + @Test + public void testAlreadyExpiredTask() { + final List output = new ArrayList<>(); + final List latches = IntStream.range(-5, 0).mapToObj(i -> { + CountDownLatch latch = new CountDownLatch(1); + timer.add(new TestTask(i, i, latch, output)); + return latch; + }).collect(Collectors.toList()); + + timer.advanceClock(0); + + latches.forEach(latch -> { + try { + assertEquals( + "already expired tasks should run immediately", + true, + latch.await(3, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + fail("Should not reach here"); + } + }); + + assertEquals( + "Output of already expired tasks", + Lists.newArrayList(-5, -4, -3, -2, -1), + output + ); + } + + @Test + public void testTaskExpiration() { + final List output = new ArrayList<>(); + final List tasks = new ArrayList<>(); + final List ids = new ArrayList<>(); + final List latches = IntStream.range(0, 5).mapToObj(i -> { + CountDownLatch latch = new CountDownLatch(1); + tasks.add(new TestTask(i, i, latch, output)); + ids.add(i); + return latch; + }).collect(Collectors.toList()); + latches.addAll(IntStream.range(10, 100).mapToObj(i -> { + CountDownLatch latch = new CountDownLatch(1); + tasks.add(new TestTask(i, i, latch, output)); + tasks.add(new TestTask(i, i, latch, output)); + ids.add(i); + ids.add(i); + return latch; + }).collect(Collectors.toList())); + latches.addAll(IntStream.range(100, 500).mapToObj(i -> { + CountDownLatch latch = new CountDownLatch(1); + tasks.add(new TestTask(i, i, latch, output)); + ids.add(i); + return latch; + }).collect(Collectors.toList())); + + // randomly submit requests + tasks.forEach(task -> timer.add(task)); + + while (timer.advanceClock(2000)) {} + + latches.forEach(latch -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + fail("Should not reach here"); + } + }); + + Collections.sort(ids); + assertEquals( + "output should match", + ids, + output + ); + } + +} From eb70d58ecb313a943b909b1ba2654a091687da5f Mon Sep 17 00:00:00 2001 From: Sijie Guo Date: Sat, 3 Aug 2019 21:48:08 +0800 Subject: [PATCH 2/2] Port delayed operations implementation from Kafka --- .../io/streamnative/kop/utils/CoreUtils.java | 44 ++ .../kop/utils/ShutdownableThread.java | 120 +++++ .../kop/utils/delayed/DelayedOperation.java | 147 ++++++ .../utils/delayed/DelayedOperationKey.java | 98 ++++ .../delayed/DelayedOperationPurgatory.java | 410 +++++++++++++++ .../kop/utils/delayed/package-info.java | 17 + .../io/streamnative/kop/utils/TestUtils.java | 43 ++ .../utils/delayed/DelayedOperationTest.java | 487 ++++++++++++++++++ 8 files changed, 1366 insertions(+) create mode 100644 src/main/java/io/streamnative/kop/utils/CoreUtils.java create mode 100644 src/main/java/io/streamnative/kop/utils/ShutdownableThread.java create mode 100644 src/main/java/io/streamnative/kop/utils/delayed/DelayedOperation.java create mode 100644 src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationKey.java create mode 100644 src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationPurgatory.java create mode 100644 src/main/java/io/streamnative/kop/utils/delayed/package-info.java create mode 100644 src/test/java/io/streamnative/kop/utils/TestUtils.java create mode 100644 src/test/java/io/streamnative/kop/utils/delayed/DelayedOperationTest.java diff --git a/src/main/java/io/streamnative/kop/utils/CoreUtils.java b/src/main/java/io/streamnative/kop/utils/CoreUtils.java new file mode 100644 index 0000000000..3987086b31 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/CoreUtils.java @@ -0,0 +1,44 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils; + +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.function.Supplier; +import lombok.experimental.UtilityClass; + +/** + * Core utils. + */ +@UtilityClass +public final class CoreUtils { + + public static T inLock(Lock lock, Supplier supplier) { + lock.lock(); + try { + return supplier.get(); + } finally { + lock.unlock(); + } + } + + public static T inReadLock(ReadWriteLock lock, Supplier supplier) { + return inLock(lock.readLock(), supplier); + } + + public static T inWriteLock(ReadWriteLock lock, Supplier supplier) { + return inLock(lock.writeLock(), supplier); + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/ShutdownableThread.java b/src/main/java/io/streamnative/kop/utils/ShutdownableThread.java new file mode 100644 index 0000000000..bc05f25c53 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/ShutdownableThread.java @@ -0,0 +1,120 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; +import org.apache.kafka.common.internals.FatalExitError; +import org.apache.kafka.common.utils.Exit; + +/** + * Shutdownable thread. + */ +@Slf4j +public abstract class ShutdownableThread extends Thread { + + private final boolean isInterruptible; + private final String logIdent; + private final CountDownLatch shutdownInitiated = new CountDownLatch(1); + private final CountDownLatch shutdownComplete = new CountDownLatch(1); + + public ShutdownableThread(String name) { + this(name, true); + } + + public ShutdownableThread(String name, + boolean isInterruptible) { + super(name); + this.isInterruptible = isInterruptible; + this.setDaemon(false); + this.logIdent = "[" + name + "]"; + } + + public boolean isRunning() { + return shutdownInitiated.getCount() != 0; + } + + public void shutdown() throws InterruptedException { + initiateShutdown(); + awaitShutdown(); + } + + public boolean isShutdownComplete() { + return shutdownComplete.getCount() == 0; + } + + public synchronized boolean initiateShutdown() { + if (isRunning()) { + log.info("{} Shutting down", logIdent); + } + shutdownInitiated.countDown(); + if (isInterruptible) { + interrupt(); + return true; + } else { + return false; + } + } + + /** + * After calling initiateShutdown(), use this API to wait until the shutdown is complete. + */ + public void awaitShutdown() throws InterruptedException { + shutdownComplete.await(); + log.info("{} Shutdown completed", logIdent); + } + + /** + * Causes the current thread to wait until the shutdown is initiated, + * or the specified waiting time elapses. + * + * @param timeout + * @param unit + */ + public void pause(long timeout, TimeUnit unit) throws InterruptedException { + if (shutdownInitiated.await(timeout, unit)) { + if (log.isTraceEnabled()) { + log.trace("{} shutdownInitiated latch count reached zero. Shutdown called.", logIdent); + } + } + } + + /** + * This method is repeatedly invoked until the thread shuts down or this method throws an exception. + */ + protected abstract void doWork(); + + @Override + public void run() { + log.info("{} Starting", logIdent); + try { + while (isRunning()) { + doWork(); + } + } catch (FatalExitError e) { + shutdownInitiated.countDown(); + shutdownComplete.countDown(); + log.info("{} Stopped", logIdent); + Exit.exit(e.statusCode()); + } catch (Throwable cause) { + if (isRunning()) { + log.error("{} Error due to", logIdent, cause); + } + } finally { + shutdownComplete.countDown(); + } + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperation.java b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperation.java new file mode 100644 index 0000000000..124afb7e35 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperation.java @@ -0,0 +1,147 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.delayed; + +import io.streamnative.kop.utils.timer.TimerTask; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import lombok.extern.slf4j.Slf4j; + +/** + * An operation whose processing needs to be delayed for at most the given delayMs. For example + * a delayed produce operation could be waiting for specified number of acks; or + * a delayed fetch operation could be waiting for a given number of bytes to accumulate. + * + *

The logic upon completing a delayed operation is defined in onComplete() and will be called exactly once. + * Once an operation is completed, isCompleted() will return true. onComplete() can be triggered by either + * forceComplete(), which forces calling onComplete() after delayMs if the operation is not yet completed, + * or tryComplete(), which first checks if the operation can be completed or not now, and if yes calls + * forceComplete(). + * + *

A subclass of DelayedOperation needs to provide an implementation of both onComplete() and tryComplete(). + */ +@Slf4j +public abstract class DelayedOperation extends TimerTask { + + protected final Optional lockOpt; + private final AtomicBoolean completed = new AtomicBoolean(false); + private final AtomicBoolean tryCompletePending = new AtomicBoolean(false); + final Lock lock; + + protected DelayedOperation(long delayMs, + Optional lockOpt) { + super(delayMs); + this.lockOpt = lockOpt; + this.lock = lockOpt.orElseGet(() -> new ReentrantLock()); + } + + /** + * Force completing the delayed operation, if not already completed. + * This function can be triggered when + * + *

1. The operation has been verified to be completable inside tryComplete() + * 2. The operation has expired and hence needs to be completed right now + * + *

Return true iff the operation is completed by the caller: note that + * concurrent threads can try to complete the same operation, but only + * the first thread will succeed in completing the operation and return + * true, others will still return false + */ + public boolean forceComplete() { + if (completed.compareAndSet(false, true)) { + // cancel the timeout timer + cancel(); + onComplete(); + return true; + } else { + return false; + } + } + + /** + * Check if the delayed operation is already completed. + */ + public boolean isCompleted() { + return completed.get(); + } + + /** + * Call-back to execute when a delayed operation gets expired and hence forced to complete. + */ + public abstract void onExpiration(); + + /** + * Process for completing an operation. This function needs to be defined + * in subclasses and will be called exactly once in forceComplete() + */ + public abstract void onComplete(); + + /** + * Try to complete the delayed operation by first checking if the operation + * can be completed by now. If yes execute the completion logic by calling + * forceComplete() and return true iff forceComplete returns true; otherwise return false + * + *

This function needs to be defined in subclasses. + */ + public abstract boolean tryComplete(); + + /** + * Thread-safe variant of tryComplete() that attempts completion only if the lock can be acquired + * without blocking. + * + *

If threadA acquires the lock and performs the check for completion before completion criteria is met + * and threadB satisfies the completion criteria, but fails to acquire the lock because threadA has not + * yet released the lock, we need to ensure that completion is attempted again without blocking threadA + * or threadB. `tryCompletePending` is set by threadB when it fails to acquire the lock and at least one + * of threadA or threadB will attempt completion of the operation if this flag is set. This ensures that + * every invocation of `maybeTryComplete` is followed by at least one invocation of `tryComplete` until + * the operation is actually completed. + */ + boolean maybeTryComplete() { + boolean retry = false; + boolean done = false; + do { + if (lock.tryLock()) { + try { + tryCompletePending.set(false); + done = tryComplete(); + } finally { + lock.unlock(); + } + // While we were holding the lock, another thread may have invoked `maybeTryComplete` and set + // `tryCompletePending`. In this case we should retry. + retry = tryCompletePending.get(); + } else { + // Another thread is holding the lock. If `tryCompletePending` is already set and this thread failed to + // acquire the lock, then the thread that is holding the lock is guaranteed to see the flag and retry. + // Otherwise, we should set the flag and retry on this thread since the thread holding the lock may have + // released the lock and returned by the time the flag is set. + retry = !tryCompletePending.getAndSet(true); + } + } while (!isCompleted() && retry); + return done; + } + + /** + * run() method defines a task that is executed on timeout. + */ + @Override + public void run() { + if (forceComplete()) { + onExpiration(); + } + } +} diff --git a/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationKey.java b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationKey.java new file mode 100644 index 0000000000..03e7839560 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationKey.java @@ -0,0 +1,98 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.delayed; + +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.experimental.Accessors; + +/** + * Delayed operation key. + */ +public interface DelayedOperationKey { + + /** + * Key label. + * + * @return key label. + */ + String keyLabel(); + + /** + * Member key. + */ + @Data + @Accessors(fluent = true) + @RequiredArgsConstructor + class MemberKey implements DelayedOperationKey { + + private final String groupId; + private final String consumerId; + + @Override + public String keyLabel() { + return String.format("%s-%s", groupId, consumerId); + } + } + + /** + * Group key. + */ + @Data + @Accessors(fluent = true) + @RequiredArgsConstructor + class GroupKey implements DelayedOperationKey { + + private final String groupId; + + @Override + public String keyLabel() { + return groupId; + } + } + + /** + * Topic key. + */ + @Data + @Accessors(fluent = true) + @RequiredArgsConstructor + class TopicKey implements DelayedOperationKey { + + private final String topic; + + @Override + public String keyLabel() { + return topic; + } + } + + /** + * Topic partition key. + */ + @Data + @Accessors(fluent = true) + @RequiredArgsConstructor + class TopicPartitionOperationKey implements DelayedOperationKey { + + private final String topic; + private final int partition; + + @Override + public String keyLabel() { + return String.format("%s-%d", topic, partition); + } + } + +} diff --git a/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationPurgatory.java b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationPurgatory.java new file mode 100644 index 0000000000..89170e315c --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/delayed/DelayedOperationPurgatory.java @@ -0,0 +1,410 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.delayed; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.streamnative.kop.utils.CoreUtils.inReadLock; +import static io.streamnative.kop.utils.CoreUtils.inWriteLock; + +import io.streamnative.kop.utils.ShutdownableThread; +import io.streamnative.kop.utils.timer.SystemTimer; +import io.streamnative.kop.utils.timer.Timer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import lombok.extern.slf4j.Slf4j; + +/** + * A helper purgatory class for bookkeeping delayed operations with a timeout, and expiring timed out operations. + */ +@Slf4j +public class DelayedOperationPurgatory { + + public static Builder builder() { + return new Builder<>(); + } + + /** + * Builder to build a delayed operation purgatory. + */ + public static class Builder { + + private String purgatoryName; + private Timer timer; + private int purgeInterval = 1000; + private boolean reaperEnabled = true; + private boolean timerEnabled = true; + + private Builder() {} + + public Builder purgatoryName(String purgatoryName) { + this.purgatoryName = purgatoryName; + return this; + } + + public Builder timeoutTimer(Timer timer) { + this.timer = timer; + return this; + } + + public Builder purgeInterval(int purgeInterval) { + this.purgeInterval = purgeInterval; + return this; + } + + public Builder reaperEnabled(boolean reaperEnabled) { + this.reaperEnabled = reaperEnabled; + return this; + } + + public Builder timerEnabled(boolean timerEnabled) { + this.timerEnabled = timerEnabled; + return this; + } + + public DelayedOperationPurgatory build() { + if (null == timer) { + timer = SystemTimer.builder().executorName(purgatoryName).build(); + } + return new DelayedOperationPurgatory<>( + purgatoryName, + timer, + purgeInterval, + reaperEnabled, + timerEnabled + ); + } + } + + private final String purgatoryName; + private final Timer timeoutTimer; + private final int purgeInterval; + private final boolean reaperEnabled; + private final boolean timerEnabled; + + /* a list of operation watching keys */ + private final ConcurrentMap watchersForKey; + + private final ReentrantReadWriteLock removeWatchersLock = new ReentrantReadWriteLock(); + + // the number of estimated total operations in the purgatory + private final AtomicInteger estimatedTotalOperations = new AtomicInteger(0); + + /* background thread expiring operations that have timed out */ + private final ShutdownableThread expirationReaper; + + public DelayedOperationPurgatory( + String purgatoryName, + Timer timeoutTimer, + int purgeInterval, + boolean reaperEnabled, + boolean timerEnabled + ) { + this.purgatoryName = purgatoryName; + this.timeoutTimer = timeoutTimer; + this.purgeInterval = purgeInterval; + this.reaperEnabled = reaperEnabled; + this.timerEnabled = timerEnabled; + + this.watchersForKey = new ConcurrentHashMap<>(); + this.expirationReaper = new ShutdownableThread( + String.format("ExpirationReaper-%s", purgatoryName) + ) { + @Override + protected void doWork() { + advanceClock(200L); + } + }; + + if (reaperEnabled) { + expirationReaper.start(); + } + } + + /** + * Check if the operation can be completed, if not watch it based on the given watch keys + * + *

Note that a delayed operation can be watched on multiple keys. It is possible that + * an operation is completed after it has been added to the watch list for some, but + * not all of the keys. In this case, the operation is considered completed and won't + * be added to the watch list of the remaining keys. The expiration reaper thread will + * remove this operation from any watcher list in which the operation exists. + * + * @param operation the delayed operation to be checked + * @param watchKeys keys for bookkeeping the operation + * @return true iff the delayed operations can be completed by the caller + */ + public boolean tryCompleteElseWatch(T operation, List watchKeys) { + checkArgument(!watchKeys.isEmpty(), "The watch key list can't be empty"); + + // The cost of tryComplete() is typically proportional to the number of keys. Calling + // tryComplete() for each key is going to be expensive if there are many keys. Instead, + // we do the check in the following way. Call tryComplete(). If the operation is not completed, + // we just add the operation to all keys. Then we call tryComplete() again. At this time, if + // the operation is still not completed, we are guaranteed that it won't miss any future triggering + // event since the operation is already on the watcher list for all keys. This does mean that + // if the operation is completed (by another thread) between the two tryComplete() calls, the + // operation is unnecessarily added for watch. However, this is a less severe issue since the + // expire reaper will clean it up periodically. + + // At this point the only thread that can attempt this operation is this current thread + // Hence it is safe to tryComplete() without a lock + boolean isCompletedByMe = operation.tryComplete(); + if (isCompletedByMe) { + return true; + } + + boolean watchCreated = false; + for (Object key : watchKeys) { + // If the operation is already completed, stop adding it to the rest of the watcher list. + if (operation.isCompleted()) { + return false; + } + watchForOperation(key, operation); + + if (!watchCreated) { + watchCreated = true; + estimatedTotalOperations.incrementAndGet(); + } + } + + isCompletedByMe = operation.maybeTryComplete(); + if (isCompletedByMe) { + return true; + } + + // if it cannot be completed by now and hence is watched, add to the expire queue also + if (!operation.isCompleted()) { + if (timerEnabled) { + timeoutTimer.add(operation); + } + if (operation.isCompleted()) { + // cancel the timer task + operation.cancel(); + } + } + + return false; + } + + /** + * Check if some delayed operations can be completed with the given watch key, + * and if yes complete them. + * + * @return the number of completed operations during this process + */ + public int checkAndComplete(Object key) { + Watchers watchers = inReadLock( + removeWatchersLock, + () -> watchersForKey.get(key)); + if (null == watchers) { + return 0; + } else { + return watchers.tryCompleteWatched(); + } + } + + /** + * Return the total size of watch lists the purgatory. Since an operation may be watched + * on multiple lists, and some of its watched entries may still be in the watch lists + * even when it has been completed, this number may be larger than the number of real operations watched + */ + public int watched() { + return allWatchers().stream().mapToInt(Watchers::countWatched).sum(); + } + + /** + * Return the number of delayed operations in the expiry queue. + */ + public int delayed() { + return timeoutTimer.size(); + } + + /** + * Cancel watching on any delayed operations for the given key. Note the operation will not be completed + */ + public List cancelForKey(Object key) { + return inWriteLock(removeWatchersLock, () -> { + Watchers watchers = watchersForKey.remove(key); + if (watchers != null) { + return watchers.cancel(); + } else { + return Collections.emptyList(); + } + }); + } + /* + * Return all the current watcher lists, + * note that the returned watchers may be removed from the list by other threads + */ + private Collection allWatchers() { + return inReadLock(removeWatchersLock, () -> watchersForKey.values()); + } + + /* + * Return the watch list of the given key, note that we need to + * grab the removeWatchersLock to avoid the operation being added to a removed watcher list + */ + private void watchForOperation(Object key, T operation) { + inReadLock(removeWatchersLock, () -> { + watchersForKey.computeIfAbsent(key, (k) -> new Watchers(k)) + .watch(operation); + return null; + }); + } + + /** + * Remove the key from watcher lists if its list is empty. + */ + private void removeKeyIfEmpty(Object key, Watchers watchers) { + inWriteLock(removeWatchersLock, () -> { + // if the current key is no longer correlated to the watchers to remove, skip + if (watchersForKey.get(key) != watchers) { + return null; + } + + if (watchers != null && watchers.isEmpty()) { + watchersForKey.remove(key); + } + return null; + }); + } + + /** + * Shutdown the expire reaper thread. + */ + public void shutdown() { + if (reaperEnabled) { + try { + expirationReaper.shutdown(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.error("Interrupted at shutting down expiration reaper for {}", purgatoryName); + } + } + timeoutTimer.shutdown(); + } + + /** + * A linked list of watched delayed operations based on some key. + */ + private class Watchers { + + private final Object key; + private final ConcurrentLinkedQueue operations = new ConcurrentLinkedQueue<>(); + + Watchers(Object key) { + this.key = key; + } + + // count the current number of watched operations. This is O(n), so use isEmpty() if possible + public int countWatched() { + return operations.size(); + } + + public boolean isEmpty() { + return operations.isEmpty(); + } + + // add the element to watch + public void watch(T t) { + operations.add(t); + } + + // traverse the list and try to complete some watched elements + public int tryCompleteWatched() { + int completed = 0; + + Iterator iter = operations.iterator(); + while (iter.hasNext()) { + T curr = iter.next(); + if (curr.isCompleted()) { + // another thread has completed this operation, just remove it + iter.remove(); + } else if (curr.maybeTryComplete()) { + iter.remove(); + completed += 1; + } + } + + if (operations.isEmpty()) { + removeKeyIfEmpty(key, this); + } + + return completed; + } + + public List cancel() { + Iterator iter = operations.iterator(); + List cancelled = new ArrayList<>(); + while (iter.hasNext()) { + T curr = iter.next(); + curr.cancel(); + iter.remove(); + cancelled.add(curr); + } + return cancelled; + } + + // traverse the list and purge elements that are already completed by others + int purgeCompleted() { + int purged = 0; + + Iterator iter = operations.iterator(); + while (iter.hasNext()) { + T curr = iter.next(); + if (curr.isCompleted()) { + iter.remove(); + purged += 1; + } + } + + if (operations.isEmpty()) { + removeKeyIfEmpty(key, this); + } + + return purged; + } + } + + public void advanceClock(long timeoutMs) { + timeoutTimer.advanceClock(timeoutMs); + + // Trigger a purge if the number of completed but still being watched operations is larger than + // the purge threshold. That number is computed by the difference btw the estimated total number of + // operations and the number of pending delayed operations. + if (estimatedTotalOperations.get() - delayed() > purgeInterval) { + // now set estimatedTotalOperations to delayed (the number of pending operations) since we are going to + // clean up watchers. Note that, if more operations are completed during the clean up, we may end up with + // a little overestimated total number of operations. + estimatedTotalOperations.getAndSet(delayed()); + if (log.isDebugEnabled()) { + log.debug("{} Begin purging watch lists", purgatoryName); + } + int purged = allWatchers().stream().mapToInt(Watchers::purgeCompleted).sum(); + if (log.isDebugEnabled()) { + log.debug("{} Purged {} elements from watch lists.", purgatoryName, purged); + } + } + } + + +} diff --git a/src/main/java/io/streamnative/kop/utils/delayed/package-info.java b/src/main/java/io/streamnative/kop/utils/delayed/package-info.java new file mode 100644 index 0000000000..d8ec5f0e54 --- /dev/null +++ b/src/main/java/io/streamnative/kop/utils/delayed/package-info.java @@ -0,0 +1,17 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Classes related delayed operations. + */ +package io.streamnative.kop.utils.delayed; \ No newline at end of file diff --git a/src/test/java/io/streamnative/kop/utils/TestUtils.java b/src/test/java/io/streamnative/kop/utils/TestUtils.java new file mode 100644 index 0000000000..ea2a19d1a9 --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/TestUtils.java @@ -0,0 +1,43 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils; + +import static org.junit.Assert.fail; + +import java.util.function.Supplier; +import lombok.SneakyThrows; + +/** + * Test utilities. + */ +public class TestUtils { + + @SneakyThrows + public static void waitUntilTrue(Supplier condition, + Supplier msg, + long waitTime, + long pause) { + long startTime = System.currentTimeMillis(); + while (true) { + if (condition.get()) { + return; + } + if (System.currentTimeMillis() > startTime + waitTime) { + fail(msg.get()); + } + Thread.sleep(Math.min(waitTime, pause)); + } + } + +} diff --git a/src/test/java/io/streamnative/kop/utils/delayed/DelayedOperationTest.java b/src/test/java/io/streamnative/kop/utils/delayed/DelayedOperationTest.java new file mode 100644 index 0000000000..0ef3eb15f9 --- /dev/null +++ b/src/test/java/io/streamnative/kop/utils/delayed/DelayedOperationTest.java @@ -0,0 +1,487 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.streamnative.kop.utils.delayed; + +import static io.streamnative.kop.utils.CoreUtils.inLock; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.Lists; +import io.streamnative.kop.utils.TestUtils; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.Semaphore; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import lombok.SneakyThrows; +import org.apache.kafka.common.utils.Time; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Unit test {@link DelayedOperation}. + */ +public class DelayedOperationTest { + + /** + * A mock delayed operation. + */ + static class MockDelayedOperation extends DelayedOperation { + + private final Optional responseLockOpt; + boolean completable = false; + + protected MockDelayedOperation(long delayMs) { + this(delayMs, Optional.empty(), Optional.empty()); + } + + protected MockDelayedOperation(long delayMs, + Optional lockOpt, + Optional responseLockOpt) { + super(delayMs, lockOpt); + this.responseLockOpt = responseLockOpt; + } + + public synchronized void awaitExpiration() throws InterruptedException { + wait(); + } + + @Override + public void onExpiration() { + // no-op + } + + @Override + public void onComplete() { + responseLockOpt.map(lock -> { + if (!lock.tryLock()) { + throw new IllegalStateException("Response callback lock could not be acquired in callback"); + } + return null; + }); + synchronized (this) { + notify(); + } + } + + @Override + public boolean tryComplete() { + if (completable) { + return forceComplete(); + } else { + return false; + } + } + } + + static class TestDelayOperation extends MockDelayedOperation { + + private final int index; + private final Object key; + private final AtomicInteger completionAttemptsRemaining; + private final int maxDelayMs; + + TestDelayOperation(int index, + int completionAttempts, + int maxDelayMs) { + super(10000L); + this.index = index; + this.key = "key" + index; + this.maxDelayMs = maxDelayMs; + this.completionAttemptsRemaining = new AtomicInteger(completionAttempts); + } + + @SneakyThrows + @Override + public boolean tryComplete() { + boolean shouldComplete = completable; + Thread.sleep(ThreadLocalRandom.current().nextInt(maxDelayMs)); + if (shouldComplete) { + return forceComplete(); + } else { + return false; + } + } + } + + DelayedOperationPurgatory purgatory = null; + ScheduledExecutorService executorService = null; + + @Before + public void setup() { + purgatory = DelayedOperationPurgatory.builder() + .purgatoryName("mock") + .build(); + } + + @After + public void teardown() { + purgatory.shutdown(); + if (null != executorService) { + executorService.shutdown(); + } + } + + @Test + public void testRequestSatisfaction() { + MockDelayedOperation r1 = new MockDelayedOperation(100000L); + MockDelayedOperation r2 = new MockDelayedOperation(100000L); + assertEquals( + "With no waiting requests, nothing should be satisfied", + 0, purgatory.checkAndComplete("test1")); + assertFalse( + "r1 not satisfied and hence watched", + purgatory.tryCompleteElseWatch(r1, Lists.newArrayList("test1"))); + assertEquals( + "Still nothing satisfied", + 0, purgatory.checkAndComplete("test1")); + assertFalse( + "r2 not satisfied and hence watched", + purgatory.tryCompleteElseWatch(r2, Lists.newArrayList("test2"))); + assertEquals( + "Still nothing satisfied", + 0, purgatory.checkAndComplete("test2")); + r1.completable = true; + assertEquals( + "r1 satisfied", + 1, purgatory.checkAndComplete("test1")); + assertEquals( + "Nothing satisfied", + 0, purgatory.checkAndComplete("test1")); + r2.completable = true; + assertEquals( + "r2 satisfied", + 1, purgatory.checkAndComplete("test2")); + assertEquals( + "Nothing satisfied", + 0, purgatory.checkAndComplete("test2")); + } + + @Test + public void testRequestExpiry() throws Exception { + long expiration = 20L; + long start = Time.SYSTEM.hiResClockMs(); + MockDelayedOperation r1 = new MockDelayedOperation(expiration); + MockDelayedOperation r2 = new MockDelayedOperation(200000L); + assertFalse( + "r1 not satisfied and hence watched", + purgatory.tryCompleteElseWatch(r1, Lists.newArrayList("test1"))); + assertFalse( + "r2 not satisfied and hence watched", + purgatory.tryCompleteElseWatch(r2, Lists.newArrayList("test2"))); + r1.awaitExpiration(); + long elapsed = Time.SYSTEM.hiResClockMs() - start; + assertTrue + ("r1 completed due to expiration", + r1.isCompleted()); + assertFalse("r2 hasn't completed", r2.isCompleted()); + assertTrue( + "Time for expiration $elapsed should at least " + expiration, + elapsed >= expiration); + } + + @Test + public void testRequestPurge() { + MockDelayedOperation r1 = new MockDelayedOperation(100000L); + MockDelayedOperation r2 = new MockDelayedOperation(100000L); + MockDelayedOperation r3 = new MockDelayedOperation(100000L); + purgatory.tryCompleteElseWatch(r1, Lists.newArrayList("test1")); + purgatory.tryCompleteElseWatch(r2, Lists.newArrayList("test1", "test2")); + purgatory.tryCompleteElseWatch(r3, Lists.newArrayList("test1", "test2", "test3")); + + assertEquals( + "Purgatory should have 3 total delayed operations", + 3, purgatory.delayed()); + assertEquals( + "Purgatory should have 6 watched elements", + 6, purgatory.watched()); + + // complete the operations, it should immediately be purged from the delayed operation + r2.completable = true; + r2.tryComplete(); + assertEquals( + "Purgatory should have 2 total delayed operations instead of " + purgatory.delayed(), + 2, purgatory.delayed()); + + r3.completable = true; + r3.tryComplete(); + assertEquals( + "Purgatory should have 1 total delayed operations instead of " + purgatory.delayed(), + 1, purgatory.delayed()); + + // checking a watch should purge the watch list + purgatory.checkAndComplete("test1"); + assertEquals( + "Purgatory should have 4 watched elements instead of " + purgatory.watched(), + 4, purgatory.watched()); + + purgatory.checkAndComplete("test2"); + assertEquals( + "Purgatory should have 2 watched elements instead of " + purgatory.watched(), + 2, purgatory.watched()); + + purgatory.checkAndComplete("test3"); + assertEquals( + "Purgatory should have 1 watched elements instead of " + purgatory.watched(), + 1, purgatory.watched()); + } + + @Test + public void shouldCancelForKeyReturningCancelledOperations() { + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Lists.newArrayList("key")); + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Lists.newArrayList("key")); + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Lists.newArrayList("key2")); + + List cancelledOperations = purgatory.cancelForKey("key"); + assertEquals(2, cancelledOperations.size()); + assertEquals(1, purgatory.delayed()); + assertEquals(1, purgatory.watched()); + } + + @Test + public void shouldReturnNilOperationsOnCancelForKeyWhenKeyDoesntExist() { + List cancelledOperations = purgatory.cancelForKey("key"); + assertTrue(cancelledOperations.isEmpty()); + } + + /** + * Verify that if there is lock contention between two threads attempting to complete, + * completion is performed without any blocking in either thread. + */ + @Test + public void testTryCompleteLockContention() throws Exception { + executorService = Executors.newSingleThreadScheduledExecutor(); + AtomicInteger completionAttemptsRemaining = new AtomicInteger(Integer.MAX_VALUE); + Semaphore tryCompleteSemaphore = new Semaphore(1); + String key = "key"; + + MockDelayedOperation op = new MockDelayedOperation(100000L) { + @SneakyThrows + @Override + public boolean tryComplete() { + boolean shouldComplete = completionAttemptsRemaining.decrementAndGet() <= 0; + tryCompleteSemaphore.acquire(); + try { + if (shouldComplete) { + return forceComplete(); + } else { + return false; + } + } finally { + tryCompleteSemaphore.release(); + } + } + }; + + purgatory.tryCompleteElseWatch(op, Lists.newArrayList(key)); + completionAttemptsRemaining.set(2); + tryCompleteSemaphore.acquire(); + Future future = runOnAnotherThread(() -> purgatory.checkAndComplete(key), false); + TestUtils.waitUntilTrue( + () -> tryCompleteSemaphore.hasQueuedThreads(), + () -> "Not attempting to complete", + 10000, + 200); + purgatory.checkAndComplete(key); // this should not block even though lock is not free + assertFalse("Operation should not have completed", op.isCompleted()); + tryCompleteSemaphore.release(); + future.get(10, TimeUnit.SECONDS); + assertTrue("Operation should have completed", op.isCompleted()); + } + + /** + * Test `tryComplete` with multiple threads to verify that there are no timing windows + * when completion is not performed even if the thread that makes the operation completable + * may not be able to acquire the operation lock. Since it is difficult to test all scenarios, + * this test uses random delays with a large number of threads. + */ + @Test + public void testTryCompleteWithMultipleThreads() { + ScheduledExecutorService executor = Executors.newScheduledThreadPool(20); + this.executorService = executor; + Random random = ThreadLocalRandom.current(); + int maxDelayMs = 10; + final int completionAttempts = 20; + + List ops = IntStream.range(0, 100).mapToObj(index -> { + TestDelayOperation op = new TestDelayOperation(index, completionAttempts, maxDelayMs); + purgatory.tryCompleteElseWatch(op, Lists.newArrayList(op.key)); + return op; + }).collect(Collectors.toList()); + + List> futures = IntStream.rangeClosed(1, completionAttempts) + .mapToObj(i -> + ops.stream().map( + op -> scheduleTryComplete(op, random.nextInt(maxDelayMs))) + .collect(Collectors.toList())) + .flatMap(List::stream) + .collect(Collectors.toList()); + futures.forEach(future -> { + try { + future.get(); + } catch (InterruptedException | ExecutionException e) { + // no-op + } + }); + + ops.forEach(op -> assertTrue("Operation should have completed", op.isCompleted())); + } + + Future scheduleTryComplete(TestDelayOperation op, long delayMs) { + return executorService.schedule(() -> { + if (op.completionAttemptsRemaining.decrementAndGet() == 0) { + op.completable = true; + } + purgatory.checkAndComplete(op.key); + }, delayMs, TimeUnit.MILLISECONDS); + } + + @Test + public void testDelayedOperationLock() throws Exception { + verifyDelayedOperationLock(() -> new MockDelayedOperation(100000L), false); + } + + @Test + public void testDelayedOperationLockOverride() throws Exception { + verifyDelayedOperationLock(() -> { + ReentrantLock lock = new ReentrantLock(); + return new MockDelayedOperation(100000L, Optional.of(lock), Optional.of(lock)); + }, false); + + verifyDelayedOperationLock(() -> new MockDelayedOperation( + 100000L, + Optional.empty(), + Optional.of(new ReentrantLock()) + ), true); + } + + void verifyDelayedOperationLock(Supplier mockDelayedOperation, boolean mismatchedLocks) + throws Exception { + String key = "key"; + executorService = Executors.newSingleThreadScheduledExecutor(); + + Function> createDelayedOperations = count -> + IntStream.rangeClosed(1, count).mapToObj(i -> { + MockDelayedOperation op = mockDelayedOperation.get(); + purgatory.tryCompleteElseWatch(op, Lists.newArrayList(key)); + assertFalse("Not completable", op.isCompleted()); + return op; + }).collect(Collectors.toList()); + + Function> createCompletableOperations = count -> + IntStream.rangeClosed(1, count).mapToObj(i -> { + MockDelayedOperation op = mockDelayedOperation.get(); + op.completable = true; + return op; + }).collect(Collectors.toList()); + + BiFunction, List, Void> checkAndComplete = + (completableOps, expectedComplete) -> { + completableOps.forEach(op -> op.completable = true); + int completed = purgatory.checkAndComplete(key); + assertEquals(expectedComplete.size(), completed); + expectedComplete.forEach(op -> assertTrue( + "Should have completed", + op.isCompleted() + )); + Set expectedNotComplete = completableOps.stream().collect(Collectors.toSet()); + expectedComplete.forEach(op -> expectedNotComplete.remove(op)); + expectedNotComplete.forEach(op -> assertFalse("Should not have completed", op.isCompleted())); + return null; + }; + + // If locks are free all completable operations should complete + List ops = createDelayedOperations.apply(2); + checkAndComplete.apply(ops, ops); + + // Lock held by current thread, completable operations should complete + ops = createDelayedOperations.apply(2); + final List ops2 = ops; + inLock(ops.get(1).lock, () -> { + checkAndComplete.apply(ops2, ops2); + return null; + }); + + // Lock held by another thread, should not block, only operations that can be + // locked without blocking on the current thread should complete + ops = createDelayedOperations.apply(2); + final List ops3 = ops; + runOnAnotherThread(() -> ops3.get(0).lock.lock(), true); + try { + checkAndComplete.apply(ops, Lists.newArrayList(ops.get(1))); + } finally { + runOnAnotherThread(() -> ops3.get(0).lock.unlock(), true); + checkAndComplete.apply(Lists.newArrayList(ops.get(0)), Lists.newArrayList(ops.get(0))); + } + + // Lock acquired by response callback held by another thread, should not block + // if the response lock is used as operation lock, only operations + // that can be locked without blocking on the current thread should complete + ops = createDelayedOperations.apply(2); + final List ops4 = ops; + ops.get(0).responseLockOpt.map(lock -> { + try { + runOnAnotherThread(() -> lock.lock(), true); + try { + try { + checkAndComplete.apply(ops4, Lists.newArrayList(ops4.get(1))); + assertFalse("Should have failed with mismatched locks", mismatchedLocks); + } catch (IllegalStateException e) { + assertTrue("Should not have failed with valid locks", mismatchedLocks); + } + } finally { + runOnAnotherThread(() -> lock.unlock(), true); + checkAndComplete.apply(Lists.newArrayList(ops4.get(0)), Lists.newArrayList(ops4.get(0))); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + }); + + // Immediately completable operations should complete without locking + ops = createCompletableOperations.apply(2); + ops.forEach(op -> { + assertTrue("Should have completed", purgatory.tryCompleteElseWatch(op, Lists.newArrayList(key))); + assertTrue("Should have completed", op.isCompleted()); + }); + } + + private Future runOnAnotherThread(Runnable f, boolean shouldComplete) throws Exception { + Future future = executorService.submit(f); + if (shouldComplete) { + future.get(); + } else { + assertFalse("Should not have completed", future.isDone()); + } + return future; + } + +}