diff --git a/wayang-commons/wayang-basic/pom.xml b/wayang-commons/wayang-basic/pom.xml index 1d1b460ae..f8ce0fe0e 100644 --- a/wayang-commons/wayang-basic/pom.xml +++ b/wayang-commons/wayang-basic/pom.xml @@ -120,6 +120,11 @@ 20231013 + + org.apache.calcite + calcite-core + ${calcite.version} + com.azure azure-storage-blob diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java new file mode 100644 index 000000000..0b556519f --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/TableSink.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.plan.wayangplan.UnarySink; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Properties; + +/** + * {@link UnarySink} that writes Records to a database table. + */ + +public class TableSink extends UnarySink { + private final String tableName; + + private String[] columnNames; + + private final Properties props; + + private String mode; + + /** + * Creates a new instance. + * + * @param props database connection properties + * @param mode write mode + * @param tableName name of the table to be written + * @param columnNames names of the columns in the tables + */ + public TableSink(Properties props, String mode, String tableName, String... columnNames) { + this(props, mode, tableName, columnNames, (DataSetType) DataSetType.createDefault(Record.class)); + } + + public TableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(type); + this.tableName = tableName; + this.columnNames = columnNames; + this.props = props; + this.mode = mode; + } + + /** + * Copies an instance (exclusive of broadcasts). + * + * @param that that should be copied + */ + public TableSink(TableSink that) { + super(that); + this.tableName = that.getTableName(); + this.columnNames = that.getColumnNames(); + this.props = that.getProperties(); + this.mode = that.getMode(); + } + + public String getTableName() { + return this.tableName; + } + + protected void setColumnNames(String[] columnNames) { + this.columnNames = columnNames; + } + + public String[] getColumnNames() { + return this.columnNames; + } + + public Properties getProperties() { + return this.props; + } + + public String getMode() { + return mode; + } + + public void setMode(String mode) { + this.mode = mode; + } +} diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java new file mode 100644 index 000000000..541600b71 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/SqlTypeUtils.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; + +import java.lang.reflect.Field; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Utility for mapping Java types to SQL types across different dialects. + */ +public class SqlTypeUtils { + + private static final Map, String>> dialectTypeMaps = new HashMap<>(); + + static { + // Default mappings (Standard SQL) + Map, String> defaultMap = new HashMap<>(); + defaultMap.put(Integer.class, "INT"); + defaultMap.put(int.class, "INT"); + defaultMap.put(Long.class, "BIGINT"); + defaultMap.put(long.class, "BIGINT"); + defaultMap.put(Double.class, "DOUBLE"); + defaultMap.put(double.class, "DOUBLE"); + defaultMap.put(Float.class, "FLOAT"); + defaultMap.put(float.class, "FLOAT"); + defaultMap.put(Boolean.class, "BOOLEAN"); + defaultMap.put(boolean.class, "BOOLEAN"); + defaultMap.put(String.class, "VARCHAR(255)"); + defaultMap.put(Date.class, "DATE"); + defaultMap.put(LocalDate.class, "DATE"); + defaultMap.put(Timestamp.class, "TIMESTAMP"); + defaultMap.put(LocalDateTime.class, "TIMESTAMP"); + + dialectTypeMaps.put(SqlDialect.DatabaseProduct.UNKNOWN, defaultMap); + + // PostgreSQL Overrides + Map, String> pgMap = new HashMap<>(defaultMap); + pgMap.put(Double.class, "DOUBLE PRECISION"); + pgMap.put(double.class, "DOUBLE PRECISION"); + dialectTypeMaps.put(SqlDialect.DatabaseProduct.POSTGRESQL, pgMap); + + // Add more dialects here as needed (MySQL, Oracle, etc.) + } + + /** + * Detects the database product from a JDBC URL. + * + * @param url JDBC URL + * @return detected DatabaseProduct + */ + public static SqlDialect.DatabaseProduct detectProduct(String url) { + if (url == null) + return SqlDialect.DatabaseProduct.UNKNOWN; + String lowerUrl = url.toLowerCase(); + if (lowerUrl.contains("postgresql") || lowerUrl.contains("postgres")) + return SqlDialect.DatabaseProduct.POSTGRESQL; + if (lowerUrl.contains("mysql")) + return SqlDialect.DatabaseProduct.MYSQL; + if (lowerUrl.contains("oracle")) + return SqlDialect.DatabaseProduct.ORACLE; + if (lowerUrl.contains("sqlite")) { + try { + return SqlDialect.DatabaseProduct.valueOf("SQLITE"); + } catch (Exception e) { + return SqlDialect.DatabaseProduct.UNKNOWN; + } + } + if (lowerUrl.contains("h2")) + return SqlDialect.DatabaseProduct.H2; + if (lowerUrl.contains("derby")) + return SqlDialect.DatabaseProduct.DERBY; + if (lowerUrl.contains("mssql") || lowerUrl.contains("sqlserver")) + return SqlDialect.DatabaseProduct.MSSQL; + return SqlDialect.DatabaseProduct.UNKNOWN; + } + + /** + * Returns the SQL type for a given Java class and database product. + * + * @param cls Java class + * @param product database product + * @return SQL type string + */ + public static String getSqlType(Class cls, SqlDialect.DatabaseProduct product) { + Map, String> typeMap = dialectTypeMaps.getOrDefault(product, + dialectTypeMaps.get(SqlDialect.DatabaseProduct.UNKNOWN)); + return typeMap.getOrDefault(cls, "VARCHAR(255)"); + } + + /** + * Extracts schema information from a POJO class or a Record. + * + * @param cls POJO class + * @param product database product + * @return a list of schema fields + */ + public static List getSchema(Class cls, SqlDialect.DatabaseProduct product) { + List schema = new ArrayList<>(); + if (cls == Record.class) { + // For Record.class without an instance, we can't derive names/types easily + // Users should use the instance-based getSchema or provide columnNames + return schema; + } + + for (Field field : cls.getDeclaredFields()) { + if (java.lang.reflect.Modifier.isStatic(field.getModifiers())) { + continue; + } + schema.add(new SchemaField(field.getName(), field.getType(), getSqlType(field.getType(), product))); + } + return schema; + } + + /** + * Extracts schema information from a Record instance by inspecting its fields. + * + * @param record representative record + * @param product database product + * @param userNames optional user-provided column names + * @return a list of schema fields + */ + public static List getSchema(Record record, SqlDialect.DatabaseProduct product, String[] userNames) { + List schema = new ArrayList<>(); + if (record == null) + return schema; + + int size = record.size(); + for (int i = 0; i < size; i++) { + String name = (userNames != null && i < userNames.length) ? userNames[i] : "c_" + i; + Object val = record.getField(i); + Class typeClass = val != null ? val.getClass() : String.class; + String type = getSqlType(typeClass, product); + schema.add(new SchemaField(name, typeClass, type)); + } + return schema; + } + + public static class SchemaField { + private final String name; + private final Class javaClass; + private final String sqlType; + + public SchemaField(String name, Class javaClass, String sqlType) { + this.name = name; + this.javaClass = javaClass; + this.sqlType = sqlType; + } + + public String getName() { + return name; + } + + public Class getJavaClass() { + return javaClass; + } + + public String getSqlType() { + return sqlType; + } + } +} diff --git a/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java new file mode 100644 index 000000000..28e043e12 --- /dev/null +++ b/wayang-commons/wayang-basic/src/test/java/org/apache/wayang/basic/util/SqlTypeUtilsTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.calcite.sql.SqlDialect; +import org.apache.wayang.basic.data.Record; +import org.junit.jupiter.api.Test; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SqlTypeUtilsTest { + + @Test + public void testDetectProduct() { + assertEquals(SqlDialect.DatabaseProduct.POSTGRESQL, + SqlTypeUtils.detectProduct("jdbc:postgresql://localhost:5432/db")); + assertEquals(SqlDialect.DatabaseProduct.MYSQL, SqlTypeUtils.detectProduct("jdbc:mysql://localhost:3306/db")); + assertEquals(SqlDialect.DatabaseProduct.ORACLE, + SqlTypeUtils.detectProduct("jdbc:oracle:thin:@localhost:1521:xe")); + assertEquals(SqlDialect.DatabaseProduct.H2, SqlTypeUtils.detectProduct("jdbc:h2:mem:test")); + assertEquals(SqlDialect.DatabaseProduct.DERBY, + SqlTypeUtils.detectProduct("jdbc:derby:memory:test;create=true")); + assertEquals(SqlDialect.DatabaseProduct.MSSQL, + SqlTypeUtils.detectProduct("jdbc:sqlserver://localhost:1433;databaseName=db")); + assertEquals(SqlDialect.DatabaseProduct.UNKNOWN, SqlTypeUtils.detectProduct("jdbc:unknown:db")); + } + + @Test + public void testGetSqlTypeDefault() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.UNKNOWN; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("INT", SqlTypeUtils.getSqlType(int.class, product)); + assertEquals("BIGINT", SqlTypeUtils.getSqlType(Long.class, product)); + assertEquals("DOUBLE", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + assertEquals("DATE", SqlTypeUtils.getSqlType(Date.class, product)); + assertEquals("TIMESTAMP", SqlTypeUtils.getSqlType(Timestamp.class, product)); + } + + @Test + public void testGetSqlTypePostgres() { + SqlDialect.DatabaseProduct product = SqlDialect.DatabaseProduct.POSTGRESQL; + assertEquals("INT", SqlTypeUtils.getSqlType(Integer.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(Double.class, product)); + assertEquals("DOUBLE PRECISION", SqlTypeUtils.getSqlType(double.class, product)); + assertEquals("VARCHAR(255)", SqlTypeUtils.getSqlType(String.class, product)); + } + + @Test + public void testGetSchema() { + List schema = SqlTypeUtils.getSchema(TestPojo.class, + SqlDialect.DatabaseProduct.POSTGRESQL); + assertEquals(3, schema.size()); + + assertEquals("id", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + + assertEquals("name", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + + assertEquals("value", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + } + + @Test + public void testGetSchemaRecord() { + Record record = new Record(1, "test", 1.5); + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + null); + + assertEquals(3, schema.size()); + assertEquals("c_0", schema.get(0).getName()); + assertEquals("INT", schema.get(0).getSqlType()); + assertEquals(Integer.class, schema.get(0).getJavaClass()); + + assertEquals("c_1", schema.get(1).getName()); + assertEquals("VARCHAR(255)", schema.get(1).getSqlType()); + assertEquals(String.class, schema.get(1).getJavaClass()); + + assertEquals("c_2", schema.get(2).getName()); + assertEquals("DOUBLE PRECISION", schema.get(2).getSqlType()); + assertEquals(Double.class, schema.get(2).getJavaClass()); + } + + @Test + public void testGetSchemaRecordWithNames() { + Record record = new Record(1, "test"); + String[] names = { "id", "description" }; + List schema = SqlTypeUtils.getSchema(record, SqlDialect.DatabaseProduct.POSTGRESQL, + names); + + assertEquals(2, schema.size()); + assertEquals("id", schema.get(0).getName()); + assertEquals("description", schema.get(1).getName()); + } + + public static class TestPojo { + public int id; + public String name; + public Double value; + } +} diff --git a/wayang-platforms/wayang-java/pom.xml b/wayang-platforms/wayang-java/pom.xml index 9c58a78fb..70966b92d 100644 --- a/wayang-platforms/wayang-java/pom.xml +++ b/wayang-platforms/wayang-java/pom.xml @@ -78,7 +78,19 @@ log4j-slf4j-impl 2.20.0 + + org.postgresql + postgresql + 42.7.2 + test + + + com.h2database + h2 + 2.2.224 + test + org.mockito mockito-core diff --git a/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java new file mode 100644 index 000000000..f8fc02c32 --- /dev/null +++ b/wayang-platforms/wayang-java/src/main/java/org/apache/wayang/java/operators/JavaTableSink.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.ReflectionUtils; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.channels.CollectionChannel; +import org.apache.wayang.java.channels.JavaChannelInstance; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + +public class JavaTableSink extends TableSink implements JavaExecutionOperator { + + private void setRecordValue(PreparedStatement ps, int index, Object value) throws SQLException { + if (value == null) { + ps.setNull(index, java.sql.Types.NULL); + } else if (value instanceof Integer) { + ps.setInt(index, (Integer) value); + } else if (value instanceof Long) { + ps.setLong(index, (Long) value); + } else if (value instanceof Double) { + ps.setDouble(index, (Double) value); + } else if (value instanceof Float) { + ps.setFloat(index, (Float) value); + } else if (value instanceof Boolean) { + ps.setBoolean(index, (Boolean) value); + } else if (value instanceof java.sql.Date) { + ps.setDate(index, (java.sql.Date) value); + } else if (value instanceof java.sql.Timestamp) { + ps.setTimestamp(index, (java.sql.Timestamp) value); + } else { + ps.setString(index, value.toString()); + } + } + + public JavaTableSink(Properties props, String mode, String tableName) { + this(props, mode, tableName, null); + } + + public JavaTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + + } + + public JavaTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + + } + + public JavaTableSink(TableSink that) { + super(that); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + JavaExecutor javaExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + JavaChannelInstance input = (JavaChannelInstance) inputs[0]; + + // The stream is converted to an Iterator so that we can read the first element + // w/o consuming the entire stream. + Iterator recordIterator = input.provideStream().iterator(); + + if (!recordIterator.hasNext()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + // We read the first element to derive the Record schema. + T firstElement = recordIterator.next(); + Class typeClass = this.getType().getDataUnitType().getTypeClass(); + + String url = this.getProperties().getProperty("url"); + org.apache.calcite.sql.SqlDialect.DatabaseProduct product = SqlTypeUtils.detectProduct(url); + + List schemaFields; + if (typeClass != Record.class) { + schemaFields = SqlTypeUtils.getSchema(typeClass, product); + } else { + schemaFields = SqlTypeUtils.getSchema((Record) firstElement, product, this.getColumnNames()); + } + + String[] currentColumnNames = this.getColumnNames(); + if (currentColumnNames == null || currentColumnNames.length == 0) { + currentColumnNames = new String[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + currentColumnNames[i] = schemaFields.get(i).getName(); + } + this.setColumnNames(currentColumnNames); + } + + String[] sqlTypes = new String[currentColumnNames.length]; + for (int i = 0; i < currentColumnNames.length; i++) { + sqlTypes[i] = "VARCHAR(255)"; // Default + for (SqlTypeUtils.SchemaField field : schemaFields) { + if (field.getName().equals(currentColumnNames[i])) { + sqlTypes[i] = field.getSqlType(); + break; + } + } + } + + final String[] finalColumnNames = currentColumnNames; + final String[] finalSqlTypes = sqlTypes; + + this.getProperties().setProperty("streamingBatchInsert", "True"); + + Connection conn; + try { + Class.forName(this.getProperties().getProperty("driver")); + conn = DriverManager.getConnection(this.getProperties().getProperty("url"), this.getProperties()); + conn.setAutoCommit(false); + + Statement stmt = conn.createStatement(); + + // Drop existing table if the mode is 'overwrite'. + if (this.getMode().equals("overwrite")) { + stmt.execute("DROP TABLE IF EXISTS " + this.getTableName()); + } + + // Create a new table if the specified table name does not exist yet. + StringBuilder sb = new StringBuilder(); + sb.append("CREATE TABLE IF NOT EXISTS ").append(this.getTableName()).append(" ("); + String separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\" ").append(finalSqlTypes[i]); + separator = ", "; + } + sb.append(")"); + stmt.execute(sb.toString()); + + // Create a prepared statement to insert value from the recordIterator. + sb = new StringBuilder(); + sb.append("INSERT INTO ").append(this.getTableName()).append(" ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("\"").append(finalColumnNames[i]).append("\""); + separator = ", "; + } + sb.append(") VALUES ("); + separator = ""; + for (int i = 0; i < finalColumnNames.length; i++) { + sb.append(separator).append("?"); + separator = ", "; + } + sb.append(")"); + PreparedStatement ps = conn.prepareStatement(sb.toString()); + + // The schema Record has to be pushed to the database too. + this.pushToStatement(ps, firstElement, typeClass, finalColumnNames); + ps.addBatch(); + + // Iterate through all remaining records and add them to the prepared statement + recordIterator.forEachRemaining( + r -> { + try { + this.pushToStatement(ps, r, typeClass, finalColumnNames); + ps.addBatch(); + } catch (SQLException e) { + e.printStackTrace(); + } + }); + + ps.executeBatch(); + conn.commit(); + conn.close(); + } catch (ClassNotFoundException e) { + System.out.println("Please specify a correct database driver."); + e.printStackTrace(); + } catch (SQLException e) { + e.printStackTrace(); + } + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private void pushToStatement(PreparedStatement ps, T element, Class typeClass, String[] columnNames) + throws SQLException { + if (typeClass == Record.class) { + Record r = (Record) element; + for (int i = 0; i < columnNames.length; i++) { + setRecordValue(ps, i + 1, r.getField(i)); + } + } else { + for (int i = 0; i < columnNames.length; i++) { + Object val = ReflectionUtils.getProperty(element, columnNames[i]); + setRecordValue(ps, i + 1, val); + } + } + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.java.tablesink.load"; + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(CollectionChannel.DESCRIPTOR, StreamChannel.DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + +} diff --git a/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java new file mode 100644 index 000000000..02b719e0f --- /dev/null +++ b/wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTableSinkTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.java.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.java.channels.StreamChannel; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.platform.JavaPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Properties; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link JavaTableSink}. + */ +class JavaTableSinkTest extends JavaExecutionOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + inputChannelInstance.accept(Stream.of(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance inputChannelInstance = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + inputChannelInstance.accept(Stream.of(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + JavaTableSink sink1 = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input1 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input1.accept(Stream.of(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + JavaTableSink sink2 = new JavaTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job2 = mock(Job.class); + when(job2.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor2 = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job2); + + StreamChannel.Instance input2 = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor2, mock(OptimizationContext.OperatorContext.class), 0); + input2.accept(Stream.of(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (id INT, name VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + input.accept(Stream.of(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + JavaTableSink sink = new JavaTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final JavaExecutor javaExecutor = (JavaExecutor) JavaPlatform.getInstance().createExecutor(job); + + StreamChannel.Instance input = (StreamChannel.Instance) StreamChannel.DESCRIPTOR + .createChannel(mock(OutputSlot.class), configuration) + .createInstance(javaExecutor, mock(OptimizationContext.OperatorContext.class), 0); + + input.accept(Stream.of(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file diff --git a/wayang-platforms/wayang-spark/pom.xml b/wayang-platforms/wayang-spark/pom.xml index 1e89fd15e..abdd225d6 100644 --- a/wayang-platforms/wayang-spark/pom.xml +++ b/wayang-platforms/wayang-spark/pom.xml @@ -121,5 +121,18 @@ 4.8 + + org.postgresql + postgresql + 42.7.2 + compile + + + + com.h2database + h2 + 2.2.224 + test + diff --git a/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java new file mode 100644 index 000000000..433e7b199 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/main/java/org/apache/wayang/spark/operators/SparkTableSink.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.basic.operators.TableSink; +import org.apache.wayang.basic.util.SqlTypeUtils; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; + +public class SparkTableSink extends TableSink implements SparkExecutionOperator { + + private SaveMode mode; + + public SparkTableSink(Properties props, String mode, String tableName, String... columnNames) { + super(props, mode, tableName, columnNames); + this.setMode(mode); + } + + public SparkTableSink(Properties props, String mode, String tableName, String[] columnNames, DataSetType type) { + super(props, mode, tableName, columnNames, type); + this.setMode(mode); + } + + public SparkTableSink(TableSink that) { + super(that); + this.setMode(that.getMode()); + } + + @Override + public Tuple, Collection> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == 1; + assert outputs.length == 0; + + JavaRDD recordRDD = ((RddChannel.Instance) inputs[0]).provideRdd(); + Class typeClass = (Class) this.getType().getDataUnitType().getTypeClass(); + SparkSession sparkSession = SparkSession.builder().sparkContext(sparkExecutor.sc.sc()).getOrCreate(); + SQLContext sqlContext = sparkSession.sqlContext(); + + Dataset df; + if (typeClass == Record.class) { + // Records need manual schema handling + if (recordRDD.isEmpty()) { + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + Record first = (Record) recordRDD.first(); + + // Centralized Schema Derivation + List schemaFields = SqlTypeUtils.getSchema(first, + SqlTypeUtils.detectProduct(this.getProperties().getProperty("url")), + this.getColumnNames()); + + // Map Record to Row + JavaRDD rowRDD = recordRDD.map(rec -> RowFactory.create(((Record) rec).getValues())); + + // Build Spark Schema + StructField[] fields = new StructField[schemaFields.size()]; + for (int i = 0; i < schemaFields.size(); i++) { + SqlTypeUtils.SchemaField sf = schemaFields.get(i); + org.apache.spark.sql.types.DataType sparkType = getSparkDataType(sf.getJavaClass()); + fields[i] = new StructField(sf.getName(), sparkType, true, Metadata.empty()); + } + + // Update column names in the operator if they were generated + String[] newColNames = schemaFields.stream().map(SqlTypeUtils.SchemaField::getName).toArray(String[]::new); + this.setColumnNames(newColNames); + + df = sqlContext.createDataFrame(rowRDD, new StructType(fields)); + } else { + // POJO Case: Let Spark handle it natively + df = sqlContext.createDataFrame(recordRDD, typeClass); + // If columnNames are provided, we should probably select/rename them, + // but usually createDataFrame(rdd, beanClass) maps fields to columns. + if (this.getColumnNames() != null && this.getColumnNames().length > 0) { + // Optionally filter or reorder columns to match this.getColumnNames() + // For now, Spark's native mapping is preferred. + } + } + + this.getProperties().setProperty("batchSize", "250000"); + df.write() + .mode(this.mode) + .jdbc(this.getProperties().getProperty("url"), this.getTableName(), this.getProperties()); + + return ExecutionOperator.modelEagerExecution(inputs, outputs, operatorContext); + } + + private org.apache.spark.sql.types.DataType getSparkDataType(Class cls) { + if (cls == Integer.class || cls == int.class) + return DataTypes.IntegerType; + if (cls == Long.class || cls == long.class) + return DataTypes.LongType; + if (cls == Double.class || cls == double.class) + return DataTypes.DoubleType; + if (cls == Float.class || cls == float.class) + return DataTypes.FloatType; + if (cls == Boolean.class || cls == boolean.class) + return DataTypes.BooleanType; + if (cls == java.sql.Date.class || cls == java.time.LocalDate.class) + return DataTypes.DateType; + if (cls == java.sql.Timestamp.class || cls == java.time.LocalDateTime.class) + return DataTypes.TimestampType; + return DataTypes.StringType; + } + + public void setMode(String mode) { + if (mode == null) { + throw new WayangException("Unspecified write mode for SparkTableSink."); + } else if (mode.equals("append")) { + this.mode = SaveMode.Append; + } else if (mode.equals("overwrite")) { + this.mode = SaveMode.Overwrite; + } else if (mode.equals("errorIfExists")) { + this.mode = SaveMode.ErrorIfExists; + } else if (mode.equals("ignore")) { + this.mode = SaveMode.Ignore; + } else { + throw new WayangException( + String.format("Specified write mode for SparkTableSink does not exist: %s", mode)); + } + } + + @Override + public List getSupportedInputChannels(int index) { + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List getSupportedOutputChannels(int index) { + throw new UnsupportedOperationException("This operator has no outputs."); + } + + @Override + public boolean containsAction() { + return true; + } + + @Override + public String getLoadProfileEstimatorConfigurationKey() { + return "rheem.spark.tablesink.load"; + } +} diff --git a/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java new file mode 100644 index 000000000..0197c3749 --- /dev/null +++ b/wayang-platforms/wayang-spark/src/test/java/org/apache/wayang/spark/operators/SparkTableSinkTest.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Record; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.platform.SparkPlatform; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Test suite for {@link SparkTableSink}. + */ +class SparkTableSinkTest extends SparkOperatorTestBase { + + private static final String JDBC_URL = "jdbc:h2:mem:sparktestdb;DB_CLOSE_DELAY=-1"; + private static final String DRIVER = "org.h2.Driver"; + private static final String TABLE_NAME = "spark_test_table"; + + private Connection connection; + + @BeforeEach + void setupTest() throws Exception { + Class.forName(DRIVER); + connection = DriverManager.getConnection(JDBC_URL, "sa", ""); + } + + @AfterEach + void teardownTest() throws Exception { + if (connection != null && !connection.isClosed()) { + try (Statement stmt = connection.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + } + connection.close(); + } + } + + @Test + void testWritingRecordToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name", "value" }, + DataSetType.createDefault(Record.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + Record record1 = new Record(1, "Alice", 100.5); + Record record2 = new Record(2, "Bob", 200.75); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(record1, record2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testWritingPojoToH2() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + null, // schema detected via reflection + DataSetType.createDefault(TestPojo.class)); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + + TestPojo p1 = new TestPojo(1, "Alice"); + TestPojo p2 = new TestPojo(2, "Bob"); + + RddChannel.Instance inputChannelInstance = this.createRddChannelInstance( + Arrays.asList(p1, p2)); + + evaluate(sink, new ChannelInstance[] { inputChannelInstance }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " ORDER BY \"id\"")) { + rs.next(); + assertEquals(1, rs.getInt("id")); + assertEquals("Alice", rs.getString("name")); + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals("Bob", rs.getString("name")); + } + } + + @Test + void testAppendMode() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Initial write (overwrite) + SparkTableSink sink1 = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input1 = this.createRddChannelInstance(Arrays.asList(new Record(1, "Alice"))); + evaluate(sink1, new ChannelInstance[] { input1 }, new ChannelInstance[0]); + + // 2. Append write + SparkTableSink sink2 = new SparkTableSink<>(dbProps, "append", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input2 = this.createRddChannelInstance(Arrays.asList(new Record(2, "Bob"))); + evaluate(sink2, new ChannelInstance[] { input2 }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT COUNT(*) FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt(1)); + } + } + + @Test + void testOverwriteWithSchemaMismatch() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + // 1. Create table with old schema (id, name) + try (Statement stmt = connection.createStatement()) { + stmt.execute("CREATE TABLE " + TABLE_NAME + " (\"id\" INT, \"name\" VARCHAR(255))"); + stmt.execute("INSERT INTO " + TABLE_NAME + " VALUES (1, 'Old')"); + } + + // 2. Overwrite with new schema (id, age, city) + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "age", "city" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(2, 30, "Berlin"))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME)) { + rs.next(); + assertEquals(2, rs.getInt("id")); + assertEquals(30, rs.getInt("age")); + assertEquals("Berlin", rs.getString("city")); + + // Verify 'name' column is gone + boolean hasName = false; + for (int i = 1; i <= rs.getMetaData().getColumnCount(); i++) { + if ("name".equalsIgnoreCase(rs.getMetaData().getColumnName(i))) { + hasName = true; + } + } + assertFalse(hasName, "Column 'name' should have been dropped"); + } + } + + @Test + void testNullValues() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "name" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, null))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT \"name\" FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertEquals(null, rs.getString(1)); + assertTrue(rs.wasNull()); + } + } + + @Test + void testSupportedTypes() throws Exception { + Configuration configuration = new Configuration(); + Properties dbProps = new Properties(); + dbProps.setProperty("url", JDBC_URL); + dbProps.setProperty("user", "sa"); + dbProps.setProperty("password", ""); + dbProps.setProperty("driver", DRIVER); + + SparkTableSink sink = new SparkTableSink<>(dbProps, "overwrite", TABLE_NAME, + new String[] { "id", "is_active", "salary", "score" }, + DataSetType.createDefault(Record.class)); + + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList(new Record(1, true, 5000.50, 95.5f))); + evaluate(sink, new ChannelInstance[] { input }, new ChannelInstance[0]); + + try (Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT * FROM " + TABLE_NAME + " WHERE \"id\" = 1")) { + rs.next(); + assertTrue(rs.getBoolean("is_active")); + assertEquals(5000.50, rs.getDouble("salary"), 0.001); + assertEquals(95.5f, rs.getFloat("score"), 0.001f); + } + } + + public static class TestPojo implements java.io.Serializable { + private int id; + private String name; + + public TestPojo() { + } + + public TestPojo(int id, String name) { + this.id = id; + this.name = name; + } + + public int getId() { + return id; + } + + public String getName() { + return name; + } + } +} \ No newline at end of file