From 6f531583344b8fe7a5a059bb9d75083d0d674044 Mon Sep 17 00:00:00 2001 From: "Colin P. Mccabe" Date: Thu, 25 Jan 2018 11:33:23 -0800 Subject: [PATCH] KAFKA-6254: Incremental fetch requests Implement incremental fetch requests as described by KIP-227. --- .../kafka/clients/FetchSessionHandler.java | 443 +++++++++++ .../clients/consumer/internals/Fetcher.java | 88 ++- .../FetchSessionIdNotFoundException.java | 29 + .../InvalidFetchSessionEpochException.java | 29 + .../apache/kafka/common/protocol/Errors.java | 16 + .../kafka/common/protocol/types/Struct.java | 6 + .../kafka/common/requests/FetchMetadata.java | 154 ++++ .../kafka/common/requests/FetchRequest.java | 187 ++++- .../kafka/common/requests/FetchResponse.java | 79 +- .../common/utils/ImplicitLinkedHashSet.java | 354 +++++++++ .../clients/FetchSessionHandlerTest.java | 356 +++++++++ .../clients/consumer/KafkaConsumerTest.java | 3 +- .../consumer/internals/FetcherTest.java | 175 +++-- .../common/requests/RequestResponseTest.java | 70 +- .../utils/ImplicitLinkedHashSetTest.java | 239 ++++++ .../src/main/scala/kafka/api/ApiVersion.scala | 7 +- .../scala/kafka/server/FetchSession.scala | 720 ++++++++++++++++++ .../main/scala/kafka/server/KafkaApis.scala | 144 ++-- .../main/scala/kafka/server/KafkaConfig.scala | 15 + .../main/scala/kafka/server/KafkaServer.scala | 7 +- .../kafka/server/ReplicaFetcherThread.scala | 50 +- .../unit/kafka/server/FetchRequestTest.scala | 55 ++ .../unit/kafka/server/FetchSessionTest.scala | 312 ++++++++ .../unit/kafka/server/KafkaApisTest.scala | 2 + .../util/ReplicaFetcherMockBlockingSend.scala | 7 +- 25 files changed, 3329 insertions(+), 218 deletions(-) create mode 100644 clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java create mode 100644 clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java create mode 100644 clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java create mode 100644 clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java create mode 100644 clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java create mode 100644 clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java create mode 100644 clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java create mode 100644 core/src/main/scala/kafka/server/FetchSession.scala create mode 100755 core/src/test/scala/unit/kafka/server/FetchSessionTest.scala diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java new file mode 100644 index 0000000000000..195324e833607 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchMetadata; +import org.apache.kafka.common.requests.FetchRequest.PartitionData; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; + +/** + * FetchSessionHandler maintains the fetch session state for connecting to a broker. + * + * Using the protocol outlined by KIP-227, clients can create incremental fetch sessions. + * These sessions allow the client to fetch information about a set of partition over + * and over, without explicitly enumerating all the partitions in the request and the + * response. + * + * FetchSessionHandler tracks the partitions which are in the session. It also + * determines which partitions need to be included in each fetch request, and what + * the attached fetch session metadata should be for each request. The corresponding + * class on the receiving broker side is FetchManager. + */ +public class FetchSessionHandler { + private final Logger log; + + private final int node; + + /** + * The metadata for the next fetch request. + */ + private FetchMetadata nextMetadata = FetchMetadata.INITIAL; + + public FetchSessionHandler(LogContext logContext, int node) { + this.log = logContext.logger(FetchSessionHandler.class); + this.node = node; + } + + /** + * All of the partitions which exist in the fetch request session. + */ + private LinkedHashMap sessionPartitions = + new LinkedHashMap<>(0); + + public static class FetchRequestData { + /** + * The partitions to send in the fetch request. + */ + private final Map toSend; + + /** + * The partitions to send in the request's "forget" list. + */ + private final List toForget; + + /** + * All of the partitions which exist in the fetch request session. + */ + private final Map sessionPartitions; + + /** + * The metadata to use in this fetch request. + */ + private final FetchMetadata metadata; + + FetchRequestData(Map toSend, + List toForget, + Map sessionPartitions, + FetchMetadata metadata) { + this.toSend = toSend; + this.toForget = toForget; + this.sessionPartitions = sessionPartitions; + this.metadata = metadata; + } + + /** + * Get the set of partitions to send in this fetch request. + */ + public Map toSend() { + return toSend; + } + + /** + * Get a list of partitions to forget in this fetch request. + */ + public List toForget() { + return toForget; + } + + /** + * Get the full set of partitions involved in this fetch request. + */ + public Map sessionPartitions() { + return sessionPartitions; + } + + public FetchMetadata metadata() { + return metadata; + } + + @Override + public String toString() { + if (metadata.isFull()) { + StringBuilder bld = new StringBuilder("FullFetchRequest("); + String prefix = ""; + for (TopicPartition partition : toSend.keySet()) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append(")"); + return bld.toString(); + } else { + StringBuilder bld = new StringBuilder("IncrementalFetchRequest(toSend=("); + String prefix = ""; + for (TopicPartition partition : toSend.keySet()) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append("), toForget=("); + prefix = ""; + for (TopicPartition partition : toForget) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append("), implied=("); + prefix = ""; + for (TopicPartition partition : sessionPartitions.keySet()) { + if (!toSend.containsKey(partition)) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + } + bld.append("))"); + return bld.toString(); + } + } + } + + public class Builder { + /** + * The next partitions which we want to fetch. + * + * It is important to maintain the insertion order of this list by using a LinkedHashMap rather + * than a regular Map. + * + * One reason is that when dealing with FULL fetch requests, if there is not enough response + * space to return data from all partitions, the server will only return data from partitions + * early in this list. + * + * Another reason is because we make use of the list ordering to optimize the preparation of + * incremental fetch requests (see below). + */ + private LinkedHashMap next = new LinkedHashMap<>(); + + /** + * Mark that we want data from this partition in the upcoming fetch. + */ + public void add(TopicPartition topicPartition, PartitionData data) { + next.put(topicPartition, data); + } + + public FetchRequestData build() { + if (nextMetadata.isFull()) { + log.debug("Built full fetch {} for node {} with {}.", + nextMetadata, node, partitionsToLogString(next.keySet())); + sessionPartitions = next; + next = null; + Map toSend = + Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions)); + return new FetchRequestData(toSend, Collections.emptyList(), toSend, nextMetadata); + } + + List added = new ArrayList<>(); + List removed = new ArrayList<>(); + List altered = new ArrayList<>(); + for (Iterator> iter = + sessionPartitions.entrySet().iterator(); iter.hasNext(); ) { + Entry entry = iter.next(); + TopicPartition topicPartition = entry.getKey(); + PartitionData prevData = entry.getValue(); + PartitionData nextData = next.get(topicPartition); + if (nextData != null) { + if (prevData.equals(nextData)) { + // Omit this partition from the FetchRequest, because it hasn't changed + // since the previous request. + next.remove(topicPartition); + } else { + // Move the altered partition to the end of 'next' + next.remove(topicPartition); + next.put(topicPartition, nextData); + entry.setValue(nextData); + altered.add(topicPartition); + } + } else { + // Remove this partition from the session. + iter.remove(); + // Indicate that we no longer want to listen to this partition. + removed.add(topicPartition); + } + } + // Add any new partitions to the session. + for (Iterator> iter = + next.entrySet().iterator(); iter.hasNext(); ) { + Entry entry = iter.next(); + TopicPartition topicPartition = entry.getKey(); + PartitionData nextData = entry.getValue(); + if (sessionPartitions.containsKey(topicPartition)) { + // In the previous loop, all the partitions which existed in both sessionPartitions + // and next were moved to the end of next, or removed from next. Therefore, + // once we hit one of them, we know there are no more unseen entries to look + // at in next. + break; + } + sessionPartitions.put(topicPartition, nextData); + added.add(topicPartition); + } + log.debug("Built incremental fetch {} for node {}. Added {}, altered {}, removed {} " + + "out of {}", nextMetadata, node, partitionsToLogString(added), + partitionsToLogString(altered), partitionsToLogString(removed), + partitionsToLogString(sessionPartitions.keySet())); + Map toSend = + Collections.unmodifiableMap(new LinkedHashMap<>(next)); + Map curSessionPartitions = + Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions)); + next = null; + return new FetchRequestData(toSend, Collections.unmodifiableList(removed), + curSessionPartitions, nextMetadata); + } + } + + public Builder newBuilder() { + return new Builder(); + } + + private String partitionsToLogString(Collection partitions) { + if (!log.isTraceEnabled()) { + return String.format("%d partition(s)", partitions.size()); + } + return "(" + Utils.join(partitions, ", ") + ")"; + } + + /** + * Return some partitions which are expected to be in a particular set, but which are not. + * + * @param toFind The partitions to look for. + * @param toSearch The set of partitions to search. + * @return null if all partitions were found; some of the missing ones + * in string form, if not. + */ + static Set findMissing(Set toFind, Set toSearch) { + Set ret = new LinkedHashSet<>(); + for (TopicPartition partition : toFind) { + if (!toSearch.contains(partition)) { + ret.add(partition); + } + } + return ret; + } + + /** + * Verify that a full fetch response contains all the partitions in the fetch session. + * + * @param response The response. + * @return True if the full fetch response partitions are valid. + */ + private String verifyFullFetchResponsePartitions(FetchResponse response) { + StringBuilder bld = new StringBuilder(); + Set omitted = + findMissing(response.responseData().keySet(), sessionPartitions.keySet()); + Set extra = + findMissing(sessionPartitions.keySet(), response.responseData().keySet()); + if (!omitted.isEmpty()) { + bld.append("omitted=(").append(Utils.join(omitted, ", ")).append(", "); + } + if (!extra.isEmpty()) { + bld.append("extra=(").append(Utils.join(extra, ", ")).append(", "); + } + if ((!omitted.isEmpty()) || (!extra.isEmpty())) { + bld.append("response=(").append(Utils.join(response.responseData().keySet(), ", ")); + return bld.toString(); + } + return null; + } + + /** + * Verify that the partitions in an incremental fetch response are contained in the session. + * + * @param response The response. + * @return True if the incremental fetch response partitions are valid. + */ + private String verifyIncrementalFetchResponsePartitions(FetchResponse response) { + Set extra = + findMissing(response.responseData().keySet(), sessionPartitions.keySet()); + if (!extra.isEmpty()) { + StringBuilder bld = new StringBuilder(); + bld.append("extra=(").append(Utils.join(extra, ", ")).append("), "); + bld.append("response=(").append( + Utils.join(response.responseData().keySet(), ", ")).append("), "); + return bld.toString(); + } + return null; + } + + /** + * Create a string describing the partitions in a FetchResponse. + * + * @param response The FetchResponse. + * @return The string to log. + */ + private String responseDataToLogString(FetchResponse response) { + if (!log.isTraceEnabled()) { + int implied = sessionPartitions.size() - response.responseData().size(); + if (implied > 0) { + return String.format(" with %d response partition(s), %d implied partition(s)", + response.responseData().size(), implied); + } else { + return String.format(" with %d response partition(s)", + response.responseData().size()); + } + } + StringBuilder bld = new StringBuilder(); + bld.append(" with response=("). + append(Utils.join(response.responseData().keySet(), ", ")). + append(")"); + String prefix = ", implied=("; + String suffix = ""; + for (TopicPartition partition : sessionPartitions.keySet()) { + if (!response.responseData().containsKey(partition)) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + suffix = ")"; + } + } + bld.append(suffix); + return bld.toString(); + } + + /** + * Handle the fetch response. + * + * @param response The response. + * @return True if the response is well-formed; false if it can't be processed + * because of missing or unexpected partitions. + */ + public boolean handleResponse(FetchResponse response) { + if (response.error() != Errors.NONE) { + log.info("Node {} was unable to process the fetch request with {}: {}.", + node, nextMetadata, response.error()); + if (response.error() == Errors.FETCH_SESSION_ID_NOT_FOUND) { + nextMetadata = FetchMetadata.INITIAL; + } else { + nextMetadata = nextMetadata.nextCloseExisting(); + } + return false; + } else if (nextMetadata.isFull()) { + String problem = verifyFullFetchResponsePartitions(response); + if (problem != null) { + log.info("Node {} sent an invalid full fetch response with {}", node, problem); + nextMetadata = FetchMetadata.INITIAL; + return false; + } else if (response.sessionId() == INVALID_SESSION_ID) { + log.debug("Node {} sent a full fetch response{}", + node, responseDataToLogString(response)); + nextMetadata = FetchMetadata.INITIAL; + return true; + } else { + // The server created a new incremental fetch session. + log.debug("Node {} sent a full fetch response that created a new incremental " + + "fetch session {}{}", node, response.sessionId(), responseDataToLogString(response)); + nextMetadata = FetchMetadata.newIncremental(response.sessionId()); + return true; + } + } else { + String problem = verifyIncrementalFetchResponsePartitions(response); + if (problem != null) { + log.info("Node {} sent an invalid incremental fetch response with {}", node, problem); + nextMetadata = nextMetadata.nextCloseExisting(); + return false; + } else if (response.sessionId() == INVALID_SESSION_ID) { + // The incremental fetch session was closed by the server. + log.debug("Node {} sent an incremental fetch response closing session {}{}", + node, nextMetadata.sessionId(), responseDataToLogString(response)); + nextMetadata = FetchMetadata.INITIAL; + return true; + } else { + // The incremental fetch session was continued by the server. + log.debug("Node {} sent an incremental fetch response for session {}{}", + node, response.sessionId(), responseDataToLogString(response)); + nextMetadata = nextMetadata.nextIncremental(); + return true; + } + } + } + + /** + * Handle an error sending the prepared request. + * + * When a network error occurs, we close any existing fetch session on our next request, + * and try to create a new session. + * + * @param t The exception. + */ + public void handleError(Throwable t) { + log.info("Error sending fetch request {} to node {}: {}.", nextMetadata, node, t.toString()); + nextMetadata = nextMetadata.nextCloseExisting(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java index 6d56139118a28..32782ee53db70 100644 --- a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java @@ -17,6 +17,7 @@ package org.apache.kafka.clients.consumer.internals; import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.FetchSessionHandler; import org.apache.kafka.clients.Metadata; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; @@ -92,6 +93,7 @@ */ public class Fetcher implements SubscriptionState.Listener, Closeable { private final Logger log; + private final LogContext logContext; private final ConsumerNetworkClient client; private final Time time; private final int minBytes; @@ -110,6 +112,7 @@ public class Fetcher implements SubscriptionState.Listener, Closeable { private final ExtendedDeserializer keyDeserializer; private final ExtendedDeserializer valueDeserializer; private final IsolationLevel isolationLevel; + private final Map sessionHandlers; private PartitionRecords nextInLineRecords = null; @@ -131,6 +134,7 @@ public Fetcher(LogContext logContext, long retryBackoffMs, IsolationLevel isolationLevel) { this.log = logContext.logger(Fetcher.class); + this.logContext = logContext; this.time = time; this.client = client; this.metadata = metadata; @@ -147,6 +151,7 @@ public Fetcher(LogContext logContext, this.sensors = new FetchManagerMetrics(metrics, metricsRegistry); this.retryBackoffMs = retryBackoffMs; this.isolationLevel = isolationLevel; + this.sessionHandlers = new HashMap<>(); subscriptions.addListener(this); } @@ -181,36 +186,37 @@ public boolean hasCompletedFetches() { return !completedFetches.isEmpty(); } - private boolean matchesRequestedPartitions(FetchRequest.Builder request, FetchResponse response) { - Set requestedPartitions = request.fetchData().keySet(); - Set fetchedPartitions = response.responseData().keySet(); - return fetchedPartitions.equals(requestedPartitions); - } - /** * Set-up a fetch request for any node that we have assigned partitions for which doesn't already have * an in-flight fetch or pending fetch data. * @return number of fetches sent */ public int sendFetches() { - Map fetchRequestMap = createFetchRequests(); - for (Map.Entry fetchEntry : fetchRequestMap.entrySet()) { - final FetchRequest.Builder request = fetchEntry.getValue(); - final Node fetchTarget = fetchEntry.getKey(); - - log.debug("Sending {} fetch for partitions {} to broker {}", isolationLevel, request.fetchData().keySet(), - fetchTarget); + Map fetchRequestMap = prepareFetchRequests(); + for (Map.Entry entry : fetchRequestMap.entrySet()) { + final Node fetchTarget = entry.getKey(); + final FetchSessionHandler.FetchRequestData data = entry.getValue(); + final FetchRequest.Builder request = FetchRequest.Builder. + forConsumer(this.maxWaitMs, this.minBytes, data.toSend()) + .isolationLevel(isolationLevel) + .setMaxBytes(this.maxBytes) + .metadata(data.metadata()) + .toForget(data.toForget()); + if (log.isDebugEnabled()) { + log.debug("Sending {} {} to broker {}", isolationLevel, data.toString(), fetchTarget); + } client.send(fetchTarget, request) .addListener(new RequestFutureListener() { @Override public void onSuccess(ClientResponse resp) { FetchResponse response = (FetchResponse) resp.responseBody(); - if (!matchesRequestedPartitions(request, response)) { - // obviously we expect the broker to always send us valid responses, so this check - // is mainly for test cases where mock fetch responses must be manually crafted. - log.warn("Ignoring fetch response containing partitions {} since it does not match " + - "the requested partitions {}", response.responseData().keySet(), - request.fetchData().keySet()); + FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id()); + if (handler == null) { + log.error("Unable to find FetchSessionHandler for node {}. Ignoring fetch response.", + fetchTarget.id()); + return; + } + if (!handler.handleResponse(response)) { return; } @@ -219,7 +225,7 @@ public void onSuccess(ClientResponse resp) { for (Map.Entry entry : response.responseData().entrySet()) { TopicPartition partition = entry.getKey(); - long fetchOffset = request.fetchData().get(partition).fetchOffset; + long fetchOffset = data.sessionPartitions().get(partition).fetchOffset; FetchResponse.PartitionData fetchData = entry.getValue(); log.debug("Fetch {} at offset {} for partition {} returned fetch data {}", @@ -233,7 +239,10 @@ public void onSuccess(ClientResponse resp) { @Override public void onFailure(RuntimeException e) { - log.debug("Fetch request {} to {} failed", request.fetchData(), fetchTarget, e); + FetchSessionHandler handler = sessionHandlers.get(fetchTarget.id()); + if (handler != null) { + handler.handleError(e); + } } }); } @@ -772,42 +781,41 @@ private List fetchablePartitions() { * Create fetch requests for all nodes for which we have assigned partitions * that have no existing requests in flight. */ - private Map createFetchRequests() { + private Map prepareFetchRequests() { // create the fetch info Cluster cluster = metadata.fetch(); - Map> fetchable = new LinkedHashMap<>(); + Map fetchable = new LinkedHashMap<>(); for (TopicPartition partition : fetchablePartitions()) { Node node = cluster.leaderFor(partition); if (node == null) { metadata.requestUpdate(); } else if (!this.client.hasPendingRequests(node)) { // if there is a leader and no in-flight requests, issue a new fetch - LinkedHashMap fetch = fetchable.get(node); - if (fetch == null) { - fetch = new LinkedHashMap<>(); - fetchable.put(node, fetch); + FetchSessionHandler.Builder builder = fetchable.get(node); + if (builder == null) { + FetchSessionHandler handler = sessionHandlers.get(node.id()); + if (handler == null) { + handler = new FetchSessionHandler(logContext, node.id()); + sessionHandlers.put(node.id(), handler); + } + builder = handler.newBuilder(); + fetchable.put(node, builder); } long position = this.subscriptions.position(partition); - fetch.put(partition, new FetchRequest.PartitionData(position, FetchRequest.INVALID_LOG_START_OFFSET, - this.fetchSize)); + builder.add(partition, new FetchRequest.PartitionData(position, FetchRequest.INVALID_LOG_START_OFFSET, + this.fetchSize)); log.debug("Added {} fetch request for partition {} at offset {} to node {}", isolationLevel, - partition, position, node); + partition, position, node); } else { log.trace("Skipping fetch for partition {} because there is an in-flight request to {}", partition, node); } } - - // create the fetches - Map requests = new HashMap<>(); - for (Map.Entry> entry : fetchable.entrySet()) { - Node node = entry.getKey(); - FetchRequest.Builder fetch = FetchRequest.Builder.forConsumer(this.maxWaitMs, this.minBytes, - entry.getValue(), isolationLevel) - .setMaxBytes(this.maxBytes); - requests.put(node, fetch); + Map reqs = new LinkedHashMap<>(); + for (Map.Entry entry : fetchable.entrySet()) { + reqs.put(entry.getKey(), entry.getValue().build()); } - return requests; + return reqs; } /** diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java new file mode 100644 index 0000000000000..2ce5f740d6719 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +public class FetchSessionIdNotFoundException extends RetriableException { + private static final long serialVersionUID = 1L; + + public FetchSessionIdNotFoundException() { + } + + public FetchSessionIdNotFoundException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java new file mode 100644 index 0000000000000..3b135c0147d06 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +public class InvalidFetchSessionEpochException extends RetriableException { + private static final long serialVersionUID = 1L; + + public InvalidFetchSessionEpochException() { + } + + public InvalidFetchSessionEpochException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java index e2b8aeaf709fe..4b44c18430fc9 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java @@ -30,6 +30,7 @@ import org.apache.kafka.common.errors.DelegationTokenExpiredException; import org.apache.kafka.common.errors.DelegationTokenNotFoundException; import org.apache.kafka.common.errors.DelegationTokenOwnerMismatchException; +import org.apache.kafka.common.errors.FetchSessionIdNotFoundException; import org.apache.kafka.common.errors.GroupAuthorizationException; import org.apache.kafka.common.errors.GroupIdNotFoundException; import org.apache.kafka.common.errors.GroupNotEmptyException; @@ -38,6 +39,7 @@ import org.apache.kafka.common.errors.InconsistentGroupProtocolException; import org.apache.kafka.common.errors.InvalidCommitOffsetSizeException; import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.InvalidFetchSessionEpochException; import org.apache.kafka.common.errors.InvalidFetchSizeException; import org.apache.kafka.common.errors.InvalidGroupIdException; import org.apache.kafka.common.errors.InvalidPartitionsException; @@ -608,6 +610,20 @@ public ApiException build(String message) { public ApiException build(String message) { return new GroupIdNotFoundException(message); } + }), + FETCH_SESSION_ID_NOT_FOUND(70, "The fetch session ID was not found", + new ApiExceptionBuilder() { + @Override + public ApiException build(String message) { + return new FetchSessionIdNotFoundException(message); + } + }), + INVALID_FETCH_SESSION_EPOCH(71, "The fetch session epoch is invalid", + new ApiExceptionBuilder() { + @Override + public ApiException build(String message) { + return new InvalidFetchSessionEpochException(message); + } }); private interface ApiExceptionBuilder { diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java index 6fb6b20ca311d..ac24a1b69b2b3 100644 --- a/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java @@ -105,6 +105,12 @@ public Long getOrElse(Field.Int64 field, long alternative) { return alternative; } + public Short getOrElse(Field.Int16 field, short alternative) { + if (hasField(field.name)) + return getShort(field.name); + return alternative; + } + public Integer getOrElse(Field.Int32 field, int alternative) { if (hasField(field.name)) return getInt(field.name); diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java new file mode 100644 index 0000000000000..feb6953f9dafd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; + +public class FetchMetadata { + public static final Logger log = LoggerFactory.getLogger(FetchMetadata.class); + + /** + * The session ID used by clients with no session. + */ + public static final int INVALID_SESSION_ID = 0; + + /** + * The first epoch. When used in a fetch request, indicates that the client + * wants to create or recreate a session. + */ + public static final int INITIAL_EPOCH = 0; + + /** + * An invalid epoch. When used in a fetch request, indicates that the client + * wants to close any existing session, and not create a new one. + */ + public static final int FINAL_EPOCH = -1; + + /** + * The FetchMetadata that is used when initializing a new FetchSessionHandler. + */ + public static final FetchMetadata INITIAL = new FetchMetadata(INVALID_SESSION_ID, INITIAL_EPOCH); + + /** + * The FetchMetadata that is implicitly used for handling older FetchRequests that + * don't include fetch metadata. + */ + public static final FetchMetadata LEGACY = new FetchMetadata(INVALID_SESSION_ID, FINAL_EPOCH); + + /** + * Returns the next epoch. + * + * @param prevEpoch The previous epoch. + * @return The next epoch. + */ + public static int nextEpoch(int prevEpoch) { + if (prevEpoch < 0) { + // The next epoch after FINAL_EPOCH is always FINAL_EPOCH itself. + return FINAL_EPOCH; + } else if (prevEpoch == Integer.MAX_VALUE) { + return 1; + } else { + return prevEpoch + 1; + } + } + + /** + * The fetch session ID. + */ + private final int sessionId; + + /** + * The fetch session epoch. + */ + private final int epoch; + + public FetchMetadata(int sessionId, int epoch) { + this.sessionId = sessionId; + this.epoch = epoch; + } + + /** + * Returns true if this is a full fetch request. + */ + public boolean isFull() { + return (this.epoch == INITIAL_EPOCH) || (this.epoch == FINAL_EPOCH); + } + + public int sessionId() { + return sessionId; + } + + public int epoch() { + return epoch; + } + + @Override + public int hashCode() { + return Objects.hash(sessionId, epoch); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FetchMetadata that = (FetchMetadata) o; + return sessionId == that.sessionId && epoch == that.epoch; + } + + /** + * Return the metadata for the next error response. + */ + public FetchMetadata nextCloseExisting() { + return new FetchMetadata(sessionId, INITIAL_EPOCH); + } + + /** + * Return the metadata for the next full fetch request. + */ + public static FetchMetadata newIncremental(int sessionId) { + return new FetchMetadata(sessionId, nextEpoch(INITIAL_EPOCH)); + } + + /** + * Return the metadata for the next incremental response. + */ + public FetchMetadata nextIncremental() { + return new FetchMetadata(sessionId, nextEpoch(epoch)); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + if (sessionId == INVALID_SESSION_ID) { + bld.append("(sessionId=INVALID, "); + } else { + bld.append("(sessionId=").append(sessionId).append(", "); + } + if (epoch == INITIAL_EPOCH) { + bld.append("epoch=INITIAL)"); + } else if (epoch == FINAL_EPOCH) { + bld.append("epoch=FINAL)"); + } else { + bld.append("epoch=").append(epoch).append(")"); + } + return bld.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java index 18425d0360ba8..65cf7fea6b1d7 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java @@ -23,19 +23,27 @@ import org.apache.kafka.common.protocol.types.Field; import org.apache.kafka.common.protocol.types.Schema; import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.utils.Utils; import java.nio.ByteBuffer; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.apache.kafka.common.protocol.CommonFields.PARTITION_ID; import static org.apache.kafka.common.protocol.CommonFields.TOPIC_NAME; import static org.apache.kafka.common.protocol.types.Type.INT32; import static org.apache.kafka.common.protocol.types.Type.INT64; import static org.apache.kafka.common.protocol.types.Type.INT8; +import static org.apache.kafka.common.requests.FetchMetadata.FINAL_EPOCH; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; public class FetchRequest extends AbstractRequest { public static final int CONSUMER_REPLICA_ID = -1; @@ -44,6 +52,7 @@ public class FetchRequest extends AbstractRequest { private static final String MIN_BYTES_KEY_NAME = "min_bytes"; private static final String ISOLATION_LEVEL_KEY_NAME = "isolation_level"; private static final String TOPICS_KEY_NAME = "topics"; + private static final String FORGOTTEN_TOPICS_DATA = "forgetten_topics_data"; // request and partition level name private static final String MAX_BYTES_KEY_NAME = "max_bytes"; @@ -139,9 +148,36 @@ public class FetchRequest extends AbstractRequest { */ private static final Schema FETCH_REQUEST_V6 = FETCH_REQUEST_V5; + // FETCH_REQUEST_V7 added incremental fetch requests. + public static final Field.Int32 SESSION_ID = new Field.Int32("session_id", "The fetch session ID"); + public static final Field.Int32 EPOCH = new Field.Int32("epoch", "The fetch epoch"); + + private static final Schema FORGOTTEN_TOPIC_DATA = new Schema( + TOPIC_NAME, + new Field(PARTITIONS_KEY_NAME, new ArrayOf(Type.INT32), + "Partitions to remove from the fetch session.")); + + private static final Schema FETCH_REQUEST_V7 = new Schema( + new Field(REPLICA_ID_KEY_NAME, INT32, "Broker id of the follower. For normal consumers, use -1."), + new Field(MAX_WAIT_KEY_NAME, INT32, "Maximum time in ms to wait for the response."), + new Field(MIN_BYTES_KEY_NAME, INT32, "Minimum bytes to accumulate in the response."), + new Field(MAX_BYTES_KEY_NAME, INT32, "Maximum bytes to accumulate in the response. Note that this is not an absolute maximum, " + + "if the first message in the first non-empty partition of the fetch is larger than this " + + "value, the message will still be returned to ensure that progress can be made."), + new Field(ISOLATION_LEVEL_KEY_NAME, INT8, "This setting controls the visibility of transactional records. Using READ_UNCOMMITTED " + + "(isolation_level = 0) makes all records visible. With READ_COMMITTED (isolation_level = 1), " + + "non-transactional and COMMITTED transactional records are visible. To be more concrete, " + + "READ_COMMITTED returns all data from offsets smaller than the current LSO (last stable offset), " + + "and enables the inclusion of the list of aborted transactions in the result, which allows " + + "consumers to discard ABORTED transactional records"), + SESSION_ID, + EPOCH, + new Field(TOPICS_KEY_NAME, new ArrayOf(FETCH_REQUEST_TOPIC_V5), "Topics to fetch in the order provided."), + new Field(FORGOTTEN_TOPICS_DATA, new ArrayOf(FORGOTTEN_TOPIC_DATA), "Topics to remove from the fetch session.")); + public static Schema[] schemaVersions() { return new Schema[]{FETCH_REQUEST_V0, FETCH_REQUEST_V1, FETCH_REQUEST_V2, FETCH_REQUEST_V3, FETCH_REQUEST_V4, - FETCH_REQUEST_V5, FETCH_REQUEST_V6}; + FETCH_REQUEST_V5, FETCH_REQUEST_V6, FETCH_REQUEST_V7}; }; // default values for older versions where a request level limit did not exist @@ -153,7 +189,14 @@ public static Schema[] schemaVersions() { private final int minBytes; private final int maxBytes; private final IsolationLevel isolationLevel; - private final LinkedHashMap fetchData; + + // Note: the iteration order of this map is significant, since it determines the order + // in which partitions appear in the message. For this reason, this map should have a + // deterministic iteration order, like LinkedHashMap or TreeMap (but unlike HashMap). + private final Map fetchData; + + private final List toForget; + private final FetchMetadata metadata; public static final class PartitionData { public final long fetchOffset; @@ -170,6 +213,21 @@ public PartitionData(long fetchOffset, long logStartOffset, int maxBytes) { public String toString() { return "(offset=" + fetchOffset + ", logStartOffset=" + logStartOffset + ", maxBytes=" + maxBytes + ")"; } + + @Override + public int hashCode() { + return Objects.hash(fetchOffset, logStartOffset, maxBytes); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PartitionData that = (PartitionData) o; + return Objects.equals(fetchOffset, that.fetchOffset) && + Objects.equals(logStartOffset, that.logStartOffset) && + Objects.equals(maxBytes, that.maxBytes); + } } static final class TopicAndPartitionData { @@ -181,9 +239,10 @@ public TopicAndPartitionData(String topic) { this.partitions = new LinkedHashMap<>(); } - public static List> batchByTopic(LinkedHashMap data) { + public static List> batchByTopic(Iterator> iter) { List> topics = new ArrayList<>(); - for (Map.Entry topicEntry : data.entrySet()) { + while (iter.hasNext()) { + Map.Entry topicEntry = iter.next(); String topic = topicEntry.getKey().topic(); int partition = topicEntry.getKey().partition(); T partitionData = topicEntry.getValue(); @@ -199,37 +258,42 @@ public static class Builder extends AbstractRequest.Builder { private final int maxWait; private final int minBytes; private final int replicaId; - private final LinkedHashMap fetchData; - private final IsolationLevel isolationLevel; + private final Map fetchData; + private IsolationLevel isolationLevel = IsolationLevel.READ_UNCOMMITTED; private int maxBytes = DEFAULT_RESPONSE_MAX_BYTES; + private FetchMetadata metadata = FetchMetadata.LEGACY; + private List toForget = Collections.emptyList(); - public static Builder forConsumer(int maxWait, int minBytes, LinkedHashMap fetchData) { - return forConsumer(maxWait, minBytes, fetchData, IsolationLevel.READ_UNCOMMITTED); - } - - public static Builder forConsumer(int maxWait, int minBytes, LinkedHashMap fetchData, - IsolationLevel isolationLevel) { - return new Builder(ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), CONSUMER_REPLICA_ID, - maxWait, minBytes, fetchData, isolationLevel); + public static Builder forConsumer(int maxWait, int minBytes, Map fetchData) { + return new Builder(ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), + CONSUMER_REPLICA_ID, maxWait, minBytes, fetchData); } public static Builder forReplica(short allowedVersion, int replicaId, int maxWait, int minBytes, - LinkedHashMap fetchData) { - return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData, - IsolationLevel.READ_UNCOMMITTED); + Map fetchData) { + return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData); } - private Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes, - LinkedHashMap fetchData, IsolationLevel isolationLevel) { + public Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes, + Map fetchData) { super(ApiKeys.FETCH, minVersion, maxVersion); this.replicaId = replicaId; this.maxWait = maxWait; this.minBytes = minBytes; this.fetchData = fetchData; + } + + public Builder isolationLevel(IsolationLevel isolationLevel) { this.isolationLevel = isolationLevel; + return this; + } + + public Builder metadata(FetchMetadata metadata) { + this.metadata = metadata; + return this; } - public LinkedHashMap fetchData() { + public Map fetchData() { return this.fetchData; } @@ -238,13 +302,23 @@ public Builder setMaxBytes(int maxBytes) { return this; } + public List toForget() { + return toForget; + } + + public Builder toForget(List toForget) { + this.toForget = toForget; + return this; + } + @Override public FetchRequest build(short version) { if (version < 3) { maxBytes = DEFAULT_RESPONSE_MAX_BYTES; } - return new FetchRequest(version, replicaId, maxWait, minBytes, maxBytes, fetchData, isolationLevel); + return new FetchRequest(version, replicaId, maxWait, minBytes, maxBytes, fetchData, + isolationLevel, toForget, metadata); } @Override @@ -257,13 +331,16 @@ public String toString() { append(", maxBytes=").append(maxBytes). append(", fetchData=").append(fetchData). append(", isolationLevel=").append(isolationLevel). + append(", toForget=").append(Utils.join(toForget, ", ")). + append(", metadata=").append(metadata). append(")"); return bld.toString(); } } private FetchRequest(short version, int replicaId, int maxWait, int minBytes, int maxBytes, - LinkedHashMap fetchData, IsolationLevel isolationLevel) { + Map fetchData, IsolationLevel isolationLevel, + List toForget, FetchMetadata metadata) { super(version); this.replicaId = replicaId; this.maxWait = maxWait; @@ -271,6 +348,8 @@ private FetchRequest(short version, int replicaId, int maxWait, int minBytes, in this.maxBytes = maxBytes; this.fetchData = fetchData; this.isolationLevel = isolationLevel; + this.toForget = toForget; + this.metadata = metadata; } public FetchRequest(Struct struct, short version) { @@ -282,11 +361,23 @@ public FetchRequest(Struct struct, short version) { maxBytes = struct.getInt(MAX_BYTES_KEY_NAME); else maxBytes = DEFAULT_RESPONSE_MAX_BYTES; - if (struct.hasField(ISOLATION_LEVEL_KEY_NAME)) isolationLevel = IsolationLevel.forId(struct.getByte(ISOLATION_LEVEL_KEY_NAME)); else isolationLevel = IsolationLevel.READ_UNCOMMITTED; + toForget = new ArrayList<>(0); + if (struct.hasField(FORGOTTEN_TOPICS_DATA)) { + for (Object forgottenTopicObj : struct.getArray(FORGOTTEN_TOPICS_DATA)) { + Struct forgottenTopic = (Struct) forgottenTopicObj; + String topicName = forgottenTopic.get(TOPIC_NAME); + for (Object partObj : forgottenTopic.getArray(PARTITIONS_KEY_NAME)) { + Integer part = (Integer) partObj; + toForget.add(new TopicPartition(topicName, part)); + } + } + } + metadata = new FetchMetadata(struct.getOrElse(SESSION_ID, INVALID_SESSION_ID), + struct.getOrElse(EPOCH, FINAL_EPOCH)); fetchData = new LinkedHashMap<>(); for (Object topicResponseObj : struct.getArray(TOPICS_KEY_NAME)) { @@ -307,15 +398,21 @@ public FetchRequest(Struct struct, short version) { @Override public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + // The error is indicated in two ways: by setting the same error code in all partitions, and by + // setting the top-level error code. The form where we set the same error code in all partitions + // is needed in order to maintain backwards compatibility with older versions of the protocol + // in which there was no top-level error code. Note that for incremental fetch responses, there + // may not be any partitions at all in the response. For this reason, the top-level error code + // is essential for them. + Errors error = Errors.forException(e); LinkedHashMap responseData = new LinkedHashMap<>(); - - for (Map.Entry entry: fetchData.entrySet()) { - FetchResponse.PartitionData partitionResponse = new FetchResponse.PartitionData(Errors.forException(e), - FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, - null, MemoryRecords.EMPTY); + for (Map.Entry entry : fetchData.entrySet()) { + FetchResponse.PartitionData partitionResponse = new FetchResponse.PartitionData(error, + FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, + FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY); responseData.put(entry.getKey(), partitionResponse); } - return new FetchResponse(responseData, throttleTimeMs); + return new FetchResponse(error, responseData, throttleTimeMs, metadata.sessionId()); } public int replicaId() { @@ -338,6 +435,10 @@ public Map fetchData() { return fetchData; } + public List toForget() { + return toForget; + } + public boolean isFromFollower() { return replicaId >= 0; } @@ -346,6 +447,10 @@ public IsolationLevel isolationLevel() { return isolationLevel; } + public FetchMetadata metadata() { + return metadata; + } + public static FetchRequest parse(ByteBuffer buffer, short version) { return new FetchRequest(ApiKeys.FETCH.parseRequest(version, buffer), version); } @@ -353,7 +458,8 @@ public static FetchRequest parse(ByteBuffer buffer, short version) { @Override protected Struct toStruct() { Struct struct = new Struct(ApiKeys.FETCH.requestSchema(version())); - List> topicsData = TopicAndPartitionData.batchByTopic(fetchData); + List> topicsData = + TopicAndPartitionData.batchByTopic(fetchData.entrySet().iterator()); struct.set(REPLICA_ID_KEY_NAME, replicaId); struct.set(MAX_WAIT_KEY_NAME, maxWait); @@ -362,6 +468,8 @@ protected Struct toStruct() { struct.set(MAX_BYTES_KEY_NAME, maxBytes); if (struct.hasField(ISOLATION_LEVEL_KEY_NAME)) struct.set(ISOLATION_LEVEL_KEY_NAME, isolationLevel.id()); + struct.setIfExists(SESSION_ID, metadata.sessionId()); + struct.setIfExists(EPOCH, metadata.epoch()); List topicArray = new ArrayList<>(); for (TopicAndPartitionData topicEntry : topicsData) { @@ -382,6 +490,25 @@ protected Struct toStruct() { topicArray.add(topicData); } struct.set(TOPICS_KEY_NAME, topicArray.toArray()); + if (struct.hasField(FORGOTTEN_TOPICS_DATA)) { + Map> topicsToPartitions = new HashMap<>(); + for (TopicPartition part : toForget) { + List partitions = topicsToPartitions.get(part.topic()); + if (partitions == null) { + partitions = new ArrayList<>(); + topicsToPartitions.put(part.topic(), partitions); + } + partitions.add(part.partition()); + } + List toForgetStructs = new ArrayList<>(); + for (Map.Entry> entry : topicsToPartitions.entrySet()) { + Struct toForgetStruct = struct.instance(FORGOTTEN_TOPICS_DATA); + toForgetStruct.set(TOPIC_NAME, entry.getKey()); + toForgetStruct.set(PARTITIONS_KEY_NAME, entry.getValue().toArray()); + toForgetStructs.add(toForgetStruct); + } + struct.set(FORGOTTEN_TOPICS_DATA, toForgetStructs.toArray()); + } return struct; } } diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java index 0d09027802ab8..98c6be333e15b 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java @@ -31,6 +31,7 @@ import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -42,6 +43,7 @@ import static org.apache.kafka.common.protocol.types.Type.INT64; import static org.apache.kafka.common.protocol.types.Type.RECORDS; import static org.apache.kafka.common.protocol.types.Type.STRING; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; /** * This wrapper supports all versions of the Fetch API @@ -148,9 +150,19 @@ public class FetchResponse extends AbstractResponse { */ private static final Schema FETCH_RESPONSE_V6 = FETCH_RESPONSE_V5; + // FETCH_REESPONSE_V7 added incremental fetch responses and a top-level error code. + public static final Field.Int32 SESSION_ID = new Field.Int32("session_id", "The fetch session ID"); + + private static final Schema FETCH_RESPONSE_V7 = new Schema( + THROTTLE_TIME_MS, + ERROR_CODE, + SESSION_ID, + new Field(RESPONSES_KEY_NAME, new ArrayOf(FETCH_RESPONSE_TOPIC_V5))); + public static Schema[] schemaVersions() { return new Schema[] {FETCH_RESPONSE_V0, FETCH_RESPONSE_V1, FETCH_RESPONSE_V2, - FETCH_RESPONSE_V3, FETCH_RESPONSE_V4, FETCH_RESPONSE_V5, FETCH_RESPONSE_V6}; + FETCH_RESPONSE_V3, FETCH_RESPONSE_V4, FETCH_RESPONSE_V5, FETCH_RESPONSE_V6, + FETCH_RESPONSE_V7}; } @@ -168,8 +180,10 @@ public static Schema[] schemaVersions() { * UNKNOWN (-1) */ - private final LinkedHashMap responseData; private final int throttleTimeMs; + private final Errors error; + private final int sessionId; + private final LinkedHashMap responseData; public static final class AbortedTransaction { public final long producerId; @@ -268,17 +282,20 @@ public String toString() { } /** - * Constructor for all versions. - * * From version 3 or later, the entries in `responseData` should be in the same order as the entries in * `FetchRequest.fetchData`. * - * @param responseData fetched data grouped by topic-partition - * @param throttleTimeMs Time in milliseconds the response was throttled + * @param error The top-level error code. + * @param responseData The fetched data grouped by partition. + * @param throttleTimeMs The time in milliseconds that the response was throttled + * @param sessionId The fetch session id. */ - public FetchResponse(LinkedHashMap responseData, int throttleTimeMs) { + public FetchResponse(Errors error, LinkedHashMap responseData, + int throttleTimeMs, int sessionId) { + this.error = error; this.responseData = responseData; this.throttleTimeMs = throttleTimeMs; + this.sessionId = sessionId; } public FetchResponse(Struct struct) { @@ -316,17 +333,19 @@ public FetchResponse(Struct struct) { } PartitionData partitionData = new PartitionData(error, highWatermark, lastStableOffset, logStartOffset, - abortedTransactions, records); + abortedTransactions, records); responseData.put(new TopicPartition(topic, partition), partitionData); } } this.responseData = responseData; this.throttleTimeMs = struct.getOrElse(THROTTLE_TIME_MS, DEFAULT_THROTTLE_TIME); + this.error = Errors.forCode(struct.getOrElse(ERROR_CODE, (short) 0)); + this.sessionId = struct.getOrElse(SESSION_ID, INVALID_SESSION_ID); } @Override public Struct toStruct(short version) { - return toStruct(version, responseData, throttleTimeMs); + return toStruct(version, throttleTimeMs, error, responseData.entrySet().iterator(), sessionId); } @Override @@ -346,6 +365,10 @@ protected Send toSend(String dest, ResponseHeader responseHeader, short apiVersi return new MultiSend(dest, sends); } + public Errors error() { + return error; + } + public LinkedHashMap responseData() { return responseData; } @@ -354,6 +377,10 @@ public int throttleTimeMs() { return this.throttleTimeMs; } + public int sessionId() { + return sessionId; + } + @Override public Map errorCounts() { Map errorCounts = new HashMap<>(); @@ -369,7 +396,15 @@ public static FetchResponse parse(ByteBuffer buffer, short version) { private static void addResponseData(Struct struct, int throttleTimeMs, String dest, List sends) { Object[] allTopicData = struct.getArray(RESPONSES_KEY_NAME); - if (struct.hasField(THROTTLE_TIME_MS)) { + if (struct.hasField(ERROR_CODE)) { + ByteBuffer buffer = ByteBuffer.allocate(14); + buffer.putInt(throttleTimeMs); + buffer.putShort(struct.get(ERROR_CODE)); + buffer.putInt(struct.get(SESSION_ID)); + buffer.putInt(allTopicData.length); + buffer.rewind(); + sends.add(new ByteBufferSend(dest, buffer)); + } else if (struct.hasField(THROTTLE_TIME_MS)) { ByteBuffer buffer = ByteBuffer.allocate(8); buffer.putInt(throttleTimeMs); buffer.putInt(allTopicData.length); @@ -416,9 +451,14 @@ private static void addPartitionData(String dest, List sends, Struct parti sends.add(new RecordsSend(dest, records)); } - private static Struct toStruct(short version, LinkedHashMap responseData, int throttleTimeMs) { + private static Struct toStruct(short version, int throttleTimeMs, Errors error, + Iterator> partIterator, int sessionId) { Struct struct = new Struct(ApiKeys.FETCH.responseSchema(version)); - List> topicsData = FetchRequest.TopicAndPartitionData.batchByTopic(responseData); + struct.setIfExists(THROTTLE_TIME_MS, throttleTimeMs); + struct.setIfExists(ERROR_CODE, error.code()); + struct.setIfExists(SESSION_ID, sessionId); + List> topicsData = + FetchRequest.TopicAndPartitionData.batchByTopic(partIterator); List topicArray = new ArrayList<>(); for (FetchRequest.TopicAndPartitionData topicEntry: topicsData) { Struct topicData = struct.instance(RESPONSES_KEY_NAME); @@ -466,13 +506,20 @@ private static Struct toStruct(short version, LinkedHashMap responseData) { - return 4 + toStruct(version, responseData, 0).sizeOf(); + /** + * Convenience method to find the size of a response. + * + * @param version The version of the response to use. + * @param partIterator The partition iterator. + * @return The response size in bytes. + */ + public static int sizeOf(short version, Iterator> partIterator) { + // Since the throttleTimeMs and metadata field sizes are constant and fixed, we can + // use arbitrary values here without affecting the result. + return 4 + toStruct(version, 0, Errors.NONE, partIterator, INVALID_SESSION_ID).sizeOf(); } } diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java new file mode 100644 index 0000000000000..701684dd0b451 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashSet.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.AbstractSet; +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * A LinkedHashSet which is more memory-efficient than the standard implementation. + * + * This set preserves the order of insertion. The order of iteration will always be + * the order of insertion. + * + * This collection requires previous and next indexes to be embedded into each + * element. Using array indices rather than pointers saves space on large heaps + * where pointer compression is not in use. It also reduces the amount of time + * the garbage collector has to spend chasing pointers. + * + * This class uses linear probing. Unlike HashMap (but like HashTable), we don't force + * the size to be a power of 2. This saves memory. + * + * This class does not have internal synchronization. + */ +@SuppressWarnings("unchecked") +public class ImplicitLinkedHashSet extends AbstractSet { + public interface Element { + int prev(); + void setPrev(int e); + int next(); + void setNext(int e); + } + + private static final int HEAD_INDEX = -1; + + public static final int INVALID_INDEX = -2; + + private static class HeadElement implements Element { + private int prev = HEAD_INDEX; + private int next = HEAD_INDEX; + + @Override + public int prev() { + return prev; + } + + @Override + public void setPrev(int prev) { + this.prev = prev; + } + + @Override + public int next() { + return next; + } + + @Override + public void setNext(int next) { + this.next = next; + } + } + + private static Element indexToElement(Element head, Element[] elements, int index) { + if (index == HEAD_INDEX) { + return head; + } + return elements[index]; + } + + private static void addToListTail(Element head, Element[] elements, int elementIdx) { + int oldTailIdx = head.prev(); + Element element = indexToElement(head, elements, elementIdx); + Element oldTail = indexToElement(head, elements, oldTailIdx); + head.setPrev(elementIdx); + oldTail.setNext(elementIdx); + element.setPrev(oldTailIdx); + element.setNext(HEAD_INDEX); + } + + private static void removeFromList(Element head, Element[] elements, int elementIdx) { + Element element = indexToElement(head, elements, elementIdx); + elements[elementIdx] = null; + int prevIdx = element.prev(); + int nextIdx = element.next(); + Element prev = indexToElement(head, elements, prevIdx); + Element next = indexToElement(head, elements, nextIdx); + prev.setNext(nextIdx); + next.setPrev(prevIdx); + element.setNext(INVALID_INDEX); + element.setPrev(INVALID_INDEX); + } + + private class ImplicitLinkedHashSetIterator implements Iterator { + private Element cur = head; + + private Element next = indexToElement(head, elements, head.next()); + + @Override + public boolean hasNext() { + return next != head; + } + + @Override + public E next() { + if (next == head) { + throw new NoSuchElementException(); + } + cur = next; + next = indexToElement(head, elements, cur.next()); + return (E) cur; + } + + @Override + public void remove() { + if (cur == head) { + throw new IllegalStateException(); + } + ImplicitLinkedHashSet.this.remove(cur); + cur = head; + } + } + + private Element head; + + private Element[] elements; + + private int size; + + @Override + public Iterator iterator() { + return new ImplicitLinkedHashSetIterator(); + } + + private static int slot(Element[] curElements, Element e) { + return (e.hashCode() & 0x7fffffff) % curElements.length; + } + + /** + * Find an element matching an example element. + * + * Using the element's hash code, we can look up the slot where it belongs. + * However, it may not have ended up in exactly this slot, due to a collision. + * Therefore, we must search forward in the array until we hit a null, before + * concluding that the element is not present. + * + * @param example The element to match. + * @return The match index, or INVALID_INDEX if no match was found. + */ + private int findIndex(E example) { + int slot = slot(elements, example); + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[slot]; + if (element == null) { + return INVALID_INDEX; + } + if (element.equals(example)) { + return slot; + } + slot = (slot + 1) % elements.length; + } + return INVALID_INDEX; + } + + /** + * Find the element which equals() the given example element. + * + * @param example The example element. + * @return Null if no element was found; the element, otherwise. + */ + public E find(E example) { + int index = findIndex(example); + if (index == INVALID_INDEX) { + return null; + } + return (E) elements[index]; + } + + /** + * Returns the number of elements in the set. + */ + @Override + public int size() { + return size; + } + + @Override + public boolean contains(Object o) { + E example = null; + try { + example = (E) o; + } catch (ClassCastException e) { + return false; + } + return find(example) != null; + } + + @Override + public boolean add(E newElement) { + if ((size + 1) >= elements.length / 2) { + // Avoid using even-sized capacities, to get better key distribution. + changeCapacity((2 * elements.length) + 1); + } + int slot = addInternal(newElement, elements); + if (slot >= 0) { + addToListTail(head, elements, slot); + size++; + return true; + } + return false; + } + + public void mustAdd(E newElement) { + if (!add(newElement)) { + throw new RuntimeException("Unable to add " + newElement); + } + } + + /** + * Adds a new element to the appropriate place in the elements array. + * + * @param newElement The new element to add. + * @param addElements The elements array. + * @return The index at which the element was inserted, or INVALID_INDEX + * if the element could not be inserted because there was already + * an equivalent element. + */ + private static int addInternal(Element newElement, Element[] addElements) { + int slot = slot(addElements, newElement); + for (int seen = 0; seen < addElements.length; seen++) { + Element element = addElements[slot]; + if (element == null) { + addElements[slot] = newElement; + return slot; + } + if (element.equals(newElement)) { + return INVALID_INDEX; + } + slot = (slot + 1) % addElements.length; + } + throw new RuntimeException("Not enough hash table slots to add a new element."); + } + + private void changeCapacity(int newCapacity) { + Element[] newElements = new Element[newCapacity]; + HeadElement newHead = new HeadElement(); + int oldSize = size; + for (Iterator iter = iterator(); iter.hasNext(); ) { + Element element = iter.next(); + iter.remove(); + int newSlot = addInternal(element, newElements); + addToListTail(newHead, newElements, newSlot); + } + this.elements = newElements; + this.head = newHead; + this.size = oldSize; + } + + @Override + public boolean remove(Object o) { + E example = null; + try { + example = (E) o; + } catch (ClassCastException e) { + return false; + } + int slot = findIndex(example); + if (slot == INVALID_INDEX) { + return false; + } + size--; + removeFromList(head, elements, slot); + slot = (slot + 1) % elements.length; + + // Find the next empty slot + int endSlot = slot; + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[endSlot]; + if (element == null) { + break; + } + endSlot = (endSlot + 1) % elements.length; + } + + // We must preserve the denseness invariant. The denseness invariant says that + // any element is either in the slot indicated by its hash code, or a slot which + // is not separated from that slot by any nulls. + // Reseat all elements in between the deleted element and the next empty slot. + while (slot != endSlot) { + reseat(slot); + slot = (slot + 1) % elements.length; + } + return true; + } + + private void reseat(int prevSlot) { + Element element = elements[prevSlot]; + int newSlot = slot(elements, element); + for (int seen = 0; seen < elements.length; seen++) { + Element e = elements[newSlot]; + if ((e == null) || (e == element)) { + break; + } + newSlot = (newSlot + 1) % elements.length; + } + if (newSlot == prevSlot) { + return; + } + Element prev = indexToElement(head, elements, element.prev()); + prev.setNext(newSlot); + Element next = indexToElement(head, elements, element.next()); + next.setPrev(newSlot); + elements[prevSlot] = null; + elements[newSlot] = element; + } + + @Override + public void clear() { + reset(elements.length); + } + + public ImplicitLinkedHashSet() { + this(5); + } + + public ImplicitLinkedHashSet(int initialCapacity) { + reset(initialCapacity); + } + + private void reset(int capacity) { + this.head = new HeadElement(); + // Avoid using even-sized capacities, to get better key distribution. + this.elements = new Element[(2 * capacity) + 1]; + this.size = 0; + } + + int numSlots() { + return elements.length; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java new file mode 100644 index 0000000000000..3095717e8dd72 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; + +import static org.apache.kafka.common.requests.FetchMetadata.INITIAL_EPOCH; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * A unit test for FetchSessionHandler. + */ +public class FetchSessionHandlerTest { + @Rule + final public Timeout globalTimeout = Timeout.millis(120000); + + private static final LogContext LOG_CONTEXT = new LogContext("[FetchSessionHandler]="); + + private static final Logger log = LOG_CONTEXT.logger(FetchSessionHandler.class); + + /** + * Create a set of TopicPartitions. We use a TreeSet, in order to get a deterministic + * ordering for test purposes. + */ + private final static Set toSet(TopicPartition... arr) { + TreeSet set = new TreeSet<>(new Comparator() { + @Override + public int compare(TopicPartition o1, TopicPartition o2) { + return o1.toString().compareTo(o2.toString()); + } + }); + set.addAll(Arrays.asList(arr)); + return set; + } + + @Test + public void testFindMissing() throws Exception { + TopicPartition foo0 = new TopicPartition("foo", 0); + TopicPartition foo1 = new TopicPartition("foo", 1); + TopicPartition bar0 = new TopicPartition("bar", 0); + TopicPartition bar1 = new TopicPartition("bar", 1); + TopicPartition baz0 = new TopicPartition("baz", 0); + TopicPartition baz1 = new TopicPartition("baz", 1); + assertEquals(toSet(), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo0))); + assertEquals(toSet(foo0), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo1))); + assertEquals(toSet(foo0, foo1), + FetchSessionHandler.findMissing(toSet(foo0, foo1), toSet(baz0))); + assertEquals(toSet(bar1, foo0, foo1), + FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1), + toSet(bar0, baz0, baz1))); + assertEquals(toSet(), + FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1, baz1), + toSet(foo0, foo1, bar0, bar1, baz0, baz1))); + } + + private static final class ReqEntry { + final TopicPartition part; + final FetchRequest.PartitionData data; + + ReqEntry(String topic, int partition, long fetchOffset, long logStartOffset, int maxBytes) { + this.part = new TopicPartition(topic, partition); + this.data = new FetchRequest.PartitionData(fetchOffset, logStartOffset, maxBytes); + } + } + + private static LinkedHashMap reqMap(ReqEntry... entries) { + LinkedHashMap map = new LinkedHashMap<>(); + for (ReqEntry entry : entries) { + map.put(entry.part, entry.data); + } + return map; + } + + private static void assertMapEquals(Map expected, + Map actual) { + Iterator> expectedIter = + expected.entrySet().iterator(); + Iterator> actualIter = + actual.entrySet().iterator(); + int i = 1; + while (expectedIter.hasNext()) { + Map.Entry expectedEntry = expectedIter.next(); + if (!actualIter.hasNext()) { + fail("Element " + i + " not found."); + } + Map.Entry actuaLEntry = actualIter.next(); + assertEquals("Element " + i + " had a different TopicPartition than expected.", + expectedEntry.getKey(), actuaLEntry.getKey()); + assertEquals("Element " + i + " had different PartitionData than expected.", + expectedEntry.getValue(), actuaLEntry.getValue()); + i++; + } + if (expectedIter.hasNext()) { + fail("Unexpected element " + i + " found."); + } + } + + private static void assertMapsEqual(Map expected, + Map... actuals) { + for (Map actual : actuals) { + assertMapEquals(expected, actual); + } + } + + private static void assertListEquals(List expected, List actual) { + for (TopicPartition expectedPart : expected) { + if (!actual.contains(expectedPart)) { + fail("Failed to find expected partition " + expectedPart); + } + } + for (TopicPartition actualPart : actual) { + if (!expected.contains(actualPart)) { + fail("Found unexpected partition " + actualPart); + } + } + } + + private static final class RespEntry { + final TopicPartition part; + final FetchResponse.PartitionData data; + + RespEntry(String topic, int partition, long highWatermark, long lastStableOffset) { + this.part = new TopicPartition(topic, partition); + this.data = new FetchResponse.PartitionData( + Errors.NONE, + highWatermark, + lastStableOffset, + 0, + null, + null); + } + } + + private static LinkedHashMap respMap(RespEntry... entries) { + LinkedHashMap map = new LinkedHashMap<>(); + for (RespEntry entry : entries) { + map.put(entry.part, entry.data); + } + return map; + } + + /** + * Test the handling of SESSIONLESS responses. + * Pre-KIP-227 brokers always supply this kind of response. + */ + @Test + public void testSessionless() throws Exception { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 110, 210)); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200), + new ReqEntry("foo", 1, 10, 110, 210)), + data.toSend(), data.sessionPartitions()); + assertEquals(INVALID_SESSION_ID, data.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data.metadata().epoch()); + + FetchResponse resp = new FetchResponse(Errors.NONE, + respMap(new RespEntry("foo", 0, 0, 0), + new RespEntry("foo", 1, 0, 0)), + 0, INVALID_SESSION_ID); + handler.handleResponse(resp); + + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertEquals(INVALID_SESSION_ID, data2.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data2.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + } + + /** + * Test handling an incremental fetch session. + */ + @Test + public void testIncrementals() throws Exception { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 110, 210)); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200), + new ReqEntry("foo", 1, 10, 110, 210)), + data.toSend(), data.sessionPartitions()); + assertEquals(INVALID_SESSION_ID, data.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data.metadata().epoch()); + + FetchResponse resp = new FetchResponse(Errors.NONE, + respMap(new RespEntry("foo", 0, 10, 20), + new RespEntry("foo", 1, 10, 20)), + 0, 123); + handler.handleResponse(resp); + + // Test an incremental fetch request which adds one partition and modifies another. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder2.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 120, 210)); + builder2.add(new TopicPartition("bar", 0), + new FetchRequest.PartitionData(20, 200, 200)); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertFalse(data2.metadata().isFull()); + assertMapEquals(reqMap(new ReqEntry("foo", 0, 0, 100, 200), + new ReqEntry("foo", 1, 10, 120, 210), + new ReqEntry("bar", 0, 20, 200, 200)), + data2.sessionPartitions()); + assertMapEquals(reqMap(new ReqEntry("bar", 0, 20, 200, 200), + new ReqEntry("foo", 1, 10, 120, 210)), + data2.toSend()); + + FetchResponse resp2 = new FetchResponse(Errors.NONE, + respMap(new RespEntry("foo", 1, 20, 20)), + 0, 123); + handler.handleResponse(resp2); + + // Skip building a new request. Test that handling an invalid fetch session epoch response results + // in a request which closes the session. + FetchResponse resp3 = new FetchResponse(Errors.INVALID_FETCH_SESSION_EPOCH, respMap(), + 0, INVALID_SESSION_ID); + handler.handleResponse(resp3); + + FetchSessionHandler.Builder builder4 = handler.newBuilder(); + builder4.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder4.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 120, 210)); + builder4.add(new TopicPartition("bar", 0), + new FetchRequest.PartitionData(20, 200, 200)); + FetchSessionHandler.FetchRequestData data4 = builder4.build(); + assertTrue(data4.metadata().isFull()); + assertEquals(data2.metadata().sessionId(), data4.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data4.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200), + new ReqEntry("foo", 1, 10, 120, 210), + new ReqEntry("bar", 0, 20, 200, 200)), + data4.sessionPartitions(), data4.toSend()); + } + + /** + * Test that calling FetchSessionHandler#Builder#build twice fails. + */ + @Test + public void testDoubleBuild() throws Exception { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder.build(); + try { + builder.build(); + fail("Expected calling build twice to fail."); + } catch (Throwable t) { + // expected + } + } + + @Test + public void testIncrementalPartitionRemoval() throws Exception { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + builder.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 110, 210)); + builder.add(new TopicPartition("bar", 0), + new FetchRequest.PartitionData(20, 120, 220)); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200), + new ReqEntry("foo", 1, 10, 110, 210), + new ReqEntry("bar", 0, 20, 120, 220)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + + FetchResponse resp = new FetchResponse(Errors.NONE, + respMap(new RespEntry("foo", 0, 10, 20), + new RespEntry("foo", 1, 10, 20), + new RespEntry("bar", 0, 10, 20)), + 0, 123); + handler.handleResponse(resp); + + // Test an incremental fetch request which removes two partitions. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(10, 110, 210)); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertFalse(data2.metadata().isFull()); + assertEquals(123, data2.metadata().sessionId()); + assertEquals(1, data2.metadata().epoch()); + assertMapEquals(reqMap(new ReqEntry("foo", 1, 10, 110, 210)), + data2.sessionPartitions()); + assertMapEquals(reqMap(), data2.toSend()); + ArrayList expectedToForget2 = new ArrayList<>(); + expectedToForget2.add(new TopicPartition("foo", 0)); + expectedToForget2.add(new TopicPartition("bar", 0)); + assertListEquals(expectedToForget2, data2.toForget()); + + // A FETCH_SESSION_ID_NOT_FOUND response triggers us to close the session. + // The next request is a session establishing FULL request. + FetchResponse resp2 = new FetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, + respMap(), 0, INVALID_SESSION_ID); + handler.handleResponse(resp2); + FetchSessionHandler.Builder builder3 = handler.newBuilder(); + builder3.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(0, 100, 200)); + FetchSessionHandler.FetchRequestData data3 = builder3.build(); + assertTrue(data3.metadata().isFull()); + assertEquals(INVALID_SESSION_ID, data3.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data3.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", 0, 0, 100, 200)), + data3.sessionPartitions(), data3.toSend()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java index a827168039ecb..d47124f8f3e56 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java @@ -99,6 +99,7 @@ import static java.util.Collections.singleton; import static java.util.Collections.singletonList; import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -1578,7 +1579,7 @@ private FetchResponse fetchResponse(Map fetches) { tpResponses.put(partition, new FetchResponse.PartitionData(Errors.NONE, 0, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records)); } - return new FetchResponse(tpResponses, 0); + return new FetchResponse(Errors.NONE, tpResponses, 0, INVALID_SESSION_ID); } private FetchResponse fetchResponse(TopicPartition partition, long fetchOffset, int count) { diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java index a3ea79356f94e..a0205e7f19ec2 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java @@ -103,6 +103,7 @@ import java.util.Set; import static java.util.Collections.singleton; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -139,6 +140,7 @@ public class FetcherTest { private MemoryRecords records; private MemoryRecords nextRecords; + private MemoryRecords emptyRecords; private Fetcher fetcher = createFetcher(subscriptions, metrics); private Metrics fetcherMetrics = new Metrics(time); private Fetcher fetcherNoAutoReset = createFetcher(subscriptionsNoAutoReset, fetcherMetrics); @@ -158,6 +160,9 @@ public void setup() throws Exception { builder.append(0L, "key".getBytes(), "value-4".getBytes()); builder.append(0L, "key".getBytes(), "value-5".getBytes()); nextRecords = builder.build(); + + builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + emptyRecords = builder.build(); } @After @@ -177,7 +182,7 @@ public void testFetchNormal() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -219,7 +224,7 @@ public void testFetcherIgnoresControlRecords() { buffer.flip(); - client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -242,7 +247,7 @@ public void testFetchError() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -283,7 +288,7 @@ public byte[] deserialize(String topic, byte[] data) { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 1); - client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); assertEquals(1, fetcher.sendFetches()); consumerClient.poll(0); @@ -345,7 +350,7 @@ public void testParseCorruptedRecord() throws Exception { // normal fetch assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); consumerClient.poll(0); // the first fetchedRecords() should return the first valid message @@ -383,7 +388,7 @@ private void seekAndConsumeRecord(ByteBuffer responseBuffer, long toOffset) { // Should not throw exception after the seek. fetcher.fetchedRecords(); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0)); consumerClient.poll(0); List> records = fetcher.fetchedRecords().get(tp0); @@ -416,7 +421,7 @@ public void testInvalidDefaultRecordBatch() { // normal fetch assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); consumerClient.poll(0); // the fetchedRecords() should always throw exception due to the bad batch. @@ -447,7 +452,7 @@ public void testParseInvalidRecordBatch() throws Exception { // normal fetch assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); consumerClient.poll(0); try { fetcher.fetchedRecords(); @@ -480,7 +485,7 @@ public void testHeaders() { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 1); - client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, memoryRecords, Errors.NONE, 100L, 0)); assertEquals(1, fetcher.sendFetches()); consumerClient.poll(0); @@ -510,8 +515,8 @@ public void testFetchMaxPollRecords() { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 1); - client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); - client.prepareResponse(matchesOffset(tp0, 4), fetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp0, 4), fullFetchResponse(tp0, this.nextRecords, Errors.NONE, 100L, 0)); assertEquals(1, fetcher.sendFetches()); consumerClient.poll(0); @@ -551,7 +556,7 @@ public void testFetchAfterPartitionWithFetchedRecordsIsUnassigned() { subscriptions.seek(tp0, 1); // Returns 3 records while `max.poll.records` is configured to 2 - client.prepareResponse(matchesOffset(tp0, 1), fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp0, 1), fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); assertEquals(1, fetcher.sendFetches()); consumerClient.poll(0); @@ -562,7 +567,7 @@ public void testFetchAfterPartitionWithFetchedRecordsIsUnassigned() { assertEquals(2, records.get(1).offset()); subscriptions.assignFromUser(singleton(tp1)); - client.prepareResponse(matchesOffset(tp1, 4), fetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tp1, 4), fullFetchResponse(tp1, this.nextRecords, Errors.NONE, 100L, 0)); subscriptions.seek(tp1, 4); assertEquals(1, fetcher.sendFetches()); @@ -594,7 +599,7 @@ public void testFetchNonContinuousRecords() { // normal fetch assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, records, Errors.NONE, 100L, 0)); consumerClient.poll(0); consumerRecords = fetcher.fetchedRecords().get(tp0); assertEquals(3, consumerRecords.size()); @@ -654,7 +659,7 @@ private void makeFetchRequestWithIncompleteRecord() { assertFalse(fetcher.hasCompletedFetches()); MemoryRecords partialRecord = MemoryRecords.readableRecords( ByteBuffer.wrap(new byte[]{0, 0, 0, 0, 0, 0, 0, 0})); - client.prepareResponse(fetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, partialRecord, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); } @@ -666,7 +671,7 @@ public void testUnauthorizedTopic() { // resize the limit of the buffer to pretend it is only fetch-size large assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0)); consumerClient.poll(0); try { fetcher.fetchedRecords(); @@ -686,7 +691,7 @@ public void testFetchDuringRebalance() { // Now the rebalance happens and fetch positions are cleared subscriptions.assignFromSubscribed(singleton(tp0)); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); consumerClient.poll(0); // The active fetch should be ignored since its position is no longer valid @@ -701,7 +706,7 @@ public void testInFlightFetchOnPausedPartition() { assertEquals(1, fetcher.sendFetches()); subscriptions.pause(tp0); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertNull(fetcher.fetchedRecords().get(tp0)); } @@ -722,7 +727,7 @@ public void testFetchNotLeaderForPartition() { subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NOT_LEADER_FOR_PARTITION, 100L, 0)); consumerClient.poll(0); assertEquals(0, fetcher.fetchedRecords().size()); assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); @@ -734,7 +739,7 @@ public void testFetchUnknownTopicOrPartition() { subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0)); consumerClient.poll(0); assertEquals(0, fetcher.fetchedRecords().size()); assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); @@ -746,7 +751,7 @@ public void testFetchOffsetOutOfRange() { subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); consumerClient.poll(0); assertEquals(0, fetcher.fetchedRecords().size()); assertTrue(subscriptions.isOffsetResetNeeded(tp0)); @@ -761,7 +766,7 @@ public void testStaleOutOfRangeError() { subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); subscriptions.seek(tp0, 1); consumerClient.poll(0); assertEquals(0, fetcher.fetchedRecords().size()); @@ -775,7 +780,7 @@ public void testFetchedRecordsAfterSeek() { subscriptionsNoAutoReset.seek(tp0, 0); assertTrue(fetcherNoAutoReset.sendFetches() > 0); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); consumerClient.poll(0); assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0)); subscriptionsNoAutoReset.seek(tp0, 2); @@ -788,7 +793,7 @@ public void testFetchOffsetOutOfRangeException() { subscriptionsNoAutoReset.seek(tp0, 0); fetcherNoAutoReset.sendFetches(); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); consumerClient.poll(0); assertFalse(subscriptionsNoAutoReset.isOffsetResetNeeded(tp0)); @@ -818,7 +823,8 @@ public void testFetchPositionAfterException() { FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records)); partitions.put(tp0, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)); - client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0)); + client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), + 0, INVALID_SESSION_ID)); consumerClient.poll(0); List> fetchedRecords = new ArrayList<>(); @@ -856,7 +862,7 @@ public void testSeekBeforeException() { Map partitions = new HashMap<>(); partitions.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, records)); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertEquals(2, fetcher.fetchedRecords().get(tp0).size()); @@ -867,7 +873,7 @@ public void testSeekBeforeException() { partitions = new HashMap<>(); partitions.put(tp1, new FetchResponse.PartitionData(Errors.OFFSET_OUT_OF_RANGE, 100, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY)); - client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0)); + client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), 0, INVALID_SESSION_ID)); consumerClient.poll(0); assertEquals(1, fetcher.fetchedRecords().get(tp0).size()); @@ -882,7 +888,7 @@ public void testFetchDisconnected() { subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0), true); consumerClient.poll(0); assertEquals(0, fetcher.fetchedRecords().size()); @@ -1148,7 +1154,7 @@ public void testQuotaMetrics() throws Exception { ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true, null); client.send(request, time.milliseconds()); client.poll(1, time.milliseconds()); - FetchResponse response = fetchResponse(tp0, nextRecords, Errors.NONE, i, throttleTimeMs); + FetchResponse response = fullFetchResponse(tp0, nextRecords, Errors.NONE, i, throttleTimeMs); buffer = response.serialize(ApiKeys.FETCH.latestVersion(), new ResponseHeader(request.correlationId())); selector.completeReceive(new NetworkReceive(node.idString(), buffer)); client.poll(1, time.milliseconds()); @@ -1325,7 +1331,8 @@ public void testFetchResponseMetricsWithOnePartitionError() { FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, MemoryRecords.EMPTY)); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0)); + client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), + 0, INVALID_SESSION_ID)); consumerClient.poll(0); fetcher.fetchedRecords(); @@ -1364,7 +1371,8 @@ public void testFetchResponseMetricsWithOnePartitionAtTheWrongOffset() { FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("val".getBytes())))); - client.prepareResponse(new FetchResponse(new LinkedHashMap<>(partitions), 0)); + client.prepareResponse(new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), + 0, INVALID_SESSION_ID)); consumerClient.poll(0); fetcher.fetchedRecords(); @@ -1390,7 +1398,7 @@ public void testFetcherMetricsTemplates() throws Exception { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, this.records, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); Map>> partitionRecords = fetcher.fetchedRecords(); @@ -1417,7 +1425,7 @@ private Map>> fetchRecords( private Map>> fetchRecords( TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) { assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, throttleTime)); + client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime)); consumerClient.poll(0); return fetcher.fetchedRecords(); } @@ -1495,7 +1503,7 @@ public void testSkippingAbortedTransactions() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1533,7 +1541,7 @@ public boolean matches(AbstractRequest body) { assertEquals(IsolationLevel.READ_COMMITTED, request.isolationLevel()); return true; } - }, fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + }, fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1604,7 +1612,7 @@ public void testReadCommittedWithCommittedAndAbortedTransactions() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1651,7 +1659,7 @@ public void testMultipleAbortMarkers() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1695,7 +1703,7 @@ public void testReadCommittedAbortMarkerWithNoData() { List abortedTransactions = new ArrayList<>(); abortedTransactions.add(new FetchResponse.AbortedTransaction(producerId, 0L)); - client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), + client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1733,7 +1741,7 @@ protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, compactedRecords, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1768,7 +1776,7 @@ public void testUpdatePositionOnEmptyBatch() { subscriptions.assignFromUser(singleton(tp0)); subscriptions.seek(tp0, 0); assertEquals(1, fetcher.sendFetches()); - client.prepareResponse(fetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0)); + client.prepareResponse(fullFetchResponse(tp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1829,7 +1837,7 @@ public void testReadCommittedWithCompactedTopic() { abortedTransactions.add(new FetchResponse.AbortedTransaction(pid2, 6L)); abortedTransactions.add(new FetchResponse.AbortedTransaction(pid1, 0L)); - client.prepareResponse(fetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), + client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1867,7 +1875,7 @@ public void testReturnAbortedTransactionsinUncommittedMode() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1900,7 +1908,7 @@ public void testConsumerPositionUpdatedWhenSkippingAbortedTransactions() { assertEquals(1, fetcher.sendFetches()); assertFalse(fetcher.hasCompletedFetches()); - client.prepareResponse(fetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); consumerClient.poll(0); assertTrue(fetcher.hasCompletedFetches()); @@ -1911,6 +1919,75 @@ public void testConsumerPositionUpdatedWhenSkippingAbortedTransactions() { assertEquals(currentOffset, (long) subscriptions.position(tp0)); } + @Test + public void testConsumingViaIncrementalFetchRequests() { + Fetcher fetcher = createFetcher(subscriptions, new Metrics(time), 2); + + List> records; + subscriptions.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1))); + subscriptions.seek(tp0, 0); + subscriptions.seek(tp1, 1); + + // Fetch some records and establish an incremental fetch session. + LinkedHashMap partitions1 = new LinkedHashMap<>(); + partitions1.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 2L, + 2, 0L, null, this.records)); + partitions1.put(tp1, new FetchResponse.PartitionData(Errors.NONE, 100L, + FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, emptyRecords)); + FetchResponse resp1 = new FetchResponse(Errors.NONE, partitions1, 0, 123); + client.prepareResponse(resp1); + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + consumerClient.poll(0); + assertTrue(fetcher.hasCompletedFetches()); + Map>> fetchedRecords = fetcher.fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(2, records.size()); + assertEquals(3L, subscriptions.position(tp0).longValue()); + assertEquals(1L, subscriptions.position(tp1).longValue()); + assertEquals(1, records.get(0).offset()); + assertEquals(2, records.get(1).offset()); + + // There is still a buffered record. + assertEquals(0, fetcher.sendFetches()); + fetchedRecords = fetcher.fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(1, records.size()); + assertEquals(3, records.get(0).offset()); + assertEquals(4L, subscriptions.position(tp0).longValue()); + + // The second response contains no new records. + LinkedHashMap partitions2 = new LinkedHashMap<>(); + FetchResponse resp2 = new FetchResponse(Errors.NONE, partitions2, 0, 123); + client.prepareResponse(resp2); + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(0); + fetchedRecords = fetcher.fetchedRecords(); + assertTrue(fetchedRecords.isEmpty()); + assertEquals(4L, subscriptions.position(tp0).longValue()); + assertEquals(1L, subscriptions.position(tp1).longValue()); + + // The third response contains some new records for tp0. + LinkedHashMap partitions3 = new LinkedHashMap<>(); + partitions3.put(tp0, new FetchResponse.PartitionData(Errors.NONE, 100L, + 4, 0L, null, this.nextRecords)); + new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions1), 0, INVALID_SESSION_ID); + FetchResponse resp3 = new FetchResponse(Errors.NONE, partitions3, 0, 123); + client.prepareResponse(resp3); + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(0); + fetchedRecords = fetcher.fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(2, records.size()); + assertEquals(6L, subscriptions.position(tp0).longValue()); + assertEquals(1L, subscriptions.position(tp1).longValue()); + assertEquals(4, records.get(0).offset()); + assertEquals(5, records.get(1).offset()); + } + private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, int baseSequence, SimpleRecord... records) { MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, TimestampType.CREATE_TIME, baseOffset, time.milliseconds(), pid, (short) 0, baseSequence, true, @@ -2033,7 +2110,7 @@ private ListOffsetResponse listOffsetResponse(TopicPartition tp, Errors error, l return new ListOffsetResponse(allPartitionData); } - private FetchResponse fetchResponseWithAbortedTransactions(MemoryRecords records, + private FetchResponse fullFetchResponseWithAbortedTransactions(MemoryRecords records, List abortedTransactions, Errors error, long lastStableOffset, @@ -2041,18 +2118,18 @@ private FetchResponse fetchResponseWithAbortedTransactions(MemoryRecords records int throttleTime) { Map partitions = Collections.singletonMap(tp0, new FetchResponse.PartitionData(error, hw, lastStableOffset, 0L, abortedTransactions, records)); - return new FetchResponse(new LinkedHashMap<>(partitions), throttleTime); + return new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), throttleTime, INVALID_SESSION_ID); } - private FetchResponse fetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) { - return fetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime); + private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) { + return fullFetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime); } - private FetchResponse fetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, + private FetchResponse fullFetchResponse(TopicPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) { Map partitions = Collections.singletonMap(tp, new FetchResponse.PartitionData(error, hw, lastStableOffset, 0L, null, records)); - return new FetchResponse(new LinkedHashMap<>(partitions), throttleTime); + return new FetchResponse(Errors.NONE, new LinkedHashMap<>(partitions), throttleTime, INVALID_SESSION_ID); } private MetadataResponse newMetadataResponse(String topic, Errors error) { diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java index b5420b5195de5..69b37e2e9a446 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java @@ -73,6 +73,7 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; import static org.apache.kafka.test.TestUtils.toBuffer; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -95,6 +96,13 @@ public void testSerialization() throws Exception { checkErrorResponse(createControlledShutdownRequest(0), new UnknownServerException()); checkRequest(createFetchRequest(4)); checkResponse(createFetchResponse(), 4); + List toForgetTopics = new ArrayList<>(); + toForgetTopics.add(new TopicPartition("foo", 0)); + toForgetTopics.add(new TopicPartition("foo", 2)); + toForgetTopics.add(new TopicPartition("bar", 0)); + checkRequest(createFetchRequest(7, new FetchMetadata(123, 456), toForgetTopics)); + checkResponse(createFetchResponse(123), 7); + checkResponse(createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123), 7); checkErrorResponse(createFetchRequest(4), new UnknownServerException()); checkRequest(createHeartBeatRequest()); checkErrorResponse(createHeartBeatRequest(), new UnknownServerException()); @@ -426,8 +434,8 @@ public void fetchResponseVersionTest() { responseData.put(new TopicPartition("test", 0), new FetchResponse.PartitionData(Errors.NONE, 1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records)); - FetchResponse v0Response = new FetchResponse(responseData, 0); - FetchResponse v1Response = new FetchResponse(responseData, 10); + FetchResponse v0Response = new FetchResponse(Errors.NONE, responseData, 0, INVALID_SESSION_ID); + FetchResponse v1Response = new FetchResponse(Errors.NONE, responseData, 10, INVALID_SESSION_ID); assertEquals("Throttle time must be zero", 0, v0Response.throttleTimeMs()); assertEquals("Throttle time must be 10", 10, v1Response.throttleTimeMs()); assertEquals("Should use schema version 0", ApiKeys.FETCH.responseSchema((short) 0), @@ -455,15 +463,22 @@ public void testFetchResponseV4() { responseData.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData(Errors.NONE, 70000, 6, FetchResponse.INVALID_LOG_START_OFFSET, Collections.emptyList(), records)); - FetchResponse response = new FetchResponse(responseData, 10); + FetchResponse response = new FetchResponse(Errors.NONE, responseData, 10, INVALID_SESSION_ID); FetchResponse deserialized = FetchResponse.parse(toBuffer(response.toStruct((short) 4)), (short) 4); assertEquals(responseData, deserialized.responseData()); } @Test - public void verifyFetchResponseFullWrite() throws Exception { - FetchResponse fetchResponse = createFetchResponse(); - short apiVersion = ApiKeys.FETCH.latestVersion(); + public void verifyFetchResponseFullWrites() throws Exception { + verifyFetchResponseFullWrite(ApiKeys.FETCH.latestVersion(), createFetchResponse(123)); + verifyFetchResponseFullWrite(ApiKeys.FETCH.latestVersion(), + createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123)); + for (short version = 0; version <= ApiKeys.FETCH.latestVersion(); version++) { + verifyFetchResponseFullWrite(version, createFetchResponse()); + } + } + + private void verifyFetchResponseFullWrite(short apiVersion, FetchResponse fetchResponse) throws Exception { int correlationId = 15; Send send = fetchResponse.toSend("1", new ResponseHeader(correlationId), apiVersion); @@ -525,6 +540,19 @@ public void testFetchRequestIsolationLevel() throws Exception { assertEquals(request.isolationLevel(), deserialized.isolationLevel()); } + @Test + public void testFetchRequestWithMetadata() throws Exception { + FetchRequest request = createFetchRequest(4, IsolationLevel.READ_COMMITTED); + Struct struct = request.toStruct(); + FetchRequest deserialized = (FetchRequest) deserialize(request, struct, request.version()); + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + + request = createFetchRequest(4, IsolationLevel.READ_UNCOMMITTED); + struct = request.toStruct(); + deserialized = (FetchRequest) deserialize(request, struct, request.version()); + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + } + @Test public void testJoinGroupRequestVersion0RebalanceTimeout() throws Exception { final short version = 0; @@ -556,11 +584,20 @@ private FindCoordinatorResponse createFindCoordinatorResponse() { return new FindCoordinatorResponse(Errors.NONE, new Node(10, "host1", 2014)); } + private FetchRequest createFetchRequest(int version, FetchMetadata metadata, List toForget) { + LinkedHashMap fetchData = new LinkedHashMap<>(); + fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, 0L, 1000000)); + fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, 0L, 1000000)); + return FetchRequest.Builder.forConsumer(100, 100000, fetchData). + metadata(metadata).setMaxBytes(1000).toForget(toForget).build((short) version); + } + private FetchRequest createFetchRequest(int version, IsolationLevel isolationLevel) { LinkedHashMap fetchData = new LinkedHashMap<>(); fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, 0L, 1000000)); fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, 0L, 1000000)); - return FetchRequest.Builder.forConsumer(100, 100000, fetchData, isolationLevel).setMaxBytes(1000).build((short) version); + return FetchRequest.Builder.forConsumer(100, 100000, fetchData). + isolationLevel(isolationLevel).setMaxBytes(1000).build((short) version); } private FetchRequest createFetchRequest(int version) { @@ -570,6 +607,23 @@ private FetchRequest createFetchRequest(int version) { return FetchRequest.Builder.forConsumer(100, 100000, fetchData).setMaxBytes(1000).build((short) version); } + private FetchResponse createFetchResponse(Errors error, int sessionId) { + return new FetchResponse(error, new LinkedHashMap(), + 25, sessionId); + } + + private FetchResponse createFetchResponse(int sessionId) { + LinkedHashMap responseData = new LinkedHashMap<>(); + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes())); + responseData.put(new TopicPartition("test", 0), new FetchResponse.PartitionData(Errors.NONE, + 1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, null, records)); + List abortedTransactions = Collections.singletonList( + new FetchResponse.AbortedTransaction(234L, 999L)); + responseData.put(new TopicPartition("test", 1), new FetchResponse.PartitionData(Errors.NONE, + 1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, abortedTransactions, MemoryRecords.EMPTY)); + return new FetchResponse(Errors.NONE, responseData, 25, sessionId); + } + private FetchResponse createFetchResponse() { LinkedHashMap responseData = new LinkedHashMap<>(); MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes())); @@ -581,7 +635,7 @@ private FetchResponse createFetchResponse() { responseData.put(new TopicPartition("test", 1), new FetchResponse.PartitionData(Errors.NONE, 1000000, FetchResponse.INVALID_LAST_STABLE_OFFSET, 0L, abortedTransactions, MemoryRecords.EMPTY)); - return new FetchResponse(responseData, 25); + return new FetchResponse(Errors.NONE, responseData, 25, INVALID_SESSION_ID); } private HeartbeatRequest createHeartBeatRequest() { diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java new file mode 100644 index 0000000000000..20084a266f208 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashSetTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertEquals; + +/** + * A unit test for ImplicitLinkedHashSet. + */ +public class ImplicitLinkedHashSetTest { + @Rule + final public Timeout globalTimeout = Timeout.millis(120000); + + private final static class TestElement implements ImplicitLinkedHashSet.Element { + private int prev = ImplicitLinkedHashSet.INVALID_INDEX; + private int next = ImplicitLinkedHashSet.INVALID_INDEX; + private final int val; + + TestElement(int val) { + this.val = val; + } + + @Override + public int prev() { + return prev; + } + + @Override + public void setPrev(int prev) { + this.prev = prev; + } + + @Override + public int next() { + return next; + } + + @Override + public void setNext(int next) { + this.next = next; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if ((o == null) || (o.getClass() != TestElement.class)) return false; + TestElement that = (TestElement) o; + return val == that.val; + } + + @Override + public String toString() { + return "TestElement(" + val + ")"; + } + + @Override + public int hashCode() { + return val; + } + } + + @Test + public void testInsertDelete() throws Exception { + ImplicitLinkedHashSet set = new ImplicitLinkedHashSet<>(100); + assertTrue(set.add(new TestElement(1))); + TestElement second = new TestElement(2); + assertTrue(set.add(second)); + assertTrue(set.add(new TestElement(3))); + assertFalse(set.add(new TestElement(3))); + assertEquals(3, set.size()); + assertTrue(set.contains(new TestElement(1))); + assertFalse(set.contains(new TestElement(4))); + TestElement secondAgain = set.find(new TestElement(2)); + assertTrue(second == secondAgain); + assertTrue(set.remove(new TestElement(1))); + assertFalse(set.remove(new TestElement(1))); + assertEquals(2, set.size()); + set.clear(); + assertEquals(0, set.size()); + } + + private static void expectTraversal(Iterator iterator, Integer... sequence) { + int i = 0; + while (iterator.hasNext()) { + TestElement element = iterator.next(); + Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " + + sequence.length + " were expected.", i < sequence.length); + Assert.assertEquals("Iterator value number " + (i + 1) + " was incorrect.", + sequence[i].intValue(), element.val); + i = i + 1; + } + Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but " + + sequence.length + " were expected.", i == sequence.length); + } + + private static void expectTraversal(Iterator iter, + Iterator expectedIter) { + int i = 0; + while (iter.hasNext()) { + TestElement element = iter.next(); + Assert.assertTrue("Iterator yieled " + (i + 1) + " elements, but only " + + i + " were expected.", expectedIter.hasNext()); + Integer expected = expectedIter.next(); + Assert.assertEquals("Iterator value number " + (i + 1) + " was incorrect.", + expected.intValue(), element.val); + i = i + 1; + } + Assert.assertFalse("Iterator yieled " + i + " elements, but at least " + + (i + 1) + " were expected.", expectedIter.hasNext()); + } + + @Test + public void testTraversal() throws Exception { + ImplicitLinkedHashSet set = new ImplicitLinkedHashSet<>(100); + expectTraversal(set.iterator()); + assertTrue(set.add(new TestElement(2))); + expectTraversal(set.iterator(), 2); + assertTrue(set.add(new TestElement(1))); + expectTraversal(set.iterator(), 2, 1); + assertTrue(set.add(new TestElement(100))); + expectTraversal(set.iterator(), 2, 1, 100); + assertTrue(set.remove(new TestElement(1))); + expectTraversal(set.iterator(), 2, 100); + assertTrue(set.add(new TestElement(1))); + expectTraversal(set.iterator(), 2, 100, 1); + Iterator iter = set.iterator(); + iter.next(); + iter.next(); + iter.remove(); + iter.next(); + assertFalse(iter.hasNext()); + expectTraversal(set.iterator(), 2, 1); + List list = new ArrayList<>(); + list.add(new TestElement(1)); + list.add(new TestElement(2)); + assertTrue(set.removeAll(list)); + assertFalse(set.removeAll(list)); + expectTraversal(set.iterator()); + assertEquals(0, set.size()); + assertTrue(set.isEmpty()); + } + + @Test + public void testCollisions() throws Exception { + ImplicitLinkedHashSet set = new ImplicitLinkedHashSet<>(5); + assertEquals(11, set.numSlots()); + assertTrue(set.add(new TestElement(11))); + assertTrue(set.add(new TestElement(0))); + assertTrue(set.add(new TestElement(22))); + assertTrue(set.add(new TestElement(33))); + assertEquals(11, set.numSlots()); + expectTraversal(set.iterator(), 11, 0, 22, 33); + assertTrue(set.remove(new TestElement(22))); + expectTraversal(set.iterator(), 11, 0, 33); + assertEquals(3, set.size()); + assertFalse(set.isEmpty()); + } + + @Test + public void testEnlargement() throws Exception { + ImplicitLinkedHashSet set = new ImplicitLinkedHashSet<>(5); + assertEquals(11, set.numSlots()); + for (int i = 0; i < 6; i++) { + assertTrue(set.add(new TestElement(i))); + } + assertEquals(23, set.numSlots()); + assertEquals(6, set.size()); + expectTraversal(set.iterator(), 0, 1, 2, 3, 4, 5); + for (int i = 0; i < 6; i++) { + assertTrue("Failed to find element " + i, set.contains(new TestElement(i))); + } + set.remove(new TestElement(3)); + assertEquals(23, set.numSlots()); + assertEquals(5, set.size()); + expectTraversal(set.iterator(), 0, 1, 2, 4, 5); + } + + @Test + public void testManyInsertsAndDeletes() throws Exception { + Random random = new Random(123); + LinkedHashSet existing = new LinkedHashSet<>(); + ImplicitLinkedHashSet set = new ImplicitLinkedHashSet<>(); + for (int i = 0; i < 100; i++) { + addRandomElement(random, existing, set); + addRandomElement(random, existing, set); + addRandomElement(random, existing, set); + removeRandomElement(random, existing, set); + expectTraversal(set.iterator(), existing.iterator()); + } + } + + private void addRandomElement(Random random, LinkedHashSet existing, + ImplicitLinkedHashSet set) { + int next; + do { + next = random.nextInt(); + } while (existing.contains(next)); + existing.add(next); + set.add(new TestElement(next)); + } + + private void removeRandomElement(Random random, LinkedHashSet existing, + ImplicitLinkedHashSet set) { + int removeIdx = random.nextInt(existing.size()); + Iterator iter = existing.iterator(); + Integer element = null; + for (int i = 0; i <= removeIdx; i++) { + element = iter.next(); + } + existing.remove(new TestElement(element)); + } +} diff --git a/core/src/main/scala/kafka/api/ApiVersion.scala b/core/src/main/scala/kafka/api/ApiVersion.scala index f95fb89279937..b8329c1ece218 100644 --- a/core/src/main/scala/kafka/api/ApiVersion.scala +++ b/core/src/main/scala/kafka/api/ApiVersion.scala @@ -73,8 +73,10 @@ object ApiVersion { // Introduced LeaderAndIsrRequest V1, UpdateMetadataRequest V4 and FetchRequest V6 via KIP-112 "1.0-IV0" -> KAFKA_1_0_IV0, "1.0" -> KAFKA_1_0_IV0, - // Introduced DeleteGroupsRequest V0 via KIP-229 - "1.1-IV0" -> KAFKA_1_1_IV0 + // Introduced DeleteGroupsRequest V0 via KIP-229, plus KIP-227 incremental fetch requests, + // and KafkaStorageException for fetch requests. + "1.1-IV0" -> KAFKA_1_1_IV0, + "1.1" -> KAFKA_1_1_IV0 ) private val versionPattern = "\\.".r @@ -191,4 +193,3 @@ case object KAFKA_1_1_IV0 extends ApiVersion { val messageFormatVersion: Byte = RecordBatch.MAGIC_VALUE_V2 val id: Int = 14 } - diff --git a/core/src/main/scala/kafka/server/FetchSession.scala b/core/src/main/scala/kafka/server/FetchSession.scala new file mode 100644 index 0000000000000..0a825f1f18720 --- /dev/null +++ b/core/src/main/scala/kafka/server/FetchSession.scala @@ -0,0 +1,720 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.concurrent.{ThreadLocalRandom, TimeUnit} + +import com.yammer.metrics.core.Gauge +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.utils.{ImplicitLinkedHashSet, Time, Utils} + +import scala.math.Ordered.orderingToOrdered +import scala.collection.{mutable, _} +import scala.collection.JavaConverters._ + +object FetchSession { + type REQ_MAP = util.Map[TopicPartition, FetchRequest.PartitionData] + type RESP_MAP = util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + type CACHE_MAP = ImplicitLinkedHashSet[CachedPartition] + + val NUM_INCREMENTAL_FETCH_SESSISONS = "NumIncrementalFetchSessions" + val NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = "NumIncrementalFetchPartitionsCached" + val INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = "IncrementalFetchSessionEvictionsPerSec" + val EVICTIONS = "evictions" + + def partitionsToLogString(partitions: util.Collection[TopicPartition], traceEnabled: Boolean): String = { + if (traceEnabled) { + "(" + Utils.join(partitions, ", ") + ")" + } else { + s"${partitions.size} partition(s)" + } + } +} + +/** + * A cached partition. + * + * The broker maintains a set of these objects for each incremental fetch session. + * When an incremental fetch request is made, any partitions which are not explicitly + * enumerated in the fetch request are loaded from the cache. Similarly, when an + * incremental fetch response is being prepared, any partitions that have not changed + * are left out of the response. + * + * We store many of these objects, so it is important for them to be memory-efficient. + * That is why we store topic and partition separately rather than storing a TopicPartition + * object. The TP object takes up more memory because it is a separate JVM object, and + * because it stores the cached hash code in memory. + * + * Note that fetcherLogStartOffset is the LSO of the follower performing the fetch, whereas + * localLogStartOffset is the log start offset of the partition on this broker. + */ +class CachedPartition(val topic: String, + val partition: Int, + var maxBytes: Int, + var fetchOffset: Long, + var highWatermark: Long, + var fetcherLogStartOffset: Long, + var localLogStartOffset: Long) + extends ImplicitLinkedHashSet.Element { + + var cachedNext: Int = ImplicitLinkedHashSet.INVALID_INDEX + var cachedPrev: Int = ImplicitLinkedHashSet.INVALID_INDEX + + override def next = cachedNext + override def setNext(next: Int) = this.cachedNext = next + override def prev = cachedPrev + override def setPrev(prev: Int) = this.cachedPrev = prev + + def this(topic: String, partition: Int) = + this(topic, partition, -1, -1, -1, -1, -1) + + def this(part: TopicPartition) = + this(part.topic(), part.partition()) + + def this(part: TopicPartition, reqData: FetchRequest.PartitionData) = + this(part.topic(), part.partition(), + reqData.maxBytes, reqData.fetchOffset, -1, + reqData.logStartOffset, -1) + + def this(part: TopicPartition, reqData: FetchRequest.PartitionData, + respData: FetchResponse.PartitionData) = + this(part.topic(), part.partition(), + reqData.maxBytes, reqData.fetchOffset, respData.highWatermark, + reqData.logStartOffset, respData.logStartOffset) + + def topicPartition() = new TopicPartition(topic, partition) + + def reqData() = new FetchRequest.PartitionData(fetchOffset, fetcherLogStartOffset, maxBytes) + + def updateRequestParams(reqData: FetchRequest.PartitionData): Unit = { + // Update our cached request parameters. + maxBytes = reqData.maxBytes + fetchOffset = reqData.fetchOffset + fetcherLogStartOffset = reqData.logStartOffset + } + + /** + * Update this CachedPartition with new request and response data. + * + * This function should be called while holding the appropriate session + * lock. + * + * @return True if this partition should be included in the FetchResponse + * we send back to the fetcher; false if it can be omitted. + */ + def updateResponseData(respData: FetchResponse.PartitionData): Boolean = { + // Check the response data. + var mustRespond = false + if ((respData.records != null) && (respData.records.sizeInBytes() > 0)) { + // Partitions with new data are always included in the response. + mustRespond = true + } + if (highWatermark != respData.highWatermark) { + mustRespond = true + highWatermark = respData.highWatermark + } + if (localLogStartOffset != respData.logStartOffset) { + mustRespond = true + localLogStartOffset = respData.logStartOffset + } + if (respData.error.code() != 0) { + // Partitions with errors are always included in the response. + // We also set the cached highWatermark to an invalid offset, -1. + // This ensures that when the error goes away, we re-send the partition. + highWatermark = -1 + mustRespond = true + } + mustRespond + } + + override def hashCode() = (31 * partition) + topic.hashCode + + def canEqual(that: Any) = that.isInstanceOf[CachedPartition] + + override def equals(that: Any): Boolean = + that match { + case that: CachedPartition => that.canEqual(this) && + this.topic.equals(that.topic) && + this.partition.equals(that.partition) + case _ => false + } + + override def toString() = synchronized { + "CachedPartition(topic=" + topic + + ", partition=" + partition + + ", maxBytes=" + maxBytes + + ", fetchOffset=" + fetchOffset + + ", highWatermark=" + highWatermark + + ", fetcherLogStartOffset=" + fetcherLogStartOffset + + ", localLogStartOffset=" + localLogStartOffset + + ")" + } +} + +/** + * The fetch session. + * + * Each fetch session is protected by its own lock, which must be taken before mutable + * fields are read or modified. This includes modification of the session partition map. + * + * @param id The unique fetch session ID. + * @param privileged True if this session is privileged. Sessions crated by followers + * are privileged; sesssion created by consumers are not. + * @param partitionMap The CachedPartitionMap. + * @param creationMs The time in milliseconds when this session was created. + * @param lastUsedMs The last used time in milliseconds. This should only be updated by + * FetchSessionCache#touch. + * @param epoch The fetch session sequence number. + */ +case class FetchSession(val id: Int, + val privileged: Boolean, + val partitionMap: FetchSession.CACHE_MAP, + val creationMs: Long, + var lastUsedMs: Long, + var epoch: Int) { + // This is used by the FetchSessionCache to store the last known size of this session. + // If this is -1, the Session is not in the cache. + var cachedSize = -1 + + def size(): Int = synchronized { + partitionMap.size() + } + + def isEmpty(): Boolean = synchronized { + partitionMap.isEmpty + } + + def lastUsedKey(): LastUsedKey = synchronized { + LastUsedKey(lastUsedMs, id) + } + + def evictableKey(): EvictableKey = synchronized { + EvictableKey(privileged, cachedSize, id) + } + + def metadata(): JFetchMetadata = synchronized { new JFetchMetadata(id, epoch) } + + def getFetchOffset(topicPartition: TopicPartition): Option[Long] = synchronized { + Option(partitionMap.find(new CachedPartition(topicPartition))).map(_.fetchOffset) + } + + type TL = util.ArrayList[TopicPartition] + + // Update the cached partition data based on the request. + def update(fetchData: FetchSession.REQ_MAP, + toForget: util.List[TopicPartition], + reqMetadata: JFetchMetadata): (TL, TL, TL) = synchronized { + val added = new TL + val updated = new TL + val removed = new TL + fetchData.entrySet().iterator().asScala.foreach(entry => { + val topicPart = entry.getKey + val reqData = entry.getValue + val newCachedPart = new CachedPartition(topicPart, reqData) + val cachedPart = partitionMap.find(newCachedPart) + if (cachedPart == null) { + partitionMap.mustAdd(newCachedPart) + added.add(topicPart) + } else { + cachedPart.updateRequestParams(reqData) + updated.add(topicPart) + } + }) + toForget.iterator().asScala.foreach(p => { + if (partitionMap.remove(new CachedPartition(p.topic(), p.partition()))) { + removed.add(p) + } + }) + (added, updated, removed) + } + + override def toString(): String = synchronized { + "FetchSession(id=" + id + + ", privileged=" + privileged + + ", partitionMap.size=" + partitionMap.size() + + ", creationMs=" + creationMs + + ", creationMs=" + lastUsedMs + + ", epoch=" + epoch + ")" + } +} + +trait FetchContext extends Logging { + /** + * Get the fetch offset for a given partition. + */ + def getFetchOffset(part: TopicPartition): Option[Long] + + /** + * Apply a function to each partition in the fetch request. + */ + def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit + + /** + * Updates the fetch context with new partition information. Generates response data. + * The response data may require subsequent down-conversion. + */ + def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse + + def partitionsToLogString(partitions: util.Collection[TopicPartition]): String = + FetchSession.partitionsToLogString(partitions, isTraceEnabled) +} + +/** + * The fetch context for a fetch request that had a session error. + */ +class SessionErrorContext(val error: Errors, + val reqMetadata: JFetchMetadata) extends FetchContext { + override def getFetchOffset(part: TopicPartition): Option[Long] = None + + override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = {} + + // Because of the fetch session error, we don't know what partitions were supposed to be in this request. + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + debug(s"Session error fetch context returning $error") + new FetchResponse(error, new FetchSession.RESP_MAP, 0, INVALID_SESSION_ID) + } +} + +/** + * The fetch context for a sessionless fetch request. + * + * @param fetchData The partition data from the fetch request. + */ +class SessionlessFetchContext(val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData]) extends FetchContext { + override def getFetchOffset(part: TopicPartition): Option[Long] = + Option(fetchData.get(part)).map(_.fetchOffset) + + override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = { + fetchData.entrySet().asScala.foreach(entry => fun(entry.getKey, entry.getValue)) + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + debug(s"Sessionless fetch context returning ${partitionsToLogString(updates.keySet())}") + new FetchResponse(Errors.NONE, updates, 0, INVALID_SESSION_ID) + } +} + +/** + * The fetch context for a full fetch request. + * + * @param time The clock to use. + * @param cache The fetch session cache. + * @param reqMetadata The request metadata. + * @param fetchData The partition data from the fetch request. + * @param isFromFollower True if this fetch request came from a follower. + */ +class FullFetchContext(private val time: Time, + private val cache: FetchSessionCache, + private val reqMetadata: JFetchMetadata, + private val fetchData: util.Map[TopicPartition, FetchRequest.PartitionData], + private val isFromFollower: Boolean) extends FetchContext { + override def getFetchOffset(part: TopicPartition): Option[Long] = + Option(fetchData.get(part)).map(_.fetchOffset) + + override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = { + fetchData.entrySet().asScala.foreach(entry => fun(entry.getKey, entry.getValue)) + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + def createNewSession(): FetchSession.CACHE_MAP = { + val cachedPartitions = new FetchSession.CACHE_MAP(updates.size()) + updates.entrySet().asScala.foreach(entry => { + val part = entry.getKey + val respData = entry.getValue + val reqData = fetchData.get(part) + cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData)) + }) + cachedPartitions + } + val responseSessionId = cache.maybeCreateSession(time.milliseconds(), isFromFollower, + updates.size(), createNewSession) + debug(s"Full fetch context with session id $responseSessionId returning " + + s"${partitionsToLogString(updates.keySet())}") + new FetchResponse(Errors.NONE, updates, 0, responseSessionId) + } +} + +/** + * The fetch context for an incremental fetch request. + * + * @param time The clock to use. + * @param reqMetadata The request metadata. + * @param session The incremental fetch request session. + */ +class IncrementalFetchContext(private val time: Time, + private val reqMetadata: JFetchMetadata, + private val session: FetchSession) extends FetchContext { + + override def getFetchOffset(tp: TopicPartition): Option[Long] = session.getFetchOffset(tp) + + override def foreachPartition(fun: (TopicPartition, FetchRequest.PartitionData) => Unit): Unit = { + // Take the session lock and iterate over all the cached partitions. + session.synchronized { + session.partitionMap.iterator().asScala.foreach(part => { + fun(new TopicPartition(part.topic, part.partition), part.reqData()) + }) + } + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + session.synchronized { + // Check to make sure that the session epoch didn't change in between + // creating this fetch context and generating this response. + val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch()) + if (session.epoch != expectedEpoch) { + info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " + + s"got ${session.epoch}. Possible duplicate request.") + new FetchResponse(Errors.INVALID_FETCH_SESSION_EPOCH, new FetchSession.RESP_MAP, 0, session.id) + } else { + // Iterate over the update list. Prune updates which don't need to be sent. + val iter = updates.entrySet().iterator() + while (iter.hasNext()) { + val entry = iter.next() + val topicPart = entry.getKey + val respData = entry.getValue + val cachedPart = session.partitionMap.find(new CachedPartition(topicPart)) + val mustRespond = cachedPart.updateResponseData(respData) + if (mustRespond) { + // Move this to the end of the cached partition map. + // This is important for ensuring fairness when lots of partitions + // have data to return. + session.partitionMap.remove(cachedPart) + session.partitionMap.mustAdd(cachedPart) + } else { + // Do not include this partition in the FetchResponse. + iter.remove() + } + } + debug(s"Incremental fetch context with session id ${session.id} returning " + + s"${partitionsToLogString(updates.keySet())}") + new FetchResponse(Errors.NONE, updates, 0, session.id) + } + } + } +} + +case class LastUsedKey(val lastUsedMs: Long, + val id: Int) extends Comparable[LastUsedKey] { + override def compareTo(other: LastUsedKey): Int = + (lastUsedMs, id) compare (other.lastUsedMs, other.id) +} + +case class EvictableKey(val privileged: Boolean, + val size: Int, + val id: Int) extends Comparable[EvictableKey] { + override def compareTo(other: EvictableKey): Int = + (privileged, size, id) compare (other.privileged, other.size, other.id) +} + +/** + * Caches fetch sessions. + * + * See tryEvict for an explanation of the cache eviction strategy. + * + * The FetchSessionCache is thread-safe because all of its methods are synchronized. + * Note that individual fetch sessions have their own locks which are separate from the + * FetchSessionCache lock. In order to avoid deadlock, the FetchSessionCache lock + * must never be acquired while an individual FetchSession lock is already held. + * + * @param maxEntries The maximum number of entries that can be in the cache. + * @param evictionMs The minimum time that an entry must be unused in order to be evictable. + */ +class FetchSessionCache(private val maxEntries: Int, + private val evictionMs: Long) extends Logging with KafkaMetricsGroup { + private var numPartitions: Long = 0 + + // A map of session ID to FetchSession. + private val sessions = new mutable.HashMap[Int, FetchSession] + + // Maps last used times to sessions. + private val lastUsed = new util.TreeMap[LastUsedKey, FetchSession] + + // A map containing sessions which can be evicted by both privileged and + // unprivileged sessions. + private val evictableByAll = new util.TreeMap[EvictableKey, FetchSession] + + // A map containing sessions which can be evicted by privileged sessions. + private val evictableByPrivileged = new util.TreeMap[EvictableKey, FetchSession] + + // Set up metrics. + removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSISONS) + newGauge(FetchSession.NUM_INCREMENTAL_FETCH_SESSISONS, + new Gauge[Int] { + def value = FetchSessionCache.this.size + } + ) + removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED) + newGauge(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED, + new Gauge[Long] { + def value = FetchSessionCache.this.totalPartitions + } + ) + removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC) + val evictionsMeter = newMeter(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC, + FetchSession.EVICTIONS, TimeUnit.SECONDS, Map.empty) + + /** + * Get a session by session ID. + * + * @param sessionId The session ID. + * @return The session, or None if no such session was found. + */ + def get(sessionId: Int): Option[FetchSession] = synchronized { + sessions.get(sessionId) + } + + /** + * Get the number of entries currently in the fetch session cache. + */ + def size(): Int = synchronized { + sessions.size + } + + /** + * Get the total number of cached partitions. + */ + def totalPartitions(): Long = synchronized { + numPartitions + } + + /** + * Creates a new random session ID. The new session ID will be positive and unique on this broker. + * + * @return The new session ID. + */ + def newSessionId(): Int = synchronized { + var id = 0 + do { + id = ThreadLocalRandom.current().nextInt(1, Int.MaxValue) + } while (sessions.contains(id) || id == INVALID_SESSION_ID) + id + } + + /** + * Try to create a new session. + * + * @param now The current time in milliseconds. + * @param privileged True if the new entry we are trying to create is privileged. + * @param size The number of cached partitions in the new entry we are trying to create. + * @param createPartitions A callback function which creates the map of cached partitions. + * @return If we created a session, the ID; INVALID_SESSION_ID otherwise. + */ + def maybeCreateSession(now: Long, + privileged: Boolean, + size: Int, + createPartitions: () => FetchSession.CACHE_MAP): Int = + synchronized { + // If there is room, create a new session entry. + if ((sessions.size < maxEntries) || + tryEvict(privileged, EvictableKey(privileged, size, 0), now)) { + val partitionMap = createPartitions() + val session = new FetchSession(newSessionId(), privileged, partitionMap, + now, now, JFetchMetadata.nextEpoch(INITIAL_EPOCH)) + debug(s"Created fetch session ${session.toString()}") + sessions.put(session.id, session) + touch(session, now) + session.id + } else { + debug(s"No fetch session created for privileged=$privileged, size=$size.") + INVALID_SESSION_ID + } + } + + /** + * Try to evict an entry from the session cache. + * + * A proposed new element A may evict an existing element B if: + * 1. A is privileged and B is not, or + * 2. B is considered "stale" because it has been inactive for a long time, or + * 3. A contains more partitions than B, and B is not recently created. + * + * @param privileged True if the new entry we would like to add is privileged. + * @param key The EvictableKey for the new entry we would like to add. + * @param now The current time in milliseconds. + * @return True if an entry was evicted; false otherwise. + */ + def tryEvict(privileged: Boolean, key: EvictableKey, now: Long): Boolean = synchronized { + // Try to evict an entry which is stale. + val lastUsedEntry = lastUsed.firstEntry() + if (lastUsedEntry == null) { + trace("There are no cache entries to evict.") + false + } else if (now - lastUsedEntry.getKey().lastUsedMs > evictionMs) { + val session = lastUsedEntry.getValue() + trace(s"Evicting stale FetchSession ${session.id}.") + remove(session) + evictionsMeter.mark() + true + } else { + // If there are no stale entries, check the first evictable entry. + // If it is less valuable than our proposed entry, evict it. + val map = if (privileged) evictableByPrivileged else evictableByAll + val evictableEntry = map.firstEntry() + if (evictableEntry == null) { + trace("No evictable entries found.") + false + } else if (key.compareTo(evictableEntry.getKey()) < 0) { + trace(s"Can't evict ${evictableEntry.getKey()} with ${key.toString}") + false + } else { + trace(s"Evicting ${evictableEntry.getKey()} with ${key.toString}.") + remove(evictableEntry.getValue()) + evictionsMeter.mark() + true + } + } + } + + def remove(sessionId: Int): Option[FetchSession] = synchronized { + get(sessionId) match { + case None => None + case Some(session) => remove(session) + } + } + + /** + * Remove an entry from the session cache. + * + * @param session The session. + * + * @return The removed session, or None if there was no such session. + */ + def remove(session: FetchSession): Option[FetchSession] = synchronized { + val evictableKey = session.synchronized { + lastUsed.remove(session.lastUsedKey()) + session.evictableKey() + } + evictableByAll.remove(evictableKey) + evictableByPrivileged.remove(evictableKey) + val removeResult = sessions.remove(session.id) + if (removeResult.isDefined) { + numPartitions = numPartitions - session.cachedSize + } + removeResult + } + + /** + * Update a session's position in the lastUsed and evictable trees. + * + * @param session The session. + * @param now The current time in milliseconds. + */ + def touch(session: FetchSession, now: Long): Unit = synchronized { + session.synchronized { + // Update the lastUsed map. + lastUsed.remove(session.lastUsedKey()) + session.lastUsedMs = now + lastUsed.put(session.lastUsedKey(), session) + + val oldSize = session.cachedSize + if (oldSize != -1) { + val oldEvictableKey = session.evictableKey() + evictableByPrivileged.remove(oldEvictableKey) + evictableByAll.remove(oldEvictableKey) + numPartitions = numPartitions - oldSize + } + session.cachedSize = session.size() + val newEvictableKey = session.evictableKey() + if ((!session.privileged) || (now - session.creationMs > evictionMs)) { + evictableByPrivileged.put(newEvictableKey, session) + } + if (now - session.creationMs > evictionMs) { + evictableByAll.put(newEvictableKey, session) + } + numPartitions = numPartitions + session.cachedSize + } + } +} + +class FetchManager(private val time: Time, + private val cache: FetchSessionCache) extends Logging { + def newContext(reqMetadata: JFetchMetadata, + fetchData: FetchSession.REQ_MAP, + toForget: util.List[TopicPartition], + isFollower: Boolean): FetchContext = { + val context = if (reqMetadata.isFull) { + var removedFetchSessionStr = "" + if (reqMetadata.sessionId() != INVALID_SESSION_ID) { + // Any session specified in a FULL fetch request will be closed. + if (cache.remove(reqMetadata.sessionId()).isDefined) { + removedFetchSessionStr = s" Removed fetch session ${reqMetadata.sessionId()}." + } + } + var suffix = "" + val context = if (reqMetadata.epoch() == FINAL_EPOCH) { + // If the epoch is FINAL_EPOCH, don't try to create a new session. + suffix = " Will not try to create a new session." + new SessionlessFetchContext(fetchData) + } else { + new FullFetchContext(time, cache, reqMetadata, fetchData, isFollower) + } + debug(s"Created a new full FetchContext with ${partitionsToLogString(fetchData.keySet())}."+ + s"${removedFetchSessionStr}${suffix}") + context + } else { + cache.synchronized { + cache.get(reqMetadata.sessionId()) match { + case None => { + info(s"Created a new error FetchContext for session id ${reqMetadata.sessionId()}: " + + "no such session ID found.") + new SessionErrorContext(Errors.FETCH_SESSION_ID_NOT_FOUND, reqMetadata) + } + case Some(session) => session.synchronized { + if (session.epoch != reqMetadata.epoch()) { + debug(s"Created a new error FetchContext for session id ${session.id}: expected " + + s"epoch ${session.epoch}, but got epoch ${reqMetadata.epoch()}.") + new SessionErrorContext(Errors.INVALID_FETCH_SESSION_EPOCH, reqMetadata) + } else { + val (added, updated, removed) = session.update(fetchData, toForget, reqMetadata) + if (session.isEmpty) { + debug(s"Created a new sessionless FetchContext and closing session id ${session.id}, " + + s"epoch ${session.epoch}: after removing ${partitionsToLogString(removed)}, " + + s"there are no more partitions left.") + cache.remove(session) + new SessionlessFetchContext(fetchData) + } else { + if (session.size() != session.cachedSize) { + // If the number of partitions in the session changed, update the session's + // position in the cache. + cache.touch(session, session.lastUsedMs) + } + session.epoch = JFetchMetadata.nextEpoch(session.epoch) + debug(s"Created a new incremental FetchContext for session id ${session.id}, " + + s"epoch ${session.epoch}: added ${partitionsToLogString(added)}, " + + s"updated ${partitionsToLogString(updated)}, " + + s"removed ${partitionsToLogString(removed)}") + new IncrementalFetchContext(time, reqMetadata, session) + } + } + } + } + } + } + context + } + + def partitionsToLogString(partitions: util.Collection[TopicPartition]): String = + FetchSession.partitionsToLogString(partitions, isTraceEnabled) +} diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 2dd6951f46d83..1e1dc1a0fcdcf 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -80,6 +80,7 @@ class KafkaApis(val requestChannel: RequestChannel, val metrics: Metrics, val authorizer: Option[Authorizer], val quotas: QuotaManagers, + val fetchManager: FetchManager, brokerTopicStats: BrokerTopicStats, val clusterId: String, time: Time, @@ -481,35 +482,52 @@ class KafkaApis(val requestChannel: RequestChannel, * Handle a fetch request */ def handleFetchRequest(request: RequestChannel.Request) { - val fetchRequest = request.body[FetchRequest] val versionId = request.header.apiVersion val clientId = request.header.clientId - - val unauthorizedTopicResponseData = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]() - val nonExistingTopicResponseData = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]() - val authorizedRequestInfo = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]() - - if (fetchRequest.isFromFollower() && !authorize(request.session, ClusterAction, Resource.ClusterResource)) - for (topicPartition <- fetchRequest.fetchData.asScala.keys) - unauthorizedTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.CLUSTER_AUTHORIZATION_FAILED, - FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, - FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY) - else - for ((topicPartition, partitionData) <- fetchRequest.fetchData.asScala) { - if (!authorize(request.session, Read, new Resource(Topic, topicPartition.topic))) - unauthorizedTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED, + val fetchRequest = request.body[FetchRequest] + val fetchContext = fetchManager.newContext(fetchRequest.metadata(), + fetchRequest.fetchData(), + fetchRequest.toForget(), + fetchRequest.isFromFollower()) + + val erroneous = mutable.ArrayBuffer[(TopicPartition, FetchResponse.PartitionData)]() + val interesting = mutable.ArrayBuffer[(TopicPartition, FetchRequest.PartitionData)]() + if (fetchRequest.isFromFollower()) { + // The follower must have ClusterAction on ClusterResource in order to fetch partition data. + if (authorize(request.session, ClusterAction, Resource.ClusterResource)) { + fetchContext.foreachPartition((part, data) => { + if (!metadataCache.contains(part.topic)) { + erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION, + FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, + FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY) + } else { + interesting += (part -> data) + } + }) + } else { + fetchContext.foreachPartition((part, data) => { + erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED, FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY) - else if (!metadataCache.contains(topicPartition.topic)) - nonExistingTopicResponseData += topicPartition -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION, + }) + } + } else { + // Regular Kafka consumers need READ permission on each partition they are fetching. + fetchContext.foreachPartition((part, data) => { + if (!authorize(request.session, Read, new Resource(Topic, part.topic))) + erroneous += part -> new FetchResponse.PartitionData(Errors.TOPIC_AUTHORIZATION_FAILED, + FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, + FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY) + else if (!metadataCache.contains(part.topic)) + erroneous += part -> new FetchResponse.PartitionData(Errors.UNKNOWN_TOPIC_OR_PARTITION, FetchResponse.INVALID_HIGHWATERMARK, FetchResponse.INVALID_LAST_STABLE_OFFSET, FetchResponse.INVALID_LOG_START_OFFSET, null, MemoryRecords.EMPTY) else - authorizedRequestInfo += (topicPartition -> partitionData) - } + interesting += (part -> data) + }) + } def convertedPartitionData(tp: TopicPartition, data: FetchResponse.PartitionData) = { - // Down-conversion of the fetched records is needed when the stored magic version is // greater than that supported by the client (as indicated by the fetch request version). If the // configured magic version for the topic is less than or equal to that supported by the version of the @@ -529,7 +547,7 @@ class KafkaApis(val requestChannel: RequestChannel, downConvertMagic.map { magic => trace(s"Down converting records from partition $tp to message format version $magic for fetch request from $clientId") - val converted = data.records.downConvert(magic, fetchRequest.fetchData.get(tp).fetchOffset, time) + val converted = data.records.downConvert(magic, fetchContext.getFetchOffset(tp).get, time) updateRecordsProcessingStats(request, tp, converted.recordsProcessingStats) new FetchResponse.PartitionData(data.error, data.highWatermark, FetchResponse.INVALID_LAST_STABLE_OFFSET, data.logStartOffset, data.abortedTransactions, converted.records) @@ -540,34 +558,28 @@ class KafkaApis(val requestChannel: RequestChannel, // the callback for process a fetch response, invoked before throttling def processResponseCallback(responsePartitionData: Seq[(TopicPartition, FetchPartitionData)]) { - val partitionData = { - responsePartitionData.map { case (tp, data) => - val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull - val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET) - tp -> new FetchResponse.PartitionData(data.error, data.highWatermark, lastStableOffset, - data.logStartOffset, abortedTransactions, data.records) - } - } - - val mergedPartitionData = partitionData ++ unauthorizedTopicResponseData ++ nonExistingTopicResponseData - val fetchedPartitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]() - - mergedPartitionData.foreach { case (topicPartition, data) => - if (data.error != Errors.NONE) - debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " + - s"on partition $topicPartition failed due to ${data.error.exceptionName}") - - fetchedPartitionData.put(topicPartition, data) + val partitions = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + responsePartitionData.foreach{ case (tp, data) => + val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull + val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET) + partitions.put(tp, new FetchResponse.PartitionData(data.error, data.highWatermark, lastStableOffset, + data.logStartOffset, abortedTransactions, data.records)) } + erroneous.foreach{case (tp, data) => partitions.put(tp, data)} + val unconvertedFetchResponse = fetchContext.updateAndGenerateResponseData(partitions) // fetch response callback invoked after any throttling def fetchResponseCallback(bandwidthThrottleTimeMs: Int) { def createResponse(requestThrottleTimeMs: Int): FetchResponse = { val convertedData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] - fetchedPartitionData.asScala.foreach { case (tp, partitionData) => + unconvertedFetchResponse.responseData().asScala.foreach { case (tp, partitionData) => + if (partitionData.error != Errors.NONE) + debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " + + s"on partition $tp failed due to ${partitionData.error.exceptionName}") convertedData.put(tp, convertedPartitionData(tp, partitionData)) } - val response = new FetchResponse(convertedData, bandwidthThrottleTimeMs + requestThrottleTimeMs) + val response = new FetchResponse(unconvertedFetchResponse.error(), convertedData, + bandwidthThrottleTimeMs + requestThrottleTimeMs, unconvertedFetchResponse.sessionId()) response.responseData.asScala.foreach { case (topicPartition, data) => // record the bytes out metrics only when the response is being sent brokerTopicStats.updateBytesOut(topicPartition.topic, fetchRequest.isFromFollower, data.records.sizeInBytes) @@ -575,6 +587,9 @@ class KafkaApis(val requestChannel: RequestChannel, response } + trace(s"Sending Fetch response with partitions.size=${unconvertedFetchResponse.responseData().size()}, " + + s"metadata=${unconvertedFetchResponse.sessionId()}") + if (fetchRequest.isFromFollower) sendResponseExemptThrottle(request, createResponse(0)) else @@ -587,21 +602,20 @@ class KafkaApis(val requestChannel: RequestChannel, if (fetchRequest.isFromFollower) { // We've already evaluated against the quota and are good to go. Just need to record it now. - val responseSize = sizeOfThrottledPartitions(versionId, fetchRequest, mergedPartitionData, quotas.leader) + val responseSize = sizeOfThrottledPartitions(versionId, unconvertedFetchResponse, quotas.leader) quotas.leader.record(responseSize) fetchResponseCallback(bandwidthThrottleTimeMs = 0) } else { // Fetch size used to determine throttle time is calculated before any down conversions. // This may be slightly different from the actual response size. But since down conversions // result in data being loaded into memory, it is better to do this after throttling to avoid OOM. - val response = new FetchResponse(fetchedPartitionData, 0) - val responseStruct = response.toStruct(versionId) + val responseStruct = unconvertedFetchResponse.toStruct(versionId) quotas.fetch.maybeRecordAndThrottle(request.session.sanitizedUser, clientId, responseStruct.sizeOf, fetchResponseCallback) } } - if (authorizedRequestInfo.isEmpty) + if (interesting.isEmpty) processResponseCallback(Seq.empty) else { // call the replica manager to fetch messages from the local replica @@ -611,23 +625,45 @@ class KafkaApis(val requestChannel: RequestChannel, fetchRequest.minBytes, fetchRequest.maxBytes, versionId <= 2, - authorizedRequestInfo, + interesting, replicationQuota(fetchRequest), processResponseCallback, fetchRequest.isolationLevel) } } + class SelectingIterator(val partitions: util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData], + val quota: ReplicationQuotaManager) + extends util.Iterator[util.Map.Entry[TopicPartition, FetchResponse.PartitionData]] { + val iter = partitions.entrySet().iterator() + + var nextElement: util.Map.Entry[TopicPartition, FetchResponse.PartitionData] = null + + override def hasNext: Boolean = { + while ((nextElement == null) && iter.hasNext()) { + val element = iter.next() + if (quota.isThrottled(element.getKey)) { + nextElement = element + } + } + nextElement != null + } + + override def next(): util.Map.Entry[TopicPartition, FetchResponse.PartitionData] = { + if (!hasNext()) throw new NoSuchElementException() + val element = nextElement + nextElement = null + element + } + + override def remove() = throw new UnsupportedOperationException() + } + private def sizeOfThrottledPartitions(versionId: Short, - fetchRequest: FetchRequest, - mergedPartitionData: Seq[(TopicPartition, FetchResponse.PartitionData)], + unconvertedResponse: FetchResponse, quota: ReplicationQuotaManager): Int = { - val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] - mergedPartitionData.foreach { case (tp, data) => - if (quota.isThrottled(tp)) - partitionData.put(tp, data) - } - FetchResponse.sizeOf(versionId, partitionData) + val iter = new SelectingIterator(unconvertedResponse.responseData(), quota) + FetchResponse.sizeOf(versionId, iter) } def replicationQuota(fetchRequest: FetchRequest): ReplicaQuota = diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala index 144dd65c58099..6402ad0717ac0 100755 --- a/core/src/main/scala/kafka/server/KafkaConfig.scala +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -174,6 +174,9 @@ object Defaults { val TransactionsAbortTimedOutTransactionsCleanupIntervalMS = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs val TransactionsRemoveExpiredTransactionsCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs + /** ********* Fetch Session Configuration **************/ + val MaxIncrementalFetchSessionCacheSlots = 1000 + /** ********* Quota Configuration ***********/ val ProducerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault val ConsumerQuotaBytesPerSecondDefault = ClientQuotaManagerConfig.QuotaBytesPerSecondDefault @@ -370,6 +373,9 @@ object KafkaConfig { val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms" val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms" + /** ********* Fetch Session Configuration **************/ + val MaxIncrementalFetchSessionCacheSlots = "max.incremental.fetch.session.cache.slots" + /** ********* Quota Configuration ***********/ val ProducerQuotaBytesPerSecondDefaultProp = "quota.producer.default" val ConsumerQuotaBytesPerSecondDefaultProp = "quota.consumer.default" @@ -638,6 +644,9 @@ object KafkaConfig { val TransactionsAbortTimedOutTransactionsIntervalMsDoc = "The interval at which to rollback transactions that have timed out" val TransactionsRemoveExpiredTransactionsIntervalMsDoc = "The interval at which to remove transactions that have expired due to transactional.id.expiration.ms passing" + /** ********* Fetch Session Configuration **************/ + val MaxIncrementalFetchSessionCacheSlotsDoc = "The maximum number of incremental fetch sessions that we will maintain." + /** ********* Quota Configuration ***********/ val ProducerQuotaBytesPerSecondDefaultDoc = "DEPRECATED: Used only when dynamic default quotas are not configured for , or in Zookeeper. " + "Any producer distinguished by clientId will get throttled if it produces more bytes than this value per-second" @@ -861,6 +870,9 @@ object KafkaConfig { .define(TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsAbortTimedOutTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsAbortTimedOutTransactionsIntervalMsDoc) .define(TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, INT, Defaults.TransactionsRemoveExpiredTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsRemoveExpiredTransactionsIntervalMsDoc) + /** ********* Fetch Session Configuration **************/ + .define(MaxIncrementalFetchSessionCacheSlots, INT, Defaults.MaxIncrementalFetchSessionCacheSlots, atLeast(0), MEDIUM, MaxIncrementalFetchSessionCacheSlotsDoc) + /** ********* Kafka Metrics Configuration ***********/ .define(MetricNumSamplesProp, INT, Defaults.MetricNumSamples, atLeast(1), LOW, MetricNumSamplesDoc) .define(MetricSampleWindowMsProp, LONG, Defaults.MetricSampleWindowMs, atLeast(1), LOW, MetricSampleWindowMsDoc) @@ -1168,6 +1180,9 @@ class KafkaConfig(val props: java.util.Map[_, _], doLog: Boolean, dynamicConfigO /** ********* Transaction Configuration **************/ val transactionIdExpirationMs = getInt(KafkaConfig.TransactionalIdExpirationMsProp) + /** ********* Fetch Session Configuration **************/ + val maxIncrementalFetchSessionCacheSlots = getInt(KafkaConfig.MaxIncrementalFetchSessionCacheSlots) + val deleteTopicEnable = getBoolean(KafkaConfig.DeleteTopicEnableProp) def compressionType = getString(KafkaConfig.CompressionTypeProp) val listeners: Seq[EndPoint] = getListeners diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 747a0dfdf021d..f15ac66eb4726 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -89,6 +89,7 @@ object KafkaServer { .timeWindow(kafkaConfig.metricSampleWindowMs, TimeUnit.MILLISECONDS) } + val MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS: Long = 120000 } /** @@ -279,10 +280,14 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP authZ } + val fetchManager = new FetchManager(Time.SYSTEM, + new FetchSessionCache(config.maxIncrementalFetchSessionCacheSlots, + KafkaServer.MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS)) + /* start processing requests */ apis = new KafkaApis(socketServer.requestChannel, replicaManager, adminManager, groupCoordinator, transactionCoordinator, kafkaController, zkClient, config.brokerId, config, metadataCache, metrics, authorizer, quotaManagers, - brokerTopicStats, clusterId, time, tokenManager) + fetchManager, brokerTopicStats, clusterId, time, tokenManager) requestHandlerPool = new KafkaRequestHandlerPool(config.brokerId, socketServer.requestChannel, apis, time, config.numIoThreads) diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala index da94c4a8f2163..8344d5beb349a 100644 --- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -26,6 +26,7 @@ import kafka.log.LogConfig import kafka.server.ReplicaFetcherThread._ import kafka.server.epoch.LeaderEpochCache import kafka.zk.AdminZkClient +import org.apache.kafka.clients.FetchSessionHandler import org.apache.kafka.common.requests.EpochEndOffset._ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.errors.KafkaStorageException @@ -35,6 +36,7 @@ import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.record.MemoryRecords import org.apache.kafka.common.requests.{EpochEndOffset, FetchResponse, ListOffsetRequest, ListOffsetResponse, OffsetsForLeaderEpochRequest, OffsetsForLeaderEpochResponse, FetchRequest => JFetchRequest} import org.apache.kafka.common.utils.{LogContext, Time} + import scala.collection.JavaConverters._ import scala.collection.{Map, mutable} @@ -65,17 +67,20 @@ class ReplicaFetcherThread(name: String, new ReplicaFetcherBlockingSend(sourceBroker, brokerConfig, metrics, time, fetcherId, s"broker-$replicaId-fetcher-$fetcherId", logContext)) private val fetchRequestVersion: Short = - if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV1) 5 + if (brokerConfig.interBrokerProtocolVersion >= KAFKA_1_1_IV0) 7 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV1) 5 else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV0) 4 else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV1) 3 else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_0_IV0) 2 else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_9_0) 1 else 0 + private val fetchMetadataSupported = brokerConfig.interBrokerProtocolVersion >= KAFKA_1_1_IV0 private val maxWait = brokerConfig.replicaFetchWaitMaxMs private val minBytes = brokerConfig.replicaFetchMinBytes private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes private val fetchSize = brokerConfig.replicaFetchMaxBytes private val shouldSendLeaderEpochRequest: Boolean = brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV2 + private val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id) private def epochCacheOpt(tp: TopicPartition): Option[LeaderEpochCache] = replicaMgr.getReplica(tp).map(_.epochs.get) @@ -211,10 +216,20 @@ class ReplicaFetcherThread(name: String, } protected def fetch(fetchRequest: FetchRequest): Seq[(TopicPartition, PartitionData)] = { - val clientResponse = leaderEndpoint.sendRequest(fetchRequest.underlying) - val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse] - fetchResponse.responseData.asScala.toSeq.map { case (key, value) => - key -> new PartitionData(value) + try { + val clientResponse = leaderEndpoint.sendRequest(fetchRequest.underlying) + val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse] + if (!fetchSessionHandler.handleResponse(fetchResponse)) { + Nil + } else { + fetchResponse.responseData.asScala.toSeq.map { case (key, value) => + key -> new PartitionData(value) + } + } + } catch { + case t: Throwable => + fetchSessionHandler.handleError(t) + throw t } } @@ -240,15 +255,16 @@ class ReplicaFetcherThread(name: String, } override def buildFetchRequest(partitionMap: Seq[(TopicPartition, PartitionFetchState)]): ResultWithPartitions[FetchRequest] = { - val requestMap = new util.LinkedHashMap[TopicPartition, JFetchRequest.PartitionData] val partitionsWithError = mutable.Set[TopicPartition]() + val builder = fetchSessionHandler.newBuilder() partitionMap.foreach { case (topicPartition, partitionFetchState) => // We will not include a replica in the fetch request if it should be throttled. if (partitionFetchState.isReadyForFetch && !shouldFollowerThrottle(quota, topicPartition)) { try { val logStartOffset = replicaMgr.getReplicaOrException(topicPartition).logStartOffset - requestMap.put(topicPartition, new JFetchRequest.PartitionData(partitionFetchState.fetchOffset, logStartOffset, fetchSize)) + builder.add(topicPartition, new JFetchRequest.PartitionData( + partitionFetchState.fetchOffset, logStartOffset, fetchSize)) } catch { case _: KafkaStorageException => // The replica has already been marked offline due to log directory failure and the original failure should have already been logged. @@ -258,9 +274,15 @@ class ReplicaFetcherThread(name: String, } } - val requestBuilder = JFetchRequest.Builder.forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, requestMap) - .setMaxBytes(maxBytes) - ResultWithPartitions(new FetchRequest(requestBuilder), partitionsWithError) + val fetchData = builder.build() + val requestBuilder = JFetchRequest.Builder. + forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, fetchData.toSend()) + .setMaxBytes(maxBytes) + .toForget(fetchData.toForget) + if (fetchMetadataSupported) { + requestBuilder.metadata(fetchData.metadata()) + } + ResultWithPartitions(new FetchRequest(fetchData.sessionPartitions(), requestBuilder), partitionsWithError) } /** @@ -365,10 +387,12 @@ class ReplicaFetcherThread(name: String, object ReplicaFetcherThread { - private[server] class FetchRequest(val underlying: JFetchRequest.Builder) extends AbstractFetcherThread.FetchRequest { - def isEmpty: Boolean = underlying.fetchData.isEmpty + private[server] class FetchRequest(val sessionParts: util.Map[TopicPartition, JFetchRequest.PartitionData], + val underlying: JFetchRequest.Builder) + extends AbstractFetcherThread.FetchRequest { def offset(topicPartition: TopicPartition): Long = - underlying.fetchData.asScala(topicPartition).fetchOffset + sessionParts.get(topicPartition).fetchOffset + override def isEmpty = sessionParts.isEmpty && underlying.toForget().isEmpty override def toString = underlying.toString } diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala index 9090fdac3ac90..f2b3552feaf5b 100644 --- a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala +++ b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.{Record, RecordBatch} import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata} import org.apache.kafka.common.serialization.{ByteArraySerializer, StringSerializer} import org.junit.Assert._ import org.junit.Test @@ -294,6 +295,60 @@ class FetchRequestTest extends BaseRequestTest { expectedMagic = RecordBatch.MAGIC_VALUE_V2) } + /** + * Test that when an incremental fetch session contains partitions with an error, + * those partitions are returned in all incremental fetch requests. + */ + @Test + def testCreateIncrementalFetchWithPartitionsInError(): Unit = { + def createFetchRequest(topicPartitions: Seq[TopicPartition], + metadata: JFetchMetadata, + toForget: Seq[TopicPartition]): FetchRequest = + FetchRequest.Builder.forConsumer(Int.MaxValue, 0, + createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty)) + .toForget(toForget.asJava) + .metadata(metadata) + .build() + val foo0 = new TopicPartition("foo", 0) + val foo1 = new TopicPartition("foo", 1) + createTopic("foo", Map(0 -> List(0, 1), 1 -> List(0, 2))) + val bar0 = new TopicPartition("bar", 0) + val req1 = createFetchRequest(List(foo0, foo1, bar0), JFetchMetadata.INITIAL, Nil) + val resp1 = sendFetchRequest(0, req1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue("Expected the broker to create a new incremental fetch session", resp1.sessionId() > 0) + debug(s"Test created an incremental fetch session ${resp1.sessionId}") + assertTrue(resp1.responseData().containsKey(foo0)) + assertTrue(resp1.responseData().containsKey(foo1)) + assertTrue(resp1.responseData().containsKey(bar0)) + assertEquals(Errors.NONE, resp1.responseData().get(foo0).error) + assertEquals(Errors.NONE, resp1.responseData().get(foo1).error) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, resp1.responseData().get(bar0).error) + val req2 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 1), Nil) + val resp2 = sendFetchRequest(0, req2) + assertEquals(Errors.NONE, resp2.error()) + assertEquals("Expected the broker to continue the incremental fetch session", + resp1.sessionId(), resp2.sessionId()) + assertFalse(resp2.responseData().containsKey(foo0)) + assertFalse(resp2.responseData().containsKey(foo1)) + assertTrue(resp2.responseData().containsKey(bar0)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, resp2.responseData().get(bar0).error) + createTopic("bar", Map(0 -> List(0, 1))) + val req3 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 2), Nil) + val resp3 = sendFetchRequest(0, req3) + assertEquals(Errors.NONE, resp3.error()) + assertFalse(resp3.responseData().containsKey(foo0)) + assertFalse(resp3.responseData().containsKey(foo1)) + assertTrue(resp3.responseData().containsKey(bar0)) + assertEquals(Errors.NONE, resp3.responseData().get(bar0).error) + val req4 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 3), Nil) + val resp4 = sendFetchRequest(0, req4) + assertEquals(Errors.NONE, resp4.error()) + assertFalse(resp4.responseData().containsKey(foo0)) + assertFalse(resp4.responseData().containsKey(foo1)) + assertFalse(resp4.responseData().containsKey(bar0)) + } + private def records(partitionData: FetchResponse.PartitionData): Seq[Record] = { partitionData.records.records.asScala.toIndexedSeq } diff --git a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala new file mode 100755 index 0000000000000..3320b63938e94 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala @@ -0,0 +1,312 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.server + +import java.util +import java.util.Collections + +import kafka.utils.MockTime +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata} +import org.junit.{Rule, Test} +import org.junit.Assert._ +import org.junit.rules.Timeout + +class FetchSessionTest { + @Rule + def globalTimeout = Timeout.millis(120000) + + @Test + def testNewSessionId(): Unit = { + val cache = new FetchSessionCache(3, 100) + for (i <- 0 to 10000) { + val id = cache.newSessionId() + assertTrue(id > 0) + } + } + + def assertCacheContains(cache: FetchSessionCache, sessionIds: Int*) = { + var i = 0 + for (sessionId <- sessionIds) { + i = i + 1 + assertTrue("Missing session " + i + " out of " + sessionIds.size + "(" + sessionId + ")", + cache.get(sessionId).isDefined) + } + assertEquals(sessionIds.size, cache.size()) + } + + private def dummyCreate(size: Int)() = { + val cacheMap = new FetchSession.CACHE_MAP(size) + for (i <- 0 to (size - 1)) { + cacheMap.add(new CachedPartition("test", i)) + } + cacheMap + } + + @Test + def testSessionCache(): Unit = { + val cache = new FetchSessionCache(3, 100) + assertEquals(0, cache.size()) + val id1 = cache.maybeCreateSession(0, false, 10, dummyCreate(10)) + val id2 = cache.maybeCreateSession(10, false, 20, dummyCreate(20)) + val id3 = cache.maybeCreateSession(20, false, 30, dummyCreate(30)) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(30, false, 40, dummyCreate(40))) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(40, false, 5, dummyCreate(5))) + assertCacheContains(cache, id1, id2, id3) + cache.touch(cache.get(id1).get, 200) + val id4 = cache.maybeCreateSession(210, false, 11, dummyCreate(11)) + assertCacheContains(cache, id1, id3, id4) + cache.touch(cache.get(id1).get, 400) + cache.touch(cache.get(id3).get, 390) + cache.touch(cache.get(id4).get, 400) + val id5 = cache.maybeCreateSession(410, false, 50, dummyCreate(50)) + assertCacheContains(cache, id3, id4, id5) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(410, false, 5, dummyCreate(5))) + val id6 = cache.maybeCreateSession(410, true, 5, dummyCreate(5)) + assertCacheContains(cache, id3, id5, id6) + } + + @Test + def testResizeCachedSessions(): Unit = { + val cache = new FetchSessionCache(2, 100) + assertEquals(0, cache.totalPartitions()) + assertEquals(0, cache.size()) + assertEquals(0, cache.evictionsMeter.count()) + val id1 = cache.maybeCreateSession(0, false, 2, dummyCreate(2)) + assertTrue(id1 > 0) + assertCacheContains(cache, id1) + val session1 = cache.get(id1).get + assertEquals(2, session1.size()) + assertEquals(2, cache.totalPartitions()) + assertEquals(1, cache.size()) + assertEquals(0, cache.evictionsMeter.count()) + val id2 = cache.maybeCreateSession(0, false, 4, dummyCreate(4)) + val session2 = cache.get(id2).get + assertTrue(id2 > 0) + assertCacheContains(cache, id1, id2) + assertEquals(6, cache.totalPartitions()) + assertEquals(2, cache.size()) + assertEquals(0, cache.evictionsMeter.count()) + cache.touch(session1, 200) + cache.touch(session2, 200) + val id3 = cache.maybeCreateSession(200, false, 5, dummyCreate(5)) + assertTrue(id3 > 0) + assertCacheContains(cache, id2, id3) + assertEquals(9, cache.totalPartitions()) + assertEquals(2, cache.size()) + assertEquals(1, cache.evictionsMeter.count()) + cache.remove(id3) + assertCacheContains(cache, id2) + assertEquals(1, cache.size()) + assertEquals(1, cache.evictionsMeter.count()) + assertEquals(4, cache.totalPartitions()) + val iter = session2.partitionMap.iterator() + iter.next() + iter.remove() + assertEquals(3, session2.size()) + assertEquals(4, session2.cachedSize) + cache.touch(session2, session2.lastUsedMs) + assertEquals(3, cache.totalPartitions()) + } + + val EMPTY_PART_LIST = Collections.unmodifiableList(new util.ArrayList[TopicPartition]()) + + @Test + def testFetchRequests(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + // Verify that SESSIONLESS requests get a SessionlessFetchContext + val context = fetchManager.newContext(JFetchMetadata.LEGACY, + new util.HashMap[TopicPartition, FetchRequest.PartitionData](), EMPTY_PART_LIST, true) + assertEquals(classOf[SessionlessFetchContext], context.getClass) + + // Create a new fetch session with a FULL fetch request + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData2.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100)) + reqData2.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100)) + val context2 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData2, EMPTY_PART_LIST, false) + assertEquals(classOf[FullFetchContext], context2.getClass) + val reqData2Iter = reqData2.entrySet().iterator() + context2.foreachPartition((topicPart, data) => { + val entry = reqData2Iter.next() + assertEquals(entry.getKey, topicPart) + assertEquals(entry.getValue, data) + }) + assertEquals(0, context2.getFetchOffset(new TopicPartition("foo", 0)).get) + assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get) + val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + respData2.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData( + Errors.NONE, 100, 100, 100, null, null)) + respData2.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData( + Errors.NONE, 10, 10, 10, null, null)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp2.error()) + assertTrue(resp2.sessionId() != INVALID_SESSION_ID) + assertEquals(respData2, resp2.responseData()) + + // Test trying to create a new session with an invalid epoch + val context3 = fetchManager.newContext( + new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false) + assertEquals(classOf[SessionErrorContext], context3.getClass) + assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, + context3.updateAndGenerateResponseData(respData2).error()) + + // Test trying to create a new session with a non-existent session id + val context4 = fetchManager.newContext( + new JFetchMetadata(resp2.sessionId() + 1, 1), reqData2, EMPTY_PART_LIST, false) + assertEquals(classOf[SessionErrorContext], context4.getClass) + assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND, + context4.updateAndGenerateResponseData(respData2).error()) + + // Continue the first fetch session we created. + val reqData5 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val context5 = fetchManager.newContext( + new JFetchMetadata(resp2.sessionId(), 1), reqData5, EMPTY_PART_LIST, false) + assertEquals(classOf[IncrementalFetchContext], context5.getClass) + val reqData5Iter = reqData2.entrySet().iterator() + context5.foreachPartition((topicPart, data) => { + val entry = reqData5Iter.next() + assertEquals(entry.getKey, topicPart) + assertEquals(entry.getValue, data) + }) + assertEquals(10, context5.getFetchOffset(new TopicPartition("foo", 1)).get) + val resp5 = context5.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp5.error()) + assertEquals(resp2.sessionId(), resp5.sessionId()) + assertEquals(0, resp5.responseData().size()) + + // Test setting an invalid fetch session epoch. + val context6 = fetchManager.newContext( + new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false) + assertEquals(classOf[SessionErrorContext], context6.getClass) + assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, + context6.updateAndGenerateResponseData(respData2).error()) + + // Close the incremental fetch session. + var prevSessionId = resp5.sessionId() + var nextSessionId = prevSessionId + do { + val reqData7 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData7.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(0, 0, 100)) + reqData7.put(new TopicPartition("bar", 1), new FetchRequest.PartitionData(10, 0, 100)) + val context7 = fetchManager.newContext( + new JFetchMetadata(prevSessionId, FINAL_EPOCH), reqData7, EMPTY_PART_LIST, false) + assertEquals(classOf[SessionlessFetchContext], context7.getClass) + assertEquals(0, cache.size()) + val respData7 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + respData7.put(new TopicPartition("bar", 0), + new FetchResponse.PartitionData(Errors.NONE, 100, 100, 100, null, null)) + respData7.put(new TopicPartition("bar", 1), + new FetchResponse.PartitionData(Errors.NONE, 100, 100, 100, null, null)) + val resp7 = context7.updateAndGenerateResponseData(respData7) + assertEquals(Errors.NONE, resp7.error()) + nextSessionId = resp7.sessionId() + } while (nextSessionId == prevSessionId) + } + + @Test + def testIncrementalFetchSession(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100)) + reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100)) + val context1 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + respData1.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData( + Errors.NONE, 100, 100, 100, null, null)) + respData1.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData( + Errors.NONE, 10, 10, 10, null, null)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData().size()) + + // Create an incremental fetch request that removes foo-0 and adds bar-0 + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData2.put(new TopicPartition("bar", 0), new FetchRequest.PartitionData(15, 0, 0)) + val removed2 = new util.ArrayList[TopicPartition] + removed2.add(new TopicPartition("foo", 0)) + val context2 = fetchManager.newContext( + new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + val parts2 = Set(new TopicPartition("foo", 1), new TopicPartition("bar", 0)) + val reqData2Iter = parts2.iterator + context2.foreachPartition((topicPart, data) => { + assertEquals(reqData2Iter.next(), topicPart) + }) + assertEquals(None, context2.getFetchOffset(new TopicPartition("foo", 0))) + assertEquals(10, context2.getFetchOffset(new TopicPartition("foo", 1)).get) + assertEquals(15, context2.getFetchOffset(new TopicPartition("bar", 0)).get) + assertEquals(None, context2.getFetchOffset(new TopicPartition("bar", 2))) + val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + respData2.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData( + Errors.NONE, 10, 10, 10, null, null)) + respData2.put(new TopicPartition("bar", 0), new FetchResponse.PartitionData( + Errors.NONE, 10, 10, 10, null, null)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp2.error()) + assertEquals(1, resp2.responseData().size()) + assertTrue(resp2.sessionId() > 0) + } + + @Test + def testZeroSizeFetchSession(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(new TopicPartition("foo", 0), new FetchRequest.PartitionData(0, 0, 100)) + reqData1.put(new TopicPartition("foo", 1), new FetchRequest.PartitionData(10, 0, 100)) + val context1 = fetchManager.newContext(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + respData1.put(new TopicPartition("foo", 0), new FetchResponse.PartitionData( + Errors.NONE, 100, 100, 100, null, null)) + respData1.put(new TopicPartition("foo", 1), new FetchResponse.PartitionData( + Errors.NONE, 10, 10, 10, null, null)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData().size()) + + // Create an incremental fetch request that removes foo-0 and foo-1 + // Verify that the previous fetch session was closed. + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val removed2 = new util.ArrayList[TopicPartition] + removed2.add(new TopicPartition("foo", 0)) + removed2.add(new TopicPartition("foo", 1)) + val context2 = fetchManager.newContext( + new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false) + assertEquals(classOf[SessionlessFetchContext], context2.getClass) + val respData2 = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(INVALID_SESSION_ID, resp2.sessionId()) + assertTrue(resp2.responseData().isEmpty) + assertEquals(0, cache.size()) + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 8e907d930e9c7..5de978cd0beb8 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -69,6 +69,7 @@ class KafkaApisTest { private val clientRequestQuotaManager = EasyMock.createNiceMock(classOf[ClientRequestQuotaManager]) private val replicaQuotaManager = EasyMock.createNiceMock(classOf[ReplicationQuotaManager]) private val quotas = QuotaManagers(clientQuotaManager, clientQuotaManager, clientRequestQuotaManager, replicaQuotaManager, replicaQuotaManager, replicaQuotaManager) + private val fetchManager = EasyMock.createNiceMock(classOf[FetchManager]) private val brokerTopicStats = new BrokerTopicStats private val clusterId = "clusterId" private val time = new MockTime @@ -96,6 +97,7 @@ class KafkaApisTest { metrics, authorizer, quotas, + fetchManager, brokerTopicStats, clusterId, time, diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala index 0692afb6591c1..1f5bec1cf8220 100644 --- a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala +++ b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala @@ -20,10 +20,10 @@ import kafka.cluster.BrokerEndPoint import kafka.server.BlockingSend import org.apache.kafka.clients.{ClientRequest, ClientResponse, MockClient} import org.apache.kafka.common.{Node, TopicPartition} -import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.requests.AbstractRequest.Builder import org.apache.kafka.common.requests.FetchResponse.PartitionData -import org.apache.kafka.common.requests.{AbstractRequest, EpochEndOffset, FetchResponse, OffsetsForLeaderEpochResponse} +import org.apache.kafka.common.requests.{AbstractRequest, EpochEndOffset, FetchResponse, OffsetsForLeaderEpochResponse, FetchMetadata => JFetchMetadata} import org.apache.kafka.common.utils.{SystemTime, Time} /** @@ -54,7 +54,8 @@ class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, Epoc case ApiKeys.FETCH => fetchCount += 1 - new FetchResponse(new java.util.LinkedHashMap[TopicPartition, PartitionData], 0) + new FetchResponse(Errors.NONE, new java.util.LinkedHashMap[TopicPartition, PartitionData], 0, + JFetchMetadata.INVALID_SESSION_ID) case _ => throw new UnsupportedOperationException