diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index e1ad4b447864..59db2f95cf90 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -2725,9 +2725,6 @@ public WriteResult expand(PCollection input) { if (input.isBounded() == IsBounded.BOUNDED) { checkArgument(!getAutoSharding(), "Auto-sharding is only applicable to unbounded input."); } - if (method == Write.Method.STORAGE_WRITE_API) { - checkArgument(!getAutoSharding(), "Auto sharding not yet available for Storage API writes"); - } if (getJsonTimePartitioning() != null) { checkArgument( @@ -3023,7 +3020,8 @@ private WriteResult continueExpandTyped( getStorageApiTriggeringFrequency(bqOptions), getBigQueryServices(), getStorageApiNumStreams(bqOptions), - method == Method.STORAGE_API_AT_LEAST_ONCE); + method == Method.STORAGE_API_AT_LEAST_ONCE, + getAutoSharding()); return input.apply("StorageApiLoads", storageApiLoads); } else { throw new RuntimeException("Unexpected write method " + method); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index a38c283f442c..e48b9a196902 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -52,6 +52,7 @@ public class StorageApiLoads private final BigQueryServices bqServices; private final int numShards; private final boolean allowInconsistentWrites; + private final boolean allowAutosharding; public StorageApiLoads( Coder destinationCoder, @@ -61,7 +62,8 @@ public StorageApiLoads( Duration triggeringFrequency, BigQueryServices bqServices, int numShards, - boolean allowInconsistentWrites) { + boolean allowInconsistentWrites, + boolean allowAutosharding) { this.destinationCoder = destinationCoder; this.dynamicDestinations = dynamicDestinations; this.createDisposition = createDisposition; @@ -70,6 +72,7 @@ public StorageApiLoads( this.bqServices = bqServices; this.numShards = numShards; this.allowInconsistentWrites = allowInconsistentWrites; + this.allowAutosharding = allowAutosharding; } @Override @@ -136,9 +139,6 @@ public WriteResult expandTriggered( // Handle triggered, low-latency loads into BigQuery. PCollection> inputInGlobalWindow = input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); - - // First shard all the records. - // TODO(reuvenlax): Add autosharding support so that users don't have to pick a shard count. PCollectionTuple result = inputInGlobalWindow .apply( @@ -154,43 +154,33 @@ public WriteResult expandTriggered( successfulRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder)); - PCollection, StorageApiWritePayload>> shardedRecords = - result - .get(successfulRowsTag) - .apply( - "AddShard", - ParDo.of( - new DoFn< - KV, - KV, StorageApiWritePayload>>() { - int shardNumber; - - @Setup - public void setup() { - shardNumber = ThreadLocalRandom.current().nextInt(numShards); - } - - @ProcessElement - public void processElement( - @Element KV element, - OutputReceiver, StorageApiWritePayload>> o) { - DestinationT destination = element.getKey(); - ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); - buffer.putInt(++shardNumber % numShards); - o.output( - KV.of(ShardedKey.of(destination, buffer.array()), element.getValue())); - } - })) - .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); - - PCollection, Iterable>> groupedRecords = - shardedRecords.apply( - "GroupIntoBatches", - GroupIntoBatches., StorageApiWritePayload>ofByteSize( - MAX_BATCH_SIZE_BYTES, - (StorageApiWritePayload e) -> (long) e.getPayload().length) - .withMaxBufferingDuration(triggeringFrequency)); + PCollection, Iterable>> groupedRecords; + + if (this.allowAutosharding) { + groupedRecords = + result + .get(successfulRowsTag) + .apply( + "GroupIntoBatches", + GroupIntoBatches.ofByteSize( + MAX_BATCH_SIZE_BYTES, + (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency) + .withShardedKey()); + + } else { + PCollection, StorageApiWritePayload>> shardedRecords = + createShardedKeyValuePairs(result) + .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); + groupedRecords = + shardedRecords.apply( + "GroupIntoBatches", + GroupIntoBatches., StorageApiWritePayload>ofByteSize( + MAX_BATCH_SIZE_BYTES, + (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency)); + } groupedRecords.apply( "StorageApiWriteSharded", new StorageApiWritesShardedRecords<>( @@ -207,6 +197,35 @@ public void processElement( result.get(failedRowsTag)); } + private PCollection, StorageApiWritePayload>> + createShardedKeyValuePairs(PCollectionTuple pCollection) { + return pCollection + .get(successfulRowsTag) + .apply( + "AddShard", + ParDo.of( + new DoFn< + KV, + KV, StorageApiWritePayload>>() { + int shardNumber; + + @Setup + public void setup() { + shardNumber = ThreadLocalRandom.current().nextInt(numShards); + } + + @ProcessElement + public void processElement( + @Element KV element, + OutputReceiver, StorageApiWritePayload>> o) { + DestinationT destination = element.getKey(); + ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); + buffer.putInt(++shardNumber % numShards); + o.output(KV.of(ShardedKey.of(destination, buffer.array()), element.getValue())); + } + })); + } + public WriteResult expandUntriggered( PCollection> input, Coder> successCoder) { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 8c8a93a40e58..18a5b1c0db8b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -1195,9 +1195,6 @@ public void testStreamingWrite() throws Exception { @Test public void testStreamingWriteWithAutoSharding() throws Exception { - if (useStorageApi) { - return; - } streamingWrite(true); } @@ -1240,6 +1237,54 @@ private void streamingWrite(boolean autoSharding) throws Exception { new TableRow().set("name", "d").set("number", "4"))); } + @Test + public void testStorageApiWriteWithAutoSharding() throws Exception { + storageWrite(true); + } + + private void storageWrite(boolean autoSharding) throws Exception { + if (!useStorageApi) { + return; + } + BigQueryIO.Write write = + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table-id") + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))) + .withTestServices(fakeBqServices) + .withoutValidation(); + if (autoSharding) { + write = + write + .withAutoSharding() + .withTriggeringFrequency(Duration.standardSeconds(5)) + .withMethod(Method.STORAGE_WRITE_API); + } + p.apply( + Create.of( + new TableRow().set("name", "a").set("number", "1"), + new TableRow().set("name", "b").set("number", "2"), + new TableRow().set("name", "c").set("number", "3"), + new TableRow().set("name", "d").set("number", "4")) + .withCoder(TableRowJsonCoder.of())) + .setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED) + .apply("WriteToBQ", write); + p.run(); + + assertThat( + fakeDatasetService.getAllRows("project-id", "dataset-id", "table-id"), + containsInAnyOrder( + new TableRow().set("name", "a").set("number", "1"), + new TableRow().set("name", "b").set("number", "2"), + new TableRow().set("name", "c").set("number", "3"), + new TableRow().set("name", "d").set("number", "4"))); + } + @DefaultSchema(JavaFieldSchema.class) static class SchemaPojo { final String name; @@ -1927,7 +1972,8 @@ public void testWriteValidateFailsBothFormatFunctions() { thrown.expect(IllegalArgumentException.class); thrown.expectMessage( - "Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not both."); + "Only one of withFormatFunction or withAvroFormatFunction/withAvroWriter maybe set, not" + + " both."); p.apply(Create.empty(INPUT_RECORD_CODER)) .apply( BigQueryIO.write()