diff --git a/extensions-core/kinesis-indexing-service/src/main/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplier.java b/extensions-core/kinesis-indexing-service/src/main/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplier.java index a05eade5f4cd..56b7d723201d 100644 --- a/extensions-core/kinesis-indexing-service/src/main/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplier.java +++ b/extensions-core/kinesis-indexing-service/src/main/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplier.java @@ -26,6 +26,8 @@ import com.amazonaws.client.builder.AwsClientBuilder; import com.amazonaws.services.kinesis.AmazonKinesis; import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder; +import com.amazonaws.services.kinesis.model.DescribeStreamRequest; +import com.amazonaws.services.kinesis.model.DescribeStreamResult; import com.amazonaws.services.kinesis.model.ExpiredIteratorException; import com.amazonaws.services.kinesis.model.GetRecordsRequest; import com.amazonaws.services.kinesis.model.GetRecordsResult; @@ -35,11 +37,13 @@ import com.amazonaws.services.kinesis.model.ResourceNotFoundException; import com.amazonaws.services.kinesis.model.Shard; import com.amazonaws.services.kinesis.model.ShardIteratorType; +import com.amazonaws.services.kinesis.model.StreamDescription; import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; import com.amazonaws.util.AwsHostNameUtils; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.Queues; import org.apache.druid.common.aws.AWSCredentialsConfig; import org.apache.druid.common.aws.AWSCredentialsUtils; @@ -61,12 +65,14 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.Executors; @@ -579,12 +585,31 @@ public String getPosition(StreamPartition partition) @Override public Set getPartitionIds(String stream) { - checkIfClosed(); - return kinesis.describeStream(stream) - .getStreamDescription() - .getShards() - .stream() - .map(Shard::getShardId).collect(Collectors.toSet()); + return wrapExceptions( + () -> { + final Set retVal = new HashSet<>(); + DescribeStreamRequest request = new DescribeStreamRequest(); + request.setStreamName(stream); + + while (request != null) { + final DescribeStreamResult result = kinesis.describeStream(request); + final StreamDescription streamDescription = result.getStreamDescription(); + final List shards = streamDescription.getShards(); + + for (Shard shard : shards) { + retVal.add(shard.getShardId()); + } + + if (streamDescription.isHasMoreShards()) { + request.setExclusiveStartShardId(Iterables.getLast(shards).getShardId()); + } else { + request = null; + } + } + + return retVal; + } + ); } @Override @@ -624,12 +649,12 @@ private void seekInternal(StreamPartition partition, String sequenceNumb sequenceNumber != null ? sequenceNumber : iteratorEnum.toString() ); - resource.shardIterator = kinesis.getShardIterator( + resource.shardIterator = wrapExceptions(() -> kinesis.getShardIterator( partition.getStream(), partition.getPartitionId(), iteratorEnum.toString(), sequenceNumber - ).getShardIterator(); + ).getShardIterator()); checkPartitionsStarted = true; } @@ -655,10 +680,10 @@ private void filterBufferAndResetFetchRunnable(Set> part // filter records in buffer and only retain ones whose partition was not seeked BlockingQueue> newQ = new LinkedBlockingQueue<>(recordBufferSize); - records - .stream() - .filter(x -> !partitions.contains(x.getStreamPartition())) - .forEachOrdered(newQ::offer); + + records.stream() + .filter(x -> !partitions.contains(x.getStreamPartition())) + .forEachOrdered(newQ::offer); records = newQ; @@ -670,20 +695,11 @@ private void filterBufferAndResetFetchRunnable(Set> part @Nullable private String getSequenceNumberInternal(StreamPartition partition, ShardIteratorType iteratorEnum) { - - String shardIterator = null; - try { - shardIterator = kinesis.getShardIterator( - partition.getStream(), - partition.getPartitionId(), - iteratorEnum.toString() - ).getShardIterator(); - } - catch (ResourceNotFoundException e) { - log.warn(e, "Caught ResourceNotFoundException while getting shardIterator"); - } - - return getSequenceNumberInternal(partition, shardIterator); + return wrapExceptions(() -> getSequenceNumberInternal( + partition, + kinesis.getShardIterator(partition.getStream(), partition.getPartitionId(), iteratorEnum.toString()) + .getShardIterator() + )); } @Nullable @@ -774,6 +790,16 @@ private static byte[] toByteArray(final ByteBuffer buffer) } } + private static T wrapExceptions(Callable callable) + { + try { + return callable.call(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + @VisibleForTesting public int bufferSize() { diff --git a/extensions-core/kinesis-indexing-service/src/test/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplierTest.java b/extensions-core/kinesis-indexing-service/src/test/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplierTest.java index bd1fbe8b824a..6dd7de9c6c51 100644 --- a/extensions-core/kinesis-indexing-service/src/test/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplierTest.java +++ b/extensions-core/kinesis-indexing-service/src/test/java/org/apache/druid/indexing/kinesis/KinesisRecordSupplierTest.java @@ -21,6 +21,7 @@ import com.amazonaws.services.kinesis.AmazonKinesis; import com.amazonaws.services.kinesis.AmazonKinesisClient; +import com.amazonaws.services.kinesis.model.DescribeStreamRequest; import com.amazonaws.services.kinesis.model.DescribeStreamResult; import com.amazonaws.services.kinesis.model.GetRecordsRequest; import com.amazonaws.services.kinesis.model.GetRecordsResult; @@ -65,12 +66,14 @@ public class KinesisRecordSupplierTest extends EasyMockSupport private static String shard1Iterator = "1"; private static String shard0Iterator = "0"; private static AmazonKinesis kinesis; - private static DescribeStreamResult describeStreamResult; + private static DescribeStreamResult describeStreamResult0; + private static DescribeStreamResult describeStreamResult1; private static GetShardIteratorResult getShardIteratorResult0; private static GetShardIteratorResult getShardIteratorResult1; private static GetRecordsResult getRecordsResult0; private static GetRecordsResult getRecordsResult1; - private static StreamDescription streamDescription; + private static StreamDescription streamDescription0; + private static StreamDescription streamDescription1; private static Shard shard0; private static Shard shard1; private static KinesisRecordSupplier recordSupplier; @@ -142,12 +145,14 @@ private static ByteBuffer jb(String timestamp, String dim1, String dim2, String public void setupTest() { kinesis = createMock(AmazonKinesisClient.class); - describeStreamResult = createMock(DescribeStreamResult.class); + describeStreamResult0 = createMock(DescribeStreamResult.class); + describeStreamResult1 = createMock(DescribeStreamResult.class); getShardIteratorResult0 = createMock(GetShardIteratorResult.class); getShardIteratorResult1 = createMock(GetShardIteratorResult.class); getRecordsResult0 = createMock(GetRecordsResult.class); getRecordsResult1 = createMock(GetRecordsResult.class); - streamDescription = createMock(StreamDescription.class); + streamDescription0 = createMock(StreamDescription.class); + streamDescription1 = createMock(StreamDescription.class); shard0 = createMock(Shard.class); shard1 = createMock(Shard.class); recordsPerFetch = 1; @@ -163,11 +168,17 @@ public void tearDownTest() @Test public void testSupplierSetup() { - Capture captured = Capture.newInstance(); - expect(kinesis.describeStream(capture(captured))).andReturn(describeStreamResult).once(); - expect(describeStreamResult.getStreamDescription()).andReturn(streamDescription).once(); - expect(streamDescription.getShards()).andReturn(ImmutableList.of(shard0, shard1)).once(); - expect(shard0.getShardId()).andReturn(shardId0).once(); + final Capture capturedRequest = Capture.newInstance(); + + expect(kinesis.describeStream(capture(capturedRequest))).andReturn(describeStreamResult0).once(); + expect(describeStreamResult0.getStreamDescription()).andReturn(streamDescription0).once(); + expect(streamDescription0.getShards()).andReturn(ImmutableList.of(shard0)).once(); + expect(streamDescription0.isHasMoreShards()).andReturn(true).once(); + expect(shard0.getShardId()).andReturn(shardId0).times(2); + expect(kinesis.describeStream(anyObject(DescribeStreamRequest.class))).andReturn(describeStreamResult1).once(); + expect(describeStreamResult1.getStreamDescription()).andReturn(streamDescription1).once(); + expect(streamDescription1.getShards()).andReturn(ImmutableList.of(shard1)).once(); + expect(streamDescription1.isHasMoreShards()).andReturn(false).once(); expect(shard1.getShardId()).andReturn(shardId1).once(); replayAll(); @@ -199,7 +210,11 @@ public void testSupplierSetup() Assert.assertEquals(Collections.emptyList(), recordSupplier.poll(100)); verifyAll(); - Assert.assertEquals(stream, captured.getValue()); + + final DescribeStreamRequest expectedRequest = new DescribeStreamRequest(); + expectedRequest.setStreamName(stream); + expectedRequest.setExclusiveStartShardId("0"); + Assert.assertEquals(expectedRequest, capturedRequest.getValue()); } private static GetRecordsRequest generateGetRecordsReq(String shardIterator, int limit)