From a97403c38738b92b04506b4456700bbbac9823da Mon Sep 17 00:00:00 2001 From: harry Date: Fri, 19 Dec 2025 17:03:01 +0100 Subject: [PATCH 1/8] introduced table sink operator, added java & spark platform implementations and simple tests --- .../wayang/basic/operators/TableSink.java | 109 ++++++++++ wayang-platforms/wayang-java/pom.xml | 6 + .../wayang/java/operators/JavaTableSink.java | 203 ++++++++++++++++++ .../java/operators/JavaTableSinkTest.java | 133 ++++++++++++ wayang-platforms/wayang-spark/pom.xml | 7 + .../spark/operators/SparkTableSink.java | 155 +++++++++++++ .../spark/operators/SparkTableSinkTest.java | 121 +++++++++++ 7 files changed, 734 insertions(+) create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java create mode 100644 wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java create mode 100644 wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java create mode 100644 wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java create mode 100644 wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java new file mode 100644 index 000000000..967bdaa48 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.types.RecordType; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Properties; + +/** + * {@link UnarySink} that writes Records to a database table. + */ + +public class TableSink extends UnarySink { + private final String tableName; + + private String[] columnNames; + + private final Properties props; + + private String mode; + + /** + * Creates a new instance. + * + * @param props database connection properties + * @param tableName name of the table to be written + * @param columnNames names of the columns in the tables + */ + public TableSink(Properties props, String mode, String tableName, String... columnNames) { + this(props, mode, tableName, columnNames, DataSetType.createDefault(Record.class)); + } + + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(type); + this.tableName = tableName; + this.columnNames = columnNames; + this.props = props; + this.mode = mode; + } + + /** + * Copies an instance (exclusive of broadcasts). + * + * @param that that should be copied + */ + public TableSink(TableSink that) { + super(that); + this.tableName = that.getTableName(); + this.columnNames = that.getColumnNames(); + this.props = that.getProperties(); + this.mode = that.getMode(); + } + + public String getTableName() { + return this.tableName; + } + + protected void setColumnNames(String[] columnNames) { + this.columnNames = columnNames; + } + + public String[] getColumnNames() { + return this.columnNames; + } + + public Properties getProperties() { + return this.props; + } + + public String getMode() { + return mode; + } + + public void setMode(String mode) { + this.mode = mode; + } + + /** + * Constructs an appropriate output {@link DataSetType} for the given column names. + * + * @param columnNames the column names or an empty array if unknown + * @return the output {@link DataSetType}, which will be based upon a {@link RecordType} unless no {@code columnNames} + * is empty + */ + private static DataSetType createOutputDataSetType(String[] columnNames) { + return columnNames.length == 0 ? + DataSetType.createDefault(Record.class) : + DataSetType.createDefault(new RecordType(columnNames)); + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 9c58a78fb..b1ffca4c2 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -78,6 +78,12 @@ log4j-slf4j-impl 2.20.0 + + org.postgresql + postgresql + 42.7.2 + test + org.mockito diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java new file mode 100644 index 000000000..8c4949689 --- /dev/null +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.wayang.java.channels.JavaChannelInstance; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + + +public class JavaTableSink extends TableSink implements JavaExecutionOperator { + + private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { + if (value == null) { + ps.setNull(index, java.sql.Types.NULL); + } else if (value instanceof Integer) { + ps.setInt(index, (Integer) value); + } else if (value instanceof Long) { + ps.setLong(index, (Long) value); + } else if (value instanceof Double) { + ps.setDouble(index, (Double) value); + } else if (value instanceof Float) { + ps.setFloat(index, (Float) value); + } else if (value instanceof Boolean) { + ps.setBoolean(index, (Boolean) value); + } else { + ps.setString(index, value.toString()); + } + } + + public JavaTableSink(Properties props, String mode, String tableName) { + this(props, mode, tableName, null); + } + + public JavaTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + + } + + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + + } + + public JavaTableSink(TableSink that) { + super(that); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + JavaExecutor javaExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + JavaChannelInstance input = (JavaChannelInstance) inputs[0]; + + // The stream is converted to an Iterator so that we can read the first element w/o consuming the entire stream. + Iterator recordIterator = input.provideStream().iterator(); + // We read the first element to derive the Record schema. + Record schemaRecord = recordIterator.next(); + + // We assume that all records have the same length and only check the first record. + int recordLength = schemaRecord.size(); + if (this.getColumnNames() != null) { + assert recordLength == this.getColumnNames().length; + } else { + String[] columnNames = new String[recordLength]; + for (int i = 0; i < recordLength; i++) { + columnNames[i] = "c_" + i; + } + this.setColumnNames(columnNames); + } + + // TODO: Check if we need this property. + this.getProperties().setProperty("streamingBatchInsert", "True"); + + Connection conn; + try { + Class.forName(this.getProperties().getProperty("driver")); + conn = DriverManager.getConnection(this.getProperties().getProperty("url"), this.getProperties()); + conn.setAutoCommit(false); + + Statement stmt = conn.createStatement(); + + // Drop existing table if the mode is 'overwrite'. + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + this.getTableName()); + } + + // Create a new table if the specified table name does not exist yet. + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); + String separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append(this.getColumnNames()[i]).append(" VARCHAR(255)"); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + // Create a prepared statement to insert value from the recordIterator. + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); + separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append(this.getColumnNames()[i]); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < recordLength; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + PreparedStatement ps = conn.prepareStatement(sb.toString()); + + // The schema Record has to be pushed to the database too. + for (int i = 0; i < recordLength; i++) { + setRecordValue(ps, i + 1, schemaRecord.getField(i)); + } + ps.addBatch(); + + // Iterate through all remaining records and add them to the prepared statement + recordIterator.forEachRemaining( + r -> { + try { + for (int i = 0; i < recordLength; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + ps.addBatch(); + } catch (SQLException e) { + e.printStackTrace(); + } + } + ); + + ps.executeBatch(); + conn.commit(); + conn.close(); + } catch (ClassNotFoundException e) { + System.out.println("Please specify a correct database driver."); + e.printStackTrace(); + } catch (SQLException e) { + e.printStackTrace(); + } + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.java.tablesink.load"; + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(CollectionChannel.DESCRIPTOR, StreamChannel.DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } +} diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java new file mode 100644 index 000000000..565cd9e20 --- /dev/null +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.platform.JavaPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link JavaTableSink}. + */ +class JavaTableSinkTest extends JavaExecutionOperatorTestBase { + + private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; + private static final String USERNAME = "postgres"; + private static final String PASSWORD = "123456"; + private static final String TABLE_NAME = "test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + // Load PostgreSQL driver + Class.forName("org.postgresql.Driver"); + + // Connect to database + connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); + + // Create test table + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + //stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(100), value DOUBLE PRECISION)"); + } + } + + @AfterEach + void teardownTest() throws Exception { + // Clean up test table + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingToPostgres() throws Exception { + Configuration configuration = new Configuration(); + + // Configure database properties + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", USERNAME); + dbProps.setProperty("password", PASSWORD); + dbProps.setProperty("driver", "org.postgresql.Driver"); + + JavaTableSink sink = new JavaTableSink(dbProps, "overwrite", TABLE_NAME, + new String[]{"id", "name", "value"}, + DataSetType.createDefault(org.apache.wayang.basic.data.Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + // Create input channel with test data + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + // Create test records + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + Record record3 = new Record(3, "Charlie", 300.25); + + inputChannelInstance.accept(Stream.of(record1, record2, record3)); + + // Execute the sink + evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + + // Verify data was written to database + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(3, rs.getInt(1), "Should have written 3 records"); + } + + // Verify specific record + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + rs.next(); + assertEquals("Alice", rs.getString("name")); + assertEquals(100.5, rs.getDouble("value"), 0.01); + } + } +} \ No newline at end of file diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 1e89fd15e..3863f952c 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,5 +121,12 @@ 4.8 + + org.postgresql + postgresql + 42.7.2 + compile + + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java new file mode 100644 index 000000000..ae25f3cb6 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; + +public class SparkTableSink extends TableSink implements SparkExecutionOperator { + + private SaveMode mode; + + private org.apache.spark.sql.types.DataType getDataType(Object value) { + if (value == null) return DataTypes.StringType; + if (value instanceof Integer) return DataTypes.IntegerType; + if (value instanceof Long) return DataTypes.LongType; + if (value instanceof Double) return DataTypes.DoubleType; + if (value instanceof Float) return DataTypes.FloatType; + if (value instanceof Boolean) return DataTypes.BooleanType; + return DataTypes.StringType; + } + + public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + this.setMode(mode); + } + + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + this.setMode(mode); + } + + public SparkTableSink(TableSink that) { + super(that); + this.setMode(that.getMode()); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + + //nothing to write if rdd empty + recordRDD.cache(); + + boolean isEmpty = recordRDD.isEmpty(); + + if (!isEmpty) { + int recordLength = recordRDD.first().size(); + + JavaRDD rowRDD = recordRDD.map(record -> { + Object[] values = record.getValues(); + return RowFactory.create(values); + }); + + StructField[] fields = new StructField[recordLength]; + Record firstRecord = recordRDD.first(); + for (int i = 0; i < recordLength; i++) { + Object value = firstRecord.getField(i); + org.apache.spark.sql.types.DataType dataType = getDataType(value); + fields[i] = new StructField(this.getColumnNames()[i], dataType, true, Metadata.empty()); + } + StructType schema = new StructType(fields); + + SQLContext sqlcontext = new SQLContext(sparkExecutor.sc.sc()); + Dataset dataSet = sqlcontext.createDataFrame(rowRDD, schema); + this.getProperties().setProperty("batchSize", "250000"); + dataSet.write().mode(this.mode).jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + } else { + System.out.println("RDD is empty, nothing to write!"); + } + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + public void setMode(String mode) { + if (mode == null) { + throw new WayangException("Unspecified write mode for SparkTableSink."); + } else if (mode.equals("append")) { + this.mode = SaveMode.Append; + } else if (mode.equals("overwrite")) { + this.mode = SaveMode.Overwrite; + } else if (mode.equals("errorIfExists")) { + this.mode = SaveMode.ErrorIfExists; + } else if (mode.equals("ignore")) { + this.mode = SaveMode.Ignore; + } else { + throw new WayangException(String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + } + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.spark.tablesink.load"; + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java new file mode 100644 index 000000000..78665e90c --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.platform.SparkPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link SparkTableSink}. + */ +class SparkTableSinkTest extends SparkOperatorTestBase { + + private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; + private static final String USERNAME = "postgres"; + private static final String PASSWORD = "123456"; + private static final String TABLE_NAME = "spark_test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + // Load PostgreSQL driver + Class.forName("org.postgresql.Driver"); + + // Connect to database + connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); + } + + @AfterEach + void teardownTest() throws Exception { + // Clean up test table + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingToPostgres() throws Exception { + Configuration configuration = new Configuration(); + + // Configure database properties + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", USERNAME); + dbProps.setProperty("password", PASSWORD); + dbProps.setProperty("driver", "org.postgresql.Driver"); + + SparkTableSink sink = new SparkTableSink(dbProps, "overwrite", TABLE_NAME, + new String[]{"id", "name", "value"}, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + // Create input RDD with test data + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + Record record3 = new Record(3, "Charlie", 300.25); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(record1, record2, record3) + ); + + // Execute the sink + evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + + // Verify data was written to database + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(3, rs.getInt(1), "Should have written 3 records"); + } + + // Verify specific record + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + rs.next(); + assertEquals("Alice", rs.getString("name")); + assertEquals(100.5, rs.getDouble("value"), 0.01); + } + } +} \ No newline at end of file From a24e9abd2f5748f358a63c1054b4bf93d2e3ab97 Mon Sep 17 00:00:00 2001 From: Sujay Barui Date: Fri, 9 Jan 2026 00:44:00 +0530 Subject: [PATCH 2/8] Update .asf.yaml About description and labels --- .asf.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.asf.yaml b/.asf.yaml index 8675b9053..043705f45 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -16,7 +16,7 @@ # github: - description: Apache Wayang is the first cross-platform data processing system. + description: Apache Wayang is a cross-platform data processing system. homepage: https://wayang.apache.org/ labels: - big-data @@ -34,6 +34,7 @@ github: - machine-learning - algorithm - privacy-preserving + - federated-learning features: # Disable wiki for documentation wiki: false From c4d5ebec3f9e39b86c3452a72ae865d24fe4a431 Mon Sep 17 00:00:00 2001 From: Sujay Barui Date: Sat, 10 Jan 2026 19:39:06 +0530 Subject: [PATCH 3/8] Remove federated learning label --- .asf.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.asf.yaml b/.asf.yaml index 043705f45..33094cfd1 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -34,7 +34,11 @@ github: - machine-learning - algorithm - privacy-preserving +<<<<<<< HEAD - federated-learning +======= + +>>>>>>> 25f47fbd (Remove federated learning label) features: # Disable wiki for documentation wiki: false From 6ea8ace94fd13f1d8a5f9917ce6d2b02a9fa0f66 Mon Sep 17 00:00:00 2001 From: harry Date: Mon, 23 Feb 2026 18:17:38 +0100 Subject: [PATCH 4/8] Revert changes to .asf.yaml to match origin/main --- .asf.yaml | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index 33094cfd1..f31e2c7b5 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -16,10 +16,10 @@ # github: - description: Apache Wayang is a cross-platform data processing system. - homepage: https://wayang.apache.org/ + description: Apache Wayang(incubating) is the first cross-platform data processing system. + homepage: https://wayang.incubator.apache.org/ labels: - - big-data + - big-data - apache - data-management-platform - cross-platform @@ -34,11 +34,7 @@ github: - machine-learning - algorithm - privacy-preserving -<<<<<<< HEAD - federated-learning -======= - ->>>>>>> 25f47fbd (Remove federated learning label) features: # Disable wiki for documentation wiki: false From dfe1672f5f5ba92d87c49e28b009325f860df709 Mon Sep 17 00:00:00 2001 From: harry Date: Mon, 23 Feb 2026 21:59:34 +0100 Subject: [PATCH 5/8] Enhance TableSink and add comprehensive tests Introduce generic type support and dialect-aware SQL mapping using Calcite. Add extensive H2 integration tests for Java and Spark covering various edge cases. --- wayang-commons/wayang-basic/pom.xml | 5 + .../wayang/basic/operators/TableSink.java | 23 +- .../wayang/basic/util/SqlTypeUtils.java | 187 ++++++++++++ .../wayang/basic/util/SqlTypeUtilsTest.java | 122 ++++++++ wayang-platforms/wayang-java/pom.xml | 6 + .../wayang/java/operators/JavaTableSink.java | 101 +++++-- .../java/operators/JavaTableSinkTest.java | 269 +++++++++++++++--- wayang-platforms/wayang-spark/pom.xml | 6 + .../spark/operators/SparkTableSink.java | 110 ++++--- .../spark/operators/SparkTableSinkTest.java | 220 ++++++++++++-- 10 files changed, 895 insertions(+), 154 deletions(-) create mode 100644 wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java create mode 100644 wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java diff --git a/wayang-commons/wayang-basic/pom.xml b/wayang-commons/wayang-basic/pom.xml index 1d1b460ae..f8ce0fe0e 100644 --- a/wayang-commons/wayang-basic/pom.xml +++ b/wayang-commons/wayang-basic/pom.xml @@ -120,6 +120,11 @@ 20231013 + + org.apache.calcite + calcite-core + ${calcite.version} + com.azure azure-storage-blob diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java index 967bdaa48..0b556519f 100644 --- a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -19,7 +19,6 @@ package org.apache.wayang.basic.operators; import org.apache.wayang.basic.data.Record; -import org.apache.wayang.basic.types.RecordType; import org.apache.wayang.core.plan.wayangplan.UnarySink; import org.apache.wayang.core.types.DataSetType; @@ -29,7 +28,7 @@ * {@link UnarySink} that writes Records to a database table. */ -public class TableSink extends UnarySink { +public class TableSink extends UnarySink { private final String tableName; private String[] columnNames; @@ -42,14 +41,15 @@ public class TableSink extends UnarySink { * Creates a new instance. * * @param props database connection properties + * @param mode write mode * @param tableName name of the table to be written * @param columnNames names of the columns in the tables */ public TableSink(Properties props, String mode, String tableName, String... columnNames) { - this(props, mode, tableName, columnNames, DataSetType.createDefault(Record.class)); + this(props, mode, tableName, columnNames, (DataSetType) DataSetType.createDefault(Record.class)); } - public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { super(type); this.tableName = tableName; this.columnNames = columnNames; @@ -62,7 +62,7 @@ public TableSink(Properties props, String mode, String tableName, String[] colum * * @param that that should be copied */ - public TableSink(TableSink that) { + public TableSink(TableSink that) { super(that); this.tableName = that.getTableName(); this.columnNames = that.getColumnNames(); @@ -93,17 +93,4 @@ public String getMode() { public void setMode(String mode) { this.mode = mode; } - - /** - * Constructs an appropriate output {@link DataSetType} for the given column names. - * - * @param columnNames the column names or an empty array if unknown - * @return the output {@link DataSetType}, which will be based upon a {@link RecordType} unless no {@code columnNames} - * is empty - */ - private static DataSetType createOutputDataSetType(String[] columnNames) { - return columnNames.length == 0 ? - DataSetType.createDefault(Record.class) : - DataSetType.createDefault(new RecordType(columnNames)); - } } diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java new file mode 100644 index 000000000..541600b71 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; + +import java.lang.reflect.Field; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utility for mapping Java types to SQL types across different dialects. + */ +public class SqlTypeUtils { + + private static final Map, String>> dialectTypeMaps = new HashMap<>(); + + static { + // Default mappings (Standard SQL) + Map, String> defaultMap = new HashMap<>(); + defaultMap.put(Integer.class, "INT"); + defaultMap.put(int.class, "INT"); + defaultMap.put(Long.class, "BIGINT"); + defaultMap.put(long.class, "BIGINT"); + defaultMap.put(Double.class, "DOUBLE"); + defaultMap.put(double.class, "DOUBLE"); + defaultMap.put(Float.class, "FLOAT"); + defaultMap.put(float.class, "FLOAT"); + defaultMap.put(Boolean.class, "BOOLEAN"); + defaultMap.put(boolean.class, "BOOLEAN"); + defaultMap.put(String.class, "VARCHAR(255)"); + defaultMap.put(Date.class, "DATE"); + defaultMap.put(LocalDate.class, "DATE"); + defaultMap.put(Timestamp.class, "TIMESTAMP"); + defaultMap.put(LocalDateTime.class, "TIMESTAMP"); + + dialectTypeMaps.put(SqlDialect.DatabaseProduct.UNKNOWN, defaultMap); + + // PostgreSQL Overrides + Map, String> pgMap = new HashMap<>(defaultMap); + pgMap.put(Double.class, "DOUBLE PRECISION"); + pgMap.put(double.class, "DOUBLE PRECISION"); + dialectTypeMaps.put(SqlDialect.DatabaseProduct.POSTGRESQL, pgMap); + + // Add more dialects here as needed (MySQL, Oracle, etc.) + } + + /** + * Detects the database product from a JDBC URL. + * + * @param url JDBC URL + * @return detected DatabaseProduct + */ + public static SqlDialect.DatabaseProduct detectProduct(String url) { + if (url == null) + return SqlDialect.DatabaseProduct.UNKNOWN; + String lowerUrl = url.toLowerCase(); + if (lowerUrl.contains("postgresql") || lowerUrl.contains("postgres")) + return SqlDialect.DatabaseProduct.POSTGRESQL; + if (lowerUrl.contains("mysql")) + return SqlDialect.DatabaseProduct.MYSQL; + if (lowerUrl.contains("oracle")) + return SqlDialect.DatabaseProduct.ORACLE; + if (lowerUrl.contains("sqlite")) { + try { + return SqlDialect.DatabaseProduct.valueOf("SQLITE"); + } catch (Exception e) { + return SqlDialect.DatabaseProduct.UNKNOWN; + } + } + if (lowerUrl.contains("h2")) + return SqlDialect.DatabaseProduct.H2; + if (lowerUrl.contains("derby")) + return SqlDialect.DatabaseProduct.DERBY; + if (lowerUrl.contains("mssql") || lowerUrl.contains("sqlserver")) + return SqlDialect.DatabaseProduct.MSSQL; + return SqlDialect.DatabaseProduct.UNKNOWN; + } + + /** + * Returns the SQL type for a given Java class and database product. + * + * @param cls Java class + * @param product database product + * @return SQL type string + */ + public static String getSqlType(Class cls, SqlDialect.DatabaseProduct product) { + Map, String> typeMap = dialectTypeMaps.getOrDefault(product, + dialectTypeMaps.get(SqlDialect.DatabaseProduct.UNKNOWN)); + return typeMap.getOrDefault(cls, "VARCHAR(255)"); + } + + /** + * Extracts schema information from a POJO class or a Record. + * + * @param cls POJO class + * @param product database product + * @return a list of schema fields + */ + public static List getSchema(Class cls, SqlDialect.DatabaseProduct product) { + List schema = new ArrayList<>(); + if (cls == Record.class) { + // For Record.class without an instance, we can't derive names/types easily + // Users should use the instance-based getSchema or provide columnNames + return schema; + } + + for (Field field : cls.getDeclaredFields()) { + if (java.lang.reflect.Modifier.isStatic(field.getModifiers())) { + continue; + } + schema.add(new SchemaField(field.getName(), field.getType(), getSqlType(field.getType(), product))); + } + return schema; + } + + /** + * Extracts schema information from a Record instance by inspecting its fields. + * + * @param record representative record + * @param product database product + * @param userNames optional user-provided column names + * @return a list of schema fields + */ + public static List getSchema(Record record, SqlDialect.DatabaseProduct product, String[] userNames) { + List schema = new ArrayList<>(); + if (record == null) + return schema; + + int size = record.size(); + for (int i = 0; i < size; i++) { + String name = (userNames != null && i < userNames.length) ? userNames[i] : "c_" + i; + Object val = record.getField(i); + Class typeClass = val != null ? val.getClass() : String.class; + String type = getSqlType(typeClass, product); + schema.add(new SchemaField(name, typeClass, type)); + } + return schema; + } + + public static class SchemaField { + private final String name; + private final Class javaClass; + private final String sqlType; + + public SchemaField(String name, Class javaClass, String sqlType) { + this.name = name; + this.javaClass = javaClass; + this.sqlType = sqlType; + } + + public String getName() { + return name; + } + + public Class getJavaClass() { + return javaClass; + } + + public String getSqlType() { + return sqlType; + } + } +} diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java new file mode 100644 index 000000000..28e043e12 --- /dev/null +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; +import org.junit.jupiter.api.Test; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SqlTypeUtilsTest { + + @Test + public void testDetectProduct() { + assertEquals(SqlDialect.DatabaseProduct.POSTGRESQL, + SqlTypeUtils.detectProduct("jdbc:postgresql://localhost:5432/db")); + assertEquals(SqlDialect.DatabaseProduct.MYSQL, SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); + assertEquals(SqlDialect.DatabaseProduct.ORACLE, + SqlTypeUtils.detectProduct("jdbc:oracle:thin:@localhost:1521:xe")); + assertEquals(SqlDialect.DatabaseProduct.H2, SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); + assertEquals(SqlDialect.DatabaseProduct.DERBY, + SqlTypeUtils.detectProduct("jdbc:derby:memory:test;create=true")); + assertEquals(SqlDialect.DatabaseProduct.MSSQL, + SqlTypeUtils.detectProduct("jdbc:sqlserver://localhost:1433;databaseName=db")); + assertEquals(SqlDialect.DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); + } + + @Test + public void testGetSqlTypeDefault() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.UNKNOWN; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("INT", SqlTypeUtils.getSqlType(int.class, product)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, product)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + assertEquals("DATE", SqlTypeUtils.getSqlType(Date.class, product)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(Timestamp.class, product)); + } + + @Test + public void testGetSqlTypePostgres() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.POSTGRESQL; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + } + + @Test + public void testGetSchema() { + List schema = SqlTypeUtils.getSchema(TestPojo.class, + SqlDialect.DatabaseProduct.POSTGRESQL); + assertEquals(3, schema.size()); + + assertEquals("id", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + + assertEquals("name", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + + assertEquals("value", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + } + + @Test + public void testGetSchemaRecord() { + Record record = new Record(1, "test", 1.5); + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + null); + + assertEquals(3, schema.size()); + assertEquals("c_0", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + assertEquals(Integer.class, schema.get(0).getJavaClass()); + + assertEquals("c_1", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + assertEquals(String.class, schema.get(1).getJavaClass()); + + assertEquals("c_2", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + assertEquals(Double.class, schema.get(2).getJavaClass()); + } + + @Test + public void testGetSchemaRecordWithNames() { + Record record = new Record(1, "test"); + String[] names = { "id", "description" }; + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + names); + + assertEquals(2, schema.size()); + assertEquals("id", schema.get(0).getName()); + assertEquals("description", schema.get(1).getName()); + } + + public static class TestPojo { + public int id; + public String name; + public Double value; + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index b1ffca4c2..70966b92d 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -85,6 +85,12 @@ test + + com.h2database + h2 + 2.2.224 + test + org.mockito mockito-core diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java index 8c4949689..f8fc02c32 100644 --- a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -20,12 +20,14 @@ import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; import org.apache.wayang.core.optimizer.OptimizationContext; import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; import org.apache.wayang.core.platform.ChannelDescriptor; import org.apache.wayang.core.platform.ChannelInstance; import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.ReflectionUtils; import org.apache.wayang.core.util.Tuple; import org.apache.wayang.java.channels.CollectionChannel; import org.apache.wayang.java.channels.JavaChannelInstance; @@ -43,8 +45,7 @@ import java.util.List; import java.util.Properties; - -public class JavaTableSink extends TableSink implements JavaExecutionOperator { +public class JavaTableSink extends TableSink implements JavaExecutionOperator { private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { if (value == null) { @@ -59,6 +60,10 @@ private void setRecordValue(PreparedStatement ps, int index, Object value) throw ps.setFloat(index, (Float) value); } else if (value instanceof Boolean) { ps.setBoolean(index, (Boolean) value); + } else if (value instanceof java.sql.Date) { + ps.setDate(index, (java.sql.Date) value); + } else if (value instanceof java.sql.Timestamp) { + ps.setTimestamp(index, (java.sql.Timestamp) value); } else { ps.setString(index, value.toString()); } @@ -73,12 +78,12 @@ public JavaTableSink(Properties props, String mode, String tableName, String... } - public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { super(props, mode, tableName, columnNames, type); } - public JavaTableSink(TableSink that) { + public JavaTableSink(TableSink that) { super(that); } @@ -92,24 +97,51 @@ public Tuple, Collection> eval assert outputs.length == 0; JavaChannelInstance input = (JavaChannelInstance) inputs[0]; - // The stream is converted to an Iterator so that we can read the first element w/o consuming the entire stream. - Iterator recordIterator = input.provideStream().iterator(); + // The stream is converted to an Iterator so that we can read the first element + // w/o consuming the entire stream. + Iterator recordIterator = input.provideStream().iterator(); + + if (!recordIterator.hasNext()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + // We read the first element to derive the Record schema. - Record schemaRecord = recordIterator.next(); + T firstElement = recordIterator.next(); + Class typeClass = this.getType().getDataUnitType().getTypeClass(); + + String url = this.getProperties().getProperty("url"); + org.apache.calcite.sql.SqlDialect.DatabaseProduct product = SqlTypeUtils.detectProduct(url); - // We assume that all records have the same length and only check the first record. - int recordLength = schemaRecord.size(); - if (this.getColumnNames() != null) { - assert recordLength == this.getColumnNames().length; + List schemaFields; + if (typeClass != Record.class) { + schemaFields = SqlTypeUtils.getSchema(typeClass, product); } else { - String[] columnNames = new String[recordLength]; - for (int i = 0; i < recordLength; i++) { - columnNames[i] = "c_" + i; + schemaFields = SqlTypeUtils.getSchema((Record) firstElement, product, this.getColumnNames()); + } + + String[] currentColumnNames = this.getColumnNames(); + if (currentColumnNames == null || currentColumnNames.length == 0) { + currentColumnNames = new String[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + currentColumnNames[i] = schemaFields.get(i).getName(); + } + this.setColumnNames(currentColumnNames); + } + + String[] sqlTypes = new String[currentColumnNames.length]; + for (int i = 0; i < currentColumnNames.length; i++) { + sqlTypes[i] = "VARCHAR(255)"; // Default + for (SqlTypeUtils.SchemaField field : schemaFields) { + if (field.getName().equals(currentColumnNames[i])) { + sqlTypes[i] = field.getSqlType(); + break; + } } - this.setColumnNames(columnNames); } - // TODO: Check if we need this property. + final String[] finalColumnNames = currentColumnNames; + final String[] finalSqlTypes = sqlTypes; + this.getProperties().setProperty("streamingBatchInsert", "True"); Connection conn; @@ -129,8 +161,8 @@ public Tuple, Collection> eval StringBuilder sb = new StringBuilder(); sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); String separator = ""; - for (int i = 0; i < recordLength; i++) { - sb.append(separator).append(this.getColumnNames()[i]).append(" VARCHAR(255)"); + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\" ").append(finalSqlTypes[i]); separator = ", "; } sb.append(")"); @@ -140,13 +172,13 @@ public Tuple, Collection> eval sb = new StringBuilder(); sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); separator = ""; - for (int i = 0; i < recordLength; i++) { - sb.append(separator).append(this.getColumnNames()[i]); + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\""); separator = ", "; } sb.append(") VALUES ("); separator = ""; - for (int i = 0; i < recordLength; i++) { + for (int i = 0; i < finalColumnNames.length; i++) { sb.append(separator).append("?"); separator = ", "; } @@ -154,24 +186,19 @@ public Tuple, Collection> eval PreparedStatement ps = conn.prepareStatement(sb.toString()); // The schema Record has to be pushed to the database too. - for (int i = 0; i < recordLength; i++) { - setRecordValue(ps, i + 1, schemaRecord.getField(i)); - } + this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); ps.addBatch(); // Iterate through all remaining records and add them to the prepared statement recordIterator.forEachRemaining( r -> { try { - for (int i = 0; i < recordLength; i++) { - setRecordValue(ps, i + 1, r.getField(i)); - } + this.pushToStatement(ps, r, typeClass, finalColumnNames); ps.addBatch(); } catch (SQLException e) { e.printStackTrace(); } - } - ); + }); ps.executeBatch(); conn.commit(); @@ -186,6 +213,21 @@ public Tuple, Collection> eval return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } + private void pushToStatement(PreparedStatement ps, T element, Class typeClass, String[] columnNames) + throws SQLException { + if (typeClass == Record.class) { + Record r = (Record) element; + for (int i = 0; i < columnNames.length; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + } else { + for (int i = 0; i < columnNames.length; i++) { + Object val = ReflectionUtils.getProperty(element, columnNames[i]); + setRecordValue(ps, i + 1, val); + } + } + } + @Override public String getLoadProfileEstimatorConfigurationKey() { return "rheem.java.tablesink.load"; @@ -200,4 +242,5 @@ public List getSupportedInputChannels(int index) { public List getSupportedOutputChannels(int index) { throw new UnsupportedOperationException("This operator has no outputs."); } + } diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java index 565cd9e20..02b719e0f 100644 --- a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -40,6 +40,8 @@ import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -48,86 +50,279 @@ */ class JavaTableSinkTest extends JavaExecutionOperatorTestBase { - private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; - private static final String USERNAME = "postgres"; - private static final String PASSWORD = "123456"; + private static final String JDBC_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; private static final String TABLE_NAME = "test_table"; private Connection connection; @BeforeEach void setupTest() throws Exception { - // Load PostgreSQL driver - Class.forName("org.postgresql.Driver"); - - // Connect to database - connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); - - // Create test table - try (Statement stmt = connection.createStatement()) { - //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); - //stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(100), value DOUBLE PRECISION)"); - } + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); } @AfterEach void teardownTest() throws Exception { - // Clean up test table if (connection != null && !connection.isClosed()) { try (Statement stmt = connection.createStatement()) { - //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); } connection.close(); } } @Test - void testWritingToPostgres() throws Exception { + void testWritingRecordToH2() throws Exception { Configuration configuration = new Configuration(); - - // Configure database properties Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); - dbProps.setProperty("user", USERNAME); - dbProps.setProperty("password", PASSWORD); - dbProps.setProperty("driver", "org.postgresql.Driver"); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); - JavaTableSink sink = new JavaTableSink(dbProps, "overwrite", TABLE_NAME, - new String[]{"id", "name", "value"}, - DataSetType.createDefault(org.apache.wayang.basic.data.Record.class)); + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); Job job = mock(Job.class); when(job.getConfiguration()).thenReturn(configuration); final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); - // Create input channel with test data StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR .createChannel(mock(OutputSlot.class), configuration) .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); - // Create test records Record record1 = new Record(1, "Alice", 100.5); Record record2 = new Record(2, "Bob", 200.75); - Record record3 = new Record(3, "Charlie", 300.25); - inputChannelInstance.accept(Stream.of(record1, record2, record3)); + inputChannelInstance.accept(Stream.of(record1, record2)); - // Execute the sink - evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); - // Verify data was written to database try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { rs.next(); - assertEquals(3, rs.getInt(1), "Should have written 3 records"); + assertEquals(2, rs.getInt(1)); } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + inputChannelInstance.accept(Stream.of(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); - // Verify specific record try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { rs.next(); + assertEquals(1, rs.getInt("id")); assertEquals("Alice", rs.getString("name")); - assertEquals(100.5, rs.getDouble("value"), 0.01); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + JavaTableSink sink1 = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input1 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input1.accept(Stream.of(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + JavaTableSink sink2 = new JavaTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job2 = mock(Job.class); + when(job2.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor2 = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job2); + + StreamChannel.Instance input2 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor2, mock(OptimizationContext.OperatorContext.class), 0); + input2.accept(Stream.of(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input.accept(Stream.of(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; } } } \ No newline at end of file diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 3863f952c..abdd225d6 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -128,5 +128,11 @@ compile + + com.h2database + h2 + 2.2.224 + test + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java index ae25f3cb6..433e7b199 100644 --- a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -23,12 +23,14 @@ import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; import org.apache.wayang.core.api.exception.WayangException; import org.apache.wayang.core.optimizer.OptimizationContext; import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; @@ -45,31 +47,21 @@ import java.util.List; import java.util.Properties; -public class SparkTableSink extends TableSink implements SparkExecutionOperator { +public class SparkTableSink extends TableSink implements SparkExecutionOperator { private SaveMode mode; - private org.apache.spark.sql.types.DataType getDataType(Object value) { - if (value == null) return DataTypes.StringType; - if (value instanceof Integer) return DataTypes.IntegerType; - if (value instanceof Long) return DataTypes.LongType; - if (value instanceof Double) return DataTypes.DoubleType; - if (value instanceof Float) return DataTypes.FloatType; - if (value instanceof Boolean) return DataTypes.BooleanType; - return DataTypes.StringType; - } - public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { super(props, mode, tableName, columnNames); this.setMode(mode); } - public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { super(props, mode, tableName, columnNames, type); this.setMode(mode); } - public SparkTableSink(TableSink that) { + public SparkTableSink(TableSink that) { super(that); this.setMode(that.getMode()); } @@ -83,40 +75,77 @@ public Tuple, Collection> eval assert inputs.length == 1; assert outputs.length == 0; - JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); - - //nothing to write if rdd empty - recordRDD.cache(); + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + Class typeClass = (Class) this.getType().getDataUnitType().getTypeClass(); + SparkSession sparkSession = SparkSession.builder().sparkContext(sparkExecutor.sc.sc()).getOrCreate(); + SQLContext sqlContext = sparkSession.sqlContext(); - boolean isEmpty = recordRDD.isEmpty(); - - if (!isEmpty) { - int recordLength = recordRDD.first().size(); - - JavaRDD rowRDD = recordRDD.map(record -> { - Object[] values = record.getValues(); - return RowFactory.create(values); - }); - - StructField[] fields = new StructField[recordLength]; - Record firstRecord = recordRDD.first(); - for (int i = 0; i < recordLength; i++) { - Object value = firstRecord.getField(i); - org.apache.spark.sql.types.DataType dataType = getDataType(value); - fields[i] = new StructField(this.getColumnNames()[i], dataType, true, Metadata.empty()); + Dataset df; + if (typeClass == Record.class) { + // Records need manual schema handling + if (recordRDD.isEmpty()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + Record first = (Record) recordRDD.first(); + + // Centralized Schema Derivation + List schemaFields = SqlTypeUtils.getSchema(first, + SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")), + this.getColumnNames()); + + // Map Record to Row + JavaRDD rowRDD = recordRDD.map(rec -> RowFactory.create(((Record) rec).getValues())); + + // Build Spark Schema + StructField[] fields = new StructField[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + SqlTypeUtils.SchemaField sf = schemaFields.get(i); + org.apache.spark.sql.types.DataType sparkType = getSparkDataType(sf.getJavaClass()); + fields[i] = new StructField(sf.getName(), sparkType, true, Metadata.empty()); } - StructType schema = new StructType(fields); - SQLContext sqlcontext = new SQLContext(sparkExecutor.sc.sc()); - Dataset dataSet = sqlcontext.createDataFrame(rowRDD, schema); - this.getProperties().setProperty("batchSize", "250000"); - dataSet.write().mode(this.mode).jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + // Update column names in the operator if they were generated + String[] newColNames = schemaFields.stream().map(SqlTypeUtils.SchemaField::getName).toArray(String[]::new); + this.setColumnNames(newColNames); + + df = sqlContext.createDataFrame(rowRDD, new StructType(fields)); } else { - System.out.println("RDD is empty, nothing to write!"); + // POJO Case: Let Spark handle it natively + df = sqlContext.createDataFrame(recordRDD, typeClass); + // If columnNames are provided, we should probably select/rename them, + // but usually createDataFrame(rdd, beanClass) maps fields to columns. + if (this.getColumnNames() != null && this.getColumnNames().length > 0) { + // Optionally filter or reorder columns to match this.getColumnNames() + // For now, Spark's native mapping is preferred. + } } + + this.getProperties().setProperty("batchSize", "250000"); + df.write() + .mode(this.mode) + .jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); } + private org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { + if (cls == Integer.class || cls == int.class) + return DataTypes.IntegerType; + if (cls == Long.class || cls == long.class) + return DataTypes.LongType; + if (cls == Double.class || cls == double.class) + return DataTypes.DoubleType; + if (cls == Float.class || cls == float.class) + return DataTypes.FloatType; + if (cls == Boolean.class || cls == boolean.class) + return DataTypes.BooleanType; + if (cls == java.sql.Date.class || cls == java.time.LocalDate.class) + return DataTypes.DateType; + if (cls == java.sql.Timestamp.class || cls == java.time.LocalDateTime.class) + return DataTypes.TimestampType; + return DataTypes.StringType; + } + public void setMode(String mode) { if (mode == null) { throw new WayangException("Unspecified write mode for SparkTableSink."); @@ -129,7 +158,8 @@ public void setMode(String mode) { } else if (mode.equals("ignore")) { this.mode = SaveMode.Ignore; } else { - throw new WayangException(String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + throw new WayangException( + String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); } } diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java index 78665e90c..0197c3749 100644 --- a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -38,6 +38,8 @@ import java.util.Properties; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -46,76 +48,234 @@ */ class SparkTableSinkTest extends SparkOperatorTestBase { - private static final String JDBC_URL = "jdbc:postgresql://localhost:5432/default"; - private static final String USERNAME = "postgres"; - private static final String PASSWORD = "123456"; + private static final String JDBC_URL = "jdbc:h2:mem:sparktestdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; private static final String TABLE_NAME = "spark_test_table"; private Connection connection; @BeforeEach void setupTest() throws Exception { - // Load PostgreSQL driver - Class.forName("org.postgresql.Driver"); - - // Connect to database - connection = DriverManager.getConnection(JDBC_URL, USERNAME, PASSWORD); + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); } @AfterEach void teardownTest() throws Exception { - // Clean up test table if (connection != null && !connection.isClosed()) { try (Statement stmt = connection.createStatement()) { - //stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); } connection.close(); } } @Test - void testWritingToPostgres() throws Exception { + void testWritingRecordToH2() throws Exception { Configuration configuration = new Configuration(); - - // Configure database properties Properties dbProps = new Properties(); dbProps.setProperty("url", JDBC_URL); - dbProps.setProperty("user", USERNAME); - dbProps.setProperty("password", PASSWORD); - dbProps.setProperty("driver", "org.postgresql.Driver"); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); - SparkTableSink sink = new SparkTableSink(dbProps, "overwrite", TABLE_NAME, - new String[]{"id", "name", "value"}, + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, DataSetType.createDefault(Record.class)); Job job = mock(Job.class); when(job.getConfiguration()).thenReturn(configuration); - // Create input RDD with test data Record record1 = new Record(1, "Alice", 100.5); Record record2 = new Record(2, "Bob", 200.75); - Record record3 = new Record(3, "Charlie", 300.25); RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( - Arrays.asList(record1, record2, record3) - ); + Arrays.asList(record1, record2)); - // Execute the sink - evaluate(sink, new ChannelInstance[]{inputChannelInstance}, new ChannelInstance[0]); + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); - // Verify data was written to database try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { rs.next(); - assertEquals(3, rs.getInt(1), "Should have written 3 records"); + assertEquals(2, rs.getInt(1)); } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); - // Verify specific record try (Statement stmt = connection.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE id = 1")) { + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { rs.next(); + assertEquals(1, rs.getInt("id")); assertEquals("Alice", rs.getString("name")); - assertEquals(100.5, rs.getDouble("value"), 0.01); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + SparkTableSink sink1 = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input1 = this.createRddChannelInstance(Arrays.asList(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + SparkTableSink sink2 = new SparkTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input2 = this.createRddChannelInstance(Arrays.asList(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (\"id\" INT, \"name\" VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo implements java.io.Serializable { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; } } } \ No newline at end of file From 820233d3674af07c1fc964f881918a4de47287c6 Mon Sep 17 00:00:00 2001 From: harry Date: Tue, 24 Feb 2026 22:05:14 +0100 Subject: [PATCH 6/8] changes inasf.yaml --- .asf.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index f31e2c7b5..47b374a48 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -16,8 +16,8 @@ # github: - description: Apache Wayang(incubating) is the first cross-platform data processing system. - homepage: https://wayang.incubator.apache.org/ + description: Apache Wayang is the first cross-platform data processing system. + homepage: https://wayang.apache.org/ labels: - big-data - apache From a94686b9177fefbc4f2668499c87418f5f4cc3f9 Mon Sep 17 00:00:00 2001 From: harry Date: Tue, 24 Feb 2026 22:06:58 +0100 Subject: [PATCH 7/8] changes in asf.yaml --- .asf.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index 47b374a48..43a40a5cc 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -19,7 +19,7 @@ github: description: Apache Wayang is the first cross-platform data processing system. homepage: https://wayang.apache.org/ labels: - - big-data + - big-data - apache - data-management-platform - cross-platform @@ -34,7 +34,7 @@ github: - machine-learning - algorithm - privacy-preserving - - federated-learning + features: # Disable wiki for documentation wiki: false From 7cd4822fc9730fede0281c6334dc32e865183f95 Mon Sep 17 00:00:00 2001 From: harry Date: Tue, 24 Feb 2026 22:07:16 +0100 Subject: [PATCH 8/8] changes in asf.yaml --- .asf.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.asf.yaml b/.asf.yaml index 43a40a5cc..8675b9053 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -34,7 +34,6 @@ github: - machine-learning - algorithm - privacy-preserving - features: # Disable wiki for documentation wiki: false