Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;

import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult;
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinKeyExtractor;
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;
import org.apache.wayang.api.sql.calcite.rel.WayangTableScan;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.operators.JoinOperator;
Expand All @@ -48,26 +50,21 @@ public class WayangMultiConditionJoinVisitor extends WayangRelNodeVisitor<Wayang
*
* @param wayangRelConverter
*/
WayangMultiConditionJoinVisitor(final WayangRelConverter wayangRelConverter) {
public WayangMultiConditionJoinVisitor(final WayangRelConverter wayangRelConverter) {
super(wayangRelConverter);
}

@Override
Operator visit(WayangJoin wayangRelNode) {
public Operator visit(WayangJoin wayangRelNode) {
final Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0));
final Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1));
final RexNode condition = ((Join) wayangRelNode).getCondition();
final RexCall call = (RexCall) condition;

//
final List<RexCall> subConditions = call.operands.stream()
.map(RexCall.class::cast)
.collect(Collectors.toList());

// calcite generates the RexInputRef indexes via looking at the union
// field list of the left and right input of a join.
// since the left input is always the first in this joined field list
// we can eagerly get the fields in the left input
final List<RexInputRef> leftTableInputRefs = subConditions.stream()
.map(sub -> sub.getOperands().stream()
.map(RexInputRef.class::cast)
Expand All @@ -79,9 +76,6 @@ Operator visit(WayangJoin wayangRelNode) {
.map(RexInputRef::getIndex)
.toArray(Integer[]::new);

// for the right table input refs, the indexes are offset by the amount of rows
// in the left
// input to the join
final List<RexInputRef> rightTableInputRefs = subConditions.stream()
.map(sub -> sub.getOperands().stream()
.map(RexInputRef.class::cast)
Expand All @@ -91,36 +85,40 @@ Operator visit(WayangJoin wayangRelNode) {

final Integer[] rightTableKeyIndexes = rightTableInputRefs.stream()
.map(RexInputRef::getIndex)
.map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount()) // apply offset
.map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount())
.toArray(Integer[]::new);

/*
final List<RelDataTypeField> leftFields = Arrays.stream(leftTableKeyIndexes)
.map(key -> wayangRelNode.getLeft().getRowType().getFieldList().get(key))
final List<RelDataTypeField> leftFields = leftTableInputRefs.stream()
.map(ref -> wayangRelNode.getLeft().getRowType().getFieldList().get(ref.getIndex()))
.collect(Collectors.toList());

final List<RelDataTypeField> rightFields = Arrays.stream(rightTableKeyIndexes)
.map(key -> wayangRelNode.getRight().getRowType().getFieldList().get(key))
final List<RelDataTypeField> rightFields = rightTableInputRefs.stream()
.map(ref -> wayangRelNode.getRight().getRowType().getFieldList().get(ref.getIndex() - wayangRelNode.getLeft().getRowType().getFieldCount()))
.collect(Collectors.toList());

final String joiningTableName = childOpLeft instanceof WayangTableScan ? childOpLeft.getName() : childOpRight.getName();
*/

// if join is joining the LHS of a join condition "JOIN left ON left = right"
// then we pick the first case, otherwise the 2nd "JOIN right ON left = right"
final JoinOperator<Record, Record, Record> join = this.getJoinOperator(
final String leftTableName = extractTableName(wayangRelNode.getLeft());
final String rightTableName = extractTableName(wayangRelNode.getRight());

final String leftFieldNames = leftFields.stream()
.map(RelDataTypeField::getName)
.collect(Collectors.joining(","));

final String rightFieldNames = rightFields.stream()
.map(RelDataTypeField::getName)
.collect(Collectors.joining(","));

final JoinOperator<Record, Record, Record> join = getJoinOperator(
leftTableKeyIndexes,
rightTableKeyIndexes,
wayangRelNode,
"",
"",
"",
"");
leftTableName,
leftFieldNames,
rightTableName,
rightFieldNames);

childOpLeft.connectTo(0, join, 0);
childOpRight.connectTo(0, join, 1);

// Join returns Tuple2 - map to a Record
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new JoinFlattenResult();

final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
Expand All @@ -133,33 +131,34 @@ Operator visit(WayangJoin wayangRelNode) {
return mapOperator;
}

/**
* This method handles the {@link JoinOperator} creation
*
* @param wayangRelNode
* @param leftKeyIndex
* @param rightKeyIndex
* @return
*/
private String extractTableName(org.apache.calcite.rel.RelNode relNode) {
if (relNode instanceof WayangTableScan) {
return ((WayangTableScan) relNode).getTableName();
}
if (relNode.getInputs() != null && !relNode.getInputs().isEmpty()) {
return extractTableName(relNode.getInput(0));
}
return "UNKNOWN";
}

protected JoinOperator<Record, Record, Record> getJoinOperator(final Integer[] leftKeyIndexes,
final Integer[] rightKeyIndexes,
final WayangJoin wayangRelNode, final String leftTableName, final String leftFieldNames,
final String rightTableName, final String rightFieldNames) {
// TODO: needs withSqlImplementation() for sql support

if (wayangRelNode.getInputs().size() != 2)
throw new UnsupportedOperationException("Join had an unexpected amount of inputs, found: "
+ wayangRelNode.getInputs().size() + ", expected: 2");

final TransformationDescriptor<Record, Record> leftProjectionDescriptor = new TransformationDescriptor<Record, Record>(
new MultiConditionJoinKeyExtractor(leftKeyIndexes),
Record.class, Record.class);
// .withSqlImplementation(""," ")
Record.class, Record.class)
.withSqlImplementation(leftTableName, leftFieldNames);

final TransformationDescriptor<Record, Record> rightProjectionDescriptor = new TransformationDescriptor<Record, Record>(
new MultiConditionJoinKeyExtractor(rightKeyIndexes),
Record.class, Record.class);
// .withSqlImplementation(""," ")
Record.class, Record.class)
.withSqlImplementation(rightTableName, rightFieldNames);

final JoinOperator<Record, Record, Record> join = new JoinOperator<>(
leftProjectionDescriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,37 @@ public String createSqlClause(Connection connection, FunctionCompiler compiler)
final Tuple<String, String> left = this.keyDescriptor0.getSqlImplementation();
final Tuple<String, String> right = this.keyDescriptor1.getSqlImplementation();
final String leftTableName = left.field0;
final String leftKey = left.field1;
final String leftKeys = left.field1;
final String rightTableName = right.field0;
final String rightKey = right.field1;
final String rightKeys = right.field1;

return "JOIN " + rightTableName + " ON " +
rightTableName + "." + rightKey
+ "=" + leftTableName + "." + leftKey;
if (leftKeys.contains(",") && rightKeys.contains(",")) {
final String[] leftColumns = leftKeys.split(",");
final String[] rightColumns = rightKeys.split(",");

if (leftColumns.length != rightColumns.length) {
throw new IllegalStateException(
"Mismatch in join key counts: left has " + leftColumns.length +
" keys, right has " + rightColumns.length + " keys");
}

final StringBuilder joinCondition = new StringBuilder();
for (int i = 0; i < leftColumns.length; i++) {
if (i > 0) {
joinCondition.append(" AND ");
}
joinCondition.append(leftTableName).append(".").append(leftColumns[i].trim())
.append("=")
.append(rightTableName).append(".").append(rightColumns[i].trim());
}

return "JOIN " + rightTableName + " ON " + joinCondition.toString();
} else {
// Backward compatibility
return "JOIN " + rightTableName + " ON " +
rightTableName + "." + rightKeys
+ "=" + leftTableName + "." + leftKeys;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,117 @@ void testWithHsqldb() throws SQLException {
sqlQueryChannelInstance.getSqlQuery()
);
}

@Test
void testMultiConditionJoinWithHsqldb() throws SQLException {
Configuration configuration = new Configuration();

Job job = mock(Job.class);
when(job.getConfiguration()).thenReturn(configuration);
when(job.getCrossPlatformExecutor()).thenReturn(new CrossPlatformExecutor(job, new NoInstrumentationStrategy()));
SqlQueryChannel.Descriptor sqlChannelDescriptor = HsqldbPlatform.getInstance().getSqlQueryChannelDescriptor();

HsqldbPlatform hsqldbPlatform = new HsqldbPlatform();

ExecutionStage sqlStage = mock(ExecutionStage.class);

Connection jdbcConnection = hsqldbPlatform.createDatabaseDescriptor(configuration).createJdbcConnection();
try {
final Statement statement = jdbcConnection.createStatement();
statement.execute("CREATE TABLE orders (order_id INT, customer_id INT, product_id INT, quantity INT);");
statement.execute("INSERT INTO orders VALUES (1, 100, 1001, 5);");
statement.execute("INSERT INTO orders VALUES (2, 101, 1002, 3);");
statement.execute("INSERT INTO orders VALUES (3, 100, 1003, 2);");
statement.execute("CREATE TABLE shipments (order_id INT, customer_id INT, ship_date VARCHAR(10));");
statement.execute("INSERT INTO shipments VALUES (1, 100, '2024-01-15');");
statement.execute("INSERT INTO shipments VALUES (2, 101, '2024-01-16');");
statement.execute("INSERT INTO shipments VALUES (3, 999, '2024-01-17');");

JdbcTableSource tableSourceOrders = new HsqldbTableSource("orders");
JdbcTableSource tableSourceShipments = new HsqldbTableSource("shipments");

ExecutionTask tableSourceOrdersTask = new ExecutionTask(tableSourceOrders);
tableSourceOrdersTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceOrders.getOutput(0)));
tableSourceOrdersTask.setStage(sqlStage);

ExecutionTask tableSourceShipmentsTask = new ExecutionTask(tableSourceShipments);
tableSourceShipmentsTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, tableSourceShipments.getOutput(0)));
tableSourceShipmentsTask.setStage(sqlStage);

final ExecutionOperator joinOperator = new HsqldbJoinOperator<Record>(
new TransformationDescriptor<Record, Record>(
(record) -> new Record(record.getField(0), record.getField(1)),
Record.class,
Record.class
).withSqlImplementation("orders", "order_id,customer_id"),
new TransformationDescriptor<Record, Record>(
(record) -> new Record(record.getField(0), record.getField(1)),
Record.class,
Record.class
).withSqlImplementation("shipments", "order_id,customer_id")
);

ExecutionTask joinTask = new ExecutionTask(joinOperator);
tableSourceOrdersTask.getOutputChannel(0).addConsumer(joinTask, 0);
tableSourceShipmentsTask.getOutputChannel(0).addConsumer(joinTask, 1);
joinTask.setOutputChannel(0, new SqlQueryChannel(sqlChannelDescriptor, joinOperator.getOutput(0)));
joinTask.setStage(sqlStage);

when(sqlStage.getStartTasks()).thenReturn(Collections.singleton(tableSourceOrdersTask));
when(sqlStage.getTerminalTasks()).thenReturn(Collections.singleton(joinTask));

ExecutionStage nextStage = mock(ExecutionStage.class);

SqlToStreamOperator sqlToStreamOperator = new SqlToStreamOperator(HsqldbPlatform.getInstance());
ExecutionTask sqlToStreamTask = new ExecutionTask(sqlToStreamOperator);
joinTask.getOutputChannel(0).addConsumer(sqlToStreamTask, 0);
sqlToStreamTask.setStage(nextStage);

JdbcExecutor executor = new JdbcExecutor(HsqldbPlatform.getInstance(), job);
executor.execute(sqlStage, new DefaultOptimizationContext(job), job.getCrossPlatformExecutor());

SqlQueryChannel.Instance sqlQueryChannelInstance =
(SqlQueryChannel.Instance) job.getCrossPlatformExecutor().getChannelInstance(sqlToStreamTask.getInputChannel(0));

String generatedSql = sqlQueryChannelInstance.getSqlQuery();
assertEquals(
"SELECT * FROM orders JOIN shipments ON orders.order_id=shipments.order_id AND orders.customer_id=shipments.customer_id;",
generatedSql
);

java.sql.ResultSet resultSet = statement.executeQuery(generatedSql);

int rowCount = 0;
boolean foundOrder1 = false;
boolean foundOrder2 = false;

while (resultSet.next()) {
rowCount++;
int orderId = resultSet.getInt("order_id");
int customerId = resultSet.getInt("customer_id");
String shipDate = resultSet.getString("ship_date");

if (orderId == 1 && customerId == 100) {
foundOrder1 = true;
assertEquals("2024-01-15", shipDate);
assertEquals(1001, resultSet.getInt("product_id"));
assertEquals(5, resultSet.getInt("quantity"));
} else if (orderId == 2 && customerId == 101) {
foundOrder2 = true;
assertEquals("2024-01-16", shipDate);
assertEquals(1002, resultSet.getInt("product_id"));
assertEquals(3, resultSet.getInt("quantity"));
}
}

assertEquals(2, rowCount, "Should return exactly 2 rows (order_id=3 with customer_id=100 should not match shipment with customer_id=999)");
assertEquals(true, foundOrder1, "Should find order 1 with customer 100");
assertEquals(true, foundOrder2, "Should find order 2 with customer 101");

resultSet.close();
} finally {
jdbcConnection.close();
}
}

}