diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index e1ad4b447864..59db2f95cf90 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -2725,9 +2725,6 @@ public WriteResult expand(PCollection input) { if (input.isBounded() == IsBounded.BOUNDED) { checkArgument(!getAutoSharding(), "Auto-sharding is only applicable to unbounded input."); } - if (method == Write.Method.STORAGE_WRITE_API) { - checkArgument(!getAutoSharding(), "Auto sharding not yet available for Storage API writes"); - } if (getJsonTimePartitioning() != null) { checkArgument( @@ -3023,7 +3020,8 @@ private WriteResult continueExpandTyped( getStorageApiTriggeringFrequency(bqOptions), getBigQueryServices(), getStorageApiNumStreams(bqOptions), - method == Method.STORAGE_API_AT_LEAST_ONCE); + method == Method.STORAGE_API_AT_LEAST_ONCE, + getAutoSharding()); return input.apply("StorageApiLoads", storageApiLoads); } else { throw new RuntimeException("Unexpected write method " + method); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java index a38c283f442c..20d1be9a2959 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiLoads.java @@ -52,6 +52,7 @@ public class StorageApiLoads private final BigQueryServices bqServices; private final int numShards; private final boolean allowInconsistentWrites; + private final boolean allowAutosharding; public StorageApiLoads( Coder destinationCoder, @@ -61,7 +62,8 @@ public StorageApiLoads( Duration triggeringFrequency, BigQueryServices bqServices, int numShards, - boolean allowInconsistentWrites) { + boolean allowInconsistentWrites, + boolean allowAutosharding) { this.destinationCoder = destinationCoder; this.dynamicDestinations = dynamicDestinations; this.createDisposition = createDisposition; @@ -70,6 +72,7 @@ public StorageApiLoads( this.bqServices = bqServices; this.numShards = numShards; this.allowInconsistentWrites = allowInconsistentWrites; + this.allowAutosharding = allowAutosharding; } @Override @@ -136,9 +139,6 @@ public WriteResult expandTriggered( // Handle triggered, low-latency loads into BigQuery. PCollection> inputInGlobalWindow = input.apply("rewindowIntoGlobal", Window.into(new GlobalWindows())); - - // First shard all the records. - // TODO(reuvenlax): Add autosharding support so that users don't have to pick a shard count. PCollectionTuple result = inputInGlobalWindow .apply( @@ -154,43 +154,61 @@ public WriteResult expandTriggered( successfulRowsTag, BigQueryStorageApiInsertErrorCoder.of(), successCoder)); - PCollection, StorageApiWritePayload>> shardedRecords = - result - .get(successfulRowsTag) - .apply( - "AddShard", - ParDo.of( - new DoFn< - KV, - KV, StorageApiWritePayload>>() { - int shardNumber; - - @Setup - public void setup() { - shardNumber = ThreadLocalRandom.current().nextInt(numShards); - } - - @ProcessElement - public void processElement( - @Element KV element, - OutputReceiver, StorageApiWritePayload>> o) { - DestinationT destination = element.getKey(); - ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); - buffer.putInt(++shardNumber % numShards); - o.output( - KV.of(ShardedKey.of(destination, buffer.array()), element.getValue())); - } - })) - .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); - - PCollection, Iterable>> groupedRecords = - shardedRecords.apply( - "GroupIntoBatches", - GroupIntoBatches., StorageApiWritePayload>ofByteSize( - MAX_BATCH_SIZE_BYTES, - (StorageApiWritePayload e) -> (long) e.getPayload().length) - .withMaxBufferingDuration(triggeringFrequency)); + PCollection, Iterable>> groupedRecords; + + if (this.allowAutosharding) { + groupedRecords = + result + .get(successfulRowsTag) + .apply( + "GroupIntoBatches", + GroupIntoBatches.ofByteSize( + MAX_BATCH_SIZE_BYTES, + (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency) + .withShardedKey()); + + } else { + PCollection, StorageApiWritePayload>> shardedRecords = + result + .get(successfulRowsTag) + .apply( + "AddShard", + ParDo.of( + new DoFn< + KV, + KV, StorageApiWritePayload>>() { + int shardNumber; + + @Setup + public void setup() { + shardNumber = ThreadLocalRandom.current().nextInt(numShards); + } + + @ProcessElement + public void processElement( + @Element KV element, + OutputReceiver, StorageApiWritePayload>> + o) { + DestinationT destination = element.getKey(); + ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES); + buffer.putInt(++shardNumber % numShards); + o.output( + KV.of( + ShardedKey.of(destination, buffer.array()), element.getValue())); + } + })) + .setCoder(KvCoder.of(ShardedKey.Coder.of(destinationCoder), payloadCoder)); + + groupedRecords = + shardedRecords.apply( + "GroupIntoBatches", + GroupIntoBatches., StorageApiWritePayload>ofByteSize( + MAX_BATCH_SIZE_BYTES, + (StorageApiWritePayload e) -> (long) e.getPayload().length) + .withMaxBufferingDuration(triggeringFrequency)); + } groupedRecords.apply( "StorageApiWriteSharded", new StorageApiWritesShardedRecords<>( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index 1bf78ed70d82..270149acd398 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -1190,19 +1190,21 @@ protected void writeString(org.apache.avro.Schema schema, Object datum, Encoder @Test public void testStreamingWrite() throws Exception { - streamingWrite(false); + autoShardingWrite(false, false); } @Test public void testStreamingWriteWithAutoSharding() throws Exception { - if (useStorageApi) { - return; - } - streamingWrite(true); + autoShardingWrite(true, false); } - private void streamingWrite(boolean autoSharding) throws Exception { - if (!useStreaming) { + @Test + public void testStorageApiWriteWithAutoSharding() throws Exception { + autoShardingWrite(true, true); + } + + private void autoShardingWrite(boolean autoSharding, boolean storageAPIWrite) throws Exception { + if (!useStreaming && !storageAPIWrite) { return; } BigQueryIO.Write write = @@ -1220,6 +1222,12 @@ private void streamingWrite(boolean autoSharding) throws Exception { if (autoSharding) { write = write.withAutoSharding(); } + if (storageAPIWrite) { + write = + write + .withTriggeringFrequency(Duration.standardSeconds(5)) + .withMethod(Method.STORAGE_WRITE_API); + } p.apply( Create.of( new TableRow().set("name", "a").set("number", "1"),