From 229643f5c5c7092e1fc82808bb277553d7d6f0e4 Mon Sep 17 00:00:00 2001 From: Makarand Milind Hinge Date: Tue, 24 Feb 2026 22:50:53 +0530 Subject: [PATCH 1/3] feat: Add SQL support for multi-condition joins - Enhanced WayangMultiConditionJoinVisitor with SQL implementation - Updated JdbcJoinOperator to generate AND clauses for multi-column joins - Added test coverage for multi-condition join SQL generation - Maintains backward compatibility with single-condition joins --- .../WayangMultiConditionJoinVisitor.java | 79 +++++++++---------- .../jdbc/operators/JdbcJoinOperator.java | 34 ++++++-- .../jdbc/operators/JdbcJoinOperatorTest.java | 78 ++++++++++++++++++ 3 files changed, 146 insertions(+), 45 deletions(-) diff --git a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangMultiConditionJoinVisitor.java b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangMultiConditionJoinVisitor.java index dfc14462a..aca9be14a 100644 --- a/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangMultiConditionJoinVisitor.java +++ b/wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangMultiConditionJoinVisitor.java @@ -23,6 +23,7 @@ import java.util.stream.Collectors; import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; @@ -30,6 +31,7 @@ import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult; import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinKeyExtractor; import org.apache.wayang.api.sql.calcite.rel.WayangJoin; +import org.apache.wayang.api.sql.calcite.rel.WayangTableScan; import org.apache.wayang.basic.data.Record; import org.apache.wayang.basic.data.Tuple2; import org.apache.wayang.basic.operators.JoinOperator; @@ -48,26 +50,21 @@ public class WayangMultiConditionJoinVisitor extends WayangRelNodeVisitor subConditions = call.operands.stream() .map(RexCall.class::cast) .collect(Collectors.toList()); - // calcite generates the RexInputRef indexes via looking at the union - // field list of the left and right input of a join. - // since the left input is always the first in this joined field list - // we can eagerly get the fields in the left input final List leftTableInputRefs = subConditions.stream() .map(sub -> sub.getOperands().stream() .map(RexInputRef.class::cast) @@ -79,9 +76,6 @@ Operator visit(WayangJoin wayangRelNode) { .map(RexInputRef::getIndex) .toArray(Integer[]::new); - // for the right table input refs, the indexes are offset by the amount of rows - // in the left - // input to the join final List rightTableInputRefs = subConditions.stream() .map(sub -> sub.getOperands().stream() .map(RexInputRef.class::cast) @@ -91,36 +85,40 @@ Operator visit(WayangJoin wayangRelNode) { final Integer[] rightTableKeyIndexes = rightTableInputRefs.stream() .map(RexInputRef::getIndex) - .map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount()) // apply offset + .map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount()) .toArray(Integer[]::new); - /* - final List leftFields = Arrays.stream(leftTableKeyIndexes) - .map(key -> wayangRelNode.getLeft().getRowType().getFieldList().get(key)) + final List leftFields = leftTableInputRefs.stream() + .map(ref -> wayangRelNode.getLeft().getRowType().getFieldList().get(ref.getIndex())) .collect(Collectors.toList()); - final List rightFields = Arrays.stream(rightTableKeyIndexes) - .map(key -> wayangRelNode.getRight().getRowType().getFieldList().get(key)) + final List rightFields = rightTableInputRefs.stream() + .map(ref -> wayangRelNode.getRight().getRowType().getFieldList().get(ref.getIndex() - wayangRelNode.getLeft().getRowType().getFieldCount())) .collect(Collectors.toList()); - final String joiningTableName = childOpLeft instanceof WayangTableScan ? childOpLeft.getName() : childOpRight.getName(); - */ - - // if join is joining the LHS of a join condition "JOIN left ON left = right" - // then we pick the first case, otherwise the 2nd "JOIN right ON left = right" - final JoinOperator join = this.getJoinOperator( + final String leftTableName = extractTableName(wayangRelNode.getLeft()); + final String rightTableName = extractTableName(wayangRelNode.getRight()); + + final String leftFieldNames = leftFields.stream() + .map(RelDataTypeField::getName) + .collect(Collectors.joining(",")); + + final String rightFieldNames = rightFields.stream() + .map(RelDataTypeField::getName) + .collect(Collectors.joining(",")); + + final JoinOperator join = getJoinOperator( leftTableKeyIndexes, rightTableKeyIndexes, wayangRelNode, - "", - "", - "", - ""); + leftTableName, + leftFieldNames, + rightTableName, + rightFieldNames); childOpLeft.connectTo(0, join, 0); childOpRight.connectTo(0, join, 1); - // Join returns Tuple2 - map to a Record final SerializableFunction, Record> mp = new JoinFlattenResult(); final MapOperator, Record> mapOperator = new MapOperator, Record>( @@ -133,19 +131,20 @@ Operator visit(WayangJoin wayangRelNode) { return mapOperator; } - /** - * This method handles the {@link JoinOperator} creation - * - * @param wayangRelNode - * @param leftKeyIndex - * @param rightKeyIndex - * @return - */ + private String extractTableName(org.apache.calcite.rel.RelNode relNode) { + if (relNode instanceof WayangTableScan) { + return ((WayangTableScan) relNode).getTableName(); + } + if (relNode.getInputs() != null && !relNode.getInputs().isEmpty()) { + return extractTableName(relNode.getInput(0)); + } + return "UNKNOWN"; + } + protected JoinOperator getJoinOperator(final Integer[] leftKeyIndexes, final Integer[] rightKeyIndexes, final WayangJoin wayangRelNode, final String leftTableName, final String leftFieldNames, final String rightTableName, final String rightFieldNames) { - // TODO: needs withSqlImplementation() for sql support if (wayangRelNode.getInputs().size() != 2) throw new UnsupportedOperationException("Join had an unexpected amount of inputs, found: " @@ -153,13 +152,13 @@ protected JoinOperator getJoinOperator(final Integer[] l final TransformationDescriptor leftProjectionDescriptor = new TransformationDescriptor( new MultiConditionJoinKeyExtractor(leftKeyIndexes), - Record.class, Record.class); - // .withSqlImplementation(""," ") + Record.class, Record.class) + .withSqlImplementation(leftTableName, leftFieldNames); final TransformationDescriptor rightProjectionDescriptor = new TransformationDescriptor( new MultiConditionJoinKeyExtractor(rightKeyIndexes), - Record.class, Record.class); - // .withSqlImplementation(""," ") + Record.class, Record.class) + .withSqlImplementation(rightTableName, rightFieldNames); final JoinOperator join = new JoinOperator<>( leftProjectionDescriptor, diff --git a/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/operators/JdbcJoinOperator.java b/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/operators/JdbcJoinOperator.java index 311b4f8fa..6ef378422 100644 --- a/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/operators/JdbcJoinOperator.java +++ b/wayang-platforms/wayang-jdbc-template/src/main/java/org/apache/wayang/jdbc/operators/JdbcJoinOperator.java @@ -69,13 +69,37 @@ public String createSqlClause(Connection connection, FunctionCompiler compiler) final Tuple left = this.keyDescriptor0.getSqlImplementation(); final Tuple right = this.keyDescriptor1.getSqlImplementation(); final String leftTableName = left.field0; - final String leftKey = left.field1; + final String leftKeys = left.field1; final String rightTableName = right.field0; - final String rightKey = right.field1; + final String rightKeys = right.field1; - return "JOIN " + rightTableName + " ON " + - rightTableName + "." + rightKey - + "=" + leftTableName + "." + leftKey; + if (leftKeys.contains(",") && rightKeys.contains(",")) { + final String[] leftColumns = leftKeys.split(","); + final String[] rightColumns = rightKeys.split(","); + + if (leftColumns.length != rightColumns.length) { + throw new IllegalStateException( + "Mismatch in join key counts: left has " + leftColumns.length + + " keys, right has " + rightColumns.length + " keys"); + } + + final StringBuilder joinCondition = new StringBuilder(); + for (int i = 0; i < leftColumns.length; i++) { + if (i > 0) { + joinCondition.append(" AND "); + } + joinCondition.append(leftTableName).append(".").append(leftColumns[i].trim()) + .append("=") + .append(rightTableName).append(".").append(rightColumns[i].trim()); + } + + return "JOIN " + rightTableName + " ON " + joinCondition.toString(); + } else { + // Backward compatibility + return "JOIN " + rightTableName + " ON " + + rightTableName + "." + rightKeys + + "=" + leftTableName + "." + leftKeys; + } } @Override diff --git a/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java b/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java index df22262ca..9b9cdc2ae 100644 --- a/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java +++ b/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java @@ -123,4 +123,82 @@ void testWithHsqldb() throws SQLException { sqlQueryChannelInstance.getSqlQuery() ); } + + @Test + void testMultiConditionJoinWithHsqldb() throws SQLException { + Configuration configuration = new Configuration(); + + Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + when(job.getCrossPlatformExecutor()).thenReturn(new CrossPlatformExecutor(job, new NoInstrumentationStrategy())); + SqlQueryChannel.Descriptor sqlChannelDescriptor = HsqldbPlatform.getInstance().getSqlQueryChannelDescriptor(); + + HsqldbPlatform hsqldbPlatform = new HsqldbPlatform(); + + ExecutionStage sqlStage = mock(ExecutionStage.class); + + // Create test data with multiple join keys + try (Connection jdbcConnection = hsqldbPlatform.createDatabaseDescriptor(configuration).createJdbcConnection()) { + final Statement statement = jdbcConnection.createStatement(); + statement.execute("CREATE TABLE orders (order_id INT, customer_id INT, product_id INT, quantity INT);"); + statement.execute("INSERT INTO orders VALUES (1, 100, 1001, 5);"); + statement.execute("INSERT INTO orders VALUES (2, 101, 1002, 3);"); + statement.execute("CREATE TABLE shipments (order_id INT, customer_id INT, ship_date VARCHAR(10));"); + statement.execute("INSERT INTO shipments VALUES (1, 100, '2024-01-15');"); + statement.execute("INSERT INTO shipments VALUES (2, 101, '2024-01-16');"); + } + + JdbcTableSource tableSourceOrders = new HsqldbTableSource("orders"); + JdbcTableSource tableSourceShipments = new HsqldbTableSource("shipments"); + + ExecutionTask tableSourceOrdersTask = new ExecutionTask(tableSourceOrders); + tableSourceOrdersTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceOrders.getOutput(0))); + tableSourceOrdersTask.setStage(sqlStage); + + ExecutionTask tableSourceShipmentsTask = new ExecutionTask(tableSourceShipments); + tableSourceShipmentsTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceShipments.getOutput(0))); + tableSourceShipmentsTask.setStage(sqlStage); + + // Create multi-condition join: JOIN ON orders.order_id = shipments.order_id AND orders.customer_id = shipments.customer_id + final ExecutionOperator joinOperator = new HsqldbJoinOperator( + new TransformationDescriptor( + (record) -> new Record(record.getField(0), record.getField(1)), + Record.class, + Record.class + ).withSqlImplementation("orders", "order_id,customer_id"), + new TransformationDescriptor( + (record) -> new Record(record.getField(0), record.getField(1)), + Record.class, + Record.class + ).withSqlImplementation("shipments", "order_id,customer_id") + ); + + ExecutionTask joinTask = new ExecutionTask(joinOperator); + tableSourceOrdersTask.getOutputChannel(0).addConsumer(joinTask, 0); + tableSourceShipmentsTask.getOutputChannel(0).addConsumer(joinTask, 1); + joinTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, joinOperator.getOutput(0))); + joinTask.setStage(sqlStage); + + when(sqlStage.getStartTasks()).thenReturn(Collections.singleton(tableSourceOrdersTask)); + when(sqlStage.getTerminalTasks()).thenReturn(Collections.singleton(joinTask)); + + ExecutionStage nextStage = mock(ExecutionStage.class); + + SqlToStreamOperator sqlToStreamOperator = new SqlToStreamOperator(HsqldbPlatform.getInstance()); + ExecutionTask sqlToStreamTask = new ExecutionTask(sqlToStreamOperator); + joinTask.getOutputChannel(0).addConsumer(sqlToStreamTask, 0); + sqlToStreamTask.setStage(nextStage); + + JdbcExecutor executor = new JdbcExecutor(HsqldbPlatform.getInstance(), job); + executor.execute(sqlStage, new DefaultOptimizationContext(job), job.getCrossPlatformExecutor()); + + SqlQueryChannel.Instance sqlQueryChannelInstance = + (SqlQueryChannel.Instance) job.getCrossPlatformExecutor().getChannelInstance(sqlToStreamTask.getInputChannel(0)); + + // Verify that multi-condition join generates proper SQL with AND + assertEquals( + "SELECT * FROM orders JOIN shipments ON orders.order_id=shipments.order_id AND orders.customer_id=shipments.customer_id;", + sqlQueryChannelInstance.getSqlQuery() + ); + } } From 3176183d8e2d656561a9384781b5ce18e24dc10e Mon Sep 17 00:00:00 2001 From: Makarand Milind Hinge Date: Wed, 25 Feb 2026 16:28:30 +0530 Subject: [PATCH 2/3] test: Add execution validation to multi-condition join test - Execute generated SQL against HSQLDB to verify actual join behavior - Validate row count (2 matching rows expected) - Verify data integrity of joined results - Test edge case: order with mismatched customer_id is correctly excluded - Confirm multi-condition AND clause works end-to-end --- .../jdbc/operators/JdbcJoinOperatorTest.java | 145 +++++++++++------- 1 file changed, 90 insertions(+), 55 deletions(-) diff --git a/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java b/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java index 9b9cdc2ae..ab984a751 100644 --- a/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java +++ b/wayang-platforms/wayang-jdbc-template/src/test/java/org/apache/wayang/jdbc/operators/JdbcJoinOperatorTest.java @@ -137,68 +137,103 @@ void testMultiConditionJoinWithHsqldb() throws SQLException { ExecutionStage sqlStage = mock(ExecutionStage.class); - // Create test data with multiple join keys - try (Connection jdbcConnection = hsqldbPlatform.createDatabaseDescriptor(configuration).createJdbcConnection()) { + Connection jdbcConnection = hsqldbPlatform.createDatabaseDescriptor(configuration).createJdbcConnection(); + try { final Statement statement = jdbcConnection.createStatement(); statement.execute("CREATE TABLE orders (order_id INT, customer_id INT, product_id INT, quantity INT);"); statement.execute("INSERT INTO orders VALUES (1, 100, 1001, 5);"); statement.execute("INSERT INTO orders VALUES (2, 101, 1002, 3);"); + statement.execute("INSERT INTO orders VALUES (3, 100, 1003, 2);"); statement.execute("CREATE TABLE shipments (order_id INT, customer_id INT, ship_date VARCHAR(10));"); statement.execute("INSERT INTO shipments VALUES (1, 100, '2024-01-15');"); statement.execute("INSERT INTO shipments VALUES (2, 101, '2024-01-16');"); + statement.execute("INSERT INTO shipments VALUES (3, 999, '2024-01-17');"); + + JdbcTableSource tableSourceOrders = new HsqldbTableSource("orders"); + JdbcTableSource tableSourceShipments = new HsqldbTableSource("shipments"); + + ExecutionTask tableSourceOrdersTask = new ExecutionTask(tableSourceOrders); + tableSourceOrdersTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceOrders.getOutput(0))); + tableSourceOrdersTask.setStage(sqlStage); + + ExecutionTask tableSourceShipmentsTask = new ExecutionTask(tableSourceShipments); + tableSourceShipmentsTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceShipments.getOutput(0))); + tableSourceShipmentsTask.setStage(sqlStage); + + final ExecutionOperator joinOperator = new HsqldbJoinOperator( + new TransformationDescriptor( + (record) -> new Record(record.getField(0), record.getField(1)), + Record.class, + Record.class + ).withSqlImplementation("orders", "order_id,customer_id"), + new TransformationDescriptor( + (record) -> new Record(record.getField(0), record.getField(1)), + Record.class, + Record.class + ).withSqlImplementation("shipments", "order_id,customer_id") + ); + + ExecutionTask joinTask = new ExecutionTask(joinOperator); + tableSourceOrdersTask.getOutputChannel(0).addConsumer(joinTask, 0); + tableSourceShipmentsTask.getOutputChannel(0).addConsumer(joinTask, 1); + joinTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, joinOperator.getOutput(0))); + joinTask.setStage(sqlStage); + + when(sqlStage.getStartTasks()).thenReturn(Collections.singleton(tableSourceOrdersTask)); + when(sqlStage.getTerminalTasks()).thenReturn(Collections.singleton(joinTask)); + + ExecutionStage nextStage = mock(ExecutionStage.class); + + SqlToStreamOperator sqlToStreamOperator = new SqlToStreamOperator(HsqldbPlatform.getInstance()); + ExecutionTask sqlToStreamTask = new ExecutionTask(sqlToStreamOperator); + joinTask.getOutputChannel(0).addConsumer(sqlToStreamTask, 0); + sqlToStreamTask.setStage(nextStage); + + JdbcExecutor executor = new JdbcExecutor(HsqldbPlatform.getInstance(), job); + executor.execute(sqlStage, new DefaultOptimizationContext(job), job.getCrossPlatformExecutor()); + + SqlQueryChannel.Instance sqlQueryChannelInstance = + (SqlQueryChannel.Instance) job.getCrossPlatformExecutor().getChannelInstance(sqlToStreamTask.getInputChannel(0)); + + String generatedSql = sqlQueryChannelInstance.getSqlQuery(); + assertEquals( + "SELECT * FROM orders JOIN shipments ON orders.order_id=shipments.order_id AND orders.customer_id=shipments.customer_id;", + generatedSql + ); + + java.sql.ResultSet resultSet = statement.executeQuery(generatedSql); + + int rowCount = 0; + boolean foundOrder1 = false; + boolean foundOrder2 = false; + + while (resultSet.next()) { + rowCount++; + int orderId = resultSet.getInt("order_id"); + int customerId = resultSet.getInt("customer_id"); + String shipDate = resultSet.getString("ship_date"); + + if (orderId == 1 && customerId == 100) { + foundOrder1 = true; + assertEquals("2024-01-15", shipDate); + assertEquals(1001, resultSet.getInt("product_id")); + assertEquals(5, resultSet.getInt("quantity")); + } else if (orderId == 2 && customerId == 101) { + foundOrder2 = true; + assertEquals("2024-01-16", shipDate); + assertEquals(1002, resultSet.getInt("product_id")); + assertEquals(3, resultSet.getInt("quantity")); + } + } + + assertEquals(2, rowCount, "Should return exactly 2 rows (order_id=3 with customer_id=100 should not match shipment with customer_id=999)"); + assertEquals(true, foundOrder1, "Should find order 1 with customer 100"); + assertEquals(true, foundOrder2, "Should find order 2 with customer 101"); + + resultSet.close(); + } finally { + jdbcConnection.close(); } - - JdbcTableSource tableSourceOrders = new HsqldbTableSource("orders"); - JdbcTableSource tableSourceShipments = new HsqldbTableSource("shipments"); - - ExecutionTask tableSourceOrdersTask = new ExecutionTask(tableSourceOrders); - tableSourceOrdersTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceOrders.getOutput(0))); - tableSourceOrdersTask.setStage(sqlStage); - - ExecutionTask tableSourceShipmentsTask = new ExecutionTask(tableSourceShipments); - tableSourceShipmentsTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceShipments.getOutput(0))); - tableSourceShipmentsTask.setStage(sqlStage); - - // Create multi-condition join: JOIN ON orders.order_id = shipments.order_id AND orders.customer_id = shipments.customer_id - final ExecutionOperator joinOperator = new HsqldbJoinOperator( - new TransformationDescriptor( - (record) -> new Record(record.getField(0), record.getField(1)), - Record.class, - Record.class - ).withSqlImplementation("orders", "order_id,customer_id"), - new TransformationDescriptor( - (record) -> new Record(record.getField(0), record.getField(1)), - Record.class, - Record.class - ).withSqlImplementation("shipments", "order_id,customer_id") - ); - - ExecutionTask joinTask = new ExecutionTask(joinOperator); - tableSourceOrdersTask.getOutputChannel(0).addConsumer(joinTask, 0); - tableSourceShipmentsTask.getOutputChannel(0).addConsumer(joinTask, 1); - joinTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, joinOperator.getOutput(0))); - joinTask.setStage(sqlStage); - - when(sqlStage.getStartTasks()).thenReturn(Collections.singleton(tableSourceOrdersTask)); - when(sqlStage.getTerminalTasks()).thenReturn(Collections.singleton(joinTask)); - - ExecutionStage nextStage = mock(ExecutionStage.class); - - SqlToStreamOperator sqlToStreamOperator = new SqlToStreamOperator(HsqldbPlatform.getInstance()); - ExecutionTask sqlToStreamTask = new ExecutionTask(sqlToStreamOperator); - joinTask.getOutputChannel(0).addConsumer(sqlToStreamTask, 0); - sqlToStreamTask.setStage(nextStage); - - JdbcExecutor executor = new JdbcExecutor(HsqldbPlatform.getInstance(), job); - executor.execute(sqlStage, new DefaultOptimizationContext(job), job.getCrossPlatformExecutor()); - - SqlQueryChannel.Instance sqlQueryChannelInstance = - (SqlQueryChannel.Instance) job.getCrossPlatformExecutor().getChannelInstance(sqlToStreamTask.getInputChannel(0)); - - // Verify that multi-condition join generates proper SQL with AND - assertEquals( - "SELECT * FROM orders JOIN shipments ON orders.order_id=shipments.order_id AND orders.customer_id=shipments.customer_id;", - sqlQueryChannelInstance.getSqlQuery() - ); } + } From 44c923e4f41405fc38f588cd3e4c7e903e58160f Mon Sep 17 00:00:00 2001 From: Makarand Milind Hinge Date: Thu, 26 Feb 2026 21:37:05 +0530 Subject: [PATCH 3/3] test: Add SQL API test for multi-condition join validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add sqlApiMultiConditionJoinGeneratesJdbcSql test - Validates SQL parsing through WayangMultiConditionJoinVisitor - Verifies withSqlImplementation() sets correct table names and field names - Confirms comma-separated field format for multi-condition joins - Proves end-to-end flow: SQL → Wayang plan → JDBC operators --- .../wayang/api/sql/SqlToWayangRelTest.java | 155 +++++++++++++++--- 1 file changed, 131 insertions(+), 24 deletions(-) diff --git a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java index 88114774b..9bde24477 100755 --- a/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java +++ b/wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java @@ -84,6 +84,7 @@ import org.apache.wayang.core.plan.wayangplan.Operator; import org.apache.wayang.core.plan.wayangplan.PlanTraversal; import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.util.Tuple; import org.apache.wayang.java.Java; import org.apache.wayang.jdbc.execution.JdbcExecutor; import org.apache.wayang.jdbc.operators.JdbcProjectionOperator; @@ -197,7 +198,7 @@ public RelDataType getRowType(final RelDataTypeFactory typeFactory) { final RelOptTable t1 = relOptSchema.getTableForMember(Arrays.asList("T1")); - final TableScan scan1 = LogicalTableScan.create(cluster, t1, List.of()); + final TableScan scan1 = LogicalTableScan.create(cluster, t1, Arrays.asList()); final SqlDialect dialect = SqlDialect.DatabaseProduct.CALCITE.getDialect(); final RelToSqlConverter converter = new RelToSqlConverter(dialect); @@ -226,17 +227,17 @@ public RelDataType getRowType(final RelDataTypeFactory typeFactory) { final PlanTransformation projectionTransformation = projectionMapping.getTransformations().iterator().next() .thatReplaces(); - plan.applyTransformations(List.of(projectionTransformation)); + plan.applyTransformations(Arrays.asList(projectionTransformation)); final Collection operators = PlanTraversal.upstream().traverse(plan.getSinks()).getTraversedNodes(); final JdbcTableSource table = operators.stream().filter(op -> op instanceof JdbcTableSource) - .map(JdbcTableSource.class::cast).findFirst().orElseThrow(); + .map(JdbcTableSource.class::cast).findFirst().orElseThrow(() -> new RuntimeException("Table not found")); final JdbcProjectionOperator projection = operators.stream().filter(op -> op instanceof JdbcProjectionOperator) - .map(JdbcProjectionOperator.class::cast).findFirst().orElseThrow(); + .map(JdbcProjectionOperator.class::cast).findFirst().orElseThrow(() -> new RuntimeException("Projection not found")); final JdbcExecutor jdbcExecutor = mock(); - final StringBuilder query = JdbcExecutor.createSqlString(jdbcExecutor, table, List.of(), projection, List.of()); + final StringBuilder query = JdbcExecutor.createSqlString(jdbcExecutor, table, Arrays.asList(), projection, Arrays.asList()); assertEquals("SELECT ID, NAME FROM T1;", query.toString()); } @@ -293,7 +294,7 @@ void aggregateCountInJavaWithIntegers() throws Exception { sqlContext.execute(wayangPlan); - final Record rec = result.stream().findFirst().orElseThrow(); + final Record rec = result.stream().findFirst().orElseThrow(() -> new RuntimeException("No record found")); assertEquals(2, rec.size()); assertEquals(3, rec.getInt(1)); } @@ -312,7 +313,7 @@ void aggregateCountInJava() throws Exception { sqlContext.execute(wayangPlan); - final Record rec = result.stream().findFirst().orElseThrow(); + final Record rec = result.stream().findFirst().orElseThrow(() -> new RuntimeException("No record found")); assertEquals(2, rec.size()); assertEquals(3, rec.getInt(1)); } @@ -343,7 +344,7 @@ void javaAverage() throws Exception { sqlContext.execute(wayangPlan); assertEquals(1, result.size()); - assertEquals(0.875f, result.stream().findFirst().orElseThrow().getDouble(0)); + assertEquals(0.875f, result.stream().findFirst().orElseThrow(() -> new RuntimeException("No record found")).getDouble(0)); } @Test @@ -407,7 +408,7 @@ void javaCrossJoin() throws Exception { sqlContext.execute(wayangPlan); - final List shouldBe = List.of(new Record("item1", "item2", "item1", "item2", "item3"), + final List shouldBe = Arrays.asList(new Record("item1", "item2", "item1", "item2", "item3"), new Record("item1", "item2", "item1", "item2", "item3"), new Record("item1", "item2", "item1", "item2", "item3"), new Record("item1", "item2", "item1", "item2", "item3"), new Record("item1", "item2", "x", "x", "x"), @@ -536,7 +537,7 @@ void joinWithLargeLeftTableIndexCorrect() throws Exception { final WayangPlan wayangPlan = t.field1; sqlContext.execute(wayangPlan); - final List shouldBe = List.of(new Record("test1", "test1", "test2", "test1", "test1", "test2"), + final List shouldBe = Arrays.asList(new Record("test1", "test1", "test2", "test1", "test1", "test2"), new Record("test2", "", "test2", "", "test2", "test2"), new Record("", "test2", "test2", "test2", "", "test2")); @@ -548,12 +549,6 @@ void joinWithLargeLeftTableIndexCorrect() throws Exception { assertEquals(resultTally, shouldBeTally); } - // Imagine case: l = {item1, item2}, r = {item3,item4}, j = {item1, item2, - // item3, item4} join on =($1,$3) would be =(item2, item4) in the join set - // however from the r set we need to factor in the - // offset, $3 -> 3 - l.size() = $1, r($1) = "item4" we cannot naively assume - // that it is always ordered as =(lRef,rRef), lRef < rRef. - // it may also be =($3,$1) @Test void joinWithLargeLeftTableIndexMirrorAlias() throws Exception { final SqlContext sqlContext = createSqlContext("/data/largeLeftTableIndex.csv"); @@ -566,7 +561,7 @@ void joinWithLargeLeftTableIndexMirrorAlias() throws Exception { final WayangPlan wayangPlan = t.field1; sqlContext.execute(wayangPlan); - final List shouldBe = List.of(new Record("test1", "test1", "test2", "test1", "test1", "test2"), + final List shouldBe = Arrays.asList(new Record("test1", "test1", "test2", "test1", "test1", "test2"), new Record("test2", "", "test2", "", "test2", "test2"), new Record("", "test2", "test2", "test2", "", "test2")); @@ -611,12 +606,11 @@ void sparkAggregate() throws Exception { sqlContext.execute(wayangPlan); - final Record rec = result.stream().findFirst().orElseThrow(); + final Record rec = result.stream().findFirst().orElseThrow(() -> new RuntimeException("No record found")); assertEquals(2, rec.size()); assertEquals(3, rec.getInt(1)); } - // tests sql-apis ability to serialize projections and joins @Test void sparkInnerJoin() throws Exception { final SqlContext sqlContext = createSqlContext("/data/largeLeftTableIndex.csv"); @@ -633,7 +627,7 @@ void sparkInnerJoin() throws Exception { sqlContext.execute(wayangPlan); - final List shouldBe = List.of(new Record("test1", "test1", "test2", "test1", "test1", "test2"), + final List shouldBe = Arrays.asList(new Record("test1", "test1", "test2", "test1", "test1", "test2"), new Record("test2", "", "test2", "", "test2", "test2"), new Record("", "test2", "test2", "test2", "", "test2")); @@ -660,12 +654,12 @@ void serializeProjection() throws Exception { final SqlOperator add = SqlStdOperatorTable.PLUS; final SqlOperator multiply = SqlStdOperatorTable.MULTIPLY; - final RexNode addition = rb.makeCall(add, List.of(inputRefX, inputRefB)); - final RexNode multiplication = rb.makeCall(multiply, List.of(addition, inputRefY)); + final RexNode addition = rb.makeCall(add, Arrays.asList(inputRefX, inputRefB)); + final RexNode multiplication = rb.makeCall(multiply, Arrays.asList(addition, inputRefY)); final RexCall projection = (RexCall) multiplication; - final ProjectMapFuncImpl impl = new ProjectMapFuncImpl(List.of(projection)); + final ProjectMapFuncImpl impl = new ProjectMapFuncImpl(Arrays.asList(projection)); final ByteArrayOutputStream byteOutStream = new ByteArrayOutputStream(); final ObjectOutputStream outStream = new ObjectOutputStream(byteOutStream); @@ -731,7 +725,7 @@ void exampleMinWithStrings() throws Exception { final WayangPlan wayangPlan = t.field1; sqlContext.execute(wayangPlan); - assertEquals("AA", result.stream().findAny().orElseThrow().getString(0)); + assertEquals("AA", result.stream().findAny().orElseThrow(() -> new RuntimeException("No record found")).getString(0)); } @Test @@ -786,6 +780,119 @@ void exampleCustomDelimiter() throws Exception { assertEquals(result.stream().findFirst().get().getInt(0), 3); } + @Test + void sqlApiMultiConditionJoinGeneratesJdbcSql() throws Exception { + final JavaTypeFactoryImpl typeFactory = new JavaTypeFactoryImpl(); + final RexBuilder rexBuilder = new RexBuilder(typeFactory); + + final VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.empty()); + planner.addRelTraitDef(ConventionTraitDef.INSTANCE); + final RelOptCluster cluster = RelOptCluster.create(planner, rexBuilder); + + final SchemaPlus rootSchema = Frameworks.createRootSchema(true); + final RelDataType ordersRowType = new Builder(typeFactory) + .add("order_id", typeFactory.createJavaType(Integer.class)) + .add("customer_id", typeFactory.createJavaType(Integer.class)) + .add("product_id", typeFactory.createJavaType(Integer.class)) + .build(); + + final RelDataType shipmentsRowType = new Builder(typeFactory) + .add("order_id", typeFactory.createJavaType(Integer.class)) + .add("customer_id", typeFactory.createJavaType(Integer.class)) + .add("ship_date", typeFactory.createJavaType(String.class)) + .build(); + + rootSchema.add("orders", new AbstractTable() { + @Override + public RelDataType getRowType(final RelDataTypeFactory typeFactory) { + return ordersRowType; + } + }); + + rootSchema.add("shipments", new AbstractTable() { + @Override + public RelDataType getRowType(final RelDataTypeFactory typeFactory) { + return shipmentsRowType; + } + }); + + final Properties configProperties = Optimizer.ConfigProperties.getDefaults(); + final RelDataTypeFactory relDataTypeFactory = new JavaTypeFactoryImpl(); + + final Optimizer optimizer = Optimizer.create(CalciteSchema.from(rootSchema), configProperties, + relDataTypeFactory); + + final SqlNode sqlNode = optimizer.parseSql( + "SELECT * FROM orders JOIN shipments ON orders.order_id = shipments.order_id AND orders.customer_id = shipments.customer_id"); + final SqlNode validatedSqlNode = optimizer.validate(sqlNode); + final RelNode relNode = optimizer.convert(validatedSqlNode); + + final RuleSet rules = RuleSets.ofList(CoreRules.FILTER_INTO_JOIN, WayangRules.WAYANG_TABLESCAN_RULE, + WayangRules.WAYANG_TABLESCAN_ENUMERABLE_RULE, WayangRules.WAYANG_PROJECT_RULE, + WayangRules.WAYANG_FILTER_RULE, WayangRules.WAYANG_JOIN_RULE, WayangRules.WAYANG_AGGREGATE_RULE, + WayangRules.WAYANG_SORT_RULE); + + final RelNode wayangRel = optimizer.optimize(relNode, relNode.getTraitSet().plus(WayangConvention.INSTANCE), + rules); + + final WayangPlan plan = Optimizer.convert(wayangRel, new ArrayList()); + + final ProjectionMapping projectionMapping = new ProjectionMapping(); + final PlanTransformation projectionTransformation = projectionMapping.getTransformations().iterator().next() + .thatReplaces(); + + plan.applyTransformations(Arrays.asList(projectionTransformation)); + + final Collection operators = PlanTraversal.upstream().traverse(plan.getSinks()).getTraversedNodes(); + + final JdbcTableSource ordersTable = operators.stream() + .filter(op -> op instanceof JdbcTableSource) + .map(JdbcTableSource.class::cast) + .filter(table -> table.getTableName().equals("orders")) + .findFirst().orElseThrow(() -> new RuntimeException("Orders table not found")); + + final JdbcTableSource shipmentsTable = operators.stream() + .filter(op -> op instanceof JdbcTableSource) + .map(JdbcTableSource.class::cast) + .filter(table -> table.getTableName().equals("shipments")) + .findFirst().orElseThrow(() -> new RuntimeException("Shipments table not found")); + + assertNotNull(ordersTable, "orders table should be present"); + assertNotNull(shipmentsTable, "shipments table should be present"); + + final org.apache.wayang.basic.operators.JoinOperator joinOp = operators.stream() + .filter(op -> op instanceof org.apache.wayang.basic.operators.JoinOperator) + .map(op -> (org.apache.wayang.basic.operators.JoinOperator) op) + .findFirst().orElseThrow(() -> new RuntimeException("Join operator not found")); + + assertNotNull(joinOp, "Join operator should be present"); + + // Verify the join operator has SQL implementations with correct field names + // This validates that WayangMultiConditionJoinVisitor called withSqlImplementation() + final Tuple leftSqlImpl = joinOp.getKeyDescriptor0().getSqlImplementation(); + final Tuple rightSqlImpl = joinOp.getKeyDescriptor1().getSqlImplementation(); + + assertNotNull(leftSqlImpl, "Left join key should have SQL implementation"); + assertNotNull(rightSqlImpl, "Right join key should have SQL implementation"); + + // Verify table names + assertEquals("orders", leftSqlImpl.field0, "Left table should be 'orders'"); + assertEquals("shipments", rightSqlImpl.field0, "Right table should be 'shipments'"); + + // Verify field names are comma-separated (multi-condition) + final String leftFields = leftSqlImpl.field1; + final String rightFields = rightSqlImpl.field1; + + assertTrue(leftFields.contains("order_id") && leftFields.contains("customer_id"), + "Left SQL implementation should contain both order_id and customer_id, got: " + leftFields); + assertTrue(rightFields.contains("order_id") && rightFields.contains("customer_id"), + "Right SQL implementation should contain both order_id and customer_id, got: " + rightFields); + + // Verify comma-separated format + assertTrue(leftFields.contains(","), "Left fields should be comma-separated for multi-condition join"); + assertTrue(rightFields.contains(","), "Right fields should be comma-separated for multi-condition join"); + } + private SqlContext createSqlContext(final String tableResourceName) throws IOException, ParseException, SQLException { final String calciteModel = "{\r\n" + " \"calcite\": {\r\n" + " \"version\": \"1.0\",\r\n"