diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index 64a651d25156..984d3a70970d 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -33,7 +33,6 @@ import java.util.AbstractList; import java.util.AbstractMap; import java.util.Arrays; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; @@ -46,7 +45,6 @@ import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType; import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.schemas.logicaltypes.DateTime; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; @@ -95,6 +93,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.codehaus.commons.compiler.CompileException; import org.codehaus.janino.ScriptEvaluator; +import org.joda.time.DateTime; import org.joda.time.Instant; import org.joda.time.ReadableInstant; @@ -375,7 +374,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) { // Convert TIME to LocalTime if (value.getType() == java.sql.Time.class) { valueDateTime = Expressions.call(BuiltInMethod.TIME_TO_INT.method, valueDateTime); - } else if (value.getType() == Long.class) { + } else if (value.getType() == Integer.class || value.getType() == Long.class) { valueDateTime = Expressions.unbox(valueDateTime); } valueDateTime = @@ -386,7 +385,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) { // Convert DATE to LocalDate if (value.getType() == java.sql.Date.class) { valueDateTime = Expressions.call(BuiltInMethod.DATE_TO_INT.method, valueDateTime); - } else if (value.getType() == Long.class) { + } else if (value.getType() == Integer.class || value.getType() == Long.class) { valueDateTime = Expressions.unbox(valueDateTime); } valueDateTime = Expressions.call(LocalDate.class, "ofEpochDay", valueDateTime); @@ -419,33 +418,6 @@ private static Expression castOutputTime(Expression value, FieldType toType) { } private static class InputGetterImpl implements RexToLixTranslator.InputGetter { - private static final Map TYPE_CONVERSION_MAP = - ImmutableMap.builder() - .put(TypeName.BYTE, Byte.class) - .put(TypeName.BYTES, byte[].class) - .put(TypeName.INT16, Short.class) - .put(TypeName.INT32, Integer.class) - .put(TypeName.INT64, Long.class) - .put(TypeName.DECIMAL, BigDecimal.class) - .put(TypeName.FLOAT, Float.class) - .put(TypeName.DOUBLE, Double.class) - .put(TypeName.STRING, String.class) - .put(TypeName.DATETIME, ReadableInstant.class) - .put(TypeName.BOOLEAN, Boolean.class) - .put(TypeName.MAP, Map.class) - .put(TypeName.ARRAY, Collection.class) - .put(TypeName.ITERABLE, Iterable.class) - .put(TypeName.ROW, Row.class) - .build(); - - private static final Map LOGICAL_TYPE_TO_BASE_TYPE_MAP = - ImmutableMap.builder() - .put(SqlTypes.DATE.getIdentifier(), Long.class) - .put(SqlTypes.TIME.getIdentifier(), Long.class) - .put(TimeWithLocalTzType.IDENTIFIER, ReadableInstant.class) - .put(SqlTypes.DATETIME.getIdentifier(), Row.class) - .put(CharType.IDENTIFIER, String.class) - .build(); private final Expression input; private final Schema inputSchema; @@ -457,10 +429,11 @@ private InputGetterImpl(Expression input, Schema inputSchema) { @Override public Expression field(BlockBuilder list, int index, Type storageType) { - return value(list, index, storageType, input, inputSchema); + return getBeamField(list, index, storageType, input, inputSchema); } - private static Expression value( + // Read field from Beam Row + private static Expression getBeamField( BlockBuilder list, int index, Type storageType, Expression input, Schema schema) { if (index >= schema.getFieldCount() || index < 0) { throw new IllegalArgumentException("Unable to find value #" + index); @@ -468,73 +441,182 @@ private static Expression value( final Expression expression = list.append(list.newName("current"), input); - FieldType fromType = schema.getField(index).getType(); - Class convertTo = null; - if (storageType == Object.class) { - convertTo = Object.class; - } else if (fromType.getTypeName().isLogicalType()) { - convertTo = LOGICAL_TYPE_TO_BASE_TYPE_MAP.get(fromType.getLogicalType().getIdentifier()); - } else { - convertTo = TYPE_CONVERSION_MAP.get(fromType.getTypeName()); - } - if (convertTo == null) { - throw new UnsupportedOperationException("Unable to get " + fromType.getTypeName()); + FieldType fieldType = schema.getField(index).getType(); + Expression value; + switch (fieldType.getTypeName()) { + case BYTE: + value = Expressions.call(expression, "getByte", Expressions.constant(index)); + break; + case INT16: + value = Expressions.call(expression, "getInt16", Expressions.constant(index)); + break; + case INT32: + value = Expressions.call(expression, "getInt32", Expressions.constant(index)); + break; + case INT64: + value = Expressions.call(expression, "getInt64", Expressions.constant(index)); + break; + case DECIMAL: + value = Expressions.call(expression, "getDecimal", Expressions.constant(index)); + break; + case FLOAT: + value = Expressions.call(expression, "getFloat", Expressions.constant(index)); + break; + case DOUBLE: + value = Expressions.call(expression, "getDouble", Expressions.constant(index)); + break; + case STRING: + value = Expressions.call(expression, "getString", Expressions.constant(index)); + break; + case DATETIME: + value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); + break; + case BOOLEAN: + value = Expressions.call(expression, "getBoolean", Expressions.constant(index)); + break; + case BYTES: + value = Expressions.call(expression, "getBytes", Expressions.constant(index)); + break; + case ARRAY: + value = Expressions.call(expression, "getArray", Expressions.constant(index)); + if (storageType == Object.class + && TypeName.ROW.equals(fieldType.getCollectionElementType().getTypeName())) { + // Workaround for missing row output support + return Expressions.convert_(value, Object.class); + } + break; + case MAP: + value = Expressions.call(expression, "getMap", Expressions.constant(index)); + break; + case ROW: + value = Expressions.call(expression, "getRow", Expressions.constant(index)); + break; + case LOGICAL_TYPE: + String identifier = fieldType.getLogicalType().getIdentifier(); + if (CharType.IDENTIFIER.equals(identifier)) { + value = Expressions.call(expression, "getString", Expressions.constant(index)); + } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { + value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); + } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { + value = + Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + Expressions.constant(index), + Expressions.constant(LocalDate.class)), + LocalDate.class); + } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { + value = + Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + Expressions.constant(index), + Expressions.constant(LocalTime.class)), + LocalTime.class); + } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { + value = + Expressions.convert_( + Expressions.call( + expression, + "getLogicalTypeValue", + Expressions.constant(index), + Expressions.constant(LocalDateTime.class)), + LocalDateTime.class); + } else { + throw new UnsupportedOperationException("Unable to get logical type " + identifier); + } + break; + default: + throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName()); } - Expression value = - Expressions.convert_( - Expressions.call( - expression, - "getBaseValue", - Expressions.constant(index), - Expressions.constant(convertTo)), - convertTo); - return (storageType != Object.class) ? value(value, fromType) : value; + return toCalciteValue(value, fieldType); } - private static Expression value(Expression value, Schema.FieldType type) { - if (type.getTypeName().isLogicalType()) { - String logicalId = type.getLogicalType().getIdentifier(); - if (SqlTypes.TIME.getIdentifier().equals(logicalId)) { + // Value conversion: Beam => Calcite + private static Expression toCalciteValue(Expression value, FieldType fieldType) { + switch (fieldType.getTypeName()) { + case BYTE: + return Expressions.convert_(value, Byte.class); + case INT16: + return Expressions.convert_(value, Short.class); + case INT32: + return Expressions.convert_(value, Integer.class); + case INT64: + return Expressions.convert_(value, Long.class); + case DECIMAL: + return Expressions.convert_(value, BigDecimal.class); + case FLOAT: + return Expressions.convert_(value, Float.class); + case DOUBLE: + return Expressions.convert_(value, Double.class); + case STRING: + return Expressions.convert_(value, String.class); + case BOOLEAN: + return Expressions.convert_(value, Boolean.class); + case DATETIME: return nullOr( - value, Expressions.divide(value, Expressions.constant(NANOS_PER_MILLISECOND))); - } else if (SqlTypes.DATE.getIdentifier().equals(logicalId)) { - return value; - } else if (SqlTypes.DATETIME.getIdentifier().equals(logicalId)) { - Expression dateValue = - Expressions.call(value, "getInt64", Expressions.constant(DateTime.DATE_FIELD_NAME)); - Expression timeValue = - Expressions.call(value, "getInt64", Expressions.constant(DateTime.TIME_FIELD_NAME)); - Expression returnValue = - Expressions.add( - Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)), - Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND))); - return nullOr(value, returnValue); - } else if (!CharType.IDENTIFIER.equals(logicalId)) { - throw new UnsupportedOperationException( - "Unknown LogicalType " + type.getLogicalType().getIdentifier()); - } - } else if (type.getTypeName().isMapType()) { - return nullOr(value, map(value, type.getMapValueType())); - } else if (CalciteUtils.isDateTimeType(type)) { - return nullOr(value, Expressions.call(value, "getMillis")); - } else if (type.getTypeName().isCompositeType()) { - return nullOr(value, row(value, type.getRowSchema())); - } else if (type.getTypeName().isCollectionType()) { - return nullOr(value, list(value, type.getCollectionElementType())); - } else if (type.getTypeName() == TypeName.BYTES) { - return nullOr( - value, Expressions.new_(ByteString.class, Types.castIfNecessary(byte[].class, value))); + value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis")); + case BYTES: + return nullOr( + value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))); + case ARRAY: + return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType())); + case MAP: + return nullOr(value, toCalciteMap(value, fieldType.getMapValueType())); + case ROW: + return nullOr(value, toCalciteRow(value, fieldType.getRowSchema())); + case LOGICAL_TYPE: + String identifier = fieldType.getLogicalType().getIdentifier(); + if (CharType.IDENTIFIER.equals(identifier)) { + return Expressions.convert_(value, String.class); + } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { + return nullOr( + value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis")); + } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { + return nullOr( + value, + Expressions.call( + Expressions.box( + Expressions.call( + Expressions.convert_(value, LocalDate.class), "toEpochDay")), + "intValue")); + } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { + return nullOr( + value, + Expressions.call( + Expressions.box( + Expressions.divide( + Expressions.call( + Expressions.convert_(value, LocalTime.class), "toNanoOfDay"), + Expressions.constant(NANOS_PER_MILLISECOND))), + "intValue")); + } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { + value = Expressions.convert_(value, LocalDateTime.class); + Expression dateValue = + Expressions.call(Expressions.call(value, "toLocalDate"), "toEpochDay"); + Expression timeValue = + Expressions.call(Expressions.call(value, "toLocalTime"), "toNanoOfDay"); + Expression returnValue = + Expressions.add( + Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)), + Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND))); + return nullOr(value, returnValue); + } else { + throw new UnsupportedOperationException("Unable to convert logical type " + identifier); + } + default: + throw new UnsupportedOperationException("Unable to convert " + fieldType.getTypeName()); } - - return value; } - private static Expression list(Expression input, FieldType elementType) { + private static Expression toCalciteList(Expression input, FieldType elementType) { ParameterExpression value = Expressions.parameter(Object.class); BlockBuilder block = new BlockBuilder(); - block.add(value(value, elementType)); + block.add(toCalciteValue(value, elementType)); return Expressions.new_( WrappedList.class, @@ -548,11 +630,11 @@ private static Expression list(Expression input, FieldType elementType) { block.toBlock()))); } - private static Expression map(Expression input, FieldType mapValueType) { + private static Expression toCalciteMap(Expression input, FieldType mapValueType) { ParameterExpression value = Expressions.parameter(Object.class); BlockBuilder block = new BlockBuilder(); - block.add(value(value, mapValueType)); + block.add(toCalciteValue(value, mapValueType)); return Expressions.new_( WrappedMap.class, @@ -566,14 +648,14 @@ private static Expression map(Expression input, FieldType mapValueType) { block.toBlock()))); } - private static Expression row(Expression input, Schema schema) { + private static Expression toCalciteRow(Expression input, Schema schema) { ParameterExpression row = Expressions.parameter(Row.class); ParameterExpression index = Expressions.parameter(int.class); BlockBuilder body = new BlockBuilder(/* optimizing= */ false); for (int i = 0; i < schema.getFieldCount(); i++) { BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body); - Expression returnValue = value(list, i, /* storageType= */ null, row, schema); + Expression returnValue = getBeamField(list, i, /* storageType= */ null, row, schema); list.append(returnValue); diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java index 0519798d65fd..85490e5229d8 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamComplexTypeTest.java @@ -386,7 +386,10 @@ public void testDatetimeFields() { .setRowSchema(dateTimeFieldSchema) .apply( SqlTransform.query( - "select EXTRACT(YEAR from dateTimeField) as yyyy, " + "select " + + " dateTimeField, " + + " nullableDateTimeField, " + + " EXTRACT(YEAR from dateTimeField) as yyyy, " + " EXTRACT(YEAR from nullableDateTimeField) as year_with_null, " + " EXTRACT(MONTH from dateTimeField) as mm, " + " EXTRACT(MONTH from nullableDateTimeField) as month_with_null " @@ -394,6 +397,8 @@ public void testDatetimeFields() { Schema outputRowSchema = Schema.builder() + .addField("dateTimeField", FieldType.DATETIME) + .addNullableField("nullableDateTimeField", FieldType.DATETIME) .addField("yyyy", FieldType.INT64) .addNullableField("year_with_null", FieldType.INT64) .addField("mm", FieldType.INT64) @@ -402,7 +407,9 @@ public void testDatetimeFields() { PAssert.that(outputRow) .containsInAnyOrder( - Row.withSchema(outputRowSchema).addValues(2019L, null, 06L, null).build()); + Row.withSchema(outputRowSchema) + .addValues(current, null, 2019L, null, 06L, null) + .build()); pipeline.run().waitUntilFinish(Duration.standardMinutes(2)); } @@ -424,7 +431,10 @@ public void testSqlLogicalTypeDateFields() { .setRowSchema(dateTimeFieldSchema) .apply( SqlTransform.query( - "select EXTRACT(DAY from dateTypeField) as dd, " + "select " + + " dateTypeField, " + + " nullableDateTypeField, " + + " EXTRACT(DAY from dateTypeField) as dd, " + " EXTRACT(DAY from nullableDateTypeField) as day_with_null, " + " dateTypeField + interval '1' day as date_with_day_added, " + " nullableDateTypeField + interval '1' day as day_added_with_null " @@ -432,6 +442,8 @@ public void testSqlLogicalTypeDateFields() { Schema outputRowSchema = Schema.builder() + .addField("dateTypeField", FieldType.logicalType(SqlTypes.DATE)) + .addNullableField("nullableDateTypeField", FieldType.logicalType(SqlTypes.DATE)) .addField("dd", FieldType.INT64) .addNullableField("day_with_null", FieldType.INT64) .addField("date_with_day_added", FieldType.logicalType(SqlTypes.DATE)) @@ -441,7 +453,8 @@ public void testSqlLogicalTypeDateFields() { PAssert.that(outputRow) .containsInAnyOrder( Row.withSchema(outputRowSchema) - .addValues(27L, null, LocalDate.of(2019, 6, 28), null) + .addValues( + LocalDate.of(2019, 6, 27), null, 27L, null, LocalDate.of(2019, 6, 28), null) .build()); pipeline.run().waitUntilFinish(Duration.standardMinutes(2)); @@ -464,7 +477,10 @@ public void testSqlLogicalTypeTimeFields() { .setRowSchema(dateTimeFieldSchema) .apply( SqlTransform.query( - "select timeTypeField + interval '1' hour as time_with_hour_added, " + "select " + + " timeTypeField, " + + " nullableTimeTypeField, " + + " timeTypeField + interval '1' hour as time_with_hour_added, " + " nullableTimeTypeField + interval '1' hour as hour_added_with_null, " + " timeTypeField - INTERVAL '60' SECOND as time_with_seconds_added, " + " nullableTimeTypeField - INTERVAL '60' SECOND as seconds_added_with_null " @@ -472,6 +488,8 @@ public void testSqlLogicalTypeTimeFields() { Schema outputRowSchema = Schema.builder() + .addField("timeTypeField", FieldType.logicalType(SqlTypes.TIME)) + .addNullableField("nullableTimeTypeField", FieldType.logicalType(SqlTypes.TIME)) .addField("time_with_hour_added", FieldType.logicalType(SqlTypes.TIME)) .addNullableField("hour_added_with_null", FieldType.logicalType(SqlTypes.TIME)) .addField("time_with_seconds_added", FieldType.logicalType(SqlTypes.TIME)) @@ -481,7 +499,13 @@ public void testSqlLogicalTypeTimeFields() { PAssert.that(outputRow) .containsInAnyOrder( Row.withSchema(outputRowSchema) - .addValues(LocalTime.of(2, 0, 0), null, LocalTime.of(0, 59, 0), null) + .addValues( + LocalTime.of(1, 0, 0), + null, + LocalTime.of(2, 0, 0), + null, + LocalTime.of(0, 59, 0), + null) .build()); pipeline.run().waitUntilFinish(Duration.standardMinutes(2)); @@ -506,7 +530,10 @@ public void testSqlLogicalTypeDatetimeFields() { .setRowSchema(dateTimeFieldSchema) .apply( SqlTransform.query( - "select EXTRACT(YEAR from dateTimeField) as yyyy, " + "select " + + " dateTimeField, " + + " nullableDateTimeField, " + + " EXTRACT(YEAR from dateTimeField) as yyyy, " + " EXTRACT(YEAR from nullableDateTimeField) as year_with_null, " + " EXTRACT(MONTH from dateTimeField) as mm, " + " EXTRACT(MONTH from nullableDateTimeField) as month_with_null, " @@ -522,6 +549,8 @@ public void testSqlLogicalTypeDatetimeFields() { Schema outputRowSchema = Schema.builder() + .addField("dateTimeField", FieldType.logicalType(SqlTypes.DATETIME)) + .addNullableField("nullableDateTimeField", FieldType.logicalType(SqlTypes.DATETIME)) .addField("yyyy", FieldType.INT64) .addNullableField("year_with_null", FieldType.INT64) .addField("mm", FieldType.INT64) @@ -540,6 +569,8 @@ public void testSqlLogicalTypeDatetimeFields() { .containsInAnyOrder( Row.withSchema(outputRowSchema) .addValues( + LocalDateTime.of(2008, 12, 25, 15, 30, 0), + null, 2008L, null, 12L,