Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.services.kinesis.AmazonKinesis;
import com.amazonaws.services.kinesis.AmazonKinesisClientBuilder;
import com.amazonaws.services.kinesis.model.DescribeStreamRequest;
import com.amazonaws.services.kinesis.model.DescribeStreamResult;
import com.amazonaws.services.kinesis.model.ExpiredIteratorException;
import com.amazonaws.services.kinesis.model.GetRecordsRequest;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
Expand All @@ -35,11 +37,13 @@
import com.amazonaws.services.kinesis.model.ResourceNotFoundException;
import com.amazonaws.services.kinesis.model.Shard;
import com.amazonaws.services.kinesis.model.ShardIteratorType;
import com.amazonaws.services.kinesis.model.StreamDescription;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.util.AwsHostNameUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Queues;
import org.apache.druid.common.aws.AWSCredentialsConfig;
import org.apache.druid.common.aws.AWSCredentialsUtils;
Expand All @@ -62,6 +66,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -581,11 +586,31 @@ public String getPosition(StreamPartition<String> partition)
@Override
public Set<String> getPartitionIds(String stream)
{
return wrapExceptions(() -> kinesis.describeStream(stream)
.getStreamDescription()
.getShards()
.stream()
.map(Shard::getShardId).collect(Collectors.toSet()));
return wrapExceptions(
() -> {
final Set<String> retVal = new HashSet<>();
DescribeStreamRequest request = new DescribeStreamRequest();
request.setStreamName(stream);

while (request != null) {
final DescribeStreamResult result = kinesis.describeStream(request);
final StreamDescription streamDescription = result.getStreamDescription();
final List<Shard> shards = streamDescription.getShards();

for (Shard shard : shards) {
retVal.add(shard.getShardId());
}

if (streamDescription.isHasMoreShards()) {
request.setExclusiveStartShardId(Iterables.getLast(shards).getShardId());
} else {
request = null;
}
}

return retVal;
}
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.amazonaws.services.kinesis.AmazonKinesis;
import com.amazonaws.services.kinesis.AmazonKinesisClient;
import com.amazonaws.services.kinesis.model.DescribeStreamRequest;
import com.amazonaws.services.kinesis.model.DescribeStreamResult;
import com.amazonaws.services.kinesis.model.GetRecordsRequest;
import com.amazonaws.services.kinesis.model.GetRecordsResult;
Expand Down Expand Up @@ -65,12 +66,14 @@ public class KinesisRecordSupplierTest extends EasyMockSupport
private static String shard1Iterator = "1";
private static String shard0Iterator = "0";
private static AmazonKinesis kinesis;
private static DescribeStreamResult describeStreamResult;
private static DescribeStreamResult describeStreamResult0;
private static DescribeStreamResult describeStreamResult1;
private static GetShardIteratorResult getShardIteratorResult0;
private static GetShardIteratorResult getShardIteratorResult1;
private static GetRecordsResult getRecordsResult0;
private static GetRecordsResult getRecordsResult1;
private static StreamDescription streamDescription;
private static StreamDescription streamDescription0;
private static StreamDescription streamDescription1;
private static Shard shard0;
private static Shard shard1;
private static KinesisRecordSupplier recordSupplier;
Expand Down Expand Up @@ -142,12 +145,14 @@ private static ByteBuffer jb(String timestamp, String dim1, String dim2, String
public void setupTest()
{
kinesis = createMock(AmazonKinesisClient.class);
describeStreamResult = createMock(DescribeStreamResult.class);
describeStreamResult0 = createMock(DescribeStreamResult.class);
describeStreamResult1 = createMock(DescribeStreamResult.class);
getShardIteratorResult0 = createMock(GetShardIteratorResult.class);
getShardIteratorResult1 = createMock(GetShardIteratorResult.class);
getRecordsResult0 = createMock(GetRecordsResult.class);
getRecordsResult1 = createMock(GetRecordsResult.class);
streamDescription = createMock(StreamDescription.class);
streamDescription0 = createMock(StreamDescription.class);
streamDescription1 = createMock(StreamDescription.class);
shard0 = createMock(Shard.class);
shard1 = createMock(Shard.class);
recordsPerFetch = 1;
Expand All @@ -163,11 +168,17 @@ public void tearDownTest()
@Test
public void testSupplierSetup()
{
Capture<String> captured = Capture.newInstance();
expect(kinesis.describeStream(capture(captured))).andReturn(describeStreamResult).once();
expect(describeStreamResult.getStreamDescription()).andReturn(streamDescription).once();
expect(streamDescription.getShards()).andReturn(ImmutableList.of(shard0, shard1)).once();
expect(shard0.getShardId()).andReturn(shardId0).once();
final Capture<DescribeStreamRequest> capturedRequest = Capture.newInstance();

expect(kinesis.describeStream(capture(capturedRequest))).andReturn(describeStreamResult0).once();
expect(describeStreamResult0.getStreamDescription()).andReturn(streamDescription0).once();
expect(streamDescription0.getShards()).andReturn(ImmutableList.of(shard0)).once();
expect(streamDescription0.isHasMoreShards()).andReturn(true).once();
expect(shard0.getShardId()).andReturn(shardId0).times(2);
expect(kinesis.describeStream(anyObject(DescribeStreamRequest.class))).andReturn(describeStreamResult1).once();
expect(describeStreamResult1.getStreamDescription()).andReturn(streamDescription1).once();
expect(streamDescription1.getShards()).andReturn(ImmutableList.of(shard1)).once();
expect(streamDescription1.isHasMoreShards()).andReturn(false).once();
expect(shard1.getShardId()).andReturn(shardId1).once();

replayAll();
Expand Down Expand Up @@ -199,7 +210,11 @@ public void testSupplierSetup()
Assert.assertEquals(Collections.emptyList(), recordSupplier.poll(100));

verifyAll();
Assert.assertEquals(stream, captured.getValue());

final DescribeStreamRequest expectedRequest = new DescribeStreamRequest();
expectedRequest.setStreamName(stream);
expectedRequest.setExclusiveStartShardId("0");
Assert.assertEquals(expectedRequest, capturedRequest.getValue());
}

private static GetRecordsRequest generateGetRecordsReq(String shardIterator, int limit)
Expand Down