Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2725,9 +2725,6 @@ public WriteResult expand(PCollection<T> input) {
if (input.isBounded() == IsBounded.BOUNDED) {
checkArgument(!getAutoSharding(), "Auto-sharding is only applicable to unbounded input.");
}
if (method == Write.Method.STORAGE_WRITE_API) {
checkArgument(!getAutoSharding(), "Auto sharding not yet available for Storage API writes");
}

if (getJsonTimePartitioning() != null) {
checkArgument(
Expand Down Expand Up @@ -3023,7 +3020,8 @@ private <DestinationT> WriteResult continueExpandTyped(
getStorageApiTriggeringFrequency(bqOptions),
getBigQueryServices(),
getStorageApiNumStreams(bqOptions),
method == Method.STORAGE_API_AT_LEAST_ONCE);
method == Method.STORAGE_API_AT_LEAST_ONCE,
getAutoSharding());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does autosharding default to true or false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Write method sets autosharding as false by default:

. The user has to explicitly enable it

return input.apply("StorageApiLoads", storageApiLoads);
} else {
throw new RuntimeException("Unexpected write method " + method);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public class StorageApiLoads<DestinationT, ElementT>
private final BigQueryServices bqServices;
private final int numShards;
private final boolean allowInconsistentWrites;
private final boolean allowAutosharding;

public StorageApiLoads(
Coder<DestinationT> destinationCoder,
Expand All @@ -61,7 +62,8 @@ public StorageApiLoads(
Duration triggeringFrequency,
BigQueryServices bqServices,
int numShards,
boolean allowInconsistentWrites) {
boolean allowInconsistentWrites,
boolean allowAutosharding) {
this.destinationCoder = destinationCoder;
this.dynamicDestinations = dynamicDestinations;
this.createDisposition = createDisposition;
Expand All @@ -70,6 +72,7 @@ public StorageApiLoads(
this.bqServices = bqServices;
this.numShards = numShards;
this.allowInconsistentWrites = allowInconsistentWrites;
this.allowAutosharding = allowAutosharding;
}

@Override
Expand Down Expand Up @@ -136,9 +139,6 @@ public WriteResult expandTriggered(
// Handle triggered, low-latency loads into BigQuery.
PCollection<KV<DestinationT, ElementT>> inputInGlobalWindow =
input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows()));

// First shard all the records.
// TODO(reuvenlax): Add autosharding support so that users don't have to pick a shard count.
PCollectionTuple result =
inputInGlobalWindow
.apply(
Expand All @@ -154,43 +154,33 @@ public WriteResult expandTriggered(
successfulRowsTag,
BigQueryStorageApiInsertErrorCoder.of(),
successCoder));
PCollection<KV<ShardedKey<DestinationT>, StorageApiWritePayload>> shardedRecords =
result
.get(successfulRowsTag)
.apply(
"AddShard",
ParDo.of(
new DoFn<
KV<DestinationT, StorageApiWritePayload>,
KV<ShardedKey<DestinationT>, StorageApiWritePayload>>() {
int shardNumber;

@Setup
public void setup() {
shardNumber = ThreadLocalRandom.current().nextInt(numShards);
}

@ProcessElement
public void processElement(
@Element KV<DestinationT, StorageApiWritePayload> element,
OutputReceiver<KV<ShardedKey<DestinationT>, StorageApiWritePayload>> o) {
DestinationT destination = element.getKey();
ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES);
buffer.putInt(++shardNumber % numShards);
o.output(
KV.of(ShardedKey.of(destination, buffer.array()), element.getValue()));
}
}))
.setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder));

PCollection<KV<ShardedKey<DestinationT>, Iterable<StorageApiWritePayload>>> groupedRecords =
shardedRecords.apply(
"GroupIntoBatches",
GroupIntoBatches.<ShardedKey<DestinationT>, StorageApiWritePayload>ofByteSize(
MAX_BATCH_SIZE_BYTES,
(StorageApiWritePayload e) -> (long) e.getPayload().length)
.withMaxBufferingDuration(triggeringFrequency));

PCollection<KV<ShardedKey<DestinationT>, Iterable<StorageApiWritePayload>>> groupedRecords;

if (this.allowAutosharding) {
groupedRecords =
result
.get(successfulRowsTag)
.apply(
"GroupIntoBatches",
GroupIntoBatches.<DestinationT, StorageApiWritePayload>ofByteSize(
MAX_BATCH_SIZE_BYTES,
(StorageApiWritePayload e) -> (long) e.getPayload().length)
.withMaxBufferingDuration(triggeringFrequency)
.withShardedKey());

} else {
PCollection<KV<ShardedKey<DestinationT>, StorageApiWritePayload>> shardedRecords =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we push this to a helper method to improve readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored, PTAL.

createShardedKeyValuePairs(result)
.setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder));
groupedRecords =
shardedRecords.apply(
"GroupIntoBatches",
GroupIntoBatches.<ShardedKey<DestinationT>, StorageApiWritePayload>ofByteSize(
MAX_BATCH_SIZE_BYTES,
(StorageApiWritePayload e) -> (long) e.getPayload().length)
.withMaxBufferingDuration(triggeringFrequency));
}
groupedRecords.apply(
"StorageApiWriteSharded",
new StorageApiWritesShardedRecords<>(
Expand All @@ -207,6 +197,35 @@ public void processElement(
result.get(failedRowsTag));
}

private PCollection<KV<ShardedKey<DestinationT>, StorageApiWritePayload>>
createShardedKeyValuePairs(PCollectionTuple pCollection) {
return pCollection
.get(successfulRowsTag)
.apply(
"AddShard",
ParDo.of(
new DoFn<
KV<DestinationT, StorageApiWritePayload>,
KV<ShardedKey<DestinationT>, StorageApiWritePayload>>() {
int shardNumber;

@Setup
public void setup() {
shardNumber = ThreadLocalRandom.current().nextInt(numShards);
}

@ProcessElement
public void processElement(
@Element KV<DestinationT, StorageApiWritePayload> element,
OutputReceiver<KV<ShardedKey<DestinationT>, StorageApiWritePayload>> o) {
DestinationT destination = element.getKey();
ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES);
buffer.putInt(++shardNumber % numShards);
o.output(KV.of(ShardedKey.of(destination, buffer.array()), element.getValue()));
}
}));
}

public WriteResult expandUntriggered(
PCollection<KV<DestinationT, ElementT>> input,
Coder<KV<DestinationT, StorageApiWritePayload>> successCoder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1195,9 +1195,6 @@ public void testStreamingWrite() throws Exception {

@Test
public void testStreamingWriteWithAutoSharding() throws Exception {
if (useStorageApi) {
return;
}
streamingWrite(true);
}

Expand Down Expand Up @@ -1240,6 +1237,54 @@ private void streamingWrite(boolean autoSharding) throws Exception {
new TableRow().set("name", "d").set("number", "4")));
}

@Test
public void testStorageApiWriteWithAutoSharding() throws Exception {
storageWrite(true);
}

private void storageWrite(boolean autoSharding) throws Exception {
if (!useStorageApi) {
return;
}
BigQueryIO.Write<TableRow> write =
BigQueryIO.writeTableRows()
.to("project-id:dataset-id.table-id")
.withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)
.withSchema(
new TableSchema()
.setFields(
ImmutableList.of(
new TableFieldSchema().setName("name").setType("STRING"),
new TableFieldSchema().setName("number").setType("INTEGER"))))
.withTestServices(fakeBqServices)
.withoutValidation();
if (autoSharding) {
write =
write
.withAutoSharding()
.withTriggeringFrequency(Duration.standardSeconds(5))
.withMethod(Method.STORAGE_WRITE_API);
}
p.apply(
Create.of(
new TableRow().set("name", "a").set("number", "1"),
new TableRow().set("name", "b").set("number", "2"),
new TableRow().set("name", "c").set("number", "3"),
new TableRow().set("name", "d").set("number", "4"))
.withCoder(TableRowJsonCoder.of()))
.setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED)
.apply("WriteToBQ", write);
p.run();

assertThat(
fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"),
containsInAnyOrder(
new TableRow().set("name", "a").set("number", "1"),
new TableRow().set("name", "b").set("number", "2"),
new TableRow().set("name", "c").set("number", "3"),
new TableRow().set("name", "d").set("number", "4")));
}

@DefaultSchema(JavaFieldSchema.class)
static class SchemaPojo {
final String name;
Expand Down Expand Up @@ -1927,7 +1972,8 @@ public void testWriteValidateFailsBothFormatFunctions() {

thrown.expect(IllegalArgumentException.class);
thrown.expectMessage(
"Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not both.");
"Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not"
+ " both.");
p.apply(Create.empty(INPUT_RECORD_CODER))
.apply(
BigQueryIO.<InputRecord>write()
Expand Down