From 30852135d494aeeb67193e07c050c16e844da18e Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 14:14:50 -0400 Subject: [PATCH 01/12] Add postgres read to managed io --- .../pipeline/v1/external_transforms.proto | 2 + .../jdbc/JdbcReadSchemaTransformProvider.java | 2 + ...adFromPostgresSchemaTransformProvider.java | 42 ++++++++++++++++++- .../org/apache/beam/sdk/managed/Managed.java | 2 + .../python/apache_beam/transforms/external.py | 3 +- sdks/python/apache_beam/transforms/managed.py | 4 +- 6 files changed, 52 insertions(+), 3 deletions(-) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index add8a1999caf..3a618c45acbe 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -76,6 +76,8 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:bigquery_write:v1"]; ICEBERG_CDC_READ = 6 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:iceberg_cdc_read:v1"]; + POSTGRES_READ = 7 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:postgres_read:v1"]; } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java index 6777be50ab50..0c35f6955611 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java @@ -401,6 +401,8 @@ public static Builder builder() { .Builder(); } + public abstract Builder toBuilder(); + @AutoValue.Builder public abstract static class Builder { public abstract Builder setDriverClassName(String value); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java index 62ff14c23e0a..317fcb67c5d1 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java @@ -18,20 +18,30 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.POSTGRES; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class ReadFromPostgresSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(ReadFromPostgresSchemaTransformProvider.class); + @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:postgres_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_READ); } @Override @@ -43,4 +53,34 @@ public String description() { protected String jdbcType() { return POSTGRES; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcReadSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + + List<@org.checkerframework.checker.nullness.qual.Nullable String> connectionInitSql = + configuration.getConnectionInitSql(); + if (connectionInitSql != null && !connectionInitSql.isEmpty()) { + LOG.warn("Postgres does not support connectionInitSql, ignoring."); + } + + Boolean disableAutoCommit = configuration.getDisableAutoCommit(); + if (disableAutoCommit != null && !disableAutoCommit) { + LOG.warn("Postgres reads require disableAutoCommit to be true, overriding to true."); + } + + // Override "connectionInitSql" and "disableAutoCommit" for postgres + configuration = + configuration + .toBuilder() + .setConnectionInitSql(Collections.emptyList()) + .setDisableAutoCommit(true) + .build(); + return super.from(configuration); + } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 06aed06c71c4..f79a8ea3d6a6 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -96,6 +96,7 @@ public class Managed { public static final String ICEBERG_CDC = "iceberg_cdc"; public static final String KAFKA = "kafka"; public static final String BIGQUERY = "bigquery"; + public static final String POSTGRES = "postgres"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = @@ -104,6 +105,7 @@ public class Managed { .put(ICEBERG_CDC, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_CDC_READ)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ)) .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) + .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index f0b69a047b7c..e6c3c132a5b7 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -80,7 +80,8 @@ ManagedTransforms.Urns.KAFKA_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.KAFKA_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.BIGQUERY_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, - ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET + ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, + ManagedTransforms.Urns.POSTGRES_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, } diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index bf680d5fd354..95cf2afd89d5 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -85,6 +85,7 @@ _ICEBERG_CDC = "iceberg_cdc" KAFKA = "kafka" BIGQUERY = "bigquery" +POSTGRES = "postgres" __all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] @@ -95,7 +96,8 @@ class Read(PTransform): ICEBERG: ManagedTransforms.Urns.ICEBERG_READ.urn, _ICEBERG_CDC: ManagedTransforms.Urns.ICEBERG_CDC_READ.urn, KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn, - BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn, + POSTGRES: ManagedTransforms.Urns.POSTGRES_READ.urn, } def __init__( From eb9366a53266091faf033c3771c1199577571e8a Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 16:43:06 -0400 Subject: [PATCH 02/12] Add postgres write to managed io --- .../pipeline/v1/external_transforms.proto | 2 ++ .../JdbcWriteSchemaTransformProvider.java | 2 ++ ...riteToPostgresSchemaTransformProvider.java | 32 ++++++++++++++++++- .../org/apache/beam/sdk/managed/Managed.java | 1 + .../python/apache_beam/transforms/external.py | 1 + sdks/python/apache_beam/transforms/managed.py | 3 +- 6 files changed, 39 insertions(+), 2 deletions(-) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index 3a618c45acbe..02a5dd18e2c6 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -78,6 +78,8 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:iceberg_cdc_read:v1"]; POSTGRES_READ = 7 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:postgres_read:v1"]; + POSTGRES_WRITE = 8 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:postgres_write:v1"]; } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java index 6f10df56aab5..26eb3cc7e826 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java @@ -382,6 +382,8 @@ public static Builder builder() { .Builder(); } + public abstract Builder toBuilder(); + @AutoValue.Builder public abstract static class Builder { public abstract Builder setDriverClassName(String value); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java index c50b84311630..d61baf47fa0c 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java @@ -18,20 +18,30 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.POSTGRES; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class WriteToPostgresSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(WriteToPostgresSchemaTransformProvider.class); + @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:postgres_write:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_WRITE); } @Override @@ -43,4 +53,24 @@ public String description() { protected String jdbcType() { return POSTGRES; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcWriteSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + + List<@org.checkerframework.checker.nullness.qual.Nullable String> connectionInitSql = + configuration.getConnectionInitSql(); + if (connectionInitSql != null && !connectionInitSql.isEmpty()) { + LOG.warn("Postgres does not support connectionInitSql, ignoring."); + } + + // Override "connectionInitSql" for postgres + configuration = configuration.toBuilder().setConnectionInitSql(Collections.emptyList()).build(); + return super.from(configuration); + } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index f79a8ea3d6a6..cda84629a7d7 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -112,6 +112,7 @@ public class Managed { .put(ICEBERG, getUrn(ExternalTransforms.ManagedTransforms.Urns.ICEBERG_WRITE)) .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE)) .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) + .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_WRITE)) .build(); /** diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index e6c3c132a5b7..c12ce985737e 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -82,6 +82,7 @@ ManagedTransforms.Urns.BIGQUERY_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.POSTGRES_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, + ManagedTransforms.Urns.POSTGRES_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, } diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 95cf2afd89d5..72dfb6fd9a0a 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -138,7 +138,8 @@ class Write(PTransform): _WRITE_TRANSFORMS = { ICEBERG: ManagedTransforms.Urns.ICEBERG_WRITE.urn, KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn, - BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn + BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn, + POSTGRES: ManagedTransforms.Urns.POSTGRES_WRITE.urn, } def __init__( From c85ad614963491e883260a8eadaa8b9ae5a43a25 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 22:36:50 -0400 Subject: [PATCH 03/12] Add integration tests for both managed and unmanaged postgres read and write. --- sdks/java/io/jdbc/build.gradle | 1 + .../java/org/apache/JdbcIOPostgresIT.java | 180 ++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index 8c5fa685fdad..12f66428dcc4 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -39,6 +39,7 @@ dependencies { testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(path: ":sdks:java:extensions:avro", configuration: "testRuntimeMigration") testImplementation project(path: ":sdks:java:io:common") + testImplementation project(path: ":sdks:java:managed") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.junit testImplementation library.java.slf4j_api diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java new file mode 100644 index 000000000000..9b8395c5455d --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import static org.apache.beam.sdk.io.common.IOITHelper.readIOTestPipelineOptions; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.apache.beam.sdk.io.common.DatabaseTestHelper; +import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions; +import org.apache.beam.sdk.io.jdbc.providers.ReadFromPostgresSchemaTransformProvider; +import org.apache.beam.sdk.io.jdbc.providers.WriteToPostgresSchemaTransformProvider; +import org.apache.beam.sdk.managed.Managed; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.postgresql.ds.PGSimpleDataSource; + +/** + * A test of {@link org.apache.beam.sdk.io.jdbc.JdbcIO} on an independent Postgres instance. + * + *

Similar to JdbcIOIT, this test requires a running instance of Postgres. Pass in connection + * information using PipelineOptions: + * + *

+ *  ./gradlew integrationTest -p sdks/java/io/jdbc -DintegrationTestPipelineOptions='[
+ *  "--postgresServerName=1.2.3.4",
+ *  "--postgresUsername=postgres",
+ *  "--postgresDatabaseName=myfancydb",
+ *  "--postgresPassword=mypass",
+ *  "--postgresSsl=false" ]'
+ *  --tests org.apache.beam.sdk.io.jdbc.JdbcIOPostgresIT
+ *  -DintegrationTestRunner=direct
+ * 
+ */ +@RunWith(JUnit4.class) +public class JdbcIOPostgresIT { + private static final Schema INPUT_SCHEMA = + Schema.of( + Schema.Field.of("id", Schema.FieldType.INT32), + Schema.Field.of("name", Schema.FieldType.STRING)); + + private static final List ROWS = + Arrays.asList( + Row.withSchema(INPUT_SCHEMA) + .withFieldValue("id", 1) + .withFieldValue("name", "foo") + .build(), + Row.withSchema(INPUT_SCHEMA) + .withFieldValue("id", 2) + .withFieldValue("name", "bar") + .build(), + Row.withSchema(INPUT_SCHEMA) + .withFieldValue("id", 3) + .withFieldValue("name", "baz") + .build()); + + private static PGSimpleDataSource dataSource; + private static String jdbcUrl; + + @Rule public TestPipeline writePipeline = TestPipeline.create(); + @Rule public TestPipeline readPipeline = TestPipeline.create(); + + @BeforeClass + public static void setup() { + PostgresIOTestPipelineOptions options; + try { + options = readIOTestPipelineOptions(PostgresIOTestPipelineOptions.class); + } catch (IllegalArgumentException e) { + options = null; + } + org.junit.Assume.assumeNotNull(options); + dataSource = DatabaseTestHelper.getPostgresDataSource(options); + jdbcUrl = DatabaseTestHelper.getPostgresDBUrl(options); + } + + @Test + public void testWriteThenRead() throws SQLException { + String tableName = DatabaseTestHelper.getTestTableName("JdbcIOPostgresIT"); + DatabaseTestHelper.createTable(dataSource, tableName); + + WriteToPostgresSchemaTransformProvider.JdbcWriteSchemaTransformConfiguration writeConfig = + WriteToPostgresSchemaTransformProvider.JdbcWriteSchemaTransformConfiguration.builder() + .setJdbcUrl(jdbcUrl) + .setUsername(dataSource.getUser()) + .setPassword(dataSource.getPassword()) + .setLocation(tableName) + .build(); + + ReadFromPostgresSchemaTransformProvider.JdbcReadSchemaTransformConfiguration readConfig = + ReadFromPostgresSchemaTransformProvider.JdbcReadSchemaTransformConfiguration.builder() + .setJdbcUrl(jdbcUrl) + .setUsername(dataSource.getUser()) + .setPassword(dataSource.getPassword()) + .setLocation(tableName) + .build(); + + try { + PCollection input = writePipeline.apply(Create.of(ROWS)).setRowSchema(INPUT_SCHEMA); + PCollectionRowTuple input_tuple = PCollectionRowTuple.of("input", input); + input_tuple.apply( + new WriteToPostgresSchemaTransformProvider.JdbcWriteSchemaTransform( + writeConfig, "postgres")); + writePipeline.run().waitUntilFinish(); + + PCollectionRowTuple pbegin_tuple = PCollectionRowTuple.empty(readPipeline); + PCollectionRowTuple output_tuple = + pbegin_tuple.apply( + new ReadFromPostgresSchemaTransformProvider.JdbcReadSchemaTransform( + readConfig, "postgres")); + PCollection output = output_tuple.get("output"); + PAssert.that(output).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, tableName); + } + } + + @Test + public void testManagedWriteThenManagedRead() throws SQLException { + String tableName = DatabaseTestHelper.getTestTableName("ManagedJdbcIOPostgresIT"); + DatabaseTestHelper.createTable(dataSource, tableName); + + Map writeConfig = + ImmutableMap.builder() + .put("jdbc_url", jdbcUrl) + .put("username", dataSource.getUser()) + .put("password", dataSource.getPassword()) + .put("location", tableName) + .build(); + + Map readConfig = + ImmutableMap.builder() + .put("jdbc_url", jdbcUrl) + .put("username", dataSource.getUser()) + .put("password", dataSource.getPassword()) + .put("location", tableName) + .build(); + + try { + PCollection input = writePipeline.apply(Create.of(ROWS)).setRowSchema(INPUT_SCHEMA); + input.apply(Managed.write(Managed.POSTGRES).withConfig(writeConfig)); + writePipeline.run().waitUntilFinish(); + + PCollectionRowTuple output = + readPipeline.apply(Managed.read(Managed.POSTGRES).withConfig(readConfig)); + PAssert.that(output.get("output")).containsInAnyOrder(ROWS); + readPipeline.run().waitUntilFinish(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, tableName); + } + } +} From f60b1367cb39dbfadf4e71c172fff04fe2c12bc2 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 22:50:26 -0400 Subject: [PATCH 04/12] Fix error in analyzeClassesDependencies gradle task --- sdks/java/io/jdbc/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index 12f66428dcc4..b2b74f9cf005 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -29,6 +29,7 @@ ext.summary = "IO to read and write on JDBC datasource." dependencies { implementation library.java.vendored_guava_32_1_2_jre implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation project(path: ":model:pipeline", configuration: "shadow") implementation library.java.dbcp2 implementation library.java.joda_time implementation "org.apache.commons:commons-pool2:2.11.1" From 1c79e754296413033397b3f6810367fd1149bc8e Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 23:20:56 -0400 Subject: [PATCH 05/12] Fix spotless failure. --- .../{ => beam/sdk/io/jdbc}/JdbcIOPostgresIT.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename sdks/java/io/jdbc/src/test/java/org/apache/{ => beam/sdk/io/jdbc}/JdbcIOPostgresIT.java (95%) diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java similarity index 95% rename from sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java rename to sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java index 9b8395c5455d..3a1aadb003bb 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/JdbcIOPostgresIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java @@ -124,18 +124,18 @@ public void testWriteThenRead() throws SQLException { try { PCollection input = writePipeline.apply(Create.of(ROWS)).setRowSchema(INPUT_SCHEMA); - PCollectionRowTuple input_tuple = PCollectionRowTuple.of("input", input); - input_tuple.apply( + PCollectionRowTuple inputTuple = PCollectionRowTuple.of("input", input); + inputTuple.apply( new WriteToPostgresSchemaTransformProvider.JdbcWriteSchemaTransform( writeConfig, "postgres")); writePipeline.run().waitUntilFinish(); - PCollectionRowTuple pbegin_tuple = PCollectionRowTuple.empty(readPipeline); - PCollectionRowTuple output_tuple = - pbegin_tuple.apply( + PCollectionRowTuple pbeginTuple = PCollectionRowTuple.empty(readPipeline); + PCollectionRowTuple outputTuple = + pbeginTuple.apply( new ReadFromPostgresSchemaTransformProvider.JdbcReadSchemaTransform( readConfig, "postgres")); - PCollection output = output_tuple.get("output"); + PCollection output = outputTuple.get("output"); PAssert.that(output).containsInAnyOrder(ROWS); readPipeline.run().waitUntilFinish(); } finally { From a2ebfea74a191b8cf9dcbebd039a0330a311b902 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 2 Sep 2025 23:24:23 -0400 Subject: [PATCH 06/12] Fix python lint --- sdks/python/apache_beam/transforms/external.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index c12ce985737e..59d9accdf06e 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -80,9 +80,9 @@ ManagedTransforms.Urns.KAFKA_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.KAFKA_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.BIGQUERY_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, - ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, + ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long ManagedTransforms.Urns.POSTGRES_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, - ManagedTransforms.Urns.POSTGRES_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, + ManagedTransforms.Urns.POSTGRES_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long } From 5c7b81709b365e6202c0c78dc891f41183df21d5 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 3 Sep 2025 13:14:32 -0400 Subject: [PATCH 07/12] Add schema transform translation for postgres read and write. --- .../jdbc/JdbcReadSchemaTransformProvider.java | 16 ++++ .../JdbcWriteSchemaTransformProvider.java | 16 ++++ .../PostgresSchemaTransformTranslation.java | 93 +++++++++++++++++++ ...adFromPostgresSchemaTransformProvider.java | 6 ++ ...riteToPostgresSchemaTransformProvider.java | 6 ++ .../beam/sdk/io/jdbc/JdbcIOPostgresIT.java | 6 +- 6 files changed, 139 insertions(+), 4 deletions(-) create mode 100644 sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslation.java diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java index 0c35f6955611..da75c9baaa45 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcReadSchemaTransformProvider.java @@ -27,6 +27,8 @@ import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -265,6 +267,20 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { } return PCollectionRowTuple.of("output", input.getPipeline().apply(readRows)); } + + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(JdbcReadSchemaTransformConfiguration.class) + .apply(config) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } } @Override diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java index 26eb3cc7e826..4dbb9b396f09 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcWriteSchemaTransformProvider.java @@ -27,7 +27,9 @@ import java.util.Objects; import javax.annotation.Nullable; import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.schemas.annotations.DefaultSchema; import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; @@ -265,6 +267,20 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { .setRowSchema(Schema.of()); return PCollectionRowTuple.of("post_write", postWrite); } + + public Row getConfigurationRow() { + try { + // To stay consistent with our SchemaTransform configuration naming conventions, + // we sort lexicographically + return SchemaRegistry.createDefault() + .getToRowFunction(JdbcWriteSchemaTransformConfiguration.class) + .apply(config) + .sorted() + .toSnakeCase(); + } catch (NoSuchSchemaException e) { + throw new RuntimeException(e); + } + } } @Override diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslation.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslation.java new file mode 100644 index 000000000000..288b29642c5a --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslation.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromPostgresSchemaTransformProvider.PostgresReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToPostgresSchemaTransformProvider.PostgresWriteSchemaTransform; +import static org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class PostgresSchemaTransformTranslation { + static class PostgresReadSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new ReadFromPostgresSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(PostgresReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(PostgresReadSchemaTransform.class, new PostgresReadSchemaTransformTranslator()) + .build(); + } + } + + static class PostgresWriteSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new WriteToPostgresSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(PostgresWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class WriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(PostgresWriteSchemaTransform.class, new PostgresWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java index 317fcb67c5d1..68fe4f89cf8e 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java @@ -83,4 +83,10 @@ protected String jdbcType() { .build(); return super.from(configuration); } + + public static class PostgresReadSchemaTransform extends JdbcReadSchemaTransform { + public PostgresReadSchemaTransform(JdbcReadSchemaTransformConfiguration config) { + super(config, POSTGRES); + } + } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java index d61baf47fa0c..d5acd29b38d5 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java @@ -73,4 +73,10 @@ protected String jdbcType() { configuration = configuration.toBuilder().setConnectionInitSql(Collections.emptyList()).build(); return super.from(configuration); } + + public static class PostgresWriteSchemaTransform extends JdbcWriteSchemaTransform { + public PostgresWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config) { + super(config, POSTGRES); + } + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java index 3a1aadb003bb..9fd5a7b3d68b 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOPostgresIT.java @@ -126,15 +126,13 @@ public void testWriteThenRead() throws SQLException { PCollection input = writePipeline.apply(Create.of(ROWS)).setRowSchema(INPUT_SCHEMA); PCollectionRowTuple inputTuple = PCollectionRowTuple.of("input", input); inputTuple.apply( - new WriteToPostgresSchemaTransformProvider.JdbcWriteSchemaTransform( - writeConfig, "postgres")); + new WriteToPostgresSchemaTransformProvider.PostgresWriteSchemaTransform(writeConfig)); writePipeline.run().waitUntilFinish(); PCollectionRowTuple pbeginTuple = PCollectionRowTuple.empty(readPipeline); PCollectionRowTuple outputTuple = pbeginTuple.apply( - new ReadFromPostgresSchemaTransformProvider.JdbcReadSchemaTransform( - readConfig, "postgres")); + new ReadFromPostgresSchemaTransformProvider.PostgresReadSchemaTransform(readConfig)); PCollection output = outputTuple.get("output"); PAssert.that(output).containsInAnyOrder(ROWS); readPipeline.run().waitUntilFinish(); From 14b097b79e79577e3e8f7e6b2d68cd8dad7e5069 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 3 Sep 2025 14:36:06 -0400 Subject: [PATCH 08/12] Add test for postgres schema transform translation. --- sdks/java/io/jdbc/build.gradle | 1 + ...adFromPostgresSchemaTransformProvider.java | 3 +- ...riteToPostgresSchemaTransformProvider.java | 3 +- ...ostgresSchemaTransformTranslationTest.java | 233 ++++++++++++++++++ 4 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslationTest.java diff --git a/sdks/java/io/jdbc/build.gradle b/sdks/java/io/jdbc/build.gradle index b2b74f9cf005..87a231a5a42b 100644 --- a/sdks/java/io/jdbc/build.gradle +++ b/sdks/java/io/jdbc/build.gradle @@ -43,6 +43,7 @@ dependencies { testImplementation project(path: ":sdks:java:managed") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.junit + testImplementation library.java.mockito_inline testImplementation library.java.slf4j_api testImplementation library.java.postgres diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java index 68fe4f89cf8e..68a3eeef83c2 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromPostgresSchemaTransformProvider.java @@ -81,12 +81,13 @@ protected String jdbcType() { .setConnectionInitSql(Collections.emptyList()) .setDisableAutoCommit(true) .build(); - return super.from(configuration); + return new PostgresReadSchemaTransform(configuration); } public static class PostgresReadSchemaTransform extends JdbcReadSchemaTransform { public PostgresReadSchemaTransform(JdbcReadSchemaTransformConfiguration config) { super(config, POSTGRES); + config.validate(POSTGRES); } } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java index d5acd29b38d5..1b1e225481a4 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToPostgresSchemaTransformProvider.java @@ -71,12 +71,13 @@ protected String jdbcType() { // Override "connectionInitSql" for postgres configuration = configuration.toBuilder().setConnectionInitSql(Collections.emptyList()).build(); - return super.from(configuration); + return new PostgresWriteSchemaTransform(configuration); } public static class PostgresWriteSchemaTransform extends JdbcWriteSchemaTransform { public PostgresWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config) { super(config, POSTGRES); + config.validate(POSTGRES); } } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslationTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslationTest.java new file mode 100644 index 000000000000..503baaefc334 --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/PostgresSchemaTransformTranslationTest.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.jdbc.providers.PostgresSchemaTransformTranslation.PostgresReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.PostgresSchemaTransformTranslation.PostgresWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromPostgresSchemaTransformProvider.PostgresReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToPostgresSchemaTransformProvider.PostgresWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.jdbc.JdbcIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +public class PostgresSchemaTransformTranslationTest { + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + static final WriteToPostgresSchemaTransformProvider WRITE_PROVIDER = + new WriteToPostgresSchemaTransformProvider(); + static final ReadFromPostgresSchemaTransformProvider READ_PROVIDER = + new ReadFromPostgresSchemaTransformProvider(); + + static final Row READ_CONFIG = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:postgresql://host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("connection_properties", "some_property") + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("disable_auto_commit", true) + .withFieldValue("fetch_size", 10) + .withFieldValue("num_partitions", 5) + .withFieldValue("output_parallelization", true) + .withFieldValue("partition_column", "col") + .withFieldValue("read_query", null) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .build(); + + static final Row WRITE_CONFIG = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:postgresql://host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("autosharding", true) + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("connection_properties", "some_property") + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("batch_size", 100L) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .withFieldValue("write_statement", null) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + PostgresWriteSchemaTransform writeTransform = + (PostgresWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + + PostgresWriteSchemaTransformTranslator translator = + new PostgresWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + PostgresWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addStringField("name").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue("test").build()))) + .setRowSchema(inputSchema); + + PostgresWriteSchemaTransform writeTransform = + (PostgresWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract PostgresWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the PostgresWriteSchemaTransform + PostgresWriteSchemaTransformTranslator translator = + new PostgresWriteSchemaTransformTranslator(); + PostgresWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + // setting a subset of fields here. + PostgresReadSchemaTransform readTransform = + (PostgresReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + PostgresReadSchemaTransformTranslator translator = new PostgresReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + PostgresReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + PostgresReadSchemaTransform readTransform = + (PostgresReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + // Mock inferBeamSchema since it requires database connection. + Schema expectedSchema = Schema.builder().addStringField("name").build(); + try (MockedStatic mock = Mockito.mockStatic(JdbcIO.ReadRows.class)) { + mock.when(() -> JdbcIO.ReadRows.inferBeamSchema(Mockito.any(), Mockito.any())) + .thenReturn(expectedSchema); + PCollectionRowTuple.empty(p).apply(readTransform); + } + + // Then translate the pipeline to a proto and extract PostgresReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the PostgresReadSchemaTransform + PostgresReadSchemaTransformTranslator translator = new PostgresReadSchemaTransformTranslator(); + PostgresReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow()); + } +} From 7a0962364ca0478dce5c66517f87ef6b6834a20d Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 3 Sep 2025 10:13:47 -0400 Subject: [PATCH 09/12] Add mysql read to managed io --- .../pipeline/v1/external_transforms.proto | 2 ++ .../ReadFromMySqlSchemaTransformProvider.java | 31 ++++++++++++++++++- .../org/apache/beam/sdk/managed/Managed.java | 2 ++ .../python/apache_beam/transforms/external.py | 1 + sdks/python/apache_beam/transforms/managed.py | 2 ++ 5 files changed, 37 insertions(+), 1 deletion(-) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index 02a5dd18e2c6..afe82382a3dd 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -80,6 +80,8 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:postgres_read:v1"]; POSTGRES_WRITE = 8 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:postgres_write:v1"]; + MYSQL_READ = 9 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:mysql_read:v1"]; } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java index 3d0135ef8ecd..0ca3064f9528 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java @@ -18,20 +18,28 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MYSQL; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class ReadFromMySqlSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(ReadFromMySqlSchemaTransformProvider.class); + @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:mysql_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_READ); } @Override @@ -43,4 +51,25 @@ public String description() { protected String jdbcType() { return MYSQL; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcReadSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + + Integer fetchSize = configuration.getFetchSize(); + if (fetchSize != null + && fetchSize > 0 + && configuration.getJdbcUrl() != null + && !configuration.getJdbcUrl().contains("useCursorFetch=true")) { + LOG.warn( + "The fetchSize option is ignored. It is required to set useCursorFetch=true" + + " in the JDBC URL when using fetchSize for MySQL"); + } + return super.from(configuration); + } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index cda84629a7d7..37ea11dd68c7 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -97,6 +97,7 @@ public class Managed { public static final String KAFKA = "kafka"; public static final String BIGQUERY = "bigquery"; public static final String POSTGRES = "postgres"; + public static final String MYSQL = "mysql"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = @@ -106,6 +107,7 @@ public class Managed { .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_READ)) .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_READ)) + .put(MYSQL, getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index b22ed6e0c645..463d669f3649 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -83,6 +83,7 @@ ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long ManagedTransforms.Urns.POSTGRES_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.POSTGRES_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long + ManagedTransforms.Urns.MYSQL_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, } diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 72dfb6fd9a0a..45c3b2a40acc 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -86,6 +86,7 @@ KAFKA = "kafka" BIGQUERY = "bigquery" POSTGRES = "postgres" +MYSQL = "mysql" __all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] @@ -98,6 +99,7 @@ class Read(PTransform): KAFKA: ManagedTransforms.Urns.KAFKA_READ.urn, BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn, POSTGRES: ManagedTransforms.Urns.POSTGRES_READ.urn, + MYSQL: ManagedTransforms.Urns.MYSQL_READ.urn, } def __init__( From f3e2edc9a39cd014aa00dd9b116c92569beade0b Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 3 Sep 2025 10:23:14 -0400 Subject: [PATCH 10/12] Add mysql write to managed io --- .../model/pipeline/v1/external_transforms.proto | 2 ++ .../WriteToMySqlSchemaTransformProvider.java | 16 +++++++++++++++- .../org/apache/beam/sdk/managed/Managed.java | 1 + sdks/python/apache_beam/transforms/external.py | 1 + sdks/python/apache_beam/transforms/managed.py | 1 + 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index afe82382a3dd..31232eb60671 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -82,6 +82,8 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:postgres_write:v1"]; MYSQL_READ = 9 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:mysql_read:v1"]; + MYSQL_WRITE = 10 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:mysql_write:v1"]; } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java index 57f085220162..9c234dfed988 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java @@ -18,9 +18,12 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.MYSQL; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; @@ -31,7 +34,7 @@ public class WriteToMySqlSchemaTransformProvider extends JdbcWriteSchemaTransfor @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:mysql_write:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_WRITE); } @Override @@ -43,4 +46,15 @@ public String description() { protected String jdbcType() { return MYSQL; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcWriteSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + return super.from(configuration); + } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 37ea11dd68c7..4f45eeac861e 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -115,6 +115,7 @@ public class Managed { .put(KAFKA, getUrn(ExternalTransforms.ManagedTransforms.Urns.KAFKA_WRITE)) .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_WRITE)) + .put(MYSQL, getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_WRITE)) .build(); /** diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index 463d669f3649..3f9f56a54139 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -84,6 +84,7 @@ ManagedTransforms.Urns.POSTGRES_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, ManagedTransforms.Urns.POSTGRES_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long ManagedTransforms.Urns.MYSQL_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, + ManagedTransforms.Urns.MYSQL_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET, } diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 45c3b2a40acc..03449236ac92 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -142,6 +142,7 @@ class Write(PTransform): KAFKA: ManagedTransforms.Urns.KAFKA_WRITE.urn, BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn, POSTGRES: ManagedTransforms.Urns.POSTGRES_WRITE.urn, + MYSQL: ManagedTransforms.Urns.MYSQL_WRITE.urn, } def __init__( From d15e5b8ebcbb15133b2304989c442e30b4b9cc19 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 3 Sep 2025 20:52:08 -0400 Subject: [PATCH 11/12] Add schema transform translation and test for mysql read and write --- .../MySqlSchemaTransformTranslation.java | 93 +++++++ .../ReadFromMySqlSchemaTransformProvider.java | 9 +- .../WriteToMySqlSchemaTransformProvider.java | 9 +- .../MysqlSchemaTransformTranslationTest.java | 231 ++++++++++++++++++ 4 files changed, 340 insertions(+), 2 deletions(-) create mode 100644 sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/MySqlSchemaTransformTranslation.java create mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/MysqlSchemaTransformTranslationTest.java diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/MySqlSchemaTransformTranslation.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/MySqlSchemaTransformTranslation.java new file mode 100644 index 000000000000..3367248b7198 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/MySqlSchemaTransformTranslation.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromMySqlSchemaTransformProvider.MySqlReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToMySqlSchemaTransformProvider.MySqlWriteSchemaTransform; +import static org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class MySqlSchemaTransformTranslation { + static class MySqlReadSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new ReadFromMySqlSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(MySqlReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(MySqlReadSchemaTransform.class, new MySqlReadSchemaTransformTranslator()) + .build(); + } + } + + static class MySqlWriteSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new WriteToMySqlSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(MySqlWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class WriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(MySqlWriteSchemaTransform.class, new MySqlWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java index 0ca3064f9528..2bf6928b5eb0 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromMySqlSchemaTransformProvider.java @@ -70,6 +70,13 @@ protected String jdbcType() { "The fetchSize option is ignored. It is required to set useCursorFetch=true" + " in the JDBC URL when using fetchSize for MySQL"); } - return super.from(configuration); + return new MySqlReadSchemaTransform(configuration); + } + + public static class MySqlReadSchemaTransform extends JdbcReadSchemaTransform { + public MySqlReadSchemaTransform(JdbcReadSchemaTransformConfiguration config) { + super(config, MYSQL); + config.validate(MYSQL); + } } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java index 9c234dfed988..a283a64d29c4 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToMySqlSchemaTransformProvider.java @@ -55,6 +55,13 @@ protected String jdbcType() { throw new IllegalArgumentException( String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); } - return super.from(configuration); + return new MySqlWriteSchemaTransform(configuration); + } + + public static class MySqlWriteSchemaTransform extends JdbcWriteSchemaTransform { + public MySqlWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config) { + super(config, MYSQL); + config.validate(MYSQL); + } } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/MysqlSchemaTransformTranslationTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/MysqlSchemaTransformTranslationTest.java new file mode 100644 index 000000000000..dbcd3cecf8bc --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/MysqlSchemaTransformTranslationTest.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.jdbc.providers.MySqlSchemaTransformTranslation.MySqlReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.MySqlSchemaTransformTranslation.MySqlWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromMySqlSchemaTransformProvider.MySqlReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToMySqlSchemaTransformProvider.MySqlWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.jdbc.JdbcIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +public class MysqlSchemaTransformTranslationTest { + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + static final WriteToMySqlSchemaTransformProvider WRITE_PROVIDER = + new WriteToMySqlSchemaTransformProvider(); + static final ReadFromMySqlSchemaTransformProvider READ_PROVIDER = + new ReadFromMySqlSchemaTransformProvider(); + + static final Row READ_CONFIG = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:mysql://host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("connection_properties", "some_property") + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("disable_auto_commit", true) + .withFieldValue("fetch_size", 10) + .withFieldValue("num_partitions", 5) + .withFieldValue("output_parallelization", true) + .withFieldValue("partition_column", "col") + .withFieldValue("read_query", null) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .build(); + + static final Row WRITE_CONFIG = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:mysql://host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("autosharding", true) + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("connection_properties", "some_property") + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("batch_size", 100L) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .withFieldValue("write_statement", null) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + MySqlWriteSchemaTransform writeTransform = + (MySqlWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + + MySqlWriteSchemaTransformTranslator translator = new MySqlWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + MySqlWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addStringField("name").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue("test").build()))) + .setRowSchema(inputSchema); + + MySqlWriteSchemaTransform writeTransform = + (MySqlWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract MySqlWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the MySqlWriteSchemaTransform + MySqlWriteSchemaTransformTranslator translator = new MySqlWriteSchemaTransformTranslator(); + MySqlWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + // setting a subset of fields here. + MySqlReadSchemaTransform readTransform = + (MySqlReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + MySqlReadSchemaTransformTranslator translator = new MySqlReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + MySqlReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + MySqlReadSchemaTransform readTransform = + (MySqlReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + // Mock inferBeamSchema since it requires database connection. + Schema expectedSchema = Schema.builder().addStringField("name").build(); + try (MockedStatic mock = Mockito.mockStatic(JdbcIO.ReadRows.class)) { + mock.when(() -> JdbcIO.ReadRows.inferBeamSchema(Mockito.any(), Mockito.any())) + .thenReturn(expectedSchema); + PCollectionRowTuple.empty(p).apply(readTransform); + } + + // Then translate the pipeline to a proto and extract MySqlReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the MySqlReadSchemaTransform + MySqlReadSchemaTransformTranslator translator = new MySqlReadSchemaTransformTranslator(); + MySqlReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow()); + } +} From 94c8f701af685d7c48e75aeea3ee62c9193ecdad Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Thu, 4 Sep 2025 14:24:42 -0400 Subject: [PATCH 12/12] Add oracle read and write to managed io. TransformTranslation and test are also included. --- .../pipeline/v1/external_transforms.proto | 4 + .../OracleSchemaTransformTranslation.java | 93 +++++++ ...ReadFromOracleSchemaTransformProvider.java | 39 ++- .../WriteToOracleSchemaTransformProvider.java | 39 ++- .../OracleSchemaTransformTranslationTest.java | 231 ++++++++++++++++++ .../org/apache/beam/sdk/managed/Managed.java | 3 + sdks/python/apache_beam/transforms/managed.py | 3 + 7 files changed, 410 insertions(+), 2 deletions(-) create mode 100644 sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslation.java create mode 100644 sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslationTest.java diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto index 31232eb60671..3607d81bde8c 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/external_transforms.proto @@ -84,6 +84,10 @@ message ManagedTransforms { "beam:schematransform:org.apache.beam:mysql_read:v1"]; MYSQL_WRITE = 10 [(org.apache.beam.model.pipeline.v1.beam_urn) = "beam:schematransform:org.apache.beam:mysql_write:v1"]; + ORACLE_READ = 11 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:oracle_read:v1"]; + ORACLE_WRITE = 12 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:schematransform:org.apache.beam:oracle_write:v1"]; } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslation.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslation.java new file mode 100644 index 000000000000..7cb0abb040fa --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslation.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromOracleSchemaTransformProvider.OracleReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToOracleSchemaTransformProvider.OracleWriteSchemaTransform; +import static org.apache.beam.sdk.schemas.transforms.SchemaTransformTranslation.SchemaTransformPayloadTranslator; + +import com.google.auto.service.AutoService; +import java.util.Map; +import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.util.construction.PTransformTranslation; +import org.apache.beam.sdk.util.construction.TransformPayloadTranslatorRegistrar; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; + +public class OracleSchemaTransformTranslation { + static class OracleReadSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new ReadFromOracleSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(OracleReadSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class ReadRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(OracleReadSchemaTransform.class, new OracleReadSchemaTransformTranslator()) + .build(); + } + } + + static class OracleWriteSchemaTransformTranslator + extends SchemaTransformPayloadTranslator { + @Override + public SchemaTransformProvider provider() { + return new WriteToOracleSchemaTransformProvider(); + } + + @Override + public Row toConfigRow(OracleWriteSchemaTransform transform) { + return transform.getConfigurationRow(); + } + } + + @AutoService(TransformPayloadTranslatorRegistrar.class) + public static class WriteRegistrar implements TransformPayloadTranslatorRegistrar { + @Override + @SuppressWarnings({ + "rawtypes", + }) + public Map< + ? extends Class, + ? extends PTransformTranslation.TransformPayloadTranslator> + getTransformPayloadTranslators() { + return ImmutableMap + ., PTransformTranslation.TransformPayloadTranslator>builder() + .put(OracleWriteSchemaTransform.class, new OracleWriteSchemaTransformTranslator()) + .build(); + } + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java index de18d5aa8189..8f2320f9e2ba 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/ReadFromOracleSchemaTransformProvider.java @@ -18,20 +18,30 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.ORACLE; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcReadSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class ReadFromOracleSchemaTransformProvider extends JdbcReadSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(ReadFromOracleSchemaTransformProvider.class); + @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:oracle_read:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.ORACLE_READ); } @Override @@ -43,4 +53,31 @@ public String description() { protected String jdbcType() { return ORACLE; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcReadSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + + List<@org.checkerframework.checker.nullness.qual.Nullable String> connectionInitSql = + configuration.getConnectionInitSql(); + if (connectionInitSql != null && !connectionInitSql.isEmpty()) { + LOG.warn("Oracle does not support connectionInitSql, ignoring."); + } + + // Override "connectionInitSql" for oracle + configuration = configuration.toBuilder().setConnectionInitSql(Collections.emptyList()).build(); + return new OracleReadSchemaTransform(configuration); + } + + public static class OracleReadSchemaTransform extends JdbcReadSchemaTransform { + public OracleReadSchemaTransform(JdbcReadSchemaTransformConfiguration config) { + super(config, ORACLE); + config.validate(ORACLE); + } + } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java index 5b3ae2c35e9d..2e8a7c8cd83c 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/providers/WriteToOracleSchemaTransformProvider.java @@ -18,20 +18,30 @@ package org.apache.beam.sdk.io.jdbc.providers; import static org.apache.beam.sdk.io.jdbc.JdbcUtil.ORACLE; +import static org.apache.beam.sdk.util.construction.BeamUrns.getUrn; import com.google.auto.service.AutoService; +import java.util.Collections; +import java.util.List; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; import org.apache.beam.sdk.io.jdbc.JdbcWriteSchemaTransformProvider; +import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.checkerframework.checker.initialization.qual.Initialized; import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; @AutoService(SchemaTransformProvider.class) public class WriteToOracleSchemaTransformProvider extends JdbcWriteSchemaTransformProvider { + private static final Logger LOG = + LoggerFactory.getLogger(WriteToOracleSchemaTransformProvider.class); + @Override public @UnknownKeyFor @NonNull @Initialized String identifier() { - return "beam:schematransform:org.apache.beam:oracle_write:v1"; + return getUrn(ExternalTransforms.ManagedTransforms.Urns.ORACLE_WRITE); } @Override @@ -43,4 +53,31 @@ public String description() { protected String jdbcType() { return ORACLE; } + + @Override + public @UnknownKeyFor @NonNull @Initialized SchemaTransform from( + JdbcWriteSchemaTransformConfiguration configuration) { + String jdbcType = configuration.getJdbcType(); + if (jdbcType != null && !jdbcType.equals(jdbcType())) { + throw new IllegalArgumentException( + String.format("Wrong JDBC type. Expected '%s' but got '%s'", jdbcType(), jdbcType)); + } + + List<@org.checkerframework.checker.nullness.qual.Nullable String> connectionInitSql = + configuration.getConnectionInitSql(); + if (connectionInitSql != null && !connectionInitSql.isEmpty()) { + LOG.warn("Oracle does not support connectionInitSql, ignoring."); + } + + // Override "connectionInitSql" for oracle + configuration = configuration.toBuilder().setConnectionInitSql(Collections.emptyList()).build(); + return new OracleWriteSchemaTransform(configuration); + } + + public static class OracleWriteSchemaTransform extends JdbcWriteSchemaTransform { + public OracleWriteSchemaTransform(JdbcWriteSchemaTransformConfiguration config) { + super(config, ORACLE); + config.validate(ORACLE); + } + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslationTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslationTest.java new file mode 100644 index 000000000000..4106a658ce4e --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/providers/OracleSchemaTransformTranslationTest.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc.providers; + +import static org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods.Enum.SCHEMA_TRANSFORM; +import static org.apache.beam.sdk.io.jdbc.providers.OracleSchemaTransformTranslation.OracleReadSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.OracleSchemaTransformTranslation.OracleWriteSchemaTransformTranslator; +import static org.apache.beam.sdk.io.jdbc.providers.ReadFromOracleSchemaTransformProvider.OracleReadSchemaTransform; +import static org.apache.beam.sdk.io.jdbc.providers.WriteToOracleSchemaTransformProvider.OracleWriteSchemaTransform; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.SchemaTransformPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.jdbc.JdbcIO; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.util.construction.BeamUrns; +import org.apache.beam.sdk.util.construction.PipelineTranslation; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.rules.TemporaryFolder; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +public class OracleSchemaTransformTranslationTest { + @ClassRule public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); + + @Rule public transient ExpectedException thrown = ExpectedException.none(); + + static final WriteToOracleSchemaTransformProvider WRITE_PROVIDER = + new WriteToOracleSchemaTransformProvider(); + static final ReadFromOracleSchemaTransformProvider READ_PROVIDER = + new ReadFromOracleSchemaTransformProvider(); + + static final Row READ_CONFIG = + Row.withSchema(READ_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:oracle:thin:@host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("connection_properties", "some_property") + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("disable_auto_commit", true) + .withFieldValue("fetch_size", 10) + .withFieldValue("num_partitions", 5) + .withFieldValue("output_parallelization", true) + .withFieldValue("partition_column", "col") + .withFieldValue("read_query", null) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .build(); + + static final Row WRITE_CONFIG = + Row.withSchema(WRITE_PROVIDER.configurationSchema()) + .withFieldValue("jdbc_url", "jdbc:oracle:thin:@host:port/database") + .withFieldValue("location", "test_table") + .withFieldValue("autosharding", true) + .withFieldValue("connection_init_sql", ImmutableList.builder().build()) + .withFieldValue("connection_properties", "some_property") + .withFieldValue("driver_class_name", null) + .withFieldValue("driver_jars", null) + .withFieldValue("batch_size", 100L) + .withFieldValue("username", "my_user") + .withFieldValue("password", "my_pass") + .withFieldValue("write_statement", null) + .build(); + + @Test + public void testRecreateWriteTransformFromRow() { + OracleWriteSchemaTransform writeTransform = + (OracleWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + + OracleWriteSchemaTransformTranslator translator = new OracleWriteSchemaTransformTranslator(); + Row translatedRow = translator.toConfigRow(writeTransform); + + OracleWriteSchemaTransform writeTransformFromRow = + translator.fromConfigRow(translatedRow, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromRow.getConfigurationRow()); + } + + @Test + public void testWriteTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + Schema inputSchema = Schema.builder().addStringField("name").build(); + PCollection input = + p.apply( + Create.of( + Collections.singletonList( + Row.withSchema(inputSchema).addValue("test").build()))) + .setRowSchema(inputSchema); + + OracleWriteSchemaTransform writeTransform = + (OracleWriteSchemaTransform) WRITE_PROVIDER.from(WRITE_CONFIG); + PCollectionRowTuple.of("input", input).apply(writeTransform); + + // Then translate the pipeline to a proto and extract OracleWriteSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List writeTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(WRITE_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, writeTransformProto.size()); + RunnerApi.FunctionSpec spec = writeTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(WRITE_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + + assertEquals(WRITE_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the OracleWriteSchemaTransform + OracleWriteSchemaTransformTranslator translator = new OracleWriteSchemaTransformTranslator(); + OracleWriteSchemaTransform writeTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(WRITE_CONFIG, writeTransformFromSpec.getConfigurationRow()); + } + + @Test + public void testReCreateReadTransformFromRow() { + // setting a subset of fields here. + OracleReadSchemaTransform readTransform = + (OracleReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + OracleReadSchemaTransformTranslator translator = new OracleReadSchemaTransformTranslator(); + Row row = translator.toConfigRow(readTransform); + + OracleReadSchemaTransform readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromRow.getConfigurationRow()); + } + + @Test + public void testReadTransformProtoTranslation() + throws InvalidProtocolBufferException, IOException { + // First build a pipeline + Pipeline p = Pipeline.create(); + + OracleReadSchemaTransform readTransform = + (OracleReadSchemaTransform) READ_PROVIDER.from(READ_CONFIG); + + // Mock inferBeamSchema since it requires database connection. + Schema expectedSchema = Schema.builder().addStringField("name").build(); + try (MockedStatic mock = Mockito.mockStatic(JdbcIO.ReadRows.class)) { + mock.when(() -> JdbcIO.ReadRows.inferBeamSchema(Mockito.any(), Mockito.any())) + .thenReturn(expectedSchema); + PCollectionRowTuple.empty(p).apply(readTransform); + } + + // Then translate the pipeline to a proto and extract OracleReadSchemaTransform proto + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + List readTransformProto = + pipelineProto.getComponents().getTransformsMap().values().stream() + .filter( + tr -> { + RunnerApi.FunctionSpec spec = tr.getSpec(); + try { + return spec.getUrn().equals(BeamUrns.getUrn(SCHEMA_TRANSFORM)) + && SchemaTransformPayload.parseFrom(spec.getPayload()) + .getIdentifier() + .equals(READ_PROVIDER.identifier()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + }) + .collect(Collectors.toList()); + assertEquals(1, readTransformProto.size()); + RunnerApi.FunctionSpec spec = readTransformProto.get(0).getSpec(); + + // Check that the proto contains correct values + SchemaTransformPayload payload = SchemaTransformPayload.parseFrom(spec.getPayload()); + Schema schemaFromSpec = SchemaTranslation.schemaFromProto(payload.getConfigurationSchema()); + assertEquals(READ_PROVIDER.configurationSchema(), schemaFromSpec); + Row rowFromSpec = RowCoder.of(schemaFromSpec).decode(payload.getConfigurationRow().newInput()); + assertEquals(READ_CONFIG, rowFromSpec); + + // Use the information in the proto to recreate the OracleReadSchemaTransform + OracleReadSchemaTransformTranslator translator = new OracleReadSchemaTransformTranslator(); + OracleReadSchemaTransform readTransformFromSpec = + translator.fromConfigRow(rowFromSpec, PipelineOptionsFactory.create()); + + assertEquals(READ_CONFIG, readTransformFromSpec.getConfigurationRow()); + } +} diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index 4f45eeac861e..cdf7fcfa1bad 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -98,6 +98,7 @@ public class Managed { public static final String BIGQUERY = "bigquery"; public static final String POSTGRES = "postgres"; public static final String MYSQL = "mysql"; + public static final String ORACLE = "oracle"; // Supported SchemaTransforms public static final Map READ_TRANSFORMS = @@ -108,6 +109,7 @@ public class Managed { .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_READ)) .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_READ)) .put(MYSQL, getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_READ)) + .put(ORACLE, getUrn(ExternalTransforms.ManagedTransforms.Urns.ORACLE_READ)) .build(); public static final Map WRITE_TRANSFORMS = ImmutableMap.builder() @@ -116,6 +118,7 @@ public class Managed { .put(BIGQUERY, getUrn(ExternalTransforms.ManagedTransforms.Urns.BIGQUERY_WRITE)) .put(POSTGRES, getUrn(ExternalTransforms.ManagedTransforms.Urns.POSTGRES_WRITE)) .put(MYSQL, getUrn(ExternalTransforms.ManagedTransforms.Urns.MYSQL_WRITE)) + .put(ORACLE, getUrn(ExternalTransforms.ManagedTransforms.Urns.ORACLE_WRITE)) .build(); /** diff --git a/sdks/python/apache_beam/transforms/managed.py b/sdks/python/apache_beam/transforms/managed.py index 03449236ac92..2b6aaddfd815 100644 --- a/sdks/python/apache_beam/transforms/managed.py +++ b/sdks/python/apache_beam/transforms/managed.py @@ -87,6 +87,7 @@ BIGQUERY = "bigquery" POSTGRES = "postgres" MYSQL = "mysql" +ORACLE = "oracle" __all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"] @@ -100,6 +101,7 @@ class Read(PTransform): BIGQUERY: ManagedTransforms.Urns.BIGQUERY_READ.urn, POSTGRES: ManagedTransforms.Urns.POSTGRES_READ.urn, MYSQL: ManagedTransforms.Urns.MYSQL_READ.urn, + ORACLE: ManagedTransforms.Urns.ORACLE_READ.urn, } def __init__( @@ -143,6 +145,7 @@ class Write(PTransform): BIGQUERY: ManagedTransforms.Urns.BIGQUERY_WRITE.urn, POSTGRES: ManagedTransforms.Urns.POSTGRES_WRITE.urn, MYSQL: ManagedTransforms.Urns.MYSQL_WRITE.urn, + ORACLE: ManagedTransforms.Urns.ORACLE_WRITE.urn, } def __init__(