From 522f1dfb7ed927f28c537e1db4186795593d6174 Mon Sep 17 00:00:00 2001 From: Amatya Date: Mon, 11 Apr 2022 20:01:01 +0530 Subject: [PATCH 01/13] Add Treap implementation --- .../org/apache/druid/timeline/SegmentId.java | 5 +- .../druid/timeline/DataSegmentTest.java | 30 ++ .../coordinator/cost/SegmentsCostCache.java | 394 +++++++++++++++++- .../cost/SegmentsCostCacheTest.java | 73 +++- 4 files changed, 492 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/apache/druid/timeline/SegmentId.java b/core/src/main/java/org/apache/druid/timeline/SegmentId.java index 8430524021c0..e1d3a97c24a0 100644 --- a/core/src/main/java/org/apache/druid/timeline/SegmentId.java +++ b/core/src/main/java/org/apache/druid/timeline/SegmentId.java @@ -266,6 +266,7 @@ public static SegmentId dummy(String dataSource, int partitionNum) private final long intervalEndMillis; @Nullable private final Chronology intervalChronology; + private final Interval interval; private final String version; private final int partitionNum; @@ -281,6 +282,7 @@ private SegmentId(String dataSource, Interval interval, String version, int part this.intervalStartMillis = interval.getStartMillis(); this.intervalEndMillis = interval.getEndMillis(); this.intervalChronology = interval.getChronology(); + this.interval = new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); // Versions are timestamp-based Strings, interning of them doesn't make sense. If this is not the case, interning // could be conditionally allowed via a system property. this.version = Objects.requireNonNull(version); @@ -320,7 +322,8 @@ public DateTime getIntervalEnd() public Interval getInterval() { - return new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); + return interval; + //return new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); } public String getVersion() diff --git a/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java b/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java index 87ec7b17869b..d16686a62734 100644 --- a/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java +++ b/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java @@ -411,6 +411,36 @@ public void testTombstoneType() } + @Test + public void getIntervalBenchmarkTest() + { + final DataSegment segment = DataSegment.builder() + .dataSource("foo") + .interval(Intervals.of("2012-01-01/2012-01-02")) + .version(DateTimes.of("2012-01-01T11:22:33.444Z").toString()) + .shardSpec(new TombstoneShardSpec()) + .loadSpec(Collections.singletonMap( + "type", + DataSegment.TOMBSTONE_LOADSPEC_TYPE + )) + .size(0) + .build(); + + long start = System.currentTimeMillis(); + int cnt = 0; + + for (int i = 0; i < 1000000000; i++) { + Interval interval = segment.getInterval(); + cnt++; + if (cnt == 100000000) { + cnt = 0; + System.out.println(interval); + } + } + long end = System.currentTimeMillis(); + System.out.println(end - start); + } + private DataSegment makeDataSegment(String dataSource, String interval, String version) { return DataSegment.builder() diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 9271de28425b..cab69ab2e966 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -24,12 +24,14 @@ import org.apache.commons.math3.util.FastMath; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.DurationGranularity; import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.server.coordinator.CostBalancerStrategy; import org.apache.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -48,18 +50,18 @@ * Joint cost for two segments (you can make formulas below readable by copy-pasting to * https://www.codecogs.com/latex/eqneditor.php): * - * cost(X, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy + * cost(Y, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy * or - * cost(X, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) + * cost(Y, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) * if x_0 <= x_1 <= y_0 <= y_1 * (*) lambda coefficient is omitted for simplicity. * * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as: * - * cost(X, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) + * cost(Y, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) * if xi_0 <= xi_1 <= y_0 <= y_1 * and - * cost(X, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) + * cost(Y, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) * if y_0 <= y_1 <= xi_0 <= xi_1 * * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}: @@ -99,7 +101,7 @@ public class SegmentsCostCache /** * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" - * from each other and thus cost(X,Y) ~ 0 for these segments + * from each other and thus cost(Y,Y) ~ 0 for these segments */ private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); @@ -196,9 +198,9 @@ public SegmentsCostCache build() { return new SegmentsCostCache( buckets - .entrySet() + .values() .stream() - .map(entry -> entry.getValue().build()) + .map(Bucket.Builder::build) .collect(Collectors.toCollection(ArrayList::new)) ); } @@ -317,7 +319,7 @@ public static Builder builder(Interval interval) static class Builder { - private final Interval interval; + protected final Interval interval; private final NavigableSet segments = new TreeSet<>(); public Builder(Interval interval) @@ -440,4 +442,380 @@ public int hashCode() throw new UnsupportedOperationException(); } } + + + abstract static class Treap> + { + protected TreapNode root; + protected final TreapNode NULL; + + public Treap() + { + NULL = new TreapNode(null); + NULL.left = NULL.right = NULL; + NULL.priority = Double.POSITIVE_INFINITY; + root = NULL; + } + + public boolean isEmpty() + { + return NULL.equals(root); + } + + public boolean contains(X val) + { + return contains(val, root); + } + + public TreapNode lower(X val) + { + return lower(val, root); + } + + public TreapNode upper(X val) + { + return upper(val, root); + } + + public TreapNode floor(X val) + { + return floor(val, root); + } + + public TreapNode ceil(X val) + { + return ceil(val, root); + } + + public void insert(X val) + { + root = insert(new TreapNode(val), root); + } + + public void remove(X val) + { + root = remove(val, root); + } + + public TreapNode getMin() + { + TreapNode node = root; + while (!NULL.equals(node.left)) { + node = node.left; + } + return node; + } + + public TreapNode getMax() + { + TreapNode node = root; + while (!NULL.equals(node.right)) { + node = node.right; + } + return node; + } + + public double query() + { + return root.sum; + } + + public void update(X val, double lazy, boolean dir) + { + if (dir) { + root = update(root, val, null, lazy); + } else { + root = update(root, null, val, lazy); + } + } + + protected abstract double getVal(X val); + + protected abstract X setVal(X val, double add); + + private boolean contains(X val, TreapNode node) + { + if (NULL.equals(node)) { + return false; + } + final int cmp = val.compareTo(node.val); + if (cmp < 0) { + return contains(val, node.left); + } + if (cmp > 0) { + return contains(val, node.right); + } + return true; + } + + private TreapNode lower(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp <= 0) { + return lower(val, node.left); + } else { + TreapNode ret = lower(val, node.right); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode upper(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp >= 0) { + return upper(val, node.right); + } else { + TreapNode ret = upper(val, node.left); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode floor(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp < 0) { + return floor(val, node.left); + } else { + TreapNode ret = floor(val, node.right); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode ceil(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp > 0) { + return ceil(val, node.right); + } else { + TreapNode ret = ceil(val, node.left); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode insert(TreapNode val, TreapNode node) + { + if (NULL.equals(node)) { + return val; + } + Pair pair = split(node, val.val); + node = merge(pair.lhs, val); + node = merge(node, pair.rhs); + return node; + } + + private TreapNode remove(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + Pair pair = split(node, val); + TreapNode lower = lower(val, pair.lhs); + if (NULL.equals(lower)) { + return pair.rhs; + } + return merge(split(pair.lhs, lower.val).lhs, pair.rhs); + } + + private Pair split(TreapNode node, X val) + { + if (NULL.equals(node)) { + return Pair.of(NULL, NULL); + } + node.lazyPropogate(); + final int cmp = val.compareTo(node.val); + Pair pair; + if (cmp < 0) { + pair = split(node.left, val); + node.left = pair.rhs; + pair = Pair.of(pair.lhs, node); + } else { + pair = split(node.right, val); + node.right = pair.lhs; + pair = Pair.of(node, pair.rhs); + } + node.recompute(); + return pair; + } + + private TreapNode merge(TreapNode left, TreapNode right) + { + if (NULL.equals(left)) { + return right; + } + if (NULL.equals(right)) { + return left; + } + left.lazyPropogate(); + right.lazyPropogate(); + TreapNode node; + if (left.priority < right.priority) { + left.right = merge(left.right, right); + node = left; + } else { + right.left = merge(left, right.left); + node = right; + } + node.recompute(); + return node; + } + + private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, double lazy) + { + TreapNode left = NULL; + TreapNode right = NULL; + if (begin != null) { + Pair pair = split(node, begin); + left = pair.lhs; + node = pair.rhs; + } + if (end != null) { + Pair pair = split(node, end); + node = pair.lhs; + right = pair.rhs; + } + node.lazy += lazy; + node = merge(left, node); + node = merge(node, right); + return node; + } + + class TreapNode + { + X val; + TreapNode left; + TreapNode right; + double priority; + double sum; + double lazy; + int size; + + TreapNode(@Nullable X val) + { + this(val, NULL, NULL); + if (val != null) { + sum = getVal(val); + size = 1; + } + } + + TreapNode(@Nullable X val, @Nullable TreapNode left, @Nullable TreapNode right) + { + this.val = val; + this.left = left; + this.right = right; + this.priority = Math.random(); + } + + void recompute() + { + if (NULL.equals(this)) { + return; + } + size = 1 + left.size + right.size; + sum = getVal(val); + left.lazyPropogate(); + right.lazyPropogate(); + sum += left.sum + right.sum; + } + + void lazyPropogate() + { + if (NULL.equals(this)) { + return; + } + val = setVal(val, lazy); + sum += size * lazy; + if (!NULL.equals(left)) { + left.lazy += lazy; + } + if (!NULL.equals(right)) { + right.lazy += lazy; + } + lazy = 0; + } + + @Override + public boolean equals(Object that) + { + return this == that; + } + } + } + + public static class TestVal implements Comparable + { + final String a; + double b; + + public TestVal(String a, double b) + { + this.a = a; + this.b = b; + } + + public String getA() + { + return a; + } + + public double getB() + { + return b; + } + + public void setB(double b) + { + this.b = b; + } + + @Override + public int compareTo(TestVal that) + { + return a.compareTo(that.getA()); + } + } + + public static class TestTreap extends Treap + { + @Override + protected double getVal(TestVal val) + { + return val.getB(); + } + + @Override + protected TestVal setVal(TestVal val, double lazy) + { + val.setB(val.getB() + lazy); + return val; + } + + public void print() + { + print(this.root); + System.out.println(); + } + + private void print(TreapNode node) + { + if (NULL.equals(node)) { + return; + } + print(node.left); + System.out.println(node.val.getA() + ", " + node.val.getB() + ", " + node.sum + ", " + node.lazy + ", " + node.priority); + print(node.right); + } + } } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index f0ae22094fe4..56b8deecaa96 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -28,7 +28,9 @@ import java.util.ArrayList; import java.util.List; +import java.util.NavigableSet; import java.util.Random; +import java.util.TreeSet; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -37,7 +39,7 @@ public class SegmentsCostCacheTest { private static final String DATA_SOURCE = "dataSource"; - private static final DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); + private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.00000001; @Test @@ -145,6 +147,75 @@ public void multipleSegmentsCostTest() Assert.assertEquals(0.001574717989780039, segmentCost, EPSILON); } + @Test + public void treapBenchmarkTest() + { + final int n = 20000; + + List ids = new ArrayList<>(); + List vals = new ArrayList<>(); + for (int i = 0; i < n; i++) { + ids.add(UUID.randomUUID().toString()); + vals.add((double) i); + } + + System.out.println("Treap:"); + long start = System.currentTimeMillis(); + SegmentsCostCache.TestTreap treap = new SegmentsCostCache.TestTreap(); + for (int i = 0; i < n; i++) { + SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + treap.update(val, 1.0, false); + treap.update(val, 3.0, true); + treap.insert(val); + } + System.out.println(treap.query()); + long end = System.currentTimeMillis(); + for (int i = n - 1; i >= 0; i -= 3) { + SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + treap.remove(val); + treap.update(val, -1.0, false); + treap.update(val, -3.0, true); + } + System.out.println(treap.query()); + System.out.println(end - start + " ms"); + + System.out.println("TreeSet:"); + start = System.currentTimeMillis(); + NavigableSet set = new TreeSet<>(); + for (int i = 0; i < n; i++) { + SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + for (SegmentsCostCache.TestVal l : set.headSet(val)) { + l.setB(l.getB() + 1.0); + } + for (SegmentsCostCache.TestVal u : set.tailSet(val)) { + u.setB(u.getB() + 3.0); + } + set.add(val); + } + double ans = 0; + for (SegmentsCostCache.TestVal val : set) { + ans += val.getB(); + } + System.out.println(ans); + for (int i = n - 1; i >= 0; i -= 3) { + SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + set.remove(val); + for (SegmentsCostCache.TestVal l : set.headSet(val)) { + l.setB(l.getB() - 1.0); + } + for (SegmentsCostCache.TestVal u : set.tailSet(val)) { + u.setB(u.getB() - 3.0); + } + } + ans = 0; + for (SegmentsCostCache.TestVal val : set) { + ans += val.getB(); + } + System.out.println(ans); + end = System.currentTimeMillis(); + System.out.println(end - start + " ms"); + } + @Test public void randomSegmentsCostTest() { From 3ce21a516b50f1aa05a58b7e708cee4684930738 Mon Sep 17 00:00:00 2001 From: Amatya Date: Mon, 11 Apr 2022 21:36:23 +0530 Subject: [PATCH 02/13] Make Treap generic --- .../coordinator/cost/SegmentsCostCache.java | 122 +++++++++++++++--- .../cost/SegmentsCostCacheTest.java | 22 ++-- 2 files changed, 113 insertions(+), 31 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index cab69ab2e966..1d9be04ede6a 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -35,6 +35,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; +import java.util.List; import java.util.ListIterator; import java.util.NavigableMap; import java.util.NavigableSet; @@ -444,7 +445,7 @@ public int hashCode() } - abstract static class Treap> + abstract static class Treap, Y> { protected TreapNode root; protected final TreapNode NULL; @@ -515,12 +516,12 @@ public TreapNode getMax() return node; } - public double query() + public Y query() { return root.sum; } - public void update(X val, double lazy, boolean dir) + public void update(X val, Y lazy, boolean dir) { if (dir) { root = update(root, val, null, lazy); @@ -529,9 +530,15 @@ public void update(X val, double lazy, boolean dir) } } - protected abstract double getVal(X val); + protected abstract Y getVal(X val); - protected abstract X setVal(X val, double add); + protected abstract X setVal(X val, Y lazy); + + protected abstract Y add(Y a, Y b); + + protected abstract Y multiply(int a, Y b); + + protected abstract Y zero(); private boolean contains(X val, TreapNode node) { @@ -671,7 +678,7 @@ private TreapNode merge(TreapNode left, TreapNode right) return node; } - private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, double lazy) + private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, Y lazy) { TreapNode left = NULL; TreapNode right = NULL; @@ -685,7 +692,7 @@ private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, dou node = pair.lhs; right = pair.rhs; } - node.lazy += lazy; + node.lazy = add(node.lazy, lazy); node = merge(left, node); node = merge(node, right); return node; @@ -694,11 +701,11 @@ private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, dou class TreapNode { X val; + Y sum; + Y lazy; TreapNode left; TreapNode right; double priority; - double sum; - double lazy; int size; TreapNode(@Nullable X val) @@ -716,6 +723,8 @@ class TreapNode this.left = left; this.right = right; this.priority = Math.random(); + this.sum = zero(); + this.lazy = zero(); } void recompute() @@ -727,7 +736,7 @@ void recompute() sum = getVal(val); left.lazyPropogate(); right.lazyPropogate(); - sum += left.sum + right.sum; + sum = add(sum, add(left.sum, right.sum)); } void lazyPropogate() @@ -736,14 +745,14 @@ void lazyPropogate() return; } val = setVal(val, lazy); - sum += size * lazy; + sum = add(sum, multiply(size, lazy)); if (!NULL.equals(left)) { - left.lazy += lazy; + left.lazy = add(left.lazy, lazy); } if (!NULL.equals(right)) { - right.lazy += lazy; + right.lazy = add(right.lazy, lazy); } - lazy = 0; + lazy = zero(); } @Override @@ -754,12 +763,12 @@ public boolean equals(Object that) } } - public static class TestVal implements Comparable + public static class TestX implements Comparable { final String a; double b; - public TestVal(String a, double b) + public TestX(String a, double b) { this.a = a; this.b = b; @@ -781,27 +790,100 @@ public void setB(double b) } @Override - public int compareTo(TestVal that) + public int compareTo(TestX that) { return a.compareTo(that.getA()); } } - public static class TestTreap extends Treap + public static class SegmentTreap extends Treap> { + + final Pair ZERO = Pair.of(0.0, 0.0); + @Override - protected double getVal(TestVal val) + protected Pair getVal(SegmentAndSum val) + { + return Pair.of(val.leftSum, val.rightSum); + } + + @Override + protected SegmentAndSum setVal(SegmentAndSum val, Pair lazy) + { + val.leftSum += lazy.lhs; + val.rightSum += lazy.rhs; + return val; + } + + @Override + protected Pair zero() + { + return ZERO; + } + + @Override + protected Pair add(Pair a, Pair b) + { + return Pair.of(a.lhs + b.lhs, a.rhs + b.rhs); + } + + @Override + protected Pair multiply(int a, Pair b) + { + return Pair.of(a * b.lhs, a * b.rhs); + } + + public List toList() + { + List list = new ArrayList<>(); + accumulate(list, root); + return list; + } + + private void accumulate(List list, TreapNode node) + { + if (NULL.equals(node)) { + return; + } + accumulate(list, node.left); + list.add(node.val); + accumulate(list, node.right); + } + } + + public static class TestTreap extends Treap + { + @Override + protected Double getVal(TestX val) { return val.getB(); } @Override - protected TestVal setVal(TestVal val, double lazy) + protected TestX setVal(TestX val, Double lazy) { val.setB(val.getB() + lazy); return val; } + @Override + protected Double add(Double a, Double b) + { + return Double.sum(a, b); + } + + @Override + protected Double multiply(int a, Double b) + { + return a * b; + } + + @Override + protected Double zero() + { + return 0.0; + } + public void print() { print(this.root); diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index 56b8deecaa96..446ac5cf3359 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -163,7 +163,7 @@ public void treapBenchmarkTest() long start = System.currentTimeMillis(); SegmentsCostCache.TestTreap treap = new SegmentsCostCache.TestTreap(); for (int i = 0; i < n; i++) { - SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); treap.update(val, 1.0, false); treap.update(val, 3.0, true); treap.insert(val); @@ -171,7 +171,7 @@ public void treapBenchmarkTest() System.out.println(treap.query()); long end = System.currentTimeMillis(); for (int i = n - 1; i >= 0; i -= 3) { - SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); treap.remove(val); treap.update(val, -1.0, false); treap.update(val, -3.0, true); @@ -181,34 +181,34 @@ public void treapBenchmarkTest() System.out.println("TreeSet:"); start = System.currentTimeMillis(); - NavigableSet set = new TreeSet<>(); + NavigableSet set = new TreeSet<>(); for (int i = 0; i < n; i++) { - SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); - for (SegmentsCostCache.TestVal l : set.headSet(val)) { + SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); + for (SegmentsCostCache.TestX l : set.headSet(val)) { l.setB(l.getB() + 1.0); } - for (SegmentsCostCache.TestVal u : set.tailSet(val)) { + for (SegmentsCostCache.TestX u : set.tailSet(val)) { u.setB(u.getB() + 3.0); } set.add(val); } double ans = 0; - for (SegmentsCostCache.TestVal val : set) { + for (SegmentsCostCache.TestX val : set) { ans += val.getB(); } System.out.println(ans); for (int i = n - 1; i >= 0; i -= 3) { - SegmentsCostCache.TestVal val = new SegmentsCostCache.TestVal(ids.get(i), vals.get(i)); + SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); set.remove(val); - for (SegmentsCostCache.TestVal l : set.headSet(val)) { + for (SegmentsCostCache.TestX l : set.headSet(val)) { l.setB(l.getB() - 1.0); } - for (SegmentsCostCache.TestVal u : set.tailSet(val)) { + for (SegmentsCostCache.TestX u : set.tailSet(val)) { u.setB(u.getB() - 3.0); } } ans = 0; - for (SegmentsCostCache.TestVal val : set) { + for (SegmentsCostCache.TestX val : set) { ans += val.getB(); } System.out.println(ans); From 10f9aecef5d8d4d5949e9f903686f3575fe6fca9 Mon Sep 17 00:00:00 2001 From: Amatya Date: Mon, 11 Apr 2022 23:32:31 +0530 Subject: [PATCH 03/13] Segment cost strategy using treap --- .../org/apache/druid/timeline/SegmentId.java | 1 - .../coordinator/cost/SegmentsCostCache.java | 42 ++++++++++--------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/apache/druid/timeline/SegmentId.java b/core/src/main/java/org/apache/druid/timeline/SegmentId.java index e1d3a97c24a0..587669f3c481 100644 --- a/core/src/main/java/org/apache/druid/timeline/SegmentId.java +++ b/core/src/main/java/org/apache/druid/timeline/SegmentId.java @@ -323,7 +323,6 @@ public DateTime getIntervalEnd() public Interval getInterval() { return interval; - //return new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); } public String getVersion() diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 1d9be04ede6a..1df4f3a63f98 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -38,9 +38,7 @@ import java.util.List; import java.util.ListIterator; import java.util.NavigableMap; -import java.util.NavigableSet; import java.util.TreeMap; -import java.util.TreeSet; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -321,7 +319,7 @@ public static Builder builder(Interval interval) static class Builder { protected final Interval interval; - private final NavigableSet segments = new TreeSet<>(); + private final SegmentTreap treap = new SegmentTreap(); public Builder(Interval interval) { @@ -344,22 +342,22 @@ public Builder addSegment(DataSegment dataSegment) SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, leftValue, rightValue); // left/right value should be added to left/right sums for elements greater/lower than current segment - segments.tailSet(segmentAndSum).forEach(v -> v.leftSum += leftValue); - segments.headSet(segmentAndSum).forEach(v -> v.rightSum += rightValue); + treap.update(segmentAndSum, Pair.of(leftValue, 0.0), true); + treap.update(segmentAndSum, Pair.of(0.0, rightValue), false); // leftSum_i = leftValue_i + \sum leftValue_j = leftValue_i + leftSum_{i-1} , j < i - SegmentAndSum lower = segments.lower(segmentAndSum); + SegmentAndSum lower = treap.lower(segmentAndSum).val; if (lower != null) { segmentAndSum.leftSum = leftValue + lower.leftSum; } // rightSum_i = rightValue_i + \sum rightValue_j = rightValue_i + rightSum_{i+1} , j > i - SegmentAndSum higher = segments.higher(segmentAndSum); + SegmentAndSum higher = treap.upper(segmentAndSum).val; if (higher != null) { segmentAndSum.rightSum = rightValue + higher.rightSum; } - if (!segments.add(segmentAndSum)) { + if (!treap.insert(segmentAndSum)) { throw new ISE("expect new segment"); } return this; @@ -369,7 +367,7 @@ public Builder removeSegment(DataSegment dataSegment) { SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, 0.0, 0.0); - if (!segments.remove(segmentAndSum)) { + if (!treap.remove(segmentAndSum)) { return this; } @@ -379,23 +377,24 @@ public Builder removeSegment(DataSegment dataSegment) double leftValue = FastMath.exp(t0) - FastMath.exp(t1); double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); - segments.tailSet(segmentAndSum).forEach(v -> v.leftSum -= leftValue); - segments.headSet(segmentAndSum).forEach(v -> v.rightSum -= rightValue); + treap.update(segmentAndSum, Pair.of(-leftValue, 0.0), true); + treap.update(segmentAndSum, Pair.of(0.0, -rightValue), false); + return this; } public boolean isEmpty() { - return segments.isEmpty(); + return treap.isEmpty(); } public Bucket build() { - ArrayList segmentsList = new ArrayList<>(segments.size()); - double[] leftSum = new double[segments.size()]; - double[] rightSum = new double[segments.size()]; + ArrayList segmentsList = new ArrayList<>(); + double[] leftSum = new double[treap.root.size]; + double[] rightSum = new double[treap.root.size]; int i = 0; - for (SegmentAndSum segmentAndSum : segments) { + for (SegmentAndSum segmentAndSum : treap.toList()) { segmentsList.add(segmentAndSum.dataSegment); leftSum[i] = segmentAndSum.leftSum; rightSum[i] = segmentAndSum.rightSum; @@ -488,14 +487,18 @@ public TreapNode ceil(X val) return ceil(val, root); } - public void insert(X val) + public boolean insert(X val) { + int oldSize = root.size; root = insert(new TreapNode(val), root); + return root.size > oldSize; } - public void remove(X val) + public boolean remove(X val) { + int oldSize = root.size; root = remove(val, root); + return root.size < oldSize; } public TreapNode getMin() @@ -799,7 +802,7 @@ public int compareTo(TestX that) public static class SegmentTreap extends Treap> { - final Pair ZERO = Pair.of(0.0, 0.0); + static final Pair ZERO = Pair.of(0.0, 0.0); @Override protected Pair getVal(SegmentAndSum val) @@ -845,6 +848,7 @@ private void accumulate(List list, TreapNode node) if (NULL.equals(node)) { return; } + node.lazyPropogate(); accumulate(list, node.left); list.add(node.val); accumulate(list, node.right); From 93c9d0bc85dee4fcd523c686750d92f32cb40840 Mon Sep 17 00:00:00 2001 From: Amatya Date: Thu, 14 Apr 2022 05:04:59 +0530 Subject: [PATCH 04/13] Store Intervals instead of DataSegments in Buckets --- .../coordinator/cost/SegmentsCostCache.java | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 1df4f3a63f98..148e73c3020d 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -113,12 +113,15 @@ public class SegmentsCostCache private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); + private static final Comparator INTERVAL_COMPARATOR = Comparators.intervalsByStartThenEnd(); + private static final Comparator SEGMENT_INTERVAL_COMPARATOR = Comparator.comparing(DataSegment::getInterval, Comparators.intervalsByStartThenEnd()); private static final Comparator BUCKET_INTERVAL_COMPARATOR = Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); + private static final Ordering INTERVAL_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); private static final Ordering SEGMENT_ORDERING = Ordering.from(SEGMENT_INTERVAL_COMPARATOR); private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); @@ -214,18 +217,18 @@ static class Bucket { private final Interval interval; private final Interval calculationInterval; - private final ArrayList sortedSegments; + private final ArrayList sortedIntervals; private final double[] leftSum; private final double[] rightSum; - Bucket(Interval interval, ArrayList sortedSegments, double[] leftSum, double[] rightSum) + Bucket(Interval interval, ArrayList sortedIntervals, double[] leftSum, double[] rightSum) { this.interval = Preconditions.checkNotNull(interval, "interval"); - this.sortedSegments = Preconditions.checkNotNull(sortedSegments, "sortedSegments"); + this.sortedIntervals = Preconditions.checkNotNull(sortedIntervals, "sortedSegments"); this.leftSum = Preconditions.checkNotNull(leftSum, "leftSum"); this.rightSum = Preconditions.checkNotNull(rightSum, "rightSum"); - Preconditions.checkArgument(sortedSegments.size() == leftSum.length && sortedSegments.size() == rightSum.length); - Preconditions.checkArgument(SEGMENT_ORDERING.isOrdered(sortedSegments)); + Preconditions.checkArgument(sortedIntervals.size() == leftSum.length && sortedIntervals.size() == rightSum.length); + Preconditions.checkArgument(INTERVAL_ORDERING.isOrdered(sortedIntervals)); this.calculationInterval = new Interval( interval.getStart().minus(LIFE_THRESHOLD), interval.getEnd().plus(LIFE_THRESHOLD) @@ -245,15 +248,15 @@ boolean inCalculationInterval(DataSegment dataSegment) double cost(DataSegment dataSegment) { // cost is calculated relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment, interval); - double t1 = convertEnd(dataSegment, interval); + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); // avoid calculation for segments outside of LIFE_THRESHOLD if (!inCalculationInterval(dataSegment)) { throw new ISE("Segment is not within calculation interval"); } - int index = Collections.binarySearch(sortedSegments, dataSegment, SEGMENT_INTERVAL_COMPARATOR); + int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); index = (index >= 0) ? index : -index - 1; return addLeftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); } @@ -264,9 +267,9 @@ private double addLeftCost(DataSegment dataSegment, double t0, double t1, int in // add to cost all left-overlapping segments int leftIndex = index - 1; while (leftIndex >= 0 - && sortedSegments.get(leftIndex).getInterval().overlaps(dataSegment.getInterval())) { - double start = convertStart(sortedSegments.get(leftIndex), interval); - double end = convertEnd(sortedSegments.get(leftIndex), interval); + && sortedIntervals.get(leftIndex).overlaps(dataSegment.getInterval())) { + double start = convertStart(sortedIntervals.get(leftIndex), interval); + double end = convertEnd(sortedIntervals.get(leftIndex), interval); leftCost += CostBalancerStrategy.intervalCost(end - start, t0 - start, t1 - start); --leftIndex; } @@ -282,28 +285,28 @@ private double rightCost(DataSegment dataSegment, double t0, double t1, int inde double rightCost = 0.0; // add all right-overlapping segments int rightIndex = index; - while (rightIndex < sortedSegments.size() && - sortedSegments.get(rightIndex).getInterval().overlaps(dataSegment.getInterval())) { - double start = convertStart(sortedSegments.get(rightIndex), interval); - double end = convertEnd(sortedSegments.get(rightIndex), interval); + while (rightIndex < sortedIntervals.size() && + sortedIntervals.get(rightIndex).overlaps(dataSegment.getInterval())) { + double start = convertStart(sortedIntervals.get(rightIndex), interval); + double end = convertEnd(sortedIntervals.get(rightIndex), interval); rightCost += CostBalancerStrategy.intervalCost(t1 - t0, start - t0, end - t0); ++rightIndex; } // add right-non-overlapping segments - if (rightIndex < sortedSegments.size()) { + if (rightIndex < sortedIntervals.size()) { rightCost += rightSum[rightIndex] * (FastMath.exp(t0) - FastMath.exp(t1)); } return rightCost; } - private static double convertStart(DataSegment dataSegment, Interval interval) + private static double convertStart(Interval interval, Interval reference) { - return toLocalInterval(dataSegment.getInterval().getStartMillis(), interval); + return toLocalInterval(interval.getStartMillis(), reference); } - private static double convertEnd(DataSegment dataSegment, Interval interval) + private static double convertEnd(Interval interval, Interval reference) { - return toLocalInterval(dataSegment.getInterval().getEndMillis(), interval); + return toLocalInterval(interval.getEndMillis(), reference); } private static double toLocalInterval(long millis, Interval interval) @@ -333,8 +336,8 @@ public Builder addSegment(DataSegment dataSegment) } // all values are pre-computed relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment, interval); - double t1 = convertEnd(dataSegment, interval); + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); double leftValue = FastMath.exp(t0) - FastMath.exp(t1); double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); @@ -371,8 +374,8 @@ public Builder removeSegment(DataSegment dataSegment) return this; } - double t0 = convertStart(dataSegment, interval); - double t1 = convertEnd(dataSegment, interval); + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); double leftValue = FastMath.exp(t0) - FastMath.exp(t1); double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); @@ -390,22 +393,22 @@ public boolean isEmpty() public Bucket build() { - ArrayList segmentsList = new ArrayList<>(); + ArrayList intervalsList = new ArrayList<>(); double[] leftSum = new double[treap.root.size]; double[] rightSum = new double[treap.root.size]; int i = 0; for (SegmentAndSum segmentAndSum : treap.toList()) { - segmentsList.add(segmentAndSum.dataSegment); + intervalsList.add(segmentAndSum.dataSegment.getInterval()); leftSum[i] = segmentAndSum.leftSum; rightSum[i] = segmentAndSum.rightSum; ++i; } - long bucketEndMillis = segmentsList + long bucketEndMillis = intervalsList .stream() - .mapToLong(s -> s.getInterval().getEndMillis()) + .mapToLong(interval -> interval.getEndMillis()) .max() .orElseGet(interval::getEndMillis); - return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), segmentsList, leftSum, rightSum); + return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervalsList, leftSum, rightSum); } } } From 4575bbacfe0fdda0de0ca3b2e68179a8cf5c4abb Mon Sep 17 00:00:00 2001 From: Amatya Date: Thu, 14 Apr 2022 14:56:22 +0530 Subject: [PATCH 05/13] Optimize computation of segment bucket cost --- .../coordinator/cost/SegmentsCostCache.java | 130 +++++++++++++++++- 1 file changed, 125 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 148e73c3020d..e742eeed3e09 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -122,7 +122,6 @@ public class SegmentsCostCache Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); private static final Ordering INTERVAL_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); - private static final Ordering SEGMENT_ORDERING = Ordering.from(SEGMENT_INTERVAL_COMPARATOR); private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); private final ArrayList sortedBuckets; @@ -221,6 +220,13 @@ static class Bucket private final double[] leftSum; private final double[] rightSum; + private final double[] cumStart; + private final double[] cumStartExp; + private final double[] cumStartExpInv; + private final double[] cumEnd; + private final double[] cumEndExp; + private final double[] cumEndExpInv; + Bucket(Interval interval, ArrayList sortedIntervals, double[] leftSum, double[] rightSum) { this.interval = Preconditions.checkNotNull(interval, "interval"); @@ -233,6 +239,26 @@ static class Bucket interval.getStart().minus(LIFE_THRESHOLD), interval.getEnd().plus(LIFE_THRESHOLD) ); + + int n = leftSum.length; + + cumStart = new double[n + 1]; + cumStartExp = new double[n + 1]; + cumStartExpInv = new double[n + 1]; + cumEnd = new double[n + 1]; + cumEndExp = new double[n + 1]; + cumEndExpInv = new double[n + 1]; + for (int i = 0; i < n; i++) { + double start = convertStart(sortedIntervals.get(i), interval); + cumStart[i + 1] = cumStart[i] + start; + cumStartExp[i + 1] = cumStartExp[i] + FastMath.exp(start); + cumStartExpInv[i + 1] = cumStartExpInv[i] + FastMath.exp(-start); + + double end = convertEnd(sortedIntervals.get(i), interval); + cumEnd[i + 1] = cumEnd[i] + end; + cumEndExp[i + 1] = cumEndExp[i] + FastMath.exp(end); + cumEndExpInv[i + 1] = cumEndExpInv[i] + FastMath.exp(-end); + } } Interval getInterval() @@ -245,7 +271,7 @@ boolean inCalculationInterval(DataSegment dataSegment) return calculationInterval.overlaps(dataSegment.getInterval()); } - double cost(DataSegment dataSegment) + double costOld(DataSegment dataSegment) { // cost is calculated relatively to bucket start (which is considered as 0) double t0 = convertStart(dataSegment.getInterval(), interval); @@ -258,10 +284,10 @@ boolean inCalculationInterval(DataSegment dataSegment) int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); index = (index >= 0) ? index : -index - 1; - return addLeftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + return leftCostOld(dataSegment, t0, t1, index) + rightCostOld(dataSegment, t0, t1, index); } - private double addLeftCost(DataSegment dataSegment, double t0, double t1, int index) + private double leftCostOld(DataSegment dataSegment, double t0, double t1, int index) { double leftCost = 0.0; // add to cost all left-overlapping segments @@ -280,7 +306,7 @@ private double addLeftCost(DataSegment dataSegment, double t0, double t1, int in return leftCost; } - private double rightCost(DataSegment dataSegment, double t0, double t1, int index) + private double rightCostOld(DataSegment dataSegment, double t0, double t1, int index) { double rightCost = 0.0; // add all right-overlapping segments @@ -299,6 +325,100 @@ private double rightCost(DataSegment dataSegment, double t0, double t1, int inde return rightCost; } + double cost(DataSegment dataSegment) + { + // cost is calculated relatively to bucket start (which is considered as 0) + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); + + // avoid calculation for segments outside of LIFE_THRESHOLD + if (!inCalculationInterval(dataSegment)) { + throw new ISE("Segment is not within calculation interval"); + } + + int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); + index = (index >= 0) ? index : -index - 1; + return leftCostOld(dataSegment, t0, t1, index) + rightCostOld(dataSegment, t0, t1, index); + } + + private double leftCost(DataSegment dataSegment, double t0, double t1, int index) + { + if (index - 1 < 0) { + return 0; + } + double exp0 = FastMath.exp(t0); + double expInv0 = 1 / exp0; + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double leftCost = 0.0; + // add to cost all left-overlapping segments + int rightBound = index - 1; + int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval()); + leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); + leftCost -= 2 * (rightBound - leftBound + 1) * t0; + leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); + // add left-non-overlapping segments + if (leftBound > 0) { + leftCost += leftSum[leftBound - 1] * (expInv1 - expInv0); + } + return leftCost; + } + + private double rightCost(DataSegment dataSegment, double t0, double t1, int index) + { + int n = leftSum.length; + if (index >= n) { + return 0; + } + double exp0 = FastMath.exp(t0); + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double rightCost = 0.0; + int leftBound = index; + int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval()); + // add all right-overlapping segments + rightCost += 2 * (rightBound - leftBound + 1) * t1; + rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); + rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); + rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + // add right-non-overlapping segments + if (rightBound + 1 < n) { + rightCost += rightSum[rightBound + 1] * (exp0 - exp1); + } + return rightCost; + } + + private int leftBoundary(int l, int r, Interval interval) + { + if (l == r) { + return interval.overlaps(sortedIntervals.get(l)) ? l : r + 1; + } + int m = (l + r) / 2; + if (interval.overlaps(sortedIntervals.get(m))) { + return leftBoundary(l, m, interval); + } else { + return leftBoundary(m + 1, r, interval); + } + } + + private int rightBoundary(int l, int r, Interval interval) + { + if (l == r) { + return interval.overlaps(sortedIntervals.get(r)) ? r : l - 1; + } + int m = (l + r + 1) / 2; + if (interval.overlaps(sortedIntervals.get(m))) { + return rightBoundary(m, r, interval); + } else { + return rightBoundary(l, m - 1, interval); + } + } + private static double convertStart(Interval interval, Interval reference) { return toLocalInterval(interval.getStartMillis(), reference); From 18baa6388f62ac61f1dcfb5b72fd8f445e342c4f Mon Sep 17 00:00:00 2001 From: Amatya Date: Thu, 14 Apr 2022 15:44:36 +0530 Subject: [PATCH 06/13] Use new cost code and cleanup --- .../coordinator/cost/SegmentsCostCache.java | 56 +------------------ 1 file changed, 1 insertion(+), 55 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index e742eeed3e09..c64c0318e6e1 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -271,60 +271,6 @@ boolean inCalculationInterval(DataSegment dataSegment) return calculationInterval.overlaps(dataSegment.getInterval()); } - double costOld(DataSegment dataSegment) - { - // cost is calculated relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); - - // avoid calculation for segments outside of LIFE_THRESHOLD - if (!inCalculationInterval(dataSegment)) { - throw new ISE("Segment is not within calculation interval"); - } - - int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); - index = (index >= 0) ? index : -index - 1; - return leftCostOld(dataSegment, t0, t1, index) + rightCostOld(dataSegment, t0, t1, index); - } - - private double leftCostOld(DataSegment dataSegment, double t0, double t1, int index) - { - double leftCost = 0.0; - // add to cost all left-overlapping segments - int leftIndex = index - 1; - while (leftIndex >= 0 - && sortedIntervals.get(leftIndex).overlaps(dataSegment.getInterval())) { - double start = convertStart(sortedIntervals.get(leftIndex), interval); - double end = convertEnd(sortedIntervals.get(leftIndex), interval); - leftCost += CostBalancerStrategy.intervalCost(end - start, t0 - start, t1 - start); - --leftIndex; - } - // add left-non-overlapping segments - if (leftIndex >= 0) { - leftCost += leftSum[leftIndex] * (FastMath.exp(-t1) - FastMath.exp(-t0)); - } - return leftCost; - } - - private double rightCostOld(DataSegment dataSegment, double t0, double t1, int index) - { - double rightCost = 0.0; - // add all right-overlapping segments - int rightIndex = index; - while (rightIndex < sortedIntervals.size() && - sortedIntervals.get(rightIndex).overlaps(dataSegment.getInterval())) { - double start = convertStart(sortedIntervals.get(rightIndex), interval); - double end = convertEnd(sortedIntervals.get(rightIndex), interval); - rightCost += CostBalancerStrategy.intervalCost(t1 - t0, start - t0, end - t0); - ++rightIndex; - } - // add right-non-overlapping segments - if (rightIndex < sortedIntervals.size()) { - rightCost += rightSum[rightIndex] * (FastMath.exp(t0) - FastMath.exp(t1)); - } - return rightCost; - } - double cost(DataSegment dataSegment) { // cost is calculated relatively to bucket start (which is considered as 0) @@ -338,7 +284,7 @@ private double rightCostOld(DataSegment dataSegment, double t0, double t1, int i int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); index = (index >= 0) ? index : -index - 1; - return leftCostOld(dataSegment, t0, t1, index) + rightCostOld(dataSegment, t0, t1, index); + return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); } private double leftCost(DataSegment dataSegment, double t0, double t1, int index) From 65f085acce05765f432f925e0c525b84f2660531 Mon Sep 17 00:00:00 2001 From: Amatya Date: Sat, 16 Apr 2022 02:05:18 +0530 Subject: [PATCH 07/13] Refactoring, cleaning, improved V3 --- .idea/misc.xml | 14 +- .../apache/druid/java/util/common/Treap.java | 345 +++++++++ .../org/apache/druid/timeline/SegmentId.java | 4 +- .../druid/timeline/DataSegmentTest.java | 30 - .../coordinator/cost/SegmentsCostCache.java | 669 ++---------------- .../coordinator/cost/SegmentsCostCacheV2.java | 551 +++++++++++++++ .../coordinator/cost/SegmentsCostCacheV3.java | 455 ++++++++++++ .../cost/SegmentsCostCacheTest.java | 77 +- .../cost/SegmentsCostCacheV2Test.java | 198 ++++++ .../cost/SegmentsCostCacheV3Test.java | 198 ++++++ 10 files changed, 1832 insertions(+), 709 deletions(-) create mode 100644 core/src/main/java/org/apache/druid/java/util/common/Treap.java create mode 100644 server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java create mode 100644 server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java create mode 100644 server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java create mode 100644 server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java diff --git a/.idea/misc.xml b/.idea/misc.xml index bf2061d7392d..c6a0d654dbd4 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -46,7 +46,7 @@ @@ -84,7 +87,10 @@ - + - + + + \ No newline at end of file diff --git a/core/src/main/java/org/apache/druid/java/util/common/Treap.java b/core/src/main/java/org/apache/druid/java/util/common/Treap.java new file mode 100644 index 000000000000..b3e979fdc2cc --- /dev/null +++ b/core/src/main/java/org/apache/druid/java/util/common/Treap.java @@ -0,0 +1,345 @@ +package org.apache.druid.java.util.common; + +import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.List; + +public abstract class Treap, Y> +{ + protected TreapNode root; + protected final TreapNode NULL; + + public Treap() + { + NULL = new TreapNode(null); + NULL.left = NULL.right = NULL; + NULL.priority = Double.POSITIVE_INFINITY; + root = NULL; + } + + public boolean isEmpty() + { + return NULL.equals(root); + } + + public int size() + { + return root.size; + } + + public boolean contains(X val) + { + return contains(val, root); + } + + public X lower(X val) + { + return lower(val, root).val; + } + + public X upper(X val) + { + return upper(val, root).val; + } + + public X floor(X val) + { + return floor(val, root).val; + } + + public X ceil(X val) + { + return ceil(val, root).val; + } + + public boolean insert(X val) + { + int oldSize = root.size; + root = insert(new TreapNode(val), root); + return root.size > oldSize; + } + + public boolean remove(X val) + { + int oldSize = root.size; + root = remove(val, root); + return root.size < oldSize; + } + + public X getMin() + { + TreapNode node = root; + while (!NULL.equals(node.left)) { + node = node.left; + } + return node.val; + } + + public X getMax() + { + TreapNode node = root; + while (!NULL.equals(node.right)) { + node = node.right; + } + return node.val; + } + + public void update(X val, Y lazy, boolean dir) + { + if (dir) { + root = update(root, val, null, lazy); + } else { + root = update(root, null, val, lazy); + } + } + + public List toList() + { + List list = new ArrayList<>(); + accumulate(root, list); + return list; + } + + protected abstract Y getVal(X val); + + protected abstract X setVal(X val, Y lazy); + + protected abstract Y add(Y a, Y b); + + protected abstract Y multiply(int a, Y b); + + protected abstract Y zero(); + + private boolean contains(X val, TreapNode node) + { + if (NULL.equals(node)) { + return false; + } + final int cmp = val.compareTo(node.val); + if (cmp < 0) { + return contains(val, node.left); + } + if (cmp > 0) { + return contains(val, node.right); + } + return true; + } + + private TreapNode lower(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp <= 0) { + return lower(val, node.left); + } else { + TreapNode ret = lower(val, node.right); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode upper(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp >= 0) { + return upper(val, node.right); + } else { + TreapNode ret = upper(val, node.left); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode floor(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp < 0) { + return floor(val, node.left); + } else { + TreapNode ret = floor(val, node.right); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode ceil(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + final int cmp = val.compareTo(node.val); + if (cmp > 0) { + return ceil(val, node.right); + } else { + TreapNode ret = ceil(val, node.left); + return (NULL.equals(ret)) ? node : ret; + } + } + + private TreapNode insert(TreapNode val, TreapNode node) + { + if (NULL.equals(node)) { + return val; + } + Pair pair = split(node, val.val); + node = merge(pair.lhs, val); + node = merge(node, pair.rhs); + return node; + } + + private TreapNode remove(X val, TreapNode node) + { + if (NULL.equals(node)) { + return node; + } + Pair pair = split(node, val); + TreapNode lower = lower(val, pair.lhs); + if (NULL.equals(lower)) { + return pair.rhs; + } + return merge(split(pair.lhs, lower.val).lhs, pair.rhs); + } + + private Pair split(TreapNode node, X val) + { + if (NULL.equals(node)) { + return Pair.of(NULL, NULL); + } + node.lazyPropogate(); + final int cmp = val.compareTo(node.val); + Pair pair; + if (cmp < 0) { + pair = split(node.left, val); + node.left = pair.rhs; + pair = Pair.of(pair.lhs, node); + } else { + pair = split(node.right, val); + node.right = pair.lhs; + pair = Pair.of(node, pair.rhs); + } + node.recompute(); + return pair; + } + + private TreapNode merge(TreapNode left, TreapNode right) + { + if (NULL.equals(left)) { + return right; + } + if (NULL.equals(right)) { + return left; + } + left.lazyPropogate(); + right.lazyPropogate(); + TreapNode node; + if (left.priority < right.priority) { + left.right = merge(left.right, right); + node = left; + } else { + right.left = merge(left, right.left); + node = right; + } + node.recompute(); + return node; + } + + private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, Y lazy) + { + TreapNode left = NULL; + TreapNode right = NULL; + if (begin != null) { + Pair pair = split(node, begin); + left = pair.lhs; + node = pair.rhs; + } + if (end != null) { + Pair pair = split(node, end); + node = pair.lhs; + right = pair.rhs; + } + node.lazy = add(node.lazy, lazy); + node = merge(left, node); + node = merge(node, right); + return node; + } + + private void accumulate(TreapNode node, List list) + { + if (NULL.equals(node)) { + return; + } + node.lazyPropogate(); + accumulate(node.left, list); + list.add(node.val); + accumulate(node.right, list); + } + + class TreapNode + { + X val; + Y sum; + Y lazy; + TreapNode left; + TreapNode right; + double priority; + int size; + + TreapNode(@Nullable X val) + { + this(val, NULL, NULL); + if (val != null) { + sum = getVal(val); + size = 1; + } + } + + TreapNode(@Nullable X val, @Nullable TreapNode left, @Nullable TreapNode right) + { + this.val = val; + this.left = left; + this.right = right; + this.priority = Math.random(); + this.sum = zero(); + this.lazy = zero(); + } + + public void recompute() + { + if (NULL.equals(this)) { + return; + } + size = 1 + left.size + right.size; + sum = getVal(val); + left.lazyPropogate(); + right.lazyPropogate(); + sum = add(sum, add(left.sum, right.sum)); + } + + public void lazyPropogate() + { + if (NULL.equals(this)) { + return; + } + val = setVal(val, lazy); + sum = add(sum, multiply(size, lazy)); + if (!NULL.equals(left)) { + left.lazy = add(left.lazy, lazy); + } + if (!NULL.equals(right)) { + right.lazy = add(right.lazy, lazy); + } + lazy = zero(); + } + + @Override + public boolean equals(Object that) + { + return this == that; + } + } +} diff --git a/core/src/main/java/org/apache/druid/timeline/SegmentId.java b/core/src/main/java/org/apache/druid/timeline/SegmentId.java index 587669f3c481..8430524021c0 100644 --- a/core/src/main/java/org/apache/druid/timeline/SegmentId.java +++ b/core/src/main/java/org/apache/druid/timeline/SegmentId.java @@ -266,7 +266,6 @@ public static SegmentId dummy(String dataSource, int partitionNum) private final long intervalEndMillis; @Nullable private final Chronology intervalChronology; - private final Interval interval; private final String version; private final int partitionNum; @@ -282,7 +281,6 @@ private SegmentId(String dataSource, Interval interval, String version, int part this.intervalStartMillis = interval.getStartMillis(); this.intervalEndMillis = interval.getEndMillis(); this.intervalChronology = interval.getChronology(); - this.interval = new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); // Versions are timestamp-based Strings, interning of them doesn't make sense. If this is not the case, interning // could be conditionally allowed via a system property. this.version = Objects.requireNonNull(version); @@ -322,7 +320,7 @@ public DateTime getIntervalEnd() public Interval getInterval() { - return interval; + return new Interval(intervalStartMillis, intervalEndMillis, intervalChronology); } public String getVersion() diff --git a/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java b/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java index d16686a62734..87ec7b17869b 100644 --- a/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java +++ b/core/src/test/java/org/apache/druid/timeline/DataSegmentTest.java @@ -411,36 +411,6 @@ public void testTombstoneType() } - @Test - public void getIntervalBenchmarkTest() - { - final DataSegment segment = DataSegment.builder() - .dataSource("foo") - .interval(Intervals.of("2012-01-01/2012-01-02")) - .version(DateTimes.of("2012-01-01T11:22:33.444Z").toString()) - .shardSpec(new TombstoneShardSpec()) - .loadSpec(Collections.singletonMap( - "type", - DataSegment.TOMBSTONE_LOADSPEC_TYPE - )) - .size(0) - .build(); - - long start = System.currentTimeMillis(); - int cnt = 0; - - for (int i = 0; i < 1000000000; i++) { - Interval interval = segment.getInterval(); - cnt++; - if (cnt == 100000000) { - cnt = 0; - System.out.println(interval); - } - } - long end = System.currentTimeMillis(); - System.out.println(end - start); - } - private DataSegment makeDataSegment(String dataSource, String interval, String version) { return DataSegment.builder() diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index c64c0318e6e1..4b5432c8cd11 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -24,21 +24,20 @@ import org.apache.commons.math3.util.FastMath; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.DurationGranularity; import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.server.coordinator.CostBalancerStrategy; import org.apache.druid.timeline.DataSegment; import org.joda.time.Interval; -import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; -import java.util.List; import java.util.ListIterator; import java.util.NavigableMap; +import java.util.NavigableSet; import java.util.TreeMap; +import java.util.TreeSet; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -49,18 +48,18 @@ * Joint cost for two segments (you can make formulas below readable by copy-pasting to * https://www.codecogs.com/latex/eqneditor.php): * - * cost(Y, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy + * cost(X, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy * or - * cost(Y, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) + * cost(X, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) * if x_0 <= x_1 <= y_0 <= y_1 * (*) lambda coefficient is omitted for simplicity. * * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as: * - * cost(Y, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) + * cost(X, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) * if xi_0 <= xi_1 <= y_0 <= y_1 * and - * cost(Y, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) + * cost(X, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) * if y_0 <= y_1 <= xi_0 <= xi_1 * * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}: @@ -100,7 +99,7 @@ public class SegmentsCostCache /** * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" - * from each other and thus cost(Y,Y) ~ 0 for these segments + * from each other and thus cost(X,Y) ~ 0 for these segments */ private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); @@ -113,15 +112,13 @@ public class SegmentsCostCache private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); - private static final Comparator INTERVAL_COMPARATOR = Comparators.intervalsByStartThenEnd(); - private static final Comparator SEGMENT_INTERVAL_COMPARATOR = Comparator.comparing(DataSegment::getInterval, Comparators.intervalsByStartThenEnd()); private static final Comparator BUCKET_INTERVAL_COMPARATOR = Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); - private static final Ordering INTERVAL_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); + private static final Ordering SEGMENT_ORDERING = Ordering.from(SEGMENT_INTERVAL_COMPARATOR); private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); private final ArrayList sortedBuckets; @@ -199,9 +196,9 @@ public SegmentsCostCache build() { return new SegmentsCostCache( buckets - .values() + .entrySet() .stream() - .map(Bucket.Builder::build) + .map(entry -> entry.getValue().build()) .collect(Collectors.toCollection(ArrayList::new)) ); } @@ -216,49 +213,22 @@ static class Bucket { private final Interval interval; private final Interval calculationInterval; - private final ArrayList sortedIntervals; + private final ArrayList sortedSegments; private final double[] leftSum; private final double[] rightSum; - private final double[] cumStart; - private final double[] cumStartExp; - private final double[] cumStartExpInv; - private final double[] cumEnd; - private final double[] cumEndExp; - private final double[] cumEndExpInv; - - Bucket(Interval interval, ArrayList sortedIntervals, double[] leftSum, double[] rightSum) + Bucket(Interval interval, ArrayList sortedSegments, double[] leftSum, double[] rightSum) { this.interval = Preconditions.checkNotNull(interval, "interval"); - this.sortedIntervals = Preconditions.checkNotNull(sortedIntervals, "sortedSegments"); + this.sortedSegments = Preconditions.checkNotNull(sortedSegments, "sortedSegments"); this.leftSum = Preconditions.checkNotNull(leftSum, "leftSum"); this.rightSum = Preconditions.checkNotNull(rightSum, "rightSum"); - Preconditions.checkArgument(sortedIntervals.size() == leftSum.length && sortedIntervals.size() == rightSum.length); - Preconditions.checkArgument(INTERVAL_ORDERING.isOrdered(sortedIntervals)); + Preconditions.checkArgument(sortedSegments.size() == leftSum.length && sortedSegments.size() == rightSum.length); + Preconditions.checkArgument(SEGMENT_ORDERING.isOrdered(sortedSegments)); this.calculationInterval = new Interval( interval.getStart().minus(LIFE_THRESHOLD), interval.getEnd().plus(LIFE_THRESHOLD) ); - - int n = leftSum.length; - - cumStart = new double[n + 1]; - cumStartExp = new double[n + 1]; - cumStartExpInv = new double[n + 1]; - cumEnd = new double[n + 1]; - cumEndExp = new double[n + 1]; - cumEndExpInv = new double[n + 1]; - for (int i = 0; i < n; i++) { - double start = convertStart(sortedIntervals.get(i), interval); - cumStart[i + 1] = cumStart[i] + start; - cumStartExp[i + 1] = cumStartExp[i] + FastMath.exp(start); - cumStartExpInv[i + 1] = cumStartExpInv[i] + FastMath.exp(-start); - - double end = convertEnd(sortedIntervals.get(i), interval); - cumEnd[i + 1] = cumEnd[i] + end; - cumEndExp[i + 1] = cumEndExp[i] + FastMath.exp(end); - cumEndExpInv[i + 1] = cumEndExpInv[i] + FastMath.exp(-end); - } } Interval getInterval() @@ -274,105 +244,65 @@ boolean inCalculationInterval(DataSegment dataSegment) double cost(DataSegment dataSegment) { // cost is calculated relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); + double t0 = convertStart(dataSegment, interval); + double t1 = convertEnd(dataSegment, interval); // avoid calculation for segments outside of LIFE_THRESHOLD if (!inCalculationInterval(dataSegment)) { throw new ISE("Segment is not within calculation interval"); } - int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); + int index = Collections.binarySearch(sortedSegments, dataSegment, SEGMENT_INTERVAL_COMPARATOR); index = (index >= 0) ? index : -index - 1; - return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + return addLeftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); } - private double leftCost(DataSegment dataSegment, double t0, double t1, int index) + private double addLeftCost(DataSegment dataSegment, double t0, double t1, int index) { - if (index - 1 < 0) { - return 0; - } - double exp0 = FastMath.exp(t0); - double expInv0 = 1 / exp0; - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; double leftCost = 0.0; // add to cost all left-overlapping segments - int rightBound = index - 1; - int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval()); - leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); - leftCost -= 2 * (rightBound - leftBound + 1) * t0; - leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); + int leftIndex = index - 1; + while (leftIndex >= 0 + && sortedSegments.get(leftIndex).getInterval().overlaps(dataSegment.getInterval())) { + double start = convertStart(sortedSegments.get(leftIndex), interval); + double end = convertEnd(sortedSegments.get(leftIndex), interval); + leftCost += CostBalancerStrategy.intervalCost(end - start, t0 - start, t1 - start); + --leftIndex; + } // add left-non-overlapping segments - if (leftBound > 0) { - leftCost += leftSum[leftBound - 1] * (expInv1 - expInv0); + if (leftIndex >= 0) { + leftCost += leftSum[leftIndex] * (FastMath.exp(-t1) - FastMath.exp(-t0)); } return leftCost; } private double rightCost(DataSegment dataSegment, double t0, double t1, int index) { - int n = leftSum.length; - if (index >= n) { - return 0; - } - double exp0 = FastMath.exp(t0); - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; double rightCost = 0.0; - int leftBound = index; - int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval()); // add all right-overlapping segments - rightCost += 2 * (rightBound - leftBound + 1) * t1; - rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); - rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); - rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + int rightIndex = index; + while (rightIndex < sortedSegments.size() && + sortedSegments.get(rightIndex).getInterval().overlaps(dataSegment.getInterval())) { + double start = convertStart(sortedSegments.get(rightIndex), interval); + double end = convertEnd(sortedSegments.get(rightIndex), interval); + rightCost += CostBalancerStrategy.intervalCost(t1 - t0, start - t0, end - t0); + ++rightIndex; + } // add right-non-overlapping segments - if (rightBound + 1 < n) { - rightCost += rightSum[rightBound + 1] * (exp0 - exp1); + if (rightIndex < sortedSegments.size()) { + rightCost += rightSum[rightIndex] * (FastMath.exp(t0) - FastMath.exp(t1)); } return rightCost; } - private int leftBoundary(int l, int r, Interval interval) - { - if (l == r) { - return interval.overlaps(sortedIntervals.get(l)) ? l : r + 1; - } - int m = (l + r) / 2; - if (interval.overlaps(sortedIntervals.get(m))) { - return leftBoundary(l, m, interval); - } else { - return leftBoundary(m + 1, r, interval); - } - } - - private int rightBoundary(int l, int r, Interval interval) - { - if (l == r) { - return interval.overlaps(sortedIntervals.get(r)) ? r : l - 1; - } - int m = (l + r + 1) / 2; - if (interval.overlaps(sortedIntervals.get(m))) { - return rightBoundary(m, r, interval); - } else { - return rightBoundary(l, m - 1, interval); - } - } - - private static double convertStart(Interval interval, Interval reference) + private static double convertStart(DataSegment dataSegment, Interval interval) { - return toLocalInterval(interval.getStartMillis(), reference); + return toLocalInterval(dataSegment.getInterval().getStartMillis(), interval); } - private static double convertEnd(Interval interval, Interval reference) + private static double convertEnd(DataSegment dataSegment, Interval interval) { - return toLocalInterval(interval.getEndMillis(), reference); + return toLocalInterval(dataSegment.getInterval().getEndMillis(), interval); } private static double toLocalInterval(long millis, Interval interval) @@ -387,8 +317,8 @@ public static Builder builder(Interval interval) static class Builder { - protected final Interval interval; - private final SegmentTreap treap = new SegmentTreap(); + private final Interval interval; + private final NavigableSet segments = new TreeSet<>(); public Builder(Interval interval) { @@ -402,8 +332,8 @@ public Builder addSegment(DataSegment dataSegment) } // all values are pre-computed relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); + double t0 = convertStart(dataSegment, interval); + double t1 = convertEnd(dataSegment, interval); double leftValue = FastMath.exp(t0) - FastMath.exp(t1); double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); @@ -411,22 +341,22 @@ public Builder addSegment(DataSegment dataSegment) SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, leftValue, rightValue); // left/right value should be added to left/right sums for elements greater/lower than current segment - treap.update(segmentAndSum, Pair.of(leftValue, 0.0), true); - treap.update(segmentAndSum, Pair.of(0.0, rightValue), false); + segments.tailSet(segmentAndSum).forEach(v -> v.leftSum += leftValue); + segments.headSet(segmentAndSum).forEach(v -> v.rightSum += rightValue); // leftSum_i = leftValue_i + \sum leftValue_j = leftValue_i + leftSum_{i-1} , j < i - SegmentAndSum lower = treap.lower(segmentAndSum).val; + SegmentAndSum lower = segments.lower(segmentAndSum); if (lower != null) { segmentAndSum.leftSum = leftValue + lower.leftSum; } // rightSum_i = rightValue_i + \sum rightValue_j = rightValue_i + rightSum_{i+1} , j > i - SegmentAndSum higher = treap.upper(segmentAndSum).val; + SegmentAndSum higher = segments.higher(segmentAndSum); if (higher != null) { segmentAndSum.rightSum = rightValue + higher.rightSum; } - if (!treap.insert(segmentAndSum)) { + if (!segments.add(segmentAndSum)) { throw new ISE("expect new segment"); } return this; @@ -436,45 +366,44 @@ public Builder removeSegment(DataSegment dataSegment) { SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, 0.0, 0.0); - if (!treap.remove(segmentAndSum)) { + if (!segments.remove(segmentAndSum)) { return this; } - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); + double t0 = convertStart(dataSegment, interval); + double t1 = convertEnd(dataSegment, interval); double leftValue = FastMath.exp(t0) - FastMath.exp(t1); double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); - treap.update(segmentAndSum, Pair.of(-leftValue, 0.0), true); - treap.update(segmentAndSum, Pair.of(0.0, -rightValue), false); - + segments.tailSet(segmentAndSum).forEach(v -> v.leftSum -= leftValue); + segments.headSet(segmentAndSum).forEach(v -> v.rightSum -= rightValue); return this; } public boolean isEmpty() { - return treap.isEmpty(); + return segments.isEmpty(); } public Bucket build() { - ArrayList intervalsList = new ArrayList<>(); - double[] leftSum = new double[treap.root.size]; - double[] rightSum = new double[treap.root.size]; + ArrayList segmentsList = new ArrayList<>(segments.size()); + double[] leftSum = new double[segments.size()]; + double[] rightSum = new double[segments.size()]; int i = 0; - for (SegmentAndSum segmentAndSum : treap.toList()) { - intervalsList.add(segmentAndSum.dataSegment.getInterval()); + for (SegmentAndSum segmentAndSum : segments) { + segmentsList.add(segmentAndSum.dataSegment); leftSum[i] = segmentAndSum.leftSum; rightSum[i] = segmentAndSum.rightSum; ++i; } - long bucketEndMillis = intervalsList + long bucketEndMillis = segmentsList .stream() - .mapToLong(interval -> interval.getEndMillis()) + .mapToLong(s -> s.getInterval().getEndMillis()) .max() .orElseGet(interval::getEndMillis); - return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervalsList, leftSum, rightSum); + return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), segmentsList, leftSum, rightSum); } } } @@ -502,7 +431,7 @@ public int compareTo(SegmentAndSum o) @Override public boolean equals(Object obj) { - throw new UnsupportedOperationException("Use SegmentAndSum.compareTo()"); + throw new UnsupportedOperationException("Use IntervalAndSum.compareTo()"); } @Override @@ -511,466 +440,4 @@ public int hashCode() throw new UnsupportedOperationException(); } } - - - abstract static class Treap, Y> - { - protected TreapNode root; - protected final TreapNode NULL; - - public Treap() - { - NULL = new TreapNode(null); - NULL.left = NULL.right = NULL; - NULL.priority = Double.POSITIVE_INFINITY; - root = NULL; - } - - public boolean isEmpty() - { - return NULL.equals(root); - } - - public boolean contains(X val) - { - return contains(val, root); - } - - public TreapNode lower(X val) - { - return lower(val, root); - } - - public TreapNode upper(X val) - { - return upper(val, root); - } - - public TreapNode floor(X val) - { - return floor(val, root); - } - - public TreapNode ceil(X val) - { - return ceil(val, root); - } - - public boolean insert(X val) - { - int oldSize = root.size; - root = insert(new TreapNode(val), root); - return root.size > oldSize; - } - - public boolean remove(X val) - { - int oldSize = root.size; - root = remove(val, root); - return root.size < oldSize; - } - - public TreapNode getMin() - { - TreapNode node = root; - while (!NULL.equals(node.left)) { - node = node.left; - } - return node; - } - - public TreapNode getMax() - { - TreapNode node = root; - while (!NULL.equals(node.right)) { - node = node.right; - } - return node; - } - - public Y query() - { - return root.sum; - } - - public void update(X val, Y lazy, boolean dir) - { - if (dir) { - root = update(root, val, null, lazy); - } else { - root = update(root, null, val, lazy); - } - } - - protected abstract Y getVal(X val); - - protected abstract X setVal(X val, Y lazy); - - protected abstract Y add(Y a, Y b); - - protected abstract Y multiply(int a, Y b); - - protected abstract Y zero(); - - private boolean contains(X val, TreapNode node) - { - if (NULL.equals(node)) { - return false; - } - final int cmp = val.compareTo(node.val); - if (cmp < 0) { - return contains(val, node.left); - } - if (cmp > 0) { - return contains(val, node.right); - } - return true; - } - - private TreapNode lower(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp <= 0) { - return lower(val, node.left); - } else { - TreapNode ret = lower(val, node.right); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode upper(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp >= 0) { - return upper(val, node.right); - } else { - TreapNode ret = upper(val, node.left); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode floor(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp < 0) { - return floor(val, node.left); - } else { - TreapNode ret = floor(val, node.right); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode ceil(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp > 0) { - return ceil(val, node.right); - } else { - TreapNode ret = ceil(val, node.left); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode insert(TreapNode val, TreapNode node) - { - if (NULL.equals(node)) { - return val; - } - Pair pair = split(node, val.val); - node = merge(pair.lhs, val); - node = merge(node, pair.rhs); - return node; - } - - private TreapNode remove(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - Pair pair = split(node, val); - TreapNode lower = lower(val, pair.lhs); - if (NULL.equals(lower)) { - return pair.rhs; - } - return merge(split(pair.lhs, lower.val).lhs, pair.rhs); - } - - private Pair split(TreapNode node, X val) - { - if (NULL.equals(node)) { - return Pair.of(NULL, NULL); - } - node.lazyPropogate(); - final int cmp = val.compareTo(node.val); - Pair pair; - if (cmp < 0) { - pair = split(node.left, val); - node.left = pair.rhs; - pair = Pair.of(pair.lhs, node); - } else { - pair = split(node.right, val); - node.right = pair.lhs; - pair = Pair.of(node, pair.rhs); - } - node.recompute(); - return pair; - } - - private TreapNode merge(TreapNode left, TreapNode right) - { - if (NULL.equals(left)) { - return right; - } - if (NULL.equals(right)) { - return left; - } - left.lazyPropogate(); - right.lazyPropogate(); - TreapNode node; - if (left.priority < right.priority) { - left.right = merge(left.right, right); - node = left; - } else { - right.left = merge(left, right.left); - node = right; - } - node.recompute(); - return node; - } - - private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, Y lazy) - { - TreapNode left = NULL; - TreapNode right = NULL; - if (begin != null) { - Pair pair = split(node, begin); - left = pair.lhs; - node = pair.rhs; - } - if (end != null) { - Pair pair = split(node, end); - node = pair.lhs; - right = pair.rhs; - } - node.lazy = add(node.lazy, lazy); - node = merge(left, node); - node = merge(node, right); - return node; - } - - class TreapNode - { - X val; - Y sum; - Y lazy; - TreapNode left; - TreapNode right; - double priority; - int size; - - TreapNode(@Nullable X val) - { - this(val, NULL, NULL); - if (val != null) { - sum = getVal(val); - size = 1; - } - } - - TreapNode(@Nullable X val, @Nullable TreapNode left, @Nullable TreapNode right) - { - this.val = val; - this.left = left; - this.right = right; - this.priority = Math.random(); - this.sum = zero(); - this.lazy = zero(); - } - - void recompute() - { - if (NULL.equals(this)) { - return; - } - size = 1 + left.size + right.size; - sum = getVal(val); - left.lazyPropogate(); - right.lazyPropogate(); - sum = add(sum, add(left.sum, right.sum)); - } - - void lazyPropogate() - { - if (NULL.equals(this)) { - return; - } - val = setVal(val, lazy); - sum = add(sum, multiply(size, lazy)); - if (!NULL.equals(left)) { - left.lazy = add(left.lazy, lazy); - } - if (!NULL.equals(right)) { - right.lazy = add(right.lazy, lazy); - } - lazy = zero(); - } - - @Override - public boolean equals(Object that) - { - return this == that; - } - } - } - - public static class TestX implements Comparable - { - final String a; - double b; - - public TestX(String a, double b) - { - this.a = a; - this.b = b; - } - - public String getA() - { - return a; - } - - public double getB() - { - return b; - } - - public void setB(double b) - { - this.b = b; - } - - @Override - public int compareTo(TestX that) - { - return a.compareTo(that.getA()); - } - } - - public static class SegmentTreap extends Treap> - { - - static final Pair ZERO = Pair.of(0.0, 0.0); - - @Override - protected Pair getVal(SegmentAndSum val) - { - return Pair.of(val.leftSum, val.rightSum); - } - - @Override - protected SegmentAndSum setVal(SegmentAndSum val, Pair lazy) - { - val.leftSum += lazy.lhs; - val.rightSum += lazy.rhs; - return val; - } - - @Override - protected Pair zero() - { - return ZERO; - } - - @Override - protected Pair add(Pair a, Pair b) - { - return Pair.of(a.lhs + b.lhs, a.rhs + b.rhs); - } - - @Override - protected Pair multiply(int a, Pair b) - { - return Pair.of(a * b.lhs, a * b.rhs); - } - - public List toList() - { - List list = new ArrayList<>(); - accumulate(list, root); - return list; - } - - private void accumulate(List list, TreapNode node) - { - if (NULL.equals(node)) { - return; - } - node.lazyPropogate(); - accumulate(list, node.left); - list.add(node.val); - accumulate(list, node.right); - } - } - - public static class TestTreap extends Treap - { - @Override - protected Double getVal(TestX val) - { - return val.getB(); - } - - @Override - protected TestX setVal(TestX val, Double lazy) - { - val.setB(val.getB() + lazy); - return val; - } - - @Override - protected Double add(Double a, Double b) - { - return Double.sum(a, b); - } - - @Override - protected Double multiply(int a, Double b) - { - return a * b; - } - - @Override - protected Double zero() - { - return 0.0; - } - - public void print() - { - print(this.root); - System.out.println(); - } - - private void print(TreapNode node) - { - if (NULL.equals(node)) { - return; - } - print(node.left); - System.out.println(node.val.getA() + ", " + node.val.getB() + ", " + node.sum + ", " + node.lazy + ", " + node.priority); - print(node.right); - } - } } diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java new file mode 100644 index 000000000000..14150c10768f --- /dev/null +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java @@ -0,0 +1,551 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.coordinator.cost; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Ordering; +import org.apache.commons.math3.util.FastMath; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.Treap; +import org.apache.druid.java.util.common.granularity.DurationGranularity; +import org.apache.druid.java.util.common.guava.Comparators; +import org.apache.druid.server.coordinator.CostBalancerStrategy; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.SegmentId; +import org.joda.time.Interval; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.ListIterator; +import java.util.NavigableMap; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * SegmentsCostCache provides faster way to calculate cost function proposed in {@link CostBalancerStrategy}. + * See https://github.com/apache/druid/pull/2972 for more details about the cost function. + * + * Joint cost for two segments (you can make formulas below readable by copy-pasting to + * https://www.codecogs.com/latex/eqneditor.php): + * + * cost(Y, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy + * or + * cost(Y, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) + * if x_0 <= x_1 <= y_0 <= y_1 + * (*) lambda coefficient is omitted for simplicity. + * + * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as: + * + * cost(Y, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) + * if xi_0 <= xi_1 <= y_0 <= y_1 + * and + * cost(Y, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) + * if y_0 <= y_1 <= xi_0 <= xi_1 + * + * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}: + * + * 1) \sum (e^{xi_0} - e^{xi_1}) -> leftSum + * 2) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) -> rightSum + * + * so that calculation of joint cost function for segment S_y became a O(1 + m) complexity task, where m + * is the number of segments in {S_xi} that overlaps S_y. + * + * Segments are stored in buckets. Bucket is a subset of segments contained in SegmentsCostCache, so that + * startTime of all segments inside a bucket are in the same time interval (with some granularity): + * + * |------------------------|--------------------------|-----------------------|-------- .... + * t_0 t_0+D t_0 + 2D t0 + 3D .... + * S_x1 S_x2 S_x3 S_x4 S_x5 S_x6 S_x7 S_x8 S_x9 + * bucket1 bucket2 bucket3 + * + * Reasons to store segments in Buckets: + * + * 1) Cost function tends to 0 as distance between segments' intervals increases; buckets + * are used to avoid redundant 0 calculations for thousands of times + * 2) To reduce number of calculations when segment is added or removed from SegmentsCostCache + * 3) To avoid infinite values during exponents calculations + * + */ +public class SegmentsCostCacheV2 +{ + /** + * HALF_LIFE_DAYS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing. + * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals + * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. + */ + private static final double HALF_LIFE_DAYS = 1.0; + private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS; + private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA; + + /** + * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" + * from each other and thus cost(Y,Y) ~ 0 for these segments + */ + private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); + + /** + * Bucket interval defines duration granularity for segment buckets. Number of buckets control the trade-off + * between updates (add/remove segment operation) and joint cost calculation: + * 1) updates complexity is increasing when number of buckets is decreasing (as buckets contain more segments) + * 2) joint cost calculation complexity is increasing with increasing of buckets number + */ + private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); + private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); + + private static final Comparator INTERVAL_COMPARATOR = Comparators.intervalsByStartThenEnd(); + + private static final Comparator BUCKET_INTERVAL_COMPARATOR = + Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); + + private static final Ordering INTERVAL_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); + private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); + + private final ArrayList sortedBuckets; + private final ArrayList intervals; + + SegmentsCostCacheV2(ArrayList sortedBuckets) + { + this.sortedBuckets = Preconditions.checkNotNull(sortedBuckets, "buckets should not be null"); + this.intervals = sortedBuckets.stream().map(Bucket::getInterval).collect(Collectors.toCollection(ArrayList::new)); + Preconditions.checkArgument( + BUCKET_ORDERING.isOrdered(sortedBuckets), + "buckets must be ordered by interval" + ); + } + + public double cost(DataSegment segment) + { + double cost = 0.0; + int index = Collections.binarySearch(intervals, segment.getInterval(), Comparators.intervalsByStartThenEnd()); + index = (index >= 0) ? index : -index - 1; + + for (ListIterator it = sortedBuckets.listIterator(index); it.hasNext(); ) { + Bucket bucket = it.next(); + if (!bucket.inCalculationInterval(segment)) { + break; + } + cost += bucket.cost(segment); + } + + for (ListIterator it = sortedBuckets.listIterator(index); it.hasPrevious(); ) { + Bucket bucket = it.previous(); + if (!bucket.inCalculationInterval(segment)) { + break; + } + cost += bucket.cost(segment); + } + + return cost; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private final NavigableMap buckets = new TreeMap<>(Comparators.intervalsByStartThenEnd()); + + public Builder addSegment(DataSegment segment) + { + Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); + builder.addSegment(segment); + return this; + } + + public Builder removeSegment(DataSegment segment) + { + Interval interval = getBucketInterval(segment); + buckets.computeIfPresent( + interval, + // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets + // map + (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder + ); + return this; + } + + public boolean isEmpty() + { + return buckets.isEmpty(); + } + + public SegmentsCostCacheV2 build() + { + return new SegmentsCostCacheV2( + buckets + .values() + .stream() + .map(Bucket.Builder::build) + .collect(Collectors.toCollection(ArrayList::new)) + ); + } + + private static Interval getBucketInterval(DataSegment segment) + { + return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart()); + } + } + + static class Bucket + { + private final Interval interval; + private final Interval calculationInterval; + private final ArrayList sortedIntervals; + private final double[] leftSum; + private final double[] rightSum; + + private final double[] cumStart; + private final double[] cumStartExp; + private final double[] cumStartExpInv; + private final double[] cumEnd; + private final double[] cumEndExp; + private final double[] cumEndExpInv; + + Bucket(Interval interval, ArrayList sortedIntervals, double[] leftSum, double[] rightSum) + { + this.interval = Preconditions.checkNotNull(interval, "interval"); + this.sortedIntervals = Preconditions.checkNotNull(sortedIntervals, "sortedSegments"); + this.leftSum = Preconditions.checkNotNull(leftSum, "leftSum"); + this.rightSum = Preconditions.checkNotNull(rightSum, "rightSum"); + Preconditions.checkArgument(sortedIntervals.size() == leftSum.length && sortedIntervals.size() == rightSum.length); + Preconditions.checkArgument(INTERVAL_ORDERING.isOrdered(sortedIntervals)); + this.calculationInterval = new Interval( + interval.getStart().minus(LIFE_THRESHOLD), + interval.getEnd().plus(LIFE_THRESHOLD) + ); + + int n = leftSum.length; + + cumStart = new double[n + 1]; + cumStartExp = new double[n + 1]; + cumStartExpInv = new double[n + 1]; + cumEnd = new double[n + 1]; + cumEndExp = new double[n + 1]; + cumEndExpInv = new double[n + 1]; + for (int i = 0; i < n; i++) { + double start = convertStart(sortedIntervals.get(i), interval); + cumStart[i + 1] = cumStart[i] + start; + cumStartExp[i + 1] = cumStartExp[i] + FastMath.exp(start); + cumStartExpInv[i + 1] = cumStartExpInv[i] + FastMath.exp(-start); + + double end = convertEnd(sortedIntervals.get(i), interval); + cumEnd[i + 1] = cumEnd[i] + end; + cumEndExp[i + 1] = cumEndExp[i] + FastMath.exp(end); + cumEndExpInv[i + 1] = cumEndExpInv[i] + FastMath.exp(-end); + } + } + + Interval getInterval() + { + return interval; + } + + boolean inCalculationInterval(DataSegment dataSegment) + { + return calculationInterval.overlaps(dataSegment.getInterval()); + } + + double cost(DataSegment dataSegment) + { + // cost is calculated relatively to bucket start (which is considered as 0) + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); + + // avoid calculation for segments outside of LIFE_THRESHOLD + if (!inCalculationInterval(dataSegment)) { + throw new ISE("Segment is not within calculation interval"); + } + + int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); + index = (index >= 0) ? index : -index - 1; + return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + } + + private double leftCost(DataSegment dataSegment, double t0, double t1, int index) + { + if (index - 1 < 0) { + return 0; + } + double exp0 = FastMath.exp(t0); + double expInv0 = 1 / exp0; + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double leftCost = 0.0; + // add to cost all left-overlapping segments + int rightBound = index - 1; + int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval()); + leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); + leftCost -= 2 * (rightBound - leftBound + 1) * t0; + leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); + // add left-non-overlapping segments + if (leftBound > 0) { + leftCost += leftSum[leftBound - 1] * (expInv1 - expInv0); + } + return leftCost; + } + + private double rightCost(DataSegment dataSegment, double t0, double t1, int index) + { + int n = leftSum.length; + if (index >= n) { + return 0; + } + double exp0 = FastMath.exp(t0); + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double rightCost = 0.0; + int leftBound = index; + int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval()); + // add all right-overlapping segments + rightCost += 2 * (rightBound - leftBound + 1) * t1; + rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); + rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); + rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + // add right-non-overlapping segments + if (rightBound + 1 < n) { + rightCost += rightSum[rightBound + 1] * (exp0 - exp1); + } + return rightCost; + } + + private int leftBoundary(int l, int r, Interval interval) + { + if (l == r) { + return interval.overlaps(sortedIntervals.get(l)) ? l : r + 1; + } + int m = (l + r) / 2; + if (interval.overlaps(sortedIntervals.get(m))) { + return leftBoundary(l, m, interval); + } else { + return leftBoundary(m + 1, r, interval); + } + } + + private int rightBoundary(int l, int r, Interval interval) + { + if (l == r) { + return interval.overlaps(sortedIntervals.get(r)) ? r : l - 1; + } + int m = (l + r + 1) / 2; + if (interval.overlaps(sortedIntervals.get(m))) { + return rightBoundary(m, r, interval); + } else { + return rightBoundary(l, m - 1, interval); + } + } + + private static double convertStart(Interval interval, Interval reference) + { + return toLocalInterval(interval.getStartMillis(), reference); + } + + private static double convertEnd(Interval interval, Interval reference) + { + return toLocalInterval(interval.getEndMillis(), reference); + } + + private static double toLocalInterval(long millis, Interval interval) + { + return millis / MILLIS_FACTOR - interval.getStartMillis() / MILLIS_FACTOR; + } + + public static Builder builder(Interval interval) + { + return new Builder(interval); + } + + static class Builder + { + protected final Interval interval; + private SegmentTreap treap = new SegmentTreap(); + public Builder(Interval interval) + { + this.interval = interval; + } + + public Builder addSegment(DataSegment dataSegment) + { + if (!interval.contains(dataSegment.getInterval().getStartMillis())) { + throw new ISE("Failed to add segment to bucket: interval is not covered by this bucket"); + } + + // all values are pre-computed relatively to bucket start (which is considered as 0) + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); + + double leftValue = FastMath.exp(t0) - FastMath.exp(t1); + double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); + + SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, leftValue, rightValue); + + // left/right value should be added to left/right sums for elements greater/lower than current segment + treap.update(segmentAndSum, Pair.of(leftValue, 0.0), true); + treap.update(segmentAndSum, Pair.of(0.0, rightValue), false); + + // leftSum_i = leftValue_i + \sum leftValue_j = leftValue_i + leftSum_{i-1} , j < i + SegmentAndSum lower = treap.lower(segmentAndSum); + if (lower != null) { + segmentAndSum.leftSum = leftValue + lower.leftSum; + } + + // rightSum_i = rightValue_i + \sum rightValue_j = rightValue_i + rightSum_{i+1} , j > i + SegmentAndSum higher = treap.upper(segmentAndSum); + if (higher != null) { + segmentAndSum.rightSum = rightValue + higher.rightSum; + } + + if (!treap.insert(segmentAndSum)) { + throw new ISE("expect new segment"); + } + return this; + } + + public Builder removeSegment(DataSegment dataSegment) + { + SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, 0.0, 0.0); + + if (!treap.remove(segmentAndSum)) { + return this; + } + + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); + + double leftValue = FastMath.exp(t0) - FastMath.exp(t1); + double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); + + treap.update(segmentAndSum, Pair.of(-leftValue, 0.0), true); + treap.update(segmentAndSum, Pair.of(0.0, -rightValue), false); + + return this; + } + + public boolean isEmpty() + { + return treap.isEmpty(); + } + + public Bucket build() + { + ArrayList intervalsList = new ArrayList<>(); + double[] leftSum = new double[treap.size()]; + double[] rightSum = new double[treap.size()]; + int i = 0; + for (SegmentAndSum segmentAndSum : treap.toList()) { + intervalsList.add(segmentAndSum.interval); + leftSum[i] = segmentAndSum.leftSum; + rightSum[i] = segmentAndSum.rightSum; + ++i; + } + treap = null; + long bucketEndMillis = intervalsList + .stream() + .mapToLong(Interval::getEndMillis) + .max() + .orElseGet(interval::getEndMillis); + return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervalsList, leftSum, rightSum); + } + } + } + + static class SegmentAndSum implements Comparable + { + private final Interval interval; + private final SegmentId segmentId; + private double leftSum; + private double rightSum; + + SegmentAndSum(DataSegment dataSegment, double leftSum, double rightSum) + { + this.interval = dataSegment.getInterval(); + this.segmentId = dataSegment.getId(); + this.leftSum = leftSum; + this.rightSum = rightSum; + } + + @Override + public int compareTo(SegmentAndSum o) + { + int c = Comparators.intervalsByStartThenEnd().compare(interval, o.interval); + return c != 0 ? c : segmentId.compareTo(o.segmentId); + } + + @Override + public boolean equals(Object obj) + { + throw new UnsupportedOperationException("Use SegmentAndSum.compareTo()"); + } + + @Override + public int hashCode() + { + throw new UnsupportedOperationException(); + } + } + + public static class SegmentTreap extends Treap> + { + + static final Pair ZERO = Pair.of(0.0, 0.0); + + @Override + protected Pair getVal(SegmentAndSum val) + { + return Pair.of(val.leftSum, val.rightSum); + } + + @Override + protected SegmentAndSum setVal(SegmentAndSum val, Pair lazy) + { + val.leftSum += lazy.lhs; + val.rightSum += lazy.rhs; + return val; + } + + @Override + protected Pair zero() + { + return ZERO; + } + + @Override + protected Pair add(Pair a, Pair b) + { + return Pair.of(a.lhs + b.lhs, a.rhs + b.rhs); + } + + @Override + protected Pair multiply(int a, Pair b) + { + return Pair.of(a * b.lhs, a * b.rhs); + } + } +} diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java new file mode 100644 index 000000000000..706e42011a68 --- /dev/null +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.coordinator.cost; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Ordering; +import org.apache.commons.math3.util.FastMath; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.granularity.DurationGranularity; +import org.apache.druid.java.util.common.guava.Comparators; +import org.apache.druid.server.coordinator.CostBalancerStrategy; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.SegmentId; +import org.joda.time.Interval; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.ListIterator; +import java.util.NavigableMap; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * SegmentsCostCache provides faster way to calculate cost function proposed in {@link CostBalancerStrategy}. + * See https://github.com/apache/druid/pull/2972 for more details about the cost function. + * + * Joint cost for two segments (you can make formulas below readable by copy-pasting to + * https://www.codecogs.com/latex/eqneditor.php): + * + * cost(Y, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy + * or + * cost(Y, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) + * if x_0 <= x_1 <= y_0 <= y_1 + * (*) lambda coefficient is omitted for simplicity. + * + * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as: + * + * cost(Y, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) + * if xi_0 <= xi_1 <= y_0 <= y_1 + * and + * cost(Y, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) + * if y_0 <= y_1 <= xi_0 <= xi_1 + * + * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}: + * + * 1) \sum (e^{xi_0} - e^{xi_1}) -> leftSum + * 2) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) -> rightSum + * + * so that calculation of joint cost function for segment S_y became a O(1 + m) complexity task, where m + * is the number of segments in {S_xi} that overlaps S_y. + * + * Segments are stored in buckets. Bucket is a subset of segments contained in SegmentsCostCache, so that + * startTime of all segments inside a bucket are in the same time interval (with some granularity): + * + * |------------------------|--------------------------|-----------------------|-------- .... + * t_0 t_0+D t_0 + 2D t0 + 3D .... + * S_x1 S_x2 S_x3 S_x4 S_x5 S_x6 S_x7 S_x8 S_x9 + * bucket1 bucket2 bucket3 + * + * Reasons to store segments in Buckets: + * + * 1) Cost function tends to 0 as distance between segments' intervals increases; buckets + * are used to avoid redundant 0 calculations for thousands of times + * 2) To reduce number of calculations when segment is added or removed from SegmentsCostCache + * 3) To avoid infinite values during exponents calculations + * + */ +public class SegmentsCostCacheV3 +{ + /** + * HALF_LIFE_DAYS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing. + * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals + * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. + */ + private static final double HALF_LIFE_DAYS = 1.0; + private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS; + private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA; + + /** + * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" + * from each other and thus cost(Y,Y) ~ 0 for these segments + */ + private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); + + /** + * Bucket interval defines duration granularity for segment buckets. Number of buckets control the trade-off + * between updates (add/remove segment operation) and joint cost calculation: + * 1) updates complexity is increasing when number of buckets is decreasing (as buckets contain more segments) + * 2) joint cost calculation complexity is increasing with increasing of buckets number + */ + private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); + private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); + + private static final Comparator INTERVAL_START_COMPARATOR = Comparators.intervalsByStartThenEnd(); + + private static final Comparator INTERVAL_END_COMPARATOR = Comparators.intervalsByEndThenStart(); + + private static final Comparator BUCKET_INTERVAL_COMPARATOR = + Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); + + private static final Ordering INTERVAL_START_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); + + private static final Ordering INTERVAL_END_ORDERING = Ordering.from(Comparators.intervalsByEndThenStart()); + private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); + + private final ArrayList sortedBuckets; + private final ArrayList intervals; + + SegmentsCostCacheV3(ArrayList sortedBuckets) + { + this.sortedBuckets = Preconditions.checkNotNull(sortedBuckets, "buckets should not be null"); + this.intervals = sortedBuckets.stream().map(Bucket::getInterval).collect(Collectors.toCollection(ArrayList::new)); + Preconditions.checkArgument( + BUCKET_ORDERING.isOrdered(sortedBuckets), + "buckets must be ordered by interval" + ); + } + + public double cost(DataSegment segment) + { + double cost = 0.0; + int index = Collections.binarySearch(intervals, segment.getInterval(), Comparators.intervalsByStartThenEnd()); + index = (index >= 0) ? index : -index - 1; + + for (ListIterator it = sortedBuckets.listIterator(index); it.hasNext(); ) { + Bucket bucket = it.next(); + if (!bucket.inCalculationInterval(segment)) { + break; + } + cost += bucket.cost(segment); + } + + for (ListIterator it = sortedBuckets.listIterator(index); it.hasPrevious(); ) { + Bucket bucket = it.previous(); + if (!bucket.inCalculationInterval(segment)) { + break; + } + cost += bucket.cost(segment); + } + + return cost; + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private final NavigableMap buckets = new TreeMap<>(Comparators.intervalsByStartThenEnd()); + + public Builder addSegment(DataSegment segment) + { + Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); + builder.addSegment(segment); + return this; + } + + public Builder removeSegment(DataSegment segment) + { + Interval interval = getBucketInterval(segment); + buckets.computeIfPresent( + interval, + // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets + // map + (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder + ); + return this; + } + + public boolean isEmpty() + { + return buckets.isEmpty(); + } + + public SegmentsCostCacheV3 build() + { + return new SegmentsCostCacheV3( + buckets + .values() + .stream() + .map(Bucket.Builder::build) + .collect(Collectors.toCollection(ArrayList::new)) + ); + } + + private static Interval getBucketInterval(DataSegment segment) + { + return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart()); + } + } + + static class Bucket + { + private final Interval interval; + private final Interval calculationInterval; + private final List intervalStartSortList; + private final List intervalEndSortList; + + private final double[] cumStart; + private final double[] cumStartExp; + private final double[] cumStartExpInv; + private final double[] cumEnd; + private final double[] cumEndExp; + private final double[] cumEndExpInv; + + Bucket(Interval interval, List intervalStartSortList, List intervalEndSortList) + { + this.interval = Preconditions.checkNotNull(interval, "interval"); + this.intervalStartSortList = Preconditions.checkNotNull(intervalStartSortList, "intervalStartSortList"); + this.intervalEndSortList = Preconditions.checkNotNull(intervalEndSortList, "intervalEndSortList"); + Preconditions.checkArgument(intervalStartSortList.size() == intervalEndSortList.size()); + Preconditions.checkArgument(INTERVAL_START_ORDERING.isOrdered(intervalStartSortList)); + Preconditions.checkArgument(INTERVAL_END_ORDERING.isOrdered(intervalEndSortList)); + this.calculationInterval = new Interval( + interval.getStart().minus(LIFE_THRESHOLD), + interval.getEnd().plus(LIFE_THRESHOLD) + ); + + int n = intervalStartSortList.size(); + double exp; + double expInv; + + cumStart = new double[n + 1]; + cumStartExp = new double[n + 1]; + cumStartExpInv = new double[n + 1]; + cumEnd = new double[n + 1]; + cumEndExp = new double[n + 1]; + cumEndExpInv = new double[n + 1]; + for (int i = 0; i < n; i++) { + double start = convertStart(intervalStartSortList.get(i), interval); + exp = FastMath.exp(start); + expInv = FastMath.exp(-start); + cumStart[i + 1] = cumStart[i] + start; + cumStartExp[i + 1] = cumStartExp[i] + exp; + cumStartExpInv[i + 1] = cumStartExpInv[i] + expInv; + + double end = convertEnd(intervalEndSortList.get(i), interval); + exp = FastMath.exp(end); + expInv = FastMath.exp(-end); + cumEnd[i + 1] = cumEnd[i] + end; + cumEndExp[i + 1] = cumEndExp[i] + exp; + cumEndExpInv[i + 1] = cumEndExpInv[i] + expInv; + } + } + + Interval getInterval() + { + return interval; + } + + boolean inCalculationInterval(DataSegment dataSegment) + { + return calculationInterval.overlaps(dataSegment.getInterval()); + } + + double cost(DataSegment dataSegment) + { + // cost is calculated relatively to bucket start (which is considered as 0) + double t0 = convertStart(dataSegment.getInterval(), interval); + double t1 = convertEnd(dataSegment.getInterval(), interval); + + // avoid calculation for segments outside of LIFE_THRESHOLD + if (!inCalculationInterval(dataSegment)) { + throw new ISE("Segment is not within calculation interval"); + } + + int index = Collections.binarySearch(intervalStartSortList, dataSegment.getInterval(), INTERVAL_START_COMPARATOR); + index = (index >= 0) ? index : -index - 1; + return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + } + + private double leftCost(DataSegment dataSegment, double t0, double t1, int index) + { + if (index - 1 < 0) { + return 0; + } + double exp0 = FastMath.exp(t0); + double expInv0 = 1 / exp0; + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double leftCost = 0.0; + // add to cost all left-overlapping segments + int rightBound = index - 1; + int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval(), intervalStartSortList); + leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); + leftCost -= 2 * (rightBound - leftBound + 1) * t0; + leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); + // add left-non-overlapping segments + if (leftBound > 0) { + leftCost += cumStartExp[leftBound] * (expInv1 - expInv0); + leftCost -= cumEndExp[leftBound] * (expInv1 - expInv0); + } + return leftCost; + } + + private double rightCost(DataSegment dataSegment, double t0, double t1, int index) + { + int n = intervalStartSortList.size(); + if (index >= n) { + return 0; + } + double exp0 = FastMath.exp(t0); + double exp1 = FastMath.exp(t1); + double expInv1 = 1 / exp1; + double rightCost = 0.0; + int leftBound = index; + int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval(), intervalStartSortList); + // add all right-overlapping segments + rightCost += 2 * (rightBound - leftBound + 1) * t1; + rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); + rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); + rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); + rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); + // add right-non-overlapping segments + rightBound++; + if (rightBound <= n) { + rightCost += (cumEndExpInv[n] - cumEndExpInv[rightBound]) * (exp0 - exp1); + rightCost -= (cumStartExpInv[n] - cumStartExpInv[rightBound]) * (exp0 - exp1); + } + return rightCost; + } + + private int leftBoundary(int l, int r, Interval interval, List intervalList) + { + if (l == r) { + return interval.overlaps(intervalList.get(l)) ? l : r + 1; + } + int m = (l + r) / 2; + if (interval.overlaps(intervalList.get(m))) { + return leftBoundary(l, m, interval, intervalList); + } else { + return leftBoundary(m + 1, r, interval, intervalList); + } + } + + private int rightBoundary(int l, int r, Interval interval, List intervalList) + { + if (l == r) { + return interval.overlaps(intervalList.get(r)) ? r : l - 1; + } + int m = (l + r + 1) / 2; + if (interval.overlaps(intervalList.get(m))) { + return rightBoundary(m, r, interval, intervalList); + } else { + return rightBoundary(l, m - 1, interval, intervalList); + } + } + + private static double convertStart(Interval interval, Interval reference) + { + return toLocalInterval(interval.getStartMillis(), reference); + } + + private static double convertEnd(Interval interval, Interval reference) + { + return toLocalInterval(interval.getEndMillis(), reference); + } + + private static double toLocalInterval(long millis, Interval interval) + { + return millis / MILLIS_FACTOR - interval.getStartMillis() / MILLIS_FACTOR; + } + + public static Builder builder(Interval interval) + { + return new Builder(interval); + } + + static class Builder + { + protected final Interval interval; + private final Set segmentSet = new HashSet<>(); + public Builder(Interval interval) + { + this.interval = interval; + } + + public Builder addSegment(DataSegment dataSegment) + { + if (!interval.contains(dataSegment.getInterval().getStartMillis())) { + throw new ISE("Failed to add segment to bucket: interval is not covered by this bucket"); + } + + if (!segmentSet.add(dataSegment.getId())) { + throw new ISE("expect new segment"); + } + + return this; + } + + public Builder removeSegment(DataSegment dataSegment) + { + segmentSet.remove(dataSegment.getId()); + + return this; + } + + public boolean isEmpty() + { + return segmentSet.isEmpty(); + } + + public Bucket build() + { + List intervalsStartSortList = segmentSet.stream() + .map(SegmentId::getInterval) + .sorted(INTERVAL_START_COMPARATOR) + .collect(Collectors.toList()); + + List intervalsEndSortList = segmentSet.stream() + .map(SegmentId::getInterval) + .sorted(INTERVAL_END_COMPARATOR) + .collect(Collectors.toList()); + + long bucketEndMillis = intervalsEndSortList.get(intervalsEndSortList.size() - 1).getEndMillis(); + bucketEndMillis = Long.max(bucketEndMillis, interval.getEndMillis()); + + segmentSet.clear(); + + return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), + intervalsStartSortList, + intervalsEndSortList); + } + } + } +} diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index 446ac5cf3359..7fb331c69e1c 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -28,9 +28,7 @@ import java.util.ArrayList; import java.util.List; -import java.util.NavigableSet; import java.util.Random; -import java.util.TreeSet; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; @@ -147,75 +145,6 @@ public void multipleSegmentsCostTest() Assert.assertEquals(0.001574717989780039, segmentCost, EPSILON); } - @Test - public void treapBenchmarkTest() - { - final int n = 20000; - - List ids = new ArrayList<>(); - List vals = new ArrayList<>(); - for (int i = 0; i < n; i++) { - ids.add(UUID.randomUUID().toString()); - vals.add((double) i); - } - - System.out.println("Treap:"); - long start = System.currentTimeMillis(); - SegmentsCostCache.TestTreap treap = new SegmentsCostCache.TestTreap(); - for (int i = 0; i < n; i++) { - SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); - treap.update(val, 1.0, false); - treap.update(val, 3.0, true); - treap.insert(val); - } - System.out.println(treap.query()); - long end = System.currentTimeMillis(); - for (int i = n - 1; i >= 0; i -= 3) { - SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); - treap.remove(val); - treap.update(val, -1.0, false); - treap.update(val, -3.0, true); - } - System.out.println(treap.query()); - System.out.println(end - start + " ms"); - - System.out.println("TreeSet:"); - start = System.currentTimeMillis(); - NavigableSet set = new TreeSet<>(); - for (int i = 0; i < n; i++) { - SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); - for (SegmentsCostCache.TestX l : set.headSet(val)) { - l.setB(l.getB() + 1.0); - } - for (SegmentsCostCache.TestX u : set.tailSet(val)) { - u.setB(u.getB() + 3.0); - } - set.add(val); - } - double ans = 0; - for (SegmentsCostCache.TestX val : set) { - ans += val.getB(); - } - System.out.println(ans); - for (int i = n - 1; i >= 0; i -= 3) { - SegmentsCostCache.TestX val = new SegmentsCostCache.TestX(ids.get(i), vals.get(i)); - set.remove(val); - for (SegmentsCostCache.TestX l : set.headSet(val)) { - l.setB(l.getB() - 1.0); - } - for (SegmentsCostCache.TestX u : set.tailSet(val)) { - u.setB(u.getB() - 3.0); - } - } - ans = 0; - for (SegmentsCostCache.TestX val : set) { - ans += val.getB(); - } - System.out.println(ans); - end = System.currentTimeMillis(); - System.out.println(end - start + " ms"); - } - @Test public void randomSegmentsCostTest() { @@ -231,9 +160,15 @@ public void randomSegmentsCostTest() REFERENCE_TIME.minusHours(1), REFERENCE_TIME.plusHours(25) )); + + long start = System.currentTimeMillis(); + dataSegments.forEach(prototype::addSegment); SegmentsCostCache.Bucket bucket = prototype.build(); + long end = System.currentTimeMillis(); + System.out.println(end - start); + double cost = bucket.cost(referenceSegment); Assert.assertEquals(0.7065117101966677, cost, EPSILON); } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java new file mode 100644 index 000000000000..86f889e20e8b --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.coordinator.cost; + +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.timeline.DataSegment; +import org.joda.time.DateTime; +import org.joda.time.Interval; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +public class SegmentsCostCacheV2Test +{ + + private static final String DATA_SOURCE = "dataSource"; + private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); + private static final double EPSILON = 0.00000001; + + @Test + public void segmentCacheTest() + { + SegmentsCostCacheV2.Builder cacheBuilder = SegmentsCostCacheV2.builder(); + cacheBuilder.addSegment(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100)); + SegmentsCostCacheV2 cache = cacheBuilder.build(); + Assert.assertEquals( + 7.8735899489011E-4, + cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100)), + EPSILON + ); + } + + @Test + public void notInCalculationIntervalCostTest() + { + SegmentsCostCacheV2.Builder cacheBuilder = SegmentsCostCacheV2.builder(); + cacheBuilder.addSegment( + createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100) + ); + SegmentsCostCacheV2 cache = cacheBuilder.build(); + Assert.assertEquals( + 0, + cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), 100)), + EPSILON + ); + } + + @Test + public void twoSegmentsCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); + + SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + SegmentsCostCacheV2.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + Assert.assertEquals(7.8735899489011E-4, segmentCost, EPSILON); + } + + @Test + public void calculationIntervalTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment( + DATA_SOURCE, + shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), + 100 + ); + + SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder( + new Interval(REFERENCE_TIME.minusHours(5), REFERENCE_TIME.plusHours(5)) + ); + prototype.addSegment(segmentA); + SegmentsCostCacheV2.Bucket bucket = prototype.build(); + + Assert.assertTrue(bucket.inCalculationInterval(segmentA)); + Assert.assertFalse(bucket.inCalculationInterval(segmentB)); + } + + @Test + public void sameSegmentCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + + SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + SegmentsCostCacheV2.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + Assert.assertEquals(8.26147353873985E-4, segmentCost, EPSILON); + } + + @Test + public void multipleSegmentsCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentC = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 2), 100); + + SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + prototype.addSegment(segmentC); + SegmentsCostCacheV2.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + + Assert.assertEquals(0.001574717989780039, segmentCost, EPSILON); + } + + @Test + public void randomSegmentsCostTest() + { + List dataSegments = new ArrayList<>(1000); + Random random = new Random(1); + for (int i = 0; i < 1000; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, random.nextInt(20)), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(1), + REFERENCE_TIME.plusHours(25) + )); + + long start = System.currentTimeMillis(); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCacheV2.Bucket bucket = prototype.build(); + + long end = System.currentTimeMillis(); + System.out.println(end - start); + + double cost = bucket.cost(referenceSegment); + Assert.assertEquals(0.7065117101966677, cost, EPSILON); + } + + private static Interval shifted1HInterval(DateTime REFERENCE_TIME, int shiftInHours) + { + return new Interval( + REFERENCE_TIME.plusHours(shiftInHours), + REFERENCE_TIME.plusHours(shiftInHours + 1) + ); + } + + private static DataSegment createSegment(String dataSource, Interval interval, long size) + { + return new DataSegment( + dataSource, + interval, + UUID.randomUUID().toString(), + new ConcurrentHashMap<>(), + new ArrayList<>(), + new ArrayList<>(), + null, + 0, + size + ); + } +} diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java new file mode 100644 index 000000000000..6ac8dbd142e4 --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.coordinator.cost; + +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.timeline.DataSegment; +import org.joda.time.DateTime; +import org.joda.time.Interval; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +public class SegmentsCostCacheV3Test +{ + + private static final String DATA_SOURCE = "dataSource"; + private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); + private static final double EPSILON = 0.00000001; + + @Test + public void segmentCacheTest() + { + SegmentsCostCacheV3.Builder cacheBuilder = SegmentsCostCacheV3.builder(); + cacheBuilder.addSegment(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100)); + SegmentsCostCacheV3 cache = cacheBuilder.build(); + Assert.assertEquals( + 7.8735899489011E-4, + cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100)), + EPSILON + ); + } + + @Test + public void notInCalculationIntervalCostTest() + { + SegmentsCostCacheV3.Builder cacheBuilder = SegmentsCostCacheV3.builder(); + cacheBuilder.addSegment( + createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100) + ); + SegmentsCostCacheV3 cache = cacheBuilder.build(); + Assert.assertEquals( + 0, + cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), 100)), + EPSILON + ); + } + + @Test + public void twoSegmentsCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + Assert.assertEquals(7.8735899489011E-4, segmentCost, EPSILON); + } + + @Test + public void calculationIntervalTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment( + DATA_SOURCE, + shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), + 100 + ); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder( + new Interval(REFERENCE_TIME.minusHours(5), REFERENCE_TIME.plusHours(5)) + ); + prototype.addSegment(segmentA); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + Assert.assertTrue(bucket.inCalculationInterval(segmentA)); + Assert.assertFalse(bucket.inCalculationInterval(segmentB)); + } + + @Test + public void sameSegmentCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + Assert.assertEquals(8.26147353873985E-4, segmentCost, EPSILON); + } + + @Test + public void multipleSegmentsCostTest() + { + DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); + DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); + DataSegment segmentC = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 2), 100); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(5), + REFERENCE_TIME.plusHours(5) + )); + + prototype.addSegment(segmentA); + prototype.addSegment(segmentC); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + double segmentCost = bucket.cost(segmentB); + + Assert.assertEquals(0.001574717989780039, segmentCost, EPSILON); + } + + @Test + public void randomSegmentsCostTest() + { + List dataSegments = new ArrayList<>(1000); + Random random = new Random(1); + for (int i = 0; i < 1000; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, random.nextInt(20)), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(1), + REFERENCE_TIME.plusHours(25) + )); + + long start = System.currentTimeMillis(); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + long end = System.currentTimeMillis(); + System.out.println(end - start); + + double cost = bucket.cost(referenceSegment); + Assert.assertEquals(0.7065117101966677, cost, EPSILON); + } + + private static Interval shifted1HInterval(DateTime REFERENCE_TIME, int shiftInHours) + { + return new Interval( + REFERENCE_TIME.plusHours(shiftInHours), + REFERENCE_TIME.plusHours(shiftInHours + 1) + ); + } + + private static DataSegment createSegment(String dataSource, Interval interval, long size) + { + return new DataSegment( + dataSource, + interval, + UUID.randomUUID().toString(), + new ConcurrentHashMap<>(), + new ArrayList<>(), + new ArrayList<>(), + null, + 0, + size + ); + } +} From 7ae1a6a8e630ee805676d60674081f22e73fca39 Mon Sep 17 00:00:00 2001 From: Amatya Date: Sat, 16 Apr 2022 02:08:54 +0530 Subject: [PATCH 08/13] revert unnecessary changes --- .idea/misc.xml | 14 ++++---------- .../server/coordinator/cost/SegmentsCostCache.java | 2 +- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/.idea/misc.xml b/.idea/misc.xml index c6a0d654dbd4..bf2061d7392d 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -46,7 +46,7 @@ @@ -87,10 +84,7 @@ - + - - - \ No newline at end of file + diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 4b5432c8cd11..9271de28425b 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -431,7 +431,7 @@ public int compareTo(SegmentAndSum o) @Override public boolean equals(Object obj) { - throw new UnsupportedOperationException("Use IntervalAndSum.compareTo()"); + throw new UnsupportedOperationException("Use SegmentAndSum.compareTo()"); } @Override From 45dcc148347339ca858d8a69b832afed4a4e6741 Mon Sep 17 00:00:00 2001 From: Amatya Date: Sat, 16 Apr 2022 12:31:16 +0530 Subject: [PATCH 09/13] Make SegmentsCostCache equivalent to CostBalancerStrategy --- .../coordinator/cost/SegmentsCostCache.java | 16 +-- .../coordinator/cost/SegmentsCostCacheV3.java | 13 +- .../cost/SegmentsCostCacheTest.java | 123 +++++++++++++----- .../cost/SegmentsCostCacheV3Test.java | 123 +++++++++++++----- 4 files changed, 202 insertions(+), 73 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 9271de28425b..7af58348345c 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -93,9 +93,10 @@ public class SegmentsCostCache * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. */ - private static final double HALF_LIFE_DAYS = 1.0; - private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS; - private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA; + private static final double HALF_LIFE_HOURS = 24.0; + private static final double LAMBDA = Math.log(2) / HALF_LIFE_HOURS; + private static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); + private static final double MILLIS_FACTOR = TimeUnit.HOURS.toMillis(1) / LAMBDA; /** * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" @@ -156,7 +157,7 @@ public double cost(DataSegment segment) cost += bucket.cost(segment); } - return cost; + return cost * NORMALIZATION_FACTOR; } public static Builder builder() @@ -262,17 +263,12 @@ private double addLeftCost(DataSegment dataSegment, double t0, double t1, int in double leftCost = 0.0; // add to cost all left-overlapping segments int leftIndex = index - 1; - while (leftIndex >= 0 - && sortedSegments.get(leftIndex).getInterval().overlaps(dataSegment.getInterval())) { + while (leftIndex >= 0) { double start = convertStart(sortedSegments.get(leftIndex), interval); double end = convertEnd(sortedSegments.get(leftIndex), interval); leftCost += CostBalancerStrategy.intervalCost(end - start, t0 - start, t1 - start); --leftIndex; } - // add left-non-overlapping segments - if (leftIndex >= 0) { - leftCost += leftSum[leftIndex] * (FastMath.exp(-t1) - FastMath.exp(-t0)); - } return leftCost; } diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java index 706e42011a68..406397eefab9 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -95,9 +95,10 @@ public class SegmentsCostCacheV3 * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. */ - private static final double HALF_LIFE_DAYS = 1.0; - private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS; - private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA; + private static final double HALF_LIFE_HOURS = 24.0; + private static final double LAMBDA = Math.log(2) / HALF_LIFE_HOURS; + private static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); + private static final double MILLIS_FACTOR = TimeUnit.HOURS.toMillis(1) / LAMBDA; /** * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" @@ -161,7 +162,7 @@ public double cost(DataSegment segment) cost += bucket.cost(segment); } - return cost; + return cost * NORMALIZATION_FACTOR; } public static Builder builder() @@ -291,7 +292,9 @@ boolean inCalculationInterval(DataSegment dataSegment) int index = Collections.binarySearch(intervalStartSortList, dataSegment.getInterval(), INTERVAL_START_COMPARATOR); index = (index >= 0) ? index : -index - 1; - return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + double normalizedCost = leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); + double denormalizingFactor = 1; + return normalizedCost * denormalizingFactor; } private double leftCost(DataSegment dataSegment, double t0, double t1, int index) diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index 7fb331c69e1c..c8a65e241424 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -20,6 +20,7 @@ package org.apache.druid.server.coordinator.cost; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.server.coordinator.CostBalancerStrategy; import org.apache.druid.timeline.DataSegment; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -36,23 +37,11 @@ public class SegmentsCostCacheTest { + private static final Random random = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.00000001; - @Test - public void segmentCacheTest() - { - SegmentsCostCache.Builder cacheBuilder = SegmentsCostCache.builder(); - cacheBuilder.addSegment(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100)); - SegmentsCostCache cache = cacheBuilder.build(); - Assert.assertEquals( - 7.8735899489011E-4, - cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100)), - EPSILON - ); - } - @Test public void notInCalculationIntervalCostTest() { @@ -146,38 +135,114 @@ public void multipleSegmentsCostTest() } @Test - public void randomSegmentsCostTest() + public void perfComparisonTest() { + final int N = 100000; + List dataSegments = new ArrayList<>(1000); - Random random = new Random(1); - for (int i = 0; i < 1000; ++i) { - dataSegments.add(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, random.nextInt(20)), 100)); + for (int i = 0; i < N; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * random.nextInt(60)), 100)); } - DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shiftedRandomInterval(REFERENCE_TIME, 5), 100); - SegmentsCostCache.Bucket.Builder prototype = SegmentsCostCache.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(1), - REFERENCE_TIME.plusHours(25) - )); + SegmentsCostCache.Builder prototype = new SegmentsCostCache.Builder(); + + long start; + long end; - long start = System.currentTimeMillis(); + start = System.currentTimeMillis(); dataSegments.forEach(prototype::addSegment); - SegmentsCostCache.Bucket bucket = prototype.build(); + SegmentsCostCache cache = prototype.build(); + + end = System.currentTimeMillis(); + System.out.println("Insertion time for " + N + " segments: " + (end - start) + " ms"); - long end = System.currentTimeMillis(); - System.out.println(end - start); + start = System.currentTimeMillis(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } - double cost = bucket.cost(referenceSegment); - Assert.assertEquals(0.7065117101966677, cost, EPSILON); + end = System.currentTimeMillis(); + System.out.println("Avg cost time: " + ((end - start) * 1000) + " us"); + + start = System.currentTimeMillis(); + + for (int i = 0; i < 1000; i++) { + cache.cost(referenceSegment); + } + + end = System.currentTimeMillis(); + System.out.println("Avg cache cost time: " + (end - start) + " us"); + + double cost = cache.cost(referenceSegment); + + System.out.println(origCost); + System.out.println(cost); + Assert.assertEquals(1, origCost / cost, EPSILON); + } + + @Test + public void correctnessTest() + { + List dataSegments = new ArrayList<>(); + + // Same as reference interval + for (int i = 0; i < 100; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + } + + // Overlapping intervals of larger size that enclose the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + } + + // intervals of small size that are enclosed within the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCache.Builder prototype = new SegmentsCostCache.Builder(); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCache cache = prototype.build(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } + + double cost = cache.cost(referenceSegment); + + System.out.println("Actual cost : " + origCost); + System.out.println("Cached cost : " + cost); + Assert.assertEquals(1, origCost / cost, EPSILON); + } + + + private static Interval shiftedXHInterval(DateTime REFERENCE_TIME, int shiftInHours, int X) + { + return new Interval( + REFERENCE_TIME.plusHours(shiftInHours), + REFERENCE_TIME.plusHours(shiftInHours + X) + ); } private static Interval shifted1HInterval(DateTime REFERENCE_TIME, int shiftInHours) + { + return shiftedXHInterval(REFERENCE_TIME, shiftInHours, 1); + } + + private static Interval shiftedRandomInterval(DateTime REFERENCE_TIME, int shiftInHours) { return new Interval( REFERENCE_TIME.plusHours(shiftInHours), - REFERENCE_TIME.plusHours(shiftInHours + 1) + REFERENCE_TIME.plusHours(shiftInHours + random.nextInt(1000)) ); } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java index 6ac8dbd142e4..bf1584c6c6ee 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java @@ -20,6 +20,7 @@ package org.apache.druid.server.coordinator.cost; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.server.coordinator.CostBalancerStrategy; import org.apache.druid.timeline.DataSegment; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -36,23 +37,11 @@ public class SegmentsCostCacheV3Test { + private static final Random random = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.00000001; - @Test - public void segmentCacheTest() - { - SegmentsCostCacheV3.Builder cacheBuilder = SegmentsCostCacheV3.builder(); - cacheBuilder.addSegment(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100)); - SegmentsCostCacheV3 cache = cacheBuilder.build(); - Assert.assertEquals( - 7.8735899489011E-4, - cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100)), - EPSILON - ); - } - @Test public void notInCalculationIntervalCostTest() { @@ -146,38 +135,114 @@ public void multipleSegmentsCostTest() } @Test - public void randomSegmentsCostTest() + public void perfComparisonTest() { + final int N = 100000; + List dataSegments = new ArrayList<>(1000); - Random random = new Random(1); - for (int i = 0; i < 1000; ++i) { - dataSegments.add(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, random.nextInt(20)), 100)); + for (int i = 0; i < N; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * random.nextInt(60)), 100)); } - DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shiftedRandomInterval(REFERENCE_TIME, 5), 100); - SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(1), - REFERENCE_TIME.plusHours(25) - )); + SegmentsCostCacheV3.Builder prototype = new SegmentsCostCacheV3.Builder(); + + long start; + long end; - long start = System.currentTimeMillis(); + start = System.currentTimeMillis(); dataSegments.forEach(prototype::addSegment); - SegmentsCostCacheV3.Bucket bucket = prototype.build(); + SegmentsCostCacheV3 cache = prototype.build(); + + end = System.currentTimeMillis(); + System.out.println("Insertion time for " + N + " segments: " + (end - start) + " ms"); - long end = System.currentTimeMillis(); - System.out.println(end - start); + start = System.currentTimeMillis(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } - double cost = bucket.cost(referenceSegment); - Assert.assertEquals(0.7065117101966677, cost, EPSILON); + end = System.currentTimeMillis(); + System.out.println("Avg cost time: " + ((end - start) * 1000) + " us"); + + start = System.currentTimeMillis(); + + for (int i = 0; i < 1000; i++) { + cache.cost(referenceSegment); + } + + end = System.currentTimeMillis(); + System.out.println("Avg cache cost time: " + (end - start) + " us"); + + double cost = cache.cost(referenceSegment); + + System.out.println(origCost); + System.out.println(cost); + Assert.assertEquals(1, origCost / cost, EPSILON); + } + + @Test + public void correctnessTest() + { + List dataSegments = new ArrayList<>(); + + // Same as reference interval + for (int i = 0; i < 100; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + } + + // Overlapping intervals of larger size that enclose the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + } + + // intervals of small size that are enclosed within the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCacheV3.Builder prototype = new SegmentsCostCacheV3.Builder(); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCacheV3 cache = prototype.build(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } + + double cost = cache.cost(referenceSegment); + + System.out.println("Actual cost : " + origCost); + System.out.println("Cached cost : " + cost); + Assert.assertEquals(1, origCost / cost, EPSILON); + } + + + private static Interval shiftedXHInterval(DateTime REFERENCE_TIME, int shiftInHours, int X) + { + return new Interval( + REFERENCE_TIME.plusHours(shiftInHours), + REFERENCE_TIME.plusHours(shiftInHours + X) + ); } private static Interval shifted1HInterval(DateTime REFERENCE_TIME, int shiftInHours) + { + return shiftedXHInterval(REFERENCE_TIME, shiftInHours, 1); + } + + private static Interval shiftedRandomInterval(DateTime REFERENCE_TIME, int shiftInHours) { return new Interval( REFERENCE_TIME.plusHours(shiftInHours), - REFERENCE_TIME.plusHours(shiftInHours + 1) + REFERENCE_TIME.plusHours(shiftInHours + random.nextInt(100)) ); } From aad54e801a970d95813e85c9c8d4de7f5f716fdd Mon Sep 17 00:00:00 2001 From: Amatya Date: Sat, 16 Apr 2022 19:38:00 +0530 Subject: [PATCH 10/13] Faster equivalent implementation with tests --- .../coordinator/cost/SegmentsCostCache.java | 2 +- .../coordinator/cost/SegmentsCostCacheV3.java | 295 +++++++++--------- .../cost/SegmentsCostCacheTest.java | 69 +++- .../cost/SegmentsCostCacheV3Test.java | 69 +++- 4 files changed, 282 insertions(+), 153 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java index 7af58348345c..674844724a48 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCache.java @@ -95,7 +95,7 @@ public class SegmentsCostCache */ private static final double HALF_LIFE_HOURS = 24.0; private static final double LAMBDA = Math.log(2) / HALF_LIFE_HOURS; - private static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); + static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); private static final double MILLIS_FACTOR = TimeUnit.HOURS.toMillis(1) / LAMBDA; /** diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java index 406397eefab9..df9396249861 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -24,6 +24,7 @@ import org.apache.commons.math3.util.FastMath; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.DurationGranularity; import org.apache.druid.java.util.common.guava.Comparators; import org.apache.druid.server.coordinator.CostBalancerStrategy; @@ -32,6 +33,7 @@ import org.joda.time.Interval; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.HashSet; @@ -97,7 +99,7 @@ public class SegmentsCostCacheV3 */ private static final double HALF_LIFE_HOURS = 24.0; private static final double LAMBDA = Math.log(2) / HALF_LIFE_HOURS; - private static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); + static final double NORMALIZATION_FACTOR = 1 / (LAMBDA * LAMBDA); private static final double MILLIS_FACTOR = TimeUnit.HOURS.toMillis(1) / LAMBDA; /** @@ -115,16 +117,9 @@ public class SegmentsCostCacheV3 private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); - private static final Comparator INTERVAL_START_COMPARATOR = Comparators.intervalsByStartThenEnd(); - - private static final Comparator INTERVAL_END_COMPARATOR = Comparators.intervalsByEndThenStart(); - private static final Comparator BUCKET_INTERVAL_COMPARATOR = Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); - private static final Ordering INTERVAL_START_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); - - private static final Ordering INTERVAL_END_ORDERING = Ordering.from(Comparators.intervalsByEndThenStart()); private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); private final ArrayList sortedBuckets; @@ -219,53 +214,67 @@ static class Bucket { private final Interval interval; private final Interval calculationInterval; - private final List intervalStartSortList; - private final List intervalEndSortList; - private final double[] cumStart; - private final double[] cumStartExp; - private final double[] cumStartExpInv; - private final double[] cumEnd; - private final double[] cumEndExp; - private final double[] cumEndExpInv; + private final long START; + private final long END; + private final double END_VAL; + private final double END_EXP; + private final double END_EXP_INV; + + private final long[] start; + private final long[] end; + + private final double[] startValSum; + private final double[] startExpSum; + private final double[] startExpInvSum; + + private final double[] endValSum; + private final double[] endExpSum; + private final double[] endExpInvSum; - Bucket(Interval interval, List intervalStartSortList, List intervalEndSortList) + Bucket(Interval interval, List> intervals) { this.interval = Preconditions.checkNotNull(interval, "interval"); - this.intervalStartSortList = Preconditions.checkNotNull(intervalStartSortList, "intervalStartSortList"); - this.intervalEndSortList = Preconditions.checkNotNull(intervalEndSortList, "intervalEndSortList"); - Preconditions.checkArgument(intervalStartSortList.size() == intervalEndSortList.size()); - Preconditions.checkArgument(INTERVAL_START_ORDERING.isOrdered(intervalStartSortList)); - Preconditions.checkArgument(INTERVAL_END_ORDERING.isOrdered(intervalEndSortList)); + this.calculationInterval = new Interval( interval.getStart().minus(LIFE_THRESHOLD), interval.getEnd().plus(LIFE_THRESHOLD) ); - int n = intervalStartSortList.size(); - double exp; - double expInv; + START = interval.getStartMillis(); + END = interval.getEndMillis(); + END_VAL = getVal(END); + END_EXP = FastMath.exp(END_VAL); + END_EXP_INV = FastMath.exp(-END_VAL); + + int n = intervals.size(); + start = new long[n]; + end = new long[n]; + for (int i = 0; i < n; i++) { + start[i] = intervals.get(i).lhs; + end[i] = intervals.get(i).rhs; + } + Arrays.sort(start); + Arrays.sort(end); + + startValSum = new double[n + 1]; + startExpSum = new double[n + 1]; + startExpInvSum = new double[n + 1]; + for (int i = 0; i < n; i++) { + double val = getVal(start[i]); + startValSum[i + 1] = startValSum[i] + val; + startExpSum[i + 1] = startExpSum[i] + FastMath.exp(val); + startExpInvSum[i + 1] = startExpInvSum[i] + FastMath.exp(-val); + } - cumStart = new double[n + 1]; - cumStartExp = new double[n + 1]; - cumStartExpInv = new double[n + 1]; - cumEnd = new double[n + 1]; - cumEndExp = new double[n + 1]; - cumEndExpInv = new double[n + 1]; + endValSum = new double[n + 1]; + endExpSum = new double[n + 1]; + endExpInvSum = new double[n + 1]; for (int i = 0; i < n; i++) { - double start = convertStart(intervalStartSortList.get(i), interval); - exp = FastMath.exp(start); - expInv = FastMath.exp(-start); - cumStart[i + 1] = cumStart[i] + start; - cumStartExp[i + 1] = cumStartExp[i] + exp; - cumStartExpInv[i + 1] = cumStartExpInv[i] + expInv; - - double end = convertEnd(intervalEndSortList.get(i), interval); - exp = FastMath.exp(end); - expInv = FastMath.exp(-end); - cumEnd[i + 1] = cumEnd[i] + end; - cumEndExp[i + 1] = cumEndExp[i] + exp; - cumEndExpInv[i + 1] = cumEndExpInv[i] + expInv; + double val = getVal(end[i]); + endValSum[i + 1] = endValSum[i] + val; + endExpSum[i + 1] = endExpSum[i] + FastMath.exp(val); + endExpInvSum[i + 1] = endExpInvSum[i] + FastMath.exp(-val); } } @@ -281,118 +290,127 @@ boolean inCalculationInterval(DataSegment dataSegment) double cost(DataSegment dataSegment) { - // cost is calculated relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); - // avoid calculation for segments outside of LIFE_THRESHOLD if (!inCalculationInterval(dataSegment)) { throw new ISE("Segment is not within calculation interval"); } - int index = Collections.binarySearch(intervalStartSortList, dataSegment.getInterval(), INTERVAL_START_COMPARATOR); - index = (index >= 0) ? index : -index - 1; - double normalizedCost = leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); - double denormalizingFactor = 1; - return normalizedCost * denormalizingFactor; + + + long x = dataSegment.getInterval().getStartMillis(); + long y = dataSegment.getInterval().getEndMillis(); + double cost = 0; + cost += solve(x, y, start, startValSum, startExpSum, startExpInvSum); + cost -= solve(x, y, end, endValSum, endExpSum, endExpInvSum); + return cost; } - private double leftCost(DataSegment dataSegment, double t0, double t1, int index) + // Sum of cost (, ) for all val in vals + private double solve(long x, long y, long[] vals, double[] sum, double[] expSum, double[] expInvSum) { - if (index - 1 < 0) { - return 0; - } - double exp0 = FastMath.exp(t0); - double expInv0 = 1 / exp0; - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; - double leftCost = 0.0; - // add to cost all left-overlapping segments - int rightBound = index - 1; - int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval(), intervalStartSortList); - leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); - leftCost -= 2 * (rightBound - leftBound + 1) * t0; - leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); - // add left-non-overlapping segments - if (leftBound > 0) { - leftCost += cumStartExp[leftBound] * (expInv1 - expInv0); - leftCost -= cumEndExp[leftBound] * (expInv1 - expInv0); + + int n = vals.length; + + double xVal = getVal(x); + double xExp = FastMath.exp(xVal); + double xExpInv = FastMath.exp(-xVal); + + double yVal = getVal(y); + double yExp = FastMath.exp(yVal); + double yExpInv = FastMath.exp(-yVal); + + double cost = 0; + + if (END < x) { + + // val , END , x , y + cost += expSum[n] * yExpInv; + cost -= expSum[n] * xExpInv; + cost += n * END_EXP * xExpInv; + cost -= n * END_EXP * yExpInv; + + } else if (END > y) { + + int l = lowerBound(0, n - 1, x, vals); + int r = upperBound(0, n - 1, y, vals); + + // val < j , y , E + cost += 2 * (l + 1) * yVal; + cost -= 2 * (l + 1) * xVal; + cost += expSum[l + 1] * yExpInv; + cost -= expSum[l + 1] * xExpInv; + cost += (l + 1) * xExp * END_EXP_INV; + cost -= (l + 1) * yExp * END_EXP_INV; + + // x <= val <= y , E + cost += 2 * (r - l - 1) * yVal; + cost -= 2 * (sum[r] - sum[l + 1]); + cost += (r - l - 1) * xExp * END_EXP_INV; + cost -= xExp * (expInvSum[r] - expInvSum[l + 1]); + cost -= (r - l - 1) * yExp * END_EXP_INV; + cost += (expSum[r] - expSum[l + 1]) * yExpInv; + + // x , y < val , E + cost += (n - r) * xExp * END_EXP_INV; + cost -= xExp * (expInvSum[n] - expInvSum[r]); + cost -= (n - r) * yExp * END_EXP_INV; + cost += yExp * (expInvSum[n] - expInvSum[r]); + + } else { + + int l = lowerBound(0, n - 1, x, vals); + + // val < x , END , y + cost += 2 * (l + 1) * END_VAL; + cost -= 2 * (l + 1) * xVal; + cost += expSum[l + 1] * yExpInv; + cost -= expSum[l + 1] * xExpInv; + cost -= (l + 1) * END_EXP * yExpInv; + cost += (l + 1) * xExp * END_EXP_INV; + + // x <= val , END , y + cost += 2 * (n - l - 1) * END_VAL; + cost -= 2 * (sum[n] - sum[l - 1]); + cost += (n - l + 1) * xExp * END_EXP_INV; + cost -= xExp * (expInvSum[n] - expInvSum[l - 1]); + cost += (expSum[n] - expSum[l - 1]) * yExpInv; + cost -= (n - l + 1) * END_EXP * yExpInv; } - return leftCost; + + return cost; } - private double rightCost(DataSegment dataSegment, double t0, double t1, int index) + private double getVal(long millis) { - int n = intervalStartSortList.size(); - if (index >= n) { - return 0; - } - double exp0 = FastMath.exp(t0); - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; - double rightCost = 0.0; - int leftBound = index; - int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval(), intervalStartSortList); - // add all right-overlapping segments - rightCost += 2 * (rightBound - leftBound + 1) * t1; - rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); - rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); - rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - // add right-non-overlapping segments - rightBound++; - if (rightBound <= n) { - rightCost += (cumEndExpInv[n] - cumEndExpInv[rightBound]) * (exp0 - exp1); - rightCost -= (cumStartExpInv[n] - cumStartExpInv[rightBound]) * (exp0 - exp1); - } - return rightCost; + return millis / MILLIS_FACTOR - START / MILLIS_FACTOR; } - private int leftBoundary(int l, int r, Interval interval, List intervalList) + private int lowerBound(int l, int r, long x, long[] a) { if (l == r) { - return interval.overlaps(intervalList.get(l)) ? l : r + 1; + return a[l] < x ? r : l - 1; } - int m = (l + r) / 2; - if (interval.overlaps(intervalList.get(m))) { - return leftBoundary(l, m, interval, intervalList); + int m = (l + r + 1) / 2; + if (a[m] < x) { + return lowerBound(m, r, x, a); } else { - return leftBoundary(m + 1, r, interval, intervalList); + return lowerBound(l, m - 1, x, a); } } - private int rightBoundary(int l, int r, Interval interval, List intervalList) + private int upperBound(int l, int r, long x, long[] a) { if (l == r) { - return interval.overlaps(intervalList.get(r)) ? r : l - 1; + return a[r] > x ? l : r + 1; } - int m = (l + r + 1) / 2; - if (interval.overlaps(intervalList.get(m))) { - return rightBoundary(m, r, interval, intervalList); + int m = (l + r) / 2; + if (a[m] > x) { + return upperBound(l, m, x, a); } else { - return rightBoundary(l, m - 1, interval, intervalList); + return upperBound(m + 1, r, x, a); } } - private static double convertStart(Interval interval, Interval reference) - { - return toLocalInterval(interval.getStartMillis(), reference); - } - - private static double convertEnd(Interval interval, Interval reference) - { - return toLocalInterval(interval.getEndMillis(), reference); - } - - private static double toLocalInterval(long millis, Interval interval) - { - return millis / MILLIS_FACTOR - interval.getStartMillis() / MILLIS_FACTOR; - } - public static Builder builder(Interval interval) { return new Builder(interval); @@ -434,24 +452,19 @@ public boolean isEmpty() public Bucket build() { - List intervalsStartSortList = segmentSet.stream() - .map(SegmentId::getInterval) - .sorted(INTERVAL_START_COMPARATOR) - .collect(Collectors.toList()); + long bucketEndMillis = interval.getEndMillis(); - List intervalsEndSortList = segmentSet.stream() - .map(SegmentId::getInterval) - .sorted(INTERVAL_END_COMPARATOR) - .collect(Collectors.toList()); + List> intervals = new ArrayList<>(); - long bucketEndMillis = intervalsEndSortList.get(intervalsEndSortList.size() - 1).getEndMillis(); - bucketEndMillis = Long.max(bucketEndMillis, interval.getEndMillis()); + for (SegmentId segment : segmentSet) { + Interval i = segment.getInterval(); + intervals.add(Pair.of(i.getStartMillis(), i.getEndMillis())); + bucketEndMillis = Math.max(bucketEndMillis, i.getEndMillis()); + } segmentSet.clear(); - return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), - intervalsStartSortList, - intervalsEndSortList); + return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervals); } } } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index c8a65e241424..6396425782ba 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -34,13 +34,15 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import static org.apache.druid.server.coordinator.cost.SegmentsCostCache.NORMALIZATION_FACTOR; + public class SegmentsCostCacheTest { private static final Random random = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); - private static final double EPSILON = 0.00000001; + private static final double EPSILON = 0.0000001; @Test public void notInCalculationIntervalCostTest() @@ -180,13 +182,60 @@ public void perfComparisonTest() double cost = cache.cost(referenceSegment); - System.out.println(origCost); - System.out.println(cost); Assert.assertEquals(1, origCost / cost, EPSILON); } @Test - public void correctnessTest() + public void bucketCorrectnessTest() + { + List dataSegments = new ArrayList<>(); + + // Same as reference interval + for (int i = 0; i < 100; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + } + + // Overlapping intervals of larger size that enclose the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + } + + // intervals of small size that are enclosed within the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + } + + // intervals not intersecting, lying to its left + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, -90), 100)); + } + + // intervals not intersecting, lying to its right + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 60), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCache.Bucket.Builder prototype = SegmentsCostCache.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(90), REFERENCE_TIME.plusHours(90) + )); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCache.Bucket bucket = prototype.build(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } + + double cost = bucket.cost(referenceSegment); + + Assert.assertEquals(NORMALIZATION_FACTOR, origCost / cost, EPSILON); + } + + @Test + public void overallCorrectnessTest() { List dataSegments = new ArrayList<>(); @@ -205,6 +254,16 @@ public void correctnessTest() dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); } + // intervals not intersecting, lying to its left + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, -90), 100)); + } + + // intervals not intersecting, lying to its right + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 60), 100)); + } + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); SegmentsCostCache.Builder prototype = new SegmentsCostCache.Builder(); @@ -219,8 +278,6 @@ public void correctnessTest() double cost = cache.cost(referenceSegment); - System.out.println("Actual cost : " + origCost); - System.out.println("Cached cost : " + cost); Assert.assertEquals(1, origCost / cost, EPSILON); } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java index bf1584c6c6ee..6c81d3b75f66 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java @@ -34,13 +34,15 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import static org.apache.druid.server.coordinator.cost.SegmentsCostCacheV3.NORMALIZATION_FACTOR; + public class SegmentsCostCacheV3Test { private static final Random random = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); - private static final double EPSILON = 0.00000001; + private static final double EPSILON = 0.0000001; @Test public void notInCalculationIntervalCostTest() @@ -176,7 +178,7 @@ public void perfComparisonTest() } end = System.currentTimeMillis(); - System.out.println("Avg cache cost time: " + (end - start) + " us"); + System.out.println("Avg new cache cost time: " + (end - start) + " us"); double cost = cache.cost(referenceSegment); @@ -186,7 +188,56 @@ public void perfComparisonTest() } @Test - public void correctnessTest() + public void bucketCorrectnessTest() + { + List dataSegments = new ArrayList<>(); + + // Same as reference interval + for (int i = 0; i < 100; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + } + + // Overlapping intervals of larger size that enclose the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + } + + // intervals of small size that are enclosed within the reference interval + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + } + + // intervals not intersecting, lying to its left + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, -90), 100)); + } + + // intervals not intersecting, lying to its right + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 60), 100)); + } + + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); + + SegmentsCostCacheV3.Bucket.Builder prototype = SegmentsCostCacheV3.Bucket.builder(new Interval( + REFERENCE_TIME.minusHours(90), REFERENCE_TIME.plusHours(90) + )); + + dataSegments.forEach(prototype::addSegment); + SegmentsCostCacheV3.Bucket bucket = prototype.build(); + + double origCost = 0; + for (DataSegment segment : dataSegments) { + origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } + + double cost = bucket.cost(referenceSegment); + + Assert.assertEquals(NORMALIZATION_FACTOR, origCost / cost, EPSILON); + } + + @Test + public void overallCorrectnessTest() { List dataSegments = new ArrayList<>(); @@ -205,6 +256,16 @@ public void correctnessTest() dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); } + // intervals not intersecting, lying to its left + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, -90), 100)); + } + + // intervals not intersecting, lying to its right + for (int i = 0; i < 10; i++) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 60), 100)); + } + DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); SegmentsCostCacheV3.Builder prototype = new SegmentsCostCacheV3.Builder(); @@ -219,8 +280,6 @@ public void correctnessTest() double cost = cache.cost(referenceSegment); - System.out.println("Actual cost : " + origCost); - System.out.println("Cached cost : " + cost); Assert.assertEquals(1, origCost / cost, EPSILON); } From 26a898a602007dff9fa78bcd88bb43549105be7f Mon Sep 17 00:00:00 2001 From: Amatya Date: Sun, 17 Apr 2022 18:01:46 +0530 Subject: [PATCH 11/13] Handle all, adhoc granularities without overflow --- .../coordinator/cost/SegmentsCostCacheV3.java | 186 ++++++++++++----- .../cost/SegmentsCostCacheV3Test.java | 192 +++++++++++++++--- 2 files changed, 304 insertions(+), 74 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java index df9396249861..5556c45c602c 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -24,6 +24,7 @@ import org.apache.commons.math3.util.FastMath; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.JodaUtils; import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.granularity.DurationGranularity; import org.apache.druid.java.util.common.guava.Comparators; @@ -82,18 +83,20 @@ * S_x1 S_x2 S_x3 S_x4 S_x5 S_x6 S_x7 S_x8 S_x9 * bucket1 bucket2 bucket3 * + * * Reasons to store segments in Buckets: * - * 1) Cost function tends to 0 as distance between segments' intervals increases; buckets - * are used to avoid redundant 0 calculations for thousands of times - * 2) To reduce number of calculations when segment is added or removed from SegmentsCostCache - * 3) To avoid infinite values during exponents calculations + * Unlike SegmentsCostCache updates are fast, and we can make do without buckets ideally. + * Unfortunately, large values for (time - bucketStart) cause overflows + * A threshold (say 1Y) is used for bucketing, and this allows us to compute cost within O(logN) for a bucket * + * If the interval duration exeeds it, we have to use CostBalancerStrategy#intervalCost over all the intervals + * This scenario has a complexity of O(M) where M is the number of "adhoc" buckets. */ public class SegmentsCostCacheV3 { /** - * HALF_LIFE_DAYS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing. + * HALF_LIFE_HOURS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing. * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. */ @@ -104,17 +107,14 @@ public class SegmentsCostCacheV3 /** * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" - * from each other and thus cost(Y,Y) ~ 0 for these segments + * from each other and thus cost ~ 0 for these segments */ private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); - /** - * Bucket interval defines duration granularity for segment buckets. Number of buckets control the trade-off - * between updates (add/remove segment operation) and joint cost calculation: - * 1) updates complexity is increasing when number of buckets is decreasing (as buckets contain more segments) - * 2) joint cost calculation complexity is increasing with increasing of buckets number - */ - private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); + // The max interval that can be added to a bucket + private static final long INTERVAL_THRESHOLD = TimeUnit.DAYS.toMillis(366); + // exp(BUCKET_INTERVAL + INTERVAL_THRESHOLD + 2 * LIFE_THRESHOLD) must be within limits + private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(90); private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); private static final Comparator BUCKET_INTERVAL_COMPARATOR = @@ -124,19 +124,32 @@ public class SegmentsCostCacheV3 private final ArrayList sortedBuckets; private final ArrayList intervals; + private final ArrayList> adhocNormalizedIntervals; - SegmentsCostCacheV3(ArrayList sortedBuckets) + private final int allGranularitySegmentCount; + private double allGranularitySegmentCost = -1; + + SegmentsCostCacheV3(ArrayList sortedBuckets, + ArrayList> adhocNormalizedIntervals, + int allGranularitySegmentCount) { this.sortedBuckets = Preconditions.checkNotNull(sortedBuckets, "buckets should not be null"); - this.intervals = sortedBuckets.stream().map(Bucket::getInterval).collect(Collectors.toCollection(ArrayList::new)); Preconditions.checkArgument( BUCKET_ORDERING.isOrdered(sortedBuckets), "buckets must be ordered by interval" ); + this.intervals = sortedBuckets.stream().map(Bucket::getInterval).collect(Collectors.toCollection(ArrayList::new)); + this.adhocNormalizedIntervals = Preconditions.checkNotNull(adhocNormalizedIntervals, "adhocIntervals should not be null"); + this.allGranularitySegmentCount = allGranularitySegmentCount; } public double cost(DataSegment segment) { + boolean allGranularity = isAllGranularity(segment); + if (allGranularity && allGranularitySegmentCost >= 0) { + return allGranularitySegmentCost; + } + double cost = 0.0; int index = Collections.binarySearch(intervals, segment.getInterval(), Comparators.intervalsByStartThenEnd()); index = (index >= 0) ? index : -index - 1; @@ -146,6 +159,7 @@ public double cost(DataSegment segment) if (!bucket.inCalculationInterval(segment)) { break; } + // O(logN) -> N segments per bucket cost += bucket.cost(segment); } @@ -154,10 +168,31 @@ public double cost(DataSegment segment) if (!bucket.inCalculationInterval(segment)) { break; } + // O(logN) -> N segments per bucket cost += bucket.cost(segment); } - return cost * NORMALIZATION_FACTOR; + double start = segment.getInterval().getStartMillis() / MILLIS_FACTOR; + double end = segment.getInterval().getEndMillis() / MILLIS_FACTOR; + + // O(1) -> for ALL granularity segments + double allStart = JodaUtils.MIN_INSTANT / MILLIS_FACTOR; + double allEnd = JodaUtils.MAX_INSTANT / MILLIS_FACTOR; + cost += allGranularitySegmentCount * CostBalancerStrategy.intervalCost(allEnd - allStart, start - allStart, end - allStart); + + // O(M) -> M adhoc buckets + for (Pair adhoc : adhocNormalizedIntervals) { + cost += CostBalancerStrategy.intervalCost(adhoc.rhs - adhoc.lhs, start - adhoc.lhs, end - adhoc.lhs); + } + + cost *= NORMALIZATION_FACTOR; + + // store cost for all granularity adhoc bucket for faster computation + if (allGranularity) { + allGranularitySegmentCost = cost; + } + + return cost; } public static Builder builder() @@ -165,42 +200,84 @@ public static Builder builder() return new Builder(); } + private static boolean isAllGranularity(DataSegment segment) { + return segment.getInterval().getStartMillis() == JodaUtils.MIN_INSTANT + && segment.getInterval().getEndMillis() == JodaUtils.MAX_INSTANT; + } + + public static class Builder { private final NavigableMap buckets = new TreeMap<>(Comparators.intervalsByStartThenEnd()); + private final HashSet allGranularitySegments = new HashSet<>(); + private final HashSet adhocSegments = new HashSet<>(); + public Builder addSegment(DataSegment segment) { - Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); - builder.addSegment(segment); + if (isAllGranularity(segment)) { + if (!allGranularitySegments.add(segment.getId())) { + throw new ISE("expect new segment"); + } + } + else if (isAdhoc(segment)) { + if (!adhocSegments.add(segment.getId())) { + throw new ISE("expect new segment"); + } + } + else { + Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); + builder.addSegment(segment); + } return this; } public Builder removeSegment(DataSegment segment) { - Interval interval = getBucketInterval(segment); - buckets.computeIfPresent( - interval, - // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets - // map - (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder - ); + if (isAllGranularity(segment)) { + allGranularitySegments.remove(segment.getId()); + } + if (isAdhoc(segment)) { + adhocSegments.remove(segment.getId()); + } + else { + Interval interval = getBucketInterval(segment); + buckets.computeIfPresent( + interval, + // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets + // map + (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder + ); + } return this; } public boolean isEmpty() { - return buckets.isEmpty(); + return buckets.isEmpty() && allGranularitySegments.isEmpty() && adhocSegments.isEmpty(); } public SegmentsCostCacheV3 build() { + final int allGranularitySegmentCount = allGranularitySegments.size(); + allGranularitySegments.clear(); + + final ArrayList> adhocNormalizedIntervals = new ArrayList<>(); + for (SegmentId segment : adhocSegments) { + double normalizedStart = segment.getInterval().getStartMillis() / MILLIS_FACTOR; + double normalizedEnd = segment.getInterval().getEndMillis() / MILLIS_FACTOR; + adhocNormalizedIntervals.add(Pair.of(normalizedStart, normalizedEnd)); + } + adhocSegments.clear(); + return new SegmentsCostCacheV3( buckets .values() .stream() .map(Bucket.Builder::build) - .collect(Collectors.toCollection(ArrayList::new)) + .collect(Collectors.toCollection(ArrayList::new)), + adhocNormalizedIntervals, + allGranularitySegmentCount ); } @@ -208,6 +285,12 @@ private static Interval getBucketInterval(DataSegment segment) { return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart()); } + + private boolean isAdhoc(DataSegment segment) { + double duration = segment.getInterval().getEndMillis() / MILLIS_FACTOR + - segment.getInterval().getStartMillis() / MILLIS_FACTOR; + return duration > INTERVAL_THRESHOLD / MILLIS_FACTOR; + } } static class Bucket @@ -241,12 +324,6 @@ static class Bucket interval.getEnd().plus(LIFE_THRESHOLD) ); - START = interval.getStartMillis(); - END = interval.getEndMillis(); - END_VAL = getVal(END); - END_EXP = FastMath.exp(END_VAL); - END_EXP_INV = FastMath.exp(-END_VAL); - int n = intervals.size(); start = new long[n]; end = new long[n]; @@ -257,24 +334,32 @@ static class Bucket Arrays.sort(start); Arrays.sort(end); + START = Math.max(interval.getStartMillis(), start[0]); + END = Math.min(interval.getEndMillis(), end[n - 1]); + + END_VAL = getVal(END); + END_EXP = FastMath.exp(END_VAL); + END_EXP_INV = FastMath.exp(-END_VAL); + + startValSum = new double[n + 1]; startExpSum = new double[n + 1]; startExpInvSum = new double[n + 1]; for (int i = 0; i < n; i++) { - double val = getVal(start[i]); - startValSum[i + 1] = startValSum[i] + val; - startExpSum[i + 1] = startExpSum[i] + FastMath.exp(val); - startExpInvSum[i + 1] = startExpInvSum[i] + FastMath.exp(-val); + double startVal = getVal(start[i]); + startValSum[i + 1] = startValSum[i] + startVal; + startExpSum[i + 1] = startExpSum[i] + FastMath.exp(startVal); + startExpInvSum[i + 1] = startExpInvSum[i] + FastMath.exp(-startVal); } endValSum = new double[n + 1]; endExpSum = new double[n + 1]; endExpInvSum = new double[n + 1]; for (int i = 0; i < n; i++) { - double val = getVal(end[i]); - endValSum[i + 1] = endValSum[i] + val; - endExpSum[i + 1] = endExpSum[i] + FastMath.exp(val); - endExpInvSum[i + 1] = endExpInvSum[i] + FastMath.exp(-val); + double endVal = getVal(end[i]); + endValSum[i + 1] = endValSum[i] + endVal; + endExpSum[i + 1] = endExpSum[i] + FastMath.exp(endVal); + endExpInvSum[i + 1] = endExpInvSum[i] + FastMath.exp(-endVal); } } @@ -295,13 +380,14 @@ boolean inCalculationInterval(DataSegment dataSegment) throw new ISE("Segment is not within calculation interval"); } - - - long x = dataSegment.getInterval().getStartMillis(); - long y = dataSegment.getInterval().getEndMillis(); + // The following bounds help avoid overflow. The cost beyond LIFE_THRESHOLD is insignificant anyway + long x = Math.max(dataSegment.getInterval().getStartMillis(), START - LIFE_THRESHOLD); + long y = Math.min(dataSegment.getInterval().getEndMillis(), END + LIFE_THRESHOLD); double cost = 0; + cost += solve(x, y, start, startValSum, startExpSum, startExpInvSum); cost -= solve(x, y, end, endValSum, endExpSum, endExpInvSum); + return cost; } @@ -370,11 +456,11 @@ private double solve(long x, long y, long[] vals, double[] sum, double[] expSum, // x <= val , END , y cost += 2 * (n - l - 1) * END_VAL; - cost -= 2 * (sum[n] - sum[l - 1]); - cost += (n - l + 1) * xExp * END_EXP_INV; - cost -= xExp * (expInvSum[n] - expInvSum[l - 1]); - cost += (expSum[n] - expSum[l - 1]) * yExpInv; - cost -= (n - l + 1) * END_EXP * yExpInv; + cost -= 2 * (sum[n] - sum[l + 1]); + cost += (n - l - 1) * xExp * END_EXP_INV; + cost -= xExp * (expInvSum[n] - expInvSum[l + 1]); + cost += (expSum[n] - expSum[l + 1]) * yExpInv; + cost -= (n - l - 1) * END_EXP * yExpInv; } return cost; diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java index 6c81d3b75f66..0555e46c3628 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java @@ -20,6 +20,7 @@ package org.apache.druid.server.coordinator.cost; import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.JodaUtils; import org.apache.druid.server.coordinator.CostBalancerStrategy; import org.apache.druid.timeline.DataSegment; import org.joda.time.DateTime; @@ -33,6 +34,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.apache.druid.server.coordinator.cost.SegmentsCostCacheV3.NORMALIZATION_FACTOR; @@ -41,7 +43,7 @@ public class SegmentsCostCacheV3Test private static final Random random = new Random(23894); private static final String DATA_SOURCE = "dataSource"; - private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); + private static final DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.0000001; @Test @@ -154,37 +156,29 @@ public void perfComparisonTest() long end; start = System.currentTimeMillis(); - dataSegments.forEach(prototype::addSegment); SegmentsCostCacheV3 cache = prototype.build(); - end = System.currentTimeMillis(); System.out.println("Insertion time for " + N + " segments: " + (end - start) + " ms"); start = System.currentTimeMillis(); - - double origCost = 0; - for (DataSegment segment : dataSegments) { - origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + for (int i = 0; i < 1000; i++) { + getExpectedCost(dataSegments, referenceSegment); } - end = System.currentTimeMillis(); - System.out.println("Avg cost time: " + ((end - start) * 1000) + " us"); + System.out.println("Avg cost time: " + (end - start) + " us"); start = System.currentTimeMillis(); - for (int i = 0; i < 1000; i++) { cache.cost(referenceSegment); } - end = System.currentTimeMillis(); System.out.println("Avg new cache cost time: " + (end - start) + " us"); + double expectedCost = getExpectedCost(dataSegments, referenceSegment); double cost = cache.cost(referenceSegment); - System.out.println(origCost); - System.out.println(cost); - Assert.assertEquals(1, origCost / cost, EPSILON); + Assert.assertEquals(1, expectedCost / cost, EPSILON); } @Test @@ -226,14 +220,11 @@ public void bucketCorrectnessTest() dataSegments.forEach(prototype::addSegment); SegmentsCostCacheV3.Bucket bucket = prototype.build(); - double origCost = 0; - for (DataSegment segment : dataSegments) { - origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); - } + double expectedCost = getExpectedCost(dataSegments, referenceSegment); double cost = bucket.cost(referenceSegment); - Assert.assertEquals(NORMALIZATION_FACTOR, origCost / cost, EPSILON); + Assert.assertEquals(NORMALIZATION_FACTOR, expectedCost / cost, EPSILON); } @Test @@ -270,19 +261,172 @@ public void overallCorrectnessTest() SegmentsCostCacheV3.Builder prototype = new SegmentsCostCacheV3.Builder(); + double expectedCost = getExpectedCost(dataSegments, referenceSegment); + dataSegments.forEach(prototype::addSegment); SegmentsCostCacheV3 cache = prototype.build(); + double cost = cache.cost(referenceSegment); + + Assert.assertEquals(1, expectedCost / cost, EPSILON); + } + + @Test + public void testLargeIntervals() + { + List intervals = new ArrayList<>(); + // add ALL granularity buckets + for (int i = 0; i < 5; i++) { + intervals.add(new Interval(JodaUtils.MIN_INSTANT, JodaUtils.MAX_INSTANT)); + } + // add random large intervals + for (int i = 0; i < 15; i++) { + intervals.add(new Interval(REFERENCE_TIME.minusYears(random.nextInt(30)), + REFERENCE_TIME.plusYears(random.nextInt(30)))); + } + // add random medium intervals + for (int i = 0; i < 30; i++) { + intervals.add(new Interval(REFERENCE_TIME.minusWeeks(random.nextInt(30)), + REFERENCE_TIME.plusWeeks(random.nextInt(30)))); + } + // add random small intervals + for (int i = 0; i < 50; i++) { + intervals.add(new Interval(REFERENCE_TIME.minusHours(random.nextInt(30)), + REFERENCE_TIME.plusHours(random.nextInt(30)))); + } - double origCost = 0; - for (DataSegment segment : dataSegments) { - origCost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + List segments = intervals.stream() + .map(interval -> createSegment(DATA_SOURCE, interval, 100)) + .collect(Collectors.toList()); + List referenceSegments = intervals.stream() + .map(interval -> createSegment("ANOTHER_DATA_SOURCE", interval, 100)) + .collect(Collectors.toList()); + + for (DataSegment segment : segments) { + for (DataSegment referenceSegment : referenceSegments) { + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(segment); + SegmentsCostCacheV3 cache = builder.build(); + + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + double cost = cache.cost(referenceSegment); + Assert.assertEquals(1, expectedCost / cost, 0.0001); + } } - double cost = cache.cost(referenceSegment); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + segments.forEach(builder::addSegment); + SegmentsCostCacheV3 cache = builder.build(); + for (DataSegment referenceSegment : referenceSegments) { + double expectedCost = getExpectedCost(segments, referenceSegment); + double cost = cache.cost(referenceSegment); + Assert.assertEquals(1 , expectedCost / cost, 0.01); + } + } - Assert.assertEquals(1, origCost / cost, EPSILON); + // ( ) [ ] + @Test + public void leftOfBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 2), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, -3, 2), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); } + // ( [ ) ] + @Test + public void leftOverlapWithBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 2), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, -1, 2), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); + } + + // ( [ ] ) + @Test + public void enclosedByBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 4), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, 1, 2), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); + } + + // [ ( ) ] + @Test + public void enclosesBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 2), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, -1, 4), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); + } + + // [ ( ] ) + @Test + public void rightOverlapWithBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 2), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, 1, 2), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); + } + + // [ ] ( ) + @Test + public void rightOfBucketTest() + { + DataSegment origin = createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, 0, 2), 100); + SegmentsCostCacheV3.Builder builder = SegmentsCostCacheV3.builder(); + builder.addSegment(origin); + SegmentsCostCacheV3 cache = builder.build(); + + DataSegment segment = createSegment("blah", shiftedXHInterval(REFERENCE_TIME, 3, 2), 100); + + double cost = cache.cost(segment); + double expectedCost = CostBalancerStrategy.computeJointSegmentsCost(origin, segment); + Assert.assertEquals(cost, expectedCost, EPSILON); + } + + private static double getExpectedCost(List segments, DataSegment referenceSegment) + { + double cost = 0; + for (DataSegment segment : segments) { + cost += CostBalancerStrategy.computeJointSegmentsCost(segment, referenceSegment); + } + return cost; + } private static Interval shiftedXHInterval(DateTime REFERENCE_TIME, int shiftInHours, int X) { From 918ff9e43ff8fb2ed55b4e220e725eb2a1cef5bb Mon Sep 17 00:00:00 2001 From: Amatya Date: Sun, 17 Apr 2022 18:23:30 +0530 Subject: [PATCH 12/13] Fix checkstyle, remove SegmentsCostCacheV2 --- .../apache/druid/java/util/common/Treap.java | 345 ----------- .../coordinator/cost/SegmentsCostCacheV2.java | 551 ------------------ .../coordinator/cost/SegmentsCostCacheV3.java | 17 +- .../cost/SegmentsCostCacheTest.java | 24 +- .../cost/SegmentsCostCacheV2Test.java | 198 ------- .../cost/SegmentsCostCacheV3Test.java | 38 +- 6 files changed, 39 insertions(+), 1134 deletions(-) delete mode 100644 core/src/main/java/org/apache/druid/java/util/common/Treap.java delete mode 100644 server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java delete mode 100644 server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java diff --git a/core/src/main/java/org/apache/druid/java/util/common/Treap.java b/core/src/main/java/org/apache/druid/java/util/common/Treap.java deleted file mode 100644 index b3e979fdc2cc..000000000000 --- a/core/src/main/java/org/apache/druid/java/util/common/Treap.java +++ /dev/null @@ -1,345 +0,0 @@ -package org.apache.druid.java.util.common; - -import javax.annotation.Nullable; -import java.util.ArrayList; -import java.util.List; - -public abstract class Treap, Y> -{ - protected TreapNode root; - protected final TreapNode NULL; - - public Treap() - { - NULL = new TreapNode(null); - NULL.left = NULL.right = NULL; - NULL.priority = Double.POSITIVE_INFINITY; - root = NULL; - } - - public boolean isEmpty() - { - return NULL.equals(root); - } - - public int size() - { - return root.size; - } - - public boolean contains(X val) - { - return contains(val, root); - } - - public X lower(X val) - { - return lower(val, root).val; - } - - public X upper(X val) - { - return upper(val, root).val; - } - - public X floor(X val) - { - return floor(val, root).val; - } - - public X ceil(X val) - { - return ceil(val, root).val; - } - - public boolean insert(X val) - { - int oldSize = root.size; - root = insert(new TreapNode(val), root); - return root.size > oldSize; - } - - public boolean remove(X val) - { - int oldSize = root.size; - root = remove(val, root); - return root.size < oldSize; - } - - public X getMin() - { - TreapNode node = root; - while (!NULL.equals(node.left)) { - node = node.left; - } - return node.val; - } - - public X getMax() - { - TreapNode node = root; - while (!NULL.equals(node.right)) { - node = node.right; - } - return node.val; - } - - public void update(X val, Y lazy, boolean dir) - { - if (dir) { - root = update(root, val, null, lazy); - } else { - root = update(root, null, val, lazy); - } - } - - public List toList() - { - List list = new ArrayList<>(); - accumulate(root, list); - return list; - } - - protected abstract Y getVal(X val); - - protected abstract X setVal(X val, Y lazy); - - protected abstract Y add(Y a, Y b); - - protected abstract Y multiply(int a, Y b); - - protected abstract Y zero(); - - private boolean contains(X val, TreapNode node) - { - if (NULL.equals(node)) { - return false; - } - final int cmp = val.compareTo(node.val); - if (cmp < 0) { - return contains(val, node.left); - } - if (cmp > 0) { - return contains(val, node.right); - } - return true; - } - - private TreapNode lower(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp <= 0) { - return lower(val, node.left); - } else { - TreapNode ret = lower(val, node.right); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode upper(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp >= 0) { - return upper(val, node.right); - } else { - TreapNode ret = upper(val, node.left); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode floor(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp < 0) { - return floor(val, node.left); - } else { - TreapNode ret = floor(val, node.right); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode ceil(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - final int cmp = val.compareTo(node.val); - if (cmp > 0) { - return ceil(val, node.right); - } else { - TreapNode ret = ceil(val, node.left); - return (NULL.equals(ret)) ? node : ret; - } - } - - private TreapNode insert(TreapNode val, TreapNode node) - { - if (NULL.equals(node)) { - return val; - } - Pair pair = split(node, val.val); - node = merge(pair.lhs, val); - node = merge(node, pair.rhs); - return node; - } - - private TreapNode remove(X val, TreapNode node) - { - if (NULL.equals(node)) { - return node; - } - Pair pair = split(node, val); - TreapNode lower = lower(val, pair.lhs); - if (NULL.equals(lower)) { - return pair.rhs; - } - return merge(split(pair.lhs, lower.val).lhs, pair.rhs); - } - - private Pair split(TreapNode node, X val) - { - if (NULL.equals(node)) { - return Pair.of(NULL, NULL); - } - node.lazyPropogate(); - final int cmp = val.compareTo(node.val); - Pair pair; - if (cmp < 0) { - pair = split(node.left, val); - node.left = pair.rhs; - pair = Pair.of(pair.lhs, node); - } else { - pair = split(node.right, val); - node.right = pair.lhs; - pair = Pair.of(node, pair.rhs); - } - node.recompute(); - return pair; - } - - private TreapNode merge(TreapNode left, TreapNode right) - { - if (NULL.equals(left)) { - return right; - } - if (NULL.equals(right)) { - return left; - } - left.lazyPropogate(); - right.lazyPropogate(); - TreapNode node; - if (left.priority < right.priority) { - left.right = merge(left.right, right); - node = left; - } else { - right.left = merge(left, right.left); - node = right; - } - node.recompute(); - return node; - } - - private TreapNode update(TreapNode node, @Nullable X begin, @Nullable X end, Y lazy) - { - TreapNode left = NULL; - TreapNode right = NULL; - if (begin != null) { - Pair pair = split(node, begin); - left = pair.lhs; - node = pair.rhs; - } - if (end != null) { - Pair pair = split(node, end); - node = pair.lhs; - right = pair.rhs; - } - node.lazy = add(node.lazy, lazy); - node = merge(left, node); - node = merge(node, right); - return node; - } - - private void accumulate(TreapNode node, List list) - { - if (NULL.equals(node)) { - return; - } - node.lazyPropogate(); - accumulate(node.left, list); - list.add(node.val); - accumulate(node.right, list); - } - - class TreapNode - { - X val; - Y sum; - Y lazy; - TreapNode left; - TreapNode right; - double priority; - int size; - - TreapNode(@Nullable X val) - { - this(val, NULL, NULL); - if (val != null) { - sum = getVal(val); - size = 1; - } - } - - TreapNode(@Nullable X val, @Nullable TreapNode left, @Nullable TreapNode right) - { - this.val = val; - this.left = left; - this.right = right; - this.priority = Math.random(); - this.sum = zero(); - this.lazy = zero(); - } - - public void recompute() - { - if (NULL.equals(this)) { - return; - } - size = 1 + left.size + right.size; - sum = getVal(val); - left.lazyPropogate(); - right.lazyPropogate(); - sum = add(sum, add(left.sum, right.sum)); - } - - public void lazyPropogate() - { - if (NULL.equals(this)) { - return; - } - val = setVal(val, lazy); - sum = add(sum, multiply(size, lazy)); - if (!NULL.equals(left)) { - left.lazy = add(left.lazy, lazy); - } - if (!NULL.equals(right)) { - right.lazy = add(right.lazy, lazy); - } - lazy = zero(); - } - - @Override - public boolean equals(Object that) - { - return this == that; - } - } -} diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java deleted file mode 100644 index 14150c10768f..000000000000 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2.java +++ /dev/null @@ -1,551 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.server.coordinator.cost; - -import com.google.common.base.Preconditions; -import com.google.common.collect.Ordering; -import org.apache.commons.math3.util.FastMath; -import org.apache.druid.java.util.common.ISE; -import org.apache.druid.java.util.common.Intervals; -import org.apache.druid.java.util.common.Pair; -import org.apache.druid.java.util.common.Treap; -import org.apache.druid.java.util.common.granularity.DurationGranularity; -import org.apache.druid.java.util.common.guava.Comparators; -import org.apache.druid.server.coordinator.CostBalancerStrategy; -import org.apache.druid.timeline.DataSegment; -import org.apache.druid.timeline.SegmentId; -import org.joda.time.Interval; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.ListIterator; -import java.util.NavigableMap; -import java.util.TreeMap; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - -/** - * SegmentsCostCache provides faster way to calculate cost function proposed in {@link CostBalancerStrategy}. - * See https://github.com/apache/druid/pull/2972 for more details about the cost function. - * - * Joint cost for two segments (you can make formulas below readable by copy-pasting to - * https://www.codecogs.com/latex/eqneditor.php): - * - * cost(Y, Y) = \int_{x_0}^{x_1} \int_{y_0}^{y_1} e^{-\lambda |x-y|}dxdy - * or - * cost(Y, Y) = e^{y_0 + y_1} (e^{x_0} - e^{x_1})(e^{y_0} - e^{y_1}) (*) - * if x_0 <= x_1 <= y_0 <= y_1 - * (*) lambda coefficient is omitted for simplicity. - * - * For a group of segments {S_xi}, i = {0, n} total joint cost with segment S_y could be calculated as: - * - * cost(Y, Y) = \sum cost(X_i, Y) = e^{y_0 + y_1} (e^{y_0} - e^{y_1}) \sum (e^{xi_0} - e^{xi_1}) - * if xi_0 <= xi_1 <= y_0 <= y_1 - * and - * cost(Y, Y) = \sum cost(X_i, Y) = (e^{y_0} - e^{y_1}) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) - * if y_0 <= y_1 <= xi_0 <= xi_1 - * - * SegmentsCostCache stores pre-computed sums for a group of segments {S_xi}: - * - * 1) \sum (e^{xi_0} - e^{xi_1}) -> leftSum - * 2) \sum e^{xi_0 + xi_1} (e^{xi_0} - e^{xi_1}) -> rightSum - * - * so that calculation of joint cost function for segment S_y became a O(1 + m) complexity task, where m - * is the number of segments in {S_xi} that overlaps S_y. - * - * Segments are stored in buckets. Bucket is a subset of segments contained in SegmentsCostCache, so that - * startTime of all segments inside a bucket are in the same time interval (with some granularity): - * - * |------------------------|--------------------------|-----------------------|-------- .... - * t_0 t_0+D t_0 + 2D t0 + 3D .... - * S_x1 S_x2 S_x3 S_x4 S_x5 S_x6 S_x7 S_x8 S_x9 - * bucket1 bucket2 bucket3 - * - * Reasons to store segments in Buckets: - * - * 1) Cost function tends to 0 as distance between segments' intervals increases; buckets - * are used to avoid redundant 0 calculations for thousands of times - * 2) To reduce number of calculations when segment is added or removed from SegmentsCostCache - * 3) To avoid infinite values during exponents calculations - * - */ -public class SegmentsCostCacheV2 -{ - /** - * HALF_LIFE_DAYS defines how fast joint cost function tends to 0 as distance between segments' intervals increasing. - * The value of 1 day means that cost function of co-locating two segments which have 1 days between their intervals - * is 0.5 of the cost, if the intervals are adjacent. If the distance is 2 days, then 0.25, etc. - */ - private static final double HALF_LIFE_DAYS = 1.0; - private static final double LAMBDA = Math.log(2) / HALF_LIFE_DAYS; - private static final double MILLIS_FACTOR = TimeUnit.DAYS.toMillis(1) / LAMBDA; - - /** - * LIFE_THRESHOLD is used to avoid calculations for segments that are "far" - * from each other and thus cost(Y,Y) ~ 0 for these segments - */ - private static final long LIFE_THRESHOLD = TimeUnit.DAYS.toMillis(30); - - /** - * Bucket interval defines duration granularity for segment buckets. Number of buckets control the trade-off - * between updates (add/remove segment operation) and joint cost calculation: - * 1) updates complexity is increasing when number of buckets is decreasing (as buckets contain more segments) - * 2) joint cost calculation complexity is increasing with increasing of buckets number - */ - private static final long BUCKET_INTERVAL = TimeUnit.DAYS.toMillis(15); - private static final DurationGranularity BUCKET_GRANULARITY = new DurationGranularity(BUCKET_INTERVAL, 0); - - private static final Comparator INTERVAL_COMPARATOR = Comparators.intervalsByStartThenEnd(); - - private static final Comparator BUCKET_INTERVAL_COMPARATOR = - Comparator.comparing(Bucket::getInterval, Comparators.intervalsByStartThenEnd()); - - private static final Ordering INTERVAL_ORDERING = Ordering.from(Comparators.intervalsByStartThenEnd()); - private static final Ordering BUCKET_ORDERING = Ordering.from(BUCKET_INTERVAL_COMPARATOR); - - private final ArrayList sortedBuckets; - private final ArrayList intervals; - - SegmentsCostCacheV2(ArrayList sortedBuckets) - { - this.sortedBuckets = Preconditions.checkNotNull(sortedBuckets, "buckets should not be null"); - this.intervals = sortedBuckets.stream().map(Bucket::getInterval).collect(Collectors.toCollection(ArrayList::new)); - Preconditions.checkArgument( - BUCKET_ORDERING.isOrdered(sortedBuckets), - "buckets must be ordered by interval" - ); - } - - public double cost(DataSegment segment) - { - double cost = 0.0; - int index = Collections.binarySearch(intervals, segment.getInterval(), Comparators.intervalsByStartThenEnd()); - index = (index >= 0) ? index : -index - 1; - - for (ListIterator it = sortedBuckets.listIterator(index); it.hasNext(); ) { - Bucket bucket = it.next(); - if (!bucket.inCalculationInterval(segment)) { - break; - } - cost += bucket.cost(segment); - } - - for (ListIterator it = sortedBuckets.listIterator(index); it.hasPrevious(); ) { - Bucket bucket = it.previous(); - if (!bucket.inCalculationInterval(segment)) { - break; - } - cost += bucket.cost(segment); - } - - return cost; - } - - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private final NavigableMap buckets = new TreeMap<>(Comparators.intervalsByStartThenEnd()); - - public Builder addSegment(DataSegment segment) - { - Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); - builder.addSegment(segment); - return this; - } - - public Builder removeSegment(DataSegment segment) - { - Interval interval = getBucketInterval(segment); - buckets.computeIfPresent( - interval, - // If there are no move segments, returning null in computeIfPresent() removes the interval from the buckets - // map - (i, builder) -> builder.removeSegment(segment).isEmpty() ? null : builder - ); - return this; - } - - public boolean isEmpty() - { - return buckets.isEmpty(); - } - - public SegmentsCostCacheV2 build() - { - return new SegmentsCostCacheV2( - buckets - .values() - .stream() - .map(Bucket.Builder::build) - .collect(Collectors.toCollection(ArrayList::new)) - ); - } - - private static Interval getBucketInterval(DataSegment segment) - { - return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart()); - } - } - - static class Bucket - { - private final Interval interval; - private final Interval calculationInterval; - private final ArrayList sortedIntervals; - private final double[] leftSum; - private final double[] rightSum; - - private final double[] cumStart; - private final double[] cumStartExp; - private final double[] cumStartExpInv; - private final double[] cumEnd; - private final double[] cumEndExp; - private final double[] cumEndExpInv; - - Bucket(Interval interval, ArrayList sortedIntervals, double[] leftSum, double[] rightSum) - { - this.interval = Preconditions.checkNotNull(interval, "interval"); - this.sortedIntervals = Preconditions.checkNotNull(sortedIntervals, "sortedSegments"); - this.leftSum = Preconditions.checkNotNull(leftSum, "leftSum"); - this.rightSum = Preconditions.checkNotNull(rightSum, "rightSum"); - Preconditions.checkArgument(sortedIntervals.size() == leftSum.length && sortedIntervals.size() == rightSum.length); - Preconditions.checkArgument(INTERVAL_ORDERING.isOrdered(sortedIntervals)); - this.calculationInterval = new Interval( - interval.getStart().minus(LIFE_THRESHOLD), - interval.getEnd().plus(LIFE_THRESHOLD) - ); - - int n = leftSum.length; - - cumStart = new double[n + 1]; - cumStartExp = new double[n + 1]; - cumStartExpInv = new double[n + 1]; - cumEnd = new double[n + 1]; - cumEndExp = new double[n + 1]; - cumEndExpInv = new double[n + 1]; - for (int i = 0; i < n; i++) { - double start = convertStart(sortedIntervals.get(i), interval); - cumStart[i + 1] = cumStart[i] + start; - cumStartExp[i + 1] = cumStartExp[i] + FastMath.exp(start); - cumStartExpInv[i + 1] = cumStartExpInv[i] + FastMath.exp(-start); - - double end = convertEnd(sortedIntervals.get(i), interval); - cumEnd[i + 1] = cumEnd[i] + end; - cumEndExp[i + 1] = cumEndExp[i] + FastMath.exp(end); - cumEndExpInv[i + 1] = cumEndExpInv[i] + FastMath.exp(-end); - } - } - - Interval getInterval() - { - return interval; - } - - boolean inCalculationInterval(DataSegment dataSegment) - { - return calculationInterval.overlaps(dataSegment.getInterval()); - } - - double cost(DataSegment dataSegment) - { - // cost is calculated relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); - - // avoid calculation for segments outside of LIFE_THRESHOLD - if (!inCalculationInterval(dataSegment)) { - throw new ISE("Segment is not within calculation interval"); - } - - int index = Collections.binarySearch(sortedIntervals, dataSegment.getInterval(), INTERVAL_COMPARATOR); - index = (index >= 0) ? index : -index - 1; - return leftCost(dataSegment, t0, t1, index) + rightCost(dataSegment, t0, t1, index); - } - - private double leftCost(DataSegment dataSegment, double t0, double t1, int index) - { - if (index - 1 < 0) { - return 0; - } - double exp0 = FastMath.exp(t0); - double expInv0 = 1 / exp0; - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; - double leftCost = 0.0; - // add to cost all left-overlapping segments - int rightBound = index - 1; - int leftBound = leftBoundary(0, index - 1, dataSegment.getInterval()); - leftCost += 2 * (cumEnd[rightBound + 1] - cumEnd[leftBound]); - leftCost -= 2 * (rightBound - leftBound + 1) * t0; - leftCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - leftCost -= expInv0 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - leftCost -= expInv1 * (cumEndExp[rightBound + 1] - cumEndExp[leftBound]); - // add left-non-overlapping segments - if (leftBound > 0) { - leftCost += leftSum[leftBound - 1] * (expInv1 - expInv0); - } - return leftCost; - } - - private double rightCost(DataSegment dataSegment, double t0, double t1, int index) - { - int n = leftSum.length; - if (index >= n) { - return 0; - } - double exp0 = FastMath.exp(t0); - double exp1 = FastMath.exp(t1); - double expInv1 = 1 / exp1; - double rightCost = 0.0; - int leftBound = index; - int rightBound = rightBoundary(index, n - 1, dataSegment.getInterval()); - // add all right-overlapping segments - rightCost += 2 * (rightBound - leftBound + 1) * t1; - rightCost -= 2 * (cumStart[rightBound + 1] - cumStart[leftBound]); - rightCost += exp0 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - rightCost += expInv1 * (cumStartExp[rightBound + 1] - cumStartExp[leftBound]); - rightCost -= exp0 * (cumStartExpInv[rightBound + 1] - cumStartExpInv[leftBound]); - rightCost -= exp1 * (cumEndExpInv[rightBound + 1] - cumEndExpInv[leftBound]); - // add right-non-overlapping segments - if (rightBound + 1 < n) { - rightCost += rightSum[rightBound + 1] * (exp0 - exp1); - } - return rightCost; - } - - private int leftBoundary(int l, int r, Interval interval) - { - if (l == r) { - return interval.overlaps(sortedIntervals.get(l)) ? l : r + 1; - } - int m = (l + r) / 2; - if (interval.overlaps(sortedIntervals.get(m))) { - return leftBoundary(l, m, interval); - } else { - return leftBoundary(m + 1, r, interval); - } - } - - private int rightBoundary(int l, int r, Interval interval) - { - if (l == r) { - return interval.overlaps(sortedIntervals.get(r)) ? r : l - 1; - } - int m = (l + r + 1) / 2; - if (interval.overlaps(sortedIntervals.get(m))) { - return rightBoundary(m, r, interval); - } else { - return rightBoundary(l, m - 1, interval); - } - } - - private static double convertStart(Interval interval, Interval reference) - { - return toLocalInterval(interval.getStartMillis(), reference); - } - - private static double convertEnd(Interval interval, Interval reference) - { - return toLocalInterval(interval.getEndMillis(), reference); - } - - private static double toLocalInterval(long millis, Interval interval) - { - return millis / MILLIS_FACTOR - interval.getStartMillis() / MILLIS_FACTOR; - } - - public static Builder builder(Interval interval) - { - return new Builder(interval); - } - - static class Builder - { - protected final Interval interval; - private SegmentTreap treap = new SegmentTreap(); - public Builder(Interval interval) - { - this.interval = interval; - } - - public Builder addSegment(DataSegment dataSegment) - { - if (!interval.contains(dataSegment.getInterval().getStartMillis())) { - throw new ISE("Failed to add segment to bucket: interval is not covered by this bucket"); - } - - // all values are pre-computed relatively to bucket start (which is considered as 0) - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); - - double leftValue = FastMath.exp(t0) - FastMath.exp(t1); - double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); - - SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, leftValue, rightValue); - - // left/right value should be added to left/right sums for elements greater/lower than current segment - treap.update(segmentAndSum, Pair.of(leftValue, 0.0), true); - treap.update(segmentAndSum, Pair.of(0.0, rightValue), false); - - // leftSum_i = leftValue_i + \sum leftValue_j = leftValue_i + leftSum_{i-1} , j < i - SegmentAndSum lower = treap.lower(segmentAndSum); - if (lower != null) { - segmentAndSum.leftSum = leftValue + lower.leftSum; - } - - // rightSum_i = rightValue_i + \sum rightValue_j = rightValue_i + rightSum_{i+1} , j > i - SegmentAndSum higher = treap.upper(segmentAndSum); - if (higher != null) { - segmentAndSum.rightSum = rightValue + higher.rightSum; - } - - if (!treap.insert(segmentAndSum)) { - throw new ISE("expect new segment"); - } - return this; - } - - public Builder removeSegment(DataSegment dataSegment) - { - SegmentAndSum segmentAndSum = new SegmentAndSum(dataSegment, 0.0, 0.0); - - if (!treap.remove(segmentAndSum)) { - return this; - } - - double t0 = convertStart(dataSegment.getInterval(), interval); - double t1 = convertEnd(dataSegment.getInterval(), interval); - - double leftValue = FastMath.exp(t0) - FastMath.exp(t1); - double rightValue = FastMath.exp(-t1) - FastMath.exp(-t0); - - treap.update(segmentAndSum, Pair.of(-leftValue, 0.0), true); - treap.update(segmentAndSum, Pair.of(0.0, -rightValue), false); - - return this; - } - - public boolean isEmpty() - { - return treap.isEmpty(); - } - - public Bucket build() - { - ArrayList intervalsList = new ArrayList<>(); - double[] leftSum = new double[treap.size()]; - double[] rightSum = new double[treap.size()]; - int i = 0; - for (SegmentAndSum segmentAndSum : treap.toList()) { - intervalsList.add(segmentAndSum.interval); - leftSum[i] = segmentAndSum.leftSum; - rightSum[i] = segmentAndSum.rightSum; - ++i; - } - treap = null; - long bucketEndMillis = intervalsList - .stream() - .mapToLong(Interval::getEndMillis) - .max() - .orElseGet(interval::getEndMillis); - return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervalsList, leftSum, rightSum); - } - } - } - - static class SegmentAndSum implements Comparable - { - private final Interval interval; - private final SegmentId segmentId; - private double leftSum; - private double rightSum; - - SegmentAndSum(DataSegment dataSegment, double leftSum, double rightSum) - { - this.interval = dataSegment.getInterval(); - this.segmentId = dataSegment.getId(); - this.leftSum = leftSum; - this.rightSum = rightSum; - } - - @Override - public int compareTo(SegmentAndSum o) - { - int c = Comparators.intervalsByStartThenEnd().compare(interval, o.interval); - return c != 0 ? c : segmentId.compareTo(o.segmentId); - } - - @Override - public boolean equals(Object obj) - { - throw new UnsupportedOperationException("Use SegmentAndSum.compareTo()"); - } - - @Override - public int hashCode() - { - throw new UnsupportedOperationException(); - } - } - - public static class SegmentTreap extends Treap> - { - - static final Pair ZERO = Pair.of(0.0, 0.0); - - @Override - protected Pair getVal(SegmentAndSum val) - { - return Pair.of(val.leftSum, val.rightSum); - } - - @Override - protected SegmentAndSum setVal(SegmentAndSum val, Pair lazy) - { - val.leftSum += lazy.lhs; - val.rightSum += lazy.rhs; - return val; - } - - @Override - protected Pair zero() - { - return ZERO; - } - - @Override - protected Pair add(Pair a, Pair b) - { - return Pair.of(a.lhs + b.lhs, a.rhs + b.rhs); - } - - @Override - protected Pair multiply(int a, Pair b) - { - return Pair.of(a * b.lhs, a * b.rhs); - } - } -} diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java index 5556c45c602c..77dc0a47ed7f 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -200,7 +200,8 @@ public static Builder builder() return new Builder(); } - private static boolean isAllGranularity(DataSegment segment) { + private static boolean isAllGranularity(DataSegment segment) + { return segment.getInterval().getStartMillis() == JodaUtils.MIN_INSTANT && segment.getInterval().getEndMillis() == JodaUtils.MAX_INSTANT; } @@ -219,13 +220,11 @@ public Builder addSegment(DataSegment segment) if (!allGranularitySegments.add(segment.getId())) { throw new ISE("expect new segment"); } - } - else if (isAdhoc(segment)) { + } else if (isAdhoc(segment)) { if (!adhocSegments.add(segment.getId())) { throw new ISE("expect new segment"); } - } - else { + } else { Bucket.Builder builder = buckets.computeIfAbsent(getBucketInterval(segment), Bucket::builder); builder.addSegment(segment); } @@ -239,8 +238,7 @@ public Builder removeSegment(DataSegment segment) } if (isAdhoc(segment)) { adhocSegments.remove(segment.getId()); - } - else { + } else { Interval interval = getBucketInterval(segment); buckets.computeIfPresent( interval, @@ -286,7 +284,8 @@ private static Interval getBucketInterval(DataSegment segment) return BUCKET_GRANULARITY.bucket(segment.getInterval().getStart()); } - private boolean isAdhoc(DataSegment segment) { + private boolean isAdhoc(DataSegment segment) + { double duration = segment.getInterval().getEndMillis() / MILLIS_FACTOR - segment.getInterval().getStartMillis() / MILLIS_FACTOR; return duration > INTERVAL_THRESHOLD / MILLIS_FACTOR; @@ -341,7 +340,6 @@ static class Bucket END_EXP = FastMath.exp(END_VAL); END_EXP_INV = FastMath.exp(-END_VAL); - startValSum = new double[n + 1]; startExpSum = new double[n + 1]; startExpInvSum = new double[n + 1]; @@ -506,6 +504,7 @@ static class Builder { protected final Interval interval; private final Set segmentSet = new HashSet<>(); + public Builder(Interval interval) { this.interval = interval; diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java index 6396425782ba..c24efbb02198 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheTest.java @@ -39,7 +39,7 @@ public class SegmentsCostCacheTest { - private static final Random random = new Random(23894); + private static final Random RANDOM = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.0000001; @@ -139,11 +139,11 @@ public void multipleSegmentsCostTest() @Test public void perfComparisonTest() { - final int N = 100000; + final int n = 100000; List dataSegments = new ArrayList<>(1000); - for (int i = 0; i < N; ++i) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * random.nextInt(60)), 100)); + for (int i = 0; i < n; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * RANDOM.nextInt(60)), 100)); } DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shiftedRandomInterval(REFERENCE_TIME, 5), 100); @@ -159,7 +159,7 @@ public void perfComparisonTest() SegmentsCostCache cache = prototype.build(); end = System.currentTimeMillis(); - System.out.println("Insertion time for " + N + " segments: " + (end - start) + " ms"); + System.out.println("Insertion time for " + n + " segments: " + (end - start) + " ms"); start = System.currentTimeMillis(); @@ -192,17 +192,17 @@ public void bucketCorrectnessTest() // Same as reference interval for (int i = 0; i < 100; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(20), 10), 100)); } // Overlapping intervals of larger size that enclose the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 70, 100), 100)); } // intervals of small size that are enclosed within the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 20, 1), 100)); } // intervals not intersecting, lying to its left @@ -241,17 +241,17 @@ public void overallCorrectnessTest() // Same as reference interval for (int i = 0; i < 100; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(20), 10), 100)); } // Overlapping intervals of larger size that enclose the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 70, 100), 100)); } // intervals of small size that are enclosed within the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 20, 1), 100)); } // intervals not intersecting, lying to its left @@ -299,7 +299,7 @@ private static Interval shiftedRandomInterval(DateTime REFERENCE_TIME, int shift { return new Interval( REFERENCE_TIME.plusHours(shiftInHours), - REFERENCE_TIME.plusHours(shiftInHours + random.nextInt(1000)) + REFERENCE_TIME.plusHours(shiftInHours + RANDOM.nextInt(1000)) ); } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java deleted file mode 100644 index 86f889e20e8b..000000000000 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV2Test.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.server.coordinator.cost; - -import org.apache.druid.java.util.common.DateTimes; -import org.apache.druid.timeline.DataSegment; -import org.joda.time.DateTime; -import org.joda.time.Interval; -import org.junit.Assert; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; - -public class SegmentsCostCacheV2Test -{ - - private static final String DATA_SOURCE = "dataSource"; - private static DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); - private static final double EPSILON = 0.00000001; - - @Test - public void segmentCacheTest() - { - SegmentsCostCacheV2.Builder cacheBuilder = SegmentsCostCacheV2.builder(); - cacheBuilder.addSegment(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100)); - SegmentsCostCacheV2 cache = cacheBuilder.build(); - Assert.assertEquals( - 7.8735899489011E-4, - cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100)), - EPSILON - ); - } - - @Test - public void notInCalculationIntervalCostTest() - { - SegmentsCostCacheV2.Builder cacheBuilder = SegmentsCostCacheV2.builder(); - cacheBuilder.addSegment( - createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100) - ); - SegmentsCostCacheV2 cache = cacheBuilder.build(); - Assert.assertEquals( - 0, - cache.cost(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), 100)), - EPSILON - ); - } - - @Test - public void twoSegmentsCostTest() - { - DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); - DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); - - SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(5), - REFERENCE_TIME.plusHours(5) - )); - - prototype.addSegment(segmentA); - SegmentsCostCacheV2.Bucket bucket = prototype.build(); - - double segmentCost = bucket.cost(segmentB); - Assert.assertEquals(7.8735899489011E-4, segmentCost, EPSILON); - } - - @Test - public void calculationIntervalTest() - { - DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); - DataSegment segmentB = createSegment( - DATA_SOURCE, - shifted1HInterval(REFERENCE_TIME, (int) TimeUnit.DAYS.toHours(50)), - 100 - ); - - SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder( - new Interval(REFERENCE_TIME.minusHours(5), REFERENCE_TIME.plusHours(5)) - ); - prototype.addSegment(segmentA); - SegmentsCostCacheV2.Bucket bucket = prototype.build(); - - Assert.assertTrue(bucket.inCalculationInterval(segmentA)); - Assert.assertFalse(bucket.inCalculationInterval(segmentB)); - } - - @Test - public void sameSegmentCostTest() - { - DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); - DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); - - SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(5), - REFERENCE_TIME.plusHours(5) - )); - - prototype.addSegment(segmentA); - SegmentsCostCacheV2.Bucket bucket = prototype.build(); - - double segmentCost = bucket.cost(segmentB); - Assert.assertEquals(8.26147353873985E-4, segmentCost, EPSILON); - } - - @Test - public void multipleSegmentsCostTest() - { - DataSegment segmentA = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, -2), 100); - DataSegment segmentB = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 0), 100); - DataSegment segmentC = createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, 2), 100); - - SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(5), - REFERENCE_TIME.plusHours(5) - )); - - prototype.addSegment(segmentA); - prototype.addSegment(segmentC); - SegmentsCostCacheV2.Bucket bucket = prototype.build(); - - double segmentCost = bucket.cost(segmentB); - - Assert.assertEquals(0.001574717989780039, segmentCost, EPSILON); - } - - @Test - public void randomSegmentsCostTest() - { - List dataSegments = new ArrayList<>(1000); - Random random = new Random(1); - for (int i = 0; i < 1000; ++i) { - dataSegments.add(createSegment(DATA_SOURCE, shifted1HInterval(REFERENCE_TIME, random.nextInt(20)), 100)); - } - - DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shifted1HInterval(REFERENCE_TIME, 5), 100); - - SegmentsCostCacheV2.Bucket.Builder prototype = SegmentsCostCacheV2.Bucket.builder(new Interval( - REFERENCE_TIME.minusHours(1), - REFERENCE_TIME.plusHours(25) - )); - - long start = System.currentTimeMillis(); - - dataSegments.forEach(prototype::addSegment); - SegmentsCostCacheV2.Bucket bucket = prototype.build(); - - long end = System.currentTimeMillis(); - System.out.println(end - start); - - double cost = bucket.cost(referenceSegment); - Assert.assertEquals(0.7065117101966677, cost, EPSILON); - } - - private static Interval shifted1HInterval(DateTime REFERENCE_TIME, int shiftInHours) - { - return new Interval( - REFERENCE_TIME.plusHours(shiftInHours), - REFERENCE_TIME.plusHours(shiftInHours + 1) - ); - } - - private static DataSegment createSegment(String dataSource, Interval interval, long size) - { - return new DataSegment( - dataSource, - interval, - UUID.randomUUID().toString(), - new ConcurrentHashMap<>(), - new ArrayList<>(), - new ArrayList<>(), - null, - 0, - size - ); - } -} diff --git a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java index 0555e46c3628..3e7f847bcec2 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3Test.java @@ -41,7 +41,7 @@ public class SegmentsCostCacheV3Test { - private static final Random random = new Random(23894); + private static final Random RANDOM = new Random(23894); private static final String DATA_SOURCE = "dataSource"; private static final DateTime REFERENCE_TIME = DateTimes.of("2014-01-01T00:00:00"); private static final double EPSILON = 0.0000001; @@ -141,11 +141,11 @@ public void multipleSegmentsCostTest() @Test public void perfComparisonTest() { - final int N = 100000; + final int n = 100000; List dataSegments = new ArrayList<>(1000); - for (int i = 0; i < N; ++i) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * random.nextInt(60)), 100)); + for (int i = 0; i < n; ++i) { + dataSegments.add(createSegment(DATA_SOURCE, shiftedRandomInterval(REFERENCE_TIME, 24 * RANDOM.nextInt(60)), 100)); } DataSegment referenceSegment = createSegment("ANOTHER_DATA_SOURCE", shiftedRandomInterval(REFERENCE_TIME, 5), 100); @@ -159,7 +159,7 @@ public void perfComparisonTest() dataSegments.forEach(prototype::addSegment); SegmentsCostCacheV3 cache = prototype.build(); end = System.currentTimeMillis(); - System.out.println("Insertion time for " + N + " segments: " + (end - start) + " ms"); + System.out.println("Insertion time for " + n + " segments: " + (end - start) + " ms"); start = System.currentTimeMillis(); for (int i = 0; i < 1000; i++) { @@ -188,17 +188,17 @@ public void bucketCorrectnessTest() // Same as reference interval for (int i = 0; i < 100; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(20), 10), 100)); } // Overlapping intervals of larger size that enclose the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 70, 100), 100)); } // intervals of small size that are enclosed within the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 20, 1), 100)); } // intervals not intersecting, lying to its left @@ -234,17 +234,17 @@ public void overallCorrectnessTest() // Same as reference interval for (int i = 0; i < 100; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(20), 10), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(20), 10), 100)); } // Overlapping intervals of larger size that enclose the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 70, 100), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 70, 100), 100)); } // intervals of small size that are enclosed within the reference interval for (int i = 0; i < 10; i++) { - dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, random.nextInt(40) - 20, 1), 100)); + dataSegments.add(createSegment(DATA_SOURCE, shiftedXHInterval(REFERENCE_TIME, RANDOM.nextInt(40) - 20, 1), 100)); } // intervals not intersecting, lying to its left @@ -280,18 +280,18 @@ public void testLargeIntervals() } // add random large intervals for (int i = 0; i < 15; i++) { - intervals.add(new Interval(REFERENCE_TIME.minusYears(random.nextInt(30)), - REFERENCE_TIME.plusYears(random.nextInt(30)))); + intervals.add(new Interval(REFERENCE_TIME.minusYears(RANDOM.nextInt(30)), + REFERENCE_TIME.plusYears(RANDOM.nextInt(30)))); } // add random medium intervals for (int i = 0; i < 30; i++) { - intervals.add(new Interval(REFERENCE_TIME.minusWeeks(random.nextInt(30)), - REFERENCE_TIME.plusWeeks(random.nextInt(30)))); + intervals.add(new Interval(REFERENCE_TIME.minusWeeks(RANDOM.nextInt(30)), + REFERENCE_TIME.plusWeeks(RANDOM.nextInt(30)))); } // add random small intervals for (int i = 0; i < 50; i++) { - intervals.add(new Interval(REFERENCE_TIME.minusHours(random.nextInt(30)), - REFERENCE_TIME.plusHours(random.nextInt(30)))); + intervals.add(new Interval(REFERENCE_TIME.minusHours(RANDOM.nextInt(30)), + REFERENCE_TIME.plusHours(RANDOM.nextInt(30)))); } List segments = intervals.stream() @@ -319,7 +319,7 @@ public void testLargeIntervals() for (DataSegment referenceSegment : referenceSegments) { double expectedCost = getExpectedCost(segments, referenceSegment); double cost = cache.cost(referenceSegment); - Assert.assertEquals(1 , expectedCost / cost, 0.01); + Assert.assertEquals(1, expectedCost / cost, 0.01); } } @@ -445,7 +445,7 @@ private static Interval shiftedRandomInterval(DateTime REFERENCE_TIME, int shift { return new Interval( REFERENCE_TIME.plusHours(shiftInHours), - REFERENCE_TIME.plusHours(shiftInHours + random.nextInt(100)) + REFERENCE_TIME.plusHours(shiftInHours + RANDOM.nextInt(100)) ); } From 929606384d78a028d15a466032fe5fd2a54d1eb2 Mon Sep 17 00:00:00 2001 From: Amatya Date: Mon, 18 Apr 2022 09:52:38 +0530 Subject: [PATCH 13/13] Enable reuse of builder in V3 --- .../druid/server/coordinator/cost/SegmentsCostCacheV3.java | 4 ---- 1 file changed, 4 deletions(-) diff --git a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java index 77dc0a47ed7f..4c50de2a65c9 100644 --- a/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java +++ b/server/src/main/java/org/apache/druid/server/coordinator/cost/SegmentsCostCacheV3.java @@ -258,7 +258,6 @@ public boolean isEmpty() public SegmentsCostCacheV3 build() { final int allGranularitySegmentCount = allGranularitySegments.size(); - allGranularitySegments.clear(); final ArrayList> adhocNormalizedIntervals = new ArrayList<>(); for (SegmentId segment : adhocSegments) { @@ -266,7 +265,6 @@ public SegmentsCostCacheV3 build() double normalizedEnd = segment.getInterval().getEndMillis() / MILLIS_FACTOR; adhocNormalizedIntervals.add(Pair.of(normalizedStart, normalizedEnd)); } - adhocSegments.clear(); return new SegmentsCostCacheV3( buckets @@ -547,8 +545,6 @@ public Bucket build() bucketEndMillis = Math.max(bucketEndMillis, i.getEndMillis()); } - segmentSet.clear(); - return new Bucket(Intervals.utc(interval.getStartMillis(), bucketEndMillis), intervals); } }