From 416fc9ba75251f17543f8556071885a026ef5745 Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Thu, 17 Mar 2022 18:00:38 +0100 Subject: [PATCH] [BEAM-14104] Support shard aware aggregation in Kinesis writer. --- .../beam/sdk/io/aws2/common/ClientPool.java | 123 -------- .../beam/sdk/io/aws2/common/ObjectPool.java | 151 ++++++++++ .../io/aws2/common/RetryConfiguration.java | 2 +- .../beam/sdk/io/aws2/kinesis/KinesisIO.java | 270 +++++++++++++++--- .../io/aws2/kinesis/KinesisPartitioner.java | 28 +- .../io/aws2/kinesis/RecordsAggregator.java | 4 - ...lientPoolTest.java => ObjectPoolTest.java} | 99 ++++--- .../io/aws2/kinesis/KinesisIOWriteTest.java | 243 ++++++++++++++-- .../io/aws2/kinesis/PutRecordsHelpers.java | 8 + .../io/aws2/kinesis/testing/KinesisIOIT.java | 5 +- 10 files changed, 677 insertions(+), 256 deletions(-) delete mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java rename sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/{ClientPoolTest.java => ObjectPoolTest.java} (63%) diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java deleted file mode 100644 index 1a7cd29ec98c..000000000000 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.common; - -import java.util.function.BiFunction; -import org.apache.beam.sdk.io.aws2.options.AwsOptions; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap; -import org.apache.commons.lang3.tuple.Pair; -import org.checkerframework.checker.nullness.qual.Nullable; -import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; - -/** - * Reference counting pool to easily share AWS clients or similar by individual client provider and - * configuration (optional). - * - *

NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link - * ProviderT} and {@link ConfigT}. If not implemented properly, clients can't be shared between - * instances of {@link org.apache.beam.sdk.transforms.DoFn}. - * - * @param Client provider - * @param Optional, nullable configuration - * @param Client - */ -public class ClientPool { - private final BiMap, RefCounted> pool = HashBiMap.create(2); - private final BiFunction builder; - - public static < - ClientT extends AutoCloseable, BuilderT extends AwsClientBuilder> - ClientPool pooledClientFactory(BuilderT builder) { - return new ClientPool<>((opts, conf) -> ClientBuilderFactory.buildClient(opts, builder, conf)); - } - - public ClientPool(BiFunction builder) { - this.builder = builder; - } - - /** Retain a reference to a shared client instance. If not available, an instance is created. */ - public ClientT retain(ProviderT provider, @Nullable ConfigT config) { - @SuppressWarnings("nullness") - Pair key = Pair.of(provider, config); - synchronized (pool) { - RefCounted ref = pool.computeIfAbsent(key, RefCounted::new); - ref.count++; - return ref.client; - } - } - - /** - * Release a reference to a shared client instance using {@link ProviderT} and {@link ConfigT} . - * If that instance is not used anymore, it will be removed and destroyed. - */ - public void release(ProviderT provider, @Nullable ConfigT config) throws Exception { - @SuppressWarnings("nullness") - Pair key = Pair.of(provider, config); - RefCounted ref; - synchronized (pool) { - ref = pool.get(key); - if (ref == null || --ref.count > 0) { - return; - } - pool.remove(key); - } - ref.client.close(); - } - - /** - * Release a reference to a shared client instance. If that instance is not used anymore, it will - * be removed and destroyed. - */ - public void release(ClientT client) throws Exception { - Pair pair = pool.inverse().get(new RefCounted(client)); - if (pair != null) { - release(pair.getLeft(), pair.getRight()); - } - } - - private class RefCounted { - private int count = 0; - private final ClientT client; - - RefCounted(ClientT client) { - this.client = client; - } - - RefCounted(Pair key) { - this(builder.apply(key.getLeft(), key.getRight())); - } - - @Override - public boolean equals(@Nullable Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - // only identity of ref counted client matters - return client == ((RefCounted) o).client; - } - - @Override - public int hashCode() { - return client.hashCode(); - } - } -} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java new file mode 100644 index 000000000000..a17c6b56e5f0 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.common; + +import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient; + +import java.util.function.Function; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.sdk.function.ThrowingConsumer; +import org.apache.beam.sdk.io.aws2.options.AwsOptions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap; +import org.apache.commons.lang3.tuple.Pair; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder; +import software.amazon.awssdk.core.SdkClient; + +/** + * Reference counting object pool to easily share & destroy objects. + * + *

Internal only, subject to incompatible changes or removal at any time! + * + *

NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link KeyT}. + * If not implemented properly, clients can't be shared between instances of {@link + * org.apache.beam.sdk.transforms.DoFn}. + * + * @param > Key to share objects by + * @param > Shared object + */ +@Internal +@Experimental +public class ObjectPool { + private final BiMap pool = HashBiMap.create(2); + private final Function builder; + private final @Nullable ThrowingConsumer finalizer; + + public ObjectPool(Function builder) { + this(builder, null); + } + + public ObjectPool( + Function builder, @Nullable ThrowingConsumer finalizer) { + this.builder = builder; + this.finalizer = finalizer; + } + + /** Retain a reference to a shared client instance. If not available, an instance is created. */ + public ObjectT retain(KeyT key) { + synchronized (pool) { + RefCounted ref = pool.computeIfAbsent(key, k -> new RefCounted(builder.apply(k))); + ref.count++; + return ref.shared; + } + } + + /** + * Release a reference to a shared object instance using {@link KeyT}. If that instance is not + * used anymore, it will be removed and destroyed. + */ + public void releaseByKey(KeyT key) { + RefCounted ref; + synchronized (pool) { + ref = pool.get(key); + if (ref == null || --ref.count > 0) { + return; + } + pool.remove(key); + } + if (finalizer != null) { + try { + finalizer.accept(ref.shared); + } catch (Exception e) { + LoggerFactory.getLogger(ObjectPool.class).warn("Exception destroying pooled object.", e); + } + } + } + + /** + * Release a reference to a shared client instance. If that instance is not used anymore, it will + * be removed and destroyed. + */ + public void release(ObjectT object) { + KeyT key = pool.inverse().get(new RefCounted(object)); + if (key != null) { + releaseByKey(key); + } + } + + public static > + ClientPool pooledClientFactory(BuilderT builder) { + return new ClientPool<>(c -> buildClient(c.getLeft(), builder, c.getRight())); + } + + /** Client pool to easily share AWS clients per configuration. */ + public static class ClientPool + extends ObjectPool, ClientT> { + + private ClientPool(Function, ClientT> builder) { + super(builder, c -> c.close()); + } + + /** Retain a reference to a shared client instance. If not available, an instance is created. */ + public ClientT retain(AwsOptions provider, ClientConfiguration config) { + return retain(Pair.of(provider, config)); + } + } + + private class RefCounted { + private int count = 0; + private final ObjectT shared; + + RefCounted(ObjectT client) { + this.shared = client; + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + // only identity of ref counted shared object matters + return shared == ((RefCounted) o).shared; + } + + @Override + public int hashCode() { + return shared.hashCode(); + } + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java index 6ca342927a89..4a816cb3d91e 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java @@ -68,7 +68,7 @@ public abstract class RetryConfiguration implements Serializable { public abstract RetryConfiguration.Builder toBuilder(); public static Builder builder() { - return Builder.builder(); + return Builder.builder().numRetries(3); } @AutoValue.Builder diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java index 69f4f28427cb..5042a574592c 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java @@ -22,6 +22,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import static org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY; import static org.apache.commons.lang3.StringUtils.isEmpty; +import static software.amazon.awssdk.services.kinesis.model.ShardFilterType.AT_LATEST; import com.google.auto.value.AutoValue; import java.io.Serializable; @@ -33,17 +34,24 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.NavigableSet; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; +import javax.annotation.concurrent.NotThreadSafe; +import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.io.Read.Unbounded; import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory; import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; -import org.apache.beam.sdk.io.aws2.common.ClientPool; +import org.apache.beam.sdk.io.aws2.common.ObjectPool; +import org.apache.beam.sdk.io.aws2.common.ObjectPool.ClientPool; import org.apache.beam.sdk.io.aws2.common.RetryConfiguration; +import org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.ExplicitPartitioner; import org.apache.beam.sdk.io.aws2.options.AwsOptions; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Distribution; @@ -65,7 +73,9 @@ import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSortedSet; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.dataflow.qual.Pure; import org.joda.time.DateTimeUtils; @@ -79,7 +89,9 @@ import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; +import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry; +import software.amazon.awssdk.services.kinesis.model.Shard; import software.amazon.kinesis.common.InitialPositionInStream; /** @@ -173,7 +185,8 @@ * utilized at all. * *

If you require finer control over the distribution of records, override {@link - * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs. + * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs. However, this might + * impact record aggregation. * *

Aggregation of records

* @@ -182,10 +195,29 @@ * href="https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation">aggregated * KPL record. * - *

However, only records with the same effective hash key are aggregated, in which the effective - * hash key is either the explicit hash key if defined, or otherwise the hashed partition key. + *

Records of the same effective hash key get aggregated. The effective hash key is: * - *

Record aggregation can be explicitly disabled using {@link + *

    + *
  1. the explicit hash key, if provided. + *
  2. the lower bound of the hash key range of the target shard according to the given partition + * key, if available. + *
  3. or otherwise the hashed partition key + *
+ * + *

To provide shard aware aggregation in 2., hash key ranges of shards are loaded and refreshed + * periodically. This allows to aggregate records into a number of aggregates that matches the + * number of shards in the stream to max out Kinesis API limits the best possible way. + * + *

Note:There's an important downside to consider when using shard aware aggregation: + * records get assigned to a shard (via an explicit hash key) on the client side, but respective + * client side state can't be guaranteed to always be up-to-date. If a shard gets split, all + * aggregates are mapped to the lower child shard until state is refreshed. Timing, however, will + * diverge between the different workers. + * + *

If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link + * RecordAggregation}, no shard details will be loaded (and used). + * + *

Record aggregation can be entirely disabled using {@link * Write#withRecordAggregationDisabled()}. * *

Configuration of AWS clients

@@ -536,11 +568,30 @@ public abstract static class RecordAggregation implements Serializable { abstract double maxBufferedTimeJitter(); + abstract Duration shardRefreshInterval(); + + abstract double shardRefreshIntervalJitter(); + + Instant nextBufferTimeout() { + return nextInstant(maxBufferedTime(), maxBufferedTimeJitter()); + } + + Instant nextShardRefresh() { + return nextInstant(shardRefreshInterval(), shardRefreshIntervalJitter()); + } + + private Instant nextInstant(Duration duration, double jitter) { + double millis = (1 - jitter + jitter * Math.random()) * duration.getMillis(); + return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis); + } + public static Builder builder() { return new AutoValue_KinesisIO_RecordAggregation.Builder() .maxBytes(Write.MAX_BYTES_PER_RECORD) .maxBufferedTimeJitter(0.7) // 70% jitter - .maxBufferedTime(Duration.standardSeconds(1)); + .maxBufferedTime(Duration.millis(500)) + .shardRefreshIntervalJitter(0.5) // 50% jitter + .shardRefreshInterval(Duration.standardMinutes(2)); } @AutoValue.Builder @@ -556,8 +607,19 @@ public abstract static class Builder { */ public abstract Builder maxBufferedTime(Duration interval); + /** + * Refresh interval for shards. + * + *

This is used for shard aware record aggregation to assign all records hashed to a + * particular shard to the same explicit hash key. Set to {@link Duration#ZERO} to disable + * loading shards. + */ + public abstract Builder shardRefreshInterval(Duration interval); + abstract Builder maxBufferedTimeJitter(double jitter); + abstract Builder shardRefreshIntervalJitter(double jitter); + abstract RecordAggregation autoBuild(); public RecordAggregation build() { @@ -670,9 +732,6 @@ public Write withConcurrentRequests(int concurrentRequests) { * Enable record aggregation that is compatible with the KPL / KCL. * *

https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation - * - *

Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates - * records with the same partition key as it's not aware of explicit hash key ranges per shard. */ public Write withRecordAggregation(RecordAggregation aggregation) { return builder().recordAggregation(aggregation).build(); @@ -682,9 +741,6 @@ public Write withRecordAggregation(RecordAggregation aggregation) { * Enable record aggregation that is compatible with the KPL / KCL. * *

https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation - * - *

Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates - * records with the same partition key as it's not aware of explicit hash key ranges per shard. */ public Write withRecordAggregation(Consumer aggregation) { RecordAggregation.Builder builder = RecordAggregation.builder(); @@ -803,14 +859,14 @@ private static class Writer implements AutoCloseable { private static final int PARTIAL_RETRIES = 10; // Retries for partial success (throttling) - private static final ClientPool CLIENTS = - ClientPool.pooledClientFactory(KinesisAsyncClient.builder()); + private static final ClientPool CLIENTS = + ObjectPool.pooledClientFactory(KinesisAsyncClient.builder()); protected final Write spec; protected final Stats stats; protected final AsyncPutRecordsHandler handler; + protected final KinesisAsyncClient kinesis; - private final KinesisAsyncClient kinesis; private List requestEntries; private int requestBytes = 0; @@ -947,25 +1003,37 @@ private void validateExplicitHashKey(String hashKey) { * with KCL to correctly implement the binary protocol, specifically {@link * software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord}. * - *

Note: The aggregation is a lot simpler than the one offered by KPL. While the KPL is aware - * of effective hash key ranges assigned to each shard, we're not and don't want to be to keep - * complexity manageable and avoid the risk of silently loosing records in the KCL: + *

To aggregate records the best possible way, records are assigned an explicit hash key that + * corresponds to the lower bound of the hash key range of the target shard. In case a record + * has already an explicit hash key assigned, it is kept unchanged. * - *

{@link software.amazon.kinesis.retrieval.AggregatorUtil#deaggregate(List, BigInteger, - * BigInteger)} drops records not matching the expected hash key range. + *

Hash key ranges of shards are expected to be only slowly changing and get refreshed + * infrequently. If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link + * RecordAggregation}, no shard details will be pulled. */ static class AggregatedWriter extends Writer { private static final Logger LOG = LoggerFactory.getLogger(AggregatedWriter.class); + private static final ObjectPool SHARD_RANGES_BY_STREAM = + new ObjectPool<>(ShardRanges::of); private final RecordAggregation aggSpec; private final Map aggregators; - private final MessageDigest md5Digest; + private final PartitionKeyHasher pkHasher; + + private final ShardRanges shardRanges; AggregatedWriter(PipelineOptions options, Write spec, RecordAggregation aggSpec) { super(options, spec); this.aggSpec = aggSpec; - this.aggregators = new LinkedHashMap<>(); - this.md5Digest = md5Digest(); + aggregators = new LinkedHashMap<>(); + pkHasher = new PartitionKeyHasher(); + if (aggSpec.shardRefreshInterval().isLongerThan(Duration.ZERO) + && !(spec.partitioner() instanceof ExplicitPartitioner)) { + shardRanges = SHARD_RANGES_BY_STREAM.retain(spec.streamName()); + shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh); + } else { + shardRanges = ShardRanges.EMPTY; + } } @Override @@ -977,20 +1045,36 @@ public void startBundle() { @Override protected void write(String partitionKey, @Nullable String explicitHashKey, byte[] data) throws Throwable { - BigInteger hashKey = effectiveHashKey(partitionKey, explicitHashKey); - RecordsAggregator agg = aggregators.computeIfAbsent(hashKey, k -> newRecordsAggregator()); + shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh); + + // calculate the effective hash key used for aggregation + BigInteger aggKey; + if (explicitHashKey != null) { + aggKey = new BigInteger(explicitHashKey); + } else { + BigInteger hashedPartitionKey = pkHasher.hashKey(partitionKey); + aggKey = shardRanges.shardAwareHashKey(hashedPartitionKey); + if (aggKey != null) { + // use the shard aware aggregation key as explicit hash key for optimal aggregation + explicitHashKey = aggKey.toString(); + } else { + aggKey = hashedPartitionKey; + } + } + + RecordsAggregator agg = aggregators.computeIfAbsent(aggKey, k -> newRecordsAggregator()); if (!agg.addRecord(partitionKey, explicitHashKey, data)) { // aggregated record too full, add a request entry and reset aggregator - addRequestEntry(agg.getAndReset(aggregationTimeoutWithJitter())); - aggregators.remove(hashKey); + addRequestEntry(agg.getAndReset(aggSpec.nextBufferTimeout())); + aggregators.remove(aggKey); if (agg.addRecord(partitionKey, explicitHashKey, data)) { - aggregators.put(hashKey, agg); // new aggregation started + aggregators.put(aggKey, agg); // new aggregation started } else { super.write(partitionKey, explicitHashKey, data); // skip aggregation } } else if (!agg.hasCapacity()) { addRequestEntry(agg.get()); - aggregators.remove(hashKey); + aggregators.remove(aggKey); } // only check timeouts sporadically if concurrency is already maxed out @@ -1001,14 +1085,7 @@ protected void write(String partitionKey, @Nullable String explicitHashKey, byte private RecordsAggregator newRecordsAggregator() { return new RecordsAggregator( - Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggregationTimeoutWithJitter()); - } - - private Instant aggregationTimeoutWithJitter() { - double millis = - (1 - aggSpec.maxBufferedTimeJitter() + aggSpec.maxBufferedTimeJitter() * Math.random()) - * aggSpec.maxBufferedTime().getMillis(); - return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis); + Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggSpec.nextBufferTimeout()); } private void checkAggregationTimeouts() throws Throwable { @@ -1021,9 +1098,8 @@ private void checkAggregationTimeouts() throws Throwable { if (agg.timeout().isAfter(now)) { break; } - LOG.debug( - "Adding aggregated entry after timeout [delay = {} ms]", - now.getMillis() - agg.timeout().getMillis()); + long delayMillis = now.getMillis() - agg.timeout().getMillis(); + LOG.debug("Adding aggregated entry after timeout [delay = {} ms]", delayMillis); addRequestEntry(agg.get()); removals.add(e.getKey()); } @@ -1041,16 +1117,23 @@ public void finishBundle() throws Throwable { super.finishBundle(); } - private BigInteger effectiveHashKey(String partitionKey, @Nullable String explicitHashKey) { - return explicitHashKey == null - ? new BigInteger(1, md5(partitionKey.getBytes(UTF_8))) - : new BigInteger(explicitHashKey); + @Override + public void close() throws Exception { + super.close(); + SHARD_RANGES_BY_STREAM.release(shardRanges); } + } - private byte[] md5(byte[] data) { - byte[] hash = md5Digest.digest(data); + @VisibleForTesting + @NotThreadSafe + static class PartitionKeyHasher { + private final MessageDigest md5Digest = md5Digest(); + + /** Hash partition key to 128 bit integer. */ + BigInteger hashKey(String partitionKey) { + byte[] hashedBytes = md5Digest.digest(partitionKey.getBytes(UTF_8)); md5Digest.reset(); - return hash; + return new BigInteger(1, hashedBytes); } private static MessageDigest md5Digest() { @@ -1062,6 +1145,99 @@ private static MessageDigest md5Digest() { } } + /** Shard hash ranges per stream to generate shard aware hash keys for record aggregation. */ + @VisibleForTesting + @ThreadSafe + interface ShardRanges { + ShardRanges EMPTY = new ShardRanges() {}; + + static ShardRanges of(String stream) { + return new ShardRangesImpl(stream); + } + + /** + * Align partition key hash to lower bound of key range of the target shard. If unavailable + * {@code null} is returned. + */ + default @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) { + return null; + } + + /** Check for and trigger periodic refresh if needed. */ + default void refreshPeriodically( + KinesisAsyncClient kinesis, Supplier nextRefreshFn) {} + + class ShardRangesImpl implements ShardRanges { + private static final Logger LOG = LoggerFactory.getLogger(ShardRanges.class); + + private final String streamName; + + private final AtomicBoolean running = new AtomicBoolean(false); + private NavigableSet shardBounds = ImmutableSortedSet.of(); + private Instant nextRefresh = Instant.EPOCH; + + private ShardRangesImpl(String streamName) { + this.streamName = streamName; + } + + @Override + public @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) { + BigInteger lowerBound = shardBounds.floor(hashedPartitionKey); + if (!shardBounds.isEmpty() && lowerBound == null) { + LOG.warn("No shard found for {} [shards={}]", hashedPartitionKey, shardBounds.size()); + } + return lowerBound; + } + + @Override + public void refreshPeriodically( + KinesisAsyncClient client, Supplier nextRefreshFn) { + if (nextRefresh.isBeforeNow() && running.compareAndSet(false, true)) { + refresh(client, nextRefreshFn, new TreeSet<>(), null); + } + } + + @SuppressWarnings("FutureReturnValueIgnored") // safe to ignore + private void refresh( + KinesisAsyncClient client, + Supplier nextRefreshFn, + TreeSet bounds, + @Nullable String nextToken) { + ListShardsRequest.Builder reqBuilder = + ListShardsRequest.builder().shardFilter(f -> f.type(AT_LATEST)); + if (nextToken != null) { + reqBuilder.nextToken(nextToken); + } else { + reqBuilder.streamName(streamName); + } + client + .listShards(reqBuilder.build()) + .whenComplete( + (resp, exc) -> { + if (exc != null) { + LOG.warn("Failed to refresh shards.", exc); + nextRefresh = nextRefreshFn.get(); // retry later + running.set(false); + return; + } + resp.shards().forEach(shard -> bounds.add(lowerHashKey(shard))); + if (resp.nextToken() != null) { + refresh(client, nextRefreshFn, bounds, resp.nextToken()); + return; + } + LOG.debug("Done refreshing {} shards.", bounds.size()); + nextRefresh = nextRefreshFn.get(); + running.set(false); + shardBounds = bounds; // swap key ranges + }); + } + + private BigInteger lowerHashKey(Shard shard) { + return new BigInteger(shard.hashKeyRange().startingHashKey()); + } + } + } + private static class Stats implements AsyncPutRecordsHandler.Stats { private static final Logger LOG = LoggerFactory.getLogger(Stats.class); private static final Duration LOG_STATS_PERIOD = Duration.standardSeconds(10); diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java index 99d5f9154464..eaf9b0b76070 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java @@ -47,6 +47,26 @@ public interface KinesisPartitioner extends Serializable { return null; } + /** + * An explicit partitioner that always returns a {@code Nonnull} explicit hash key. The partition + * key is irrelevant in this case, though it cannot be {@code null}. + */ + interface ExplicitPartitioner extends KinesisPartitioner { + @Override + default @Nonnull String getPartitionKey(T record) { + return "a"; // will be ignored, but can't be null or empty + } + + /** + * Required hash value (128-bit integer) to determine explicitly the shard a record is assigned + * to based on the hash key range of each shard. The explicit hash key overrides the partition + * key hash. + */ + @Override + @Nonnull + String getExplicitHashKey(T record); + } + /** * Explicit hash key partitioner that randomly returns one of x precalculated hash keys. Hash keys * are derived by equally dividing the 128-bit hash universe, assuming that hash ranges of shards @@ -70,15 +90,9 @@ static KinesisPartitioner explicitRandomPartitioner(int shards) { hashKey = hashKey.add(distance); } - return new KinesisPartitioner() { + return new ExplicitPartitioner() { @Nonnull @Override - public String getPartitionKey(T record) { - return "a"; // ignored, but can't be null - } - - @Nullable - @Override public String getExplicitHashKey(T record) { return hashKeys[new Random().nextInt(shards)]; } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java index 5694ec7f14f3..bcf546e792e8 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java @@ -41,10 +41,6 @@ * Record aggregator compatible with the record (de)aggregation of the Kinesis Producer Library * (KPL) and Kinesis Client Library (KCL). * - *

However, only records with the same effective hash key should be aggregated to keep complexity - * manageable. Otherwise, the aggregator would have to be aware of the most up-to-date explicit hash - * key ranges per shard. - * *

https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation */ @NotThreadSafe diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java similarity index 63% rename from sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java rename to sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java index 228e214f8b3c..154957ee3c09 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java @@ -20,10 +20,10 @@ import static java.util.concurrent.ForkJoinPool.commonPool; import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -34,30 +34,24 @@ import java.util.concurrent.ForkJoinTask; import java.util.function.Function; import java.util.stream.Stream; -import org.junit.Before; +import org.apache.beam.sdk.testing.ExpectedLogs; +import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Spy; -import org.mockito.junit.MockitoJUnitRunner; - -@RunWith(MockitoJUnitRunner.class) -public class ClientPoolTest { - @Spy ClientProvider provider = new ClientProvider(); - ClientPool, String, AutoCloseable> pool; - - @Before - public void init() { - pool = new ClientPool<>((p, c) -> p.apply(c)); - } + +public class ObjectPoolTest { + Function provider = spy(new Provider()); + ObjectPool pool = new ObjectPool<>(provider, obj -> obj.close()); + + @Rule public ExpectedLogs logs = ExpectedLogs.none(ObjectPool.class); class ResourceTask implements Callable { @Override - public AutoCloseable call() throws Exception { - AutoCloseable client = pool.retain(provider, "config"); - pool.retain(provider, "config"); - pool.release(provider, "config"); + public AutoCloseable call() { + AutoCloseable client = pool.retain("config"); + pool.retain("config"); + pool.release(client); verifyNoInteractions(client); - pool.release(provider, "config"); + pool.release(client); return client; } } @@ -83,8 +77,8 @@ public void shareClientsOfSameConfiguration() { String config1 = "config1"; String config2 = "config2"; - assertThat(pool.retain(provider, config1)).isSameAs(pool.retain(provider, config1)); - assertThat(pool.retain(provider, config1)).isNotSameAs(pool.retain(provider, config2)); + assertThat(pool.retain(config1)).isSameAs(pool.retain(config1)); + assertThat(pool.retain(config1)).isNotSameAs(pool.retain(config2)); verify(provider, times(2)).apply(anyString()); verify(provider, times(1)).apply(config1); verify(provider, times(1)).apply(config2); @@ -97,47 +91,70 @@ public void closeClientsOnceReleased() throws Exception { AutoCloseable client = null; for (int i = 0; i < sharedInstances; i++) { - client = pool.retain(provider, config); + client = pool.retain(config); } - for (int i = 1; i < sharedInstances; i++) { - pool.release(provider, config); + for (int i = 0; i < sharedInstances - 1; i++) { + pool.release(client); } verifyNoInteractions(client); // verify close on last release - pool.release(provider, config); + pool.release(client); verify(client).close(); // verify further attempts to release have no effect - pool.release(provider, config); + pool.release(client); verifyNoMoreInteractions(client); } @Test - public void recreateClientOnceReleased() throws Exception { + public void closeClientsOnceReleasedByKey() throws Exception { String config = "config"; - AutoCloseable client1 = pool.retain(provider, config); - pool.release(provider, config); - AutoCloseable client2 = pool.retain(provider, config); + int sharedInstances = 10; - verify(provider, times(2)).apply(config); + AutoCloseable client = null; + for (int i = 0; i < sharedInstances; i++) { + client = pool.retain(config); + } + + for (int i = 0; i < sharedInstances - 1; i++) { + pool.releaseByKey(config); + } + verifyNoInteractions(client); + // verify close on last release + pool.releaseByKey(config); + verify(client).close(); + // verify further attempts to release have no effect + pool.releaseByKey(config); + verifyNoMoreInteractions(client); + } + + @Test + public void recreateClientOnceReleased() throws Exception { + String config = "config"; + AutoCloseable client1 = pool.retain(config); + pool.release(client1); verify(client1).close(); + + AutoCloseable client2 = pool.retain(config); verifyNoInteractions(client2); + + verify(provider, times(2)).apply(config); + assertThat(client1).isNotSameAs(client2); } @Test public void releaseWithError() throws Exception { - String config = "config"; - AutoCloseable client1 = pool.retain(provider, config); - doThrow(new Exception("error on close")).when(client1).close(); - assertThatThrownBy(() -> pool.release(provider, config)).hasMessage("error on close"); + Exception onClose = new Exception("error on close"); - AutoCloseable client2 = pool.retain(provider, config); - verify(provider, times(2)).apply(config); - verify(client1).close(); - verifyNoInteractions(client2); + AutoCloseable client = pool.retain("config"); + doThrow(onClose).when(client).close(); + pool.release(client); + + verify(client).close(); + logs.verifyWarn("Exception destroying pooled object.", onClose); } - static class ClientProvider implements Function { + static class Provider implements Function { @Override public AutoCloseable apply(String configName) { return mock(AutoCloseable.class, configName); diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java index 1db39161306c..e5418f4ff142 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java @@ -17,18 +17,26 @@ */ package org.apache.beam.sdk.io.aws2.kinesis; +import static java.math.BigInteger.ONE; +import static java.util.Arrays.stream; import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.supplyAsync; import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toList; import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_RECORD; import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_REQUEST; import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_RECORDS_PER_REQUEST; +import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MAX_HASH_KEY; +import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MIN_HASH_KEY; +import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner; import static org.apache.beam.sdk.io.common.TestRow.getExpectedValues; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.concat; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.transform; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.joda.time.Duration.ZERO; +import static org.joda.time.Duration.millis; import static org.joda.time.Duration.standardSeconds; import static org.mockito.AdditionalMatchers.and; import static org.mockito.ArgumentMatchers.any; @@ -40,14 +48,20 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import java.math.BigInteger; import java.util.List; +import java.util.concurrent.CompletableFuture; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory; +import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; +import org.apache.beam.sdk.io.aws2.common.RetryConfiguration; import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write; import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.AggregatedWriter; +import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.PartitionKeyHasher; +import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.ShardRanges; import org.apache.beam.sdk.io.common.TestRow; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -55,22 +69,27 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects; import org.assertj.core.api.ThrowableAssert; import org.joda.time.DateTimeUtils; -import org.joda.time.Duration; +import org.joda.time.Instant; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; import org.mockito.InOrder; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; import software.amazon.awssdk.services.kinesis.KinesisAsyncClient; import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder; +import software.amazon.awssdk.services.kinesis.model.HashKeyRange; +import software.amazon.awssdk.services.kinesis.model.ListShardsRequest; +import software.amazon.awssdk.services.kinesis.model.ListShardsResponse; import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest; -import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse; +import software.amazon.awssdk.services.kinesis.model.Shard; /** Tests for {@link KinesisIO#write()}. */ @RunWith(MockitoJUnitRunner.StrictStubs.class) @@ -82,8 +101,12 @@ public class KinesisIOWriteTest extends PutRecordsHelpers { @Mock public KinesisAsyncClient client; @Before - public void configureClientBuilderFactory() { + public void configure() { MockClientBuilderFactory.set(pipeline, KinesisAsyncClientBuilder.class, client); + + CompletableFuture errorResp = new CompletableFuture<>(); + errorResp.completeExceptionally(new RuntimeException("Unavailable, retried later")); + when(client.listShards(any(ListShardsRequest.class))).thenReturn(errorResp); } @After @@ -158,11 +181,11 @@ public void testWriteWithBatchMaxBytes() { @Test public void testWriteFailure() { - when(client.putRecords(any(PutRecordsRequest.class))) + when(client.putRecords(anyRequest())) .thenReturn( - completedFuture(PutRecordsResponse.builder().build()), + completedFuture(successResponse), supplyAsync(() -> checkNotNull(null, "putRecords failed")), - completedFuture(PutRecordsResponse.builder().build())); + completedFuture(successResponse)); pipeline .apply(GenerateSequence.from(0).to(100)) @@ -178,17 +201,20 @@ public void testWriteFailure() { @Test public void testWriteWithPartialSuccess() { - when(client.putRecords(any(PutRecordsRequest.class))) + when(client.putRecords(anyRequest())) .thenReturn(completedFuture(partialSuccessResponse(70, 30))) .thenReturn(completedFuture(partialSuccessResponse(10, 20))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + .thenReturn(completedFuture(successResponse)); + + // minimize delay due to retries + RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(millis(1)).build(); pipeline .apply(Create.of(100)) .apply(ParDo.of(new GenerateTestRows())) .apply( kinesisWrite() - // .withRetryConfiguration(RetryConfiguration.fixed(10, Duration.millis(1))) + .withClientConfiguration(ClientConfiguration.builder().retry(retry).build()) .withRecordAggregationDisabled()); pipeline.run().waitUntilFinish(); @@ -201,9 +227,8 @@ public void testWriteWithPartialSuccess() { } @Test - public void testWriteAggregated() { - when(client.putRecords(any(PutRecordsRequest.class))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + public void testWriteAggregatedByDefault() { + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); pipeline .apply(Create.of(100)) @@ -215,10 +240,82 @@ public void testWriteAggregated() { verify(client).close(); } + @Test + public void testWriteAggregatedShardAware() { + mockShardRanges(MIN_HASH_KEY, MAX_HASH_KEY.shiftRight(1)); // 2 shards + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); + + pipeline + .apply(Create.of(100)) + .apply(ParDo.of(new GenerateTestRows())) + .apply(kinesisWrite().withPartitioner(row -> row.id().toString())); + + pipeline.run().waitUntilFinish(); + verify(client).putRecords(argThat(hasSize(2))); // 1 aggregated record per shard + verify(client).listShards(any(ListShardsRequest.class)); + verify(client).close(); + } + + @Test + public void testWriteAggregatedShardRefreshPending() { + CompletableFuture resp = new CompletableFuture<>(); + when(client.listShards(any(ListShardsRequest.class))).thenReturn(resp); + + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); + + pipeline + .apply(Create.of(100)) + .apply(ParDo.of(new GenerateTestRows())) + .apply(kinesisWrite().withPartitioner(row -> row.id().toString())); + + pipeline.run().waitUntilFinish(); + resp.complete(ListShardsResponse.builder().build()); // complete list shards after pipeline + + // while shards are unknown, each row is aggregated into an individual aggregated record + verify(client).putRecords(argThat(hasSize(100))); + verify(client).listShards(any(ListShardsRequest.class)); + verify(client).close(); + } + + @Test + public void testWriteAggregatedShardRefreshDisabled() { + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); + + pipeline + .apply(Create.of(100)) + .apply(ParDo.of(new GenerateTestRows())) + .apply( + kinesisWrite() + .withRecordAggregation(b -> b.shardRefreshInterval(ZERO)) // disable refresh + .withPartitioner(row -> row.id().toString())); + + pipeline.run().waitUntilFinish(); + + // each row is aggregated into an individual aggregated record + verify(client).putRecords(argThat(hasSize(100))); + verify(client, times(0)).listShards(any(ListShardsRequest.class)); // disabled + verify(client).close(); + } + + @Test + public void testWriteAggregatedUsingExplicitPartitioner() { + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); + + pipeline + .apply(Create.of(100)) + .apply(ParDo.of(new GenerateTestRows())) + .apply(kinesisWrite().withPartitioner(explicitRandomPartitioner(2))); + + pipeline.run().waitUntilFinish(); + verify(client).putRecords(argThat(hasSize(2))); // configuration of partitioner + verify(client, times(0)) + .listShards(any(ListShardsRequest.class)); // disabled for explicit partitioner + verify(client).close(); + } + @Test public void testWriteAggregatedWithMaxBytes() { - when(client.putRecords(any(PutRecordsRequest.class))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead final int expectedBytes = 20 + 3 + 500 * 10; @@ -241,8 +338,7 @@ public void testWriteAggregatedWithMaxBytes() { @Test public void testWriteAggregatedWithMaxBytesAndBatchMaxBytes() { - when(client.putRecords(any(PutRecordsRequest.class))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead final int expectedBytes = 20 + 3 + 500 * 10; @@ -267,8 +363,7 @@ public void testWriteAggregatedWithMaxBytesAndBatchMaxBytes() { @Test public void testWriteAggregatedWithMaxBytesAndBatchMaxRecords() { - when(client.putRecords(any(PutRecordsRequest.class))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead final int expectedBytes = 20 + 3 + 500 * 10; @@ -293,21 +388,19 @@ public void testWriteAggregatedWithMaxBytesAndBatchMaxRecords() { @Test public void testWriteAggregatedWithMaxBufferTime() throws Throwable { - when(client.putRecords(any(PutRecordsRequest.class))) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); Write write = kinesisWrite() .withPartitioner(r -> r.id().toString()) - .withRecordAggregation( - b -> b.maxBufferedTime(Duration.millis(100)).maxBufferedTimeJitter(0.2)); + .withRecordAggregation(b -> b.maxBufferedTime(millis(100)).maxBufferedTimeJitter(0.2)); + DateTimeUtils.setCurrentMillisFixed(0); AggregatedWriter writer = new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation()); writer.startBundle(); - DateTimeUtils.setCurrentMillisFixed(0); for (int i = 1; i <= 3; i++) { writer.write(TestRow.fromSeed(i)); } @@ -328,15 +421,82 @@ public void testWriteAggregatedWithMaxBufferTime() throws Throwable { writer.close(); InOrder ordered = inOrder(client); - ordered - .verify(client) - .putRecords(and(argThat(hasSize(3)), argThat(hasPartitions("1", "2", "3")))); - ordered.verify(client).putRecords(and(argThat(hasSize(2)), argThat(hasPartitions("4", "5")))); - ordered.verify(client).putRecords(and(argThat(hasSize(1)), argThat(hasPartitions("6")))); + ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3"))); + ordered.verify(client).putRecords(argThat(hasPartitions("4", "5"))); + ordered.verify(client).putRecords(argThat(hasPartitions("6"))); + ordered.verify(client).close(); + verifyNoMoreInteractions(client); + } + + @Test + public void testWriteAggregatedWithShardsRefresh() throws Throwable { + when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse)); + + Write write = + kinesisWrite() + .withPartitioner(r -> r.id().toString()) + .withRecordAggregation(b -> b.shardRefreshInterval(millis(1000))); + + DateTimeUtils.setCurrentMillisFixed(1); + AggregatedWriter writer = + new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation()); + + // initially, no shards known + for (int i = 1; i <= 3; i++) { + writer.write(TestRow.fromSeed(i)); + } + + // forward clock, trigger timeouts and refresh shards + DateTimeUtils.setCurrentMillisFixed(1500); + mockShardRanges(MIN_HASH_KEY); + + for (int i = 1; i <= 10; i++) { + writer.write(TestRow.fromSeed(i)); // all aggregated into one record + } + + writer.finishBundle(); + writer.close(); + + InOrder ordered = inOrder(client); + ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3"))); + ordered.verify(client).putRecords(argThat(hasExplicitPartitions(MIN_HASH_KEY.toString()))); ordered.verify(client).close(); + verify(client, times(2)).listShards(any(ListShardsRequest.class)); verifyNoMoreInteractions(client); } + @Test + public void testShardRangesRefresh() { + BigInteger shard1 = MIN_HASH_KEY; + BigInteger shard2 = MAX_HASH_KEY.shiftRight(2); + BigInteger shard3 = MAX_HASH_KEY.shiftRight(1); + + when(client.listShards(argThat(isRequest(STREAM, null)))) + .thenReturn(completedFuture(listShardsResponse("a", shard(shard1)))); + when(client.listShards(argThat(isRequest(null, "a")))) + .thenReturn(completedFuture(listShardsResponse("b", shard(shard2)))); + when(client.listShards(argThat(isRequest(null, "b")))) + .thenReturn(completedFuture(listShardsResponse(null, shard(shard3)))); + + PartitionKeyHasher pkHasher = new PartitionKeyHasher(); + ShardRanges shardRanges = ShardRanges.of(STREAM); + shardRanges.refreshPeriodically(client, Instant::now); + + verify(client, times(3)).listShards(any(ListShardsRequest.class)); + + BigInteger hashKeyA = pkHasher.hashKey("a"); + assertThat(shardRanges.shardAwareHashKey(hashKeyA)).isEqualTo(shard1); + assertThat(hashKeyA).isBetween(shard1, shard2.subtract(ONE)); + + BigInteger hashKeyB = pkHasher.hashKey("b"); + assertThat(shardRanges.shardAwareHashKey(hashKeyB)).isEqualTo(shard3); + assertThat(hashKeyB).isBetween(shard3, MAX_HASH_KEY); + + BigInteger hashKeyC = pkHasher.hashKey("c"); + assertThat(shardRanges.shardAwareHashKey(hashKeyC)).isEqualTo(shard2); + assertThat(hashKeyC).isBetween(shard2, shard3.subtract(ONE)); + } + @Test public void validateMissingStreamName() { assertThrown(identity()) @@ -407,6 +567,30 @@ public void validateRecordAggregationMaxBytesAboveLimit() { .hasMessage("maxBytes must be positive and <= " + MAX_BYTES_PER_RECORD); } + private Shard shard(BigInteger lowerRange) { + return Shard.builder() + .hashKeyRange(HashKeyRange.builder().startingHashKey(lowerRange.toString()).build()) + .build(); + } + + private ListShardsResponse listShardsResponse(String nextToken, Shard... shards) { + return ListShardsResponse.builder().shards(shards).nextToken(nextToken).build(); + } + + protected ArgumentMatcher isRequest(String stream, String nextToken) { + return req -> + req != null + && Objects.equal(stream, req.streamName()) + && Objects.equal(nextToken, req.nextToken()); + } + + private void mockShardRanges(BigInteger... lowerBounds) { + List shards = stream(lowerBounds).map(lower -> shard(lower)).collect(toList()); + + when(client.listShards(any(ListShardsRequest.class))) + .thenReturn(completedFuture(ListShardsResponse.builder().shards(shards).build())); + } + private ThrowableAssert assertThrown(Function, Write> writeConfig) { pipeline.enableAbandonedNodeEnforcement(false); PCollection input = mock(PCollection.class); @@ -422,8 +606,7 @@ private ThrowableAssert assertThrown(Function, Write> wr private Supplier>> captureBatchRecords(KinesisAsyncClient mock) { ArgumentCaptor cap = ArgumentCaptor.forClass(PutRecordsRequest.class); - when(mock.putRecords(cap.capture())) - .thenReturn(completedFuture(PutRecordsResponse.builder().build())); + when(mock.putRecords(cap.capture())).thenReturn(completedFuture(successResponse)); return () -> transform(cap.getAllValues(), req -> transform(req.records(), this::toTestRow)); } diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java index 0a4976caa164..ae178118d51b 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java @@ -40,6 +40,8 @@ public abstract class PutRecordsHelpers { protected static final String ERROR_CODE = "ProvisionedThroughputExceededException"; + PutRecordsResponse successResponse = PutRecordsResponse.builder().build(); + protected PutRecordsRequest anyRequest() { return any(); } @@ -62,6 +64,12 @@ protected ArgumentMatcher hasPartitions(String... partitions) && transform(req.records(), r -> r.partitionKey()).containsAll(asList(partitions)); } + protected ArgumentMatcher hasExplicitPartitions(String... partitions) { + return req -> + hasSize(partitions.length).matches(req) + && transform(req.records(), r -> r.explicitHashKey()).containsAll(asList(partitions)); + } + protected PutRecordsResponse partialSuccessResponse(int successes, int errors) { PutRecordsResultEntry e = PutRecordsResultEntry.builder().errorCode(ERROR_CODE).build(); PutRecordsResultEntry s = PutRecordsResultEntry.builder().build(); diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java index e049c4376632..c100ddca7f78 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java @@ -18,7 +18,6 @@ package org.apache.beam.sdk.io.aws2.kinesis.testing; import static java.nio.charset.StandardCharsets.UTF_8; -import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner; import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KINESIS; import java.io.Serializable; @@ -74,7 +73,7 @@ public interface ITOptions extends ITEnvironment.ITOptions { void setKinesisStream(String value); @Description("Number of shards of stream") - @Default.Integer(2) + @Default.Integer(8) Integer getKinesisShards(); void setKinesisShards(Integer count); @@ -120,7 +119,7 @@ private void runWrite() { KinesisIO.Write write = KinesisIO.write() .withStreamName(env.options().getKinesisStream()) - .withPartitioner(explicitRandomPartitioner(env.options().getKinesisShards())) + .withPartitioner(row -> row.name()) .withSerializer(testRowToBytes); if (!options.getUseRecordAggregation()) { write = write.withRecordAggregationDisabled();