Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
Expand Down Expand Up @@ -55,7 +56,103 @@ public final class TaskAssignmentUtils {
private TaskAssignmentUtils() {}

/**
* Return an {@code AssignmentError} for a task assignment created for an application.
* A simple config container for necessary parameters and optional overrides to apply when
* running the active or standby task rack-aware optimizations.
*/
public static class RackAwareOptimizationParams {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: make this final

private final ApplicationState applicationState;
private final Optional<Integer> trafficCostOverride;
private final Optional<Integer> nonOverlapCostOverride;
private final Optional<SortedSet<TaskId>> tasksToOptimize;

private RackAwareOptimizationParams(final ApplicationState applicationState,
final Optional<Integer> trafficCostOverride,
final Optional<Integer> nonOverlapCostOverride,
final Optional<SortedSet<TaskId>> tasksToOptimize) {
Comment on lines +69 to +71
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation is off by one space I think?

this.applicationState = applicationState;
this.trafficCostOverride = trafficCostOverride;
this.nonOverlapCostOverride = nonOverlapCostOverride;
this.tasksToOptimize = tasksToOptimize;
}

/**
* Return a new config object with no overrides and the tasksToOptimize initialized to the set of all tasks in the given ApplicationState
*/
public static RackAwareOptimizationParams of(final ApplicationState applicationState) {
return new RackAwareOptimizationParams(applicationState, Optional.empty(), Optional.empty(), Optional.empty());
}

/**
* Return a new config object with the tasksToOptimize set to all stateful tasks in the given ApplicationState
*/
public RackAwareOptimizationParams forStatefulTasks() {
final SortedSet<TaskId> tasks = applicationState.allTasks().values()
.stream()
.filter(TaskInfo::isStateful)
.map(TaskInfo::id)
.collect(Collectors.toCollection(TreeSet::new));
return forTasks(tasks);
}

/**
* Return a new config object with the tasksToOptimize set to all stateless tasks in the given ApplicationState
*/
public RackAwareOptimizationParams forStatelessTasks() {
final SortedSet<TaskId> tasks = applicationState.allTasks().values()
.stream()
.filter(taskInfo -> !taskInfo.isStateful())
.map(TaskInfo::id)
.collect(Collectors.toCollection(TreeSet::new));
return forTasks(tasks);
}

/**
* Return a new config object with the provided tasksToOptimize
*/
public RackAwareOptimizationParams forTasks(final SortedSet<TaskId> tasksToOptimize) {
return new RackAwareOptimizationParams(
applicationState,
trafficCostOverride,
nonOverlapCostOverride,
Optional.of(tasksToOptimize)
);
}

/**
* Return a new config object with the provided trafficCost override applied
*/
public RackAwareOptimizationParams withTrafficCostOverride(final int trafficCostOverride) {
return new RackAwareOptimizationParams(
applicationState,
Optional.of(trafficCostOverride),
nonOverlapCostOverride,
tasksToOptimize
);
}

/**
* Return a new config object with the provided nonOverlapCost override applied
*/
public RackAwareOptimizationParams withNonOverlapCostOverride(final int nonOverlapCostOverride) {
return new RackAwareOptimizationParams(
applicationState,
trafficCostOverride,
Optional.of(nonOverlapCostOverride),
tasksToOptimize
);
}
}

/**
* Validate the passed-in {@link TaskAssignment} and return an {@link AssignmentError} representing the
* first error detected in the assignment, or {@link AssignmentError#NONE} if the assignment passes the
* verification check.
* <p>
* Note: this verification is performed automatically by the StreamsPartitionAssignor on the assignment
* returned by the TaskAssignor, and the error returned to the assignor via the {@link TaskAssignor#onAssignmentComputed}
* callback. Therefore, it is not required to call this manually from the {@link TaskAssignor#assign} method.
* However, if an invalid assignment is returned it will fail the rebalance and kill the thread, so it may be useful to
* utilize this method in an assignor to verify the assignment before returning it and fix any errors it finds.
*
* @param applicationState The application for which this task assignment is being assessed.
* @param taskAssignment The task assignment that will be validated.
Expand Down Expand Up @@ -153,16 +250,14 @@ public static Map<ProcessId, KafkaStreamsAssignment> identityAssignment(final Ap
* If rack-aware client tags are configured, the rack-aware standby task assignor will be used
*
* @param applicationState the metadata and other info describing the current application state
* @param kafkaStreamsAssignments the current assignment of tasks to KafkaStreams clients
*
* @return a new map containing the mappings from KafkaStreamsAssignments updated with the default standby assignment
* @param kafkaStreamsAssignments the KafkaStreams client assignments to add standby tasks to
*/
public static Map<ProcessId, KafkaStreamsAssignment> defaultStandbyTaskAssignment(final ApplicationState applicationState,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
public static void defaultStandbyTaskAssignment(final ApplicationState applicationState,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
if (!applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty()) {
return tagBasedStandbyTaskAssignment(applicationState, kafkaStreamsAssignments);
tagBasedStandbyTaskAssignment(applicationState, kafkaStreamsAssignments);
} else {
return loadBasedStandbyTaskAssignment(applicationState, kafkaStreamsAssignments);
loadBasedStandbyTaskAssignment(applicationState, kafkaStreamsAssignments);
}
}

Expand All @@ -185,34 +280,43 @@ public static Map<ProcessId, KafkaStreamsAssignment> defaultStandbyTaskAssignmen
* <p>
* This method optimizes cross-rack traffic for active tasks only. For standby task optimization,
* use {@link #optimizeRackAwareStandbyTasks}.
* <p>
* It is recommended to run this optimization before assigning any standby tasks, especially if you have configured
* your KafkaStreams clients with assignment tags via the rack.aware.assignment.tags config since this method may
* shuffle around active tasks without considering the client tags and can result in a violation of the original
* client tag assignment's constraints.
*
* @param applicationState the metadata and other info describing the current application state
* @param optimizationParams optional configuration parameters to apply
* @param kafkaStreamsAssignments the current assignment of tasks to KafkaStreams clients
* @param tasks the set of tasks to reassign if possible. Must already be assigned to a KafkaStreams client
*
* @return a map with the KafkaStreamsAssignments updated to minimize cross-rack traffic for active tasks
*/
public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareActiveTasks(final ApplicationState applicationState,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments,
final SortedSet<TaskId> tasks) {
if (tasks.isEmpty()) {
return kafkaStreamsAssignments;
public static void optimizeRackAwareActiveTasks(final RackAwareOptimizationParams optimizationParams,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
final ApplicationState applicationState = optimizationParams.applicationState;
final SortedSet<TaskId> activeTasksToOptimize = getTasksToOptimize(kafkaStreamsAssignments, optimizationParams, AssignedTask.Type.ACTIVE);
if (activeTasksToOptimize.isEmpty()) {
return;
}

if (!canPerformRackAwareOptimization(applicationState, AssignedTask.Type.ACTIVE)) {
return kafkaStreamsAssignments;
if (!canPerformRackAwareOptimization(applicationState, optimizationParams, AssignedTask.Type.ACTIVE)) {
return;
}

initializeAssignmentsForAllClients(applicationState, kafkaStreamsAssignments);

final int crossRackTrafficCost = applicationState.assignmentConfigs().rackAwareTrafficCost().getAsInt();
final int nonOverlapCost = applicationState.assignmentConfigs().rackAwareNonOverlapCost().getAsInt();
final int crossRackTrafficCost =
optimizationParams.trafficCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareTrafficCost()
.getAsInt());
final int nonOverlapCost =
optimizationParams.nonOverlapCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareNonOverlapCost()
.getAsInt());

final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false);
final List<TaskId> taskIds = new ArrayList<>(tasks);
final List<TaskId> taskIds = new ArrayList<>(activeTasksToOptimize);

final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId = applicationState.allTasks().values().stream()
.filter(taskInfo -> tasks.contains(taskInfo.id()))
.filter(taskInfo -> activeTasksToOptimize.contains(taskInfo.id()))
.collect(Collectors.toMap(TaskInfo::id, TaskInfo::topicPartitions));

final List<ProcessId> clientIds = new ArrayList<>(kafkaStreamsStates.keySet());
Expand Down Expand Up @@ -259,8 +363,6 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareActiveTask
(assignment, taskId) -> assignment.removeTask(new AssignedTask(taskId, AssignedTask.Type.ACTIVE)),
(assignment, taskId) -> assignment.tasks().containsKey(taskId) && assignment.tasks().get(taskId).type() == AssignedTask.Type.ACTIVE
);

return kafkaStreamsAssignments;
}

/**
Expand All @@ -283,31 +385,34 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareActiveTask
* This method optimizes cross-rack traffic for standby tasks only. For active task optimization,
* use {@link #optimizeRackAwareActiveTasks}.
*
* @param optimizationParams optional configuration parameters to apply
* @param kafkaStreamsAssignments the current assignment of tasks to KafkaStreams clients
* @param applicationState the metadata and other info describing the current application state
*
* @return a map with the KafkaStreamsAssignments updated to minimize cross-rack traffic for standby tasks
*/
public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTasks(final ApplicationState applicationState,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
if (!canPerformRackAwareOptimization(applicationState, AssignedTask.Type.STANDBY)) {
return kafkaStreamsAssignments;
public static void optimizeRackAwareStandbyTasks(final RackAwareOptimizationParams optimizationParams,
final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
final ApplicationState applicationState = optimizationParams.applicationState;
final SortedSet<TaskId> standbyTasksToOptimize = getTasksToOptimize(kafkaStreamsAssignments, optimizationParams, AssignedTask.Type.STANDBY);
if (standbyTasksToOptimize.isEmpty()) {
return;
}

if (!canPerformRackAwareOptimization(applicationState, optimizationParams, AssignedTask.Type.STANDBY)) {
return;
}

initializeAssignmentsForAllClients(applicationState, kafkaStreamsAssignments);

final int crossRackTrafficCost = applicationState.assignmentConfigs().rackAwareTrafficCost().getAsInt();
final int nonOverlapCost = applicationState.assignmentConfigs().rackAwareNonOverlapCost().getAsInt();
final int crossRackTrafficCost =
optimizationParams.trafficCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareTrafficCost()
.getAsInt());
final int nonOverlapCost =
optimizationParams.nonOverlapCostOverride.orElseGet(() -> applicationState.assignmentConfigs()
.rackAwareNonOverlapCost()
.getAsInt());

final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false);

final List<TaskId> standbyTasksToOptimize = kafkaStreamsAssignments.values().stream()
.flatMap(r -> r.tasks().values().stream())
.filter(task -> task.type() == AssignedTask.Type.STANDBY)
.map(AssignedTask::id)
.distinct()
.collect(Collectors.toList());

final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId =
applicationState.allTasks().values().stream().collect(Collectors.toMap(
TaskInfo::id,
Expand All @@ -317,7 +422,7 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas
final List<ProcessId> clientIds = new ArrayList<>(kafkaStreamsStates.keySet());
final long initialCost = computeTotalAssignmentCost(
topicPartitionsByTaskId,
standbyTasksToOptimize,
new ArrayList<>(standbyTasksToOptimize),
clientIds,
kafkaStreamsAssignments,
kafkaStreamsStates,
Expand Down Expand Up @@ -411,7 +516,7 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas
}
final long finalCost = computeTotalAssignmentCost(
topicPartitionsByTaskId,
standbyTasksToOptimize,
new ArrayList<>(standbyTasksToOptimize),
clientIds,
kafkaStreamsAssignments,
kafkaStreamsStates,
Expand All @@ -424,7 +529,6 @@ public static Map<ProcessId, KafkaStreamsAssignment> optimizeRackAwareStandbyTas
final long duration = System.currentTimeMillis() - startTime;
LOG.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}",
round, duration, kafkaStreamsAssignments, finalCost);
return kafkaStreamsAssignments;
}

private static long computeTotalAssignmentCost(final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId,
Expand Down Expand Up @@ -541,6 +645,7 @@ private static int getCrossRackTrafficCost(final Set<TaskTopicPartition> topicPa
* is set.
*/
private static boolean canPerformRackAwareOptimization(final ApplicationState applicationState,
final RackAwareOptimizationParams optimizationParams,
final AssignedTask.Type taskType) {
final AssignmentConfigs assignmentConfigs = applicationState.assignmentConfigs();
final String rackAwareAssignmentStrategy = assignmentConfigs.rackAwareAssignmentStrategy();
Expand Down Expand Up @@ -902,6 +1007,20 @@ private static void initializeAssignmentsForAllClients(final ApplicationState ap
}
}

private static SortedSet<TaskId> getTasksToOptimize(final Map<ProcessId, KafkaStreamsAssignment> assignments,
final RackAwareOptimizationParams optimizationParams,
final AssignedTask.Type taskType) {
if (optimizationParams != null && optimizationParams.tasksToOptimize.isPresent()) {
return optimizationParams.tasksToOptimize.get();
}

return assignments.values().stream()
.flatMap(r -> r.tasks().values().stream())
.filter(task -> task.type() == taskType)
.map(AssignedTask::id)
.collect(Collectors.toCollection(TreeSet::new));
}

private static class TagStatistics {
private final Map<String, Set<String>> tagKeyToValues;
private final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToClients;
Expand Down
Loading