diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java index c647b0e70bbf..3d3eb2aac275 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java @@ -1043,6 +1043,8 @@ public static FileNaming relativeFileNaming( abstract boolean getNoSpilling(); + abstract @Nullable Integer getMaxNumWritersPerBundle(); + abstract @Nullable ErrorHandler getBadRecordErrorHandler(); abstract Builder toBuilder(); @@ -1093,6 +1095,9 @@ abstract Builder setSharding( abstract Builder setNoSpilling(boolean noSpilling); + abstract Builder setMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle); + abstract Builder setBadRecordErrorHandler( @Nullable ErrorHandler badRecordErrorHandler); @@ -1326,6 +1331,15 @@ public Write withNoSpilling() { return toBuilder().setNoSpilling(true).build(); } + /** + * Set the maximum number of writers created in a bundle before spilling to shuffle. See {@link + * WriteFiles#withMaxNumWritersPerBundle()}. + */ + public Write withMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle) { + return toBuilder().setMaxNumWritersPerBundle(maxNumWritersPerBundle).build(); + } + /** * Configures a new {@link Write} with an ErrorHandler. For configuring an ErrorHandler, see * {@link ErrorHandler}. Whenever a record is formatted, or a lookup for a dynamic destination @@ -1424,6 +1438,9 @@ public WriteFilesResult expand(PCollection input) { resolvedSpec.setIgnoreWindowing(getIgnoreWindowing()); resolvedSpec.setAutoSharding(getAutoSharding()); resolvedSpec.setNoSpilling(getNoSpilling()); + if (getMaxNumWritersPerBundle() != null) { + resolvedSpec.setMaxNumWritersPerBundle(getMaxNumWritersPerBundle()); + } Write resolved = resolvedSpec.build(); WriteFiles writeFiles = @@ -1445,6 +1462,9 @@ public WriteFilesResult expand(PCollection input) { if (getNoSpilling()) { writeFiles = writeFiles.withNoSpilling(); } + if (getMaxNumWritersPerBundle() != null) { + writeFiles = writeFiles.withMaxNumWritersPerBundle(getMaxNumWritersPerBundle()); + } if (getBadRecordErrorHandler() != null) { writeFiles = writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java index dc76d9016577..52982e2fe160 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java @@ -274,6 +274,9 @@ public abstract static class Write extends PTransform, PDone /** Whether to skip the spilling of data caused by having maxNumWritersPerBundle. */ abstract boolean getNoSpilling(); + /** Maximum number of writers created in a bundle before spilling to shuffle. */ + abstract @Nullable Integer getMaxNumWritersPerBundle(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -290,6 +293,8 @@ abstract static class Builder { abstract Builder setNoSpilling(boolean noSpilling); + abstract Builder setMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle); + abstract Write build(); } @@ -388,6 +393,11 @@ public Write withNoSpilling() { return toBuilder().setNoSpilling(true).build(); } + /** See {@link WriteFiles#withMaxNumWritersPerBundle()}. */ + public Write withMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle) { + return toBuilder().setMaxNumWritersPerBundle(maxNumWritersPerBundle).build(); + } + @Override public PDone expand(PCollection input) { checkState( @@ -403,6 +413,9 @@ public PDone expand(PCollection input) { if (getNoSpilling()) { write = write.withNoSpilling(); } + if (getMaxNumWritersPerBundle() != null) { + write = write.withMaxNumWritersPerBundle(getMaxNumWritersPerBundle()); + } input.apply("Write", write); return PDone.in(input.getPipeline()); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformConfiguration.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformConfiguration.java index e123b5c0847e..8167d4a399e3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformConfiguration.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformConfiguration.java @@ -80,6 +80,11 @@ public static TFRecordWriteSchemaTransformConfiguration.Builder builder() { @Nullable public abstract Boolean getNoSpilling(); + @SchemaFieldDescription( + "Maximum number of writers created in a bundle before spilling to shuffle.") + @Nullable + public abstract Integer getMaxNumWritersPerBundle(); + @SchemaFieldDescription("This option specifies whether and where to output unwritable rows.") @Nullable public abstract ErrorHandling getErrorHandling(); @@ -99,6 +104,8 @@ public abstract static class Builder { public abstract Builder setNoSpilling(Boolean value); + public abstract Builder setMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle); + public abstract Builder setErrorHandling(ErrorHandling errorHandling); /** Builds the {@link TFRecordWriteSchemaTransformConfiguration} configuration. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java index b45d8584be5e..bc9b7bbeac66 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordWriteSchemaTransformProvider.java @@ -132,6 +132,10 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { if (Boolean.TRUE.equals(configuration.getNoSpilling())) { writeTransform = writeTransform.withNoSpilling(); } + if (configuration.getMaxNumWritersPerBundle() != null) { + writeTransform = + writeTransform.withMaxNumWritersPerBundle(configuration.getMaxNumWritersPerBundle()); + } // Obtain input schema and verify only one field and its bytes Schema inputSchema = input.get(INPUT).getSchema(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java index f9844f3a73a5..f26f23b4656b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java @@ -718,6 +718,9 @@ public abstract static class TypedWrite /** Whether to skip the spilling of data caused by having maxNumWritersPerBundle. */ abstract boolean getNoSpilling(); + /** Maximum number of writers created in a bundle before spilling to shuffle. */ + abstract @Nullable Integer getMaxNumWritersPerBundle(); + /** Whether to skip writing any output files if the PCollection is empty. */ abstract boolean getSkipIfEmpty(); @@ -779,6 +782,9 @@ abstract Builder setBatchMaxBufferingDuration( abstract Builder setNoSpilling(boolean noSpilling); + abstract Builder setMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle); + abstract Builder setSkipIfEmpty(boolean noSpilling); abstract Builder setWritableByteChannelFactory( @@ -1062,6 +1068,12 @@ public TypedWrite withNoSpilling() { return toBuilder().setNoSpilling(true).build(); } + /** Set the maximum number of writers created in a bundle before spilling to shuffle. */ + public TypedWrite withMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle) { + return toBuilder().setMaxNumWritersPerBundle(maxNumWritersPerBundle).build(); + } + /** See {@link FileIO.Write#withBadRecordErrorHandler(ErrorHandler)} for details on usage. */ public TypedWrite withBadRecordErrorHandler( ErrorHandler errorHandler) { @@ -1161,6 +1173,9 @@ public WriteFilesResult expand(PCollection input) { if (getNoSpilling()) { write = write.withNoSpilling(); } + if (getMaxNumWritersPerBundle() != null) { + write = write.withMaxNumWritersPerBundle(getMaxNumWritersPerBundle()); + } if (getBadRecordErrorHandler() != null) { write = write.withBadRecordErrorHandler(getBadRecordErrorHandler()); } @@ -1187,6 +1202,7 @@ public void populateDisplayData(DisplayData.Builder builder) { builder .addIfNotNull( DisplayData.item("numShards", getNumShards()).withLabel("Maximum Output Shards")) + .addIfNotNull(DisplayData.item("maxNumWritersPerBundle", getMaxNumWritersPerBundle())) .addIfNotNull( DisplayData.item("tempDirectory", getTempDirectory()) .withLabel("Directory for temporary files")) @@ -1348,6 +1364,11 @@ public Write withNoSpilling() { return new Write(inner.withNoSpilling()); } + /** See {@link TypedWrite#withMaxNumWritersPerBundle(Integer)}. */ + public Write withMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle) { + return new Write(inner.withMaxNumWritersPerBundle(maxNumWritersPerBundle)); + } + /** See {@link TypedWrite#withBatchSize(Integer)}. */ public Write withBatchSize(@Nullable Integer batchSize) { return new Write(inner.withBatchSize(batchSize)); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java index cb48931958ce..c1b56a2b4458 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/WriteFiles.java @@ -292,6 +292,13 @@ public WriteFiles withNumShards( /** Set the maximum number of writers created in a bundle before spilling to shuffle. */ public WriteFiles withMaxNumWritersPerBundle( int maxNumWritersPerBundle) { + checkArgument( + getMaxNumWritersPerBundle() != -1, + "Cannot use withMaxNumWritersPerBundle() after withNoSpilling() has been set."); + checkArgument( + maxNumWritersPerBundle > 0 && maxNumWritersPerBundle <= DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE, + "maxNumWritersPerBundle must be greater than 0 and less than or equal to %s", + DEFAULT_MAX_NUM_WRITERS_PER_BUNDLE); return toBuilder().setMaxNumWritersPerBundle(maxNumWritersPerBundle).build(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java index 5adbcbb8152f..9c067a533e0b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordSchemaTransformProviderTest.java @@ -276,6 +276,7 @@ public void testWriteFindTransformAndMakeItWork() { "num_shards", "compression", "no_spilling", + "max_num_writers_per_bundle", "error_handling"), tfrecordProvider.configurationSchema().getFields().stream() .map(field -> field.getName()) diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java index 695ff4474d71..eba0f793265d 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOWriteTest.java @@ -650,6 +650,7 @@ public void testWriteDisplayData() { .withSuffix("bar") .withShardNameTemplate("-SS-of-NN-") .withNumShards(100) + .withMaxNumWritersPerBundle(5) .withFooter("myFooter") .withHeader("myHeader"); @@ -661,6 +662,7 @@ public void testWriteDisplayData() { assertThat(displayData, hasDisplayItem("fileFooter", "myFooter")); assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-")); assertThat(displayData, hasDisplayItem("numShards", 100)); + assertThat(displayData, hasDisplayItem("maxNumWritersPerBundle", 5)); assertThat(displayData, hasDisplayItem("writableByteChannelFactory", "UNCOMPRESSED")); } diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java index 2e4939560ad1..2ddde14bcc26 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroIO.java @@ -1426,6 +1426,8 @@ public abstract static class TypedWrite abstract boolean getNoSpilling(); + abstract @Nullable Integer getMaxNumWritersPerBundle(); + abstract @Nullable FilenamePolicy getFilenamePolicy(); abstract @Nullable DynamicAvroDestinations @@ -1483,6 +1485,9 @@ public Builder setGenericRecords(boolean genericRe abstract Builder setNoSpilling(boolean noSpilling); + abstract Builder setMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle); + abstract Builder setFilenamePolicy( FilenamePolicy filenamePolicy); @@ -1690,6 +1695,12 @@ public TypedWrite withNoSpilling() { return toBuilder().setNoSpilling(true).build(); } + /** See {@link WriteFiles#withMaxNumWritersPerBundle()}. */ + public TypedWrite withMaxNumWritersPerBundle( + @Nullable Integer maxNumWritersPerBundle) { + return toBuilder().setMaxNumWritersPerBundle(maxNumWritersPerBundle).build(); + } + /** Writes to Avro file(s) compressed using specified codec. */ public TypedWrite withCodec(CodecFactory codec) { return toBuilder().setCodec(new SerializableAvroCodecFactory(codec)).build(); @@ -1799,6 +1810,9 @@ public WriteFilesResult expand(PCollection input) { if (getNoSpilling()) { write = write.withNoSpilling(); } + if (getMaxNumWritersPerBundle() != null) { + write = write.withMaxNumWritersPerBundle(getMaxNumWritersPerBundle()); + } if (getBadRecordErrorHandler() != null) { write = write.withBadRecordErrorHandler(getBadRecordErrorHandler()); } diff --git a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java index fc2b68c0a893..d71299cceb69 100644 --- a/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java +++ b/sdks/java/io/csv/src/main/java/org/apache/beam/sdk/io/csv/CsvIO.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.commons.csv.CSVFormat; +import org.checkerframework.checker.nullness.qual.Nullable; /** * {@link PTransform}s for reading and writing CSV files. @@ -550,6 +551,16 @@ public Write withNoSpilling() { return toBuilder().setTextIOWrite(getTextIOWrite().withNoSpilling()).build(); } + /** + * Set the maximum number of writers created in a bundle before spilling to shuffle. See {@link + * WriteFiles#withMaxNumWritersPerBundle()}. + */ + public Write withMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle) { + return toBuilder() + .setTextIOWrite(getTextIOWrite().withMaxNumWritersPerBundle(maxNumWritersPerBundle)) + .build(); + } + /** * Specifies to use a given fixed number of shards per window. See {@link * TextIO.Write#withNumShards}. diff --git a/sdks/java/io/json/src/main/java/org/apache/beam/sdk/io/json/JsonIO.java b/sdks/java/io/json/src/main/java/org/apache/beam/sdk/io/json/JsonIO.java index 3abb29a80427..1cb576e8f420 100644 --- a/sdks/java/io/json/src/main/java/org/apache/beam/sdk/io/json/JsonIO.java +++ b/sdks/java/io/json/src/main/java/org/apache/beam/sdk/io/json/JsonIO.java @@ -37,6 +37,7 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; /** * {@link PTransform}s for reading and writing JSON files. @@ -170,6 +171,16 @@ public Write withNoSpilling() { return toBuilder().setTextIOWrite(getTextIOWrite().withNoSpilling()).build(); } + /** + * Set the maximum number of writers created in a bundle before spilling to shuffle. See {@link + * WriteFiles#withMaxNumWritersPerBundle()}. + */ + public Write withMaxNumWritersPerBundle(@Nullable Integer maxNumWritersPerBundle) { + return toBuilder() + .setTextIOWrite(getTextIOWrite().withMaxNumWritersPerBundle(maxNumWritersPerBundle)) + .build(); + } + /** * Specifies to use a given fixed number of shards per window. See {@link * TextIO.Write#withNumShards}. diff --git a/sdks/standard_external_transforms.yaml b/sdks/standard_external_transforms.yaml index f5d71830145a..1c536ce319d2 100644 --- a/sdks/standard_external_transforms.yaml +++ b/sdks/standard_external_transforms.yaml @@ -19,7 +19,7 @@ # configuration in /sdks/standard_expansion_services.yaml. # Refer to gen_xlang_wrappers.py for more info. # -# Last updated on: 2025-04-24 +# Last updated on: 2025-06-05 - default_service: sdks:java:io:expansion-service:shadowJar description: 'Outputs a PCollection of Beam Rows, each containing a single INT64 @@ -91,6 +91,11 @@ name: filename_suffix nullable: true type: str + - description: Maximum number of writers created in a bundle before spilling to + shuffle. + name: max_num_writers_per_bundle + nullable: true + type: int32 - description: Whether to skip the spilling of data caused by having maxNumWritersPerBundle. name: no_spilling nullable: true