-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[BEAM-9379] Simplify BeamCalcRel inputs #13930
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,6 @@ | |
| import java.util.AbstractList; | ||
| import java.util.AbstractMap; | ||
| import java.util.Arrays; | ||
| import java.util.Collection; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.Set; | ||
|
|
@@ -46,7 +45,6 @@ | |
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType; | ||
| import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType; | ||
| import org.apache.beam.sdk.schemas.Schema; | ||
| import org.apache.beam.sdk.schemas.logicaltypes.DateTime; | ||
| import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; | ||
| import org.apache.beam.sdk.transforms.DoFn; | ||
| import org.apache.beam.sdk.transforms.PTransform; | ||
|
|
@@ -95,6 +93,7 @@ | |
| import org.checkerframework.checker.nullness.qual.Nullable; | ||
| import org.codehaus.commons.compiler.CompileException; | ||
| import org.codehaus.janino.ScriptEvaluator; | ||
| import org.joda.time.DateTime; | ||
| import org.joda.time.Instant; | ||
| import org.joda.time.ReadableInstant; | ||
|
|
||
|
|
@@ -375,7 +374,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) { | |
| // Convert TIME to LocalTime | ||
| if (value.getType() == java.sql.Time.class) { | ||
| valueDateTime = Expressions.call(BuiltInMethod.TIME_TO_INT.method, valueDateTime); | ||
| } else if (value.getType() == Long.class) { | ||
| } else if (value.getType() == Integer.class || value.getType() == Long.class) { | ||
| valueDateTime = Expressions.unbox(valueDateTime); | ||
| } | ||
| valueDateTime = | ||
|
|
@@ -386,7 +385,7 @@ private static Expression castOutputTime(Expression value, FieldType toType) { | |
| // Convert DATE to LocalDate | ||
| if (value.getType() == java.sql.Date.class) { | ||
| valueDateTime = Expressions.call(BuiltInMethod.DATE_TO_INT.method, valueDateTime); | ||
| } else if (value.getType() == Long.class) { | ||
| } else if (value.getType() == Integer.class || value.getType() == Long.class) { | ||
| valueDateTime = Expressions.unbox(valueDateTime); | ||
| } | ||
| valueDateTime = Expressions.call(LocalDate.class, "ofEpochDay", valueDateTime); | ||
|
|
@@ -419,33 +418,6 @@ private static Expression castOutputTime(Expression value, FieldType toType) { | |
| } | ||
|
|
||
| private static class InputGetterImpl implements RexToLixTranslator.InputGetter { | ||
| private static final Map<TypeName, Class> TYPE_CONVERSION_MAP = | ||
| ImmutableMap.<TypeName, Class>builder() | ||
| .put(TypeName.BYTE, Byte.class) | ||
| .put(TypeName.BYTES, byte[].class) | ||
| .put(TypeName.INT16, Short.class) | ||
| .put(TypeName.INT32, Integer.class) | ||
| .put(TypeName.INT64, Long.class) | ||
| .put(TypeName.DECIMAL, BigDecimal.class) | ||
| .put(TypeName.FLOAT, Float.class) | ||
| .put(TypeName.DOUBLE, Double.class) | ||
| .put(TypeName.STRING, String.class) | ||
| .put(TypeName.DATETIME, ReadableInstant.class) | ||
| .put(TypeName.BOOLEAN, Boolean.class) | ||
| .put(TypeName.MAP, Map.class) | ||
| .put(TypeName.ARRAY, Collection.class) | ||
| .put(TypeName.ITERABLE, Iterable.class) | ||
| .put(TypeName.ROW, Row.class) | ||
| .build(); | ||
|
|
||
| private static final Map<String, Class> LOGICAL_TYPE_TO_BASE_TYPE_MAP = | ||
| ImmutableMap.<String, Class>builder() | ||
| .put(SqlTypes.DATE.getIdentifier(), Long.class) | ||
| .put(SqlTypes.TIME.getIdentifier(), Long.class) | ||
| .put(TimeWithLocalTzType.IDENTIFIER, ReadableInstant.class) | ||
| .put(SqlTypes.DATETIME.getIdentifier(), Row.class) | ||
| .put(CharType.IDENTIFIER, String.class) | ||
| .build(); | ||
|
|
||
| private final Expression input; | ||
| private final Schema inputSchema; | ||
|
|
@@ -457,84 +429,194 @@ private InputGetterImpl(Expression input, Schema inputSchema) { | |
|
|
||
| @Override | ||
| public Expression field(BlockBuilder list, int index, Type storageType) { | ||
| return value(list, index, storageType, input, inputSchema); | ||
| return getBeamField(list, index, storageType, input, inputSchema); | ||
| } | ||
|
|
||
| private static Expression value( | ||
| // Read field from Beam Row | ||
| private static Expression getBeamField( | ||
| BlockBuilder list, int index, Type storageType, Expression input, Schema schema) { | ||
| if (index >= schema.getFieldCount() || index < 0) { | ||
| throw new IllegalArgumentException("Unable to find value #" + index); | ||
| } | ||
|
|
||
| final Expression expression = list.append(list.newName("current"), input); | ||
|
|
||
| FieldType fromType = schema.getField(index).getType(); | ||
| Class convertTo = null; | ||
| if (storageType == Object.class) { | ||
| convertTo = Object.class; | ||
| } else if (fromType.getTypeName().isLogicalType()) { | ||
| convertTo = LOGICAL_TYPE_TO_BASE_TYPE_MAP.get(fromType.getLogicalType().getIdentifier()); | ||
| } else { | ||
| convertTo = TYPE_CONVERSION_MAP.get(fromType.getTypeName()); | ||
| } | ||
| if (convertTo == null) { | ||
| throw new UnsupportedOperationException("Unable to get " + fromType.getTypeName()); | ||
| FieldType fieldType = schema.getField(index).getType(); | ||
| Expression value; | ||
| switch (fieldType.getTypeName()) { | ||
| case BYTE: | ||
| value = Expressions.call(expression, "getByte", Expressions.constant(index)); | ||
apilloud marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| break; | ||
| case INT16: | ||
| value = Expressions.call(expression, "getInt16", Expressions.constant(index)); | ||
| break; | ||
| case INT32: | ||
| value = Expressions.call(expression, "getInt32", Expressions.constant(index)); | ||
| break; | ||
| case INT64: | ||
| value = Expressions.call(expression, "getInt64", Expressions.constant(index)); | ||
| break; | ||
| case DECIMAL: | ||
| value = Expressions.call(expression, "getDecimal", Expressions.constant(index)); | ||
| break; | ||
| case FLOAT: | ||
| value = Expressions.call(expression, "getFloat", Expressions.constant(index)); | ||
| break; | ||
| case DOUBLE: | ||
| value = Expressions.call(expression, "getDouble", Expressions.constant(index)); | ||
| break; | ||
| case STRING: | ||
| value = Expressions.call(expression, "getString", Expressions.constant(index)); | ||
| break; | ||
| case DATETIME: | ||
| value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); | ||
| break; | ||
| case BOOLEAN: | ||
| value = Expressions.call(expression, "getBoolean", Expressions.constant(index)); | ||
| break; | ||
| case BYTES: | ||
| value = Expressions.call(expression, "getBytes", Expressions.constant(index)); | ||
| break; | ||
| case ARRAY: | ||
| value = Expressions.call(expression, "getArray", Expressions.constant(index)); | ||
| if (storageType == Object.class | ||
| && TypeName.ROW.equals(fieldType.getCollectionElementType().getTypeName())) { | ||
| // Workaround for missing row output support | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the one special case that remains, it requires rewriting the output code to support rows. That is the next PR. |
||
| return Expressions.convert_(value, Object.class); | ||
| } | ||
| break; | ||
| case MAP: | ||
| value = Expressions.call(expression, "getMap", Expressions.constant(index)); | ||
| break; | ||
| case ROW: | ||
| value = Expressions.call(expression, "getRow", Expressions.constant(index)); | ||
| break; | ||
| case LOGICAL_TYPE: | ||
| String identifier = fieldType.getLogicalType().getIdentifier(); | ||
| if (CharType.IDENTIFIER.equals(identifier)) { | ||
| value = Expressions.call(expression, "getString", Expressions.constant(index)); | ||
| } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { | ||
| value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); | ||
| } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { | ||
| value = | ||
| Expressions.convert_( | ||
| Expressions.call( | ||
| expression, | ||
| "getLogicalTypeValue", | ||
| Expressions.constant(index), | ||
| Expressions.constant(LocalDate.class)), | ||
| LocalDate.class); | ||
| } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { | ||
| value = | ||
| Expressions.convert_( | ||
| Expressions.call( | ||
| expression, | ||
| "getLogicalTypeValue", | ||
| Expressions.constant(index), | ||
| Expressions.constant(LocalTime.class)), | ||
| LocalTime.class); | ||
| } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { | ||
| value = | ||
| Expressions.convert_( | ||
| Expressions.call( | ||
| expression, | ||
| "getLogicalTypeValue", | ||
| Expressions.constant(index), | ||
| Expressions.constant(LocalDateTime.class)), | ||
| LocalDateTime.class); | ||
| } else { | ||
| throw new UnsupportedOperationException("Unable to get logical type " + identifier); | ||
| } | ||
| break; | ||
| default: | ||
| throw new UnsupportedOperationException("Unable to get " + fieldType.getTypeName()); | ||
| } | ||
|
|
||
| Expression value = | ||
| Expressions.convert_( | ||
| Expressions.call( | ||
| expression, | ||
| "getBaseValue", | ||
| Expressions.constant(index), | ||
| Expressions.constant(convertTo)), | ||
| convertTo); | ||
| return (storageType != Object.class) ? value(value, fromType) : value; | ||
| return toCalciteValue(value, fieldType); | ||
| } | ||
|
|
||
| private static Expression value(Expression value, Schema.FieldType type) { | ||
| if (type.getTypeName().isLogicalType()) { | ||
| String logicalId = type.getLogicalType().getIdentifier(); | ||
| if (SqlTypes.TIME.getIdentifier().equals(logicalId)) { | ||
| // Value conversion: Beam => Calcite | ||
| private static Expression toCalciteValue(Expression value, FieldType fieldType) { | ||
| switch (fieldType.getTypeName()) { | ||
| case BYTE: | ||
| return Expressions.convert_(value, Byte.class); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are these conversions necessary? e.g. doesn't
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These may be coming out of a Map or List as an Object. This isn't strictly necessary, I'm doing it to ensure types are what we expect. It generates a cast which results in a ClassCastException if it isn't the type we expected. |
||
| case INT16: | ||
| return Expressions.convert_(value, Short.class); | ||
| case INT32: | ||
| return Expressions.convert_(value, Integer.class); | ||
| case INT64: | ||
| return Expressions.convert_(value, Long.class); | ||
| case DECIMAL: | ||
| return Expressions.convert_(value, BigDecimal.class); | ||
| case FLOAT: | ||
| return Expressions.convert_(value, Float.class); | ||
| case DOUBLE: | ||
| return Expressions.convert_(value, Double.class); | ||
| case STRING: | ||
| return Expressions.convert_(value, String.class); | ||
| case BOOLEAN: | ||
| return Expressions.convert_(value, Boolean.class); | ||
| case DATETIME: | ||
| return nullOr( | ||
| value, Expressions.divide(value, Expressions.constant(NANOS_PER_MILLISECOND))); | ||
| } else if (SqlTypes.DATE.getIdentifier().equals(logicalId)) { | ||
| return value; | ||
| } else if (SqlTypes.DATETIME.getIdentifier().equals(logicalId)) { | ||
| Expression dateValue = | ||
| Expressions.call(value, "getInt64", Expressions.constant(DateTime.DATE_FIELD_NAME)); | ||
| Expression timeValue = | ||
| Expressions.call(value, "getInt64", Expressions.constant(DateTime.TIME_FIELD_NAME)); | ||
| Expression returnValue = | ||
| Expressions.add( | ||
| Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)), | ||
| Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND))); | ||
| return nullOr(value, returnValue); | ||
| } else if (!CharType.IDENTIFIER.equals(logicalId)) { | ||
| throw new UnsupportedOperationException( | ||
| "Unknown LogicalType " + type.getLogicalType().getIdentifier()); | ||
| } | ||
| } else if (type.getTypeName().isMapType()) { | ||
| return nullOr(value, map(value, type.getMapValueType())); | ||
| } else if (CalciteUtils.isDateTimeType(type)) { | ||
| return nullOr(value, Expressions.call(value, "getMillis")); | ||
| } else if (type.getTypeName().isCompositeType()) { | ||
| return nullOr(value, row(value, type.getRowSchema())); | ||
| } else if (type.getTypeName().isCollectionType()) { | ||
| return nullOr(value, list(value, type.getCollectionElementType())); | ||
| } else if (type.getTypeName() == TypeName.BYTES) { | ||
| return nullOr( | ||
| value, Expressions.new_(ByteString.class, Types.castIfNecessary(byte[].class, value))); | ||
| value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis")); | ||
| case BYTES: | ||
| return nullOr( | ||
| value, Expressions.new_(ByteString.class, Expressions.convert_(value, byte[].class))); | ||
| case ARRAY: | ||
| return nullOr(value, toCalciteList(value, fieldType.getCollectionElementType())); | ||
| case MAP: | ||
| return nullOr(value, toCalciteMap(value, fieldType.getMapValueType())); | ||
| case ROW: | ||
| return nullOr(value, toCalciteRow(value, fieldType.getRowSchema())); | ||
| case LOGICAL_TYPE: | ||
| String identifier = fieldType.getLogicalType().getIdentifier(); | ||
| if (CharType.IDENTIFIER.equals(identifier)) { | ||
| return Expressions.convert_(value, String.class); | ||
| } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { | ||
| return nullOr( | ||
| value, Expressions.call(Expressions.convert_(value, DateTime.class), "getMillis")); | ||
| } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { | ||
| return nullOr( | ||
| value, | ||
| Expressions.call( | ||
| Expressions.box( | ||
| Expressions.call( | ||
| Expressions.convert_(value, LocalDate.class), "toEpochDay")), | ||
| "intValue")); | ||
| } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { | ||
| return nullOr( | ||
| value, | ||
| Expressions.call( | ||
| Expressions.box( | ||
| Expressions.divide( | ||
| Expressions.call( | ||
| Expressions.convert_(value, LocalTime.class), "toNanoOfDay"), | ||
| Expressions.constant(NANOS_PER_MILLISECOND))), | ||
| "intValue")); | ||
| } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { | ||
| value = Expressions.convert_(value, LocalDateTime.class); | ||
| Expression dateValue = | ||
| Expressions.call(Expressions.call(value, "toLocalDate"), "toEpochDay"); | ||
| Expression timeValue = | ||
| Expressions.call(Expressions.call(value, "toLocalTime"), "toNanoOfDay"); | ||
| Expression returnValue = | ||
| Expressions.add( | ||
| Expressions.multiply(dateValue, Expressions.constant(MILLIS_PER_DAY)), | ||
| Expressions.divide(timeValue, Expressions.constant(NANOS_PER_MILLISECOND))); | ||
| return nullOr(value, returnValue); | ||
| } else { | ||
| throw new UnsupportedOperationException("Unable to convert logical type " + identifier); | ||
| } | ||
| default: | ||
| throw new UnsupportedOperationException("Unable to convert " + fieldType.getTypeName()); | ||
| } | ||
|
|
||
| return value; | ||
| } | ||
|
|
||
| private static Expression list(Expression input, FieldType elementType) { | ||
| private static Expression toCalciteList(Expression input, FieldType elementType) { | ||
| ParameterExpression value = Expressions.parameter(Object.class); | ||
|
|
||
| BlockBuilder block = new BlockBuilder(); | ||
| block.add(value(value, elementType)); | ||
| block.add(toCalciteValue(value, elementType)); | ||
|
|
||
| return Expressions.new_( | ||
| WrappedList.class, | ||
|
|
@@ -548,11 +630,11 @@ private static Expression list(Expression input, FieldType elementType) { | |
| block.toBlock()))); | ||
| } | ||
|
|
||
| private static Expression map(Expression input, FieldType mapValueType) { | ||
| private static Expression toCalciteMap(Expression input, FieldType mapValueType) { | ||
| ParameterExpression value = Expressions.parameter(Object.class); | ||
|
|
||
| BlockBuilder block = new BlockBuilder(); | ||
| block.add(value(value, mapValueType)); | ||
| block.add(toCalciteValue(value, mapValueType)); | ||
|
|
||
| return Expressions.new_( | ||
| WrappedMap.class, | ||
|
|
@@ -566,14 +648,14 @@ private static Expression map(Expression input, FieldType mapValueType) { | |
| block.toBlock()))); | ||
| } | ||
|
|
||
| private static Expression row(Expression input, Schema schema) { | ||
| private static Expression toCalciteRow(Expression input, Schema schema) { | ||
| ParameterExpression row = Expressions.parameter(Row.class); | ||
| ParameterExpression index = Expressions.parameter(int.class); | ||
| BlockBuilder body = new BlockBuilder(/* optimizing= */ false); | ||
|
|
||
| for (int i = 0; i < schema.getFieldCount(); i++) { | ||
| BlockBuilder list = new BlockBuilder(/* optimizing= */ false, body); | ||
| Expression returnValue = value(list, i, /* storageType= */ null, row, schema); | ||
| Expression returnValue = getBeamField(list, i, /* storageType= */ null, row, schema); | ||
|
|
||
| list.append(returnValue); | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to include Integer now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal type is actually suppose to be int. We've been passing Calcite a Long instead and the compiler just happily upconverts when doing math between the two types. We are still doing math somewhere (window functions?) that turns this into a Long sometimes. I think even before this change we were at risk of receiving a Integer here. (In the next CL I'm switching this to the java Number interface so we are permissive on outputs.)
https://github.com/apache/calcite-avatica/blob/89e0deb510311b85b8c8bacde6d2ff70c309930e/core/src/main/java/org/apache/calcite/avatica/SqlType.java#L306