Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.druid.server.coordinator.balancer;

import com.google.common.collect.Lists;
import org.apache.druid.client.DruidServer;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.server.coordination.ServerType;
Expand All @@ -31,16 +32,16 @@
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ReservoirSegmentSamplerTest
{
Expand All @@ -55,9 +56,6 @@ public class ReservoirSegmentSamplerTest
.withNumPartitions(10)
.eachOfSizeInMb(100);

private final Function<ServerHolder, Collection<DataSegment>> GET_SERVED_SEGMENTS
= serverHolder -> serverHolder.getServer().iterateAllSegments();

@Before
public void setUp()
{
Expand All @@ -80,7 +78,7 @@ public void testEverySegmentGetsPickedAtleastOnce()
// due to the pseudo-randomness of this method, we may not select a segment every single time no matter what.
segmentCountMap.compute(
ReservoirSegmentSampler
.pickMovableSegmentsFrom(servers, 1, GET_SERVED_SEGMENTS, Collections.emptySet())
.pickMovableSegmentsFrom(servers, 1, ServerHolder::getServedSegments, Collections.emptySet())
.get(0).getSegment(),
(segment, count) -> count == null ? 1 : count + 1
);
Expand Down Expand Up @@ -151,9 +149,16 @@ public void testPickLoadingOrLoadedSegments()
Assert.assertTrue(pickedSegments.containsAll(loadingSegments));

// Pick only loaded segments
pickedSegments = ReservoirSegmentSampler
.pickMovableSegmentsFrom(Arrays.asList(server1, server2), 10, GET_SERVED_SEGMENTS, Collections.emptySet())
.stream().map(BalancerSegmentHolder::getSegment).collect(Collectors.toSet());
List<BalancerSegmentHolder> pickedHolders = ReservoirSegmentSampler.pickMovableSegmentsFrom(
Arrays.asList(server1, server2),
10,
ServerHolder::getServedSegments,
Collections.emptySet()
);
pickedSegments = pickedHolders
.stream()
.map(BalancerSegmentHolder::getSegment)
.collect(Collectors.toSet());

// Verify that only loaded segments are picked
Assert.assertEquals(loadedSegments.size(), pickedSegments.size());
Expand All @@ -177,7 +182,7 @@ public void testSegmentsOnBrokersAreIgnored()
List<BalancerSegmentHolder> pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(
Arrays.asList(historical, broker),
10,
GET_SERVED_SEGMENTS,
ServerHolder::getServedSegments,
Collections.emptySet()
);

Expand Down Expand Up @@ -206,8 +211,12 @@ public void testBroadcastSegmentsAreIgnored()
);

// Try to pick all the segments on the servers
List<BalancerSegmentHolder> pickedSegments = ReservoirSegmentSampler
.pickMovableSegmentsFrom(servers, 10, GET_SERVED_SEGMENTS, Collections.singleton(broadcastDatasource));
List<BalancerSegmentHolder> pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(
servers,
10,
ServerHolder::getServedSegments,
Collections.singleton(broadcastDatasource)
);

// Verify that none of the broadcast segments are picked
Assert.assertEquals(2, pickedSegments.size());
Expand All @@ -216,21 +225,83 @@ public void testBroadcastSegmentsAreIgnored()
}
}

@Test
public void testSegmentsFromAllServersAreEquallyLikelyToBePicked()
{
// Create 4 servers, each having an equal number of segments
final List<List<DataSegment>> subSegmentLists = Lists.partition(segments, segments.size() / 4);
final List<ServerHolder> servers = IntStream.range(0, 4).mapToObj(
i -> createHistorical("server_" + i, subSegmentLists.get(i).toArray(new DataSegment[0]))
).collect(Collectors.toList());

// Get the distribution of picked segments for different sample percentages
final int[] samplePercentages = {50, 20, 10, 5};
for (int samplePercentage : samplePercentages) {
final int[] numSegmentsPickedFromServer
= pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50);

final int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum();

// Number of segments picked from each server is ~25% of total
final double expectedPickedSegments = totalSegmentsPicked * 0.25;
final double error = totalSegmentsPicked * 0.02;
for (int pickedSegments : numSegmentsPickedFromServer) {
Assert.assertEquals(expectedPickedSegments, pickedSegments, error);
}
}
}

@Test
public void testSegmentsFromMorePopulousServerAreMoreLikelyToBePicked()
{
// Create 4 servers, first one having twice as many segments as the rest
final List<List<DataSegment>> subSegmentLists = Lists.partition(segments, segments.size() / 5);

final List<ServerHolder> servers = new ArrayList<>();
List<DataSegment> segmentsForServer0 = new ArrayList<>(subSegmentLists.get(0));
segmentsForServer0.addAll(subSegmentLists.get(1));
servers.add(createHistorical("server_" + 0, segmentsForServer0));

IntStream.range(1, 4).mapToObj(
i -> createHistorical("server_" + i, subSegmentLists.get(i + 1))
).forEach(servers::add);

final int[] samplePercentages = {50, 20, 10, 5};
for (int samplePercentage : samplePercentages) {
final int[] numSegmentsPickedFromServer
= pickSegmentsAndGetPickedCountPerServer(servers, samplePercentage, 50);

final int totalSegmentsPicked = Arrays.stream(numSegmentsPickedFromServer).sum();

// Number of segments picked from server0 are ~40% of total and
// number of segments picked from other servers are each ~20% of total
double error = totalSegmentsPicked * 0.02;
Assert.assertEquals(totalSegmentsPicked * 0.40, numSegmentsPickedFromServer[0], error);

for (int serverId = 1; serverId < servers.size(); ++serverId) {
Assert.assertEquals(totalSegmentsPicked * 0.20, numSegmentsPickedFromServer[serverId], error);
}
}
}

@Test(timeout = 60_000)
public void testNumberOfIterationsToCycleThroughAllSegments()
public void testNumberOfSamplingsRequiredToPickAllSegments()
{
// The number of runs required for each sample percentage
// The number of sampling iterations required for each sample percentage
// remains more or less fixed, even with a larger number of segments
final int[] samplePercentages = {100, 50, 10, 5, 1};
final int[] expectedIterations = {1, 20, 100, 200, 1000};

final int[] totalObservedIterations = new int[5];

// For every sample percentage, count the minimum number of required samplings
for (int i = 0; i < 50; ++i) {
for (int j = 0; j < samplePercentages.length; ++j) {
totalObservedIterations[j] += countMinRunsWithSamplePercent(samplePercentages[j]);
totalObservedIterations[j] += countMinRunsToPickAllSegments(samplePercentages[j]);
}
}

// Compute the avg value from the 50 observations for each sample percentage
for (int j = 0; j < samplePercentages.length; ++j) {
double avgObservedIterations = totalObservedIterations[j] / 50.0;
Assert.assertTrue(avgObservedIterations <= expectedIterations[j]);
Expand All @@ -244,7 +315,7 @@ public void testNumberOfIterationsToCycleThroughAllSegments()
* <p>
* {@code k = sampleSize = totalNumSegments * samplePercentage}
*/
private int countMinRunsWithSamplePercent(int samplePercentage)
private int countMinRunsToPickAllSegments(int samplePercentage)
{
final int numSegments = segments.size();
final List<ServerHolder> servers = Arrays.asList(
Expand All @@ -259,7 +330,7 @@ private int countMinRunsWithSamplePercent(int samplePercentage)
int numIterations = 1;
for (; numIterations < 10000; ++numIterations) {
ReservoirSegmentSampler
.pickMovableSegmentsFrom(servers, sampleSize, GET_SERVED_SEGMENTS, Collections.emptySet())
.pickMovableSegmentsFrom(servers, sampleSize, ServerHolder::getServedSegments, Collections.emptySet())
.forEach(holder -> pickedSegments.add(holder.getSegment()));

if (pickedSegments.size() >= numSegments) {
Expand All @@ -270,6 +341,38 @@ private int countMinRunsWithSamplePercent(int samplePercentage)
return numIterations;
}

private int[] pickSegmentsAndGetPickedCountPerServer(
List<ServerHolder> servers,
int samplePercentage,
int numIterations
)
{
final int numSegmentsToPick = (int) (segments.size() * samplePercentage / 100.0);
final int[] numSegmentsPickedFromServer = new int[servers.size()];

for (int i = 0; i < numIterations; ++i) {
List<BalancerSegmentHolder> pickedSegments = ReservoirSegmentSampler.pickMovableSegmentsFrom(
servers,
numSegmentsToPick,
ServerHolder::getServedSegments,
Collections.emptySet()
);

// Get the number of segments picked from each server
for (BalancerSegmentHolder pickedSegment : pickedSegments) {
int serverIndex = servers.indexOf(pickedSegment.getServer());
numSegmentsPickedFromServer[serverIndex]++;
}
}

return numSegmentsPickedFromServer;
}

private ServerHolder createHistorical(String serverName, List<DataSegment> loadedSegments)
{
return createHistorical(serverName, loadedSegments.toArray(new DataSegment[0]));
}

private ServerHolder createHistorical(String serverName, DataSegment... loadedSegments)
{
final DruidServer server =
Expand Down